/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { util } from '@tensorflow/tfjs-core'; import { Im2ColPackedProgram } from '../im2col_packed_gpu'; import { mapActivationToShaderProgram } from '../kernel_utils/kernel_funcs_utils'; import { MatMulPackedProgram } from '../mulmat_packed_gpu'; import * as webgl_util from '../webgl_util'; import { batchMatMulImpl, MATMUL_SHARED_DIM_THRESHOLD } from './BatchMatMul_impl'; import { identity } from './Identity'; import { reshape } from './Reshape'; // Both conv2dByMatMul and conv2dWithIm2Row fuse height and width into one // dimension to compute batchMatMul, so bias and activation weights are also // supposed to fuse the two dimensions into one. // // This function computes the target shape for fusing height and width // dimensions. Returning null means the shape is already compatible. // // Even though the bias is not supposed to be a 3-D or a 4-D (including // batch) tensor and PReLU activiation weights is not supposed to be a 4-D // tensor, we still need to support them, because we haven't disabled // them for NHWC format. // https://github.com/tensorflow/tfjs/blob/b53bd47e880367ae57493f0ea628abaf08db2d5d/tfjs-core/src/ops/fused/conv2d.ts#L181-L196 function getShapeForBatchMatMul(shape, isChannelsLast) { const length = shape.length; if (length >= 3) { return isChannelsLast ? [ ...shape.slice(0, -3) /* batch */, shape[length - 3] * shape[length - 2] /* height * width */, shape[length - 1] /* channel */ ] : [ ...shape.slice(0, -3) /* batch */, shape[length - 3] /* channel */, shape[length - 2] * shape[length - 1] /* height * width */ ]; } else if (!isChannelsLast && length === 1 && shape[0] > 1) { return [shape[0], 1]; } else { return null; } } // For 1x1 kernels that iterate through every point in the input, convolution // can be expressed as matrix multiplication (without need for memory // remapping). export function conv2dByMatMul({ x, filter, convInfo, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) { // Reshapes conv2D input to 2D tensors, uses matMul and then reshape the // result from 2D to 4D. const xShape = x.shape; const xTexData = backend.texData.get(x.dataId); const sharedMatMulDim = convInfo.inChannels; const outerShapeX = xShape[0] * xShape[1] * xShape[2]; const outerShapeFilter = convInfo.outChannels; const isChannelsLast = convInfo.dataFormat === 'channelsLast'; const transposeA = false; const transposeB = false; let out; const intermediates = []; if (preluActivationWeights != null) { const targetShape = getShapeForBatchMatMul(preluActivationWeights.shape, isChannelsLast); if (targetShape != null) { preluActivationWeights = reshape({ inputs: { x: preluActivationWeights }, backend, attrs: { shape: targetShape } }); intermediates.push(preluActivationWeights); } } if (bias != null) { const targetShape = getShapeForBatchMatMul(bias.shape, isChannelsLast); if (targetShape != null) { bias = reshape({ inputs: { x: bias }, backend, attrs: { shape: targetShape } }); intermediates.push(bias); } } // TODO: Once reduction ops are packed, batchMatMul will always be packed // and we can remove this condition. const batchMatMulWillBeUnpacked = (outerShapeX === 1 || outerShapeFilter === 1) && sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD; // The algorithm in the if condition assumes (1) the output will be packed, // (2) x is packed, (3) x isChannelsLast, (4) x's packed texture is already // on GPU, (5) col is odd, (6) the width, height and inChannels are the same // for xTexData.shape and xShape. const canOptimize = !batchMatMulWillBeUnpacked && xTexData.isPacked && isChannelsLast && xTexData.texture != null && xShape[2] % 2 !== 0 && util.arraysEqual(xTexData.shape.slice(-3), xShape.slice(-3)); if (canOptimize) { // We avoid expensive packed 2x2 reshape by padding col count to next, // even number. When col is odd, the result of packed batchMatMul is // the same (has the same texture layout and and values in the texture) as // it is for next even col. We make the odd-cols tensor to look like // even-cols tensor before the operation and, after the batchMatMul, // fix the even-cols result to have odd number of cols. const targetShape = xShape[0] * xShape[1] * (xShape[2] + 1); const xReshaped = { dataId: x.dataId, shape: [1, targetShape, convInfo.inChannels], dtype: x.dtype }; // xTexData.shape gets referenced from GPGPUBinary.inShapeInfos. // Decrementing col count, after batchMatMul->...->compileProgram leads to // invalid col count within the reference in GPGPUBinary.inShapeInfos. // Alternative fix would be to provide a copy to GPGPUBinary.inShapeInfos // in compileProgram method, but that would affect compilation of all // programs - instead, provide a copy here, with even col count, before // calling batchMatMul->...->compileProgram and after that, the original // xTexData.shape is restored. const originalXTexDataShape = xTexData.shape; xTexData.shape = xTexData.shape.slice(); xTexData.shape[xTexData.shape.length - 2]++; util.assert(webgl_util.isReshapeFree(xTexData.shape, xReshaped.shape), () => `packed reshape ${xTexData.shape} to ${xReshaped.shape} isn't free`); const filterReshaped = reshape({ inputs: { x: filter }, backend, attrs: { shape: [1, convInfo.inChannels, convInfo.outChannels] } }); intermediates.push(filterReshaped); const pointwiseConv = batchMatMulImpl({ a: xReshaped, b: filterReshaped, backend, transposeA, transposeB, bias, activation, preluActivationWeights, leakyreluAlpha }); const pointwiseConvTexData = backend.texData.get(pointwiseConv.dataId); util.assert(pointwiseConvTexData.isPacked, () => 'batchMatMul result is expected to be packed'); // Restore the input shape to original. xTexData.shape = originalXTexDataShape; // Set the output shape - there is no need for expensive reshape as data // layout is already correct. pointwiseConvTexData.shape = convInfo.outShape; out = identity({ inputs: { x: pointwiseConv }, backend }); out.shape = convInfo.outShape; intermediates.push(pointwiseConv); } else { const numCols = convInfo.outHeight * convInfo.outWidth; const xReshaped = reshape({ inputs: { x }, backend, attrs: { shape: isChannelsLast ? [convInfo.batchSize, numCols, convInfo.inChannels] : [convInfo.batchSize, convInfo.inChannels, numCols] } }); const filterReshaped = reshape({ inputs: { x: filter }, backend, attrs: { shape: [1, convInfo.inChannels, convInfo.outChannels] } }); const result = batchMatMulImpl({ a: isChannelsLast ? xReshaped : filterReshaped, b: isChannelsLast ? filterReshaped : xReshaped, transposeA: !isChannelsLast, transposeB, backend, bias, activation, preluActivationWeights, leakyreluAlpha }); out = reshape({ inputs: { x: result }, backend, attrs: { shape: convInfo.outShape } }); intermediates.push(xReshaped); intermediates.push(filterReshaped); intermediates.push(result); } for (const i of intermediates) { backend.disposeIntermediateTensorInfo(i); } return out; } // Implements the im2row algorithm as outlined in "High Performance // Convolutional Neural Networks for Document Processing" (Suvisoft, 2006) export function conv2dWithIm2Row({ x, filter, convInfo, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) { // Rearranges conv2d input so each block to be convolved over forms the // column of a new matrix with shape [filterWidth * filterHeight * // inChannels, outHeight * outWidth]. The filter is also rearranged so each // output channel forms a row of a new matrix with shape [outChannels, // filterWidth * filterHeight * inChannels]. The convolution is then // computed by multiplying these matrices and reshaping the result. const { filterWidth, filterHeight, inChannels, outWidth, outHeight, dataFormat } = convInfo; const isChannelsLast = dataFormat === 'channelsLast'; const sharedDim = filterWidth * filterHeight * inChannels; const numCols = outHeight * outWidth; const x2ColShape = [convInfo.batchSize, sharedDim, numCols]; const transposeA = true; const transposeB = false; const intermediates = []; if (preluActivationWeights != null) { const targetShape = getShapeForBatchMatMul(preluActivationWeights.shape, isChannelsLast); if (targetShape != null) { preluActivationWeights = reshape({ inputs: { x: preluActivationWeights }, backend, attrs: { shape: targetShape } }); intermediates.push(preluActivationWeights); } } if (bias != null) { const targetShape = getShapeForBatchMatMul(bias.shape, isChannelsLast); if (targetShape != null) { bias = reshape({ inputs: { x: bias }, backend, attrs: { shape: targetShape } }); intermediates.push(bias); } } const w2Row = reshape({ inputs: { x: filter }, backend, attrs: { shape: [1, sharedDim, util.sizeFromShape(filter.shape) / sharedDim] } }); intermediates.push(w2Row); const im2ColProgram = new Im2ColPackedProgram(x2ColShape, convInfo); const customValues = [ x.shape, [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inChannels], [convInfo.filterWidth * convInfo.inChannels], [convInfo.outWidth] ]; const im2Col = backend.runWebGLProgram(im2ColProgram, [x], 'float32', customValues); const im2ColReshaped = reshape({ inputs: { x: im2Col }, backend, attrs: { shape: x2ColShape } }); intermediates.push(im2Col); intermediates.push(im2ColReshaped); const hasBias = bias != null; const hasPreluActivationWeights = preluActivationWeights != null; const hasLeakyreluAlpha = activation === 'leakyrelu'; const fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null; const matmulProgram = new MatMulPackedProgram(isChannelsLast ? im2ColReshaped.shape : w2Row.shape, isChannelsLast ? w2Row.shape : im2ColReshaped.shape, isChannelsLast ? [convInfo.batchSize, numCols, convInfo.outChannels] : [convInfo.batchSize, convInfo.outChannels, numCols], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); const inputs = isChannelsLast ? [im2ColReshaped, w2Row] : [w2Row, im2ColReshaped]; if (bias) { inputs.push(bias); } if (hasPreluActivationWeights) { inputs.push(preluActivationWeights); } if (hasLeakyreluAlpha) { const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', util.createScalarValue(leakyreluAlpha, 'float32')); inputs.push($leakyreluAlpha); intermediates.push($leakyreluAlpha); } const product = backend.runWebGLProgram(matmulProgram, inputs, 'float32'); const out = reshape({ inputs: { x: product }, backend, attrs: { shape: convInfo.outShape } }); intermediates.push(product); for (const i of intermediates) { backend.disposeIntermediateTensorInfo(i); } return out; } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"Conv2D_impl.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-webgl/src/kernels/Conv2D_impl.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAA2B,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAKrE,OAAO,EAAC,mBAAmB,EAAC,MAAM,sBAAsB,CAAC;AACzD,OAAO,EAAC,4BAA4B,EAAC,MAAM,oCAAoC,CAAC;AAChF,OAAO,EAAC,mBAAmB,EAAC,MAAM,sBAAsB,CAAC;AACzD,OAAO,KAAK,UAAU,MAAM,eAAe,CAAC;AAE5C,OAAO,EAAC,eAAe,EAAE,2BAA2B,EAAC,MAAM,oBAAoB,CAAC;AAChF,OAAO,EAAC,QAAQ,EAAC,MAAM,YAAY,CAAC;AACpC,OAAO,EAAC,OAAO,EAAC,MAAM,WAAW,CAAC;AAalC,0EAA0E;AAC1E,4EAA4E;AAC5E,gDAAgD;AAChD,EAAE;AACF,sEAAsE;AACtE,oEAAoE;AACpE,EAAE;AACF,uEAAuE;AACvE,0EAA0E;AAC1E,qEAAqE;AACrE,wBAAwB;AACxB,+HAA+H;AAC/H,SAAS,sBAAsB,CAC3B,KAAe,EAAE,cAAuB;IAC1C,MAAM,MAAM,GAAG,KAAK,CAAC,MAAM,CAAC;IAC5B,IAAI,MAAM,IAAI,CAAC,EAAE;QACf,OAAO,cAAc,CAAC,CAAC;YACnB;gBACE,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,WAAW;gBACjC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,GAAG,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,oBAAoB;gBAC1D,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,aAAa;aAChC,CAAC,CAAC;YACH;gBACE,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,WAAW,EAAE,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,aAAa;gBAClE,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,GAAG,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,oBAAoB;aAC3D,CAAC;KACP;SAAM,IAAI,CAAC,cAAc,IAAI,MAAM,KAAK,CAAC,IAAI,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,EAAE;QAC1D,OAAO,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;KACtB;SAAM;QACL,OAAO,IAAI,CAAC;KACb;AACH,CAAC;AAED,6EAA6E;AAC7E,qEAAqE;AACrE,cAAc;AACd,MAAM,UAAU,cAAc,CAAC,EAC7B,CAAC,EACD,MAAM,EACN,QAAQ,EACR,OAAO,EACP,IAAI,GAAG,IAAI,EACX,sBAAsB,GAAG,IAAI,EAC7B,cAAc,GAAG,CAAC,EAClB,UAAU,GAAG,IAAI,EACJ;IACb,wEAAwE;IACxE,wBAAwB;IACxB,MAAM,MAAM,GAAG,CAAC,CAAC,KAAK,CAAC;IACvB,MAAM,QAAQ,GAAG,OAAO,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;IAC/C,MAAM,eAAe,GAAG,QAAQ,CAAC,UAAU,CAAC;IAC5C,MAAM,WAAW,GAAG,MAAM,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;IACtD,MAAM,gBAAgB,GAAG,QAAQ,CAAC,WAAW,CAAC;IAC9C,MAAM,cAAc,GAAG,QAAQ,CAAC,UAAU,KAAK,cAAc,CAAC;IAC9D,MAAM,UAAU,GAAG,KAAK,CAAC;IACzB,MAAM,UAAU,GAAG,KAAK,CAAC;IAEzB,IAAI,GAAe,CAAC;IACpB,MAAM,aAAa,GAAiB,EAAE,CAAC;IAEvC,IAAI,sBAAsB,IAAI,IAAI,EAAE;QAClC,MAAM,WAAW,GACb,sBAAsB,CAAC,sBAAsB,CAAC,KAAK,EAAE,cAAc,CAAC,CAAC;QACzE,IAAI,WAAW,IAAI,IAAI,EAAE;YACvB,sBAAsB,GAAG,OAAO,CAAC;gBAC/B,MAAM,EAAE,EAAC,CAAC,EAAE,sBAAsB,EAAC;gBACnC,OAAO;gBACP,KAAK,EAAE,EAAC,KAAK,EAAE,WAAW,EAAC;aAC5B,CAAC,CAAC;YACH,aAAa,CAAC,IAAI,CAAC,sBAAsB,CAAC,CAAC;SAC5C;KACF;IAED,IAAI,IAAI,IAAI,IAAI,EAAE;QAChB,MAAM,WAAW,GAAG,sBAAsB,CAAC,IAAI,CAAC,KAAK,EAAE,cAAc,CAAC,CAAC;QACvE,IAAI,WAAW,IAAI,IAAI,EAAE;YACvB,IAAI,GAAG,OAAO,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,IAAI,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,WAAW,EAAC,EAAC,CAAC,CAAC;YAC1E,aAAa,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;SAC1B;KACF;IAED,yEAAyE;IACzE,oCAAoC;IACpC,MAAM,yBAAyB,GAC3B,CAAC,WAAW,KAAK,CAAC,IAAI,gBAAgB,KAAK,CAAC,CAAC;QAC7C,eAAe,GAAG,2BAA2B,CAAC;IAElD,2EAA2E;IAC3E,4EAA4E;IAC5E,4EAA4E;IAC5E,iCAAiC;IACjC,MAAM,WAAW,GAAG,CAAC,yBAAyB,IAAI,QAAQ,CAAC,QAAQ;QAC/D,cAAc,IAAI,QAAQ,CAAC,OAAO,IAAI,IAAI,IAAI,MAAM,CAAC,CAAC,CAAC,GAAG,CAAC,KAAK,CAAC;QACjE,IAAI,CAAC,WAAW,CAAC,QAAQ,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAEjE,IAAI,WAAW,EAAE;QACf,sEAAsE;QACtE,oEAAoE;QACpE,0EAA0E;QAC1E,oEAAoE;QACpE,oEAAoE;QACpE,uDAAuD;QACvD,MAAM,WAAW,GAAG,MAAM,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;QAC5D,MAAM,SAAS,GAAe;YAC5B,MAAM,EAAE,CAAC,CAAC,MAAM;YAChB,KAAK,EAAE,CAAC,CAAC,EAAE,WAAW,EAAE,QAAQ,CAAC,UAAU,CAAC;YAC5C,KAAK,EAAE,CAAC,CAAC,KAAK;SACf,CAAC;QACF,gEAAgE;QAChE,0EAA0E;QAC1E,sEAAsE;QACtE,yEAAyE;QACzE,qEAAqE;QACrE,uEAAuE;QACvE,wEAAwE;QACxE,8BAA8B;QAC9B,MAAM,qBAAqB,GAAG,QAAQ,CAAC,KAAK,CAAC;QAC7C,QAAQ,CAAC,KAAK,GAAG,QAAQ,CAAC,KAAK,CAAC,KAAK,EAAE,CAAC;QACxC,QAAQ,CAAC,KAAK,CAAC,QAAQ,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,EAAE,CAAC;QAC5C,IAAI,CAAC,MAAM,CACP,UAAU,CAAC,aAAa,CAAC,QAAQ,CAAC,KAAK,EAAE,SAAS,CAAC,KAAK,CAAC,EACzD,GAAG,EAAE,CAAC,kBAAkB,QAAQ,CAAC,KAAK,OAClC,SAAS,CAAC,KAAK,aAAa,CAAC,CAAC;QACtC,MAAM,cAAc,GAAG,OAAO,CAAC;YAC7B,MAAM,EAAE,EAAC,CAAC,EAAE,MAAM,EAAC;YACnB,OAAO;YACP,KAAK,EAAE,EAAC,KAAK,EAAE,CAAC,CAAC,EAAE,QAAQ,CAAC,UAAU,EAAE,QAAQ,CAAC,WAAW,CAAC,EAAC;SAC/D,CAAC,CAAC;QACH,aAAa,CAAC,IAAI,CAAC,cAAc,CAAC,CAAC;QACnC,MAAM,aAAa,GAAG,eAAe,CAAC;YACpC,CAAC,EAAE,SAAS;YACZ,CAAC,EAAE,cAAc;YACjB,OAAO;YACP,UAAU;YACV,UAAU;YACV,IAAI;YACJ,UAAU;YACV,sBAAsB;YACtB,cAAc;SACf,CAAC,CAAC;QAEH,MAAM,oBAAoB,GAAG,OAAO,CAAC,OAAO,CAAC,GAAG,CAAC,aAAa,CAAC,MAAM,CAAC,CAAC;QACvE,IAAI,CAAC,MAAM,CACP,oBAAoB,CAAC,QAAQ,EAC7B,GAAG,EAAE,CAAC,6CAA6C,CAAC,CAAC;QACzD,uCAAuC;QACvC,QAAQ,CAAC,KAAK,GAAG,qBAAqB,CAAC;QACvC,wEAAwE;QACxE,6BAA6B;QAC7B,oBAAoB,CAAC,KAAK,GAAG,QAAQ,CAAC,QAAQ,CAAC;QAE/C,GAAG,GAAG,QAAQ,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,aAAa,EAAC,EAAE,OAAO,EAAC,CAAC,CAAC;QACtD,GAAG,CAAC,KAAK,GAAG,QAAQ,CAAC,QAAQ,CAAC;QAE9B,aAAa,CAAC,IAAI,CAAC,aAAa,CAAC,CAAC;KACnC;SAAM;QACL,MAAM,OAAO,GAAG,QAAQ,CAAC,SAAS,GAAG,QAAQ,CAAC,QAAQ,CAAC;QACvD,MAAM,SAAS,GAAG,OAAO,CAAC;YACxB,MAAM,EAAE,EAAC,CAAC,EAAC;YACX,OAAO;YACP,KAAK,EAAE;gBACL,KAAK,EAAE,cAAc,CAAC,CAAC;oBACnB,CAAC,QAAQ,CAAC,SAAS,EAAE,OAAO,EAAE,QAAQ,CAAC,UAAU,CAAC,CAAC,CAAC;oBACpD,CAAC,QAAQ,CAAC,SAAS,EAAE,QAAQ,CAAC,UAAU,EAAE,OAAO,CAAC;aACvD;SACF,CAAC,CAAC;QACH,MAAM,cAAc,GAAG,OAAO,CAAC;YAC7B,MAAM,EAAE,EAAC,CAAC,EAAE,MAAM,EAAC;YACnB,OAAO;YACP,KAAK,EAAE,EAAC,KAAK,EAAE,CAAC,CAAC,EAAE,QAAQ,CAAC,UAAU,EAAE,QAAQ,CAAC,WAAW,CAAC,EAAC;SAC/D,CAAC,CAAC;QACH,MAAM,MAAM,GAAG,eAAe,CAAC;YAC7B,CAAC,EAAE,cAAc,CAAC,CAAC,CAAC,SAAS,CAAC,CAAC,CAAC,cAAc;YAC9C,CAAC,EAAE,cAAc,CAAC,CAAC,CAAC,cAAc,CAAC,CAAC,CAAC,SAAS;YAC9C,UAAU,EAAE,CAAC,cAAc;YAC3B,UAAU;YACV,OAAO;YACP,IAAI;YACJ,UAAU;YACV,sBAAsB;YACtB,cAAc;SACf,CAAC,CAAC;QAEH,GAAG,GAAG,OAAO,CACT,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,MAAM,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,QAAQ,CAAC,QAAQ,EAAC,EAAC,CAAC,CAAC;QAEvE,aAAa,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;QAC9B,aAAa,CAAC,IAAI,CAAC,cAAc,CAAC,CAAC;QACnC,aAAa,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;KAC5B;IAED,KAAK,MAAM,CAAC,IAAI,aAAa,EAAE;QAC7B,OAAO,CAAC,6BAA6B,CAAC,CAAC,CAAC,CAAC;KAC1C;IAED,OAAO,GAAG,CAAC;AACb,CAAC;AAED,mEAAmE;AACnE,0EAA0E;AAC1E,MAAM,UAAU,gBAAgB,CAAC,EAC/B,CAAC,EACD,MAAM,EACN,QAAQ,EACR,OAAO,EACP,IAAI,GAAG,IAAI,EACX,sBAAsB,GAAG,IAAI,EAC7B,cAAc,GAAG,CAAC,EAClB,UAAU,GAAG,IAAI,EACJ;IACb,uEAAuE;IACvE,kEAAkE;IAClE,2EAA2E;IAC3E,sEAAsE;IACtE,oEAAoE;IACpE,mEAAmE;IACnE,MAAM,EACJ,WAAW,EACX,YAAY,EACZ,UAAU,EACV,QAAQ,EACR,SAAS,EACT,UAAU,EACX,GAAG,QAAQ,CAAC;IAEb,MAAM,cAAc,GAAG,UAAU,KAAK,cAAc,CAAC;IAErD,MAAM,SAAS,GAAG,WAAW,GAAG,YAAY,GAAG,UAAU,CAAC;IAC1D,MAAM,OAAO,GAAG,SAAS,GAAG,QAAQ,CAAC;IACrC,MAAM,UAAU,GAAG,CAAC,QAAQ,CAAC,SAAS,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC;IAC5D,MAAM,UAAU,GAAG,IAAI,CAAC;IACxB,MAAM,UAAU,GAAG,KAAK,CAAC;IAEzB,MAAM,aAAa,GAAiB,EAAE,CAAC;IAEvC,IAAI,sBAAsB,IAAI,IAAI,EAAE;QAClC,MAAM,WAAW,GACb,sBAAsB,CAAC,sBAAsB,CAAC,KAAK,EAAE,cAAc,CAAC,CAAC;QACzE,IAAI,WAAW,IAAI,IAAI,EAAE;YACvB,sBAAsB,GAAG,OAAO,CAAC;gBAC/B,MAAM,EAAE,EAAC,CAAC,EAAE,sBAAsB,EAAC;gBACnC,OAAO;gBACP,KAAK,EAAE,EAAC,KAAK,EAAE,WAAW,EAAC;aAC5B,CAAC,CAAC;YACH,aAAa,CAAC,IAAI,CAAC,sBAAsB,CAAC,CAAC;SAC5C;KACF;IAED,IAAI,IAAI,IAAI,IAAI,EAAE;QAChB,MAAM,WAAW,GAAG,sBAAsB,CAAC,IAAI,CAAC,KAAK,EAAE,cAAc,CAAC,CAAC;QACvE,IAAI,WAAW,IAAI,IAAI,EAAE;YACvB,IAAI,GAAG,OAAO,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,IAAI,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,WAAW,EAAC,EAAC,CAAC,CAAC;YAC1E,aAAa,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;SAC1B;KACF;IAED,MAAM,KAAK,GAAG,OAAO,CAAC;QACpB,MAAM,EAAE,EAAC,CAAC,EAAE,MAAM,EAAC;QACnB,OAAO;QACP,KAAK,EAAE,EAAC,KAAK,EAAE,CAAC,CAAC,EAAE,SAAS,EAAE,IAAI,CAAC,aAAa,CAAC,MAAM,CAAC,KAAK,CAAC,GAAG,SAAS,CAAC,EAAC;KAC7E,CAAC,CAAC;IACH,aAAa,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;IAE1B,MAAM,aAAa,GAAG,IAAI,mBAAmB,CAAC,UAAU,EAAE,QAAQ,CAAC,CAAC;IACpE,MAAM,YAAY,GAAG;QACnB,CAAC,CAAC,KAAK,EAAE,CAAC,QAAQ,CAAC,OAAO,CAAC,GAAG,EAAE,QAAQ,CAAC,OAAO,CAAC,IAAI,CAAC;QACtD,CAAC,QAAQ,CAAC,YAAY,EAAE,QAAQ,CAAC,WAAW,CAAC;QAC7C,CAAC,QAAQ,CAAC,cAAc,EAAE,QAAQ,CAAC,aAAa,CAAC,EAAE,CAAC,QAAQ,CAAC,UAAU,CAAC;QACxE,CAAC,QAAQ,CAAC,WAAW,GAAG,QAAQ,CAAC,UAAU,CAAC,EAAE,CAAC,QAAQ,CAAC,QAAQ,CAAC;KAClE,CAAC;IACF,MAAM,MAAM,GACR,OAAO,CAAC,eAAe,CAAC,aAAa,EAAE,CAAC,CAAC,CAAC,EAAE,SAAS,EAAE,YAAY,CAAC,CAAC;IACzE,MAAM,cAAc,GAChB,OAAO,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,MAAM,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,UAAU,EAAC,EAAC,CAAC,CAAC;IAExE,aAAa,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;IAC3B,aAAa,CAAC,IAAI,CAAC,cAAc,CAAC,CAAC;IAEnC,MAAM,OAAO,GAAG,IAAI,IAAI,IAAI,CAAC;IAC7B,MAAM,yBAAyB,GAAG,sBAAsB,IAAI,IAAI,CAAC;IACjE,MAAM,iBAAiB,GAAG,UAAU,KAAK,WAAW,CAAC;IACrD,MAAM,eAAe,GACjB,UAAU,CAAC,CAAC,CAAC,4BAA4B,CAAC,UAAU,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC;IACvE,MAAM,aAAa,GAAG,IAAI,mBAAmB,CACzC,cAAc,CAAC,CAAC,CAAC,cAAc,CAAC,KAAiC,CAAC,CAAC;QAClD,KAAK,CAAC,KAAiC,EACxD,cAAc,CAAC,CAAC,CAAC,KAAK,CAAC,KAAiC,CAAC,CAAC;QACzC,cAAc,CAAC,KAAiC,EACjE,cAAc,CAAC,CAAC,CAAC,CAAC,QAAQ,CAAC,SAAS,EAAE,OAAO,EAAE,QAAQ,CAAC,WAAW,CAAC,CAAC,CAAC;QACrD,CAAC,QAAQ,CAAC,SAAS,EAAE,QAAQ,CAAC,WAAW,EAAE,OAAO,CAAC,EACpE,UAAU,EAAE,UAAU,EAAE,OAAO,EAAE,eAAe,EAChD,yBAAyB,EAAE,iBAAiB,CAAC,CAAC;IAClD,MAAM,MAAM,GACR,cAAc,CAAC,CAAC,CAAC,CAAC,cAAc,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,EAAE,cAAc,CAAC,CAAC;IACvE,IAAI,IAAI,EAAE;QACR,MAAM,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;KACnB;IACD,IAAI,yBAAyB,EAAE;QAC7B,MAAM,CAAC,IAAI,CAAC,sBAAsB,CAAC,CAAC;KACrC;IACD,IAAI,iBAAiB,EAAE;QACrB,MAAM,eAAe,GAAG,OAAO,CAAC,cAAc,CAC1C,EAAE,EAAE,SAAS,EACb,IAAI,CAAC,iBAAiB,CAAC,cAAsC,EACtC,SAAS,CAAC,CAAC,CAAC;QACvC,MAAM,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;QAC7B,aAAa,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;KACrC;IACD,MAAM,OAAO,GAAG,OAAO,CAAC,eAAe,CAAC,aAAa,EAAE,MAAM,EAAE,SAAS,CAAC,CAAC;IAC1E,MAAM,GAAG,GAAG,OAAO,CACf,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,OAAO,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,QAAQ,CAAC,QAAQ,EAAC,EAAC,CAAC,CAAC;IAExE,aAAa,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;IAC5B,KAAK,MAAM,CAAC,IAAI,aAAa,EAAE;QAC7B,OAAO,CAAC,6BAA6B,CAAC,CAAC,CAAC,CAAC;KAC1C;IAED,OAAO,GAAG,CAAC;AACb,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {backend_util, TensorInfo, util} from '@tensorflow/tfjs-core';\n\n// import {assertAndGetBroadcastShape} from\n// '../../../tfjs-core/src/ops/broadcast_util';\nimport {MathBackendWebGL} from '../backend_webgl';\nimport {Im2ColPackedProgram} from '../im2col_packed_gpu';\nimport {mapActivationToShaderProgram} from '../kernel_utils/kernel_funcs_utils';\nimport {MatMulPackedProgram} from '../mulmat_packed_gpu';\nimport * as webgl_util from '../webgl_util';\n\nimport {batchMatMulImpl, MATMUL_SHARED_DIM_THRESHOLD} from './BatchMatMul_impl';\nimport {identity} from './Identity';\nimport {reshape} from './Reshape';\n\ntype Conv2DConfig = {\n  x: TensorInfo,\n  filter: TensorInfo,\n  convInfo: backend_util.Conv2DInfo,\n  backend: MathBackendWebGL,\n  bias?: TensorInfo,\n  preluActivationWeights?: TensorInfo,\n  leakyreluAlpha?: number,\n  activation?: backend_util.Activation\n};\n\n// Both conv2dByMatMul and conv2dWithIm2Row fuse height and width into one\n// dimension to compute batchMatMul, so bias and activation weights are also\n// supposed to fuse the two dimensions into one.\n//\n// This function computes the target shape for fusing height and width\n// dimensions. Returning null means the shape is already compatible.\n//\n// Even though the bias is not supposed to be a 3-D or a 4-D (including\n// batch) tensor and PReLU activiation weights is not supposed to be a 4-D\n// tensor, we still need to support them, because we haven't disabled\n// them for NHWC format.\n// https://github.com/tensorflow/tfjs/blob/b53bd47e880367ae57493f0ea628abaf08db2d5d/tfjs-core/src/ops/fused/conv2d.ts#L181-L196\nfunction getShapeForBatchMatMul(\n    shape: number[], isChannelsLast: boolean): number[] {\n  const length = shape.length;\n  if (length >= 3) {\n    return isChannelsLast ?\n        [\n          ...shape.slice(0, -3) /* batch */,\n          shape[length - 3] * shape[length - 2] /* height * width */,\n          shape[length - 1] /* channel */\n        ] :\n        [\n          ...shape.slice(0, -3) /* batch */, shape[length - 3] /* channel */,\n          shape[length - 2] * shape[length - 1] /* height * width */\n        ];\n  } else if (!isChannelsLast && length === 1 && shape[0] > 1) {\n    return [shape[0], 1];\n  } else {\n    return null;\n  }\n}\n\n// For 1x1 kernels that iterate through every point in the input, convolution\n// can be expressed as matrix multiplication (without need for memory\n// remapping).\nexport function conv2dByMatMul({\n  x,\n  filter,\n  convInfo,\n  backend,\n  bias = null,\n  preluActivationWeights = null,\n  leakyreluAlpha = 0,\n  activation = null\n}: Conv2DConfig) {\n  // Reshapes conv2D input to 2D tensors, uses matMul and then reshape the\n  // result from 2D to 4D.\n  const xShape = x.shape;\n  const xTexData = backend.texData.get(x.dataId);\n  const sharedMatMulDim = convInfo.inChannels;\n  const outerShapeX = xShape[0] * xShape[1] * xShape[2];\n  const outerShapeFilter = convInfo.outChannels;\n  const isChannelsLast = convInfo.dataFormat === 'channelsLast';\n  const transposeA = false;\n  const transposeB = false;\n\n  let out: TensorInfo;\n  const intermediates: TensorInfo[] = [];\n\n  if (preluActivationWeights != null) {\n    const targetShape =\n        getShapeForBatchMatMul(preluActivationWeights.shape, isChannelsLast);\n    if (targetShape != null) {\n      preluActivationWeights = reshape({\n        inputs: {x: preluActivationWeights},\n        backend,\n        attrs: {shape: targetShape}\n      });\n      intermediates.push(preluActivationWeights);\n    }\n  }\n\n  if (bias != null) {\n    const targetShape = getShapeForBatchMatMul(bias.shape, isChannelsLast);\n    if (targetShape != null) {\n      bias = reshape({inputs: {x: bias}, backend, attrs: {shape: targetShape}});\n      intermediates.push(bias);\n    }\n  }\n\n  // TODO: Once reduction ops are packed, batchMatMul will always be packed\n  // and we can remove this condition.\n  const batchMatMulWillBeUnpacked =\n      (outerShapeX === 1 || outerShapeFilter === 1) &&\n      sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD;\n\n  // The algorithm in the if condition assumes (1) the output will be packed,\n  // (2) x is packed, (3) x isChannelsLast, (4)  x's packed texture is already\n  // on GPU, (5) col is odd, (6) the width, height and inChannels are the same\n  // for xTexData.shape and xShape.\n  const canOptimize = !batchMatMulWillBeUnpacked && xTexData.isPacked &&\n      isChannelsLast && xTexData.texture != null && xShape[2] % 2 !== 0 &&\n      util.arraysEqual(xTexData.shape.slice(-3), xShape.slice(-3));\n\n  if (canOptimize) {\n    // We avoid expensive packed 2x2 reshape by padding col count to next,\n    // even number. When col is odd, the result of packed batchMatMul is\n    // the same (has the same texture layout and and values in the texture) as\n    // it is for next even col. We make the odd-cols tensor to look like\n    // even-cols tensor before the operation and, after the batchMatMul,\n    // fix the even-cols result to have odd number of cols.\n    const targetShape = xShape[0] * xShape[1] * (xShape[2] + 1);\n    const xReshaped: TensorInfo = {\n      dataId: x.dataId,\n      shape: [1, targetShape, convInfo.inChannels],\n      dtype: x.dtype\n    };\n    // xTexData.shape gets referenced from GPGPUBinary.inShapeInfos.\n    // Decrementing col count, after batchMatMul->...->compileProgram leads to\n    // invalid col count within the reference in GPGPUBinary.inShapeInfos.\n    // Alternative fix would be to provide a copy to GPGPUBinary.inShapeInfos\n    // in compileProgram method, but that would affect compilation of all\n    // programs - instead, provide a copy here, with even col count, before\n    // calling batchMatMul->...->compileProgram and after that, the original\n    // xTexData.shape is restored.\n    const originalXTexDataShape = xTexData.shape;\n    xTexData.shape = xTexData.shape.slice();\n    xTexData.shape[xTexData.shape.length - 2]++;\n    util.assert(\n        webgl_util.isReshapeFree(xTexData.shape, xReshaped.shape),\n        () => `packed reshape ${xTexData.shape} to ${\n            xReshaped.shape} isn't free`);\n    const filterReshaped = reshape({\n      inputs: {x: filter},\n      backend,\n      attrs: {shape: [1, convInfo.inChannels, convInfo.outChannels]}\n    });\n    intermediates.push(filterReshaped);\n    const pointwiseConv = batchMatMulImpl({\n      a: xReshaped,\n      b: filterReshaped,\n      backend,\n      transposeA,\n      transposeB,\n      bias,\n      activation,\n      preluActivationWeights,\n      leakyreluAlpha\n    });\n\n    const pointwiseConvTexData = backend.texData.get(pointwiseConv.dataId);\n    util.assert(\n        pointwiseConvTexData.isPacked,\n        () => 'batchMatMul result is expected to be packed');\n    // Restore the input shape to original.\n    xTexData.shape = originalXTexDataShape;\n    // Set the output shape - there is no need for expensive reshape as data\n    // layout is already correct.\n    pointwiseConvTexData.shape = convInfo.outShape;\n\n    out = identity({inputs: {x: pointwiseConv}, backend});\n    out.shape = convInfo.outShape;\n\n    intermediates.push(pointwiseConv);\n  } else {\n    const numCols = convInfo.outHeight * convInfo.outWidth;\n    const xReshaped = reshape({\n      inputs: {x},\n      backend,\n      attrs: {\n        shape: isChannelsLast ?\n            [convInfo.batchSize, numCols, convInfo.inChannels] :\n            [convInfo.batchSize, convInfo.inChannels, numCols]\n      }\n    });\n    const filterReshaped = reshape({\n      inputs: {x: filter},\n      backend,\n      attrs: {shape: [1, convInfo.inChannels, convInfo.outChannels]}\n    });\n    const result = batchMatMulImpl({\n      a: isChannelsLast ? xReshaped : filterReshaped,\n      b: isChannelsLast ? filterReshaped : xReshaped,\n      transposeA: !isChannelsLast,\n      transposeB,\n      backend,\n      bias,\n      activation,\n      preluActivationWeights,\n      leakyreluAlpha\n    });\n\n    out = reshape(\n        {inputs: {x: result}, backend, attrs: {shape: convInfo.outShape}});\n\n    intermediates.push(xReshaped);\n    intermediates.push(filterReshaped);\n    intermediates.push(result);\n  }\n\n  for (const i of intermediates) {\n    backend.disposeIntermediateTensorInfo(i);\n  }\n\n  return out;\n}\n\n// Implements the im2row algorithm as outlined in \"High Performance\n// Convolutional Neural Networks for Document Processing\" (Suvisoft, 2006)\nexport function conv2dWithIm2Row({\n  x,\n  filter,\n  convInfo,\n  backend,\n  bias = null,\n  preluActivationWeights = null,\n  leakyreluAlpha = 0,\n  activation = null\n}: Conv2DConfig) {\n  // Rearranges conv2d input so each block to be convolved over forms the\n  // column of a new matrix with shape [filterWidth * filterHeight *\n  // inChannels, outHeight * outWidth]. The filter is also rearranged so each\n  // output channel forms a row of a new matrix with shape [outChannels,\n  // filterWidth * filterHeight * inChannels]. The convolution is then\n  // computed by multiplying these matrices and reshaping the result.\n  const {\n    filterWidth,\n    filterHeight,\n    inChannels,\n    outWidth,\n    outHeight,\n    dataFormat\n  } = convInfo;\n\n  const isChannelsLast = dataFormat === 'channelsLast';\n\n  const sharedDim = filterWidth * filterHeight * inChannels;\n  const numCols = outHeight * outWidth;\n  const x2ColShape = [convInfo.batchSize, sharedDim, numCols];\n  const transposeA = true;\n  const transposeB = false;\n\n  const intermediates: TensorInfo[] = [];\n\n  if (preluActivationWeights != null) {\n    const targetShape =\n        getShapeForBatchMatMul(preluActivationWeights.shape, isChannelsLast);\n    if (targetShape != null) {\n      preluActivationWeights = reshape({\n        inputs: {x: preluActivationWeights},\n        backend,\n        attrs: {shape: targetShape}\n      });\n      intermediates.push(preluActivationWeights);\n    }\n  }\n\n  if (bias != null) {\n    const targetShape = getShapeForBatchMatMul(bias.shape, isChannelsLast);\n    if (targetShape != null) {\n      bias = reshape({inputs: {x: bias}, backend, attrs: {shape: targetShape}});\n      intermediates.push(bias);\n    }\n  }\n\n  const w2Row = reshape({\n    inputs: {x: filter},\n    backend,\n    attrs: {shape: [1, sharedDim, util.sizeFromShape(filter.shape) / sharedDim]}\n  });\n  intermediates.push(w2Row);\n\n  const im2ColProgram = new Im2ColPackedProgram(x2ColShape, convInfo);\n  const customValues = [\n    x.shape, [convInfo.padInfo.top, convInfo.padInfo.left],\n    [convInfo.strideHeight, convInfo.strideWidth],\n    [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inChannels],\n    [convInfo.filterWidth * convInfo.inChannels], [convInfo.outWidth]\n  ];\n  const im2Col =\n      backend.runWebGLProgram(im2ColProgram, [x], 'float32', customValues);\n  const im2ColReshaped =\n      reshape({inputs: {x: im2Col}, backend, attrs: {shape: x2ColShape}});\n\n  intermediates.push(im2Col);\n  intermediates.push(im2ColReshaped);\n\n  const hasBias = bias != null;\n  const hasPreluActivationWeights = preluActivationWeights != null;\n  const hasLeakyreluAlpha = activation === 'leakyrelu';\n  const fusedActivation =\n      activation ? mapActivationToShaderProgram(activation, true) : null;\n  const matmulProgram = new MatMulPackedProgram(\n      isChannelsLast ? im2ColReshaped.shape as [number, number, number] :\n                       w2Row.shape as [number, number, number],\n      isChannelsLast ? w2Row.shape as [number, number, number] :\n                       im2ColReshaped.shape as [number, number, number],\n      isChannelsLast ? [convInfo.batchSize, numCols, convInfo.outChannels] :\n                       [convInfo.batchSize, convInfo.outChannels, numCols],\n      transposeA, transposeB, hasBias, fusedActivation,\n      hasPreluActivationWeights, hasLeakyreluAlpha);\n  const inputs: TensorInfo[] =\n      isChannelsLast ? [im2ColReshaped, w2Row] : [w2Row, im2ColReshaped];\n  if (bias) {\n    inputs.push(bias);\n  }\n  if (hasPreluActivationWeights) {\n    inputs.push(preluActivationWeights);\n  }\n  if (hasLeakyreluAlpha) {\n    const $leakyreluAlpha = backend.makeTensorInfo(\n        [], 'float32',\n        util.createScalarValue(leakyreluAlpha as unknown as 'float32',\n                               'float32'));\n    inputs.push($leakyreluAlpha);\n    intermediates.push($leakyreluAlpha);\n  }\n  const product = backend.runWebGLProgram(matmulProgram, inputs, 'float32');\n  const out = reshape(\n      {inputs: {x: product}, backend, attrs: {shape: convInfo.outShape}});\n\n  intermediates.push(product);\n  for (const i of intermediates) {\n    backend.disposeIntermediateTensorInfo(i);\n  }\n\n  return out;\n}\n"]}