"use strict";
|
/**
|
* @license
|
* Copyright 2019 Google LLC All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
Object.defineProperty(exports, "__esModule", { value: true });
|
var LRNPackedProgram = /** @class */ (function () {
|
function LRNPackedProgram(xShape, radius, bias, alpha, beta) {
|
this.variableNames = ['x'];
|
this.outputShape = [];
|
this.packedInputs = true;
|
this.packedOutput = true;
|
var rad = radius;
|
var maxD = xShape[3] - 1;
|
this.outputShape = xShape;
|
// optimize pow(bias + alpha * sum, -beta)
|
// src: https://github.com/tensorflow/tensorflow/..
|
// blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..
|
// tensorflow/core/kernels/mkl_lrn_op.cc#L320
|
var powOperator;
|
var basis = "float(" + bias + ") + float(" + alpha + ") * sum";
|
if (beta === 0.5) {
|
powOperator = "inversesqrt(" + basis + ")";
|
}
|
else if (beta === 1.0) {
|
powOperator = "1.0/(" + basis + ")";
|
}
|
else {
|
powOperator = "exp(log(" + basis + ") * float(-" + beta + "));";
|
}
|
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords.x;\n int r = coords.y;\n int c = coords.z;\n int d = coords.w;\n\n bool hasNextCol = d < " + this.outputShape[3] + ";\n bool hasNextRow = c < " + this.outputShape[2] + ";\n\n vec4 sum = vec4(0.);\n vec4 xFragAtOutputCoords = getX(b, r, c, d);\n\n vec4 xAtOutputCoords = vec4(\n getChannel(xFragAtOutputCoords, vec2(c, d)),\n hasNextCol ?\n getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0,\n hasNextRow ?\n getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0,\n (hasNextRow && hasNextCol) ?\n getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0\n );\n\n int firstChannel = d - " + rad + ";\n vec2 cache = vec2(0.);\n if(firstChannel >= 0){\n vec4 firstChannelFrag = getX(b, r, c, firstChannel);\n cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel));\n if(hasNextRow){\n cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel));\n }\n }\n\n ivec2 depth = ivec2(d, d + 1);\n for (int j = - " + rad + "; j <= " + rad + "; j++) {\n ivec2 idx = depth + j;\n bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0));\n bvec2 belowUpperBound = lessThanEqual(idx, ivec2(" + maxD + "));\n\n bool depthInRange = aboveLowerBound.x && belowUpperBound.x;\n bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y;\n\n if(depthInRange || depthPlusOneInRange){\n vec4 z = vec4(0.);\n vec4 xFragAtCurrentDepth;\n z.xz = cache.xy;\n if(depthPlusOneInRange && hasNextCol){\n xFragAtCurrentDepth = idx.y != d ?\n getX(b, r, c, idx.y) : xFragAtOutputCoords;\n z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y));\n if(hasNextRow){\n z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y));\n }\n }\n cache.xy = z.yw;\n sum += z * z;\n }\n }\n vec4 result = xAtOutputCoords * " + powOperator + ";\n setOutput(result);\n }\n ";
|
}
|
return LRNPackedProgram;
|
}());
|
exports.LRNPackedProgram = LRNPackedProgram;
|
//# sourceMappingURL=lrn_packed_gpu.js.map
|