/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, env, FusedDepthwiseConv2D, util } from '@tensorflow/tfjs-core'; import { DepthwiseConv2DProgram } from '../conv_gpu_depthwise'; import { DepthwiseConvPacked2DProgram } from '../conv_packed_gpu_depthwise'; import { mapActivationToShaderProgram } from '../kernel_utils/kernel_funcs_utils'; export function fusedDepthwiseConv2D(args) { const { inputs, backend, attrs } = args; const { x, filter, bias, preluActivationWeights } = inputs; const { strides, pad, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs; const intermediates = []; let $dilations = dilations; if ($dilations == null) { $dilations = [1, 1]; } util.assert(backend_util.eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' + `1. Got strides ${strides} and dilations '${$dilations}'`); const convInfo = backend_util.computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */); const shouldPackDepthwiseConv = env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1; const fusedActivation = activation ? mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) : null; const programInputs = [x, filter]; const hasBias = bias != null; const hasPreluActivationWeights = preluActivationWeights != null; const hasLeakyreluAlpha = activation === 'leakyrelu'; if (hasBias) { programInputs.push(bias); } if (hasPreluActivationWeights) { programInputs.push(preluActivationWeights); } if (hasLeakyreluAlpha) { const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', util.createScalarValue(leakyreluAlpha, 'float32')); programInputs.push($leakyreluAlpha); intermediates.push($leakyreluAlpha); } let program; if (shouldPackDepthwiseConv) { program = new DepthwiseConvPacked2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); } else { program = new DepthwiseConv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); } const customValues = [ [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth] ]; const result = backend.runWebGLProgram(program, programInputs, 'float32', customValues); intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t)); return result; } export const fusedDepthwiseConv2DConfig = { kernelName: FusedDepthwiseConv2D, backendName: 'webgl', kernelFunc: fusedDepthwiseConv2D, }; //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiRnVzZWREZXB0aHdpc2VDb252MkQuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtd2ViZ2wvc3JjL2tlcm5lbHMvRnVzZWREZXB0aHdpc2VDb252MkQudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLFlBQVksRUFBRSxHQUFHLEVBQUUsb0JBQW9CLEVBQStGLElBQUksRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBR2pMLE9BQU8sRUFBQyxzQkFBc0IsRUFBQyxNQUFNLHVCQUF1QixDQUFDO0FBQzdELE9BQU8sRUFBQyw0QkFBNEIsRUFBQyxNQUFNLDhCQUE4QixDQUFDO0FBQzFFLE9BQU8sRUFBQyw0QkFBNEIsRUFBQyxNQUFNLG9DQUFvQyxDQUFDO0FBRWhGLE1BQU0sVUFBVSxvQkFBb0IsQ0FBQyxJQUlwQztJQUNDLE1BQU0sRUFBQyxNQUFNLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBQyxHQUFHLElBQUksQ0FBQztJQUN0QyxNQUFNLEVBQUMsQ0FBQyxFQUFFLE1BQU0sRUFBRSxJQUFJLEVBQUUsc0JBQXNCLEVBQUMsR0FBRyxNQUFNLENBQUM7SUFDekQsTUFBTSxFQUFDLE9BQU8sRUFBRSxHQUFHLEVBQUUsU0FBUyxFQUFFLGVBQWUsRUFBRSxVQUFVLEVBQUUsY0FBYyxFQUFDLEdBQ3hFLEtBQUssQ0FBQztJQUVWLE1BQU0sYUFBYSxHQUFpQixFQUFFLENBQUM7SUFFdkMsSUFBSSxVQUFVLEdBQUcsU0FBUyxDQUFDO0lBQzNCLElBQUksVUFBVSxJQUFJLElBQUksRUFBRTtRQUN0QixVQUFVLEdBQUcsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUM7S0FDckI7SUFFRCxJQUFJLENBQUMsTUFBTSxDQUNQLFlBQVksQ0FBQyw4QkFBOEIsQ0FBQyxPQUFPLEVBQUUsVUFBVSxDQUFDLEVBQ2hFLEdBQUcsRUFBRSxDQUFDLGdFQUFnRTtRQUNsRSxrQkFBa0IsT0FBTyxtQkFBbUIsVUFBVSxHQUFHLENBQUMsQ0FBQztJQUVuRSxNQUFNLFFBQVEsR0FBRyxZQUFZLENBQUMsaUJBQWlCLENBQzNDLENBQUMsQ0FBQyxLQUF5QyxFQUMzQyxNQUFNLENBQUMsS0FBeUMsRUFBRSxPQUFPLEVBQUUsVUFBVSxFQUNyRSxHQUFHLEVBQUUsZUFBZSxFQUFFLElBQUksQ0FBQyxlQUFlLENBQUMsQ0FBQztJQUVoRCxNQUFNLHVCQUF1QixHQUFHLEdBQUcsRUFBRSxDQUFDLE9BQU8sQ0FBQywwQkFBMEIsQ0FBQztRQUNyRSxRQUFRLENBQUMsV0FBVyxJQUFJLENBQUM7UUFDekIsUUFBUSxDQUFDLFdBQVcsR0FBRyxRQUFRLENBQUMsVUFBVSxLQUFLLENBQUMsQ0FBQztJQUNyRCxNQUFNLGVBQWUsR0FBRyxVQUFVLENBQUMsQ0FBQztRQUNoQyw0QkFBNEIsQ0FBQyxVQUFVLEVBQUUsdUJBQXVCLENBQUMsQ0FBQyxDQUFDO1FBQ25FLElBQUksQ0FBQztJQUNULE1BQU0sYUFBYSxHQUFpQixDQUFDLENBQUMsRUFBRSxNQUFNLENBQUMsQ0FBQztJQUVoRCxNQUFNLE9BQU8sR0FBRyxJQUFJLElBQUksSUFBSSxDQUFDO0lBQzdCLE1BQU0seUJBQXlCLEdBQUcsc0JBQXNCLElBQUksSUFBSSxDQUFDO0lBQ2pFLE1BQU0saUJBQWlCLEdBQUcsVUFBVSxLQUFLLFdBQVcsQ0FBQztJQUVyRCxJQUFJLE9BQU8sRUFBRTtRQUNYLGFBQWEsQ0FBQyxJQUFJLENBQUMsSUFBSSxDQUFDLENBQUM7S0FDMUI7SUFDRCxJQUFJLHlCQUF5QixFQUFFO1FBQzdCLGFBQWEsQ0FBQyxJQUFJLENBQUMsc0JBQXNCLENBQUMsQ0FBQztLQUM1QztJQUNELElBQUksaUJBQWlCLEVBQUU7UUFDckIsTUFBTSxlQUFlLEdBQUcsT0FBTyxDQUFDLGNBQWMsQ0FDMUMsRUFBRSxFQUFFLFNBQVMsRUFDYixJQUFJLENBQUMsaUJBQWlCLENBQUMsY0FBc0MsRUFDdEMsU0FBUyxDQUFDLENBQUMsQ0FBQztRQUN2QyxhQUFhLENBQUMsSUFBSSxDQUFDLGVBQWUsQ0FBQyxDQUFDO1FBQ3BDLGFBQWEsQ0FBQyxJQUFJLENBQUMsZUFBZSxDQUFDLENBQUM7S0FDckM7SUFFRCxJQUFJLE9BQTRELENBQUM7SUFDakUsSUFBSSx1QkFBdUIsRUFBRTtRQUMzQixPQUFPLEdBQUcsSUFBSSw0QkFBNEIsQ0FDdEMsUUFBUSxFQUFFLE9BQU8sRUFBRSxlQUFlLEVBQUUseUJBQXlCLEVBQzdELGlCQUFpQixDQUFDLENBQUM7S0FDeEI7U0FBTTtRQUNMLE9BQU8sR0FBRyxJQUFJLHNCQUFzQixDQUNoQyxRQUFRLEVBQUUsT0FBTyxFQUFFLGVBQWUsRUFBRSx5QkFBeUIsRUFDN0QsaUJBQWlCLENBQUMsQ0FBQztLQUN4QjtJQUNELE1BQU0sWUFBWSxHQUFHO1FBQ25CLENBQUMsUUFBUSxDQUFDLE9BQU8sQ0FBQyxHQUFHLEVBQUUsUUFBUSxDQUFDLE9BQU8sQ0FBQyxJQUFJLENBQUM7UUFDN0MsQ0FBQyxRQUFRLENBQUMsWUFBWSxFQUFFLFFBQVEsQ0FBQyxXQUFXLENBQUM7UUFDN0MsQ0FBQyxRQUFRLENBQUMsY0FBYyxFQUFFLFFBQVEsQ0FBQyxhQUFhLENBQUM7UUFDakQsQ0FBQyxRQUFRLENBQUMsUUFBUSxFQUFFLFFBQVEsQ0FBQyxPQUFPLENBQUM7S0FDdEMsQ0FBQztJQUNGLE1BQU0sTUFBTSxHQUNSLE9BQU8sQ0FBQyxlQUFlLENBQUMsT0FBTyxFQUFFLGFBQWEsRUFBRSxTQUFTLEVBQUUsWUFBWSxDQUFDLENBQUM7SUFFN0UsYUFBYSxDQUFDLE9BQU8sQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO0lBRXJFLE9BQU8sTUFBTSxDQUFDO0FBQ2hCLENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSwwQkFBMEIsR0FBaUI7SUFDdEQsVUFBVSxFQUFFLG9CQUFvQjtJQUNoQyxXQUFXLEVBQUUsT0FBTztJQUNwQixVQUFVLEVBQUUsb0JBQTZDO0NBQzFELENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7YmFja2VuZF91dGlsLCBlbnYsIEZ1c2VkRGVwdGh3aXNlQ29udjJELCBGdXNlZERlcHRod2lzZUNvbnYyREF0dHJzLCBGdXNlZERlcHRod2lzZUNvbnYyRElucHV0cywgS2VybmVsQ29uZmlnLCBLZXJuZWxGdW5jLCBUZW5zb3JJbmZvLCB1dGlsfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuXG5pbXBvcnQge01hdGhCYWNrZW5kV2ViR0x9IGZyb20gJy4uL2JhY2tlbmRfd2ViZ2wnO1xuaW1wb3J0IHtEZXB0aHdpc2VDb252MkRQcm9ncmFtfSBmcm9tICcuLi9jb252X2dwdV9kZXB0aHdpc2UnO1xuaW1wb3J0IHtEZXB0aHdpc2VDb252UGFja2VkMkRQcm9ncmFtfSBmcm9tICcuLi9jb252X3BhY2tlZF9ncHVfZGVwdGh3aXNlJztcbmltcG9ydCB7bWFwQWN0aXZhdGlvblRvU2hhZGVyUHJvZ3JhbX0gZnJvbSAnLi4va2VybmVsX3V0aWxzL2tlcm5lbF9mdW5jc191dGlscyc7XG5cbmV4cG9ydCBmdW5jdGlvbiBmdXNlZERlcHRod2lzZUNvbnYyRChhcmdzOiB7XG4gIGlucHV0czogRnVzZWREZXB0aHdpc2VDb252MkRJbnB1dHMsXG4gIGF0dHJzOiBGdXNlZERlcHRod2lzZUNvbnYyREF0dHJzLFxuICBiYWNrZW5kOiBNYXRoQmFja2VuZFdlYkdMXG59KSB7XG4gIGNvbnN0IHtpbnB1dHMsIGJhY2tlbmQsIGF0dHJzfSA9IGFyZ3M7XG4gIGNvbnN0IHt4LCBmaWx0ZXIsIGJpYXMsIHByZWx1QWN0aXZhdGlvbldlaWdodHN9ID0gaW5wdXRzO1xuICBjb25zdCB7c3RyaWRlcywgcGFkLCBkaWxhdGlvbnMsIGRpbVJvdW5kaW5nTW9kZSwgYWN0aXZhdGlvbiwgbGVha3lyZWx1QWxwaGF9ID1cbiAgICAgIGF0dHJzO1xuXG4gIGNvbnN0IGludGVybWVkaWF0ZXM6IFRlbnNvckluZm9bXSA9IFtdO1xuXG4gIGxldCAkZGlsYXRpb25zID0gZGlsYXRpb25zO1xuICBpZiAoJGRpbGF0aW9ucyA9PSBudWxsKSB7XG4gICAgJGRpbGF0aW9ucyA9IFsxLCAxXTtcbiAgfVxuXG4gIHV0aWwuYXNzZXJ0KFxuICAgICAgYmFja2VuZF91dGlsLmVpdGhlclN0cmlkZXNPckRpbGF0aW9uc0FyZU9uZShzdHJpZGVzLCAkZGlsYXRpb25zKSxcbiAgICAgICgpID0+ICdFcnJvciBpbiBkZXB0aHdpc2VDb252MmQ6IEVpdGhlciBzdHJpZGVzIG9yIGRpbGF0aW9ucyBtdXN0IGJlICcgK1xuICAgICAgICAgIGAxLiBHb3Qgc3RyaWRlcyAke3N0cmlkZXN9IGFuZCBkaWxhdGlvbnMgJyR7JGRpbGF0aW9uc30nYCk7XG5cbiAgY29uc3QgY29udkluZm8gPSBiYWNrZW5kX3V0aWwuY29tcHV0ZUNvbnYyREluZm8oXG4gICAgICB4LnNoYXBlIGFzIFtudW1iZXIsIG51bWJlciwgbnVtYmVyLCBudW1iZXJdLFxuICAgICAgZmlsdGVyLnNoYXBlIGFzIFtudW1iZXIsIG51bWJlciwgbnVtYmVyLCBudW1iZXJdLCBzdHJpZGVzLCAkZGlsYXRpb25zLFxuICAgICAgcGFkLCBkaW1Sb3VuZGluZ01vZGUsIHRydWUgLyogZGVwdGh3aXNlICovKTtcblxuICBjb25zdCBzaG91bGRQYWNrRGVwdGh3aXNlQ29udiA9IGVudigpLmdldEJvb2woJ1dFQkdMX1BBQ0tfREVQVEhXSVNFQ09OVicpICYmXG4gICAgICBjb252SW5mby5zdHJpZGVXaWR0aCA8PSAyICYmXG4gICAgICBjb252SW5mby5vdXRDaGFubmVscyAvIGNvbnZJbmZvLmluQ2hhbm5lbHMgPT09IDE7XG4gIGNvbnN0IGZ1c2VkQWN0aXZhdGlvbiA9IGFjdGl2YXRpb24gP1xuICAgICAgbWFwQWN0aXZhdGlvblRvU2hhZGVyUHJvZ3JhbShhY3RpdmF0aW9uLCBzaG91bGRQYWNrRGVwdGh3aXNlQ29udikgOlxuICAgICAgbnVsbDtcbiAgY29uc3QgcHJvZ3JhbUlucHV0czogVGVuc29ySW5mb1tdID0gW3gsIGZpbHRlcl07XG5cbiAgY29uc3QgaGFzQmlhcyA9IGJpYXMgIT0gbnVsbDtcbiAgY29uc3QgaGFzUHJlbHVBY3RpdmF0aW9uV2VpZ2h0cyA9IHByZWx1QWN0aXZhdGlvbldlaWdodHMgIT0gbnVsbDtcbiAgY29uc3QgaGFzTGVha3lyZWx1QWxwaGEgPSBhY3RpdmF0aW9uID09PSAnbGVha3lyZWx1JztcblxuICBpZiAoaGFzQmlhcykge1xuICAgIHByb2dyYW1JbnB1dHMucHVzaChiaWFzKTtcbiAgfVxuICBpZiAoaGFzUHJlbHVBY3RpdmF0aW9uV2VpZ2h0cykge1xuICAgIHByb2dyYW1JbnB1dHMucHVzaChwcmVsdUFjdGl2YXRpb25XZWlnaHRzKTtcbiAgfVxuICBpZiAoaGFzTGVha3lyZWx1QWxwaGEpIHtcbiAgICBjb25zdCAkbGVha3lyZWx1QWxwaGEgPSBiYWNrZW5kLm1ha2VUZW5zb3JJbmZvKFxuICAgICAgICBbXSwgJ2Zsb2F0MzInLFxuICAgICAgICB1dGlsLmNyZWF0ZVNjYWxhclZhbHVlKGxlYWt5cmVsdUFscGhhIGFzIHVua25vd24gYXMgJ2Zsb2F0MzInLFxuICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICdmbG9hdDMyJykpO1xuICAgIHByb2dyYW1JbnB1dHMucHVzaCgkbGVha3lyZWx1QWxwaGEpO1xuICAgIGludGVybWVkaWF0ZXMucHVzaCgkbGVha3lyZWx1QWxwaGEpO1xuICB9XG5cbiAgbGV0IHByb2dyYW06IERlcHRod2lzZUNvbnYyRFByb2dyYW18RGVwdGh3aXNlQ29udlBhY2tlZDJEUHJvZ3JhbTtcbiAgaWYgKHNob3VsZFBhY2tEZXB0aHdpc2VDb252KSB7XG4gICAgcHJvZ3JhbSA9IG5ldyBEZXB0aHdpc2VDb252UGFja2VkMkRQcm9ncmFtKFxuICAgICAgICBjb252SW5mbywgaGFzQmlhcywgZnVzZWRBY3RpdmF0aW9uLCBoYXNQcmVsdUFjdGl2YXRpb25XZWlnaHRzLFxuICAgICAgICBoYXNMZWFreXJlbHVBbHBoYSk7XG4gIH0gZWxzZSB7XG4gICAgcHJvZ3JhbSA9IG5ldyBEZXB0aHdpc2VDb252MkRQcm9ncmFtKFxuICAgICAgICBjb252SW5mbywgaGFzQmlhcywgZnVzZWRBY3RpdmF0aW9uLCBoYXNQcmVsdUFjdGl2YXRpb25XZWlnaHRzLFxuICAgICAgICBoYXNMZWFreXJlbHVBbHBoYSk7XG4gIH1cbiAgY29uc3QgY3VzdG9tVmFsdWVzID0gW1xuICAgIFtjb252SW5mby5wYWRJbmZvLnRvcCwgY29udkluZm8ucGFkSW5mby5sZWZ0XSxcbiAgICBbY29udkluZm8uc3RyaWRlSGVpZ2h0LCBjb252SW5mby5zdHJpZGVXaWR0aF0sXG4gICAgW2NvbnZJbmZvLmRpbGF0aW9uSGVpZ2h0LCBjb252SW5mby5kaWxhdGlvbldpZHRoXSxcbiAgICBbY29udkluZm8uaW5IZWlnaHQsIGNvbnZJbmZvLmluV2lkdGhdXG4gIF07XG4gIGNvbnN0IHJlc3VsdCA9XG4gICAgICBiYWNrZW5kLnJ1bldlYkdMUHJvZ3JhbShwcm9ncmFtLCBwcm9ncmFtSW5wdXRzLCAnZmxvYXQzMicsIGN1c3RvbVZhbHVlcyk7XG5cbiAgaW50ZXJtZWRpYXRlcy5mb3JFYWNoKHQgPT4gYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyh0KSk7XG5cbiAgcmV0dXJuIHJlc3VsdDtcbn1cblxuZXhwb3J0IGNvbnN0IGZ1c2VkRGVwdGh3aXNlQ29udjJEQ29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IEZ1c2VkRGVwdGh3aXNlQ29udjJELFxuICBiYWNrZW5kTmFtZTogJ3dlYmdsJyxcbiAga2VybmVsRnVuYzogZnVzZWREZXB0aHdpc2VDb252MkQgYXMgdW5rbm93biBhcyBLZXJuZWxGdW5jLFxufTtcbiJdfQ==