/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { ResizeNearestNeighborGrad, util } from '@tensorflow/tfjs-core'; import { assertNotComplex } from '../cpu_util'; export function resizeNearestNeighborGrad(args) { const { inputs, backend, attrs } = args; const { images, dy } = inputs; const { alignCorners } = attrs; assertNotComplex([dy, images], 'resizeNearestNeighborGrad'); const imagesStrides = util.computeStrides(images.shape); const dyStrides = util.computeStrides(dy.shape); const [batch, xHeight, xWidth, depth] = images.shape; const [, yHeight, yWidth] = dy.shape; const output = new Float32Array(batch * xHeight * xWidth * depth); const dyValues = backend.data.get(dy.dataId).values; // In the backwards pass, we want to find the pixels that were generated // for each pixel in the input image the forward pass const effectiveXSize = [ (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight, (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth ]; const effectiveYSize = [ (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight, (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth ]; const heightScale = effectiveXSize[0] / effectiveYSize[0]; const widthScale = effectiveXSize[1] / effectiveYSize[1]; const invHeightScale = 1 / heightScale; const invWidthScale = 1 / widthScale; // This defines the size of the window of values around a particular // index in dy that we want to search for contributions to dx. const winHeight = (Math.ceil(invHeightScale) * 2) + 2; const winWidth = (Math.ceil(invWidthScale) * 2) + 2; // Loop over the output space. for (let b = 0; b < batch; b++) { const batchOffset = b * imagesStrides[0]; for (let r = 0; r < xHeight; r++) { const rowOffset = batchOffset + r * imagesStrides[1]; // Compute bounds for where in dy we will look const startRLerp = Math.floor(r * invHeightScale); const startDyR = Math.floor(startRLerp - (winHeight / 2)); for (let c = 0; c < xWidth; c++) { const colOffset = rowOffset + c * imagesStrides[2]; // Compute bounds for where in dy we will look const startCLerp = Math.floor(c * invWidthScale); const startDyC = Math.floor(startCLerp - (winWidth / 2)); for (let d = 0; d < depth; d++) { let accum = 0; // loop over dy for (let dyRIndex = 0; dyRIndex < winHeight; dyRIndex++) { const dyR = dyRIndex + startDyR; // Guard against the window exceeding the bounds of dy if (dyR < 0 || dyR >= yHeight) { continue; } const dyROffset = batchOffset + dyR * dyStrides[1]; const sourceFracRow = dyR * heightScale; const sourceNearestRow = Math.min(xHeight - 1, alignCorners ? Math.round(sourceFracRow) : Math.floor(sourceFracRow)); if (r !== sourceNearestRow) { continue; } for (let dyCIndex = 0; dyCIndex < winWidth; dyCIndex++) { const dyC = dyCIndex + startDyC; // Guard against the window exceeding the bounds of dy if (dyC < 0 || dyC >= yWidth) { continue; } const dyCOffset = dyROffset + dyC * dyStrides[2]; const sourceFracCol = dyC * widthScale; const sourceNearestCol = Math.min(xWidth - 1, alignCorners ? Math.round(sourceFracCol) : Math.floor(sourceFracCol)); if (c === sourceNearestCol) { accum += dyValues[dyCOffset + d]; } } } output[colOffset + d] = accum; } } } } return backend.makeTensorInfo(images.shape, images.dtype, output); } export const resizeNearestNeighborGradConfig = { kernelName: ResizeNearestNeighborGrad, backendName: 'cpu', kernelFunc: resizeNearestNeighborGrad }; //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"ResizeNearestNeighborGrad.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-cpu/src/kernels/ResizeNearestNeighborGrad.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAA2B,yBAAyB,EAA2F,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAGzL,OAAO,EAAC,gBAAgB,EAAC,MAAM,aAAa,CAAC;AAE7C,MAAM,UAAU,yBAAyB,CAAC,IAIzC;IACC,MAAM,EAAC,MAAM,EAAE,OAAO,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,MAAM,EAAE,EAAE,EAAC,GAAG,MAAM,CAAC;IAC5B,MAAM,EAAC,YAAY,EAAC,GAAG,KAAK,CAAC;IAE7B,gBAAgB,CAAC,CAAC,EAAE,EAAE,MAAM,CAAC,EAAE,2BAA2B,CAAC,CAAC;IAE5D,MAAM,aAAa,GAAG,IAAI,CAAC,cAAc,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;IACxD,MAAM,SAAS,GAAG,IAAI,CAAC,cAAc,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC;IAChD,MAAM,CAAC,KAAK,EAAE,OAAO,EAAE,MAAM,EAAE,KAAK,CAAC,GAAG,MAAM,CAAC,KAAK,CAAC;IACrD,MAAM,CAAC,EAAE,OAAO,EAAE,MAAM,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC;IAErC,MAAM,MAAM,GAAG,IAAI,YAAY,CAAC,KAAK,GAAG,OAAO,GAAG,MAAM,GAAG,KAAK,CAAC,CAAC;IAClE,MAAM,QAAQ,GAAG,OAAO,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,MAAoB,CAAC;IAElE,wEAAwE;IACxE,qDAAqD;IAErD,MAAM,cAAc,GAAqB;QACvC,CAAC,YAAY,IAAI,OAAO,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,GAAG,CAAC,CAAC,CAAC,CAAC,OAAO;QACrD,CAAC,YAAY,IAAI,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC,MAAM;KACnD,CAAC;IAEF,MAAM,cAAc,GAAqB;QACvC,CAAC,YAAY,IAAI,OAAO,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,GAAG,CAAC,CAAC,CAAC,CAAC,OAAO;QACrD,CAAC,YAAY,IAAI,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC,MAAM;KACnD,CAAC;IAEF,MAAM,WAAW,GAAG,cAAc,CAAC,CAAC,CAAC,GAAG,cAAc,CAAC,CAAC,CAAC,CAAC;IAC1D,MAAM,UAAU,GAAG,cAAc,CAAC,CAAC,CAAC,GAAG,cAAc,CAAC,CAAC,CAAC,CAAC;IAEzD,MAAM,cAAc,GAAG,CAAC,GAAG,WAAW,CAAC;IACvC,MAAM,aAAa,GAAG,CAAC,GAAG,UAAU,CAAC;IAErC,oEAAoE;IACpE,8DAA8D;IAC9D,MAAM,SAAS,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;IACtD,MAAM,QAAQ,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,aAAa,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;IAEpD,8BAA8B;IAC9B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,EAAE,CAAC,EAAE,EAAE;QAC9B,MAAM,WAAW,GAAG,CAAC,GAAG,aAAa,CAAC,CAAC,CAAC,CAAC;QACzC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,OAAO,EAAE,CAAC,EAAE,EAAE;YAChC,MAAM,SAAS,GAAG,WAAW,GAAG,CAAC,GAAG,aAAa,CAAC,CAAC,CAAC,CAAC;YAErD,8CAA8C;YAC9C,MAAM,UAAU,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,GAAG,cAAc,CAAC,CAAC;YAClD,MAAM,QAAQ,GAAG,IAAI,CAAC,KAAK,CAAC,UAAU,GAAG,CAAC,SAAS,GAAG,CAAC,CAAC,CAAC,CAAC;YAC1D,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,EAAE,CAAC,EAAE,EAAE;gBAC/B,MAAM,SAAS,GAAG,SAAS,GAAG,CAAC,GAAG,aAAa,CAAC,CAAC,CAAC,CAAC;gBAEnD,8CAA8C;gBAC9C,MAAM,UAAU,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,GAAG,aAAa,CAAC,CAAC;gBACjD,MAAM,QAAQ,GAAG,IAAI,CAAC,KAAK,CAAC,UAAU,GAAG,CAAC,QAAQ,GAAG,CAAC,CAAC,CAAC,CAAC;gBAEzD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,EAAE,CAAC,EAAE,EAAE;oBAC9B,IAAI,KAAK,GAAG,CAAC,CAAC;oBACd,eAAe;oBAEf,KAAK,IAAI,QAAQ,GAAG,CAAC,EAAE,QAAQ,GAAG,SAAS,EAAE,QAAQ,EAAE,EAAE;wBACvD,MAAM,GAAG,GAAG,QAAQ,GAAG,QAAQ,CAAC;wBAChC,sDAAsD;wBACtD,IAAI,GAAG,GAAG,CAAC,IAAI,GAAG,IAAI,OAAO,EAAE;4BAC7B,SAAS;yBACV;wBAED,MAAM,SAAS,GAAG,WAAW,GAAG,GAAG,GAAG,SAAS,CAAC,CAAC,CAAC,CAAC;wBACnD,MAAM,aAAa,GAAG,GAAG,GAAG,WAAW,CAAC;wBACxC,MAAM,gBAAgB,GAAG,IAAI,CAAC,GAAG,CAC7B,OAAO,GAAG,CAAC,EACX,YAAY,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,aAAa,CAAC,CAAC,CAAC;4BAC3B,IAAI,CAAC,KAAK,CAAC,aAAa,CAAC,CAAC,CAAC;wBAC9C,IAAI,CAAC,KAAK,gBAAgB,EAAE;4BAC1B,SAAS;yBACV;wBACD,KAAK,IAAI,QAAQ,GAAG,CAAC,EAAE,QAAQ,GAAG,QAAQ,EAAE,QAAQ,EAAE,EAAE;4BACtD,MAAM,GAAG,GAAG,QAAQ,GAAG,QAAQ,CAAC;4BAChC,sDAAsD;4BACtD,IAAI,GAAG,GAAG,CAAC,IAAI,GAAG,IAAI,MAAM,EAAE;gCAC5B,SAAS;6BACV;4BAED,MAAM,SAAS,GAAG,SAAS,GAAG,GAAG,GAAG,SAAS,CAAC,CAAC,CAAC,CAAC;4BACjD,MAAM,aAAa,GAAG,GAAG,GAAG,UAAU,CAAC;4BACvC,MAAM,gBAAgB,GAAG,IAAI,CAAC,GAAG,CAC7B,MAAM,GAAG,CAAC,EACV,YAAY,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,aAAa,CAAC,CAAC,CAAC;gCAC3B,IAAI,CAAC,KAAK,CAAC,aAAa,CAAC,CAAC,CAAC;4BAE9C,IAAI,CAAC,KAAK,gBAAgB,EAAE;gCAC1B,KAAK,IAAI,QAAQ,CAAC,SAAS,GAAG,CAAC,CAAC,CAAC;6BAClC;yBACF;qBACF;oBACD,MAAM,CAAC,SAAS,GAAG,CAAC,CAAC,GAAG,KAAK,CAAC;iBAC/B;aACF;SACF;KACF;IAED,OAAO,OAAO,CAAC,cAAc,CAAC,MAAM,CAAC,KAAK,EAAE,MAAM,CAAC,KAAK,EAAE,MAAM,CAAC,CAAC;AACpE,CAAC;AAED,MAAM,CAAC,MAAM,+BAA+B,GAAiB;IAC3D,UAAU,EAAE,yBAAyB;IACrC,WAAW,EAAE,KAAK;IAClB,UAAU,EAAE,yBAAkD;CAC/D,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {KernelConfig, KernelFunc, ResizeNearestNeighborGrad, ResizeNearestNeighborGradAttrs, ResizeNearestNeighborGradInputs, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core';\n\nimport {MathBackendCPU} from '../backend_cpu';\nimport {assertNotComplex} from '../cpu_util';\n\nexport function resizeNearestNeighborGrad(args: {\n  inputs: ResizeNearestNeighborGradInputs,\n  backend: MathBackendCPU,\n  attrs: ResizeNearestNeighborGradAttrs\n}): TensorInfo {\n  const {inputs, backend, attrs} = args;\n  const {images, dy} = inputs;\n  const {alignCorners} = attrs;\n\n  assertNotComplex([dy, images], 'resizeNearestNeighborGrad');\n\n  const imagesStrides = util.computeStrides(images.shape);\n  const dyStrides = util.computeStrides(dy.shape);\n  const [batch, xHeight, xWidth, depth] = images.shape;\n  const [, yHeight, yWidth] = dy.shape;\n\n  const output = new Float32Array(batch * xHeight * xWidth * depth);\n  const dyValues = backend.data.get(dy.dataId).values as TypedArray;\n\n  // In the backwards pass, we want to find the pixels that were generated\n  // for each pixel in the input image the forward pass\n\n  const effectiveXSize: [number, number] = [\n    (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,\n    (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth\n  ];\n\n  const effectiveYSize: [number, number] = [\n    (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,\n    (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth\n  ];\n\n  const heightScale = effectiveXSize[0] / effectiveYSize[0];\n  const widthScale = effectiveXSize[1] / effectiveYSize[1];\n\n  const invHeightScale = 1 / heightScale;\n  const invWidthScale = 1 / widthScale;\n\n  // This defines the size of the window of values around a particular\n  // index in dy that we want to search for contributions to dx.\n  const winHeight = (Math.ceil(invHeightScale) * 2) + 2;\n  const winWidth = (Math.ceil(invWidthScale) * 2) + 2;\n\n  // Loop over the output space.\n  for (let b = 0; b < batch; b++) {\n    const batchOffset = b * imagesStrides[0];\n    for (let r = 0; r < xHeight; r++) {\n      const rowOffset = batchOffset + r * imagesStrides[1];\n\n      // Compute bounds for where in dy we will look\n      const startRLerp = Math.floor(r * invHeightScale);\n      const startDyR = Math.floor(startRLerp - (winHeight / 2));\n      for (let c = 0; c < xWidth; c++) {\n        const colOffset = rowOffset + c * imagesStrides[2];\n\n        // Compute bounds for where in dy we will look\n        const startCLerp = Math.floor(c * invWidthScale);\n        const startDyC = Math.floor(startCLerp - (winWidth / 2));\n\n        for (let d = 0; d < depth; d++) {\n          let accum = 0;\n          // loop over dy\n\n          for (let dyRIndex = 0; dyRIndex < winHeight; dyRIndex++) {\n            const dyR = dyRIndex + startDyR;\n            // Guard against the window exceeding the bounds of dy\n            if (dyR < 0 || dyR >= yHeight) {\n              continue;\n            }\n\n            const dyROffset = batchOffset + dyR * dyStrides[1];\n            const sourceFracRow = dyR * heightScale;\n            const sourceNearestRow = Math.min(\n                xHeight - 1,\n                alignCorners ? Math.round(sourceFracRow) :\n                               Math.floor(sourceFracRow));\n            if (r !== sourceNearestRow) {\n              continue;\n            }\n            for (let dyCIndex = 0; dyCIndex < winWidth; dyCIndex++) {\n              const dyC = dyCIndex + startDyC;\n              // Guard against the window exceeding the bounds of dy\n              if (dyC < 0 || dyC >= yWidth) {\n                continue;\n              }\n\n              const dyCOffset = dyROffset + dyC * dyStrides[2];\n              const sourceFracCol = dyC * widthScale;\n              const sourceNearestCol = Math.min(\n                  xWidth - 1,\n                  alignCorners ? Math.round(sourceFracCol) :\n                                 Math.floor(sourceFracCol));\n\n              if (c === sourceNearestCol) {\n                accum += dyValues[dyCOffset + d];\n              }\n            }\n          }\n          output[colOffset + d] = accum;\n        }\n      }\n    }\n  }\n\n  return backend.makeTensorInfo(images.shape, images.dtype, output);\n}\n\nexport const resizeNearestNeighborGradConfig: KernelConfig = {\n  kernelName: ResizeNearestNeighborGrad,\n  backendName: 'cpu',\n  kernelFunc: resizeNearestNeighborGrad as unknown as KernelFunc\n};\n"]}