"use strict";
|
/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
Object.defineProperty(exports, "__esModule", { value: true });
|
var util = require("../../util");
|
var DepthwiseConvPacked2DProgram = /** @class */ (function () {
|
function DepthwiseConvPacked2DProgram(convInfo, addBias, activation, hasPreluActivation) {
|
if (addBias === void 0) { addBias = false; }
|
if (activation === void 0) { activation = null; }
|
if (hasPreluActivation === void 0) { hasPreluActivation = false; }
|
this.variableNames = ['x', 'W'];
|
this.packedInputs = true;
|
this.packedOutput = true;
|
this.outputShape = convInfo.outShape;
|
var xNumRows = convInfo.inHeight;
|
var xNumCols = convInfo.inWidth;
|
var padTop = convInfo.padInfo.top;
|
var padLeft = convInfo.padInfo.left;
|
var strideHeight = convInfo.strideHeight;
|
var strideWidth = convInfo.strideWidth;
|
var dilationHeight = convInfo.dilationHeight;
|
var dilationWidth = convInfo.dilationWidth;
|
var filterHeight = convInfo.filterHeight;
|
var filterWidth = convInfo.filterWidth;
|
var texelsAcross = filterWidth;
|
var mainLoop = "int xR; int xC; int xCOffset;";
|
for (var r = 0; r < filterHeight; r++) {
|
for (var c = 0; c < filterWidth; c++) {
|
mainLoop += "\n vec4 xTexelR" + r + "C" + c * 2 + " = vec4(0.);\n vec4 wR" + r + "C" + c + " = vec4(0.);\n vec4 xR" + r + "C" + c + " = vec4(0.);";
|
}
|
}
|
/**
|
* This vectorized implementation works by gathering the values needed for
|
* each output channel's dot product into vec4's and then multiplying them
|
* all together (this happens in the final double for-loop below). Most of
|
* the main loop consists of constructing these vec4's with the minimum
|
* number of texture2D calls, which means making use of all four returned
|
* values from a texture2D call at once.
|
*/
|
for (var r = 0; r < filterHeight; r++) {
|
for (var texelC = 0; texelC < texelsAcross; texelC++) {
|
var c = texelC * 2;
|
mainLoop += "\n xR = xRCorner + " + r * dilationHeight + ";\n xC = xCCorner + " + c * dilationWidth + ";\n ";
|
if (strideWidth === 1) {
|
if (c < filterWidth) {
|
// If padding is odd, the outer texels have to be composed.
|
if (padLeft % 2 === 1) {
|
// TODO: Ensure vec4 previous does not result in redundant sample,
|
// and avoid setting xTexelRC's that exceed the boundary in the
|
// first place rather than resetting them to vec4(0)).
|
// To compute xCOffset:
|
// - If padding is odd, we must add 1 to ensure we ask for an
|
// even-numbered row.
|
// - We subtract 2 to access the previous texel.
|
mainLoop += "\n xCOffset = xC + 1;\n if(xR >= 0 && xR < " + xNumRows + " && xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if(xCOffset + 1 >= " + xNumCols + ") {\n xTexelR" + r + "C" + c + ".zw = vec2(0.);\n }\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n xCOffset = xC + 1 - 2;\n if(xR >= 0 && xR < " + xNumRows + " && xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n vec4 previous = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if(xCOffset + 1 >= " + xNumCols + ") {\n previous.zw = vec2(0.);\n }\n\n xR" + r + "C" + c + " = vec4(previous.zw, xTexelR" + r + "C" + c + ".xy);\n } else {\n xR" + r + "C" + c + " = vec4(0, 0, xTexelR" + r + "C" + c + ".xy);\n }\n ";
|
}
|
else {
|
// Padding is even, so xRC corresponds to a single texel.
|
mainLoop += "\n if(xR >= 0 && xR < " + xNumRows + " && xC >= 0 && xC < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xC, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n xR" + r + "C" + c + " = xTexelR" + r + "C" + c + ";\n ";
|
}
|
if (c + 1 < filterWidth) {
|
// If dilation is even, the second entry should match the first
|
// (either both are composed or both are single samples). But if
|
// dilation is odd, then the second entry should be the opposite
|
// of the first (if the first is composed, the second is a single
|
// sample, and vice versa.)
|
var nextTexelOffset = padLeft % 2 === 0 ?
|
util.nearestLargerEven(dilationWidth) :
|
dilationWidth;
|
if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) ||
|
(dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) {
|
mainLoop += "\n xCOffset = xC + " + padLeft % 2 + " + " + nextTexelOffset + ";\n\n if(xR >= 0 && xR < " + xNumRows + " &&\n xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xCOffset, d1);\n }\n ";
|
// If dilation > 1 then the xRC's will not be able to share any
|
// values, so each xRC will require two unique calls to getX.
|
if (dilationWidth > 1) {
|
mainLoop += "\n xCOffset -= 2;\n if(xR >= 0 && xR < " + xNumRows + " &&\n xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n ";
|
}
|
mainLoop += "\n xR" + r + "C" + (c + 1) + " = vec4(\n xTexelR" + r + "C" + c + ".zw, xTexelR" + r + "C" + (c + 2) + ".xy);\n ";
|
}
|
else {
|
mainLoop += "\n xCOffset = xC + " + nextTexelOffset + ";\n\n if(xR >= 0 && xR < " + xNumRows + " &&\n xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xCOffset, d1);\n }\n\n xR" + r + "C" + (c + 1) + " = xTexelR" + r + "C" + (c + 2) + ";\n ";
|
}
|
}
|
}
|
}
|
else { // stride > 1
|
if (c < filterWidth) {
|
mainLoop += "\n if(xR >= 0 && xR < " + xNumRows + ") {\n ";
|
// Depending on whether padLeft is even or odd, we want either the
|
// xy or zw channels from X texels for xR${r}C${c}. If padLeft is
|
// even, xR${r}C${c + 1} is simply the zw channels of texels we've
|
// already sampled. But if padLeft is odd, xR${r}C{$c + 1}.zw will
|
// need to come from the xy channels of a new texel, hence the `vec4
|
// final` initialized below.
|
if (padLeft % 2 === 1) {
|
mainLoop += "\n xCOffset = xC + 1 - " + strideWidth + ";\n if(xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n if(xC + 1 >= 0 && xC + 1 < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xC + 1, d1);\n } else {\n xTexelR" + r + "C" + (c + 2) + " = vec4(0.);\n }\n\n xR" + r + "C" + c + " = vec4(\n xTexelR" + r + "C" + c + ".zw, xTexelR" + r + "C" + (c + 2) + ".zw);\n ";
|
if (c + 1 < filterWidth) {
|
mainLoop += "\n vec4 final = vec4(0.);\n xCOffset = xC + 1 + " + strideWidth + ";\n if(xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n final = getX(batch, xR, xCOffset, d1);\n }\n xR" + r + "C" + (c + 1) + " = vec4(xTexelR" + r + "C" + (c + 2) + ".xy, final.xy);\n ";
|
}
|
}
|
else {
|
mainLoop += "\n if(xC >= 0 && xC < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xC, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n xCOffset = xC + " + strideWidth + ";\n if(xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR" + r + "C" + (c + 2) + " = vec4(0.);\n }\n\n xR" + r + "C" + c + " = vec4(\n xTexelR" + r + "C" + c + ".xy, xTexelR" + r + "C" + (c + 2) + ".xy);\n ";
|
if (c + 1 < filterWidth) {
|
mainLoop += "\n xR" + r + "C" + (c + 1) + " = vec4(\n xTexelR" + r + "C" + c + ".zw, xTexelR" + r + "C" + (c + 2) + ".zw);\n ";
|
}
|
}
|
mainLoop += "}";
|
}
|
}
|
if (c < filterWidth) {
|
mainLoop += "\n vec4 wTexelR" + r + "C" + c + " = getW(" + r + ", " + c + ", d1, q);\n wR" + r + "C" + c + " = vec4(wTexelR" + r + "C" + c + ".xz, wTexelR" + r + "C" + c + ".xz);\n ";
|
if (c + 1 < filterWidth) {
|
mainLoop += "\n vec4 wTexelR" + r + "C" + (c + 1) + " = getW(" + r + ", " + (c + 1) + ", d1, q);\n wR" + r + "C" + (c + 1) + " =\n vec4(wTexelR" + r + "C" + (c + 1) + ".xz, wTexelR" + r + "C" + (c + 1) + ".xz);";
|
}
|
}
|
}
|
}
|
for (var r = 0; r < filterHeight; r++) {
|
for (var c = 0; c < filterWidth; c++) {
|
mainLoop += "dotProd += xR" + r + "C" + c + " * wR" + r + "C" + c + ";";
|
}
|
}
|
var activationSnippet = '', applyActivationSnippet = '';
|
if (activation) {
|
if (hasPreluActivation) {
|
activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }";
|
}
|
else {
|
activationSnippet = "vec4 activation(vec4 x) {\n " + activation + "\n }";
|
}
|
applyActivationSnippet = "result = activation(result);";
|
}
|
var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
|
if (addBias) {
|
this.variableNames.push('bias');
|
}
|
if (hasPreluActivation) {
|
this.variableNames.push('preluActivationWeights');
|
}
|
this.userCode = "\n " + activationSnippet + "\n\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2;\n int q = 0;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n vec4 dotProd = vec4(0.);\n\n " + mainLoop + "\n\n vec4 result = dotProd;\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n ";
|
}
|
return DepthwiseConvPacked2DProgram;
|
}());
|
exports.DepthwiseConvPacked2DProgram = DepthwiseConvPacked2DProgram;
|
//# sourceMappingURL=conv_packed_gpu_depthwise.js.map
|