/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, env, FusedDepthwiseConv2D, util } from '@tensorflow/tfjs-core'; import { DepthwiseConv2DProgram } from '../conv_gpu_depthwise'; import { DepthwiseConvPacked2DProgram } from '../conv_packed_gpu_depthwise'; import { mapActivationToShaderProgram } from '../kernel_utils/kernel_funcs_utils'; export function fusedDepthwiseConv2D(args) { const { inputs, backend, attrs } = args; const { x, filter, bias, preluActivationWeights } = inputs; const { strides, pad, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs; const intermediates = []; let $dilations = dilations; if ($dilations == null) { $dilations = [1, 1]; } util.assert(backend_util.eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' + `1. Got strides ${strides} and dilations '${$dilations}'`); const convInfo = backend_util.computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */); const shouldPackDepthwiseConv = env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1; const fusedActivation = activation ? mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) : null; const programInputs = [x, filter]; const hasBias = bias != null; const hasPreluActivationWeights = preluActivationWeights != null; const hasLeakyreluAlpha = activation === 'leakyrelu'; if (hasBias) { programInputs.push(bias); } if (hasPreluActivationWeights) { programInputs.push(preluActivationWeights); } if (hasLeakyreluAlpha) { const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', util.createScalarValue(leakyreluAlpha, 'float32')); programInputs.push($leakyreluAlpha); intermediates.push($leakyreluAlpha); } let program; if (shouldPackDepthwiseConv) { program = new DepthwiseConvPacked2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); } else { program = new DepthwiseConv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); } const customValues = [ [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth] ]; const result = backend.runWebGLProgram(program, programInputs, 'float32', customValues); intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t)); return result; } export const fusedDepthwiseConv2DConfig = { kernelName: FusedDepthwiseConv2D, backendName: 'webgl', kernelFunc: fusedDepthwiseConv2D, }; //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"FusedDepthwiseConv2D.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-webgl/src/kernels/FusedDepthwiseConv2D.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAAE,GAAG,EAAE,oBAAoB,EAA+F,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAGjL,OAAO,EAAC,sBAAsB,EAAC,MAAM,uBAAuB,CAAC;AAC7D,OAAO,EAAC,4BAA4B,EAAC,MAAM,8BAA8B,CAAC;AAC1E,OAAO,EAAC,4BAA4B,EAAC,MAAM,oCAAoC,CAAC;AAEhF,MAAM,UAAU,oBAAoB,CAAC,IAIpC;IACC,MAAM,EAAC,MAAM,EAAE,OAAO,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,sBAAsB,EAAC,GAAG,MAAM,CAAC;IACzD,MAAM,EAAC,OAAO,EAAE,GAAG,EAAE,SAAS,EAAE,eAAe,EAAE,UAAU,EAAE,cAAc,EAAC,GACxE,KAAK,CAAC;IAEV,MAAM,aAAa,GAAiB,EAAE,CAAC;IAEvC,IAAI,UAAU,GAAG,SAAS,CAAC;IAC3B,IAAI,UAAU,IAAI,IAAI,EAAE;QACtB,UAAU,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;KACrB;IAED,IAAI,CAAC,MAAM,CACP,YAAY,CAAC,8BAA8B,CAAC,OAAO,EAAE,UAAU,CAAC,EAChE,GAAG,EAAE,CAAC,gEAAgE;QAClE,kBAAkB,OAAO,mBAAmB,UAAU,GAAG,CAAC,CAAC;IAEnE,MAAM,QAAQ,GAAG,YAAY,CAAC,iBAAiB,CAC3C,CAAC,CAAC,KAAyC,EAC3C,MAAM,CAAC,KAAyC,EAAE,OAAO,EAAE,UAAU,EACrE,GAAG,EAAE,eAAe,EAAE,IAAI,CAAC,eAAe,CAAC,CAAC;IAEhD,MAAM,uBAAuB,GAAG,GAAG,EAAE,CAAC,OAAO,CAAC,0BAA0B,CAAC;QACrE,QAAQ,CAAC,WAAW,IAAI,CAAC;QACzB,QAAQ,CAAC,WAAW,GAAG,QAAQ,CAAC,UAAU,KAAK,CAAC,CAAC;IACrD,MAAM,eAAe,GAAG,UAAU,CAAC,CAAC;QAChC,4BAA4B,CAAC,UAAU,EAAE,uBAAuB,CAAC,CAAC,CAAC;QACnE,IAAI,CAAC;IACT,MAAM,aAAa,GAAiB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;IAEhD,MAAM,OAAO,GAAG,IAAI,IAAI,IAAI,CAAC;IAC7B,MAAM,yBAAyB,GAAG,sBAAsB,IAAI,IAAI,CAAC;IACjE,MAAM,iBAAiB,GAAG,UAAU,KAAK,WAAW,CAAC;IAErD,IAAI,OAAO,EAAE;QACX,aAAa,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;KAC1B;IACD,IAAI,yBAAyB,EAAE;QAC7B,aAAa,CAAC,IAAI,CAAC,sBAAsB,CAAC,CAAC;KAC5C;IACD,IAAI,iBAAiB,EAAE;QACrB,MAAM,eAAe,GAAG,OAAO,CAAC,cAAc,CAC1C,EAAE,EAAE,SAAS,EACb,IAAI,CAAC,iBAAiB,CAAC,cAAsC,EACtC,SAAS,CAAC,CAAC,CAAC;QACvC,aAAa,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;QACpC,aAAa,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;KACrC;IAED,IAAI,OAA4D,CAAC;IACjE,IAAI,uBAAuB,EAAE;QAC3B,OAAO,GAAG,IAAI,4BAA4B,CACtC,QAAQ,EAAE,OAAO,EAAE,eAAe,EAAE,yBAAyB,EAC7D,iBAAiB,CAAC,CAAC;KACxB;SAAM;QACL,OAAO,GAAG,IAAI,sBAAsB,CAChC,QAAQ,EAAE,OAAO,EAAE,eAAe,EAAE,yBAAyB,EAC7D,iBAAiB,CAAC,CAAC;KACxB;IACD,MAAM,YAAY,GAAG;QACnB,CAAC,QAAQ,CAAC,OAAO,CAAC,GAAG,EAAE,QAAQ,CAAC,OAAO,CAAC,IAAI,CAAC;QAC7C,CAAC,QAAQ,CAAC,YAAY,EAAE,QAAQ,CAAC,WAAW,CAAC;QAC7C,CAAC,QAAQ,CAAC,cAAc,EAAE,QAAQ,CAAC,aAAa,CAAC;QACjD,CAAC,QAAQ,CAAC,QAAQ,EAAE,QAAQ,CAAC,OAAO,CAAC;KACtC,CAAC;IACF,MAAM,MAAM,GACR,OAAO,CAAC,eAAe,CAAC,OAAO,EAAE,aAAa,EAAE,SAAS,EAAE,YAAY,CAAC,CAAC;IAE7E,aAAa,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,OAAO,CAAC,6BAA6B,CAAC,CAAC,CAAC,CAAC,CAAC;IAErE,OAAO,MAAM,CAAC;AAChB,CAAC;AAED,MAAM,CAAC,MAAM,0BAA0B,GAAiB;IACtD,UAAU,EAAE,oBAAoB;IAChC,WAAW,EAAE,OAAO;IACpB,UAAU,EAAE,oBAA6C;CAC1D,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, env, FusedDepthwiseConv2D, FusedDepthwiseConv2DAttrs, FusedDepthwiseConv2DInputs, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core';\n\nimport {MathBackendWebGL} from '../backend_webgl';\nimport {DepthwiseConv2DProgram} from '../conv_gpu_depthwise';\nimport {DepthwiseConvPacked2DProgram} from '../conv_packed_gpu_depthwise';\nimport {mapActivationToShaderProgram} from '../kernel_utils/kernel_funcs_utils';\n\nexport function fusedDepthwiseConv2D(args: {\n  inputs: FusedDepthwiseConv2DInputs,\n  attrs: FusedDepthwiseConv2DAttrs,\n  backend: MathBackendWebGL\n}) {\n  const {inputs, backend, attrs} = args;\n  const {x, filter, bias, preluActivationWeights} = inputs;\n  const {strides, pad, dilations, dimRoundingMode, activation, leakyreluAlpha} =\n      attrs;\n\n  const intermediates: TensorInfo[] = [];\n\n  let $dilations = dilations;\n  if ($dilations == null) {\n    $dilations = [1, 1];\n  }\n\n  util.assert(\n      backend_util.eitherStridesOrDilationsAreOne(strides, $dilations),\n      () => 'Error in depthwiseConv2d: Either strides or dilations must be ' +\n          `1. Got strides ${strides} and dilations '${$dilations}'`);\n\n  const convInfo = backend_util.computeConv2DInfo(\n      x.shape as [number, number, number, number],\n      filter.shape as [number, number, number, number], strides, $dilations,\n      pad, dimRoundingMode, true /* depthwise */);\n\n  const shouldPackDepthwiseConv = env().getBool('WEBGL_PACK_DEPTHWISECONV') &&\n      convInfo.strideWidth <= 2 &&\n      convInfo.outChannels / convInfo.inChannels === 1;\n  const fusedActivation = activation ?\n      mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) :\n      null;\n  const programInputs: TensorInfo[] = [x, filter];\n\n  const hasBias = bias != null;\n  const hasPreluActivationWeights = preluActivationWeights != null;\n  const hasLeakyreluAlpha = activation === 'leakyrelu';\n\n  if (hasBias) {\n    programInputs.push(bias);\n  }\n  if (hasPreluActivationWeights) {\n    programInputs.push(preluActivationWeights);\n  }\n  if (hasLeakyreluAlpha) {\n    const $leakyreluAlpha = backend.makeTensorInfo(\n        [], 'float32',\n        util.createScalarValue(leakyreluAlpha as unknown as 'float32',\n                               'float32'));\n    programInputs.push($leakyreluAlpha);\n    intermediates.push($leakyreluAlpha);\n  }\n\n  let program: DepthwiseConv2DProgram|DepthwiseConvPacked2DProgram;\n  if (shouldPackDepthwiseConv) {\n    program = new DepthwiseConvPacked2DProgram(\n        convInfo, hasBias, fusedActivation, hasPreluActivationWeights,\n        hasLeakyreluAlpha);\n  } else {\n    program = new DepthwiseConv2DProgram(\n        convInfo, hasBias, fusedActivation, hasPreluActivationWeights,\n        hasLeakyreluAlpha);\n  }\n  const customValues = [\n    [convInfo.padInfo.top, convInfo.padInfo.left],\n    [convInfo.strideHeight, convInfo.strideWidth],\n    [convInfo.dilationHeight, convInfo.dilationWidth],\n    [convInfo.inHeight, convInfo.inWidth]\n  ];\n  const result =\n      backend.runWebGLProgram(program, programInputs, 'float32', customValues);\n\n  intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t));\n\n  return result;\n}\n\nexport const fusedDepthwiseConv2DConfig: KernelConfig = {\n  kernelName: FusedDepthwiseConv2D,\n  backendName: 'webgl',\n  kernelFunc: fusedDepthwiseConv2D as unknown as KernelFunc,\n};\n"]}