/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, SpaceToBatchND, util } from '@tensorflow/tfjs-core';
|
import { padV2 } from './PadV2';
|
import { reshape } from './Reshape';
|
import { transpose } from './Transpose';
|
export const spaceToBatchND = (args) => {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { blockShape, paddings } = attrs;
|
util.assert(x.shape.length <= 4, () => 'spaceToBatchND for rank > 4 with a WebGL backend not ' +
|
'implemented yet');
|
const prod = blockShape.reduce((a, b) => a * b);
|
const completePaddings = [[0, 0]];
|
completePaddings.push(...paddings);
|
for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {
|
completePaddings.push([0, 0]);
|
}
|
const toDispose = [];
|
const paddedX = padV2({
|
inputs: { x },
|
backend,
|
attrs: { paddings: completePaddings, constantValue: 0 }
|
});
|
const reshapedPaddedShape = backend_util.getReshaped(paddedX.shape, blockShape, prod, false);
|
const permutedReshapedPaddedPermutation = backend_util.getPermuted(reshapedPaddedShape.length, blockShape.length, false);
|
const flattenShape = backend_util.getReshapedPermuted(paddedX.shape, blockShape, prod, false);
|
const reshapedPaddedX = reshape({ inputs: { x: paddedX }, backend, attrs: { shape: reshapedPaddedShape } });
|
const paddedXT = transpose({
|
inputs: { x: reshapedPaddedX },
|
backend,
|
attrs: { perm: permutedReshapedPaddedPermutation }
|
});
|
const result = reshape({ inputs: { x: paddedXT }, backend, attrs: { shape: flattenShape } });
|
toDispose.push(paddedX);
|
toDispose.push(reshapedPaddedX);
|
toDispose.push(paddedXT);
|
toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
return result;
|
};
|
export const spaceToBatchNDConfig = {
|
kernelName: SpaceToBatchND,
|
backendName: 'webgl',
|
kernelFunc: spaceToBatchND
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3BhY2VUb0JhdGNoTkQuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtd2ViZ2wvc3JjL2tlcm5lbHMvU3BhY2VUb0JhdGNoTkQudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBNEIsY0FBYyxFQUF5RCxJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUkxSixPQUFPLEVBQUMsS0FBSyxFQUFDLE1BQU0sU0FBUyxDQUFDO0FBQzlCLE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFDbEMsT0FBTyxFQUFDLFNBQVMsRUFBQyxNQUFNLGFBQWEsQ0FBQztBQUV0QyxNQUFNLENBQUMsTUFBTSxjQUFjLEdBQUcsQ0FBQyxJQUk5QixFQUFjLEVBQUU7SUFDZixNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLENBQUMsRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUNuQixNQUFNLEVBQUMsVUFBVSxFQUFFLFFBQVEsRUFBQyxHQUFHLEtBQUssQ0FBQztJQUVyQyxJQUFJLENBQUMsTUFBTSxDQUNQLENBQUMsQ0FBQyxLQUFLLENBQUMsTUFBTSxJQUFJLENBQUMsRUFDbkIsR0FBRyxFQUFFLENBQUMsdURBQXVEO1FBQ3pELGlCQUFpQixDQUFDLENBQUM7SUFFM0IsTUFBTSxJQUFJLEdBQUcsVUFBVSxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsRUFBRSxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQztJQUVoRCxNQUFNLGdCQUFnQixHQUE0QixDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUM7SUFDM0QsZ0JBQWdCLENBQUMsSUFBSSxDQUFDLEdBQUcsUUFBbUMsQ0FBQyxDQUFDO0lBQzlELEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxHQUFHLFVBQVUsQ0FBQyxNQUFNLEVBQUUsQ0FBQyxHQUFHLENBQUMsQ0FBQyxLQUFLLENBQUMsTUFBTSxFQUFFLEVBQUUsQ0FBQyxFQUFFO1FBQzNELGdCQUFnQixDQUFDLElBQUksQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDO0tBQy9CO0lBRUQsTUFBTSxTQUFTLEdBQUcsRUFBRSxDQUFDO0lBRXJCLE1BQU0sT0FBTyxHQUFHLEtBQUssQ0FBQztRQUNwQixNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUM7UUFDWCxPQUFPO1FBQ1AsS0FBSyxFQUFFLEVBQUMsUUFBUSxFQUFFLGdCQUFnQixFQUFFLGFBQWEsRUFBRSxDQUFDLEVBQUM7S0FDdEQsQ0FBQyxDQUFDO0lBRUgsTUFBTSxtQkFBbUIsR0FDckIsWUFBWSxDQUFDLFdBQVcsQ0FBQyxPQUFPLENBQUMsS0FBSyxFQUFFLFVBQVUsRUFBRSxJQUFJLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFFckUsTUFBTSxpQ0FBaUMsR0FBRyxZQUFZLENBQUMsV0FBVyxDQUM5RCxtQkFBbUIsQ0FBQyxNQUFNLEVBQUUsVUFBVSxDQUFDLE1BQU0sRUFBRSxLQUFLLENBQUMsQ0FBQztJQUUxRCxNQUFNLFlBQVksR0FDZCxZQUFZLENBQUMsbUJBQW1CLENBQUMsT0FBTyxDQUFDLEtBQUssRUFBRSxVQUFVLEVBQUUsSUFBSSxFQUFFLEtBQUssQ0FBQyxDQUFDO0lBRTdFLE1BQU0sZUFBZSxHQUFHLE9BQU8sQ0FDM0IsRUFBQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUUsT0FBTyxFQUFDLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxFQUFDLEtBQUssRUFBRSxtQkFBbUIsRUFBQyxFQUFDLENBQUMsQ0FBQztJQUUxRSxNQUFNLFFBQVEsR0FBRyxTQUFTLENBQUM7UUFDekIsTUFBTSxFQUFFLEVBQUMsQ0FBQyxFQUFFLGVBQWUsRUFBQztRQUM1QixPQUFPO1FBQ1AsS0FBSyxFQUFFLEVBQUMsSUFBSSxFQUFFLGlDQUFpQyxFQUFDO0tBQ2pELENBQUMsQ0FBQztJQUVILE1BQU0sTUFBTSxHQUNSLE9BQU8sQ0FBQyxFQUFDLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxRQUFRLEVBQUMsRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFFLEVBQUMsS0FBSyxFQUFFLFlBQVksRUFBQyxFQUFDLENBQUMsQ0FBQztJQUU1RSxTQUFTLENBQUMsSUFBSSxDQUFDLE9BQU8sQ0FBQyxDQUFDO0lBQ3hCLFNBQVMsQ0FBQyxJQUFJLENBQUMsZUFBZSxDQUFDLENBQUM7SUFDaEMsU0FBUyxDQUFDLElBQUksQ0FBQyxRQUFRLENBQUMsQ0FBQztJQUV6QixTQUFTLENBQUMsT0FBTyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsT0FBTyxDQUFDLDZCQUE2QixDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7SUFFakUsT0FBTyxNQUFNLENBQUM7QUFDaEIsQ0FBQyxDQUFDO0FBRUYsTUFBTSxDQUFDLE1BQU0sb0JBQW9CLEdBQWlCO0lBQ2hELFVBQVUsRUFBRSxjQUFjO0lBQzFCLFdBQVcsRUFBRSxPQUFPO0lBQ3BCLFVBQVUsRUFBRSxjQUF1QztDQUNwRCxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBTcGFjZVRvQmF0Y2hORCwgU3BhY2VUb0JhdGNoTkRBdHRycywgU3BhY2VUb0JhdGNoTkRJbnB1dHMsIFRlbnNvckluZm8sIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRXZWJHTH0gZnJvbSAnLi4vYmFja2VuZF93ZWJnbCc7XG5cbmltcG9ydCB7cGFkVjJ9IGZyb20gJy4vUGFkVjInO1xuaW1wb3J0IHtyZXNoYXBlfSBmcm9tICcuL1Jlc2hhcGUnO1xuaW1wb3J0IHt0cmFuc3Bvc2V9IGZyb20gJy4vVHJhbnNwb3NlJztcblxuZXhwb3J0IGNvbnN0IHNwYWNlVG9CYXRjaE5EID0gKGFyZ3M6IHtcbiAgaW5wdXRzOiBTcGFjZVRvQmF0Y2hORElucHV0cyxcbiAgYmFja2VuZDogTWF0aEJhY2tlbmRXZWJHTCxcbiAgYXR0cnM6IFNwYWNlVG9CYXRjaE5EQXR0cnNcbn0pOiBUZW5zb3JJbmZvID0+IHtcbiAgY29uc3Qge2lucHV0cywgYmFja2VuZCwgYXR0cnN9ID0gYXJncztcbiAgY29uc3Qge3h9ID0gaW5wdXRzO1xuICBjb25zdCB7YmxvY2tTaGFwZSwgcGFkZGluZ3N9ID0gYXR0cnM7XG5cbiAgdXRpbC5hc3NlcnQoXG4gICAgICB4LnNoYXBlLmxlbmd0aCA8PSA0LFxuICAgICAgKCkgPT4gJ3NwYWNlVG9CYXRjaE5EIGZvciByYW5rID4gNCB3aXRoIGEgV2ViR0wgYmFja2VuZCBub3QgJyArXG4gICAgICAgICAgJ2ltcGxlbWVudGVkIHlldCcpO1xuXG4gIGNvbnN0IHByb2QgPSBibG9ja1NoYXBlLnJlZHVjZSgoYSwgYikgPT4gYSAqIGIpO1xuXG4gIGNvbnN0IGNvbXBsZXRlUGFkZGluZ3M6IEFycmF5PFtudW1iZXIsIG51bWJlcl0+ID0gW1swLCAwXV07XG4gIGNvbXBsZXRlUGFkZGluZ3MucHVzaCguLi5wYWRkaW5ncyBhcyBBcnJheTxbbnVtYmVyLCBudW1iZXJdPik7XG4gIGZvciAobGV0IGkgPSAxICsgYmxvY2tTaGFwZS5sZW5ndGg7IGkgPCB4LnNoYXBlLmxlbmd0aDsgKytpKSB7XG4gICAgY29tcGxldGVQYWRkaW5ncy5wdXNoKFswLCAwXSk7XG4gIH1cblxuICBjb25zdCB0b0Rpc3Bvc2UgPSBbXTtcblxuICBjb25zdCBwYWRkZWRYID0gcGFkVjIoe1xuICAgIGlucHV0czoge3h9LFxuICAgIGJhY2tlbmQsXG4gICAgYXR0cnM6IHtwYWRkaW5nczogY29tcGxldGVQYWRkaW5ncywgY29uc3RhbnRWYWx1ZTogMH1cbiAgfSk7XG5cbiAgY29uc3QgcmVzaGFwZWRQYWRkZWRTaGFwZSA9XG4gICAgICBiYWNrZW5kX3V0aWwuZ2V0UmVzaGFwZWQocGFkZGVkWC5zaGFwZSwgYmxvY2tTaGFwZSwgcHJvZCwgZmFsc2UpO1xuXG4gIGNvbnN0IHBlcm11dGVkUmVzaGFwZWRQYWRkZWRQZXJtdXRhdGlvbiA9IGJhY2tlbmRfdXRpbC5nZXRQZXJtdXRlZChcbiAgICAgIHJlc2hhcGVkUGFkZGVkU2hhcGUubGVuZ3RoLCBibG9ja1NoYXBlLmxlbmd0aCwgZmFsc2UpO1xuXG4gIGNvbnN0IGZsYXR0ZW5TaGFwZSA9XG4gICAgICBiYWNrZW5kX3V0aWwuZ2V0UmVzaGFwZWRQZXJtdXRlZChwYWRkZWRYLnNoYXBlLCBibG9ja1NoYXBlLCBwcm9kLCBmYWxzZSk7XG5cbiAgY29uc3QgcmVzaGFwZWRQYWRkZWRYID0gcmVzaGFwZShcbiAgICAgIHtpbnB1dHM6IHt4OiBwYWRkZWRYfSwgYmFja2VuZCwgYXR0cnM6IHtzaGFwZTogcmVzaGFwZWRQYWRkZWRTaGFwZX19KTtcblxuICBjb25zdCBwYWRkZWRYVCA9IHRyYW5zcG9zZSh7XG4gICAgaW5wdXRzOiB7eDogcmVzaGFwZWRQYWRkZWRYfSxcbiAgICBiYWNrZW5kLFxuICAgIGF0dHJzOiB7cGVybTogcGVybXV0ZWRSZXNoYXBlZFBhZGRlZFBlcm11dGF0aW9ufVxuICB9KTtcblxuICBjb25zdCByZXN1bHQgPVxuICAgICAgcmVzaGFwZSh7aW5wdXRzOiB7eDogcGFkZGVkWFR9LCBiYWNrZW5kLCBhdHRyczoge3NoYXBlOiBmbGF0dGVuU2hhcGV9fSk7XG5cbiAgdG9EaXNwb3NlLnB1c2gocGFkZGVkWCk7XG4gIHRvRGlzcG9zZS5wdXNoKHJlc2hhcGVkUGFkZGVkWCk7XG4gIHRvRGlzcG9zZS5wdXNoKHBhZGRlZFhUKTtcblxuICB0b0Rpc3Bvc2UuZm9yRWFjaCh0ID0+IGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8odCkpO1xuXG4gIHJldHVybiByZXN1bHQ7XG59O1xuXG5leHBvcnQgY29uc3Qgc3BhY2VUb0JhdGNoTkRDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogU3BhY2VUb0JhdGNoTkQsXG4gIGJhY2tlbmROYW1lOiAnd2ViZ2wnLFxuICBrZXJuZWxGdW5jOiBzcGFjZVRvQmF0Y2hORCBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmNcbn07XG4iXX0=
|