gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
/**
 * @license
 * Copyright 2019 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { util } from '@tensorflow/tfjs-core';
import { getChannels } from './packing_util';
import { getCoordsDataType } from './shader_compiler';
export class ArgMinMaxPackedProgram {
    constructor(shape, windowSize, op, firstPass) {
        this.variableNames = ['A'];
        this.packedInputs = true;
        this.packedOutput = true;
        util.assert(shape.length > 2, () => `Packed arg${op.charAt(0).toUpperCase() +
            op.slice(1)} supports only inputs with rank above 2.`);
        const inSize = shape[shape.length - 1];
        const outSize = Math.ceil(inSize / windowSize);
        this.outputShape = shape.slice(0, -1);
        if (outSize > 1) {
            this.outputShape.push(outSize);
        }
        if (!firstPass) {
            this.variableNames.push('bestIndicesA');
        }
        const outShape = this.outputShape;
        const rank = outShape.length;
        const dtype = getCoordsDataType(rank);
        const coords = getChannels('coords', rank);
        let sourceLocSetup;
        let sourceRank;
        if (outSize === 1) {
            sourceRank = rank + 1;
            const sourceLocDType = getCoordsDataType(sourceRank);
            sourceLocSetup = `
        ${sourceLocDType} sourceLocR = ${sourceLocDType}(${coords.join()}, 0);
        ++${coords[rank - 1]};
        ${sourceLocDType} sourceLocG = ${sourceLocDType}(${coords.join()}, 0);
        ++${coords[rank - 2]};
        ${sourceLocDType} sourceLocA = ${sourceLocDType}(${coords.join()}, 0);
        --${coords[rank - 1]};
        ${sourceLocDType} sourceLocB = ${sourceLocDType}(${coords.join()}, 0);
        --${coords[rank - 2]};`;
        }
        else {
            sourceRank = rank;
            sourceLocSetup = `
        ${dtype} sourceLocR = coords;
        ++${coords[rank - 1]};
        ${dtype} sourceLocG = coords;
        ++${coords[rank - 2]};
        ${dtype} sourceLocA = coords;
        --${coords[rank - 1]};
        ${dtype} sourceLocB = coords;
        --${coords[rank - 2]};`;
        }
        const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank);
        const inChannel = '.' + channels[sourceRank - 1]; // e.g. ".b" for rank 3.
        const intChannels = channels.map(x => 'int ' + x);
        const srcRCoords = getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r');
        const srcGCoords = getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g');
        const srcBCoords = getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b');
        const srcACoords = getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a');
        const compOp = (op === 'max') ? 'greaterThan' : 'lessThan';
        const fetchCandidateIdx = firstPass ? '' : `
          inIdx = round(vec4(getBestIndicesAChannel(${srcRCoords.join()}),
                             getBestIndicesAChannel(${srcGCoords.join()}),
                             getBestIndicesAChannel(${srcBCoords.join()}),
                             getBestIndicesAChannel(${srcACoords.join()})));`;
        const fetchValue = `vec4(
            getAChannel(${srcRCoords.join()}),
            hasNextCol ? getAChannel(${srcGCoords.join()}) : 0.,
            hasNextRow ? getAChannel(${srcBCoords.join()}) : 0.,
            hasNextRow && hasNextCol ? getAChannel(${srcACoords.join()}) : 0.)`;
        const getBestIndicesAChannelSnippet = firstPass ? '' : `
      float getBestIndicesAChannel(${intChannels.join()}) {
        return getChannel(getBestIndicesA(${channels.join()}),
                                          vec2(${channels.slice(-2).join()}));
      }`;
        this.userCode = `
      float getAChannel(${intChannels.join()}) {
        return getChannel(getA(${channels.join()}),
                               vec2(${channels.slice(-2).join()}));
      }
      ${getBestIndicesAChannelSnippet}
      void main() {
        ${dtype} coords = getOutputCoords();
        bool hasNextCol = ${coords[rank - 1]} < ${outShape[rank - 1] - 1};
        bool hasNextRow = ${coords[rank - 2]} < ${outShape[rank - 2] - 1};
        ${sourceLocSetup}
        ivec4 srcIdx = ivec4(sourceLocR${inChannel}, sourceLocG${inChannel},
          sourceLocB${inChannel}, sourceLocA${inChannel}) * ${windowSize};
        ivec4 inIdx = srcIdx;
        vec4 bestIndex = vec4(inIdx);
        vec4 bestValue = ${fetchValue};
 
        for (int i = 0; i < ${windowSize}; i++) {
          inIdx = srcIdx;
          ${fetchCandidateIdx}
          vec4 candidate = ${fetchValue};
          bvec4 nan = isnan(candidate);
          bvec4 replace = bvec4(
            vec4(${compOp}(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));
 
          bestValue = vec4(replace.x  ? candidate.x : bestValue.x,
                           replace.y  ? candidate.y : bestValue.y,
                           replace.z  ? candidate.z : bestValue.z,
                           replace.w  ? candidate.w : bestValue.w);
          bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));
          srcIdx++;
        }
        setOutput(bestIndex);
      }
    `;
    }
}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"argminmax_packed_gpu.js","sourceRoot":"","sources":["../../../../../tfjs-backend-webgl/src/argminmax_packed_gpu.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAG3C,OAAO,EAAC,WAAW,EAAC,MAAM,gBAAgB,CAAC;AAC3C,OAAO,EAAC,iBAAiB,EAAC,MAAM,mBAAmB,CAAC;AAEpD,MAAM,OAAO,sBAAsB;IAOjC,YACI,KAAe,EAAE,UAAkB,EAAE,EAAe,EACpD,SAAkB;QARtB,kBAAa,GAAG,CAAC,GAAG,CAAC,CAAC;QAGtB,iBAAY,GAAG,IAAI,CAAC;QACpB,iBAAY,GAAG,IAAI,CAAC;QAKlB,IAAI,CAAC,MAAM,CACP,KAAK,CAAC,MAAM,GAAG,CAAC,EAChB,GAAG,EAAE,CAAC,aACF,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,WAAW,EAAE;YAC1B,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,0CAA0C,CAAC,CAAC;QAC/D,MAAM,MAAM,GAAG,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;QACvC,MAAM,OAAO,GAAG,IAAI,CAAC,IAAI,CAAC,MAAM,GAAG,UAAU,CAAC,CAAC;QAC/C,IAAI,CAAC,WAAW,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,OAAO,GAAG,CAAC,EAAE;YACf,IAAI,CAAC,WAAW,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;SAChC;QACD,IAAI,CAAC,SAAS,EAAE;YACd,IAAI,CAAC,aAAa,CAAC,IAAI,CAAC,cAAc,CAAC,CAAC;SACzC;QACD,MAAM,QAAQ,GAAG,IAAI,CAAC,WAAW,CAAC;QAClC,MAAM,IAAI,GAAG,QAAQ,CAAC,MAAM,CAAC;QAC7B,MAAM,KAAK,GAAG,iBAAiB,CAAC,IAAI,CAAC,CAAC;QACtC,MAAM,MAAM,GAAG,WAAW,CAAC,QAAQ,EAAE,IAAI,CAAC,CAAC;QAE3C,IAAI,cAAc,CAAC;QACnB,IAAI,UAAU,CAAC;QACf,IAAI,OAAO,KAAK,CAAC,EAAE;YACjB,UAAU,GAAG,IAAI,GAAG,CAAC,CAAC;YACtB,MAAM,cAAc,GAAG,iBAAiB,CAAC,UAAU,CAAC,CAAC;YACrD,cAAc,GAAG;UACb,cAAc,iBAAiB,cAAc,IAAI,MAAM,CAAC,IAAI,EAAE;YAC5D,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;UAClB,cAAc,iBAAiB,cAAc,IAAI,MAAM,CAAC,IAAI,EAAE;YAC5D,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;UAClB,cAAc,iBAAiB,cAAc,IAAI,MAAM,CAAC,IAAI,EAAE;YAC5D,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;UAClB,cAAc,iBAAiB,cAAc,IAAI,MAAM,CAAC,IAAI,EAAE;YAC5D,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC,GAAG,CAAC;SAC3B;aAAM;YACL,UAAU,GAAG,IAAI,CAAC;YAClB,cAAc,GAAG;UACb,KAAK;YACH,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;UAClB,KAAK;YACH,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;UAClB,KAAK;YACH,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;UAClB,KAAK;YACH,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC,GAAG,CAAC;SAC3B;QACD,MAAM,QAAQ,GAAG,CAAC,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC;QACrE,MAAM,SAAS,GAAG,GAAG,GAAG,QAAQ,CAAC,UAAU,GAAG,CAAC,CAAC,CAAC,CAAE,wBAAwB;QAC3E,MAAM,WAAW,GAAG,QAAQ,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;QAClD,MAAM,UAAU,GACZ,WAAW,CAAC,YAAY,EAAE,UAAU,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,SAAS,CAAC,CAAC;QAChE,MAAM,UAAU,GACZ,WAAW,CAAC,YAAY,EAAE,UAAU,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,SAAS,CAAC,CAAC;QAChE,MAAM,UAAU,GACZ,WAAW,CAAC,YAAY,EAAE,UAAU,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,SAAS,CAAC,CAAC;QAChE,MAAM,UAAU,GACZ,WAAW,CAAC,YAAY,EAAE,UAAU,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,SAAS,CAAC,CAAC;QAEhE,MAAM,MAAM,GAAG,CAAC,EAAE,KAAK,KAAK,CAAC,CAAC,CAAC,CAAC,aAAa,CAAC,CAAC,CAAC,UAAU,CAAC;QAC3D,MAAM,iBAAiB,GAAG,SAAS,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;sDACO,UAAU,CAAC,IAAI,EAAE;sDACjB,UAAU,CAAC,IAAI,EAAE;sDACjB,UAAU,CAAC,IAAI,EAAE;sDACjB,UAAU,CAAC,IAAI,EAAE,MAAM,CAAC;QAE1E,MAAM,UAAU,GAAG;0BACG,UAAU,CAAC,IAAI,EAAE;uCACJ,UAAU,CAAC,IAAI,EAAE;uCACjB,UAAU,CAAC,IAAI,EAAE;qDACH,UAAU,CAAC,IAAI,EAAE,SAAS,CAAC;QAE5E,MAAM,6BAA6B,GAAG,SAAS,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;qCACtB,WAAW,CAAC,IAAI,EAAE;4CACX,QAAQ,CAAC,IAAI,EAAE;iDACV,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,EAAE;QAClE,CAAC;QAEL,IAAI,CAAC,QAAQ,GAAG;0BACM,WAAW,CAAC,IAAI,EAAE;iCACX,QAAQ,CAAC,IAAI,EAAE;sCACV,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,EAAE;;QAEvD,6BAA6B;;UAE3B,KAAK;4BACa,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC,MAAM,QAAQ,CAAC,IAAI,GAAG,CAAC,CAAC,GAAG,CAAC;4BAC5C,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC,MAAM,QAAQ,CAAC,IAAI,GAAG,CAAC,CAAC,GAAG,CAAC;UAC9D,cAAc;yCACiB,SAAS,eAAe,SAAS;sBACpD,SAAS,eAAe,SAAS,OAAO,UAAU;;;2BAG7C,UAAU;;8BAEP,UAAU;;YAE5B,iBAAiB;6BACA,UAAU;;;mBAGpB,MAAM;;;;;;;;;;;KAWpB,CAAC;IACJ,CAAC;CACF","sourcesContent":["/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {util} from '@tensorflow/tfjs-core';\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getChannels} from './packing_util';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class ArgMinMaxPackedProgram implements GPGPUProgram {\n  variableNames = ['A'];\n  outputShape: number[];\n  userCode: string;\n  packedInputs = true;\n  packedOutput = true;\n\n  constructor(\n      shape: number[], windowSize: number, op: 'max'|'min',\n      firstPass: boolean) {\n    util.assert(\n        shape.length > 2,\n        () => `Packed arg${\n            op.charAt(0).toUpperCase() +\n            op.slice(1)} supports only inputs with rank above 2.`);\n    const inSize = shape[shape.length - 1];\n    const outSize = Math.ceil(inSize / windowSize);\n    this.outputShape = shape.slice(0, -1);\n    if (outSize > 1) {\n      this.outputShape.push(outSize);\n    }\n    if (!firstPass) {\n      this.variableNames.push('bestIndicesA');\n    }\n    const outShape = this.outputShape;\n    const rank = outShape.length;\n    const dtype = getCoordsDataType(rank);\n    const coords = getChannels('coords', rank);\n\n    let sourceLocSetup;\n    let sourceRank;\n    if (outSize === 1) {\n      sourceRank = rank + 1;\n      const sourceLocDType = getCoordsDataType(sourceRank);\n      sourceLocSetup = `\n        ${sourceLocDType} sourceLocR = ${sourceLocDType}(${coords.join()}, 0);\n        ++${coords[rank - 1]};\n        ${sourceLocDType} sourceLocG = ${sourceLocDType}(${coords.join()}, 0);\n        ++${coords[rank - 2]};\n        ${sourceLocDType} sourceLocA = ${sourceLocDType}(${coords.join()}, 0);\n        --${coords[rank - 1]};\n        ${sourceLocDType} sourceLocB = ${sourceLocDType}(${coords.join()}, 0);\n        --${coords[rank - 2]};`;\n    } else {\n      sourceRank = rank;\n      sourceLocSetup = `\n        ${dtype} sourceLocR = coords;\n        ++${coords[rank - 1]};\n        ${dtype} sourceLocG = coords;\n        ++${coords[rank - 2]};\n        ${dtype} sourceLocA = coords;\n        --${coords[rank - 1]};\n        ${dtype} sourceLocB = coords;\n        --${coords[rank - 2]};`;\n    }\n    const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank);\n    const inChannel = '.' + channels[sourceRank - 1];  // e.g. \".b\" for rank 3.\n    const intChannels = channels.map(x => 'int ' + x);\n    const srcRCoords =\n        getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r');\n    const srcGCoords =\n        getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g');\n    const srcBCoords =\n        getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b');\n    const srcACoords =\n        getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a');\n\n    const compOp = (op === 'max') ? 'greaterThan' : 'lessThan';\n    const fetchCandidateIdx = firstPass ? '' : `\n          inIdx = round(vec4(getBestIndicesAChannel(${srcRCoords.join()}),\n                             getBestIndicesAChannel(${srcGCoords.join()}),\n                             getBestIndicesAChannel(${srcBCoords.join()}),\n                             getBestIndicesAChannel(${srcACoords.join()})));`;\n\n    const fetchValue = `vec4(\n            getAChannel(${srcRCoords.join()}),\n            hasNextCol ? getAChannel(${srcGCoords.join()}) : 0.,\n            hasNextRow ? getAChannel(${srcBCoords.join()}) : 0.,\n            hasNextRow && hasNextCol ? getAChannel(${srcACoords.join()}) : 0.)`;\n\n    const getBestIndicesAChannelSnippet = firstPass ? '' : `\n      float getBestIndicesAChannel(${intChannels.join()}) {\n        return getChannel(getBestIndicesA(${channels.join()}),\n                                          vec2(${channels.slice(-2).join()}));\n      }`;\n\n    this.userCode = `\n      float getAChannel(${intChannels.join()}) {\n        return getChannel(getA(${channels.join()}),\n                               vec2(${channels.slice(-2).join()}));\n      }\n      ${getBestIndicesAChannelSnippet}\n      void main() {\n        ${dtype} coords = getOutputCoords();\n        bool hasNextCol = ${coords[rank - 1]} < ${outShape[rank - 1] - 1};\n        bool hasNextRow = ${coords[rank - 2]} < ${outShape[rank - 2] - 1};\n        ${sourceLocSetup}\n        ivec4 srcIdx = ivec4(sourceLocR${inChannel}, sourceLocG${inChannel},\n          sourceLocB${inChannel}, sourceLocA${inChannel}) * ${windowSize};\n        ivec4 inIdx = srcIdx;\n        vec4 bestIndex = vec4(inIdx);\n        vec4 bestValue = ${fetchValue};\n\n        for (int i = 0; i < ${windowSize}; i++) {\n          inIdx = srcIdx;\n          ${fetchCandidateIdx}\n          vec4 candidate = ${fetchValue};\n          bvec4 nan = isnan(candidate);\n          bvec4 replace = bvec4(\n            vec4(${compOp}(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));\n\n          bestValue = vec4(replace.x  ? candidate.x : bestValue.x,\n                           replace.y  ? candidate.y : bestValue.y,\n                           replace.z  ? candidate.z : bestValue.z,\n                           replace.w  ? candidate.w : bestValue.w);\n          bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));\n          srcIdx++;\n        }\n        setOutput(bestIndex);\n      }\n    `;\n  }\n}\n"]}