/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { useShapeUniforms } from './gpgpu_math'; export class MatMulPackedProgram { constructor(aShape, bShape, outputShape, transposeA = false, transposeB = false, addBias = false, activation = null, hasPreluActivation = false, hasLeakyreluActivation = false) { this.variableNames = ['matrixA', 'matrixB']; this.packedInputs = true; this.packedOutput = true; this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); const sharedDim = transposeA ? aShape[1] : aShape[2]; const sharedDimensionPacked = Math.ceil(sharedDim / 2); const aSample = transposeA ? 'i * 2, rc.y' : 'rc.y, i * 2'; const bSample = transposeB ? 'rc.z, i * 2' : 'i * 2, rc.z'; const aSwizzle = transposeA ? ['a.xxyy', 'a.zzww'] : ['a.xxzz', 'a.yyww']; const bSwizzle = transposeB ? ['b.xzxz', 'b.ywyw'] : ['b.xyxy', 'b.zwzw']; let activationSnippet = '', applyActivationSnippet = ''; if (activation) { if (hasPreluActivation) { activationSnippet = `vec4 activation(vec4 a) { vec4 b = getPreluActivationWeightsAtOutCoords(); ${activation} }`; } else if (hasLeakyreluActivation) { activationSnippet = `vec4 activation(vec4 a) { vec4 b = getLeakyreluAlphaAtOutCoords(); ${activation} }`; } else { activationSnippet = `vec4 activation(vec4 x) { ${activation} }`; } applyActivationSnippet = `result = activation(result);`; } const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; if (addBias) { this.variableNames.push('bias'); } if (hasPreluActivation) { this.variableNames.push('preluActivationWeights'); } if (hasLeakyreluActivation) { this.variableNames.push('leakyreluAlpha'); } let batchASnippet = 'rc.x'; let batchBSnippet = 'rc.x'; if (aShape[0] < bShape[0]) { batchASnippet = `imod(rc.x, ${aShape[0]})`; } else if (bShape[0] < aShape[0]) { batchBSnippet = `imod(rc.x, ${bShape[0]})`; } this.userCode = ` ${activationSnippet} // Don't use uniform for sharedDimensionPacked for performance. const float sharedDimension = ${sharedDimensionPacked}.0; vec4 dot2x2ARowBCol(ivec3 rc) { vec4 result = vec4(0); int batchA = ${batchASnippet}; int batchB = ${batchBSnippet}; for (int i = 0; i < ${sharedDimensionPacked}; i++) { vec4 a = getMatrixA(batchA, ${aSample}); vec4 b = getMatrixB(batchB, ${bSample}); // These swizzled products need to be separately added. // See: https://github.com/tensorflow/tfjs/issues/1735 result += (${aSwizzle[0]} * ${bSwizzle[0]}); result += (${aSwizzle[1]} * ${bSwizzle[1]}); } return result; } void main() { ivec3 rc = getOutputCoords(); vec4 result = dot2x2ARowBCol(rc); ${addBiasSnippet} ${applyActivationSnippet} setOutput(result); } `; } } //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoibXVsbWF0X3BhY2tlZF9ncHUuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtd2ViZ2wvc3JjL211bG1hdF9wYWNrZWRfZ3B1LnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBZSxnQkFBZ0IsRUFBQyxNQUFNLGNBQWMsQ0FBQztBQUU1RCxNQUFNLE9BQU8sbUJBQW1CO0lBUTlCLFlBQ0ksTUFBZ0MsRUFBRSxNQUFnQyxFQUNsRSxXQUFxQyxFQUFFLFVBQVUsR0FBRyxLQUFLLEVBQ3pELFVBQVUsR0FBRyxLQUFLLEVBQUUsT0FBTyxHQUFHLEtBQUssRUFBRSxhQUFxQixJQUFJLEVBQzlELGtCQUFrQixHQUFHLEtBQUssRUFBRSxzQkFBc0IsR0FBRyxLQUFLO1FBWDlELGtCQUFhLEdBQUcsQ0FBQyxTQUFTLEVBQUUsU0FBUyxDQUFDLENBQUM7UUFDdkMsaUJBQVksR0FBRyxJQUFJLENBQUM7UUFDcEIsaUJBQVksR0FBRyxJQUFJLENBQUM7UUFVbEIsSUFBSSxDQUFDLFdBQVcsR0FBRyxXQUFXLENBQUM7UUFDL0IsSUFBSSxDQUFDLG1CQUFtQixHQUFHLGdCQUFnQixDQUFDLElBQUksQ0FBQyxXQUFXLENBQUMsTUFBTSxDQUFDLENBQUM7UUFFckUsTUFBTSxTQUFTLEdBQUcsVUFBVSxDQUFDLENBQUMsQ0FBQyxNQUFNLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUNyRCxNQUFNLHFCQUFxQixHQUFHLElBQUksQ0FBQyxJQUFJLENBQUMsU0FBUyxHQUFHLENBQUMsQ0FBQyxDQUFDO1FBRXZELE1BQU0sT0FBTyxHQUFHLFVBQVUsQ0FBQyxDQUFDLENBQUMsYUFBYSxDQUFDLENBQUMsQ0FBQyxhQUFhLENBQUM7UUFDM0QsTUFBTSxPQUFPLEdBQUcsVUFBVSxDQUFDLENBQUMsQ0FBQyxhQUFhLENBQUMsQ0FBQyxDQUFDLGFBQWEsQ0FBQztRQUMzRCxNQUFNLFFBQVEsR0FBRyxVQUFVLENBQUMsQ0FBQyxDQUFDLENBQUMsUUFBUSxFQUFFLFFBQVEsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLFFBQVEsRUFBRSxRQUFRLENBQUMsQ0FBQztRQUMxRSxNQUFNLFFBQVEsR0FBRyxVQUFVLENBQUMsQ0FBQyxDQUFDLENBQUMsUUFBUSxFQUFFLFFBQVEsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLFFBQVEsRUFBRSxRQUFRLENBQUMsQ0FBQztRQUUxRSxJQUFJLGlCQUFpQixHQUFHLEVBQUUsRUFBRSxzQkFBc0IsR0FBRyxFQUFFLENBQUM7UUFDeEQsSUFBSSxVQUFVLEVBQUU7WUFDZCxJQUFJLGtCQUFrQixFQUFFO2dCQUN0QixpQkFBaUIsR0FBRzs7WUFFaEIsVUFBVTtVQUNaLENBQUM7YUFDSjtpQkFBTSxJQUFJLHNCQUFzQixFQUFFO2dCQUNqQyxpQkFBaUIsR0FBRzs7WUFFaEIsVUFBVTtVQUNaLENBQUM7YUFDSjtpQkFBTTtnQkFDTCxpQkFBaUIsR0FBRztZQUNoQixVQUFVO1VBQ1osQ0FBQzthQUNKO1lBRUQsc0JBQXNCLEdBQUcsOEJBQThCLENBQUM7U0FDekQ7UUFFRCxNQUFNLGNBQWMsR0FBRyxPQUFPLENBQUMsQ0FBQyxDQUFDLGlDQUFpQyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUM7UUFDeEUsSUFBSSxPQUFPLEVBQUU7WUFDWCxJQUFJLENBQUMsYUFBYSxDQUFDLElBQUksQ0FBQyxNQUFNLENBQUMsQ0FBQztTQUNqQztRQUVELElBQUksa0JBQWtCLEVBQUU7WUFDdEIsSUFBSSxDQUFDLGFBQWEsQ0FBQyxJQUFJLENBQUMsd0JBQXdCLENBQUMsQ0FBQztTQUNuRDtRQUVELElBQUksc0JBQXNCLEVBQUU7WUFDMUIsSUFBSSxDQUFDLGFBQWEsQ0FBQyxJQUFJLENBQUMsZ0JBQWdCLENBQUMsQ0FBQztTQUMzQztRQUVELElBQUksYUFBYSxHQUFHLE1BQU0sQ0FBQztRQUMzQixJQUFJLGFBQWEsR0FBRyxNQUFNLENBQUM7UUFDM0IsSUFBSSxNQUFNLENBQUMsQ0FBQyxDQUFDLEdBQUcsTUFBTSxDQUFDLENBQUMsQ0FBQyxFQUFFO1lBQ3pCLGFBQWEsR0FBRyxjQUFjLE1BQU0sQ0FBQyxDQUFDLENBQUMsR0FBRyxDQUFDO1NBQzVDO2FBQU0sSUFBSSxNQUFNLENBQUMsQ0FBQyxDQUFDLEdBQUcsTUFBTSxDQUFDLENBQUMsQ0FBQyxFQUFFO1lBQ2hDLGFBQWEsR0FBRyxjQUFjLE1BQU0sQ0FBQyxDQUFDLENBQUMsR0FBRyxDQUFDO1NBQzVDO1FBRUQsSUFBSSxDQUFDLFFBQVEsR0FBRztRQUNaLGlCQUFpQjs7c0NBRWEscUJBQXFCOzs7O3VCQUlwQyxhQUFhO3VCQUNiLGFBQWE7OEJBQ04scUJBQXFCO3dDQUNYLE9BQU87d0NBQ1AsT0FBTzs7Ozt1QkFJeEIsUUFBUSxDQUFDLENBQUMsQ0FBQyxNQUFNLFFBQVEsQ0FBQyxDQUFDLENBQUM7dUJBQzVCLFFBQVEsQ0FBQyxDQUFDLENBQUMsTUFBTSxRQUFRLENBQUMsQ0FBQyxDQUFDOzs7Ozs7Ozs7VUFTekMsY0FBYzs7VUFFZCxzQkFBc0I7Ozs7S0FJM0IsQ0FBQztJQUNKLENBQUM7Q0FDRiIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDE4IEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtHUEdQVVByb2dyYW0sIHVzZVNoYXBlVW5pZm9ybXN9IGZyb20gJy4vZ3BncHVfbWF0aCc7XG5cbmV4cG9ydCBjbGFzcyBNYXRNdWxQYWNrZWRQcm9ncmFtIGltcGxlbWVudHMgR1BHUFVQcm9ncmFtIHtcbiAgdmFyaWFibGVOYW1lcyA9IFsnbWF0cml4QScsICdtYXRyaXhCJ107XG4gIHBhY2tlZElucHV0cyA9IHRydWU7XG4gIHBhY2tlZE91dHB1dCA9IHRydWU7XG4gIG91dHB1dFNoYXBlOiBudW1iZXJbXTtcbiAgdXNlckNvZGU6IHN0cmluZztcbiAgZW5hYmxlU2hhcGVVbmlmb3JtczogYm9vbGVhbjtcblxuICBjb25zdHJ1Y3RvcihcbiAgICAgIGFTaGFwZTogW251bWJlciwgbnVtYmVyLCBudW1iZXJdLCBiU2hhcGU6IFtudW1iZXIsIG51bWJlciwgbnVtYmVyXSxcbiAgICAgIG91dHB1dFNoYXBlOiBbbnVtYmVyLCBudW1iZXIsIG51bWJlcl0sIHRyYW5zcG9zZUEgPSBmYWxzZSxcbiAgICAgIHRyYW5zcG9zZUIgPSBmYWxzZSwgYWRkQmlhcyA9IGZhbHNlLCBhY3RpdmF0aW9uOiBzdHJpbmcgPSBudWxsLFxuICAgICAgaGFzUHJlbHVBY3RpdmF0aW9uID0gZmFsc2UsIGhhc0xlYWt5cmVsdUFjdGl2YXRpb24gPSBmYWxzZSkge1xuICAgIHRoaXMub3V0cHV0U2hhcGUgPSBvdXRwdXRTaGFwZTtcbiAgICB0aGlzLmVuYWJsZVNoYXBlVW5pZm9ybXMgPSB1c2VTaGFwZVVuaWZvcm1zKHRoaXMub3V0cHV0U2hhcGUubGVuZ3RoKTtcblxuICAgIGNvbnN0IHNoYXJlZERpbSA9IHRyYW5zcG9zZUEgPyBhU2hhcGVbMV0gOiBhU2hhcGVbMl07XG4gICAgY29uc3Qgc2hhcmVkRGltZW5zaW9uUGFja2VkID0gTWF0aC5jZWlsKHNoYXJlZERpbSAvIDIpO1xuXG4gICAgY29uc3QgYVNhbXBsZSA9IHRyYW5zcG9zZUEgPyAnaSAqIDIsIHJjLnknIDogJ3JjLnksIGkgKiAyJztcbiAgICBjb25zdCBiU2FtcGxlID0gdHJhbnNwb3NlQiA/ICdyYy56LCBpICogMicgOiAnaSAqIDIsIHJjLnonO1xuICAgIGNvbnN0IGFTd2l6emxlID0gdHJhbnNwb3NlQSA/IFsnYS54eHl5JywgJ2Euenp3dyddIDogWydhLnh4enonLCAnYS55eXd3J107XG4gICAgY29uc3QgYlN3aXp6bGUgPSB0cmFuc3Bvc2VCID8gWydiLnh6eHonLCAnYi55d3l3J10gOiBbJ2IueHl4eScsICdiLnp3encnXTtcblxuICAgIGxldCBhY3RpdmF0aW9uU25pcHBldCA9ICcnLCBhcHBseUFjdGl2YXRpb25TbmlwcGV0ID0gJyc7XG4gICAgaWYgKGFjdGl2YXRpb24pIHtcbiAgICAgIGlmIChoYXNQcmVsdUFjdGl2YXRpb24pIHtcbiAgICAgICAgYWN0aXZhdGlvblNuaXBwZXQgPSBgdmVjNCBhY3RpdmF0aW9uKHZlYzQgYSkge1xuICAgICAgICAgIHZlYzQgYiA9IGdldFByZWx1QWN0aXZhdGlvbldlaWdodHNBdE91dENvb3JkcygpO1xuICAgICAgICAgICR7YWN0aXZhdGlvbn1cbiAgICAgICAgfWA7XG4gICAgICB9IGVsc2UgaWYgKGhhc0xlYWt5cmVsdUFjdGl2YXRpb24pIHtcbiAgICAgICAgYWN0aXZhdGlvblNuaXBwZXQgPSBgdmVjNCBhY3RpdmF0aW9uKHZlYzQgYSkge1xuICAgICAgICAgIHZlYzQgYiA9IGdldExlYWt5cmVsdUFscGhhQXRPdXRDb29yZHMoKTtcbiAgICAgICAgICAke2FjdGl2YXRpb259XG4gICAgICAgIH1gO1xuICAgICAgfSBlbHNlIHtcbiAgICAgICAgYWN0aXZhdGlvblNuaXBwZXQgPSBgdmVjNCBhY3RpdmF0aW9uKHZlYzQgeCkge1xuICAgICAgICAgICR7YWN0aXZhdGlvbn1cbiAgICAgICAgfWA7XG4gICAgICB9XG5cbiAgICAgIGFwcGx5QWN0aXZhdGlvblNuaXBwZXQgPSBgcmVzdWx0ID0gYWN0aXZhdGlvbihyZXN1bHQpO2A7XG4gICAgfVxuXG4gICAgY29uc3QgYWRkQmlhc1NuaXBwZXQgPSBhZGRCaWFzID8gJ3Jlc3VsdCArPSBnZXRCaWFzQXRPdXRDb29yZHMoKTsnIDogJyc7XG4gICAgaWYgKGFkZEJpYXMpIHtcbiAgICAgIHRoaXMudmFyaWFibGVOYW1lcy5wdXNoKCdiaWFzJyk7XG4gICAgfVxuXG4gICAgaWYgKGhhc1ByZWx1QWN0aXZhdGlvbikge1xuICAgICAgdGhpcy52YXJpYWJsZU5hbWVzLnB1c2goJ3ByZWx1QWN0aXZhdGlvbldlaWdodHMnKTtcbiAgICB9XG5cbiAgICBpZiAoaGFzTGVha3lyZWx1QWN0aXZhdGlvbikge1xuICAgICAgdGhpcy52YXJpYWJsZU5hbWVzLnB1c2goJ2xlYWt5cmVsdUFscGhhJyk7XG4gICAgfVxuXG4gICAgbGV0IGJhdGNoQVNuaXBwZXQgPSAncmMueCc7XG4gICAgbGV0IGJhdGNoQlNuaXBwZXQgPSAncmMueCc7XG4gICAgaWYgKGFTaGFwZVswXSA8IGJTaGFwZVswXSkge1xuICAgICAgYmF0Y2hBU25pcHBldCA9IGBpbW9kKHJjLngsICR7YVNoYXBlWzBdfSlgO1xuICAgIH0gZWxzZSBpZiAoYlNoYXBlWzBdIDwgYVNoYXBlWzBdKSB7XG4gICAgICBiYXRjaEJTbmlwcGV0ID0gYGltb2QocmMueCwgJHtiU2hhcGVbMF19KWA7XG4gICAgfVxuXG4gICAgdGhpcy51c2VyQ29kZSA9IGBcbiAgICAgICR7YWN0aXZhdGlvblNuaXBwZXR9XG4gICAgICAvLyBEb24ndCB1c2UgdW5pZm9ybSBmb3Igc2hhcmVkRGltZW5zaW9uUGFja2VkIGZvciBwZXJmb3JtYW5jZS5cbiAgICAgIGNvbnN0IGZsb2F0IHNoYXJlZERpbWVuc2lvbiA9ICR7c2hhcmVkRGltZW5zaW9uUGFja2VkfS4wO1xuXG4gICAgICB2ZWM0IGRvdDJ4MkFSb3dCQ29sKGl2ZWMzIHJjKSB7XG4gICAgICAgIHZlYzQgcmVzdWx0ID0gdmVjNCgwKTtcbiAgICAgICAgaW50IGJhdGNoQSA9ICR7YmF0Y2hBU25pcHBldH07XG4gICAgICAgIGludCBiYXRjaEIgPSAke2JhdGNoQlNuaXBwZXR9O1xuICAgICAgICBmb3IgKGludCBpID0gMDsgaSA8ICR7c2hhcmVkRGltZW5zaW9uUGFja2VkfTsgaSsrKSB7XG4gICAgICAgICAgdmVjNCBhID0gZ2V0TWF0cml4QShiYXRjaEEsICR7YVNhbXBsZX0pO1xuICAgICAgICAgIHZlYzQgYiA9IGdldE1hdHJpeEIoYmF0Y2hCLCAke2JTYW1wbGV9KTtcblxuICAgICAgICAgIC8vIFRoZXNlIHN3aXp6bGVkIHByb2R1Y3RzIG5lZWQgdG8gYmUgc2VwYXJhdGVseSBhZGRlZC5cbiAgICAgICAgICAvLyBTZWU6IGh0dHBzOi8vZ2l0aHViLmNvbS90ZW5zb3JmbG93L3RmanMvaXNzdWVzLzE3MzVcbiAgICAgICAgICByZXN1bHQgKz0gKCR7YVN3aXp6bGVbMF19ICogJHtiU3dpenpsZVswXX0pO1xuICAgICAgICAgIHJlc3VsdCArPSAoJHthU3dpenpsZVsxXX0gKiAke2JTd2l6emxlWzFdfSk7XG4gICAgICAgIH1cbiAgICAgICAgcmV0dXJuIHJlc3VsdDtcbiAgICAgIH1cblxuICAgICAgdm9pZCBtYWluKCkge1xuICAgICAgICBpdmVjMyByYyA9IGdldE91dHB1dENvb3JkcygpO1xuICAgICAgICB2ZWM0IHJlc3VsdCA9IGRvdDJ4MkFSb3dCQ29sKHJjKTtcblxuICAgICAgICAke2FkZEJpYXNTbmlwcGV0fVxuXG4gICAgICAgICR7YXBwbHlBY3RpdmF0aW9uU25pcHBldH1cblxuICAgICAgICBzZXRPdXRwdXQocmVzdWx0KTtcbiAgICAgIH1cbiAgICBgO1xuICB9XG59XG4iXX0=