/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, Conv2DBackpropFilter, TensorBuffer } from '@tensorflow/tfjs-core';
|
import { assertNotComplex } from '../cpu_util';
|
export function conv2DBackpropFilter(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, dy } = inputs;
|
const { strides, pad, dataFormat, dimRoundingMode, filterShape } = attrs;
|
assertNotComplex([x, dy], 'conv2dBackpropFilter');
|
const $dataFormat = backend_util.convertConv2DDataFormat(dataFormat);
|
const convInfo = backend_util.computeConv2DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
|
const { strideHeight, strideWidth, filterHeight, filterWidth } = convInfo;
|
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
|
const dW = new TensorBuffer(convInfo.filterShape, 'float32');
|
const leftPad = convInfo.padInfo.left;
|
const topPad = convInfo.padInfo.top;
|
const xVals = backend.data.get(x.dataId).values;
|
const dyVals = backend.data.get(dy.dataId).values;
|
const xBuf = new TensorBuffer(x.shape, x.dtype, xVals);
|
const dyBuf = new TensorBuffer(dy.shape, dy.dtype, dyVals);
|
for (let wR = 0; wR < filterHeight; ++wR) {
|
const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
|
const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
|
for (let wC = 0; wC < filterWidth; ++wC) {
|
const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
|
const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
|
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
|
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
|
let dotProd = 0;
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
for (let yR = yRMin; yR < yRMax; ++yR) {
|
const xR = wR + yR * strideHeight - topPad;
|
for (let yC = yCMin; yC < yCMax; ++yC) {
|
const xC = wC + yC * strideWidth - leftPad;
|
if (isChannelsLast) {
|
dotProd += xBuf.get(b, xR, xC, d1) *
|
dyBuf.get(b, yR, yC, d2);
|
}
|
else {
|
dotProd += xBuf.get(b, d1, xR, xC) *
|
dyBuf.get(b, d2, yR, yC);
|
}
|
}
|
}
|
}
|
dW.set(dotProd, wR, wC, d1, d2);
|
}
|
}
|
}
|
}
|
return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);
|
}
|
export const conv2DBackpropFilterConfig = {
|
kernelName: Conv2DBackpropFilter,
|
backendName: 'cpu',
|
kernelFunc: conv2DBackpropFilter
|
};
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"Conv2DBackpropFilter.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-cpu/src/kernels/Conv2DBackpropFilter.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAAE,oBAAoB,EAAmF,YAAY,EAAyB,MAAM,uBAAuB,CAAC;AAGhM,OAAO,EAAC,gBAAgB,EAAC,MAAM,aAAa,CAAC;AAE7C,MAAM,UAAU,oBAAoB,CAAC,IAIpC;IACC,MAAM,EAAC,MAAM,EAAE,OAAO,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,CAAC,EAAE,EAAE,EAAC,GAAG,MAAM,CAAC;IACvB,MAAM,EAAC,OAAO,EAAE,GAAG,EAAE,UAAU,EAAE,eAAe,EAAE,WAAW,EAAC,GAAG,KAAK,CAAC;IAEvE,gBAAgB,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,sBAAsB,CAAC,CAAC;IAElD,MAAM,WAAW,GAAG,YAAY,CAAC,uBAAuB,CAAC,UAAU,CAAC,CAAC;IACrE,MAAM,QAAQ,GAAG,YAAY,CAAC,iBAAiB,CAC3C,CAAC,CAAC,KAAyC,EAAE,WAAW,EAAE,OAAO,EACjE,CAAC,CAAC,eAAe,EAAE,GAAG,EAAE,eAAe,EAAE,KAAK,CAAC,eAAe,EAC9D,WAAW,CAAC,CAAC;IAEjB,MAAM,EAAC,YAAY,EAAE,WAAW,EAAE,YAAY,EAAE,WAAW,EAAC,GAAG,QAAQ,CAAC;IACxE,MAAM,cAAc,GAAG,QAAQ,CAAC,UAAU,KAAK,cAAc,CAAC;IAC9D,MAAM,EAAE,GAAG,IAAI,YAAY,CAAC,QAAQ,CAAC,WAAW,EAAE,SAAS,CAAC,CAAC;IAE7D,MAAM,OAAO,GAAG,QAAQ,CAAC,OAAO,CAAC,IAAI,CAAC;IACtC,MAAM,MAAM,GAAG,QAAQ,CAAC,OAAO,CAAC,GAAG,CAAC;IACpC,MAAM,KAAK,GAAG,OAAO,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,MAAoB,CAAC;IAC9D,MAAM,MAAM,GAAG,OAAO,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,MAAoB,CAAC;IAEhE,MAAM,IAAI,GAAG,IAAI,YAAY,CAAC,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC;IACvD,MAAM,KAAK,GAAG,IAAI,YAAY,CAAC,EAAE,CAAC,KAAK,EAAE,EAAE,CAAC,KAAK,EAAE,MAAM,CAAC,CAAC;IAE3D,KAAK,IAAI,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,YAAY,EAAE,EAAE,EAAE,EAAE;QACxC,MAAM,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC,MAAM,GAAG,EAAE,CAAC,GAAG,YAAY,CAAC,CAAC,CAAC;QACnE,MAAM,KAAK,GAAG,IAAI,CAAC,GAAG,CAClB,QAAQ,CAAC,SAAS,EAAE,CAAC,QAAQ,CAAC,QAAQ,GAAG,MAAM,GAAG,EAAE,CAAC,GAAG,YAAY,CAAC,CAAC;QAE1E,KAAK,IAAI,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,WAAW,EAAE,EAAE,EAAE,EAAE;YACvC,MAAM,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC,OAAO,GAAG,EAAE,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC;YACnE,MAAM,KAAK,GAAG,IAAI,CAAC,GAAG,CAClB,QAAQ,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC,OAAO,GAAG,OAAO,GAAG,EAAE,CAAC,GAAG,WAAW,CAAC,CAAC;YAExE,KAAK,IAAI,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,QAAQ,CAAC,UAAU,EAAE,EAAE,EAAE,EAAE;gBAC/C,KAAK,IAAI,EAAE,GAAG,CAAC,EAAE,EAAE,GAAG,QAAQ,CAAC,WAAW,EAAE,EAAE,EAAE,EAAE;oBAChD,IAAI,OAAO,GAAG,CAAC,CAAC;oBAChB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,CAAC,SAAS,EAAE,EAAE,CAAC,EAAE;wBAC3C,KAAK,IAAI,EAAE,GAAG,KAAK,EAAE,EAAE,GAAG,KAAK,EAAE,EAAE,EAAE,EAAE;4BACrC,MAAM,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,YAAY,GAAG,MAAM,CAAC;4BAC3C,KAAK,IAAI,EAAE,GAAG,KAAK,EAAE,EAAE,GAAG,KAAK,EAAE,EAAE,EAAE,EAAE;gCACrC,MAAM,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,WAAW,GAAG,OAAO,CAAC;gCAC3C,IAAI,cAAc,EAAE;oCAClB,OAAO,IAAK,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAY;wCACzC,KAAK,CAAC,GAAG,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAY,CAAC;iCAC1C;qCAAM;oCACL,OAAO,IAAK,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAY;wCACzC,KAAK,CAAC,GAAG,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAY,CAAC;iCAC1C;6BACF;yBACF;qBACF;oBACD,EAAE,CAAC,GAAG,CAAC,OAAO,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC;iBACjC;aACF;SACF;KACF;IAED,OAAO,OAAO,CAAC,cAAc,CAAC,EAAE,CAAC,KAAK,EAAE,EAAE,CAAC,KAAK,EAAE,EAAE,CAAC,MAAM,CAAC,CAAC;AAC/D,CAAC;AAED,MAAM,CAAC,MAAM,0BAA0B,GAAiB;IACtD,UAAU,EAAE,oBAAoB;IAChC,WAAW,EAAE,KAAK;IAClB,UAAU,EAAE,oBAA6C;CAC1D,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, Conv2DBackpropFilter, Conv2DBackpropFilterAttrs, Conv2DBackpropFilterInputs, KernelConfig, KernelFunc, TensorBuffer, TensorInfo, TypedArray} from '@tensorflow/tfjs-core';\n\nimport {MathBackendCPU} from '../backend_cpu';\nimport {assertNotComplex} from '../cpu_util';\n\nexport function conv2DBackpropFilter(args: {\n  inputs: Conv2DBackpropFilterInputs,\n  backend: MathBackendCPU,\n  attrs: Conv2DBackpropFilterAttrs\n}): TensorInfo {\n  const {inputs, backend, attrs} = args;\n  const {x, dy} = inputs;\n  const {strides, pad, dataFormat, dimRoundingMode, filterShape} = attrs;\n\n  assertNotComplex([x, dy], 'conv2dBackpropFilter');\n\n  const $dataFormat = backend_util.convertConv2DDataFormat(dataFormat);\n  const convInfo = backend_util.computeConv2DInfo(\n      x.shape as [number, number, number, number], filterShape, strides,\n      1 /* dilations */, pad, dimRoundingMode, false /* depthwise */,\n      $dataFormat);\n\n  const {strideHeight, strideWidth, filterHeight, filterWidth} = convInfo;\n  const isChannelsLast = convInfo.dataFormat === 'channelsLast';\n  const dW = new TensorBuffer(convInfo.filterShape, 'float32');\n\n  const leftPad = convInfo.padInfo.left;\n  const topPad = convInfo.padInfo.top;\n  const xVals = backend.data.get(x.dataId).values as TypedArray;\n  const dyVals = backend.data.get(dy.dataId).values as TypedArray;\n\n  const xBuf = new TensorBuffer(x.shape, x.dtype, xVals);\n  const dyBuf = new TensorBuffer(dy.shape, dy.dtype, dyVals);\n\n  for (let wR = 0; wR < filterHeight; ++wR) {\n    const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));\n    const yRMax = Math.min(\n        convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);\n\n    for (let wC = 0; wC < filterWidth; ++wC) {\n      const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));\n      const yCMax = Math.min(\n          convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);\n\n      for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {\n        for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {\n          let dotProd = 0;\n          for (let b = 0; b < convInfo.batchSize; ++b) {\n            for (let yR = yRMin; yR < yRMax; ++yR) {\n              const xR = wR + yR * strideHeight - topPad;\n              for (let yC = yCMin; yC < yCMax; ++yC) {\n                const xC = wC + yC * strideWidth - leftPad;\n                if (isChannelsLast) {\n                  dotProd += (xBuf.get(b, xR, xC, d1) as number) *\n                      (dyBuf.get(b, yR, yC, d2) as number);\n                } else {\n                  dotProd += (xBuf.get(b, d1, xR, xC) as number) *\n                      (dyBuf.get(b, d2, yR, yC) as number);\n                }\n              }\n            }\n          }\n          dW.set(dotProd, wR, wC, d1, d2);\n        }\n      }\n    }\n  }\n\n  return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);\n}\n\nexport const conv2DBackpropFilterConfig: KernelConfig = {\n  kernelName: Conv2DBackpropFilter,\n  backendName: 'cpu',\n  kernelFunc: conv2DBackpropFilter as unknown as KernelFunc\n};\n"]}
|