/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util, DepthwiseConv2dNativeBackpropFilter, TensorBuffer } from '@tensorflow/tfjs-core';
|
import { assertNotComplex } from '../cpu_util';
|
export function depthwiseConv2dNativeBackpropFilter(args) {
|
const { inputs, backend, attrs } = args;
|
const { x, dy } = inputs;
|
const { strides, dilations, pad, dimRoundingMode, filterShape } = attrs;
|
assertNotComplex([x, dy], 'depthwiseConv2dNativeBackpropFilter');
|
const convInfo = backend_util.computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
|
const { strideHeight, strideWidth, filterHeight, filterWidth } = convInfo;
|
const dW = new TensorBuffer(convInfo.filterShape, 'float32');
|
const leftPad = convInfo.padInfo.left;
|
const topPad = convInfo.padInfo.top;
|
const chMul = convInfo.outChannels / convInfo.inChannels;
|
const xVals = backend.data.get(x.dataId).values;
|
const xBuf = new TensorBuffer(x.shape, x.dtype, xVals);
|
const dyVals = backend.data.get(dy.dataId).values;
|
const dyBuf = new TensorBuffer(dy.shape, dy.dtype, dyVals);
|
for (let wR = 0; wR < filterHeight; ++wR) {
|
const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
|
const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
|
for (let wC = 0; wC < filterWidth; ++wC) {
|
const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
|
const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
|
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
|
const d1 = Math.trunc(d2 / chMul);
|
const dm = d2 % chMul;
|
let dotProd = 0;
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
for (let yR = yRMin; yR < yRMax; ++yR) {
|
const xR = wR + yR * strideHeight - topPad;
|
for (let yC = yCMin; yC < yCMax; ++yC) {
|
const xC = wC + yC * strideWidth - leftPad;
|
dotProd += xBuf.get(b, xR, xC, d1) *
|
dyBuf.get(b, yR, yC, d2);
|
}
|
}
|
}
|
dW.set(dotProd, wR, wC, d1, dm);
|
}
|
}
|
}
|
return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);
|
}
|
export const depthwiseConv2dNativeBackpropFilterConfig = {
|
kernelName: DepthwiseConv2dNativeBackpropFilter,
|
backendName: 'cpu',
|
kernelFunc: depthwiseConv2dNativeBackpropFilter
|
};
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiRGVwdGh3aXNlQ29udjJkTmF0aXZlQmFja3Byb3BGaWx0ZXIuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtY3B1L3NyYy9rZXJuZWxzL0RlcHRod2lzZUNvbnYyZE5hdGl2ZUJhY2twcm9wRmlsdGVyLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxZQUFZLEVBQUUsbUNBQW1DLEVBQWlILFlBQVksRUFBeUIsTUFBTSx1QkFBdUIsQ0FBQztBQUc3TyxPQUFPLEVBQUMsZ0JBQWdCLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFFN0MsTUFBTSxVQUFVLG1DQUFtQyxDQUFDLElBSW5EO0lBQ0MsTUFBTSxFQUFDLE1BQU0sRUFBRSxPQUFPLEVBQUUsS0FBSyxFQUFDLEdBQUcsSUFBSSxDQUFDO0lBQ3RDLE1BQU0sRUFBQyxDQUFDLEVBQUUsRUFBRSxFQUFDLEdBQUcsTUFBTSxDQUFDO0lBQ3ZCLE1BQU0sRUFBQyxPQUFPLEVBQUUsU0FBUyxFQUFFLEdBQUcsRUFBRSxlQUFlLEVBQUUsV0FBVyxFQUFDLEdBQUcsS0FBSyxDQUFDO0lBRXRFLGdCQUFnQixDQUFDLENBQUMsQ0FBQyxFQUFFLEVBQUUsQ0FBQyxFQUFFLHFDQUFxQyxDQUFDLENBQUM7SUFFakUsTUFBTSxRQUFRLEdBQUcsWUFBWSxDQUFDLGlCQUFpQixDQUMzQyxDQUFDLENBQUMsS0FBeUMsRUFBRSxXQUFXLEVBQUUsT0FBTyxFQUNqRSxTQUFTLEVBQUUsR0FBRyxFQUFFLGVBQWUsRUFBRSxJQUFJLENBQUMsZUFBZSxDQUFDLENBQUM7SUFFM0QsTUFBTSxFQUFDLFlBQVksRUFBRSxXQUFXLEVBQUUsWUFBWSxFQUFFLFdBQVcsRUFBQyxHQUFHLFFBQVEsQ0FBQztJQUV4RSxNQUFNLEVBQUUsR0FBRyxJQUFJLFlBQVksQ0FBQyxRQUFRLENBQUMsV0FBVyxFQUFFLFNBQVMsQ0FBQyxDQUFDO0lBRTdELE1BQU0sT0FBTyxHQUFHLFFBQVEsQ0FBQyxPQUFPLENBQUMsSUFBSSxDQUFDO0lBQ3RDLE1BQU0sTUFBTSxHQUFHLFFBQVEsQ0FBQyxPQUFPLENBQUMsR0FBRyxDQUFDO0lBQ3BDLE1BQU0sS0FBSyxHQUFHLFFBQVEsQ0FBQyxXQUFXLEdBQUcsUUFBUSxDQUFDLFVBQVUsQ0FBQztJQUV6RCxNQUFNLEtBQUssR0FBRyxPQUFPLENBQUMsSUFBSSxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsTUFBTSxDQUFDLENBQUMsTUFBb0IsQ0FBQztJQUM5RCxNQUFNLElBQUksR0FBRyxJQUFJLFlBQVksQ0FBQyxDQUFDLENBQUMsS0FBSyxFQUFFLENBQUMsQ0FBQyxLQUFLLEVBQUUsS0FBSyxDQUFDLENBQUM7SUFDdkQsTUFBTSxNQUFNLEdBQUcsT0FBTyxDQUFDLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLE1BQU0sQ0FBQyxDQUFDLE1BQW9CLENBQUM7SUFDaEUsTUFBTSxLQUFLLEdBQUcsSUFBSSxZQUFZLENBQUMsRUFBRSxDQUFDLEtBQUssRUFBRSxFQUFFLENBQUMsS0FBSyxFQUFFLE1BQU0sQ0FBQyxDQUFDO0lBQzNELEtBQUssSUFBSSxFQUFFLEdBQUcsQ0FBQyxFQUFFLEVBQUUsR0FBRyxZQUFZLEVBQUUsRUFBRSxFQUFFLEVBQUU7UUFDeEMsTUFBTSxLQUFLLEdBQUcsSUFBSSxDQUFDLEdBQUcsQ0FBQyxDQUFDLEVBQUUsSUFBSSxDQUFDLElBQUksQ0FBQyxDQUFDLE1BQU0sR0FBRyxFQUFFLENBQUMsR0FBRyxZQUFZLENBQUMsQ0FBQyxDQUFDO1FBQ25FLE1BQU0sS0FBSyxHQUFHLElBQUksQ0FBQyxHQUFHLENBQ2xCLFFBQVEsQ0FBQyxTQUFTLEVBQUUsQ0FBQyxRQUFRLENBQUMsUUFBUSxHQUFHLE1BQU0sR0FBRyxFQUFFLENBQUMsR0FBRyxZQUFZLENBQUMsQ0FBQztRQUUxRSxLQUFLLElBQUksRUFBRSxHQUFHLENBQUMsRUFBRSxFQUFFLEdBQUcsV0FBVyxFQUFFLEVBQUUsRUFBRSxFQUFFO1lBQ3ZDLE1BQU0sS0FBSyxHQUFHLElBQUksQ0FBQyxHQUFHLENBQUMsQ0FBQyxFQUFFLElBQUksQ0FBQyxJQUFJLENBQUMsQ0FBQyxPQUFPLEdBQUcsRUFBRSxDQUFDLEdBQUcsV0FBVyxDQUFDLENBQUMsQ0FBQztZQUNuRSxNQUFNLEtBQUssR0FBRyxJQUFJLENBQUMsR0FBRyxDQUNsQixRQUFRLENBQUMsUUFBUSxFQUFFLENBQUMsUUFBUSxDQUFDLE9BQU8sR0FBRyxPQUFPLEdBQUcsRUFBRSxDQUFDLEdBQUcsV0FBVyxDQUFDLENBQUM7WUFFeEUsS0FBSyxJQUFJLEVBQUUsR0FBRyxDQUFDLEVBQUUsRUFBRSxHQUFHLFFBQVEsQ0FBQyxXQUFXLEVBQUUsRUFBRSxFQUFFLEVBQUU7Z0JBQ2hELE1BQU0sRUFBRSxHQUFHLElBQUksQ0FBQyxLQUFLLENBQUMsRUFBRSxHQUFHLEtBQUssQ0FBQyxDQUFDO2dCQUNsQyxNQUFNLEVBQUUsR0FBRyxFQUFFLEdBQUcsS0FBSyxDQUFDO2dCQUV0QixJQUFJLE9BQU8sR0FBRyxDQUFDLENBQUM7Z0JBQ2hCLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxRQUFRLENBQUMsU0FBUyxFQUFFLEVBQUUsQ0FBQyxFQUFFO29CQUMzQyxLQUFLLElBQUksRUFBRSxHQUFHLEtBQUssRUFBRSxFQUFFLEdBQUcsS0FBSyxFQUFFLEVBQUUsRUFBRSxFQUFFO3dCQUNyQyxNQUFNLEVBQUUsR0FBRyxFQUFFLEdBQUcsRUFBRSxHQUFHLFlBQVksR0FBRyxNQUFNLENBQUM7d0JBQzNDLEtBQUssSUFBSSxFQUFFLEdBQUcsS0FBSyxFQUFFLEVBQUUsR0FBRyxLQUFLLEVBQUUsRUFBRSxFQUFFLEVBQUU7NEJBQ3JDLE1BQU0sRUFBRSxHQUFHLEVBQUUsR0FBRyxFQUFFLEdBQUcsV0FBVyxHQUFHLE9BQU8sQ0FBQzs0QkFDM0MsT0FBTyxJQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsQ0FBQyxFQUFFLEVBQUUsRUFBRSxFQUFFLEVBQUUsRUFBRSxDQUFZO2dDQUN6QyxLQUFLLENBQUMsR0FBRyxDQUFDLENBQUMsRUFBRSxFQUFFLEVBQUUsRUFBRSxFQUFFLEVBQUUsQ0FBWSxDQUFDO3lCQUMxQztxQkFDRjtpQkFDRjtnQkFDRCxFQUFFLENBQUMsR0FBRyxDQUFDLE9BQU8sRUFBRSxFQUFFLEVBQUUsRUFBRSxFQUFFLEVBQUUsRUFBRSxFQUFFLENBQUMsQ0FBQzthQUNqQztTQUNGO0tBQ0Y7SUFFRCxPQUFPLE9BQU8sQ0FBQyxjQUFjLENBQUMsRUFBRSxDQUFDLEtBQUssRUFBRSxFQUFFLENBQUMsS0FBSyxFQUFFLEVBQUUsQ0FBQyxNQUFNLENBQUMsQ0FBQztBQUMvRCxDQUFDO0FBRUQsTUFBTSxDQUFDLE1BQU0seUNBQXlDLEdBQWlCO0lBQ3JFLFVBQVUsRUFBRSxtQ0FBbUM7SUFDL0MsV0FBVyxFQUFFLEtBQUs7SUFDbEIsVUFBVSxFQUFFLG1DQUE0RDtDQUN6RSxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgRGVwdGh3aXNlQ29udjJkTmF0aXZlQmFja3Byb3BGaWx0ZXIsIERlcHRod2lzZUNvbnYyZE5hdGl2ZUJhY2twcm9wRmlsdGVyQXR0cnMsIERlcHRod2lzZUNvbnYyZE5hdGl2ZUJhY2twcm9wRmlsdGVySW5wdXRzLCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIFRlbnNvckJ1ZmZlciwgVGVuc29ySW5mbywgVHlwZWRBcnJheX0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcblxuaW1wb3J0IHtNYXRoQmFja2VuZENQVX0gZnJvbSAnLi4vYmFja2VuZF9jcHUnO1xuaW1wb3J0IHthc3NlcnROb3RDb21wbGV4fSBmcm9tICcuLi9jcHVfdXRpbCc7XG5cbmV4cG9ydCBmdW5jdGlvbiBkZXB0aHdpc2VDb252MmROYXRpdmVCYWNrcHJvcEZpbHRlcihhcmdzOiB7XG4gIGlucHV0czogRGVwdGh3aXNlQ29udjJkTmF0aXZlQmFja3Byb3BGaWx0ZXJJbnB1dHMsXG4gIGJhY2tlbmQ6IE1hdGhCYWNrZW5kQ1BVLFxuICBhdHRyczogRGVwdGh3aXNlQ29udjJkTmF0aXZlQmFja3Byb3BGaWx0ZXJBdHRyc1xufSk6IFRlbnNvckluZm8ge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7eCwgZHl9ID0gaW5wdXRzO1xuICBjb25zdCB7c3RyaWRlcywgZGlsYXRpb25zLCBwYWQsIGRpbVJvdW5kaW5nTW9kZSwgZmlsdGVyU2hhcGV9ID0gYXR0cnM7XG5cbiAgYXNzZXJ0Tm90Q29tcGxleChbeCwgZHldLCAnZGVwdGh3aXNlQ29udjJkTmF0aXZlQmFja3Byb3BGaWx0ZXInKTtcblxuICBjb25zdCBjb252SW5mbyA9IGJhY2tlbmRfdXRpbC5jb21wdXRlQ29udjJESW5mbyhcbiAgICAgIHguc2hhcGUgYXMgW251bWJlciwgbnVtYmVyLCBudW1iZXIsIG51bWJlcl0sIGZpbHRlclNoYXBlLCBzdHJpZGVzLFxuICAgICAgZGlsYXRpb25zLCBwYWQsIGRpbVJvdW5kaW5nTW9kZSwgdHJ1ZSAvKiBkZXB0aHdpc2UgKi8pO1xuXG4gIGNvbnN0IHtzdHJpZGVIZWlnaHQsIHN0cmlkZVdpZHRoLCBmaWx0ZXJIZWlnaHQsIGZpbHRlcldpZHRofSA9IGNvbnZJbmZvO1xuXG4gIGNvbnN0IGRXID0gbmV3IFRlbnNvckJ1ZmZlcihjb252SW5mby5maWx0ZXJTaGFwZSwgJ2Zsb2F0MzInKTtcblxuICBjb25zdCBsZWZ0UGFkID0gY29udkluZm8ucGFkSW5mby5sZWZ0O1xuICBjb25zdCB0b3BQYWQgPSBjb252SW5mby5wYWRJbmZvLnRvcDtcbiAgY29uc3QgY2hNdWwgPSBjb252SW5mby5vdXRDaGFubmVscyAvIGNvbnZJbmZvLmluQ2hhbm5lbHM7XG5cbiAgY29uc3QgeFZhbHMgPSBiYWNrZW5kLmRhdGEuZ2V0KHguZGF0YUlkKS52YWx1ZXMgYXMgVHlwZWRBcnJheTtcbiAgY29uc3QgeEJ1ZiA9IG5ldyBUZW5zb3JCdWZmZXIoeC5zaGFwZSwgeC5kdHlwZSwgeFZhbHMpO1xuICBjb25zdCBkeVZhbHMgPSBiYWNrZW5kLmRhdGEuZ2V0KGR5LmRhdGFJZCkudmFsdWVzIGFzIFR5cGVkQXJyYXk7XG4gIGNvbnN0IGR5QnVmID0gbmV3IFRlbnNvckJ1ZmZlcihkeS5zaGFwZSwgZHkuZHR5cGUsIGR5VmFscyk7XG4gIGZvciAobGV0IHdSID0gMDsgd1IgPCBmaWx0ZXJIZWlnaHQ7ICsrd1IpIHtcbiAgICBjb25zdCB5Uk1pbiA9IE1hdGgubWF4KDAsIE1hdGguY2VpbCgodG9wUGFkIC0gd1IpIC8gc3RyaWRlSGVpZ2h0KSk7XG4gICAgY29uc3QgeVJNYXggPSBNYXRoLm1pbihcbiAgICAgICAgY29udkluZm8ub3V0SGVpZ2h0LCAoY29udkluZm8uaW5IZWlnaHQgKyB0b3BQYWQgLSB3UikgLyBzdHJpZGVIZWlnaHQpO1xuXG4gICAgZm9yIChsZXQgd0MgPSAwOyB3QyA8IGZpbHRlcldpZHRoOyArK3dDKSB7XG4gICAgICBjb25zdCB5Q01pbiA9IE1hdGgubWF4KDAsIE1hdGguY2VpbCgobGVmdFBhZCAtIHdDKSAvIHN0cmlkZVdpZHRoKSk7XG4gICAgICBjb25zdCB5Q01heCA9IE1hdGgubWluKFxuICAgICAgICAgIGNvbnZJbmZvLm91dFdpZHRoLCAoY29udkluZm8uaW5XaWR0aCArIGxlZnRQYWQgLSB3QykgLyBzdHJpZGVXaWR0aCk7XG5cbiAgICAgIGZvciAobGV0IGQyID0gMDsgZDIgPCBjb252SW5mby5vdXRDaGFubmVsczsgKytkMikge1xuICAgICAgICBjb25zdCBkMSA9IE1hdGgudHJ1bmMoZDIgLyBjaE11bCk7XG4gICAgICAgIGNvbnN0IGRtID0gZDIgJSBjaE11bDtcblxuICAgICAgICBsZXQgZG90UHJvZCA9IDA7XG4gICAgICAgIGZvciAobGV0IGIgPSAwOyBiIDwgY29udkluZm8uYmF0Y2hTaXplOyArK2IpIHtcbiAgICAgICAgICBmb3IgKGxldCB5UiA9IHlSTWluOyB5UiA8IHlSTWF4OyArK3lSKSB7XG4gICAgICAgICAgICBjb25zdCB4UiA9IHdSICsgeVIgKiBzdHJpZGVIZWlnaHQgLSB0b3BQYWQ7XG4gICAgICAgICAgICBmb3IgKGxldCB5QyA9IHlDTWluOyB5QyA8IHlDTWF4OyArK3lDKSB7XG4gICAgICAgICAgICAgIGNvbnN0IHhDID0gd0MgKyB5QyAqIHN0cmlkZVdpZHRoIC0gbGVmdFBhZDtcbiAgICAgICAgICAgICAgZG90UHJvZCArPSAoeEJ1Zi5nZXQoYiwgeFIsIHhDLCBkMSkgYXMgbnVtYmVyKSAqXG4gICAgICAgICAgICAgICAgICAoZHlCdWYuZ2V0KGIsIHlSLCB5QywgZDIpIGFzIG51bWJlcik7XG4gICAgICAgICAgICB9XG4gICAgICAgICAgfVxuICAgICAgICB9XG4gICAgICAgIGRXLnNldChkb3RQcm9kLCB3Uiwgd0MsIGQxLCBkbSk7XG4gICAgICB9XG4gICAgfVxuICB9XG5cbiAgcmV0dXJuIGJhY2tlbmQubWFrZVRlbnNvckluZm8oZFcuc2hhcGUsIGRXLmR0eXBlLCBkVy52YWx1ZXMpO1xufVxuXG5leHBvcnQgY29uc3QgZGVwdGh3aXNlQ29udjJkTmF0aXZlQmFja3Byb3BGaWx0ZXJDb25maWc6IEtlcm5lbENvbmZpZyA9IHtcbiAga2VybmVsTmFtZTogRGVwdGh3aXNlQ29udjJkTmF0aXZlQmFja3Byb3BGaWx0ZXIsXG4gIGJhY2tlbmROYW1lOiAnY3B1JyxcbiAga2VybmVsRnVuYzogZGVwdGh3aXNlQ29udjJkTmF0aXZlQmFja3Byb3BGaWx0ZXIgYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jXG59O1xuIl19
|