/**
|
* @license
|
* Copyright 2019 Google LLC All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
|
import {GPGPUProgram} from './gpgpu_math';
|
|
export class LRNPackedProgram implements GPGPUProgram {
|
variableNames = ['x'];
|
outputShape: number[] = [];
|
userCode: string;
|
packedInputs = true;
|
packedOutput = true;
|
|
constructor(
|
xShape: number[], radius: number, bias: number, alpha: number,
|
beta: number) {
|
const rad = radius;
|
const maxD = xShape[3] - 1;
|
this.outputShape = xShape;
|
|
// optimize pow(bias + alpha * sum, -beta)
|
// src: https://github.com/tensorflow/tensorflow/..
|
// blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..
|
// tensorflow/core/kernels/mkl_lrn_op.cc#L320
|
let powOperator;
|
const basis = `float(${bias}) + float(${alpha}) * sum`;
|
if (beta === 0.5) {
|
powOperator = `inversesqrt(${basis})`;
|
} else if (beta === 1.0) {
|
powOperator = `1.0/(${basis})`;
|
} else {
|
powOperator = `exp(log(${basis}) * float(-${beta}));`;
|
}
|
|
this.userCode = `
|
void main() {
|
ivec4 coords = getOutputCoords();
|
int b = coords.x;
|
int r = coords.y;
|
int c = coords.z;
|
int d = coords.w;
|
|
bool hasNextCol = d < ${this.outputShape[3]};
|
bool hasNextRow = c < ${this.outputShape[2]};
|
|
vec4 sum = vec4(0.);
|
vec4 xFragAtOutputCoords = getX(b, r, c, d);
|
|
vec4 xAtOutputCoords = vec4(
|
getChannel(xFragAtOutputCoords, vec2(c, d)),
|
hasNextCol ?
|
getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0,
|
hasNextRow ?
|
getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0,
|
(hasNextRow && hasNextCol) ?
|
getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0
|
);
|
|
int firstChannel = d - ${rad};
|
vec2 cache = vec2(0.);
|
if(firstChannel >= 0){
|
vec4 firstChannelFrag = getX(b, r, c, firstChannel);
|
cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel));
|
if(hasNextRow){
|
cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel));
|
}
|
}
|
|
ivec2 depth = ivec2(d, d + 1);
|
for (int j = - ${rad}; j <= ${rad}; j++) {
|
ivec2 idx = depth + j;
|
bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0));
|
bvec2 belowUpperBound = lessThanEqual(idx, ivec2(${maxD}));
|
|
bool depthInRange = aboveLowerBound.x && belowUpperBound.x;
|
bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y;
|
|
if(depthInRange || depthPlusOneInRange){
|
vec4 z = vec4(0.);
|
vec4 xFragAtCurrentDepth;
|
z.xz = cache.xy;
|
if(depthPlusOneInRange && hasNextCol){
|
xFragAtCurrentDepth = idx.y != d ?
|
getX(b, r, c, idx.y) : xFragAtOutputCoords;
|
z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y));
|
if(hasNextRow){
|
z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y));
|
}
|
}
|
cache.xy = z.yw;
|
sum += z * z;
|
}
|
}
|
vec4 result = xAtOutputCoords * ${powOperator};
|
setOutput(result);
|
}
|
`;
|
}
|
}
|