gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
/**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { useShapeUniforms } from './gpgpu_math';
import { getChannels } from './packing_util';
import { getCoordsDataType } from './shader_compiler';
export class PackProgram {
    constructor(outputShape) {
        this.variableNames = ['A'];
        this.packedInputs = false;
        this.packedOutput = true;
        // Only input / output 3D tensors.
        this.outputShape = outputShape;
        this.rank = outputShape.length;
        this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
        if (this.rank === 0) {
            this.userCode = `
        void main() {
          setOutput(vec4(getA(), 0., 0., 0.));
        }
      `;
        }
        else {
            const channels = getChannels('rc', this.rank);
            const dtype = getCoordsDataType(this.rank);
            const outOfBoundsCondition = this.getOutOfBoundsCondition(channels);
            const setup = this.getSetup(channels);
            const output = this.getOutput(channels);
            this.userCode = `
        void main() {
          ${dtype} rc = getOutputCoords();
 
          if(${outOfBoundsCondition}) {
            setOutput(vec4(0));
          } else {
            ${setup}
 
            setOutput(vec4(${output}));
          }
        }
      `;
        }
    }
    getSourceCoordsArr(dims) {
        const coords = [];
        for (let row = 0; row <= 1; row++) {
            for (let col = 0; col <= 1; col++) {
                let coord = `${row === 0 ? 'r' : 'rp1'}, ${col === 0 ? 'c' : 'cp1'}`;
                for (let d = 2; d < this.rank; d++) {
                    coord = `${dims[dims.length - 1 - d]},` + coord;
                }
                coords.push(coord);
            }
        }
        return coords;
    }
    getOutOfBoundsCondition(dims) {
        if (this.rank === 1) {
            return `rc > ${this.enableShapeUniforms ? 'outShape' : this.outputShape[0]}`;
        }
        let cond = '';
        for (let i = this.rank - 2; i < this.rank; i++) {
            cond += `${dims[i]} >= ${this.enableShapeUniforms ? `outShape[${i}]` : this.outputShape[i]}`;
            if (i < this.rank - 1) {
                cond += '||';
            }
        }
        return cond;
    }
    getSetup(dims) {
        if (this.rank === 1) {
            return '';
        }
        const innerDims = dims.slice(-2);
        const col = this.enableShapeUniforms ? `outShape[${this.rank} - 1]` :
            this.outputShape[this.rank - 1];
        const row = this.enableShapeUniforms ? `outShape[${this.rank} - 2]` :
            this.outputShape[this.rank - 2];
        return `
      int r = ${innerDims[0]};
      int c = ${innerDims[1]};
      int rp1 = r + 1;
      int cp1 = c + 1;
 
      bool cEdge = cp1 >= ${col};
      bool rEdge = rp1 >= ${row};
    `;
    }
    getOutput(dims) {
        const sourceCoords = this.getSourceCoordsArr(dims);
        if (this.rank === 1) {
            const outShape = this.enableShapeUniforms ? 'outShape' : this.outputShape[0];
            return `getA(rc), (rc + 1 >= ${outShape} ? 0. : getA(rc + 1)), 0, 0`;
        }
        return `getA(${sourceCoords[0]}),
            cEdge ? 0. : getA(${sourceCoords[1]}),
            rEdge ? 0. : getA(${sourceCoords[2]}),
            rEdge || cEdge ? 0. : getA(${sourceCoords[3]})`;
    }
}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"pack_gpu.js","sourceRoot":"","sources":["../../../../../tfjs-backend-webgl/src/pack_gpu.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAe,gBAAgB,EAAC,MAAM,cAAc,CAAC;AAC5D,OAAO,EAAC,WAAW,EAAC,MAAM,gBAAgB,CAAC;AAC3C,OAAO,EAAC,iBAAiB,EAAC,MAAM,mBAAmB,CAAC;AAEpD,MAAM,OAAO,WAAW;IAStB,YACI,WACY;QAVhB,kBAAa,GAAG,CAAC,GAAG,CAAC,CAAC;QAGtB,iBAAY,GAAG,KAAK,CAAC;QACrB,iBAAY,GAAG,IAAI,CAAC;QAOC,kCAAkC;QACrD,IAAI,CAAC,WAAW,GAAG,WAAW,CAAC;QAC/B,IAAI,CAAC,IAAI,GAAG,WAAW,CAAC,MAAM,CAAC;QAC/B,IAAI,CAAC,mBAAmB,GAAG,gBAAgB,CAAC,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,CAAC;QAErE,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;YACnB,IAAI,CAAC,QAAQ,GAAG;;;;OAIf,CAAC;SACH;aAAM;YACL,MAAM,QAAQ,GAAG,WAAW,CAAC,IAAI,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC;YAC9C,MAAM,KAAK,GAAG,iBAAiB,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;YAC3C,MAAM,oBAAoB,GAAG,IAAI,CAAC,uBAAuB,CAAC,QAAQ,CAAC,CAAC;YACpE,MAAM,KAAK,GAAG,IAAI,CAAC,QAAQ,CAAC,QAAQ,CAAC,CAAC;YACtC,MAAM,MAAM,GAAG,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAC,CAAC;YAExC,IAAI,CAAC,QAAQ,GAAG;;YAEV,KAAK;;eAEF,oBAAoB;;;cAGrB,KAAK;;6BAEU,MAAM;;;OAG5B,CAAC;SACH;IACH,CAAC;IAEO,kBAAkB,CAAC,IAAc;QACvC,MAAM,MAAM,GAAG,EAAE,CAAC;QAElB,KAAK,IAAI,GAAG,GAAG,CAAC,EAAE,GAAG,IAAI,CAAC,EAAE,GAAG,EAAE,EAAE;YACjC,KAAK,IAAI,GAAG,GAAG,CAAC,EAAE,GAAG,IAAI,CAAC,EAAE,GAAG,EAAE,EAAE;gBACjC,IAAI,KAAK,GAAG,GAAG,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,KAAK,KAAK,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,KAAK,EAAE,CAAC;gBAErE,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,IAAI,EAAE,CAAC,EAAE,EAAE;oBAClC,KAAK,GAAG,GAAG,IAAI,CAAC,IAAI,CAAC,MAAM,GAAG,CAAC,GAAG,CAAC,CAAC,GAAG,GAAG,KAAK,CAAC;iBACjD;gBAED,MAAM,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;aACpB;SACF;QACD,OAAO,MAAM,CAAC;IAChB,CAAC;IAEO,uBAAuB,CAAC,IAAc;QAC5C,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;YACnB,OAAO,QACH,IAAI,CAAC,mBAAmB,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC,EAAE,CAAC;SACnE;QAED,IAAI,IAAI,GAAG,EAAE,CAAC;QACd,KAAK,IAAI,CAAC,GAAG,IAAI,CAAC,IAAI,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,IAAI,EAAE,CAAC,EAAE,EAAE;YAC9C,IAAI,IAAI,GAAG,IAAI,CAAC,CAAC,CAAC,OACd,IAAI,CAAC,mBAAmB,CAAC,CAAC,CAAC,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC,EAAE,CAAC;YACxE,IAAI,CAAC,GAAG,IAAI,CAAC,IAAI,GAAG,CAAC,EAAE;gBACrB,IAAI,IAAI,IAAI,CAAC;aACd;SACF;QAED,OAAO,IAAI,CAAC;IACd,CAAC;IAEO,QAAQ,CAAC,IAAc;QAC7B,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;YACnB,OAAO,EAAE,CAAC;SACX;QAED,MAAM,SAAS,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;QACjC,MAAM,GAAG,GAAG,IAAI,CAAC,mBAAmB,CAAC,CAAC,CAAC,YAAY,IAAI,CAAC,IAAI,OAAO,CAAC,CAAC;YAC9B,IAAI,CAAC,WAAW,CAAC,IAAI,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC;QACvE,MAAM,GAAG,GAAG,IAAI,CAAC,mBAAmB,CAAC,CAAC,CAAC,YAAY,IAAI,CAAC,IAAI,OAAO,CAAC,CAAC;YAC9B,IAAI,CAAC,WAAW,CAAC,IAAI,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC;QAEvE,OAAO;gBACK,SAAS,CAAC,CAAC,CAAC;gBACZ,SAAS,CAAC,CAAC,CAAC;;;;4BAIA,GAAG;4BACH,GAAG;KAC1B,CAAC;IACJ,CAAC;IAEO,SAAS,CAAC,IAAc;QAC9B,MAAM,YAAY,GAAG,IAAI,CAAC,kBAAkB,CAAC,IAAI,CAAC,CAAC;QACnD,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;YACnB,MAAM,QAAQ,GACV,IAAI,CAAC,mBAAmB,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC;YAChE,OAAO,wBAAwB,QAAQ,6BAA6B,CAAC;SACtE;QAED,OAAO,QAAQ,YAAY,CAAC,CAAC,CAAC;gCACF,YAAY,CAAC,CAAC,CAAC;gCACf,YAAY,CAAC,CAAC,CAAC;yCACN,YAAY,CAAC,CAAC,CAAC,GAAG,CAAC;IAC1D,CAAC;CACF","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram, useShapeUniforms} from './gpgpu_math';\nimport {getChannels} from './packing_util';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class PackProgram implements GPGPUProgram {\n  variableNames = ['A'];\n  outputShape: number[];\n  userCode: string;\n  packedInputs = false;\n  packedOutput = true;\n  enableShapeUniforms: boolean;\n  rank: number;\n\n  constructor(\n      outputShape:\n          number[]) {  // TODO(https://github.com/tensorflow/tfjs/issues/893):\n                       // Only input / output 3D tensors.\n    this.outputShape = outputShape;\n    this.rank = outputShape.length;\n    this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);\n\n    if (this.rank === 0) {\n      this.userCode = `\n        void main() {\n          setOutput(vec4(getA(), 0., 0., 0.));\n        }\n      `;\n    } else {\n      const channels = getChannels('rc', this.rank);\n      const dtype = getCoordsDataType(this.rank);\n      const outOfBoundsCondition = this.getOutOfBoundsCondition(channels);\n      const setup = this.getSetup(channels);\n      const output = this.getOutput(channels);\n\n      this.userCode = `\n        void main() {\n          ${dtype} rc = getOutputCoords();\n\n          if(${outOfBoundsCondition}) {\n            setOutput(vec4(0));\n          } else {\n            ${setup}\n\n            setOutput(vec4(${output}));\n          }\n        }\n      `;\n    }\n  }\n\n  private getSourceCoordsArr(dims: string[]): string[] {\n    const coords = [];\n\n    for (let row = 0; row <= 1; row++) {\n      for (let col = 0; col <= 1; col++) {\n        let coord = `${row === 0 ? 'r' : 'rp1'}, ${col === 0 ? 'c' : 'cp1'}`;\n\n        for (let d = 2; d < this.rank; d++) {\n          coord = `${dims[dims.length - 1 - d]},` + coord;\n        }\n\n        coords.push(coord);\n      }\n    }\n    return coords;\n  }\n\n  private getOutOfBoundsCondition(dims: string[]): string {\n    if (this.rank === 1) {\n      return `rc > ${\n          this.enableShapeUniforms ? 'outShape' : this.outputShape[0]}`;\n    }\n\n    let cond = '';\n    for (let i = this.rank - 2; i < this.rank; i++) {\n      cond += `${dims[i]} >= ${\n          this.enableShapeUniforms ? `outShape[${i}]` : this.outputShape[i]}`;\n      if (i < this.rank - 1) {\n        cond += '||';\n      }\n    }\n\n    return cond;\n  }\n\n  private getSetup(dims: string[]): string {\n    if (this.rank === 1) {\n      return '';\n    }\n\n    const innerDims = dims.slice(-2);\n    const col = this.enableShapeUniforms ? `outShape[${this.rank} - 1]` :\n                                           this.outputShape[this.rank - 1];\n    const row = this.enableShapeUniforms ? `outShape[${this.rank} - 2]` :\n                                           this.outputShape[this.rank - 2];\n\n    return `\n      int r = ${innerDims[0]};\n      int c = ${innerDims[1]};\n      int rp1 = r + 1;\n      int cp1 = c + 1;\n\n      bool cEdge = cp1 >= ${col};\n      bool rEdge = rp1 >= ${row};\n    `;\n  }\n\n  private getOutput(dims: string[]): string {\n    const sourceCoords = this.getSourceCoordsArr(dims);\n    if (this.rank === 1) {\n      const outShape =\n          this.enableShapeUniforms ? 'outShape' : this.outputShape[0];\n      return `getA(rc), (rc + 1 >= ${outShape} ? 0. : getA(rc + 1)), 0, 0`;\n    }\n\n    return `getA(${sourceCoords[0]}),\n            cEdge ? 0. : getA(${sourceCoords[1]}),\n            rEdge ? 0. : getA(${sourceCoords[2]}),\n            rEdge || cEdge ? 0. : getA(${sourceCoords[3]})`;\n  }\n}\n"]}