/** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { ENGINE } from '../../engine'; import { customGrad } from '../../gradients'; import { _FusedMatMul } from '../../kernel_names'; import { makeTypesMatch } from '../../tensor_util'; import { convertToTensor } from '../../tensor_util_env'; import * as util from '../../util'; import { add } from '../add'; import * as broadcast_util from '../broadcast_util'; import { applyActivation, getFusedBiasGradient, getFusedDyActivation, shouldFuse } from '../fused_util'; import { matMul as unfusedMatMul } from '../mat_mul'; import { op } from '../operation'; import { reshape } from '../reshape'; /** * Computes the dot product of two matrices with optional activation and bias. * * ```js * const a = tf.tensor2d([-1, -2], [1, 2]); * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]); * const bias = tf.tensor2d([1, 2], [1, 2]); * * tf.fused.matMul({a, b, bias, activation: 'relu'}).print(); * ``` * * @param obj An object with the following properties: * - `a` First matrix in dot product operation. * - `b` Second matrix in dot product operation. * - `transposeA` If true, `a` is transposed before multiplication. * - `transposeB` If true, `b` is transposed before multiplication. * - `bias` Matrix to be added to the result. * - `activation` Name of activation kernel (defaults to `linear`). * - `preluActivationWeights` Tensor of prelu weights. * - `leakyreluAlpha` Alpha of leakyrelu. */ function fusedMatMul_({ a, b, transposeA = false, transposeB = false, bias, activation = 'linear', preluActivationWeights, leakyreluAlpha = 0.2, }) { if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) { let result = unfusedMatMul(a, b, transposeA, transposeB); if (bias != null) { result = add(result, bias); } return applyActivation(result, activation, preluActivationWeights, leakyreluAlpha); } let $a = convertToTensor(a, 'a', 'fused matMul'); let $b = convertToTensor(b, 'b', 'fused matMul'); [$a, $b] = makeTypesMatch($a, $b); const innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1]; const innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2]; const outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2]; const outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1]; const outerDimsA = $a.shape.slice(0, -2); const outerDimsB = $b.shape.slice(0, -2); const batchDimA = util.sizeFromShape(outerDimsA); const batchDimB = util.sizeFromShape(outerDimsB); util.assert(innerShapeA === innerShapeB, () => `Error in fused matMul: inner shapes (${innerShapeA}) and (` + `${innerShapeB}) of Tensors with shapes ${$a.shape} and ` + `${$b.shape} and transposeA=${transposeA}` + ` and transposeB=${transposeB} must match.`); const outShapeOuterDims = broadcast_util.assertAndGetBroadcastShape($a.shape.slice(0, -2), $b.shape.slice(0, -2)); const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]); const a3D = transposeA ? reshape($a, [batchDimA, innerShapeA, outerShapeA]) : reshape($a, [batchDimA, outerShapeA, innerShapeA]); const b3D = transposeB ? reshape($b, [batchDimB, outerShapeB, innerShapeB]) : reshape($b, [batchDimB, innerShapeB, outerShapeB]); let $bias; if (bias != null) { $bias = convertToTensor(bias, 'bias', 'fused matMul'); [$bias] = makeTypesMatch($bias, $a); broadcast_util.assertAndGetBroadcastShape(outShape, $bias.shape); } let $preluActivationWeights; if (preluActivationWeights != null) { $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused matMul'); } const grad = (dy, saved) => { const [a3D, b3D, y, $bias] = saved; // we reshape dy because the result of the forward is not // necessarily going to be a 3d tensor due to a reshape done at the end of // the customOp. const dyActivation = getFusedDyActivation(reshape(dy, y.shape), y, activation); let aDer; let bDer; if (!transposeA && !transposeB) { aDer = unfusedMatMul(dyActivation, b3D, false, true); bDer = unfusedMatMul(a3D, dyActivation, true, false); } else if (!transposeA && transposeB) { aDer = unfusedMatMul(dyActivation, b3D, false, false); bDer = unfusedMatMul(dyActivation, a3D, true, false); } else if (transposeA && !transposeB) { aDer = unfusedMatMul(b3D, dyActivation, false, true); bDer = unfusedMatMul(a3D, dyActivation, false, false); } else { aDer = unfusedMatMul(b3D, dyActivation, true, true); bDer = unfusedMatMul(dyActivation, a3D, true, true); } if (bias != null) { const biasDer = getFusedBiasGradient($bias, dyActivation); return [aDer, bDer, biasDer]; } else { return [aDer, bDer]; } }; const inputs = { a: a3D, b: b3D, bias: $bias, preluActivationWeights: $preluActivationWeights }; const attrs = { transposeA, transposeB, activation, leakyreluAlpha }; // Depending on the the params passed in we will have different number of // inputs and thus a a different number of elements in the gradient. if (bias == null) { const customOp = customGrad((a3D, b3D, save) => { const res = // tslint:disable-next-line: no-unnecessary-type-assertion ENGINE.runKernel(_FusedMatMul, inputs, attrs); save([a3D, b3D, res]); return { value: reshape(res, outShape), gradFunc: grad }; }); return customOp(a3D, b3D); } else { const customOpWithBias = customGrad((a3D, b3D, $bias, save) => { const res = // tslint:disable-next-line: no-unnecessary-type-assertion ENGINE.runKernel(_FusedMatMul, inputs, attrs); save([a3D, b3D, res, $bias]); return { value: reshape(res, outShape), gradFunc: grad }; }); return customOpWithBias(a3D, b3D, $bias); } } export const matMul = /* @__PURE__ */ op({ fusedMatMul_ }); //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"mat_mul.js","sourceRoot":"","sources":["../../../../../../../tfjs-core/src/ops/fused/mat_mul.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,MAAM,EAAC,MAAM,cAAc,CAAC;AACpC,OAAO,EAAC,UAAU,EAAC,MAAM,iBAAiB,CAAC;AAC3C,OAAO,EAAC,YAAY,EAAwC,MAAM,oBAAoB,CAAC;AAIvF,OAAO,EAAC,cAAc,EAAC,MAAM,mBAAmB,CAAC;AACjD,OAAO,EAAC,eAAe,EAAC,MAAM,uBAAuB,CAAC;AAEtD,OAAO,KAAK,IAAI,MAAM,YAAY,CAAC;AAEnC,OAAO,EAAC,GAAG,EAAC,MAAM,QAAQ,CAAC;AAC3B,OAAO,KAAK,cAAc,MAAM,mBAAmB,CAAC;AAEpD,OAAO,EAAC,eAAe,EAAE,oBAAoB,EAAE,oBAAoB,EAAE,UAAU,EAAC,MAAM,eAAe,CAAC;AACtG,OAAO,EAAC,MAAM,IAAI,aAAa,EAAC,MAAM,YAAY,CAAC;AACnD,OAAO,EAAC,EAAE,EAAC,MAAM,cAAc,CAAC;AAChC,OAAO,EAAC,OAAO,EAAC,MAAM,YAAY,CAAC;AAEnC;;;;;;;;;;;;;;;;;;;;GAoBG;AACH,SAAS,YAAY,CAAC,EACpB,CAAC,EACD,CAAC,EACD,UAAU,GAAG,KAAK,EAClB,UAAU,GAAG,KAAK,EAClB,IAAI,EACJ,UAAU,GAAG,QAAQ,EACrB,sBAAsB,EACtB,cAAc,GAAG,GAAG,GAUrB;IACG,IAAI,UAAU,CAAC,MAAM,CAAC,KAAK,CAAC,aAAa,EAAE,UAAU,CAAC,KAAK,KAAK,EAAE;QAChE,IAAI,MAAM,GAAG,aAAa,CAAC,CAAC,EAAE,CAAC,EAAE,UAAU,EAAE,UAAU,CAAC,CAAC;QACzD,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,MAAM,GAAG,GAAG,CAAC,MAAM,EAAE,IAAI,CAAC,CAAC;SAC5B;QAED,OAAO,eAAe,CACX,MAAM,EAAE,UAAU,EAAE,sBAAsB,EAAE,cAAc,CAAC,CAAC;KACxE;IAED,IAAI,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,GAAG,EAAE,cAAc,CAAC,CAAC;IACjD,IAAI,EAAE,GAAG,eAAe,CAAC,CAAC,EAAE,GAAG,EAAE,cAAc,CAAC,CAAC;IACjD,CAAC,EAAE,EAAE,EAAE,CAAC,GAAG,cAAc,CAAC,EAAE,EAAE,EAAE,CAAC,CAAC;IAElC,MAAM,WAAW,GACb,UAAU,CAAC,CAAC,CAAC,EAAE,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC;IAC/D,MAAM,WAAW,GACb,UAAU,CAAC,CAAC,CAAC,EAAE,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC;IAE/D,MAAM,WAAW,GACb,UAAU,CAAC,CAAC,CAAC,EAAE,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC;IAC/D,MAAM,WAAW,GACb,UAAU,CAAC,CAAC,CAAC,EAAE,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC;IAE/D,MAAM,UAAU,GAAG,EAAE,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACzC,MAAM,UAAU,GAAG,EAAE,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACzC,MAAM,SAAS,GAAG,IAAI,CAAC,aAAa,CAAC,UAAU,CAAC,CAAC;IACjD,MAAM,SAAS,GAAG,IAAI,CAAC,aAAa,CAAC,UAAU,CAAC,CAAC;IAEjD,IAAI,CAAC,MAAM,CACP,WAAW,KAAK,WAAW,EAC3B,GAAG,EAAE,CAAC,wCAAwC,WAAW,SAAS;QAC9D,GAAG,WAAW,4BAA4B,EAAE,CAAC,KAAK,OAAO;QACzD,GAAG,EAAE,CAAC,KAAK,mBAAmB,UAAU,EAAE;QAC1C,mBAAmB,UAAU,cAAc,CAAC,CAAC;IAErD,MAAM,iBAAiB,GAAG,cAAc,CAAC,0BAA0B,CAC/D,EAAE,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;IAClD,MAAM,QAAQ,GAAG,iBAAiB,CAAC,MAAM,CAAC,CAAC,WAAW,EAAE,WAAW,CAAC,CAAC,CAAC;IAEtE,MAAM,GAAG,GAAa,UAAU,CAAC,CAAC;QAC9B,OAAO,CAAC,EAAE,EAAE,CAAC,SAAS,EAAE,WAAW,EAAE,WAAW,CAAC,CAAC,CAAC,CAAC;QACpD,OAAO,CAAC,EAAE,EAAE,CAAC,SAAS,EAAE,WAAW,EAAE,WAAW,CAAC,CAAC,CAAC;IACvD,MAAM,GAAG,GAAa,UAAU,CAAC,CAAC;QAC9B,OAAO,CAAC,EAAE,EAAE,CAAC,SAAS,EAAE,WAAW,EAAE,WAAW,CAAC,CAAC,CAAC,CAAC;QACpD,OAAO,CAAC,EAAE,EAAE,CAAC,SAAS,EAAE,WAAW,EAAE,WAAW,CAAC,CAAC,CAAC;IAEvD,IAAI,KAAa,CAAC;IAClB,IAAI,IAAI,IAAI,IAAI,EAAE;QAChB,KAAK,GAAG,eAAe,CAAC,IAAI,EAAE,MAAM,EAAE,cAAc,CAAC,CAAC;QACtD,CAAC,KAAK,CAAC,GAAG,cAAc,CAAC,KAAK,EAAE,EAAE,CAAC,CAAC;QAEpC,cAAc,CAAC,0BAA0B,CAAC,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,CAAC;KAClE;IAED,IAAI,uBAA+B,CAAC;IACpC,IAAI,sBAAsB,IAAI,IAAI,EAAE;QAClC,uBAAuB,GAAG,eAAe,CACrC,sBAAsB,EAAE,eAAe,EAAE,cAAc,CAAC,CAAC;KAC9D;IAED,MAAM,IAAI,GAAG,CAAC,EAAY,EAAE,KAAe,EAAE,EAAE;QAC7C,MAAM,CAAC,GAAG,EAAE,GAAG,EAAE,CAAC,EAAE,KAAK,CAAC,GAAG,KAAK,CAAC;QACnC,yDAAyD;QACzD,0EAA0E;QAC1E,gBAAgB;QAChB,MAAM,YAAY,GACd,oBAAoB,CAAC,OAAO,CAAC,EAAE,EAAE,CAAC,CAAC,KAAK,CAAC,EAAE,CAAC,EAAE,UAAU,CAAC,CAAC;QAC9D,IAAI,IAAY,CAAC;QACjB,IAAI,IAAY,CAAC;QAEjB,IAAI,CAAC,UAAU,IAAI,CAAC,UAAU,EAAE;YAC9B,IAAI,GAAG,aAAa,CAAC,YAAY,EAAE,GAAG,EAAE,KAAK,EAAE,IAAI,CAAC,CAAC;YACrD,IAAI,GAAG,aAAa,CAAC,GAAG,EAAE,YAAY,EAAE,IAAI,EAAE,KAAK,CAAC,CAAC;SACtD;aAAM,IAAI,CAAC,UAAU,IAAI,UAAU,EAAE;YACpC,IAAI,GAAG,aAAa,CAAC,YAAY,EAAE,GAAG,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC;YACtD,IAAI,GAAG,aAAa,CAAC,YAAY,EAAE,GAAG,EAAE,IAAI,EAAE,KAAK,CAAC,CAAC;SACtD;aAAM,IAAI,UAAU,IAAI,CAAC,UAAU,EAAE;YACpC,IAAI,GAAG,aAAa,CAAC,GAAG,EAAE,YAAY,EAAE,KAAK,EAAE,IAAI,CAAC,CAAC;YACrD,IAAI,GAAG,aAAa,CAAC,GAAG,EAAE,YAAY,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC;SACvD;aAAM;YACL,IAAI,GAAG,aAAa,CAAC,GAAG,EAAE,YAAY,EAAE,IAAI,EAAE,IAAI,CAAC,CAAC;YACpD,IAAI,GAAG,aAAa,CAAC,YAAY,EAAE,GAAG,EAAE,IAAI,EAAE,IAAI,CAAC,CAAC;SACrD;QAED,IAAI,IAAI,IAAI,IAAI,EAAE;YAChB,MAAM,OAAO,GAAG,oBAAoB,CAAC,KAAK,EAAE,YAAY,CAAC,CAAC;YAC1D,OAAO,CAAC,IAAI,EAAE,IAAI,EAAE,OAAO,CAAC,CAAC;SAC9B;aAAM;YACL,OAAO,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC;SACrB;IACH,CAAC,CAAC;IAEF,MAAM,MAAM,GAAuB;QACjC,CAAC,EAAE,GAAG;QACN,CAAC,EAAE,GAAG;QACN,IAAI,EAAE,KAAK;QACX,sBAAsB,EAAE,uBAAuB;KAChD,CAAC;IACF,MAAM,KAAK,GACP,EAAC,UAAU,EAAE,UAAU,EAAE,UAAU,EAAE,cAAc,EAAC,CAAC;IAEzD,yEAAyE;IACzE,oEAAoE;IACpE,IAAI,IAAI,IAAI,IAAI,EAAE;QAChB,MAAM,QAAQ,GACV,UAAU,CAAC,CAAC,GAAa,EAAE,GAAa,EAAE,IAAkB,EAAE,EAAE;YAC9D,MAAM,GAAG;YACL,0DAA0D;YAC1D,MAAM,CAAC,SAAS,CACZ,YAAY,EAAE,MAAmC,EACjD,KAAgC,CAAW,CAAC;YAEpD,IAAI,CAAC,CAAC,GAAG,EAAE,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;YAEtB,OAAO,EAAC,KAAK,EAAE,OAAO,CAAC,GAAG,EAAE,QAAQ,CAAC,EAAE,QAAQ,EAAE,IAAI,EAAC,CAAC;QACzD,CAAC,CAAC,CAAC;QACP,OAAO,QAAQ,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC;KAC3B;SAAM;QACL,MAAM,gBAAgB,GAAG,UAAU,CAC/B,CAAC,GAAa,EAAE,GAAa,EAAE,KAAa,EAAE,IAAkB,EAAE,EAAE;YAClE,MAAM,GAAG;YACL,0DAA0D;YAC1D,MAAM,CAAC,SAAS,CACZ,YAAY,EAAE,MAAmC,EACjD,KAAgC,CAAW,CAAC;YAEpD,IAAI,CAAC,CAAC,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,KAAK,CAAC,CAAC,CAAC;YAE7B,OAAO,EAAC,KAAK,EAAE,OAAO,CAAC,GAAG,EAAE,QAAQ,CAAC,EAAE,QAAQ,EAAE,IAAI,EAAC,CAAC;QACzD,CAAC,CAAC,CAAC;QAEP,OAAO,gBAAgB,CAAC,GAAG,EAAE,GAAG,EAAE,KAAK,CAAC,CAAC;KAC1C;AACH,CAAC;AAED,MAAM,CAAC,MAAM,MAAM,GAAG,eAAe,CAAC,EAAE,CAAC,EAAC,YAAY,EAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from '../../engine';\nimport {customGrad} from '../../gradients';\nimport {_FusedMatMul, _FusedMatMulAttrs, _FusedMatMulInputs} from '../../kernel_names';\nimport {NamedAttrMap} from '../../kernel_registry';\nimport {Tensor, Tensor3D} from '../../tensor';\nimport {GradSaveFunc, NamedTensorMap} from '../../tensor_types';\nimport {makeTypesMatch} from '../../tensor_util';\nimport {convertToTensor} from '../../tensor_util_env';\nimport {TensorLike} from '../../types';\nimport * as util from '../../util';\n\nimport {add} from '../add';\nimport * as broadcast_util from '../broadcast_util';\nimport {Activation} from '../fused_types';\nimport {applyActivation, getFusedBiasGradient, getFusedDyActivation, shouldFuse} from '../fused_util';\nimport {matMul as unfusedMatMul} from '../mat_mul';\nimport {op} from '../operation';\nimport {reshape} from '../reshape';\n\n/**\n * Computes the dot product of two matrices with optional activation and bias.\n *\n * ```js\n * const a = tf.tensor2d([-1, -2], [1, 2]);\n * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]);\n * const bias = tf.tensor2d([1, 2], [1, 2]);\n *\n * tf.fused.matMul({a, b, bias, activation: 'relu'}).print();\n * ```\n *\n * @param obj An object with the following properties:\n * - `a` First matrix in dot product operation.\n * - `b` Second matrix in dot product operation.\n * - `transposeA` If true, `a` is transposed before multiplication.\n * - `transposeB` If true, `b` is transposed before multiplication.\n * - `bias` Matrix to be added to the result.\n * - `activation` Name of activation kernel (defaults to `linear`).\n * - `preluActivationWeights` Tensor of prelu weights.\n * - `leakyreluAlpha` Alpha of leakyrelu.\n */\nfunction fusedMatMul_({\n  a,\n  b,\n  transposeA = false,\n  transposeB = false,\n  bias,\n  activation = 'linear',\n  preluActivationWeights,\n  leakyreluAlpha = 0.2,\n}: {\n  a: Tensor|TensorLike,\n  b: Tensor|TensorLike,\n  transposeA?: boolean,\n  transposeB?: boolean,\n  bias?: Tensor|TensorLike,\n  activation?: Activation,\n  preluActivationWeights?: Tensor\n  leakyreluAlpha?: number\n}): Tensor {\n    if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {\n      let result = unfusedMatMul(a, b, transposeA, transposeB);\n      if (bias != null) {\n        result = add(result, bias);\n      }\n\n      return applyActivation(\n                 result, activation, preluActivationWeights, leakyreluAlpha);\n    }\n\n    let $a = convertToTensor(a, 'a', 'fused matMul');\n    let $b = convertToTensor(b, 'b', 'fused matMul');\n    [$a, $b] = makeTypesMatch($a, $b);\n\n    const innerShapeA =\n        transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1];\n    const innerShapeB =\n        transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2];\n\n    const outerShapeA =\n        transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2];\n    const outerShapeB =\n        transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1];\n\n    const outerDimsA = $a.shape.slice(0, -2);\n    const outerDimsB = $b.shape.slice(0, -2);\n    const batchDimA = util.sizeFromShape(outerDimsA);\n    const batchDimB = util.sizeFromShape(outerDimsB);\n\n    util.assert(\n        innerShapeA === innerShapeB,\n        () => `Error in fused matMul: inner shapes (${innerShapeA}) and (` +\n            `${innerShapeB}) of Tensors with shapes ${$a.shape} and ` +\n            `${$b.shape} and transposeA=${transposeA}` +\n            ` and transposeB=${transposeB} must match.`);\n\n    const outShapeOuterDims = broadcast_util.assertAndGetBroadcastShape(\n        $a.shape.slice(0, -2), $b.shape.slice(0, -2));\n    const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);\n\n    const a3D: Tensor3D = transposeA ?\n        reshape($a, [batchDimA, innerShapeA, outerShapeA]) :\n        reshape($a, [batchDimA, outerShapeA, innerShapeA]);\n    const b3D: Tensor3D = transposeB ?\n        reshape($b, [batchDimB, outerShapeB, innerShapeB]) :\n        reshape($b, [batchDimB, innerShapeB, outerShapeB]);\n\n    let $bias: Tensor;\n    if (bias != null) {\n      $bias = convertToTensor(bias, 'bias', 'fused matMul');\n      [$bias] = makeTypesMatch($bias, $a);\n\n      broadcast_util.assertAndGetBroadcastShape(outShape, $bias.shape);\n    }\n\n    let $preluActivationWeights: Tensor;\n    if (preluActivationWeights != null) {\n      $preluActivationWeights = convertToTensor(\n          preluActivationWeights, 'prelu weights', 'fused matMul');\n    }\n\n    const grad = (dy: Tensor3D, saved: Tensor[]) => {\n      const [a3D, b3D, y, $bias] = saved;\n      // we reshape dy because the result of the forward is not\n      // necessarily going to be a 3d tensor due to a reshape done at the end of\n      // the customOp.\n      const dyActivation =\n          getFusedDyActivation(reshape(dy, y.shape), y, activation);\n      let aDer: Tensor;\n      let bDer: Tensor;\n\n      if (!transposeA && !transposeB) {\n        aDer = unfusedMatMul(dyActivation, b3D, false, true);\n        bDer = unfusedMatMul(a3D, dyActivation, true, false);\n      } else if (!transposeA && transposeB) {\n        aDer = unfusedMatMul(dyActivation, b3D, false, false);\n        bDer = unfusedMatMul(dyActivation, a3D, true, false);\n      } else if (transposeA && !transposeB) {\n        aDer = unfusedMatMul(b3D, dyActivation, false, true);\n        bDer = unfusedMatMul(a3D, dyActivation, false, false);\n      } else {\n        aDer = unfusedMatMul(b3D, dyActivation, true, true);\n        bDer = unfusedMatMul(dyActivation, a3D, true, true);\n      }\n\n      if (bias != null) {\n        const biasDer = getFusedBiasGradient($bias, dyActivation);\n        return [aDer, bDer, biasDer];\n      } else {\n        return [aDer, bDer];\n      }\n    };\n\n    const inputs: _FusedMatMulInputs = {\n      a: a3D,\n      b: b3D,\n      bias: $bias,\n      preluActivationWeights: $preluActivationWeights\n    };\n    const attrs: _FusedMatMulAttrs =\n        {transposeA, transposeB, activation, leakyreluAlpha};\n\n    // Depending on the the params passed in we will have different number of\n    // inputs and thus a a different number of elements in the gradient.\n    if (bias == null) {\n      const customOp =\n          customGrad((a3D: Tensor3D, b3D: Tensor3D, save: GradSaveFunc) => {\n            const res =\n                // tslint:disable-next-line: no-unnecessary-type-assertion\n                ENGINE.runKernel(\n                    _FusedMatMul, inputs as unknown as NamedTensorMap,\n                    attrs as unknown as NamedAttrMap) as Tensor;\n\n            save([a3D, b3D, res]);\n\n            return {value: reshape(res, outShape), gradFunc: grad};\n          });\n      return customOp(a3D, b3D);\n    } else {\n      const customOpWithBias = customGrad(\n          (a3D: Tensor3D, b3D: Tensor3D, $bias: Tensor, save: GradSaveFunc) => {\n            const res =\n                // tslint:disable-next-line: no-unnecessary-type-assertion\n                ENGINE.runKernel(\n                    _FusedMatMul, inputs as unknown as NamedTensorMap,\n                    attrs as unknown as NamedAttrMap) as Tensor;\n\n            save([a3D, b3D, res, $bias]);\n\n            return {value: reshape(res, outShape), gradFunc: grad};\n          });\n\n      return customOpWithBias(a3D, b3D, $bias);\n    }\n  }\n\n  export const matMul = /* @__PURE__ */ op({fusedMatMul_});\n"]}