/** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { useShapeUniforms } from './gpgpu_math'; export class DepthwiseConv2DProgram { constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false, hasLeakyReluAlpha = false) { this.variableNames = ['x', 'W']; this.customUniforms = [ { name: 'pads', type: 'ivec2' }, { name: 'strides', type: 'ivec2' }, { name: 'dilations', type: 'ivec2' }, { name: 'inDims', type: 'ivec2' }, ]; this.outputShape = convInfo.outShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const channelMul = convInfo.outChannels / convInfo.inChannels; let activationSnippet = '', applyActivationSnippet = ''; if (activation) { if (hasPreluActivation) { activationSnippet = `float activation(float a) { float b = getPreluActivationWeightsAtOutCoords(); ${activation} }`; } else if (hasLeakyReluAlpha) { activationSnippet = `float activation(float a) { float b = getLeakyreluAlphaAtOutCoords(); ${activation} }`; } else { activationSnippet = ` float activation(float x) { ${activation} } `; } applyActivationSnippet = `result = activation(result);`; } const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; if (addBias) { this.variableNames.push('bias'); } if (hasPreluActivation) { this.variableNames.push('preluActivationWeights'); } if (hasLeakyReluAlpha) { this.variableNames.push('leakyreluAlpha'); } this.userCode = ` ${activationSnippet} void main() { ivec4 coords = getOutputCoords(); int batch = coords.x; ivec2 xRCCorner = coords.yz * strides - pads; int d2 = coords.w; int d1 = d2 / ${channelMul}; int q = d2 - d1 * ${channelMul}; int xRCorner = xRCCorner.x; int xCCorner = xRCCorner.y; // Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2). // ? = to be determined. : = across all values in that axis. float dotProd = 0.0; // TO DO(dsmilkov): Flatten the two for loops and vec4 the operations. for (int wR = 0; wR < ${filterHeight}; wR++) { int xR = xRCorner + wR * dilations[0]; if (xR < 0 || xR >= inDims[0]) { continue; } for (int wC = 0; wC < ${filterWidth}; wC++) { int xC = xCCorner + wC * dilations[1]; if (xC < 0 || xC >= inDims[1]) { continue; } float xVal = getX(batch, xR, xC, d1); float wVal = getW(wR, wC, d1, q); dotProd += xVal * wVal; } } float result = dotProd; ${addBiasSnippet} ${applyActivationSnippet} setOutput(result); } `; } } //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiY29udl9ncHVfZGVwdGh3aXNlLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLXdlYmdsL3NyYy9jb252X2dwdV9kZXB0aHdpc2UudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBR0gsT0FBTyxFQUFlLGdCQUFnQixFQUFDLE1BQU0sY0FBYyxDQUFDO0FBRTVELE1BQU0sT0FBTyxzQkFBc0I7SUFZakMsWUFDSSxRQUFpQyxFQUFFLE9BQU8sR0FBRyxLQUFLLEVBQ2xELGFBQXFCLElBQUksRUFBRSxrQkFBa0IsR0FBRyxLQUFLLEVBQ3JELGlCQUFpQixHQUFHLEtBQUs7UUFkN0Isa0JBQWEsR0FBRyxDQUFDLEdBQUcsRUFBRSxHQUFHLENBQUMsQ0FBQztRQUkzQixtQkFBYyxHQUFHO1lBQ2YsRUFBQyxJQUFJLEVBQUUsTUFBTSxFQUFFLElBQUksRUFBRSxPQUFnQixFQUFFO1lBQ3ZDLEVBQUMsSUFBSSxFQUFFLFNBQVMsRUFBRSxJQUFJLEVBQUUsT0FBZ0IsRUFBRTtZQUMxQyxFQUFDLElBQUksRUFBRSxXQUFXLEVBQUUsSUFBSSxFQUFFLE9BQWdCLEVBQUU7WUFDNUMsRUFBQyxJQUFJLEVBQUUsUUFBUSxFQUFFLElBQUksRUFBRSxPQUFnQixFQUFFO1NBQzFDLENBQUM7UUFNQSxJQUFJLENBQUMsV0FBVyxHQUFHLFFBQVEsQ0FBQyxRQUFRLENBQUM7UUFDckMsSUFBSSxDQUFDLG1CQUFtQixHQUFHLGdCQUFnQixDQUFDLElBQUksQ0FBQyxXQUFXLENBQUMsTUFBTSxDQUFDLENBQUM7UUFFckUsTUFBTSxZQUFZLEdBQUcsUUFBUSxDQUFDLFlBQVksQ0FBQztRQUMzQyxNQUFNLFdBQVcsR0FBRyxRQUFRLENBQUMsV0FBVyxDQUFDO1FBQ3pDLE1BQU0sVUFBVSxHQUFHLFFBQVEsQ0FBQyxXQUFXLEdBQUcsUUFBUSxDQUFDLFVBQVUsQ0FBQztRQUU5RCxJQUFJLGlCQUFpQixHQUFHLEVBQUUsRUFBRSxzQkFBc0IsR0FBRyxFQUFFLENBQUM7UUFDeEQsSUFBSSxVQUFVLEVBQUU7WUFDZCxJQUFJLGtCQUFrQixFQUFFO2dCQUN0QixpQkFBaUIsR0FBRzs7WUFFaEIsVUFBVTtVQUNaLENBQUM7YUFDSjtpQkFBTSxJQUFJLGlCQUFpQixFQUFFO2dCQUM1QixpQkFBaUIsR0FBRzs7WUFFaEIsVUFBVTtVQUNaLENBQUM7YUFDSjtpQkFBTTtnQkFDTCxpQkFBaUIsR0FBRzs7Y0FFZCxVQUFVOztTQUVmLENBQUM7YUFDSDtZQUVELHNCQUFzQixHQUFHLDhCQUE4QixDQUFDO1NBQ3pEO1FBRUQsTUFBTSxjQUFjLEdBQUcsT0FBTyxDQUFDLENBQUMsQ0FBQyxpQ0FBaUMsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDO1FBQ3hFLElBQUksT0FBTyxFQUFFO1lBQ1gsSUFBSSxDQUFDLGFBQWEsQ0FBQyxJQUFJLENBQUMsTUFBTSxDQUFDLENBQUM7U0FDakM7UUFFRCxJQUFJLGtCQUFrQixFQUFFO1lBQ3RCLElBQUksQ0FBQyxhQUFhLENBQUMsSUFBSSxDQUFDLHdCQUF3QixDQUFDLENBQUM7U0FDbkQ7UUFDRCxJQUFJLGlCQUFpQixFQUFFO1lBQ3JCLElBQUksQ0FBQyxhQUFhLENBQUMsSUFBSSxDQUFDLGdCQUFnQixDQUFDLENBQUM7U0FDM0M7UUFFRCxJQUFJLENBQUMsUUFBUSxHQUFHO1FBQ1osaUJBQWlCOzs7Ozs7O3dCQU9ELFVBQVU7NEJBQ04sVUFBVTs7Ozs7Ozs7O2dDQVNOLFlBQVk7Ozs7Ozs7a0NBT1YsV0FBVzs7Ozs7Ozs7Ozs7Ozs7VUFjbkMsY0FBYztVQUNkLHNCQUFzQjs7O0tBRzNCLENBQUM7SUFDSixDQUFDO0NBQ0YiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAxNyBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7YmFja2VuZF91dGlsfSBmcm9tICdAdGVuc29yZmxvdy90ZmpzLWNvcmUnO1xuaW1wb3J0IHtHUEdQVVByb2dyYW0sIHVzZVNoYXBlVW5pZm9ybXN9IGZyb20gJy4vZ3BncHVfbWF0aCc7XG5cbmV4cG9ydCBjbGFzcyBEZXB0aHdpc2VDb252MkRQcm9ncmFtIGltcGxlbWVudHMgR1BHUFVQcm9ncmFtIHtcbiAgdmFyaWFibGVOYW1lcyA9IFsneCcsICdXJ107XG4gIG91dHB1dFNoYXBlOiBudW1iZXJbXTtcbiAgdXNlckNvZGU6IHN0cmluZztcbiAgZW5hYmxlU2hhcGVVbmlmb3JtczogYm9vbGVhbjtcbiAgY3VzdG9tVW5pZm9ybXMgPSBbXG4gICAge25hbWU6ICdwYWRzJywgdHlwZTogJ2l2ZWMyJyBhcyBjb25zdCB9LFxuICAgIHtuYW1lOiAnc3RyaWRlcycsIHR5cGU6ICdpdmVjMicgYXMgY29uc3QgfSxcbiAgICB7bmFtZTogJ2RpbGF0aW9ucycsIHR5cGU6ICdpdmVjMicgYXMgY29uc3QgfSxcbiAgICB7bmFtZTogJ2luRGltcycsIHR5cGU6ICdpdmVjMicgYXMgY29uc3QgfSxcbiAgXTtcblxuICBjb25zdHJ1Y3RvcihcbiAgICAgIGNvbnZJbmZvOiBiYWNrZW5kX3V0aWwuQ29udjJESW5mbywgYWRkQmlhcyA9IGZhbHNlLFxuICAgICAgYWN0aXZhdGlvbjogc3RyaW5nID0gbnVsbCwgaGFzUHJlbHVBY3RpdmF0aW9uID0gZmFsc2UsXG4gICAgICBoYXNMZWFreVJlbHVBbHBoYSA9IGZhbHNlKSB7XG4gICAgdGhpcy5vdXRwdXRTaGFwZSA9IGNvbnZJbmZvLm91dFNoYXBlO1xuICAgIHRoaXMuZW5hYmxlU2hhcGVVbmlmb3JtcyA9IHVzZVNoYXBlVW5pZm9ybXModGhpcy5vdXRwdXRTaGFwZS5sZW5ndGgpO1xuXG4gICAgY29uc3QgZmlsdGVySGVpZ2h0ID0gY29udkluZm8uZmlsdGVySGVpZ2h0O1xuICAgIGNvbnN0IGZpbHRlcldpZHRoID0gY29udkluZm8uZmlsdGVyV2lkdGg7XG4gICAgY29uc3QgY2hhbm5lbE11bCA9IGNvbnZJbmZvLm91dENoYW5uZWxzIC8gY29udkluZm8uaW5DaGFubmVscztcblxuICAgIGxldCBhY3RpdmF0aW9uU25pcHBldCA9ICcnLCBhcHBseUFjdGl2YXRpb25TbmlwcGV0ID0gJyc7XG4gICAgaWYgKGFjdGl2YXRpb24pIHtcbiAgICAgIGlmIChoYXNQcmVsdUFjdGl2YXRpb24pIHtcbiAgICAgICAgYWN0aXZhdGlvblNuaXBwZXQgPSBgZmxvYXQgYWN0aXZhdGlvbihmbG9hdCBhKSB7XG4gICAgICAgICAgZmxvYXQgYiA9IGdldFByZWx1QWN0aXZhdGlvbldlaWdodHNBdE91dENvb3JkcygpO1xuICAgICAgICAgICR7YWN0aXZhdGlvbn1cbiAgICAgICAgfWA7XG4gICAgICB9IGVsc2UgaWYgKGhhc0xlYWt5UmVsdUFscGhhKSB7XG4gICAgICAgIGFjdGl2YXRpb25TbmlwcGV0ID0gYGZsb2F0IGFjdGl2YXRpb24oZmxvYXQgYSkge1xuICAgICAgICAgIGZsb2F0IGIgPSBnZXRMZWFreXJlbHVBbHBoYUF0T3V0Q29vcmRzKCk7XG4gICAgICAgICAgJHthY3RpdmF0aW9ufVxuICAgICAgICB9YDtcbiAgICAgIH0gZWxzZSB7XG4gICAgICAgIGFjdGl2YXRpb25TbmlwcGV0ID0gYFxuICAgICAgICAgIGZsb2F0IGFjdGl2YXRpb24oZmxvYXQgeCkge1xuICAgICAgICAgICAgJHthY3RpdmF0aW9ufVxuICAgICAgICAgIH1cbiAgICAgICAgYDtcbiAgICAgIH1cblxuICAgICAgYXBwbHlBY3RpdmF0aW9uU25pcHBldCA9IGByZXN1bHQgPSBhY3RpdmF0aW9uKHJlc3VsdCk7YDtcbiAgICB9XG5cbiAgICBjb25zdCBhZGRCaWFzU25pcHBldCA9IGFkZEJpYXMgPyAncmVzdWx0ICs9IGdldEJpYXNBdE91dENvb3JkcygpOycgOiAnJztcbiAgICBpZiAoYWRkQmlhcykge1xuICAgICAgdGhpcy52YXJpYWJsZU5hbWVzLnB1c2goJ2JpYXMnKTtcbiAgICB9XG5cbiAgICBpZiAoaGFzUHJlbHVBY3RpdmF0aW9uKSB7XG4gICAgICB0aGlzLnZhcmlhYmxlTmFtZXMucHVzaCgncHJlbHVBY3RpdmF0aW9uV2VpZ2h0cycpO1xuICAgIH1cbiAgICBpZiAoaGFzTGVha3lSZWx1QWxwaGEpIHtcbiAgICAgIHRoaXMudmFyaWFibGVOYW1lcy5wdXNoKCdsZWFreXJlbHVBbHBoYScpO1xuICAgIH1cblxuICAgIHRoaXMudXNlckNvZGUgPSBgXG4gICAgICAke2FjdGl2YXRpb25TbmlwcGV0fVxuXG4gICAgICB2b2lkIG1haW4oKSB7XG4gICAgICAgIGl2ZWM0IGNvb3JkcyA9IGdldE91dHB1dENvb3JkcygpO1xuICAgICAgICBpbnQgYmF0Y2ggPSBjb29yZHMueDtcbiAgICAgICAgaXZlYzIgeFJDQ29ybmVyID0gY29vcmRzLnl6ICogc3RyaWRlcyAtIHBhZHM7XG4gICAgICAgIGludCBkMiA9IGNvb3Jkcy53O1xuICAgICAgICBpbnQgZDEgPSBkMiAvICR7Y2hhbm5lbE11bH07XG4gICAgICAgIGludCBxID0gZDIgLSBkMSAqICR7Y2hhbm5lbE11bH07XG5cbiAgICAgICAgaW50IHhSQ29ybmVyID0geFJDQ29ybmVyLng7XG4gICAgICAgIGludCB4Q0Nvcm5lciA9IHhSQ0Nvcm5lci55O1xuXG4gICAgICAgIC8vIENvbnZvbHZlIHgoPywgPywgZDEpIHdpdGggdyg6LCA6LCBkMSwgcSkgdG8gZ2V0IHkoeVIsIHlDLCBkMikuXG4gICAgICAgIC8vID8gPSB0byBiZSBkZXRlcm1pbmVkLiA6ID0gYWNyb3NzIGFsbCB2YWx1ZXMgaW4gdGhhdCBheGlzLlxuICAgICAgICBmbG9hdCBkb3RQcm9kID0gMC4wO1xuICAgICAgICAvLyBUTyBETyhkc21pbGtvdik6IEZsYXR0ZW4gdGhlIHR3byBmb3IgbG9vcHMgYW5kIHZlYzQgdGhlIG9wZXJhdGlvbnMuXG4gICAgICAgIGZvciAoaW50IHdSID0gMDsgd1IgPCAke2ZpbHRlckhlaWdodH07IHdSKyspIHtcbiAgICAgICAgICBpbnQgeFIgPSB4UkNvcm5lciArIHdSICogZGlsYXRpb25zWzBdO1xuXG4gICAgICAgICAgaWYgKHhSIDwgMCB8fCB4UiA+PSBpbkRpbXNbMF0pIHtcbiAgICAgICAgICAgIGNvbnRpbnVlO1xuICAgICAgICAgIH1cblxuICAgICAgICAgIGZvciAoaW50IHdDID0gMDsgd0MgPCAke2ZpbHRlcldpZHRofTsgd0MrKykge1xuICAgICAgICAgICAgaW50IHhDID0geENDb3JuZXIgKyB3QyAqIGRpbGF0aW9uc1sxXTtcblxuICAgICAgICAgICAgaWYgKHhDIDwgMCB8fCB4QyA+PSBpbkRpbXNbMV0pIHtcbiAgICAgICAgICAgICAgY29udGludWU7XG4gICAgICAgICAgICB9XG5cbiAgICAgICAgICAgIGZsb2F0IHhWYWwgPSBnZXRYKGJhdGNoLCB4UiwgeEMsIGQxKTtcbiAgICAgICAgICAgIGZsb2F0IHdWYWwgPSBnZXRXKHdSLCB3QywgZDEsIHEpO1xuICAgICAgICAgICAgZG90UHJvZCArPSB4VmFsICogd1ZhbDtcbiAgICAgICAgICB9XG4gICAgICAgIH1cblxuICAgICAgICBmbG9hdCByZXN1bHQgPSBkb3RQcm9kO1xuICAgICAgICAke2FkZEJpYXNTbmlwcGV0fVxuICAgICAgICAke2FwcGx5QWN0aXZhdGlvblNuaXBwZXR9XG4gICAgICAgIHNldE91dHB1dChyZXN1bHQpO1xuICAgICAgfVxuICAgIGA7XG4gIH1cbn1cbiJdfQ==