"use strict";
|
/**
|
* @license
|
* Copyright 2017 Google Inc. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
Object.defineProperty(exports, "__esModule", { value: true });
|
var DepthwiseConv2DProgram = /** @class */ (function () {
|
function DepthwiseConv2DProgram(convInfo, addBias, activation, hasPreluActivation) {
|
if (addBias === void 0) { addBias = false; }
|
if (activation === void 0) { activation = null; }
|
if (hasPreluActivation === void 0) { hasPreluActivation = false; }
|
this.variableNames = ['x', 'W'];
|
this.outputShape = convInfo.outShape;
|
var xNumRows = convInfo.inHeight;
|
var xNumCols = convInfo.inWidth;
|
var padTop = convInfo.padInfo.top;
|
var padLeft = convInfo.padInfo.left;
|
var strideHeight = convInfo.strideHeight;
|
var strideWidth = convInfo.strideWidth;
|
var dilationHeight = convInfo.dilationHeight;
|
var dilationWidth = convInfo.dilationWidth;
|
var filterHeight = convInfo.filterHeight;
|
var filterWidth = convInfo.filterWidth;
|
var channelMul = convInfo.outChannels / convInfo.inChannels;
|
var activationSnippet = '', applyActivationSnippet = '';
|
if (activation) {
|
if (hasPreluActivation) {
|
activationSnippet = "float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }";
|
}
|
else {
|
activationSnippet = "\n float activation(float x) {\n " + activation + "\n }\n ";
|
}
|
applyActivationSnippet = "result = activation(result);";
|
}
|
var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
|
if (addBias) {
|
this.variableNames.push('bias');
|
}
|
if (hasPreluActivation) {
|
this.variableNames.push('preluActivationWeights');
|
}
|
this.userCode = "\n " + activationSnippet + "\n\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2 / " + channelMul + ";\n int q = d2 - d1 * " + channelMul + ";\n\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n // TO DO(dsmilkov): Flatten the two for loops and vec4 the operations.\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + xNumRows + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + xNumCols + ") {\n continue;\n }\n\n float xVal = getX(batch, xR, xC, d1);\n float wVal = getW(wR, wC, d1, q);\n dotProd += xVal * wVal;\n }\n }\n\n float result = dotProd;\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n ";
|
}
|
return DepthwiseConv2DProgram;
|
}());
|
exports.DepthwiseConv2DProgram = DepthwiseConv2DProgram;
|
//# sourceMappingURL=conv_gpu_depthwise.js.map
|