/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, SpaceToBatchND, util } from '@tensorflow/tfjs-core';
|
import { assertNotComplex } from '../cpu_util';
|
import { padV2Config } from './PadV2';
|
import { reshape } from './Reshape';
|
import { transpose } from './Transpose';
|
export function spaceToBatchND(args) {
|
const { inputs, backend, attrs } = args;
|
const { x } = inputs;
|
const { blockShape, paddings } = attrs;
|
assertNotComplex([x], 'spaceToBatchND');
|
const prod = util.sizeFromShape(blockShape);
|
const completePaddings = [[0, 0]];
|
completePaddings.push(...paddings);
|
for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {
|
completePaddings.push([0, 0]);
|
}
|
const paddedX = padV2Config.kernelFunc({
|
inputs: { x },
|
backend,
|
attrs: { paddings: completePaddings, constantValue: 0 }
|
});
|
const reshapedPaddedShape = backend_util.getReshaped(paddedX.shape, blockShape, prod, false);
|
const permutedReshapedPaddedPermutation = backend_util.getPermuted(reshapedPaddedShape.length, blockShape.length, false);
|
const flattenShape = backend_util.getReshapedPermuted(paddedX.shape, blockShape, prod, false);
|
const reshapeInputs = { x: paddedX };
|
const reshapeAttrs = { shape: reshapedPaddedShape };
|
const paddedXReshaped = reshape({ inputs: reshapeInputs, backend, attrs: reshapeAttrs });
|
const transposeInputs = { x: paddedXReshaped };
|
const transposeAttrs = { perm: permutedReshapedPaddedPermutation };
|
const paddedXT = transpose({ inputs: transposeInputs, backend, attrs: transposeAttrs });
|
const resultReshapeInputs = { x: paddedXT };
|
const resultReshapeAttrs = { shape: flattenShape };
|
const result = reshape({ inputs: resultReshapeInputs, backend, attrs: resultReshapeAttrs });
|
backend.disposeIntermediateTensorInfo(paddedX);
|
backend.disposeIntermediateTensorInfo(paddedXReshaped);
|
backend.disposeIntermediateTensorInfo(paddedXT);
|
return result;
|
}
|
export const spaceToBatchNDConfig = {
|
kernelName: SpaceToBatchND,
|
backendName: 'cpu',
|
kernelFunc: spaceToBatchND
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiU3BhY2VUb0JhdGNoTkQuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtY3B1L3NyYy9rZXJuZWxzL1NwYWNlVG9CYXRjaE5ELnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxZQUFZLEVBQXlELGNBQWMsRUFBMEYsSUFBSSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFHeE4sT0FBTyxFQUFDLGdCQUFnQixFQUFDLE1BQU0sYUFBYSxDQUFDO0FBRTdDLE9BQU8sRUFBQyxXQUFXLEVBQUMsTUFBTSxTQUFTLENBQUM7QUFDcEMsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNsQyxPQUFPLEVBQUMsU0FBUyxFQUFDLE1BQU0sYUFBYSxDQUFDO0FBRXRDLE1BQU0sVUFBVSxjQUFjLENBQUMsSUFJOUI7SUFDQyxNQUFNLEVBQUMsTUFBTSxFQUFFLE9BQU8sRUFBRSxLQUFLLEVBQUMsR0FBRyxJQUFJLENBQUM7SUFDdEMsTUFBTSxFQUFDLENBQUMsRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUNuQixNQUFNLEVBQUMsVUFBVSxFQUFFLFFBQVEsRUFBQyxHQUFHLEtBQUssQ0FBQztJQUVyQyxnQkFBZ0IsQ0FBQyxDQUFDLENBQUMsQ0FBQyxFQUFFLGdCQUFnQixDQUFDLENBQUM7SUFFeEMsTUFBTSxJQUFJLEdBQUcsSUFBSSxDQUFDLGFBQWEsQ0FBQyxVQUFVLENBQUMsQ0FBQztJQUU1QyxNQUFNLGdCQUFnQixHQUE0QixDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLENBQUM7SUFDM0QsZ0JBQWdCLENBQUMsSUFBSSxDQUFDLEdBQUksUUFBb0MsQ0FBQyxDQUFDO0lBRWhFLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxHQUFHLFVBQVUsQ0FBQyxNQUFNLEVBQUUsQ0FBQyxHQUFHLENBQUMsQ0FBQyxLQUFLLENBQUMsTUFBTSxFQUFFLEVBQUUsQ0FBQyxFQUFFO1FBQzNELGdCQUFnQixDQUFDLElBQUksQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDO0tBQy9CO0lBRUQsTUFBTSxPQUFPLEdBQUcsV0FBVyxDQUFDLFVBQVUsQ0FBQztRQUNyQyxNQUFNLEVBQUUsRUFBQyxDQUFDLEVBQUM7UUFDWCxPQUFPO1FBQ1AsS0FBSyxFQUFFLEVBQUMsUUFBUSxFQUFFLGdCQUFnQixFQUFFLGFBQWEsRUFBRSxDQUFDLEVBQUM7S0FDdEQsQ0FBZSxDQUFDO0lBRWpCLE1BQU0sbUJBQW1CLEdBQ3JCLFlBQVksQ0FBQyxXQUFXLENBQUMsT0FBTyxDQUFDLEtBQUssRUFBRSxVQUFVLEVBQUUsSUFBSSxFQUFFLEtBQUssQ0FBQyxDQUFDO0lBRXJFLE1BQU0saUNBQWlDLEdBQUcsWUFBWSxDQUFDLFdBQVcsQ0FDOUQsbUJBQW1CLENBQUMsTUFBTSxFQUFFLFVBQVUsQ0FBQyxNQUFNLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFFMUQsTUFBTSxZQUFZLEdBQ2QsWUFBWSxDQUFDLG1CQUFtQixDQUFDLE9BQU8sQ0FBQyxLQUFLLEVBQUUsVUFBVSxFQUFFLElBQUksRUFBRSxLQUFLLENBQUMsQ0FBQztJQUU3RSxNQUFNLGFBQWEsR0FBa0IsRUFBQyxDQUFDLEVBQUUsT0FBTyxFQUFDLENBQUM7SUFDbEQsTUFBTSxZQUFZLEdBQWlCLEVBQUMsS0FBSyxFQUFFLG1CQUFtQixFQUFDLENBQUM7SUFDaEUsTUFBTSxlQUFlLEdBQ2pCLE9BQU8sQ0FBQyxFQUFDLE1BQU0sRUFBRSxhQUFhLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxZQUFZLEVBQUMsQ0FBQyxDQUFDO0lBRW5FLE1BQU0sZUFBZSxHQUFvQixFQUFDLENBQUMsRUFBRSxlQUFlLEVBQUMsQ0FBQztJQUM5RCxNQUFNLGNBQWMsR0FDQyxFQUFDLElBQUksRUFBRSxpQ0FBaUMsRUFBQyxDQUFDO0lBQy9ELE1BQU0sUUFBUSxHQUNWLFNBQVMsQ0FBQyxFQUFDLE1BQU0sRUFBRSxlQUFlLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxjQUFjLEVBQUMsQ0FBQyxDQUFDO0lBRXpFLE1BQU0sbUJBQW1CLEdBQWtCLEVBQUMsQ0FBQyxFQUFFLFFBQVEsRUFBQyxDQUFDO0lBQ3pELE1BQU0sa0JBQWtCLEdBQWlCLEVBQUMsS0FBSyxFQUFFLFlBQVksRUFBQyxDQUFDO0lBQy9ELE1BQU0sTUFBTSxHQUFHLE9BQU8sQ0FDbEIsRUFBQyxNQUFNLEVBQUUsbUJBQW1CLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBRSxrQkFBa0IsRUFBQyxDQUFDLENBQUM7SUFFdkUsT0FBTyxDQUFDLDZCQUE2QixDQUFDLE9BQU8sQ0FBQyxDQUFDO0lBQy9DLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxlQUFlLENBQUMsQ0FBQztJQUN2RCxPQUFPLENBQUMsNkJBQTZCLENBQUMsUUFBUSxDQUFDLENBQUM7SUFFaEQsT0FBTyxNQUFNLENBQUM7QUFDaEIsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLG9CQUFvQixHQUFpQjtJQUNoRCxVQUFVLEVBQUUsY0FBYztJQUMxQixXQUFXLEVBQUUsS0FBSztJQUNsQixVQUFVLEVBQUUsY0FBdUM7Q0FDcEQsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtiYWNrZW5kX3V0aWwsIEtlcm5lbENvbmZpZywgS2VybmVsRnVuYywgUmVzaGFwZUF0dHJzLCBSZXNoYXBlSW5wdXRzLCBTcGFjZVRvQmF0Y2hORCwgU3BhY2VUb0JhdGNoTkRBdHRycywgU3BhY2VUb0JhdGNoTkRJbnB1dHMsIFRlbnNvckluZm8sIFRyYW5zcG9zZUF0dHJzLCBUcmFuc3Bvc2VJbnB1dHMsIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRDUFV9IGZyb20gJy4uL2JhY2tlbmRfY3B1JztcbmltcG9ydCB7YXNzZXJ0Tm90Q29tcGxleH0gZnJvbSAnLi4vY3B1X3V0aWwnO1xuXG5pbXBvcnQge3BhZFYyQ29uZmlnfSBmcm9tICcuL1BhZFYyJztcbmltcG9ydCB7cmVzaGFwZX0gZnJvbSAnLi9SZXNoYXBlJztcbmltcG9ydCB7dHJhbnNwb3NlfSBmcm9tICcuL1RyYW5zcG9zZSc7XG5cbmV4cG9ydCBmdW5jdGlvbiBzcGFjZVRvQmF0Y2hORChhcmdzOiB7XG4gIGlucHV0czogU3BhY2VUb0JhdGNoTkRJbnB1dHMsXG4gIGJhY2tlbmQ6IE1hdGhCYWNrZW5kQ1BVLFxuICBhdHRyczogU3BhY2VUb0JhdGNoTkRBdHRyc1xufSk6IFRlbnNvckluZm8ge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7eH0gPSBpbnB1dHM7XG4gIGNvbnN0IHtibG9ja1NoYXBlLCBwYWRkaW5nc30gPSBhdHRycztcblxuICBhc3NlcnROb3RDb21wbGV4KFt4XSwgJ3NwYWNlVG9CYXRjaE5EJyk7XG5cbiAgY29uc3QgcHJvZCA9IHV0aWwuc2l6ZUZyb21TaGFwZShibG9ja1NoYXBlKTtcblxuICBjb25zdCBjb21wbGV0ZVBhZGRpbmdzOiBBcnJheTxbbnVtYmVyLCBudW1iZXJdPiA9IFtbMCwgMF1dO1xuICBjb21wbGV0ZVBhZGRpbmdzLnB1c2goLi4uKHBhZGRpbmdzIGFzIEFycmF5PFtudW1iZXIsIG51bWJlcl0+KSk7XG5cbiAgZm9yIChsZXQgaSA9IDEgKyBibG9ja1NoYXBlLmxlbmd0aDsgaSA8IHguc2hhcGUubGVuZ3RoOyArK2kpIHtcbiAgICBjb21wbGV0ZVBhZGRpbmdzLnB1c2goWzAsIDBdKTtcbiAgfVxuXG4gIGNvbnN0IHBhZGRlZFggPSBwYWRWMkNvbmZpZy5rZXJuZWxGdW5jKHtcbiAgICBpbnB1dHM6IHt4fSxcbiAgICBiYWNrZW5kLFxuICAgIGF0dHJzOiB7cGFkZGluZ3M6IGNvbXBsZXRlUGFkZGluZ3MsIGNvbnN0YW50VmFsdWU6IDB9XG4gIH0pIGFzIFRlbnNvckluZm87XG5cbiAgY29uc3QgcmVzaGFwZWRQYWRkZWRTaGFwZSA9XG4gICAgICBiYWNrZW5kX3V0aWwuZ2V0UmVzaGFwZWQocGFkZGVkWC5zaGFwZSwgYmxvY2tTaGFwZSwgcHJvZCwgZmFsc2UpO1xuXG4gIGNvbnN0IHBlcm11dGVkUmVzaGFwZWRQYWRkZWRQZXJtdXRhdGlvbiA9IGJhY2tlbmRfdXRpbC5nZXRQZXJtdXRlZChcbiAgICAgIHJlc2hhcGVkUGFkZGVkU2hhcGUubGVuZ3RoLCBibG9ja1NoYXBlLmxlbmd0aCwgZmFsc2UpO1xuXG4gIGNvbnN0IGZsYXR0ZW5TaGFwZSA9XG4gICAgICBiYWNrZW5kX3V0aWwuZ2V0UmVzaGFwZWRQZXJtdXRlZChwYWRkZWRYLnNoYXBlLCBibG9ja1NoYXBlLCBwcm9kLCBmYWxzZSk7XG5cbiAgY29uc3QgcmVzaGFwZUlucHV0czogUmVzaGFwZUlucHV0cyA9IHt4OiBwYWRkZWRYfTtcbiAgY29uc3QgcmVzaGFwZUF0dHJzOiBSZXNoYXBlQXR0cnMgPSB7c2hhcGU6IHJlc2hhcGVkUGFkZGVkU2hhcGV9O1xuICBjb25zdCBwYWRkZWRYUmVzaGFwZWQgPVxuICAgICAgcmVzaGFwZSh7aW5wdXRzOiByZXNoYXBlSW5wdXRzLCBiYWNrZW5kLCBhdHRyczogcmVzaGFwZUF0dHJzfSk7XG5cbiAgY29uc3QgdHJhbnNwb3NlSW5wdXRzOiBUcmFuc3Bvc2VJbnB1dHMgPSB7eDogcGFkZGVkWFJlc2hhcGVkfTtcbiAgY29uc3QgdHJhbnNwb3NlQXR0cnM6XG4gICAgICBUcmFuc3Bvc2VBdHRycyA9IHtwZXJtOiBwZXJtdXRlZFJlc2hhcGVkUGFkZGVkUGVybXV0YXRpb259O1xuICBjb25zdCBwYWRkZWRYVCA9XG4gICAgICB0cmFuc3Bvc2Uoe2lucHV0czogdHJhbnNwb3NlSW5wdXRzLCBiYWNrZW5kLCBhdHRyczogdHJhbnNwb3NlQXR0cnN9KTtcblxuICBjb25zdCByZXN1bHRSZXNoYXBlSW5wdXRzOiBSZXNoYXBlSW5wdXRzID0ge3g6IHBhZGRlZFhUfTtcbiAgY29uc3QgcmVzdWx0UmVzaGFwZUF0dHJzOiBSZXNoYXBlQXR0cnMgPSB7c2hhcGU6IGZsYXR0ZW5TaGFwZX07XG4gIGNvbnN0IHJlc3VsdCA9IHJlc2hhcGUoXG4gICAgICB7aW5wdXRzOiByZXN1bHRSZXNoYXBlSW5wdXRzLCBiYWNrZW5kLCBhdHRyczogcmVzdWx0UmVzaGFwZUF0dHJzfSk7XG5cbiAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhwYWRkZWRYKTtcbiAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhwYWRkZWRYUmVzaGFwZWQpO1xuICBiYWNrZW5kLmRpc3Bvc2VJbnRlcm1lZGlhdGVUZW5zb3JJbmZvKHBhZGRlZFhUKTtcblxuICByZXR1cm4gcmVzdWx0O1xufVxuXG5leHBvcnQgY29uc3Qgc3BhY2VUb0JhdGNoTkRDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogU3BhY2VUb0JhdGNoTkQsXG4gIGJhY2tlbmROYW1lOiAnY3B1JyxcbiAga2VybmVsRnVuYzogc3BhY2VUb0JhdGNoTkQgYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jXG59O1xuIl19
|