gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"use strict";
/**
 * @license
 * Copyright 2017 Google Inc. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
Object.defineProperty(exports, "__esModule", { value: true });
var Conv2DProgram = /** @class */ (function () {
    function Conv2DProgram(convInfo, addBias, activation, hasPreluActivationWeights) {
        if (addBias === void 0) { addBias = false; }
        if (activation === void 0) { activation = null; }
        if (hasPreluActivationWeights === void 0) { hasPreluActivationWeights = false; }
        this.variableNames = ['x', 'W'];
        this.outputShape = convInfo.outShape;
        var padTop = convInfo.padInfo.top;
        var padLeft = convInfo.padInfo.left;
        var strideHeight = convInfo.strideHeight;
        var strideWidth = convInfo.strideWidth;
        var dilationHeight = convInfo.dilationHeight;
        var dilationWidth = convInfo.dilationWidth;
        var filterHeight = convInfo.filterHeight;
        var filterWidth = convInfo.filterWidth;
        var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
        var inputDepthVec4Remainder = convInfo.inChannels % 4;
        var isChannelsLast = convInfo.dataFormat === 'channelsLast';
        var rowDim = isChannelsLast ? 1 : 2;
        var colDim = isChannelsLast ? 2 : 3;
        var channelDim = isChannelsLast ? 3 : 1;
        var activationSnippet = '', applyActivationSnippet = '';
        if (activation) {
            if (hasPreluActivationWeights) {
                activationSnippet = "float activation(float a) {\n          float b = getPreluActivationWeightsAtOutCoords();\n          " + activation + "\n        }";
            }
            else {
                activationSnippet = "\n          float activation(float x) {\n            " + activation + "\n          }\n        ";
            }
            applyActivationSnippet = "result = activation(result);";
        }
        var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
        if (addBias) {
            this.variableNames.push('bias');
        }
        if (hasPreluActivationWeights) {
            this.variableNames.push('preluActivationWeights');
        }
        this.userCode = "\n      " + activationSnippet + "\n\n      const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n      const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n      void main() {\n        ivec4 coords = getOutputCoords();\n        int batch = coords[0];\n        int d2 = coords[" + channelDim + "];\n\n        ivec2 xRCCorner =\n            ivec2(coords[" + rowDim + "], coords[" + colDim + "]) * strides - pads;\n        int xRCorner = xRCCorner.x;\n        int xCCorner = xRCCorner.y;\n\n        // Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2).\n        // ? = to be determined. : = across all values in that axis.\n        float dotProd = 0.0;\n        for (int wR = 0; wR < " + filterHeight + "; wR++) {\n          int xR = xRCorner + wR * " + dilationHeight + ";\n\n          if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n            continue;\n          }\n\n          for (int wC = 0; wC < " + filterWidth + "; wC++) {\n            int xC = xCCorner + wC * " + dilationWidth + ";\n\n            if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n              continue;\n            }\n\n            for (int d1 = 0; d1 < " + inputDepthNearestVec4 + "; d1 += 4) {\n              vec4 wValues = vec4(\n                getW(wR, wC, d1, d2),\n                getW(wR, wC, d1 + 1, d2),\n                getW(wR, wC, d1 + 2, d2),\n                getW(wR, wC, d1 + 3, d2)\n              );\n\n              if (" + isChannelsLast + ") {\n                vec4 xValues = vec4(\n                  getX(batch, xR, xC, d1),\n                  getX(batch, xR, xC, d1 + 1),\n                  getX(batch, xR, xC, d1 + 2),\n                  getX(batch, xR, xC, d1 + 3)\n                );\n                dotProd += dot(xValues, wValues);\n              } else {\n                vec4 xValues = vec4(\n                  getX(batch, d1, xR, xC),\n                  getX(batch, d1 + 1, xR, xC),\n                  getX(batch, d1 + 2, xR, xC),\n                  getX(batch, d1 + 3, xR, xC)\n                );\n                dotProd += dot(xValues, wValues);\n              }\n            }\n\n            if (" + (inputDepthVec4Remainder === 1) + ") {\n\n              if (" + isChannelsLast + ") {\n                dotProd +=\n                    getX(batch, xR, xC, " + inputDepthNearestVec4 + ") *\n                    getW(wR, wC, " + inputDepthNearestVec4 + ", d2);\n              } else {\n                dotProd +=\n                    getX(batch, " + inputDepthNearestVec4 + ", xR, xC) *\n                    getW(wR, wC, " + inputDepthNearestVec4 + ", d2);\n              }\n\n            } else if (" + (inputDepthVec4Remainder === 2) + ") {\n              vec2 wValues = vec2(\n                getW(wR, wC, " + inputDepthNearestVec4 + ", d2),\n                getW(wR, wC, " + inputDepthNearestVec4 + " + 1, d2)\n              );\n\n              if (" + isChannelsLast + ") {\n                vec2 xValues = vec2(\n                  getX(batch, xR, xC, " + inputDepthNearestVec4 + "),\n                  getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 1)\n                );\n                dotProd += dot(xValues, wValues);\n              } else {\n                vec2 xValues = vec2(\n                  getX(batch, " + inputDepthNearestVec4 + ", xR, xC),\n                  getX(batch, " + inputDepthNearestVec4 + " + 1, xR, xC)\n                );\n                dotProd += dot(xValues, wValues);\n              }\n\n            } else if (" + (inputDepthVec4Remainder === 3) + ") {\n              vec3 wValues = vec3(\n                getW(wR, wC, " + inputDepthNearestVec4 + ", d2),\n                getW(wR, wC, " + inputDepthNearestVec4 + " + 1, d2),\n                getW(wR, wC, " + inputDepthNearestVec4 + " + 2, d2)\n              );\n\n              if (" + isChannelsLast + ") {\n                vec3 xValues = vec3(\n                  getX(batch, xR, xC, " + inputDepthNearestVec4 + "),\n                  getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 1),\n                  getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 2)\n                );\n                dotProd += dot(xValues, wValues);\n              } else {\n                vec3 xValues = vec3(\n                  getX(batch, " + inputDepthNearestVec4 + ", xR, xC),\n                  getX(batch, " + inputDepthNearestVec4 + " + 1, xR, xC),\n                  getX(batch, " + inputDepthNearestVec4 + " + 2, xR, xC)\n                );\n                dotProd += dot(xValues, wValues);\n              }\n\n            }\n          }\n        }\n\n        float result = dotProd;\n        " + addBiasSnippet + "\n        " + applyActivationSnippet + "\n        setOutput(result);\n      }\n    ";
    }
    return Conv2DProgram;
}());
exports.Conv2DProgram = Conv2DProgram;
var Conv3DProgram = /** @class */ (function () {
    function Conv3DProgram(convInfo) {
        this.variableNames = ['x', 'W'];
        this.outputShape = convInfo.outShape;
        var padFront = convInfo.padInfo.front;
        var padTop = convInfo.padInfo.top;
        var padLeft = convInfo.padInfo.left;
        var strideDepth = convInfo.strideDepth;
        var strideHeight = convInfo.strideHeight;
        var strideWidth = convInfo.strideWidth;
        var dilationDepth = convInfo.dilationDepth;
        var dilationHeight = convInfo.dilationHeight;
        var dilationWidth = convInfo.dilationWidth;
        var filterDepth = convInfo.filterDepth;
        var filterHeight = convInfo.filterHeight;
        var filterWidth = convInfo.filterWidth;
        var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
        var inputDepthVec4Remainder = convInfo.inChannels % 4;
        this.userCode = "\n      const ivec3 strides = ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n      const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n      void main() {\n        ivec5 coords = getOutputCoords();\n        int batch = coords.x;\n        int d2 = coords.u;\n\n        ivec3 xFRCCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n        int xFCorner = xFRCCorner.x;\n        int xRCorner = xFRCCorner.y;\n        int xCCorner = xFRCCorner.z;\n\n        // Convolve x(?, ?, ?, d1) with w(:, :, :, d1, d2) to get\n        // y(yF, yR, yC, d2). ? = to be determined. : = across all\n        // values in that axis.\n        float dotProd = 0.0;\n        for (int wF = 0; wF < " + filterDepth + "; wF++) {\n          int xF = xFCorner + wF * " + dilationDepth + ";\n\n          if (xF < 0 || xF >= " + convInfo.inDepth + ") {\n            continue;\n          }\n\n          for (int wR = 0; wR < " + filterHeight + "; wR++) {\n            int xR = xRCorner + wR * " + dilationHeight + ";\n\n            if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n              continue;\n            }\n\n            for (int wC = 0; wC < " + filterWidth + "; wC++) {\n              int xC = xCCorner + wC * " + dilationWidth + ";\n\n              if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n                continue;\n              }\n\n              for (int d1 = 0; d1 < " + inputDepthNearestVec4 + "; d1 += 4) {\n                vec4 xValues = vec4(\n                  getX(batch, xF, xR, xC, d1),\n                  getX(batch, xF, xR, xC, d1 + 1),\n                  getX(batch, xF, xR, xC, d1 + 2),\n                  getX(batch, xF, xR, xC, d1 + 3)\n                );\n                vec4 wValues = vec4(\n                  getW(wF, wR, wC, d1, d2),\n                  getW(wF, wR, wC, d1 + 1, d2),\n                  getW(wF, wR, wC, d1 + 2, d2),\n                  getW(wF, wR, wC, d1 + 3, d2)\n                );\n\n                dotProd += dot(xValues, wValues);\n              }\n\n              if (" + (inputDepthVec4Remainder === 1) + ") {\n                dotProd +=\n                  getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + ") *\n                  getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2);\n              } else if (" + (inputDepthVec4Remainder === 2) + ") {\n                vec2 xValues = vec2(\n                  getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + "),\n                  getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 1)\n                );\n                vec2 wValues = vec2(\n                  getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2),\n                  getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 1, d2)\n                );\n                dotProd += dot(xValues, wValues);\n              } else if (" + (inputDepthVec4Remainder === 3) + ") {\n                vec3 xValues = vec3(\n                  getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + "),\n                  getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 1),\n                  getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 2)\n                );\n                vec3 wValues = vec3(\n                  getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2),\n                  getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 1, d2),\n                  getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 2, d2)\n                );\n                dotProd += dot(xValues, wValues);\n              }\n            }\n          }\n        }\n        setOutput(dotProd);\n      }\n    ";
    }
    return Conv3DProgram;
}());
exports.Conv3DProgram = Conv3DProgram;
//# sourceMappingURL=conv_gpu.js.map