/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { DepthwiseConv2dNative } from '../kernel_names'; import * as conv_util from '../ops/conv_util'; import { depthwiseConv2dNativeBackpropFilter } from '../ops/depthwise_conv2d_native_backprop_filter'; import { depthwiseConv2dNativeBackpropInput } from '../ops/depthwise_conv2d_native_backprop_input'; import * as util from '../util'; export const depthwiseConv2dNativeGradConfig = { kernelName: DepthwiseConv2dNative, inputsToSave: ['x', 'filter'], gradFunc: (dy, saved, attrs) => { const { dilations, strides, pad, dimRoundingMode } = attrs; const $dilations = dilations == null ? [1, 1] : dilations; util.assert(conv_util.tupleValuesAreOne($dilations), () => 'Error in gradient of depthwiseConv2dNative: dilation rates ' + `greater than 1 are not yet supported. Got dilations ` + `'${$dilations}'`); const [x, filter] = saved; util.assert(x.rank === 4, () => `Error in gradient of depthwiseConv2dNative: input must be ` + `rank 4, but got rank ${x.rank}.`); util.assert(filter.rank === 4, () => `Error in gradient of depthwiseConv2dNative: filter must be ` + `rank 4, but got rank ${filter.rank}.`); util.assert(x.shape[3] === filter.shape[2], () => `Error in gradient of depthwiseConv2d: number of input ` + `channels (${x.shape[3]}) must match the inChannels dimension ` + `in filter ${filter.shape[2]}.`); util.assert(conv_util.eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in gradient of depthwiseConv2d: Either strides or ' + `dilations must be 1. Got strides ${strides} and dilations ` + `'${$dilations}'.`); conv_util.checkPadOnDimRoundingMode('depthwiseConv2d', pad, dimRoundingMode); return { x: () => depthwiseConv2dNativeBackpropInput(x.shape, dy, filter, strides, pad, $dilations, dimRoundingMode), filter: () => depthwiseConv2dNativeBackpropFilter(x, dy, filter.shape, strides, pad, $dilations, dimRoundingMode), }; } }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiRGVwdGh3aXNlQ29udjJkTmF0aXZlX2dyYWQuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvcmUvc3JjL2dyYWRpZW50cy9EZXB0aHdpc2VDb252MmROYXRpdmVfZ3JhZC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFDSCxPQUFPLEVBQUMscUJBQXFCLEVBQTZCLE1BQU0saUJBQWlCLENBQUM7QUFFbEYsT0FBTyxLQUFLLFNBQVMsTUFBTSxrQkFBa0IsQ0FBQztBQUM5QyxPQUFPLEVBQUMsbUNBQW1DLEVBQUMsTUFBTSxnREFBZ0QsQ0FBQztBQUNuRyxPQUFPLEVBQUMsa0NBQWtDLEVBQUMsTUFBTSwrQ0FBK0MsQ0FBQztBQUVqRyxPQUFPLEtBQUssSUFBSSxNQUFNLFNBQVMsQ0FBQztBQUVoQyxNQUFNLENBQUMsTUFBTSwrQkFBK0IsR0FBZTtJQUN6RCxVQUFVLEVBQUUscUJBQXFCO0lBQ2pDLFlBQVksRUFBRSxDQUFDLEdBQUcsRUFBRSxRQUFRLENBQUM7SUFDN0IsUUFBUSxFQUFFLENBQUMsRUFBWSxFQUFFLEtBQWUsRUFBRSxLQUFtQixFQUFFLEVBQUU7UUFDL0QsTUFBTSxFQUFDLFNBQVMsRUFBRSxPQUFPLEVBQUUsR0FBRyxFQUFFLGVBQWUsRUFBQyxHQUM1QyxLQUE4QyxDQUFDO1FBQ25ELE1BQU0sVUFBVSxHQUFHLFNBQVMsSUFBSSxJQUFJLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBbUIsQ0FBQyxDQUFDLENBQUMsU0FBUyxDQUFDO1FBRTVFLElBQUksQ0FBQyxNQUFNLENBQ1AsU0FBUyxDQUFDLGlCQUFpQixDQUFDLFVBQVUsQ0FBQyxFQUN2QyxHQUFHLEVBQUUsQ0FBQyw2REFBNkQ7WUFDL0Qsc0RBQXNEO1lBQ3RELElBQUksVUFBVSxHQUFHLENBQUMsQ0FBQztRQUUzQixNQUFNLENBQUMsQ0FBQyxFQUFFLE1BQU0sQ0FBQyxHQUFHLEtBQTZCLENBQUM7UUFFbEQsSUFBSSxDQUFDLE1BQU0sQ0FDUCxDQUFDLENBQUMsSUFBSSxLQUFLLENBQUMsRUFDWixHQUFHLEVBQUUsQ0FBQyw0REFBNEQ7WUFDOUQsd0JBQXdCLENBQUMsQ0FBQyxJQUFJLEdBQUcsQ0FBQyxDQUFDO1FBQzNDLElBQUksQ0FBQyxNQUFNLENBQ1AsTUFBTSxDQUFDLElBQUksS0FBSyxDQUFDLEVBQ2pCLEdBQUcsRUFBRSxDQUFDLDZEQUE2RDtZQUMvRCx3QkFBd0IsTUFBTSxDQUFDLElBQUksR0FBRyxDQUFDLENBQUM7UUFDaEQsSUFBSSxDQUFDLE1BQU0sQ0FDUCxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQyxLQUFLLE1BQU0sQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLEVBQzlCLEdBQUcsRUFBRSxDQUFDLHdEQUF3RDtZQUMxRCxhQUFhLENBQUMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLHdDQUF3QztZQUMvRCxhQUFhLE1BQU0sQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxDQUFDO1FBRXpDLElBQUksQ0FBQyxNQUFNLENBQ1AsU0FBUyxDQUFDLDhCQUE4QixDQUFDLE9BQU8sRUFBRSxVQUFVLENBQUMsRUFDN0QsR0FBRyxFQUFFLENBQUMsMERBQTBEO1lBQzVELHFDQUFxQyxPQUFPLGlCQUFpQjtZQUM3RCxJQUFJLFVBQVUsSUFBSSxDQUFDLENBQUM7UUFFNUIsU0FBUyxDQUFDLHlCQUF5QixDQUMvQixpQkFBaUIsRUFBRSxHQUFHLEVBQUUsZUFBZSxDQUFDLENBQUM7UUFFN0MsT0FBTztZQUNMLENBQUMsRUFBRSxHQUFHLEVBQUUsQ0FBQyxrQ0FBa0MsQ0FDdkMsQ0FBQyxDQUFDLEtBQUssRUFBRSxFQUFFLEVBQUUsTUFBTSxFQUFFLE9BQU8sRUFBRSxHQUFHLEVBQUUsVUFBVSxFQUFFLGVBQWUsQ0FBQztZQUNuRSxNQUFNLEVBQUUsR0FBRyxFQUFFLENBQUMsbUNBQW1DLENBQzdDLENBQUMsRUFBRSxFQUFFLEVBQUUsTUFBTSxDQUFDLEtBQUssRUFBRSxPQUFPLEVBQUUsR0FBRyxFQUFFLFVBQVUsRUFBRSxlQUFlLENBQUM7U0FDcEUsQ0FBQztJQUNKLENBQUM7Q0FDRixDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuaW1wb3J0IHtEZXB0aHdpc2VDb252MmROYXRpdmUsIERlcHRod2lzZUNvbnYyZE5hdGl2ZUF0dHJzfSBmcm9tICcuLi9rZXJuZWxfbmFtZXMnO1xuaW1wb3J0IHtHcmFkQ29uZmlnLCBOYW1lZEF0dHJNYXB9IGZyb20gJy4uL2tlcm5lbF9yZWdpc3RyeSc7XG5pbXBvcnQgKiBhcyBjb252X3V0aWwgZnJvbSAnLi4vb3BzL2NvbnZfdXRpbCc7XG5pbXBvcnQge2RlcHRod2lzZUNvbnYyZE5hdGl2ZUJhY2twcm9wRmlsdGVyfSBmcm9tICcuLi9vcHMvZGVwdGh3aXNlX2NvbnYyZF9uYXRpdmVfYmFja3Byb3BfZmlsdGVyJztcbmltcG9ydCB7ZGVwdGh3aXNlQ29udjJkTmF0aXZlQmFja3Byb3BJbnB1dH0gZnJvbSAnLi4vb3BzL2RlcHRod2lzZV9jb252MmRfbmF0aXZlX2JhY2twcm9wX2lucHV0JztcbmltcG9ydCB7VGVuc29yLCBUZW5zb3I0RH0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCAqIGFzIHV0aWwgZnJvbSAnLi4vdXRpbCc7XG5cbmV4cG9ydCBjb25zdCBkZXB0aHdpc2VDb252MmROYXRpdmVHcmFkQ29uZmlnOiBHcmFkQ29uZmlnID0ge1xuICBrZXJuZWxOYW1lOiBEZXB0aHdpc2VDb252MmROYXRpdmUsXG4gIGlucHV0c1RvU2F2ZTogWyd4JywgJ2ZpbHRlciddLFxuICBncmFkRnVuYzogKGR5OiBUZW5zb3I0RCwgc2F2ZWQ6IFRlbnNvcltdLCBhdHRyczogTmFtZWRBdHRyTWFwKSA9PiB7XG4gICAgY29uc3Qge2RpbGF0aW9ucywgc3RyaWRlcywgcGFkLCBkaW1Sb3VuZGluZ01vZGV9ID1cbiAgICAgICAgYXR0cnMgYXMgdW5rbm93biBhcyBEZXB0aHdpc2VDb252MmROYXRpdmVBdHRycztcbiAgICBjb25zdCAkZGlsYXRpb25zID0gZGlsYXRpb25zID09IG51bGwgPyBbMSwgMV0gYXNbbnVtYmVyLG51bWJlcl0gOiBkaWxhdGlvbnM7XG5cbiAgICB1dGlsLmFzc2VydChcbiAgICAgICAgY29udl91dGlsLnR1cGxlVmFsdWVzQXJlT25lKCRkaWxhdGlvbnMpLFxuICAgICAgICAoKSA9PiAnRXJyb3IgaW4gZ3JhZGllbnQgb2YgZGVwdGh3aXNlQ29udjJkTmF0aXZlOiBkaWxhdGlvbiByYXRlcyAnICtcbiAgICAgICAgICAgIGBncmVhdGVyIHRoYW4gMSBhcmUgbm90IHlldCBzdXBwb3J0ZWQuIEdvdCBkaWxhdGlvbnMgYCArXG4gICAgICAgICAgICBgJyR7JGRpbGF0aW9uc30nYCk7XG5cbiAgICBjb25zdCBbeCwgZmlsdGVyXSA9IHNhdmVkIGFzIFtUZW5zb3I0RCwgVGVuc29yNERdO1xuXG4gICAgdXRpbC5hc3NlcnQoXG4gICAgICAgIHgucmFuayA9PT0gNCxcbiAgICAgICAgKCkgPT4gYEVycm9yIGluIGdyYWRpZW50IG9mIGRlcHRod2lzZUNvbnYyZE5hdGl2ZTogaW5wdXQgbXVzdCBiZSBgICtcbiAgICAgICAgICAgIGByYW5rIDQsIGJ1dCBnb3QgcmFuayAke3gucmFua30uYCk7XG4gICAgdXRpbC5hc3NlcnQoXG4gICAgICAgIGZpbHRlci5yYW5rID09PSA0LFxuICAgICAgICAoKSA9PiBgRXJyb3IgaW4gZ3JhZGllbnQgb2YgZGVwdGh3aXNlQ29udjJkTmF0aXZlOiBmaWx0ZXIgbXVzdCBiZSBgICtcbiAgICAgICAgICAgIGByYW5rIDQsIGJ1dCBnb3QgcmFuayAke2ZpbHRlci5yYW5rfS5gKTtcbiAgICB1dGlsLmFzc2VydChcbiAgICAgICAgeC5zaGFwZVszXSA9PT0gZmlsdGVyLnNoYXBlWzJdLFxuICAgICAgICAoKSA9PiBgRXJyb3IgaW4gZ3JhZGllbnQgb2YgZGVwdGh3aXNlQ29udjJkOiBudW1iZXIgb2YgaW5wdXQgYCArXG4gICAgICAgICAgICBgY2hhbm5lbHMgKCR7eC5zaGFwZVszXX0pIG11c3QgbWF0Y2ggdGhlIGluQ2hhbm5lbHMgZGltZW5zaW9uIGAgK1xuICAgICAgICAgICAgYGluIGZpbHRlciAke2ZpbHRlci5zaGFwZVsyXX0uYCk7XG5cbiAgICB1dGlsLmFzc2VydChcbiAgICAgICAgY29udl91dGlsLmVpdGhlclN0cmlkZXNPckRpbGF0aW9uc0FyZU9uZShzdHJpZGVzLCAkZGlsYXRpb25zKSxcbiAgICAgICAgKCkgPT4gJ0Vycm9yIGluIGdyYWRpZW50IG9mIGRlcHRod2lzZUNvbnYyZDogRWl0aGVyIHN0cmlkZXMgb3IgJyArXG4gICAgICAgICAgICBgZGlsYXRpb25zIG11c3QgYmUgIDEuIEdvdCBzdHJpZGVzICR7c3RyaWRlc30gYW5kIGRpbGF0aW9ucyBgICtcbiAgICAgICAgICAgIGAnJHskZGlsYXRpb25zfScuYCk7XG5cbiAgICBjb252X3V0aWwuY2hlY2tQYWRPbkRpbVJvdW5kaW5nTW9kZShcbiAgICAgICAgJ2RlcHRod2lzZUNvbnYyZCcsIHBhZCwgZGltUm91bmRpbmdNb2RlKTtcblxuICAgIHJldHVybiB7XG4gICAgICB4OiAoKSA9PiBkZXB0aHdpc2VDb252MmROYXRpdmVCYWNrcHJvcElucHV0KFxuICAgICAgICAgIHguc2hhcGUsIGR5LCBmaWx0ZXIsIHN0cmlkZXMsIHBhZCwgJGRpbGF0aW9ucywgZGltUm91bmRpbmdNb2RlKSxcbiAgICAgIGZpbHRlcjogKCkgPT4gZGVwdGh3aXNlQ29udjJkTmF0aXZlQmFja3Byb3BGaWx0ZXIoXG4gICAgICAgICAgeCwgZHksIGZpbHRlci5zaGFwZSwgc3RyaWRlcywgcGFkLCAkZGlsYXRpb25zLCBkaW1Sb3VuZGluZ01vZGUpLFxuICAgIH07XG4gIH1cbn07XG4iXX0=