gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
var shader_compiler_1 = require("./shader_compiler");
var GatherNDProgram = /** @class */ (function () {
    function GatherNDProgram(sliceDim, strides, shape) {
        this.sliceDim = sliceDim;
        this.strides = strides;
        this.variableNames = ['x', 'indices'];
        this.outputShape = shape;
        var stridesType = shader_compiler_1.getCoordsDataType(strides.length);
        var dtype = shader_compiler_1.getCoordsDataType(shape.length);
        var strideString = this.sliceDim > 1 ? 'strides[j]' : 'strides';
        this.userCode = "\n        " + stridesType + " strides = " + stridesType + "(" + this.strides + ");\n         void main() {\n          " + dtype + " coords = getOutputCoords();\n          int flattenIndex = 0;\n          for (int j = 0; j < " + this.sliceDim + "; j++) {\n            int index = round(getIndices(coords[0], j));\n            flattenIndex += index * " + strideString + ";\n          }\n          setOutput(getX(flattenIndex, coords[1]));\n        }\n      ";
    }
    return GatherNDProgram;
}());
exports.GatherNDProgram = GatherNDProgram;
//# sourceMappingURL=gather_nd_gpu.js.map