/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { backend_util, env, FusedConv2D, util } from '@tensorflow/tfjs-core'; import { Conv2DProgram } from '../conv_gpu'; import { Conv2DPackedProgram } from '../conv_packed_gpu'; import { mapActivationToShaderProgram } from '../kernel_utils/kernel_funcs_utils'; import { conv2dByMatMul, conv2dWithIm2Row } from './Conv2D_impl'; import { reshape } from './Reshape'; export function fusedConv2d(args) { const { inputs, backend, attrs } = args; const { x, filter, bias, preluActivationWeights } = inputs; const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs; const $dataFormat = backend_util.convertConv2DDataFormat(dataFormat); const convInfo = backend_util.computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat); let out; const intermediates = []; const hasBias = bias != null; const hasPreluActivationWeights = preluActivationWeights != null; const hasLeakyreluAlpha = activation === 'leakyrelu'; const prepareInputs = () => { const inputs = [x, filter]; // If the input is a 1-D tensor, align it with the channels. // // For fusedConv2d, the inputs (x, W, bias, preluActivationWeights) are // supposed to be aligned with the dataFormat. The 4-D tensor inputs or // scalar inputs are originally aligned, but the 1-D tensor inputs are // supposed to be aligned with the channels (only bias and PReLU activation // weights could be a 1-D tensor). const alignInputWithDataFormat = (input, dataFormat) => { if (dataFormat === 'NCHW' && input.shape.length === 1 && input.shape[0] !== 1) { const alignedInput = reshape({ inputs: { x: input }, backend, attrs: { shape: [input.shape[0], 1, 1] } }); intermediates.push(alignedInput); return alignedInput; } return input; }; if (hasBias) { inputs.push(alignInputWithDataFormat(bias, dataFormat)); } if (hasPreluActivationWeights) { inputs.push(alignInputWithDataFormat(preluActivationWeights, dataFormat)); } if (hasLeakyreluAlpha) { const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', util.createScalarValue(leakyreluAlpha, 'float32')); inputs.push($leakyreluAlpha); intermediates.push($leakyreluAlpha); } return inputs; }; if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && (convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) { out = conv2dByMatMul({ x, filter, convInfo, backend, bias, activation, preluActivationWeights, leakyreluAlpha }); } else if (convInfo.strideWidth <= 2 && $dataFormat === 'channelsLast' && env().getBool('WEBGL_EXP_CONV')) { const fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null; const program = new Conv2DPackedProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); const customValues = [ [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth] ]; const inputs = prepareInputs(); out = backend.runWebGLProgram(program, inputs, 'float32', customValues); } else if (env().getBool('WEBGL_CONV_IM2COL')) { out = conv2dWithIm2Row({ x, filter, convInfo, backend, bias, activation, preluActivationWeights, leakyreluAlpha }); } else { const fusedActivation = activation ? mapActivationToShaderProgram(activation, false) : null; const program = new Conv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); const inputs = prepareInputs(); out = backend.runWebGLProgram(program, inputs, 'float32'); } const outReshaped = reshape({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } }); intermediates.push(out); intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t)); return outReshaped; } export const fusedConv2DConfig = { kernelName: FusedConv2D, backendName: 'webgl', kernelFunc: fusedConv2d, }; //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"FusedConv2D.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-webgl/src/kernels/FusedConv2D.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,YAAY,EAAE,GAAG,EAAE,WAAW,EAA6E,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAGtJ,OAAO,EAAC,aAAa,EAAC,MAAM,aAAa,CAAC;AAC1C,OAAO,EAAC,mBAAmB,EAAC,MAAM,oBAAoB,CAAC;AACvD,OAAO,EAAC,4BAA4B,EAAC,MAAM,oCAAoC,CAAC;AAEhF,OAAO,EAAC,cAAc,EAAE,gBAAgB,EAAC,MAAM,eAAe,CAAC;AAC/D,OAAO,EAAC,OAAO,EAAC,MAAM,WAAW,CAAC;AAElC,MAAM,UAAU,WAAW,CAAC,IAI3B;IACC,MAAM,EAAC,MAAM,EAAE,OAAO,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,sBAAsB,EAAC,GAAG,MAAM,CAAC;IACzD,MAAM,EACJ,OAAO,EACP,GAAG,EACH,UAAU,EACV,SAAS,EACT,eAAe,EACf,UAAU,EACV,cAAc,EACf,GAAG,KAAK,CAAC;IAEV,MAAM,WAAW,GAAG,YAAY,CAAC,uBAAuB,CAAC,UAAU,CAAC,CAAC;IACrE,MAAM,QAAQ,GAAG,YAAY,CAAC,iBAAiB,CAC3C,CAAC,CAAC,KAAyC,EAC3C,MAAM,CAAC,KAAyC,EAAE,OAAO,EAAE,SAAS,EAAE,GAAG,EACzE,eAAe,EAAE,KAAK,CAAC,eAAe,EAAE,WAAW,CAAC,CAAC;IACzD,IAAI,GAAe,CAAC;IACpB,MAAM,aAAa,GAAiB,EAAE,CAAC;IAEvC,MAAM,OAAO,GAAG,IAAI,IAAI,IAAI,CAAC;IAC7B,MAAM,yBAAyB,GAAG,sBAAsB,IAAI,IAAI,CAAC;IACjE,MAAM,iBAAiB,GAAG,UAAU,KAAK,WAAW,CAAC;IAErD,MAAM,aAAa,GAAG,GAAiB,EAAE;QACvC,MAAM,MAAM,GAAiB,CAAC,CAAC,EAAE,MAAM,CAAC,CAAC;QAEzC,4DAA4D;QAC5D,EAAE;QACF,uEAAuE;QACvE,uEAAuE;QACvE,sEAAsE;QACtE,2EAA2E;QAC3E,kCAAkC;QAClC,MAAM,wBAAwB,GAC1B,CAAC,KAAiB,EAAE,UAAyB,EAAc,EAAE;YAC3D,IAAI,UAAU,KAAK,MAAM,IAAI,KAAK,CAAC,KAAK,CAAC,MAAM,KAAK,CAAC;gBACjD,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,EAAE;gBACxB,MAAM,YAAY,GAAG,OAAO,CAAC;oBAC3B,MAAM,EAAE,EAAC,CAAC,EAAE,KAAK,EAAC;oBAClB,OAAO;oBACP,KAAK,EAAE,EAAC,KAAK,EAAE,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAC;iBACvC,CAAC,CAAC;gBACH,aAAa,CAAC,IAAI,CAAC,YAAY,CAAC,CAAC;gBACjC,OAAO,YAAY,CAAC;aACrB;YACD,OAAO,KAAK,CAAC;QACf,CAAC,CAAC;QAEN,IAAI,OAAO,EAAE;YACX,MAAM,CAAC,IAAI,CAAC,wBAAwB,CAAC,IAAI,EAAE,UAAU,CAAC,CAAC,CAAC;SACzD;QAED,IAAI,yBAAyB,EAAE;YAC7B,MAAM,CAAC,IAAI,CAAC,wBAAwB,CAAC,sBAAsB,EAAE,UAAU,CAAC,CAAC,CAAC;SAC3E;QAED,IAAI,iBAAiB,EAAE;YACrB,MAAM,eAAe,GAAG,OAAO,CAAC,cAAc,CAC1C,EAAE,EAAE,SAAS,EACb,IAAI,CAAC,iBAAiB,CAAC,cAAsC,EAAE,SAAS,CAAC,CAAC,CAAC;YAC/E,MAAM,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;YAC7B,aAAa,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;SACrC;QACD,OAAO,MAAM,CAAC;IAChB,CAAC,CAAC;IAEF,IAAI,QAAQ,CAAC,YAAY,KAAK,CAAC,IAAI,QAAQ,CAAC,WAAW,KAAK,CAAC;QACzD,QAAQ,CAAC,cAAc,KAAK,CAAC,IAAI,QAAQ,CAAC,aAAa,KAAK,CAAC;QAC7D,QAAQ,CAAC,YAAY,KAAK,CAAC,IAAI,QAAQ,CAAC,WAAW,KAAK,CAAC;QACzD,CAAC,QAAQ,CAAC,OAAO,CAAC,IAAI,KAAK,MAAM,IAAI,QAAQ,CAAC,OAAO,CAAC,IAAI,KAAK,OAAO,CAAC,EAAE;QAC3E,GAAG,GAAG,cAAc,CAAC;YACnB,CAAC;YACD,MAAM;YACN,QAAQ;YACR,OAAO;YACP,IAAI;YACJ,UAAU;YACV,sBAAsB;YACtB,cAAc;SACf,CAAC,CAAC;KACJ;SAAM,IAAI,QAAQ,CAAC,WAAW,IAAI,CAAC,IAAI,WAAW,KAAK,cAAc;WACjE,GAAG,EAAE,CAAC,OAAO,CAAC,gBAAgB,CAAC,EAChC;QACA,MAAM,eAAe,GACjB,UAAU,CAAC,CAAC,CAAC,4BAA4B,CAAC,UAAU,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC;QACzE,MAAM,OAAO,GAAG,IAAI,mBAAmB,CACrC,QAAQ,EAAE,OAAO,EAAE,eAAe,EAAE,yBAAyB,EAC7D,iBAAiB,CAAC,CAAC;QACrB,MAAM,YAAY,GAAG;YACnB,CAAC,QAAQ,CAAC,OAAO,CAAC,GAAG,EAAE,QAAQ,CAAC,OAAO,CAAC,IAAI,CAAC;YAC7C,CAAC,QAAQ,CAAC,YAAY,EAAE,QAAQ,CAAC,WAAW,CAAC;YAC7C,CAAC,QAAQ,CAAC,cAAc,EAAE,QAAQ,CAAC,aAAa,CAAC;YACjD,CAAC,QAAQ,CAAC,QAAQ,EAAE,QAAQ,CAAC,OAAO,CAAC;SACtC,CAAC;QACF,MAAM,MAAM,GAAG,aAAa,EAAE,CAAC;QAC/B,GAAG,GAAG,OAAO,CAAC,eAAe,CAAC,OAAO,EAAE,MAAM,EAAE,SAAS,EAAE,YAAY,CAAC,CAAC;KACzE;SAAM,IAAI,GAAG,EAAE,CAAC,OAAO,CAAC,mBAAmB,CAAC,EAAE;QAC7C,GAAG,GAAG,gBAAgB,CAAC;YACrB,CAAC;YACD,MAAM;YACN,QAAQ;YACR,OAAO;YACP,IAAI;YACJ,UAAU;YACV,sBAAsB;YACtB,cAAc;SACf,CAAC,CAAC;KACJ;SAAM;QACL,MAAM,eAAe,GACjB,UAAU,CAAC,CAAC,CAAC,4BAA4B,CAAC,UAAU,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC;QACxE,MAAM,OAAO,GAAG,IAAI,aAAa,CAC7B,QAAQ,EAAE,OAAO,EAAE,eAAe,EAAE,yBAAyB,EAC7D,iBAAiB,CAAC,CAAC;QAEvB,MAAM,MAAM,GAAG,aAAa,EAAE,CAAC;QAC/B,GAAG,GAAG,OAAO,CAAC,eAAe,CAAC,OAAO,EAAE,MAAM,EAAE,SAAS,CAAC,CAAC;KAC3D;IAED,MAAM,WAAW,GACb,OAAO,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,GAAG,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,QAAQ,CAAC,QAAQ,EAAC,EAAC,CAAC,CAAC;IAE5E,aAAa,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;IACxB,aAAa,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,OAAO,CAAC,6BAA6B,CAAC,CAAC,CAAC,CAAC,CAAC;IAErE,OAAO,WAAW,CAAC;AACrB,CAAC;AAED,MAAM,CAAC,MAAM,iBAAiB,GAAiB;IAC7C,UAAU,EAAE,WAAW;IACvB,WAAW,EAAE,OAAO;IACpB,UAAU,EAAE,WAAoC;CACjD,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, env, FusedConv2D, FusedConv2DAttrs, FusedConv2DInputs, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core';\n\nimport {MathBackendWebGL} from '../backend_webgl';\nimport {Conv2DProgram} from '../conv_gpu';\nimport {Conv2DPackedProgram} from '../conv_packed_gpu';\nimport {mapActivationToShaderProgram} from '../kernel_utils/kernel_funcs_utils';\n\nimport {conv2dByMatMul, conv2dWithIm2Row} from './Conv2D_impl';\nimport {reshape} from './Reshape';\n\nexport function fusedConv2d(args: {\n  inputs: FusedConv2DInputs,\n  attrs: FusedConv2DAttrs,\n  backend: MathBackendWebGL\n}) {\n  const {inputs, backend, attrs} = args;\n  const {x, filter, bias, preluActivationWeights} = inputs;\n  const {\n    strides,\n    pad,\n    dataFormat,\n    dilations,\n    dimRoundingMode,\n    activation,\n    leakyreluAlpha\n  } = attrs;\n\n  const $dataFormat = backend_util.convertConv2DDataFormat(dataFormat);\n  const convInfo = backend_util.computeConv2DInfo(\n      x.shape as [number, number, number, number],\n      filter.shape as [number, number, number, number], strides, dilations, pad,\n      dimRoundingMode, false /* depthwise */, $dataFormat);\n  let out: TensorInfo;\n  const intermediates: TensorInfo[] = [];\n\n  const hasBias = bias != null;\n  const hasPreluActivationWeights = preluActivationWeights != null;\n  const hasLeakyreluAlpha = activation === 'leakyrelu';\n\n  const prepareInputs = (): TensorInfo[] => {\n    const inputs: TensorInfo[] = [x, filter];\n\n    // If the input is a 1-D tensor, align it with the channels.\n    //\n    // For fusedConv2d, the inputs (x, W, bias, preluActivationWeights) are\n    // supposed to be aligned with the dataFormat. The 4-D tensor inputs or\n    // scalar inputs are originally aligned, but the 1-D tensor inputs are\n    // supposed to be aligned with the channels (only bias and PReLU activation\n    // weights could be a 1-D tensor).\n    const alignInputWithDataFormat =\n        (input: TensorInfo, dataFormat: 'NHWC'|'NCHW'): TensorInfo => {\n          if (dataFormat === 'NCHW' && input.shape.length === 1 &&\n              input.shape[0] !== 1) {\n            const alignedInput = reshape({\n              inputs: {x: input},\n              backend,\n              attrs: {shape: [input.shape[0], 1, 1]}\n            });\n            intermediates.push(alignedInput);\n            return alignedInput;\n          }\n          return input;\n        };\n\n    if (hasBias) {\n      inputs.push(alignInputWithDataFormat(bias, dataFormat));\n    }\n\n    if (hasPreluActivationWeights) {\n      inputs.push(alignInputWithDataFormat(preluActivationWeights, dataFormat));\n    }\n\n    if (hasLeakyreluAlpha) {\n      const $leakyreluAlpha = backend.makeTensorInfo(\n          [], 'float32',\n          util.createScalarValue(leakyreluAlpha as unknown as 'float32', 'float32'));\n      inputs.push($leakyreluAlpha);\n      intermediates.push($leakyreluAlpha);\n    }\n    return inputs;\n  };\n\n  if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&\n      convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&\n      convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&\n      (convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) {\n    out = conv2dByMatMul({\n      x,\n      filter,\n      convInfo,\n      backend,\n      bias,\n      activation,\n      preluActivationWeights,\n      leakyreluAlpha\n    });\n  } else if (convInfo.strideWidth <= 2 && $dataFormat === 'channelsLast'\n    && env().getBool('WEBGL_EXP_CONV')\n    ) {\n      const fusedActivation =\n          activation ? mapActivationToShaderProgram(activation, true) : null;\n    const program = new Conv2DPackedProgram(\n      convInfo, hasBias, fusedActivation, hasPreluActivationWeights,\n      hasLeakyreluAlpha);\n    const customValues = [\n      [convInfo.padInfo.top, convInfo.padInfo.left],\n      [convInfo.strideHeight, convInfo.strideWidth],\n      [convInfo.dilationHeight, convInfo.dilationWidth],\n      [convInfo.inHeight, convInfo.inWidth]\n    ];\n    const inputs = prepareInputs();\n    out = backend.runWebGLProgram(program, inputs, 'float32', customValues);\n  } else if (env().getBool('WEBGL_CONV_IM2COL')) {\n    out = conv2dWithIm2Row({\n      x,\n      filter,\n      convInfo,\n      backend,\n      bias,\n      activation,\n      preluActivationWeights,\n      leakyreluAlpha\n    });\n  } else {\n    const fusedActivation =\n        activation ? mapActivationToShaderProgram(activation, false) : null;\n    const program = new Conv2DProgram(\n        convInfo, hasBias, fusedActivation, hasPreluActivationWeights,\n        hasLeakyreluAlpha);\n\n    const inputs = prepareInputs();\n    out = backend.runWebGLProgram(program, inputs, 'float32');\n  }\n\n  const outReshaped =\n      reshape({inputs: {x: out}, backend, attrs: {shape: convInfo.outShape}});\n\n  intermediates.push(out);\n  intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t));\n\n  return outReshaped;\n}\n\nexport const fusedConv2DConfig: KernelConfig = {\n  kernelName: FusedConv2D,\n  backendName: 'webgl',\n  kernelFunc: fusedConv2d as unknown as KernelFunc,\n};\n"]}