/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { ResizeNearestNeighborGrad, util } from '@tensorflow/tfjs-core';
|
import { assertNotComplex } from '../cpu_util';
|
export function resizeNearestNeighborGrad(args) {
|
const { inputs, backend, attrs } = args;
|
const { images, dy } = inputs;
|
const { alignCorners } = attrs;
|
assertNotComplex([dy, images], 'resizeNearestNeighborGrad');
|
const imagesStrides = util.computeStrides(images.shape);
|
const dyStrides = util.computeStrides(dy.shape);
|
const [batch, xHeight, xWidth, depth] = images.shape;
|
const [, yHeight, yWidth] = dy.shape;
|
const output = new Float32Array(batch * xHeight * xWidth * depth);
|
const dyValues = backend.data.get(dy.dataId).values;
|
// In the backwards pass, we want to find the pixels that were generated
|
// for each pixel in the input image the forward pass
|
const effectiveXSize = [
|
(alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
|
(alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
|
];
|
const effectiveYSize = [
|
(alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
|
(alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
|
];
|
const heightScale = effectiveXSize[0] / effectiveYSize[0];
|
const widthScale = effectiveXSize[1] / effectiveYSize[1];
|
const invHeightScale = 1 / heightScale;
|
const invWidthScale = 1 / widthScale;
|
// This defines the size of the window of values around a particular
|
// index in dy that we want to search for contributions to dx.
|
const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
|
const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
|
// Loop over the output space.
|
for (let b = 0; b < batch; b++) {
|
const batchOffset = b * imagesStrides[0];
|
for (let r = 0; r < xHeight; r++) {
|
const rowOffset = batchOffset + r * imagesStrides[1];
|
// Compute bounds for where in dy we will look
|
const startRLerp = Math.floor(r * invHeightScale);
|
const startDyR = Math.floor(startRLerp - (winHeight / 2));
|
for (let c = 0; c < xWidth; c++) {
|
const colOffset = rowOffset + c * imagesStrides[2];
|
// Compute bounds for where in dy we will look
|
const startCLerp = Math.floor(c * invWidthScale);
|
const startDyC = Math.floor(startCLerp - (winWidth / 2));
|
for (let d = 0; d < depth; d++) {
|
let accum = 0;
|
// loop over dy
|
for (let dyRIndex = 0; dyRIndex < winHeight; dyRIndex++) {
|
const dyR = dyRIndex + startDyR;
|
// Guard against the window exceeding the bounds of dy
|
if (dyR < 0 || dyR >= yHeight) {
|
continue;
|
}
|
const dyROffset = batchOffset + dyR * dyStrides[1];
|
const sourceFracRow = dyR * heightScale;
|
const sourceNearestRow = Math.min(xHeight - 1, alignCorners ? Math.round(sourceFracRow) :
|
Math.floor(sourceFracRow));
|
if (r !== sourceNearestRow) {
|
continue;
|
}
|
for (let dyCIndex = 0; dyCIndex < winWidth; dyCIndex++) {
|
const dyC = dyCIndex + startDyC;
|
// Guard against the window exceeding the bounds of dy
|
if (dyC < 0 || dyC >= yWidth) {
|
continue;
|
}
|
const dyCOffset = dyROffset + dyC * dyStrides[2];
|
const sourceFracCol = dyC * widthScale;
|
const sourceNearestCol = Math.min(xWidth - 1, alignCorners ? Math.round(sourceFracCol) :
|
Math.floor(sourceFracCol));
|
if (c === sourceNearestCol) {
|
accum += dyValues[dyCOffset + d];
|
}
|
}
|
}
|
output[colOffset + d] = accum;
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo(images.shape, images.dtype, output);
|
}
|
export const resizeNearestNeighborGradConfig = {
|
kernelName: ResizeNearestNeighborGrad,
|
backendName: 'cpu',
|
kernelFunc: resizeNearestNeighborGrad
|
};
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"ResizeNearestNeighborGrad.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-cpu/src/kernels/ResizeNearestNeighborGrad.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAA2B,yBAAyB,EAA2F,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAGzL,OAAO,EAAC,gBAAgB,EAAC,MAAM,aAAa,CAAC;AAE7C,MAAM,UAAU,yBAAyB,CAAC,IAIzC;IACC,MAAM,EAAC,MAAM,EAAE,OAAO,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,MAAM,EAAE,EAAE,EAAC,GAAG,MAAM,CAAC;IAC5B,MAAM,EAAC,YAAY,EAAC,GAAG,KAAK,CAAC;IAE7B,gBAAgB,CAAC,CAAC,EAAE,EAAE,MAAM,CAAC,EAAE,2BAA2B,CAAC,CAAC;IAE5D,MAAM,aAAa,GAAG,IAAI,CAAC,cAAc,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;IACxD,MAAM,SAAS,GAAG,IAAI,CAAC,cAAc,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC;IAChD,MAAM,CAAC,KAAK,EAAE,OAAO,EAAE,MAAM,EAAE,KAAK,CAAC,GAAG,MAAM,CAAC,KAAK,CAAC;IACrD,MAAM,CAAC,EAAE,OAAO,EAAE,MAAM,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC;IAErC,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,KAAK,GAAG,OAAO,GAAG,MAAM,GAAG,KAAK,CAAC,CAAC;IAClE,MAAM,QAAQ,GAAG,OAAO,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,MAAoB,CAAC;IAElE,wEAAwE;IACxE,qDAAqD;IAErD,MAAM,cAAc,GAAqB;QACvC,CAAC,YAAY,IAAI,OAAO,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,GAAG,CAAC,CAAC,CAAC,CAAC,OAAO;QACrD,CAAC,YAAY,IAAI,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC,MAAM;KACnD,CAAC;IAEF,MAAM,cAAc,GAAqB;QACvC,CAAC,YAAY,IAAI,OAAO,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,GAAG,CAAC,CAAC,CAAC,CAAC,OAAO;QACrD,CAAC,YAAY,IAAI,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC,MAAM;KACnD,CAAC;IAEF,MAAM,WAAW,GAAG,cAAc,CAAC,CAAC,CAAC,GAAG,cAAc,CAAC,CAAC,CAAC,CAAC;IAC1D,MAAM,UAAU,GAAG,cAAc,CAAC,CAAC,CAAC,GAAG,cAAc,CAAC,CAAC,CAAC,CAAC;IAEzD,MAAM,cAAc,GAAG,CAAC,GAAG,WAAW,CAAC;IACvC,MAAM,aAAa,GAAG,CAAC,GAAG,UAAU,CAAC;IAErC,oEAAoE;IACpE,8DAA8D;IAC9D,MAAM,SAAS,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;IACtD,MAAM,QAAQ,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,aAAa,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;IAEpD,8BAA8B;IAC9B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,EAAE,CAAC,EAAE,EAAE;QAC9B,MAAM,WAAW,GAAG,CAAC,GAAG,aAAa,CAAC,CAAC,CAAC,CAAC;QACzC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,OAAO,EAAE,CAAC,EAAE,EAAE;YAChC,MAAM,SAAS,GAAG,WAAW,GAAG,CAAC,GAAG,aAAa,CAAC,CAAC,CAAC,CAAC;YAErD,8CAA8C;YAC9C,MAAM,UAAU,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,GAAG,cAAc,CAAC,CAAC;YAClD,MAAM,QAAQ,GAAG,IAAI,CAAC,KAAK,CAAC,UAAU,GAAG,CAAC,SAAS,GAAG,CAAC,CAAC,CAAC,CAAC;YAC1D,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,EAAE,CAAC,EAAE,EAAE;gBAC/B,MAAM,SAAS,GAAG,SAAS,GAAG,CAAC,GAAG,aAAa,CAAC,CAAC,CAAC,CAAC;gBAEnD,8CAA8C;gBAC9C,MAAM,UAAU,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,GAAG,aAAa,CAAC,CAAC;gBACjD,MAAM,QAAQ,GAAG,IAAI,CAAC,KAAK,CAAC,UAAU,GAAG,CAAC,QAAQ,GAAG,CAAC,CAAC,CAAC,CAAC;gBAEzD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,EAAE,CAAC,EAAE,EAAE;oBAC9B,IAAI,KAAK,GAAG,CAAC,CAAC;oBACd,eAAe;oBAEf,KAAK,IAAI,QAAQ,GAAG,CAAC,EAAE,QAAQ,GAAG,SAAS,EAAE,QAAQ,EAAE,EAAE;wBACvD,MAAM,GAAG,GAAG,QAAQ,GAAG,QAAQ,CAAC;wBAChC,sDAAsD;wBACtD,IAAI,GAAG,GAAG,CAAC,IAAI,GAAG,IAAI,OAAO,EAAE;4BAC7B,SAAS;yBACV;wBAED,MAAM,SAAS,GAAG,WAAW,GAAG,GAAG,GAAG,SAAS,CAAC,CAAC,CAAC,CAAC;wBACnD,MAAM,aAAa,GAAG,GAAG,GAAG,WAAW,CAAC;wBACxC,MAAM,gBAAgB,GAAG,IAAI,CAAC,GAAG,CAC7B,OAAO,GAAG,CAAC,EACX,YAAY,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,aAAa,CAAC,CAAC,CAAC;4BAC3B,IAAI,CAAC,KAAK,CAAC,aAAa,CAAC,CAAC,CAAC;wBAC9C,IAAI,CAAC,KAAK,gBAAgB,EAAE;4BAC1B,SAAS;yBACV;wBACD,KAAK,IAAI,QAAQ,GAAG,CAAC,EAAE,QAAQ,GAAG,QAAQ,EAAE,QAAQ,EAAE,EAAE;4BACtD,MAAM,GAAG,GAAG,QAAQ,GAAG,QAAQ,CAAC;4BAChC,sDAAsD;4BACtD,IAAI,GAAG,GAAG,CAAC,IAAI,GAAG,IAAI,MAAM,EAAE;gCAC5B,SAAS;6BACV;4BAED,MAAM,SAAS,GAAG,SAAS,GAAG,GAAG,GAAG,SAAS,CAAC,CAAC,CAAC,CAAC;4BACjD,MAAM,aAAa,GAAG,GAAG,GAAG,UAAU,CAAC;4BACvC,MAAM,gBAAgB,GAAG,IAAI,CAAC,GAAG,CAC7B,MAAM,GAAG,CAAC,EACV,YAAY,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,aAAa,CAAC,CAAC,CAAC;gCAC3B,IAAI,CAAC,KAAK,CAAC,aAAa,CAAC,CAAC,CAAC;4BAE9C,IAAI,CAAC,KAAK,gBAAgB,EAAE;gCAC1B,KAAK,IAAI,QAAQ,CAAC,SAAS,GAAG,CAAC,CAAC,CAAC;6BAClC;yBACF;qBACF;oBACD,MAAM,CAAC,SAAS,GAAG,CAAC,CAAC,GAAG,KAAK,CAAC;iBAC/B;aACF;SACF;KACF;IAED,OAAO,OAAO,CAAC,cAAc,CAAC,MAAM,CAAC,KAAK,EAAE,MAAM,CAAC,KAAK,EAAE,MAAM,CAAC,CAAC;AACpE,CAAC;AAED,MAAM,CAAC,MAAM,+BAA+B,GAAiB;IAC3D,UAAU,EAAE,yBAAyB;IACrC,WAAW,EAAE,KAAK;IAClB,UAAU,EAAE,yBAAkD;CAC/D,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {KernelConfig, KernelFunc, ResizeNearestNeighborGrad, ResizeNearestNeighborGradAttrs, ResizeNearestNeighborGradInputs, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core';\n\nimport {MathBackendCPU} from '../backend_cpu';\nimport {assertNotComplex} from '../cpu_util';\n\nexport function resizeNearestNeighborGrad(args: {\n  inputs: ResizeNearestNeighborGradInputs,\n  backend: MathBackendCPU,\n  attrs: ResizeNearestNeighborGradAttrs\n}): TensorInfo {\n  const {inputs, backend, attrs} = args;\n  const {images, dy} = inputs;\n  const {alignCorners} = attrs;\n\n  assertNotComplex([dy, images], 'resizeNearestNeighborGrad');\n\n  const imagesStrides = util.computeStrides(images.shape);\n  const dyStrides = util.computeStrides(dy.shape);\n  const [batch, xHeight, xWidth, depth] = images.shape;\n  const [, yHeight, yWidth] = dy.shape;\n\n  const output = new Float32Array(batch * xHeight * xWidth * depth);\n  const dyValues = backend.data.get(dy.dataId).values as TypedArray;\n\n  // In the backwards pass, we want to find the pixels that were generated\n  // for each pixel in the input image the forward pass\n\n  const effectiveXSize: [number, number] = [\n    (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,\n    (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth\n  ];\n\n  const effectiveYSize: [number, number] = [\n    (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,\n    (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth\n  ];\n\n  const heightScale = effectiveXSize[0] / effectiveYSize[0];\n  const widthScale = effectiveXSize[1] / effectiveYSize[1];\n\n  const invHeightScale = 1 / heightScale;\n  const invWidthScale = 1 / widthScale;\n\n  // This defines the size of the window of values around a particular\n  // index in dy that we want to search for contributions to dx.\n  const winHeight = (Math.ceil(invHeightScale) * 2) + 2;\n  const winWidth = (Math.ceil(invWidthScale) * 2) + 2;\n\n  // Loop over the output space.\n  for (let b = 0; b < batch; b++) {\n    const batchOffset = b * imagesStrides[0];\n    for (let r = 0; r < xHeight; r++) {\n      const rowOffset = batchOffset + r * imagesStrides[1];\n\n      // Compute bounds for where in dy we will look\n      const startRLerp = Math.floor(r * invHeightScale);\n      const startDyR = Math.floor(startRLerp - (winHeight / 2));\n      for (let c = 0; c < xWidth; c++) {\n        const colOffset = rowOffset + c * imagesStrides[2];\n\n        // Compute bounds for where in dy we will look\n        const startCLerp = Math.floor(c * invWidthScale);\n        const startDyC = Math.floor(startCLerp - (winWidth / 2));\n\n        for (let d = 0; d < depth; d++) {\n          let accum = 0;\n          // loop over dy\n\n          for (let dyRIndex = 0; dyRIndex < winHeight; dyRIndex++) {\n            const dyR = dyRIndex + startDyR;\n            // Guard against the window exceeding the bounds of dy\n            if (dyR < 0 || dyR >= yHeight) {\n              continue;\n            }\n\n            const dyROffset = batchOffset + dyR * dyStrides[1];\n            const sourceFracRow = dyR * heightScale;\n            const sourceNearestRow = Math.min(\n                xHeight - 1,\n                alignCorners ? Math.round(sourceFracRow) :\n                               Math.floor(sourceFracRow));\n            if (r !== sourceNearestRow) {\n              continue;\n            }\n            for (let dyCIndex = 0; dyCIndex < winWidth; dyCIndex++) {\n              const dyC = dyCIndex + startDyC;\n              // Guard against the window exceeding the bounds of dy\n              if (dyC < 0 || dyC >= yWidth) {\n                continue;\n              }\n\n              const dyCOffset = dyROffset + dyC * dyStrides[2];\n              const sourceFracCol = dyC * widthScale;\n              const sourceNearestCol = Math.min(\n                  xWidth - 1,\n                  alignCorners ? Math.round(sourceFracCol) :\n                                 Math.floor(sourceFracCol));\n\n              if (c === sourceNearestCol) {\n                accum += dyValues[dyCOffset + d];\n              }\n            }\n          }\n          output[colOffset + d] = accum;\n        }\n      }\n    }\n  }\n\n  return backend.makeTensorInfo(images.shape, images.dtype, output);\n}\n\nexport const resizeNearestNeighborGradConfig: KernelConfig = {\n  kernelName: ResizeNearestNeighborGrad,\n  backendName: 'cpu',\n  kernelFunc: resizeNearestNeighborGrad as unknown as KernelFunc\n};\n"]}
|