gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
/**
 * @license
 * Copyright 2019 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
export class LRNPackedProgram {
    constructor(xShape, radius, bias, alpha, beta) {
        this.variableNames = ['x'];
        this.outputShape = [];
        this.packedInputs = true;
        this.packedOutput = true;
        const rad = radius;
        const maxD = xShape[3] - 1;
        this.outputShape = xShape;
        // optimize pow(bias + alpha * sum, -beta)
        // src: https://github.com/tensorflow/tensorflow/..
        // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..
        // tensorflow/core/kernels/mkl_lrn_op.cc#L320
        let powOperator;
        const basis = `float(${bias}) + float(${alpha}) * sum`;
        if (beta === 0.5) {
            powOperator = `inversesqrt(${basis})`;
        }
        else if (beta === 1.0) {
            powOperator = `1.0/(${basis})`;
        }
        else {
            powOperator = `exp(log(${basis}) * float(-${beta}));`;
        }
        this.userCode = `
      void main() {
        ivec4 coords = getOutputCoords();
        int b = coords.x;
        int r = coords.y;
        int c = coords.z;
        int d = coords.w;
 
        bool hasNextCol = d < ${this.outputShape[3]};
        bool hasNextRow = c < ${this.outputShape[2]};
 
        vec4 sum = vec4(0.);
        vec4 xFragAtOutputCoords = getX(b, r, c, d);
 
        vec4 xAtOutputCoords = vec4(
          getChannel(xFragAtOutputCoords, vec2(c, d)),
          hasNextCol ?
            getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0,
          hasNextRow ?
            getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0,
          (hasNextRow && hasNextCol) ?
            getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0
        );
 
        int firstChannel = d - ${rad};
        vec2 cache = vec2(0.);
        if(firstChannel >= 0){
          vec4 firstChannelFrag = getX(b, r, c, firstChannel);
          cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel));
            if(hasNextRow){
              cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel));
            }
        }
 
        ivec2 depth = ivec2(d, d + 1);
        for (int j = - ${rad}; j <= ${rad}; j++) {
          ivec2 idx = depth + j;
          bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0));
          bvec2 belowUpperBound = lessThanEqual(idx, ivec2(${maxD}));
 
          bool depthInRange = aboveLowerBound.x && belowUpperBound.x;
          bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y;
 
          if(depthInRange || depthPlusOneInRange){
            vec4 z = vec4(0.);
            vec4 xFragAtCurrentDepth;
            z.xz = cache.xy;
            if(depthPlusOneInRange && hasNextCol){
              xFragAtCurrentDepth = idx.y != d ?
                getX(b, r, c, idx.y) : xFragAtOutputCoords;
              z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y));
              if(hasNextRow){
                z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y));
              }
            }
            cache.xy = z.yw;
            sum += z * z;
          }
        }
        vec4 result = xAtOutputCoords * ${powOperator};
        setOutput(result);
      }
    `;
    }
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoibHJuX3BhY2tlZF9ncHUuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtd2ViZ2wvc3JjL2xybl9wYWNrZWRfZ3B1LnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUlILE1BQU0sT0FBTyxnQkFBZ0I7SUFPM0IsWUFDSSxNQUFnQixFQUFFLE1BQWMsRUFBRSxJQUFZLEVBQUUsS0FBYSxFQUM3RCxJQUFZO1FBUmhCLGtCQUFhLEdBQUcsQ0FBQyxHQUFHLENBQUMsQ0FBQztRQUN0QixnQkFBVyxHQUFhLEVBQUUsQ0FBQztRQUUzQixpQkFBWSxHQUFHLElBQUksQ0FBQztRQUNwQixpQkFBWSxHQUFHLElBQUksQ0FBQztRQUtsQixNQUFNLEdBQUcsR0FBRyxNQUFNLENBQUM7UUFDbkIsTUFBTSxJQUFJLEdBQUcsTUFBTSxDQUFDLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQztRQUMzQixJQUFJLENBQUMsV0FBVyxHQUFHLE1BQU0sQ0FBQztRQUUxQiwwQ0FBMEM7UUFDMUMsbURBQW1EO1FBQ25ELG1EQUFtRDtRQUNuRCw2Q0FBNkM7UUFDN0MsSUFBSSxXQUFXLENBQUM7UUFDaEIsTUFBTSxLQUFLLEdBQUcsU0FBUyxJQUFJLGFBQWEsS0FBSyxTQUFTLENBQUM7UUFDdkQsSUFBSSxJQUFJLEtBQUssR0FBRyxFQUFFO1lBQ2hCLFdBQVcsR0FBRyxlQUFlLEtBQUssR0FBRyxDQUFDO1NBQ3ZDO2FBQU0sSUFBSSxJQUFJLEtBQUssR0FBRyxFQUFFO1lBQ3ZCLFdBQVcsR0FBRyxRQUFRLEtBQUssR0FBRyxDQUFDO1NBQ2hDO2FBQU07WUFDTCxXQUFXLEdBQUcsV0FBVyxLQUFLLGNBQWMsSUFBSSxLQUFLLENBQUM7U0FDdkQ7UUFFRCxJQUFJLENBQUMsUUFBUSxHQUFHOzs7Ozs7OztnQ0FRWSxJQUFJLENBQUMsV0FBVyxDQUFDLENBQUMsQ0FBQztnQ0FDbkIsSUFBSSxDQUFDLFdBQVcsQ0FBQyxDQUFDLENBQUM7Ozs7Ozs7Ozs7Ozs7OztpQ0FlbEIsR0FBRzs7Ozs7Ozs7Ozs7eUJBV1gsR0FBRyxVQUFVLEdBQUc7Ozs2REFHb0IsSUFBSTs7Ozs7Ozs7Ozs7Ozs7Ozs7Ozs7OzBDQXFCdkIsV0FBVzs7O0tBR2hELENBQUM7SUFDSixDQUFDO0NBQ0YiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAxOSBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7R1BHUFVQcm9ncmFtfSBmcm9tICcuL2dwZ3B1X21hdGgnO1xuXG5leHBvcnQgY2xhc3MgTFJOUGFja2VkUHJvZ3JhbSBpbXBsZW1lbnRzIEdQR1BVUHJvZ3JhbSB7XG4gIHZhcmlhYmxlTmFtZXMgPSBbJ3gnXTtcbiAgb3V0cHV0U2hhcGU6IG51bWJlcltdID0gW107XG4gIHVzZXJDb2RlOiBzdHJpbmc7XG4gIHBhY2tlZElucHV0cyA9IHRydWU7XG4gIHBhY2tlZE91dHB1dCA9IHRydWU7XG5cbiAgY29uc3RydWN0b3IoXG4gICAgICB4U2hhcGU6IG51bWJlcltdLCByYWRpdXM6IG51bWJlciwgYmlhczogbnVtYmVyLCBhbHBoYTogbnVtYmVyLFxuICAgICAgYmV0YTogbnVtYmVyKSB7XG4gICAgY29uc3QgcmFkID0gcmFkaXVzO1xuICAgIGNvbnN0IG1heEQgPSB4U2hhcGVbM10gLSAxO1xuICAgIHRoaXMub3V0cHV0U2hhcGUgPSB4U2hhcGU7XG5cbiAgICAvLyBvcHRpbWl6ZSBwb3coYmlhcyArIGFscGhhICogc3VtLCAtYmV0YSlcbiAgICAvLyBzcmM6IGh0dHBzOi8vZ2l0aHViLmNvbS90ZW5zb3JmbG93L3RlbnNvcmZsb3cvLi5cbiAgICAvLyBibG9iLzI2MDMzYTE2NDRhOWM0YTVmYmUzMTcwYWIyZTg2NGI2YTRjY2Q0Y2EvLi5cbiAgICAvLyB0ZW5zb3JmbG93L2NvcmUva2VybmVscy9ta2xfbHJuX29wLmNjI0wzMjBcbiAgICBsZXQgcG93T3BlcmF0b3I7XG4gICAgY29uc3QgYmFzaXMgPSBgZmxvYXQoJHtiaWFzfSkgKyBmbG9hdCgke2FscGhhfSkgKiBzdW1gO1xuICAgIGlmIChiZXRhID09PSAwLjUpIHtcbiAgICAgIHBvd09wZXJhdG9yID0gYGludmVyc2VzcXJ0KCR7YmFzaXN9KWA7XG4gICAgfSBlbHNlIGlmIChiZXRhID09PSAxLjApIHtcbiAgICAgIHBvd09wZXJhdG9yID0gYDEuMC8oJHtiYXNpc30pYDtcbiAgICB9IGVsc2Uge1xuICAgICAgcG93T3BlcmF0b3IgPSBgZXhwKGxvZygke2Jhc2lzfSkgKiBmbG9hdCgtJHtiZXRhfSkpO2A7XG4gICAgfVxuXG4gICAgdGhpcy51c2VyQ29kZSA9IGBcbiAgICAgIHZvaWQgbWFpbigpIHtcbiAgICAgICAgaXZlYzQgY29vcmRzID0gZ2V0T3V0cHV0Q29vcmRzKCk7XG4gICAgICAgIGludCBiID0gY29vcmRzLng7XG4gICAgICAgIGludCByID0gY29vcmRzLnk7XG4gICAgICAgIGludCBjID0gY29vcmRzLno7XG4gICAgICAgIGludCBkID0gY29vcmRzLnc7XG5cbiAgICAgICAgYm9vbCBoYXNOZXh0Q29sID0gZCA8ICR7dGhpcy5vdXRwdXRTaGFwZVszXX07XG4gICAgICAgIGJvb2wgaGFzTmV4dFJvdyA9IGMgPCAke3RoaXMub3V0cHV0U2hhcGVbMl19O1xuXG4gICAgICAgIHZlYzQgc3VtID0gdmVjNCgwLik7XG4gICAgICAgIHZlYzQgeEZyYWdBdE91dHB1dENvb3JkcyA9IGdldFgoYiwgciwgYywgZCk7XG5cbiAgICAgICAgdmVjNCB4QXRPdXRwdXRDb29yZHMgPSB2ZWM0KFxuICAgICAgICAgIGdldENoYW5uZWwoeEZyYWdBdE91dHB1dENvb3JkcywgdmVjMihjLCBkKSksXG4gICAgICAgICAgaGFzTmV4dENvbCA/XG4gICAgICAgICAgICBnZXRDaGFubmVsKHhGcmFnQXRPdXRwdXRDb29yZHMsIHZlYzIoYywgZCArIDEpKSA6IDAuMCxcbiAgICAgICAgICBoYXNOZXh0Um93ID9cbiAgICAgICAgICAgIGdldENoYW5uZWwoeEZyYWdBdE91dHB1dENvb3JkcyAsIHZlYzIoYyArIDEsIGQpKSA6IDAuMCxcbiAgICAgICAgICAoaGFzTmV4dFJvdyAmJiBoYXNOZXh0Q29sKSA/XG4gICAgICAgICAgICBnZXRDaGFubmVsKHhGcmFnQXRPdXRwdXRDb29yZHMsIHZlYzIoYyArIDEsIGQgKyAxKSkgOiAwLjBcbiAgICAgICAgKTtcblxuICAgICAgICBpbnQgZmlyc3RDaGFubmVsID0gZCAtICR7cmFkfTtcbiAgICAgICAgdmVjMiBjYWNoZSA9IHZlYzIoMC4pO1xuICAgICAgICBpZihmaXJzdENoYW5uZWwgPj0gMCl7XG4gICAgICAgICAgdmVjNCBmaXJzdENoYW5uZWxGcmFnID0gZ2V0WChiLCByLCBjLCBmaXJzdENoYW5uZWwpO1xuICAgICAgICAgIGNhY2hlLnggPSBnZXRDaGFubmVsKGZpcnN0Q2hhbm5lbEZyYWcsIHZlYzIoYywgZmlyc3RDaGFubmVsKSk7XG4gICAgICAgICAgICBpZihoYXNOZXh0Um93KXtcbiAgICAgICAgICAgICAgY2FjaGUueSA9IGdldENoYW5uZWwoZmlyc3RDaGFubmVsRnJhZywgdmVjMihjICsgMSwgZmlyc3RDaGFubmVsKSk7XG4gICAgICAgICAgICB9XG4gICAgICAgIH1cblxuICAgICAgICBpdmVjMiBkZXB0aCA9IGl2ZWMyKGQsIGQgKyAxKTtcbiAgICAgICAgZm9yIChpbnQgaiA9IC0gJHtyYWR9OyBqIDw9ICR7cmFkfTsgaisrKSB7XG4gICAgICAgICAgaXZlYzIgaWR4ID0gZGVwdGggKyBqO1xuICAgICAgICAgIGJ2ZWMyIGFib3ZlTG93ZXJCb3VuZCA9IGdyZWF0ZXJUaGFuRXF1YWwoaWR4LCBpdmVjMigwKSk7XG4gICAgICAgICAgYnZlYzIgYmVsb3dVcHBlckJvdW5kID0gbGVzc1RoYW5FcXVhbChpZHgsIGl2ZWMyKCR7bWF4RH0pKTtcblxuICAgICAgICAgIGJvb2wgZGVwdGhJblJhbmdlID0gYWJvdmVMb3dlckJvdW5kLnggJiYgYmVsb3dVcHBlckJvdW5kLng7XG4gICAgICAgICAgYm9vbCBkZXB0aFBsdXNPbmVJblJhbmdlID0gYWJvdmVMb3dlckJvdW5kLnkgJiYgYmVsb3dVcHBlckJvdW5kLnk7XG5cbiAgICAgICAgICBpZihkZXB0aEluUmFuZ2UgfHwgZGVwdGhQbHVzT25lSW5SYW5nZSl7XG4gICAgICAgICAgICB2ZWM0IHogPSB2ZWM0KDAuKTtcbiAgICAgICAgICAgIHZlYzQgeEZyYWdBdEN1cnJlbnREZXB0aDtcbiAgICAgICAgICAgIHoueHogPSBjYWNoZS54eTtcbiAgICAgICAgICAgIGlmKGRlcHRoUGx1c09uZUluUmFuZ2UgJiYgaGFzTmV4dENvbCl7XG4gICAgICAgICAgICAgIHhGcmFnQXRDdXJyZW50RGVwdGggPSBpZHgueSAhPSBkID9cbiAgICAgICAgICAgICAgICBnZXRYKGIsIHIsIGMsIGlkeC55KSA6IHhGcmFnQXRPdXRwdXRDb29yZHM7XG4gICAgICAgICAgICAgIHoueSA9IGdldENoYW5uZWwoeEZyYWdBdEN1cnJlbnREZXB0aCwgdmVjMihjLCBpZHgueSkpO1xuICAgICAgICAgICAgICBpZihoYXNOZXh0Um93KXtcbiAgICAgICAgICAgICAgICB6LncgPSBnZXRDaGFubmVsKHhGcmFnQXRDdXJyZW50RGVwdGgsIHZlYzIoYyArIDEsIGlkeC55KSk7XG4gICAgICAgICAgICAgIH1cbiAgICAgICAgICAgIH1cbiAgICAgICAgICAgIGNhY2hlLnh5ID0gei55dztcbiAgICAgICAgICAgIHN1bSArPSB6ICogejtcbiAgICAgICAgICB9XG4gICAgICAgIH1cbiAgICAgICAgdmVjNCByZXN1bHQgPSB4QXRPdXRwdXRDb29yZHMgKiAke3Bvd09wZXJhdG9yfTtcbiAgICAgICAgc2V0T3V0cHV0KHJlc3VsdCk7XG4gICAgICB9XG4gICAgYDtcbiAgfVxufVxuIl19