gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util, SpaceToBatchND, util } from '@tensorflow/tfjs-core';
import { padV2 } from './PadV2';
import { reshape } from './Reshape';
import { transpose } from './Transpose';
export const spaceToBatchND = (args) => {
    const { inputs, backend, attrs } = args;
    const { x } = inputs;
    const { blockShape, paddings } = attrs;
    util.assert(x.shape.length <= 4, () => 'spaceToBatchND for rank > 4 with a WebGL backend not ' +
        'implemented yet');
    const prod = blockShape.reduce((a, b) => a * b);
    const completePaddings = [[0, 0]];
    completePaddings.push(...paddings);
    for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {
        completePaddings.push([0, 0]);
    }
    const toDispose = [];
    const paddedX = padV2({
        inputs: { x },
        backend,
        attrs: { paddings: completePaddings, constantValue: 0 }
    });
    const reshapedPaddedShape = backend_util.getReshaped(paddedX.shape, blockShape, prod, false);
    const permutedReshapedPaddedPermutation = backend_util.getPermuted(reshapedPaddedShape.length, blockShape.length, false);
    const flattenShape = backend_util.getReshapedPermuted(paddedX.shape, blockShape, prod, false);
    const reshapedPaddedX = reshape({ inputs: { x: paddedX }, backend, attrs: { shape: reshapedPaddedShape } });
    const paddedXT = transpose({
        inputs: { x: reshapedPaddedX },
        backend,
        attrs: { perm: permutedReshapedPaddedPermutation }
    });
    const result = reshape({ inputs: { x: paddedXT }, backend, attrs: { shape: flattenShape } });
    toDispose.push(paddedX);
    toDispose.push(reshapedPaddedX);
    toDispose.push(paddedXT);
    toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
    return result;
};
export const spaceToBatchNDConfig = {
    kernelName: SpaceToBatchND,
    backendName: 'webgl',
    kernelFunc: spaceToBatchND
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3BhY2VUb0JhdGNoTkQuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtd2ViZ2wvc3JjL2tlcm5lbHMvU3BhY2VUb0JhdGNoTkQudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBNEIsY0FBYyxFQUF5RCxJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUkxSixPQUFPLEVBQUMsS0FBSyxFQUFDLE1BQU0sU0FBUyxDQUFDO0FBQzlCLE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFDbEMsT0FBTyxFQUFDLFNBQVMsRUFBQyxNQUFNLGFBQWEsQ0FBQztBQUV0QyxNQUFNLENBQUMsTUFBTSxjQUFjLEdBQUcsQ0FBQyxJQUk5QixFQUFjLEVBQUU7SUFDZixNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLENBQUMsRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUNuQixNQUFNLEVBQUMsVUFBVSxFQUFFLFFBQVEsRUFBQyxHQUFHLEtBQUssQ0FBQztJQUVyQyxJQUFJLENBQUMsTUFBTSxDQUNQLENBQUMsQ0FBQyxLQUFLLENBQUMsTUFBTSxJQUFJLENBQUMsRUFDbkIsR0FBRyxFQUFFLENBQUMsdURBQXVEO1FBQ3pELGlCQUFpQixDQUFDLENBQUM7SUFFM0IsTUFBTSxJQUFJLEdBQUcsVUFBVSxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsRUFBRSxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQztJQUVoRCxNQUFNLGdCQUFnQixHQUE0QixDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUM7SUFDM0QsZ0JBQWdCLENBQUMsSUFBSSxDQUFDLEdBQUcsUUFBbUMsQ0FBQyxDQUFDO0lBQzlELEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxHQUFHLFVBQVUsQ0FBQyxNQUFNLEVBQUUsQ0FBQyxHQUFHLENBQUMsQ0FBQyxLQUFLLENBQUMsTUFBTSxFQUFFLEVBQUUsQ0FBQyxFQUFFO1FBQzNELGdCQUFnQixDQUFDLElBQUksQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDO0tBQy9CO0lBRUQsTUFBTSxTQUFTLEdBQUcsRUFBRSxDQUFDO0lBRXJCLE1BQU0sT0FBTyxHQUFHLEtBQUssQ0FBQztRQUNwQixNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUM7UUFDWCxPQUFPO1FBQ1AsS0FBSyxFQUFFLEVBQUMsUUFBUSxFQUFFLGdCQUFnQixFQUFFLGFBQWEsRUFBRSxDQUFDLEVBQUM7S0FDdEQsQ0FBQyxDQUFDO0lBRUgsTUFBTSxtQkFBbUIsR0FDckIsWUFBWSxDQUFDLFdBQVcsQ0FBQyxPQUFPLENBQUMsS0FBSyxFQUFFLFVBQVUsRUFBRSxJQUFJLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFFckUsTUFBTSxpQ0FBaUMsR0FBRyxZQUFZLENBQUMsV0FBVyxDQUM5RCxtQkFBbUIsQ0FBQyxNQUFNLEVBQUUsVUFBVSxDQUFDLE1BQU0sRUFBRSxLQUFLLENBQUMsQ0FBQztJQUUxRCxNQUFNLFlBQVksR0FDZCxZQUFZLENBQUMsbUJBQW1CLENBQUMsT0FBTyxDQUFDLEtBQUssRUFBRSxVQUFVLEVBQUUsSUFBSSxFQUFFLEtBQUssQ0FBQyxDQUFDO0lBRTdFLE1BQU0sZUFBZSxHQUFHLE9BQU8sQ0FDM0IsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsT0FBTyxFQUFDLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxFQUFDLEtBQUssRUFBRSxtQkFBbUIsRUFBQyxFQUFDLENBQUMsQ0FBQztJQUUxRSxNQUFNLFFBQVEsR0FBRyxTQUFTLENBQUM7UUFDekIsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLGVBQWUsRUFBQztRQUM1QixPQUFPO1FBQ1AsS0FBSyxFQUFFLEVBQUMsSUFBSSxFQUFFLGlDQUFpQyxFQUFDO0tBQ2pELENBQUMsQ0FBQztJQUVILE1BQU0sTUFBTSxHQUNSLE9BQU8sQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxRQUFRLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLFlBQVksRUFBQyxFQUFDLENBQUMsQ0FBQztJQUU1RSxTQUFTLENBQUMsSUFBSSxDQUFDLE9BQU8sQ0FBQyxDQUFDO0lBQ3hCLFNBQVMsQ0FBQyxJQUFJLENBQUMsZUFBZSxDQUFDLENBQUM7SUFDaEMsU0FBUyxDQUFDLElBQUksQ0FBQyxRQUFRLENBQUMsQ0FBQztJQUV6QixTQUFTLENBQUMsT0FBTyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsT0FBTyxDQUFDLDZCQUE2QixDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7SUFFakUsT0FBTyxNQUFNLENBQUM7QUFDaEIsQ0FBQyxDQUFDO0FBRUYsTUFBTSxDQUFDLE1BQU0sb0JBQW9CLEdBQWlCO0lBQ2hELFVBQVUsRUFBRSxjQUFjO0lBQzFCLFdBQVcsRUFBRSxPQUFPO0lBQ3BCLFVBQVUsRUFBRSxjQUF1QztDQUNwRCxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBTcGFjZVRvQmF0Y2hORCwgU3BhY2VUb0JhdGNoTkRBdHRycywgU3BhY2VUb0JhdGNoTkRJbnB1dHMsIFRlbnNvckluZm8sIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRXZWJHTH0gZnJvbSAnLi4vYmFja2VuZF93ZWJnbCc7XG5cbmltcG9ydCB7cGFkVjJ9IGZyb20gJy4vUGFkVjInO1xuaW1wb3J0IHtyZXNoYXBlfSBmcm9tICcuL1Jlc2hhcGUnO1xuaW1wb3J0IHt0cmFuc3Bvc2V9IGZyb20gJy4vVHJhbnNwb3NlJztcblxuZXhwb3J0IGNvbnN0IHNwYWNlVG9CYXRjaE5EID0gKGFyZ3M6IHtcbiAgaW5wdXRzOiBTcGFjZVRvQmF0Y2hORElucHV0cyxcbiAgYmFja2VuZDogTWF0aEJhY2tlbmRXZWJHTCxcbiAgYXR0cnM6IFNwYWNlVG9CYXRjaE5EQXR0cnNcbn0pOiBUZW5zb3JJbmZvID0+IHtcbiAgY29uc3Qge2lucHV0cywgYmFja2VuZCwgYXR0cnN9ID0gYXJncztcbiAgY29uc3Qge3h9ID0gaW5wdXRzO1xuICBjb25zdCB7YmxvY2tTaGFwZSwgcGFkZGluZ3N9ID0gYXR0cnM7XG5cbiAgdXRpbC5hc3NlcnQoXG4gICAgICB4LnNoYXBlLmxlbmd0aCA8PSA0LFxuICAgICAgKCkgPT4gJ3NwYWNlVG9CYXRjaE5EIGZvciByYW5rID4gNCB3aXRoIGEgV2ViR0wgYmFja2VuZCBub3QgJyArXG4gICAgICAgICAgJ2ltcGxlbWVudGVkIHlldCcpO1xuXG4gIGNvbnN0IHByb2QgPSBibG9ja1NoYXBlLnJlZHVjZSgoYSwgYikgPT4gYSAqIGIpO1xuXG4gIGNvbnN0IGNvbXBsZXRlUGFkZGluZ3M6IEFycmF5PFtudW1iZXIsIG51bWJlcl0+ID0gW1swLCAwXV07XG4gIGNvbXBsZXRlUGFkZGluZ3MucHVzaCguLi5wYWRkaW5ncyBhcyBBcnJheTxbbnVtYmVyLCBudW1iZXJdPik7XG4gIGZvciAobGV0IGkgPSAxICsgYmxvY2tTaGFwZS5sZW5ndGg7IGkgPCB4LnNoYXBlLmxlbmd0aDsgKytpKSB7XG4gICAgY29tcGxldGVQYWRkaW5ncy5wdXNoKFswLCAwXSk7XG4gIH1cblxuICBjb25zdCB0b0Rpc3Bvc2UgPSBbXTtcblxuICBjb25zdCBwYWRkZWRYID0gcGFkVjIoe1xuICAgIGlucHV0czoge3h9LFxuICAgIGJhY2tlbmQsXG4gICAgYXR0cnM6IHtwYWRkaW5nczogY29tcGxldGVQYWRkaW5ncywgY29uc3RhbnRWYWx1ZTogMH1cbiAgfSk7XG5cbiAgY29uc3QgcmVzaGFwZWRQYWRkZWRTaGFwZSA9XG4gICAgICBiYWNrZW5kX3V0aWwuZ2V0UmVzaGFwZWQocGFkZGVkWC5zaGFwZSwgYmxvY2tTaGFwZSwgcHJvZCwgZmFsc2UpO1xuXG4gIGNvbnN0IHBlcm11dGVkUmVzaGFwZWRQYWRkZWRQZXJtdXRhdGlvbiA9IGJhY2tlbmRfdXRpbC5nZXRQZXJtdXRlZChcbiAgICAgIHJlc2hhcGVkUGFkZGVkU2hhcGUubGVuZ3RoLCBibG9ja1NoYXBlLmxlbmd0aCwgZmFsc2UpO1xuXG4gIGNvbnN0IGZsYXR0ZW5TaGFwZSA9XG4gICAgICBiYWNrZW5kX3V0aWwuZ2V0UmVzaGFwZWRQZXJtdXRlZChwYWRkZWRYLnNoYXBlLCBibG9ja1NoYXBlLCBwcm9kLCBmYWxzZSk7XG5cbiAgY29uc3QgcmVzaGFwZWRQYWRkZWRYID0gcmVzaGFwZShcbiAgICAgIHtpbnB1dHM6IHt4OiBwYWRkZWRYfSwgYmFja2VuZCwgYXR0cnM6IHtzaGFwZTogcmVzaGFwZWRQYWRkZWRTaGFwZX19KTtcblxuICBjb25zdCBwYWRkZWRYVCA9IHRyYW5zcG9zZSh7XG4gICAgaW5wdXRzOiB7eDogcmVzaGFwZWRQYWRkZWRYfSxcbiAgICBiYWNrZW5kLFxuICAgIGF0dHJzOiB7cGVybTogcGVybXV0ZWRSZXNoYXBlZFBhZGRlZFBlcm11dGF0aW9ufVxuICB9KTtcblxuICBjb25zdCByZXN1bHQgPVxuICAgICAgcmVzaGFwZSh7aW5wdXRzOiB7eDogcGFkZGVkWFR9LCBiYWNrZW5kLCBhdHRyczoge3NoYXBlOiBmbGF0dGVuU2hhcGV9fSk7XG5cbiAgdG9EaXNwb3NlLnB1c2gocGFkZGVkWCk7XG4gIHRvRGlzcG9zZS5wdXNoKHJlc2hhcGVkUGFkZGVkWCk7XG4gIHRvRGlzcG9zZS5wdXNoKHBhZGRlZFhUKTtcblxuICB0b0Rpc3Bvc2UuZm9yRWFjaCh0ID0+IGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8odCkpO1xuXG4gIHJldHVybiByZXN1bHQ7XG59O1xuXG5leHBvcnQgY29uc3Qgc3BhY2VUb0JhdGNoTkRDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogU3BhY2VUb0JhdGNoTkQsXG4gIGJhY2tlbmROYW1lOiAnd2ViZ2wnLFxuICBrZXJuZWxGdW5jOiBzcGFjZVRvQmF0Y2hORCBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmNcbn07XG4iXX0=