/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, DepthwiseConv2dNative, env, util } from '@tensorflow/tfjs-core'; import { DepthwiseConv2DProgram } from '../conv_gpu_depthwise'; import { DepthwiseConvPacked2DProgram } from '../conv_packed_gpu_depthwise'; export function depthwiseConv2dNative(args) { const { inputs, backend, attrs } = args; const { x, filter } = inputs; const { strides, pad, dilations, dimRoundingMode } = attrs; let $dilations = dilations; if ($dilations == null) { $dilations = [1, 1]; } util.assert(backend_util.eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' + `1. Got strides ${strides} and dilations '${$dilations}'`); const convInfo = backend_util.computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */); let program; if (env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1) { program = new DepthwiseConvPacked2DProgram(convInfo); } else { program = new DepthwiseConv2DProgram(convInfo); } const customValues = [ [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth] ]; return backend.runWebGLProgram(program, [x, filter], 'float32', customValues); } export const depthwiseConv2dNativeConfig = { kernelName: DepthwiseConv2dNative, backendName: 'webgl', kernelFunc: depthwiseConv2dNative, }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiRGVwdGh3aXNlQ29udjJkTmF0aXZlLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLXdlYmdsL3NyYy9rZXJuZWxzL0RlcHRod2lzZUNvbnYyZE5hdGl2ZS50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsWUFBWSxFQUFFLHFCQUFxQixFQUEyRCxHQUFHLEVBQTRCLElBQUksRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBR3hLLE9BQU8sRUFBQyxzQkFBc0IsRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBQzdELE9BQU8sRUFBQyw0QkFBNEIsRUFBQyxNQUFNLDhCQUE4QixDQUFDO0FBRTFFLE1BQU0sVUFBVSxxQkFBcUIsQ0FBQyxJQUlyQztJQUNDLE1BQU0sRUFBQyxNQUFNLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBQyxHQUFHLElBQUksQ0FBQztJQUN0QyxNQUFNLEVBQUMsQ0FBQyxFQUFFLE1BQU0sRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUMzQixNQUFNLEVBQUMsT0FBTyxFQUFFLEdBQUcsRUFBRSxTQUFTLEVBQUUsZUFBZSxFQUFDLEdBQUcsS0FBSyxDQUFDO0lBRXpELElBQUksVUFBVSxHQUFHLFNBQVMsQ0FBQztJQUMzQixJQUFJLFVBQVUsSUFBSSxJQUFJLEVBQUU7UUFDdEIsVUFBVSxHQUFHLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDO0tBQ3JCO0lBRUQsSUFBSSxDQUFDLE1BQU0sQ0FDUCxZQUFZLENBQUMsOEJBQThCLENBQUMsT0FBTyxFQUFFLFVBQVUsQ0FBQyxFQUNoRSxHQUFHLEVBQUUsQ0FBQyxnRUFBZ0U7UUFDbEUsa0JBQWtCLE9BQU8sbUJBQW1CLFVBQVUsR0FBRyxDQUFDLENBQUM7SUFFbkUsTUFBTSxRQUFRLEdBQUcsWUFBWSxDQUFDLGlCQUFpQixDQUMzQyxDQUFDLENBQUMsS0FBeUMsRUFDM0MsTUFBTSxDQUFDLEtBQXlDLEVBQUUsT0FBTyxFQUFFLFVBQVUsRUFDckUsR0FBRyxFQUFFLGVBQWUsRUFBRSxJQUFJLENBQUMsZUFBZSxDQUFDLENBQUM7SUFFaEQsSUFBSSxPQUE0RCxDQUFDO0lBQ2pFLElBQUksR0FBRyxFQUFFLENBQUMsT0FBTyxDQUFDLDBCQUEwQixDQUFDLElBQUksUUFBUSxDQUFDLFdBQVcsSUFBSSxDQUFDO1FBQ3RFLFFBQVEsQ0FBQyxXQUFXLEdBQUcsUUFBUSxDQUFDLFVBQVUsS0FBSyxDQUFDLEVBQUU7UUFDcEQsT0FBTyxHQUFHLElBQUksNEJBQTRCLENBQUMsUUFBUSxDQUFDLENBQUM7S0FDdEQ7U0FBTTtRQUNMLE9BQU8sR0FBRyxJQUFJLHNCQUFzQixDQUFDLFFBQVEsQ0FBQyxDQUFDO0tBQ2hEO0lBQ0QsTUFBTSxZQUFZLEdBQUc7UUFDbkIsQ0FBQyxRQUFRLENBQUMsT0FBTyxDQUFDLEdBQUcsRUFBRSxRQUFRLENBQUMsT0FBTyxDQUFDLElBQUksQ0FBQztRQUM3QyxDQUFDLFFBQVEsQ0FBQyxZQUFZLEVBQUUsUUFBUSxDQUFDLFdBQVcsQ0FBQztRQUM3QyxDQUFDLFFBQVEsQ0FBQyxjQUFjLEVBQUUsUUFBUSxDQUFDLGFBQWEsQ0FBQztRQUNqRCxDQUFDLFFBQVEsQ0FBQyxRQUFRLEVBQUUsUUFBUSxDQUFDLE9BQU8sQ0FBQztLQUN0QyxDQUFDO0lBQ0YsT0FBTyxPQUFPLENBQUMsZUFBZSxDQUFDLE9BQU8sRUFBRSxDQUFDLENBQUMsRUFBRSxNQUFNLENBQUMsRUFBRSxTQUFTLEVBQUUsWUFBWSxDQUFDLENBQUM7QUFDaEYsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLDJCQUEyQixHQUFpQjtJQUN2RCxVQUFVLEVBQUUscUJBQXFCO0lBQ2pDLFdBQVcsRUFBRSxPQUFPO0lBQ3BCLFVBQVUsRUFBRSxxQkFBOEM7Q0FDM0QsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtiYWNrZW5kX3V0aWwsIERlcHRod2lzZUNvbnYyZE5hdGl2ZSwgRGVwdGh3aXNlQ29udjJkTmF0aXZlQXR0cnMsIERlcHRod2lzZUNvbnYyZE5hdGl2ZUlucHV0cywgZW52LCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRXZWJHTH0gZnJvbSAnLi4vYmFja2VuZF93ZWJnbCc7XG5pbXBvcnQge0RlcHRod2lzZUNvbnYyRFByb2dyYW19IGZyb20gJy4uL2NvbnZfZ3B1X2RlcHRod2lzZSc7XG5pbXBvcnQge0RlcHRod2lzZUNvbnZQYWNrZWQyRFByb2dyYW19IGZyb20gJy4uL2NvbnZfcGFja2VkX2dwdV9kZXB0aHdpc2UnO1xuXG5leHBvcnQgZnVuY3Rpb24gZGVwdGh3aXNlQ29udjJkTmF0aXZlKGFyZ3M6IHtcbiAgaW5wdXRzOiBEZXB0aHdpc2VDb252MmROYXRpdmVJbnB1dHMsXG4gIGF0dHJzOiBEZXB0aHdpc2VDb252MmROYXRpdmVBdHRycyxcbiAgYmFja2VuZDogTWF0aEJhY2tlbmRXZWJHTFxufSkge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7eCwgZmlsdGVyfSA9IGlucHV0cztcbiAgY29uc3Qge3N0cmlkZXMsIHBhZCwgZGlsYXRpb25zLCBkaW1Sb3VuZGluZ01vZGV9ID0gYXR0cnM7XG5cbiAgbGV0ICRkaWxhdGlvbnMgPSBkaWxhdGlvbnM7XG4gIGlmICgkZGlsYXRpb25zID09IG51bGwpIHtcbiAgICAkZGlsYXRpb25zID0gWzEsIDFdO1xuICB9XG5cbiAgdXRpbC5hc3NlcnQoXG4gICAgICBiYWNrZW5kX3V0aWwuZWl0aGVyU3RyaWRlc09yRGlsYXRpb25zQXJlT25lKHN0cmlkZXMsICRkaWxhdGlvbnMpLFxuICAgICAgKCkgPT4gJ0Vycm9yIGluIGRlcHRod2lzZUNvbnYyZDogRWl0aGVyIHN0cmlkZXMgb3IgZGlsYXRpb25zIG11c3QgYmUgJyArXG4gICAgICAgICAgYDEuIEdvdCBzdHJpZGVzICR7c3RyaWRlc30gYW5kIGRpbGF0aW9ucyAnJHskZGlsYXRpb25zfSdgKTtcblxuICBjb25zdCBjb252SW5mbyA9IGJhY2tlbmRfdXRpbC5jb21wdXRlQ29udjJESW5mbyhcbiAgICAgIHguc2hhcGUgYXMgW251bWJlciwgbnVtYmVyLCBudW1iZXIsIG51bWJlcl0sXG4gICAgICBmaWx0ZXIuc2hhcGUgYXMgW251bWJlciwgbnVtYmVyLCBudW1iZXIsIG51bWJlcl0sIHN0cmlkZXMsICRkaWxhdGlvbnMsXG4gICAgICBwYWQsIGRpbVJvdW5kaW5nTW9kZSwgdHJ1ZSAvKiBkZXB0aHdpc2UgKi8pO1xuXG4gIGxldCBwcm9ncmFtOiBEZXB0aHdpc2VDb252MkRQcm9ncmFtfERlcHRod2lzZUNvbnZQYWNrZWQyRFByb2dyYW07XG4gIGlmIChlbnYoKS5nZXRCb29sKCdXRUJHTF9QQUNLX0RFUFRIV0lTRUNPTlYnKSAmJiBjb252SW5mby5zdHJpZGVXaWR0aCA8PSAyICYmXG4gICAgICBjb252SW5mby5vdXRDaGFubmVscyAvIGNvbnZJbmZvLmluQ2hhbm5lbHMgPT09IDEpIHtcbiAgICBwcm9ncmFtID0gbmV3IERlcHRod2lzZUNvbnZQYWNrZWQyRFByb2dyYW0oY29udkluZm8pO1xuICB9IGVsc2Uge1xuICAgIHByb2dyYW0gPSBuZXcgRGVwdGh3aXNlQ29udjJEUHJvZ3JhbShjb252SW5mbyk7XG4gIH1cbiAgY29uc3QgY3VzdG9tVmFsdWVzID0gW1xuICAgIFtjb252SW5mby5wYWRJbmZvLnRvcCwgY29udkluZm8ucGFkSW5mby5sZWZ0XSxcbiAgICBbY29udkluZm8uc3RyaWRlSGVpZ2h0LCBjb252SW5mby5zdHJpZGVXaWR0aF0sXG4gICAgW2NvbnZJbmZvLmRpbGF0aW9uSGVpZ2h0LCBjb252SW5mby5kaWxhdGlvbldpZHRoXSxcbiAgICBbY29udkluZm8uaW5IZWlnaHQsIGNvbnZJbmZvLmluV2lkdGhdXG4gIF07XG4gIHJldHVybiBiYWNrZW5kLnJ1bldlYkdMUHJvZ3JhbShwcm9ncmFtLCBbeCwgZmlsdGVyXSwgJ2Zsb2F0MzInLCBjdXN0b21WYWx1ZXMpO1xufVxuXG5leHBvcnQgY29uc3QgZGVwdGh3aXNlQ29udjJkTmF0aXZlQ29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IERlcHRod2lzZUNvbnYyZE5hdGl2ZSxcbiAgYmFja2VuZE5hbWU6ICd3ZWJnbCcsXG4gIGtlcm5lbEZ1bmM6IGRlcHRod2lzZUNvbnYyZE5hdGl2ZSBhcyB1bmtub3duIGFzIEtlcm5lbEZ1bmMsXG59O1xuIl19