gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util, SpaceToBatchND, util } from '@tensorflow/tfjs-core';
import { assertNotComplex } from '../cpu_util';
import { padV2Config } from './PadV2';
import { reshape } from './Reshape';
import { transpose } from './Transpose';
export function spaceToBatchND(args) {
    const { inputs, backend, attrs } = args;
    const { x } = inputs;
    const { blockShape, paddings } = attrs;
    assertNotComplex([x], 'spaceToBatchND');
    const prod = util.sizeFromShape(blockShape);
    const completePaddings = [[0, 0]];
    completePaddings.push(...paddings);
    for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {
        completePaddings.push([0, 0]);
    }
    const paddedX = padV2Config.kernelFunc({
        inputs: { x },
        backend,
        attrs: { paddings: completePaddings, constantValue: 0 }
    });
    const reshapedPaddedShape = backend_util.getReshaped(paddedX.shape, blockShape, prod, false);
    const permutedReshapedPaddedPermutation = backend_util.getPermuted(reshapedPaddedShape.length, blockShape.length, false);
    const flattenShape = backend_util.getReshapedPermuted(paddedX.shape, blockShape, prod, false);
    const reshapeInputs = { x: paddedX };
    const reshapeAttrs = { shape: reshapedPaddedShape };
    const paddedXReshaped = reshape({ inputs: reshapeInputs, backend, attrs: reshapeAttrs });
    const transposeInputs = { x: paddedXReshaped };
    const transposeAttrs = { perm: permutedReshapedPaddedPermutation };
    const paddedXT = transpose({ inputs: transposeInputs, backend, attrs: transposeAttrs });
    const resultReshapeInputs = { x: paddedXT };
    const resultReshapeAttrs = { shape: flattenShape };
    const result = reshape({ inputs: resultReshapeInputs, backend, attrs: resultReshapeAttrs });
    backend.disposeIntermediateTensorInfo(paddedX);
    backend.disposeIntermediateTensorInfo(paddedXReshaped);
    backend.disposeIntermediateTensorInfo(paddedXT);
    return result;
}
export const spaceToBatchNDConfig = {
    kernelName: SpaceToBatchND,
    backendName: 'cpu',
    kernelFunc: spaceToBatchND
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3BhY2VUb0JhdGNoTkQuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtY3B1L3NyYy9rZXJuZWxzL1NwYWNlVG9CYXRjaE5ELnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxZQUFZLEVBQXlELGNBQWMsRUFBMEYsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFHeE4sT0FBTyxFQUFDLGdCQUFnQixFQUFDLE1BQU0sYUFBYSxDQUFDO0FBRTdDLE9BQU8sRUFBQyxXQUFXLEVBQUMsTUFBTSxTQUFTLENBQUM7QUFDcEMsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNsQyxPQUFPLEVBQUMsU0FBUyxFQUFDLE1BQU0sYUFBYSxDQUFDO0FBRXRDLE1BQU0sVUFBVSxjQUFjLENBQUMsSUFJOUI7SUFDQyxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLENBQUMsRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUNuQixNQUFNLEVBQUMsVUFBVSxFQUFFLFFBQVEsRUFBQyxHQUFHLEtBQUssQ0FBQztJQUVyQyxnQkFBZ0IsQ0FBQyxDQUFDLENBQUMsQ0FBQyxFQUFFLGdCQUFnQixDQUFDLENBQUM7SUFFeEMsTUFBTSxJQUFJLEdBQUcsSUFBSSxDQUFDLGFBQWEsQ0FBQyxVQUFVLENBQUMsQ0FBQztJQUU1QyxNQUFNLGdCQUFnQixHQUE0QixDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUM7SUFDM0QsZ0JBQWdCLENBQUMsSUFBSSxDQUFDLEdBQUksUUFBb0MsQ0FBQyxDQUFDO0lBRWhFLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxHQUFHLFVBQVUsQ0FBQyxNQUFNLEVBQUUsQ0FBQyxHQUFHLENBQUMsQ0FBQyxLQUFLLENBQUMsTUFBTSxFQUFFLEVBQUUsQ0FBQyxFQUFFO1FBQzNELGdCQUFnQixDQUFDLElBQUksQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDO0tBQy9CO0lBRUQsTUFBTSxPQUFPLEdBQUcsV0FBVyxDQUFDLFVBQVUsQ0FBQztRQUNyQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUM7UUFDWCxPQUFPO1FBQ1AsS0FBSyxFQUFFLEVBQUMsUUFBUSxFQUFFLGdCQUFnQixFQUFFLGFBQWEsRUFBRSxDQUFDLEVBQUM7S0FDdEQsQ0FBZSxDQUFDO0lBRWpCLE1BQU0sbUJBQW1CLEdBQ3JCLFlBQVksQ0FBQyxXQUFXLENBQUMsT0FBTyxDQUFDLEtBQUssRUFBRSxVQUFVLEVBQUUsSUFBSSxFQUFFLEtBQUssQ0FBQyxDQUFDO0lBRXJFLE1BQU0saUNBQWlDLEdBQUcsWUFBWSxDQUFDLFdBQVcsQ0FDOUQsbUJBQW1CLENBQUMsTUFBTSxFQUFFLFVBQVUsQ0FBQyxNQUFNLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFFMUQsTUFBTSxZQUFZLEdBQ2QsWUFBWSxDQUFDLG1CQUFtQixDQUFDLE9BQU8sQ0FBQyxLQUFLLEVBQUUsVUFBVSxFQUFFLElBQUksRUFBRSxLQUFLLENBQUMsQ0FBQztJQUU3RSxNQUFNLGFBQWEsR0FBa0IsRUFBQyxDQUFDLEVBQUUsT0FBTyxFQUFDLENBQUM7SUFDbEQsTUFBTSxZQUFZLEdBQWlCLEVBQUMsS0FBSyxFQUFFLG1CQUFtQixFQUFDLENBQUM7SUFDaEUsTUFBTSxlQUFlLEdBQ2pCLE9BQU8sQ0FBQyxFQUFDLE1BQU0sRUFBRSxhQUFhLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxZQUFZLEVBQUMsQ0FBQyxDQUFDO0lBRW5FLE1BQU0sZUFBZSxHQUFvQixFQUFDLENBQUMsRUFBRSxlQUFlLEVBQUMsQ0FBQztJQUM5RCxNQUFNLGNBQWMsR0FDQyxFQUFDLElBQUksRUFBRSxpQ0FBaUMsRUFBQyxDQUFDO0lBQy9ELE1BQU0sUUFBUSxHQUNWLFNBQVMsQ0FBQyxFQUFDLE1BQU0sRUFBRSxlQUFlLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxjQUFjLEVBQUMsQ0FBQyxDQUFDO0lBRXpFLE1BQU0sbUJBQW1CLEdBQWtCLEVBQUMsQ0FBQyxFQUFFLFFBQVEsRUFBQyxDQUFDO0lBQ3pELE1BQU0sa0JBQWtCLEdBQWlCLEVBQUMsS0FBSyxFQUFFLFlBQVksRUFBQyxDQUFDO0lBQy9ELE1BQU0sTUFBTSxHQUFHLE9BQU8sQ0FDbEIsRUFBQyxNQUFNLEVBQUUsbUJBQW1CLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxrQkFBa0IsRUFBQyxDQUFDLENBQUM7SUFFdkUsT0FBTyxDQUFDLDZCQUE2QixDQUFDLE9BQU8sQ0FBQyxDQUFDO0lBQy9DLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxlQUFlLENBQUMsQ0FBQztJQUN2RCxPQUFPLENBQUMsNkJBQTZCLENBQUMsUUFBUSxDQUFDLENBQUM7SUFFaEQsT0FBTyxNQUFNLENBQUM7QUFDaEIsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLG9CQUFvQixHQUFpQjtJQUNoRCxVQUFVLEVBQUUsY0FBYztJQUMxQixXQUFXLEVBQUUsS0FBSztJQUNsQixVQUFVLEVBQUUsY0FBdUM7Q0FDcEQsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtiYWNrZW5kX3V0aWwsIEtlcm5lbENvbmZpZywgS2VybmVsRnVuYywgUmVzaGFwZUF0dHJzLCBSZXNoYXBlSW5wdXRzLCBTcGFjZVRvQmF0Y2hORCwgU3BhY2VUb0JhdGNoTkRBdHRycywgU3BhY2VUb0JhdGNoTkRJbnB1dHMsIFRlbnNvckluZm8sIFRyYW5zcG9zZUF0dHJzLCBUcmFuc3Bvc2VJbnB1dHMsIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRDUFV9IGZyb20gJy4uL2JhY2tlbmRfY3B1JztcbmltcG9ydCB7YXNzZXJ0Tm90Q29tcGxleH0gZnJvbSAnLi4vY3B1X3V0aWwnO1xuXG5pbXBvcnQge3BhZFYyQ29uZmlnfSBmcm9tICcuL1BhZFYyJztcbmltcG9ydCB7cmVzaGFwZX0gZnJvbSAnLi9SZXNoYXBlJztcbmltcG9ydCB7dHJhbnNwb3NlfSBmcm9tICcuL1RyYW5zcG9zZSc7XG5cbmV4cG9ydCBmdW5jdGlvbiBzcGFjZVRvQmF0Y2hORChhcmdzOiB7XG4gIGlucHV0czogU3BhY2VUb0JhdGNoTkRJbnB1dHMsXG4gIGJhY2tlbmQ6IE1hdGhCYWNrZW5kQ1BVLFxuICBhdHRyczogU3BhY2VUb0JhdGNoTkRBdHRyc1xufSk6IFRlbnNvckluZm8ge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7eH0gPSBpbnB1dHM7XG4gIGNvbnN0IHtibG9ja1NoYXBlLCBwYWRkaW5nc30gPSBhdHRycztcblxuICBhc3NlcnROb3RDb21wbGV4KFt4XSwgJ3NwYWNlVG9CYXRjaE5EJyk7XG5cbiAgY29uc3QgcHJvZCA9IHV0aWwuc2l6ZUZyb21TaGFwZShibG9ja1NoYXBlKTtcblxuICBjb25zdCBjb21wbGV0ZVBhZGRpbmdzOiBBcnJheTxbbnVtYmVyLCBudW1iZXJdPiA9IFtbMCwgMF1dO1xuICBjb21wbGV0ZVBhZGRpbmdzLnB1c2goLi4uKHBhZGRpbmdzIGFzIEFycmF5PFtudW1iZXIsIG51bWJlcl0+KSk7XG5cbiAgZm9yIChsZXQgaSA9IDEgKyBibG9ja1NoYXBlLmxlbmd0aDsgaSA8IHguc2hhcGUubGVuZ3RoOyArK2kpIHtcbiAgICBjb21wbGV0ZVBhZGRpbmdzLnB1c2goWzAsIDBdKTtcbiAgfVxuXG4gIGNvbnN0IHBhZGRlZFggPSBwYWRWMkNvbmZpZy5rZXJuZWxGdW5jKHtcbiAgICBpbnB1dHM6IHt4fSxcbiAgICBiYWNrZW5kLFxuICAgIGF0dHJzOiB7cGFkZGluZ3M6IGNvbXBsZXRlUGFkZGluZ3MsIGNvbnN0YW50VmFsdWU6IDB9XG4gIH0pIGFzIFRlbnNvckluZm87XG5cbiAgY29uc3QgcmVzaGFwZWRQYWRkZWRTaGFwZSA9XG4gICAgICBiYWNrZW5kX3V0aWwuZ2V0UmVzaGFwZWQocGFkZGVkWC5zaGFwZSwgYmxvY2tTaGFwZSwgcHJvZCwgZmFsc2UpO1xuXG4gIGNvbnN0IHBlcm11dGVkUmVzaGFwZWRQYWRkZWRQZXJtdXRhdGlvbiA9IGJhY2tlbmRfdXRpbC5nZXRQZXJtdXRlZChcbiAgICAgIHJlc2hhcGVkUGFkZGVkU2hhcGUubGVuZ3RoLCBibG9ja1NoYXBlLmxlbmd0aCwgZmFsc2UpO1xuXG4gIGNvbnN0IGZsYXR0ZW5TaGFwZSA9XG4gICAgICBiYWNrZW5kX3V0aWwuZ2V0UmVzaGFwZWRQZXJtdXRlZChwYWRkZWRYLnNoYXBlLCBibG9ja1NoYXBlLCBwcm9kLCBmYWxzZSk7XG5cbiAgY29uc3QgcmVzaGFwZUlucHV0czogUmVzaGFwZUlucHV0cyA9IHt4OiBwYWRkZWRYfTtcbiAgY29uc3QgcmVzaGFwZUF0dHJzOiBSZXNoYXBlQXR0cnMgPSB7c2hhcGU6IHJlc2hhcGVkUGFkZGVkU2hhcGV9O1xuICBjb25zdCBwYWRkZWRYUmVzaGFwZWQgPVxuICAgICAgcmVzaGFwZSh7aW5wdXRzOiByZXNoYXBlSW5wdXRzLCBiYWNrZW5kLCBhdHRyczogcmVzaGFwZUF0dHJzfSk7XG5cbiAgY29uc3QgdHJhbnNwb3NlSW5wdXRzOiBUcmFuc3Bvc2VJbnB1dHMgPSB7eDogcGFkZGVkWFJlc2hhcGVkfTtcbiAgY29uc3QgdHJhbnNwb3NlQXR0cnM6XG4gICAgICBUcmFuc3Bvc2VBdHRycyA9IHtwZXJtOiBwZXJtdXRlZFJlc2hhcGVkUGFkZGVkUGVybXV0YXRpb259O1xuICBjb25zdCBwYWRkZWRYVCA9XG4gICAgICB0cmFuc3Bvc2Uoe2lucHV0czogdHJhbnNwb3NlSW5wdXRzLCBiYWNrZW5kLCBhdHRyczogdHJhbnNwb3NlQXR0cnN9KTtcblxuICBjb25zdCByZXN1bHRSZXNoYXBlSW5wdXRzOiBSZXNoYXBlSW5wdXRzID0ge3g6IHBhZGRlZFhUfTtcbiAgY29uc3QgcmVzdWx0UmVzaGFwZUF0dHJzOiBSZXNoYXBlQXR0cnMgPSB7c2hhcGU6IGZsYXR0ZW5TaGFwZX07XG4gIGNvbnN0IHJlc3VsdCA9IHJlc2hhcGUoXG4gICAgICB7aW5wdXRzOiByZXN1bHRSZXNoYXBlSW5wdXRzLCBiYWNrZW5kLCBhdHRyczogcmVzdWx0UmVzaGFwZUF0dHJzfSk7XG5cbiAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhwYWRkZWRYKTtcbiAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhwYWRkZWRYUmVzaGFwZWQpO1xuICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKHBhZGRlZFhUKTtcblxuICByZXR1cm4gcmVzdWx0O1xufVxuXG5leHBvcnQgY29uc3Qgc3BhY2VUb0JhdGNoTkRDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogU3BhY2VUb0JhdGNoTkQsXG4gIGJhY2tlbmROYW1lOiAnY3B1JyxcbiAga2VybmVsRnVuYzogc3BhY2VUb0JhdGNoTkQgYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jXG59O1xuIl19