/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
|
/**
|
* Gets the new shape of the input Tensor after it's been reshaped
|
* to:
|
* [blockShape[0], ..., blockShape[M-1], batch / prod(blockShape),
|
* inputShape[1], ..., inputShape[N-1]]
|
*
|
* See step 1: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
|
*/
|
export function getReshaped(
|
inputShape: number[], blockShape: number[], prod: number,
|
batchToSpace = true): number[] {
|
let reshaped: number[] = [];
|
if (batchToSpace) {
|
reshaped = reshaped.concat(blockShape.slice(0));
|
reshaped.push(inputShape[0] / prod);
|
reshaped = reshaped.concat(inputShape.slice(1));
|
} else {
|
reshaped = reshaped.concat(inputShape[0]);
|
const spatialLength = blockShape.length;
|
for (let i = 0; i < spatialLength; ++i) {
|
reshaped =
|
reshaped.concat([inputShape[i + 1] / blockShape[i], blockShape[i]]);
|
}
|
reshaped = reshaped.concat(inputShape.slice(spatialLength + 1));
|
}
|
return reshaped;
|
}
|
|
/**
|
* Gets the permutation that will transpose the dimensions of the
|
* reshaped tensor to shape:
|
*
|
* [batch / prod(block_shape),inputShape[1], blockShape[0], ...,
|
* inputShape[M], blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]
|
*
|
* see step 2: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
|
*/
|
export function getPermuted(
|
reshapedRank: number, blockShapeRank: number,
|
batchToSpace = true): number[] {
|
const permuted = [];
|
if (batchToSpace) {
|
permuted.push(blockShapeRank);
|
for (let i = blockShapeRank + 1; i < reshapedRank; ++i) {
|
if (i <= 2 * blockShapeRank) {
|
permuted.push(i);
|
permuted.push(i - (blockShapeRank + 1));
|
} else {
|
permuted.push(i);
|
}
|
}
|
} else {
|
const permutedBeforeBatch = [];
|
const permutedAfterBatch = [];
|
for (let i = 1; i < reshapedRank; ++i) {
|
if (i >= blockShapeRank * 2 + 1 || i % 2 === 1) {
|
permutedAfterBatch.push(i);
|
} else {
|
permutedBeforeBatch.push(i);
|
}
|
}
|
permuted.push(...permutedBeforeBatch);
|
permuted.push(0);
|
permuted.push(...permutedAfterBatch);
|
}
|
return permuted;
|
}
|
|
/**
|
* Gets the shape of the reshaped and permuted input Tensor before any cropping
|
* is applied. The new shape will be:
|
*
|
* [batch / prod(blockShape),inputShape[1] * blockShape[0], ...,
|
* inputShape[M] * blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]
|
*
|
* See step 3: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
|
*/
|
export function getReshapedPermuted(
|
inputShape: number[], blockShape: number[], prod: number,
|
batchToSpace = true): number[] {
|
const reshapedPermuted = [];
|
|
if (batchToSpace) {
|
reshapedPermuted.push(inputShape[0] / prod);
|
} else {
|
reshapedPermuted.push(inputShape[0] * prod);
|
}
|
|
for (let i = 1; i < inputShape.length; ++i) {
|
if (i <= blockShape.length) {
|
if (batchToSpace) {
|
reshapedPermuted.push(blockShape[i - 1] * inputShape[i]);
|
} else {
|
reshapedPermuted.push(inputShape[i] / blockShape[i - 1]);
|
}
|
} else {
|
reshapedPermuted.push(inputShape[i]);
|
}
|
}
|
|
return reshapedPermuted;
|
}
|
|
/**
|
* Converts the crops argument into the beginning coordinates of a slice
|
* operation.
|
*/
|
export function getSliceBeginCoords(
|
crops: number[][], blockShape: number): number[] {
|
const sliceBeginCoords = [0];
|
for (let i = 0; i < blockShape; ++i) {
|
sliceBeginCoords.push(crops[i][0]);
|
}
|
return sliceBeginCoords;
|
}
|
|
/**
|
* Converts the crops argument into the size of a slice operation. When
|
* combined with getSliceBeginCoords this function allows the reshaped and
|
* permuted Tensor to be cropped to its final output shape of:
|
*
|
* inputShape[1] * blockShape[0] - crops[0,0] - crops[0,1], ...,
|
* inputShape[M] * blockShape[M-1] -crops[M-1,0] -
|
* crops[M-1,1],inputShape[M+1], ..., inputShape[N-1]]
|
*
|
* See step 4: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
|
*/
|
export function getSliceSize(
|
uncroppedShape: number[], crops: number[][], blockShape: number): number[] {
|
const sliceSize = uncroppedShape.slice(0, 1);
|
for (let i = 0; i < blockShape; ++i) {
|
sliceSize.push(uncroppedShape[i + 1] - crops[i][0] - crops[i][1]);
|
}
|
|
return sliceSize;
|
}
|