/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, Conv2DBackpropInput, TensorBuffer, util } from '@tensorflow/tfjs-core';
|
import { assertNotComplex } from '../cpu_util';
|
export function conv2DBackpropInput(args) {
|
const { inputs, backend, attrs } = args;
|
const { dy, filter } = inputs;
|
const { inputShape, strides, pad, dataFormat, dimRoundingMode } = attrs;
|
assertNotComplex([dy, filter], 'conv2dBackpropInput');
|
const filterStrides = util.computeStrides(filter.shape);
|
const dyStrides = util.computeStrides(dy.shape);
|
let $dataFormat = backend_util.convertConv2DDataFormat(dataFormat);
|
const convInfo = backend_util.computeConv2DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad, dimRoundingMode, false, $dataFormat);
|
const dx = new TensorBuffer(convInfo.inShape, 'float32');
|
const dxValues = dx.values;
|
const dyValues = backend.data.get(dy.dataId).values;
|
const fltValues = backend.data.get(filter.dataId).values;
|
const [fltS0, fltS1, fltS2] = filterStrides;
|
const { batchSize, filterHeight, filterWidth, inChannels, inHeight, inWidth, outChannels, outHeight, outWidth, strideHeight, strideWidth } = convInfo;
|
$dataFormat = convInfo.dataFormat;
|
const topPad = filterHeight - 1 - convInfo.padInfo.top;
|
const leftPad = filterWidth - 1 - convInfo.padInfo.left;
|
const isChannelsLast = $dataFormat === 'channelsLast';
|
const xBatchStride = dx.strides[0];
|
const xRowStride = isChannelsLast ? dx.strides[1] : dx.strides[2];
|
const xColStride = isChannelsLast ? dx.strides[2] : 1;
|
const xChannelStride = isChannelsLast ? 1 : dx.strides[1];
|
const yBatchStride = dyStrides[0];
|
const yRowStride = isChannelsLast ? dyStrides[1] : dyStrides[2];
|
const yColStride = isChannelsLast ? dyStrides[2] : 1;
|
const yChannelStride = isChannelsLast ? 1 : dyStrides[1];
|
for (let b = 0; b < batchSize; ++b) {
|
for (let d1 = 0; d1 < inChannels; ++d1) {
|
for (let xR = 0; xR < inHeight; ++xR) {
|
const xRCorner = xR - topPad;
|
const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
|
const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
|
for (let xC = 0; xC < inWidth; ++xC) {
|
const xCCorner = xC - leftPad;
|
const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
|
const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
|
let dotProd = 0;
|
for (let yR = xRMin; yR < yRMax; ++yR) {
|
const wR = yR * strideHeight - xRCorner;
|
for (let yC = xCMin; yC < yCMax; ++yC) {
|
const wC = yC * strideWidth - xCCorner;
|
const dyOffset = yBatchStride * b + yRowStride * yR + yColStride * yC;
|
const fltOffset = fltS0 * (filterHeight - 1 - wR) +
|
fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
|
for (let d2 = 0; d2 < outChannels; ++d2) {
|
const pixel = dyValues[dyOffset + yChannelStride * d2];
|
const weight = fltValues[fltOffset + d2];
|
dotProd += pixel * weight;
|
}
|
}
|
}
|
const dxOffset = xBatchStride * b + xRowStride * xR +
|
xColStride * xC + xChannelStride * d1;
|
dxValues[dxOffset] = dotProd;
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
|
}
|
export const conv2DBackpropInputConfig = {
|
kernelName: Conv2DBackpropInput,
|
backendName: 'cpu',
|
kernelFunc: conv2DBackpropInput
|
};
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"Conv2DBackpropInput.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-cpu/src/kernels/Conv2DBackpropInput.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAAE,mBAAmB,EAAiF,YAAY,EAA0B,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAGnM,OAAO,EAAC,gBAAgB,EAAC,MAAM,aAAa,CAAC;AAE7C,MAAM,UAAU,mBAAmB,CAAC,IAInC;IACC,MAAM,EAAC,MAAM,EAAE,OAAO,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,EAAE,EAAE,MAAM,EAAC,GAAG,MAAM,CAAC;IAC5B,MAAM,EAAC,UAAU,EAAE,OAAO,EAAE,GAAG,EAAE,UAAU,EAAE,eAAe,EAAC,GAAG,KAAK,CAAC;IAEtE,gBAAgB,CAAC,CAAC,EAAE,EAAE,MAAM,CAAC,EAAE,qBAAqB,CAAC,CAAC;IAEtD,MAAM,aAAa,GAAG,IAAI,CAAC,cAAc,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;IACxD,MAAM,SAAS,GAAG,IAAI,CAAC,cAAc,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC;IAEhD,IAAI,WAAW,GAAG,YAAY,CAAC,uBAAuB,CAAC,UAAU,CAAC,CAAC;IACnE,MAAM,QAAQ,GAAG,YAAY,CAAC,iBAAiB,CAC3C,UAAU,EAAE,MAAM,CAAC,KAAyC,EAAE,OAAO,EACrE,CAAC,CAAC,eAAe,EAAE,GAAG,EAAE,eAAe,EAAE,KAAK,EAAE,WAAW,CAAC,CAAC;IAEjE,MAAM,EAAE,GAAG,IAAI,YAAY,CAAC,QAAQ,CAAC,OAAO,EAAE,SAAS,CAAC,CAAC;IACzD,MAAM,QAAQ,GAAG,EAAE,CAAC,MAAM,CAAC;IAC3B,MAAM,QAAQ,GAAG,OAAO,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,MAAoB,CAAC;IAClE,MAAM,SAAS,GAAG,OAAO,CAAC,IAAI,CAAC,GAAG,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC,MAAoB,CAAC;IACvE,MAAM,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,GAAG,aAAa,CAAC;IAC5C,MAAM,EACJ,SAAS,EACT,YAAY,EACZ,WAAW,EACX,UAAU,EACV,QAAQ,EACR,OAAO,EACP,WAAW,EACX,SAAS,EACT,QAAQ,EACR,YAAY,EACZ,WAAW,EACZ,GAAG,QAAQ,CAAC;IACb,WAAW,GAAG,QAAQ,CAAC,UAAU,CAAC;IAClC,MAAM,MAAM,GAAG,YAAY,GAAG,CAAC,GAAG,QAAQ,CAAC,OAAO,CAAC,GAAG,CAAC;IACvD,MAAM,OAAO,GAAG,WAAW,GAAG,CAAC,GAAG,QAAQ,CAAC,OAAO,CAAC,IAAI,CAAC;IAExD,MAAM,cAAc,GAAG,WAAW,KAAK,cAAc,CAAC;IACtD,MAAM,YAAY,GAAG,EAAE,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IACnC,MAAM,UAAU,GAAG,cAAc,CAAC,CAAC,CAAC,EAAE,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IAClE,MAAM,UAAU,GAAG,cAAc,CAAC,CAAC,CAAC,EAAE,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IACtD,MAAM,cAAc,GAAG,cAAc,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IAC1D,MAAM,YAAY,GAAG,SAAS,CAAC,CAAC,CAAC,CAAC;IAClC,MAAM,UAAU,GAAG,cAAc,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC;IAChE,MAAM,UAAU,GAAG,cAAc,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IACrD,MAAM,cAAc,GAAG,cAAc,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC;IAEzD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,EAAE,CAAC,EAAE;QAClC,KAAK,IAAI,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,UAAU,EAAE,EAAE,EAAE,EAAE;YACtC,KAAK,IAAI,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,QAAQ,EAAE,EAAE,EAAE,EAAE;gBACpC,MAAM,QAAQ,GAAG,EAAE,GAAG,MAAM,CAAC;gBAC7B,MAAM,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,QAAQ,GAAG,YAAY,CAAC,CAAC,CAAC;gBAC9D,MAAM,KAAK,GACP,IAAI,CAAC,GAAG,CAAC,SAAS,EAAE,CAAC,YAAY,GAAG,QAAQ,CAAC,GAAG,YAAY,CAAC,CAAC;gBAElE,KAAK,IAAI,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,OAAO,EAAE,EAAE,EAAE,EAAE;oBACnC,MAAM,QAAQ,GAAG,EAAE,GAAG,OAAO,CAAC;oBAC9B,MAAM,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,QAAQ,GAAG,WAAW,CAAC,CAAC,CAAC;oBAC7D,MAAM,KAAK,GACP,IAAI,CAAC,GAAG,CAAC,QAAQ,EAAE,CAAC,WAAW,GAAG,QAAQ,CAAC,GAAG,WAAW,CAAC,CAAC;oBAE/D,IAAI,OAAO,GAAG,CAAC,CAAC;oBAChB,KAAK,IAAI,EAAE,GAAG,KAAK,EAAE,EAAE,GAAG,KAAK,EAAE,EAAE,EAAE,EAAE;wBACrC,MAAM,EAAE,GAAG,EAAE,GAAG,YAAY,GAAG,QAAQ,CAAC;wBAExC,KAAK,IAAI,EAAE,GAAG,KAAK,EAAE,EAAE,GAAG,KAAK,EAAE,EAAE,EAAE,EAAE;4BACrC,MAAM,EAAE,GAAG,EAAE,GAAG,WAAW,GAAG,QAAQ,CAAC;4BACvC,MAAM,QAAQ,GACV,YAAY,GAAG,CAAC,GAAG,UAAU,GAAG,EAAE,GAAG,UAAU,GAAG,EAAE,CAAC;4BACzD,MAAM,SAAS,GAAG,KAAK,GAAG,CAAC,YAAY,GAAG,CAAC,GAAG,EAAE,CAAC;gCAC7C,KAAK,GAAG,CAAC,WAAW,GAAG,CAAC,GAAG,EAAE,CAAC,GAAG,KAAK,GAAG,EAAE,CAAC;4BAEhD,KAAK,IAAI,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,WAAW,EAAE,EAAE,EAAE,EAAE;gCACvC,MAAM,KAAK,GAAG,QAAQ,CAAC,QAAQ,GAAG,cAAc,GAAG,EAAE,CAAC,CAAC;gCACvD,MAAM,MAAM,GAAG,SAAS,CAAC,SAAS,GAAG,EAAE,CAAC,CAAC;gCACzC,OAAO,IAAI,KAAK,GAAG,MAAM,CAAC;6BAC3B;yBACF;qBACF;oBACD,MAAM,QAAQ,GAAG,YAAY,GAAG,CAAC,GAAG,UAAU,GAAG,EAAE;wBAC/C,UAAU,GAAG,EAAE,GAAG,cAAc,GAAG,EAAE,CAAC;oBAC1C,QAAQ,CAAC,QAAQ,CAAC,GAAG,OAAO,CAAC;iBAC9B;aACF;SACF;KACF;IAED,OAAO,OAAO,CAAC,cAAc,CAAC,EAAE,CAAC,KAAK,EAAE,EAAE,CAAC,KAAK,EAAE,EAAE,CAAC,MAAM,CAAC,CAAC;AAC/D,CAAC;AAED,MAAM,CAAC,MAAM,yBAAyB,GAAiB;IACrD,UAAU,EAAE,mBAAmB;IAC/B,WAAW,EAAE,KAAK;IAClB,UAAU,EAAE,mBAA4C;CACzD,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, Conv2DBackpropInput, Conv2DBackpropInputAttrs, Conv2DBackpropInputInputs, KernelConfig, KernelFunc, TensorBuffer, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core';\n\nimport {MathBackendCPU} from '../backend_cpu';\nimport {assertNotComplex} from '../cpu_util';\n\nexport function conv2DBackpropInput(args: {\n  inputs: Conv2DBackpropInputInputs,\n  backend: MathBackendCPU,\n  attrs: Conv2DBackpropInputAttrs\n}): TensorInfo {\n  const {inputs, backend, attrs} = args;\n  const {dy, filter} = inputs;\n  const {inputShape, strides, pad, dataFormat, dimRoundingMode} = attrs;\n\n  assertNotComplex([dy, filter], 'conv2dBackpropInput');\n\n  const filterStrides = util.computeStrides(filter.shape);\n  const dyStrides = util.computeStrides(dy.shape);\n\n  let $dataFormat = backend_util.convertConv2DDataFormat(dataFormat);\n  const convInfo = backend_util.computeConv2DInfo(\n      inputShape, filter.shape as [number, number, number, number], strides,\n      1 /* dilations */, pad, dimRoundingMode, false, $dataFormat);\n\n  const dx = new TensorBuffer(convInfo.inShape, 'float32');\n  const dxValues = dx.values;\n  const dyValues = backend.data.get(dy.dataId).values as TypedArray;\n  const fltValues = backend.data.get(filter.dataId).values as TypedArray;\n  const [fltS0, fltS1, fltS2] = filterStrides;\n  const {\n    batchSize,\n    filterHeight,\n    filterWidth,\n    inChannels,\n    inHeight,\n    inWidth,\n    outChannels,\n    outHeight,\n    outWidth,\n    strideHeight,\n    strideWidth\n  } = convInfo;\n  $dataFormat = convInfo.dataFormat;\n  const topPad = filterHeight - 1 - convInfo.padInfo.top;\n  const leftPad = filterWidth - 1 - convInfo.padInfo.left;\n\n  const isChannelsLast = $dataFormat === 'channelsLast';\n  const xBatchStride = dx.strides[0];\n  const xRowStride = isChannelsLast ? dx.strides[1] : dx.strides[2];\n  const xColStride = isChannelsLast ? dx.strides[2] : 1;\n  const xChannelStride = isChannelsLast ? 1 : dx.strides[1];\n  const yBatchStride = dyStrides[0];\n  const yRowStride = isChannelsLast ? dyStrides[1] : dyStrides[2];\n  const yColStride = isChannelsLast ? dyStrides[2] : 1;\n  const yChannelStride = isChannelsLast ? 1 : dyStrides[1];\n\n  for (let b = 0; b < batchSize; ++b) {\n    for (let d1 = 0; d1 < inChannels; ++d1) {\n      for (let xR = 0; xR < inHeight; ++xR) {\n        const xRCorner = xR - topPad;\n        const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));\n        const yRMax =\n            Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);\n\n        for (let xC = 0; xC < inWidth; ++xC) {\n          const xCCorner = xC - leftPad;\n          const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));\n          const yCMax =\n              Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);\n\n          let dotProd = 0;\n          for (let yR = xRMin; yR < yRMax; ++yR) {\n            const wR = yR * strideHeight - xRCorner;\n\n            for (let yC = xCMin; yC < yCMax; ++yC) {\n              const wC = yC * strideWidth - xCCorner;\n              const dyOffset =\n                  yBatchStride * b + yRowStride * yR + yColStride * yC;\n              const fltOffset = fltS0 * (filterHeight - 1 - wR) +\n                  fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;\n\n              for (let d2 = 0; d2 < outChannels; ++d2) {\n                const pixel = dyValues[dyOffset + yChannelStride * d2];\n                const weight = fltValues[fltOffset + d2];\n                dotProd += pixel * weight;\n              }\n            }\n          }\n          const dxOffset = xBatchStride * b + xRowStride * xR +\n              xColStride * xC + xChannelStride * d1;\n          dxValues[dxOffset] = dotProd;\n        }\n      }\n    }\n  }\n\n  return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);\n}\n\nexport const conv2DBackpropInputConfig: KernelConfig = {\n  kernelName: Conv2DBackpropInput,\n  backendName: 'cpu',\n  kernelFunc: conv2DBackpropInput as unknown as KernelFunc\n};\n"]}
|