gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
/**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { useShapeUniforms } from './gpgpu_math';
import * as shader_util from './shader_compiler_util';
export class ReshapePackedProgram {
    constructor(outputShape, inputShape) {
        this.variableNames = ['A'];
        this.packedInputs = true;
        this.packedOutput = true;
        this.customUniforms = [{ name: 'inputShape', type: 'ivec3' }];
        this.outputShape = outputShape;
        this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
        let mainLoop = ``;
        for (let i = 0; i < 4; i++) {
            let thisRC = `thisRC = rc;`;
            if (i % 2 === 1) {
                thisRC += `thisRC.z += 1;`;
            }
            if (i > 1) {
                thisRC += `thisRC.y += 1;`;
            }
            mainLoop += `
        ${thisRC}
        ${i > 0 ? `if(thisRC.y < rows && thisRC.z < cols){` : ''}
          int flatIndex = getFlatIndex(thisRC);
 
          ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);
          vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));
 
          result[${i}] =
            getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);
        ${i > 0 ? '}' : ''}
      `;
        }
        this.userCode = `
      ${getReshapedInputCoords(inputShape, this.enableShapeUniforms)}
      ${this.enableShapeUniforms ? shader_util.getFlatIndexFrom3DOutput() :
            shader_util.getFlatIndexFrom3D(outputShape)}
 
      void main() {
        ivec3 rc = getOutputCoords();
 
        vec4 result = vec4(0.);
 
        ivec3 thisRC;
        int rows = ${this.enableShapeUniforms ? 'outShape[1]' : outputShape[1]};
        int cols = ${this.enableShapeUniforms ? 'outShape[2]' : outputShape[2]};
 
        ${mainLoop}
 
        setOutput(result);
      }
    `;
    }
}
function getReshapedInputCoords(shape, enableShapeUniforms) {
    const coordsFromIndexSnippet = enableShapeUniforms ?
        shader_util.getLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], 'inputShape') :
        shader_util.getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
    return `
    ivec3 inputCoordsFromReshapedOutCoords(int index) {
      ${coordsFromIndexSnippet}
      return ivec3(r, c, d);
    }
  `;
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoicmVzaGFwZV9wYWNrZWRfZ3B1LmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLXdlYmdsL3NyYy9yZXNoYXBlX3BhY2tlZF9ncHUudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFlLGdCQUFnQixFQUFDLE1BQU0sY0FBYyxDQUFDO0FBQzVELE9BQU8sS0FBSyxXQUFXLE1BQU0sd0JBQXdCLENBQUM7QUFFdEQsTUFBTSxPQUFPLG9CQUFvQjtJQVMvQixZQUFZLFdBQXFDLEVBQUUsVUFFbEQ7UUFWRCxrQkFBYSxHQUFHLENBQUMsR0FBRyxDQUFDLENBQUM7UUFDdEIsaUJBQVksR0FBRyxJQUFJLENBQUM7UUFDcEIsaUJBQVksR0FBRyxJQUFJLENBQUM7UUFJcEIsbUJBQWMsR0FBRyxDQUFDLEVBQUMsSUFBSSxFQUFFLFlBQVksRUFBRSxJQUFJLEVBQUUsT0FBZ0IsRUFBRSxDQUFDLENBQUM7UUFLL0QsSUFBSSxDQUFDLFdBQVcsR0FBRyxXQUFXLENBQUM7UUFDL0IsSUFBSSxDQUFDLG1CQUFtQixHQUFHLGdCQUFnQixDQUFDLElBQUksQ0FBQyxXQUFXLENBQUMsTUFBTSxDQUFDLENBQUM7UUFFckUsSUFBSSxRQUFRLEdBQUcsRUFBRSxDQUFDO1FBQ2xCLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxFQUFFLEVBQUU7WUFDMUIsSUFBSSxNQUFNLEdBQUcsY0FBYyxDQUFDO1lBQzVCLElBQUksQ0FBQyxHQUFHLENBQUMsS0FBSyxDQUFDLEVBQUU7Z0JBQ2YsTUFBTSxJQUFJLGdCQUFnQixDQUFDO2FBQzVCO1lBQ0QsSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFO2dCQUNULE1BQU0sSUFBSSxnQkFBZ0IsQ0FBQzthQUM1QjtZQUVELFFBQVEsSUFBSTtVQUNSLE1BQU07VUFDTixDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsQ0FBQyx5Q0FBeUMsQ0FBQyxDQUFDLENBQUMsRUFBRTs7Ozs7O21CQU03QyxDQUFDOztVQUVWLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsRUFBRTtPQUNuQixDQUFDO1NBQ0g7UUFFRCxJQUFJLENBQUMsUUFBUSxHQUFHO1FBQ1osc0JBQXNCLENBQUMsVUFBVSxFQUFFLElBQUksQ0FBQyxtQkFBbUIsQ0FBQztRQUU1RCxJQUFJLENBQUMsbUJBQW1CLENBQUMsQ0FBQyxDQUFDLFdBQVcsQ0FBQyx3QkFBd0IsRUFBRSxDQUFDLENBQUM7WUFDeEMsV0FBVyxDQUFDLGtCQUFrQixDQUFDLFdBQVcsQ0FBQzs7Ozs7Ozs7cUJBUXpELElBQUksQ0FBQyxtQkFBbUIsQ0FBQyxDQUFDLENBQUMsYUFBYSxDQUFDLENBQUMsQ0FBQyxXQUFXLENBQUMsQ0FBQyxDQUFDO3FCQUN6RCxJQUFJLENBQUMsbUJBQW1CLENBQUMsQ0FBQyxDQUFDLGFBQWEsQ0FBQyxDQUFDLENBQUMsV0FBVyxDQUFDLENBQUMsQ0FBQzs7VUFFcEUsUUFBUTs7OztLQUliLENBQUM7SUFDSixDQUFDO0NBQ0Y7QUFFRCxTQUFTLHNCQUFzQixDQUMzQixLQUErQixFQUFFLG1CQUE0QjtJQUMvRCxNQUFNLHNCQUFzQixHQUFHLG1CQUFtQixDQUFDLENBQUM7UUFDaEQsV0FBVyxDQUFDLDJDQUEyQyxDQUNuRCxDQUFDLEdBQUcsRUFBRSxHQUFHLEVBQUUsR0FBRyxDQUFDLEVBQUUsWUFBWSxDQUFDLENBQUMsQ0FBQztRQUNwQyxXQUFXLENBQUMsa0NBQWtDLENBQUMsQ0FBQyxHQUFHLEVBQUUsR0FBRyxFQUFFLEdBQUcsQ0FBQyxFQUFFLEtBQUssQ0FBQyxDQUFDO0lBRTNFLE9BQU87O1FBRUQsc0JBQXNCOzs7R0FHM0IsQ0FBQztBQUNKLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAxOCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7R1BHUFVQcm9ncmFtLCB1c2VTaGFwZVVuaWZvcm1zfSBmcm9tICcuL2dwZ3B1X21hdGgnO1xuaW1wb3J0ICogYXMgc2hhZGVyX3V0aWwgZnJvbSAnLi9zaGFkZXJfY29tcGlsZXJfdXRpbCc7XG5cbmV4cG9ydCBjbGFzcyBSZXNoYXBlUGFja2VkUHJvZ3JhbSBpbXBsZW1lbnRzIEdQR1BVUHJvZ3JhbSB7XG4gIHZhcmlhYmxlTmFtZXMgPSBbJ0EnXTtcbiAgcGFja2VkSW5wdXRzID0gdHJ1ZTtcbiAgcGFja2VkT3V0cHV0ID0gdHJ1ZTtcbiAgb3V0cHV0U2hhcGU6IG51bWJlcltdO1xuICB1c2VyQ29kZTogc3RyaW5nO1xuICBlbmFibGVTaGFwZVVuaWZvcm1zOiBib29sZWFuO1xuICBjdXN0b21Vbmlmb3JtcyA9IFt7bmFtZTogJ2lucHV0U2hhcGUnLCB0eXBlOiAnaXZlYzMnIGFzIGNvbnN0IH1dO1xuXG4gIGNvbnN0cnVjdG9yKG91dHB1dFNoYXBlOiBbbnVtYmVyLCBudW1iZXIsIG51bWJlcl0sIGlucHV0U2hhcGU6IFtcbiAgICBudW1iZXIsIG51bWJlciwgbnVtYmVyXG4gIF0pIHtcbiAgICB0aGlzLm91dHB1dFNoYXBlID0gb3V0cHV0U2hhcGU7XG4gICAgdGhpcy5lbmFibGVTaGFwZVVuaWZvcm1zID0gdXNlU2hhcGVVbmlmb3Jtcyh0aGlzLm91dHB1dFNoYXBlLmxlbmd0aCk7XG5cbiAgICBsZXQgbWFpbkxvb3AgPSBgYDtcbiAgICBmb3IgKGxldCBpID0gMDsgaSA8IDQ7IGkrKykge1xuICAgICAgbGV0IHRoaXNSQyA9IGB0aGlzUkMgPSByYztgO1xuICAgICAgaWYgKGkgJSAyID09PSAxKSB7XG4gICAgICAgIHRoaXNSQyArPSBgdGhpc1JDLnogKz0gMTtgO1xuICAgICAgfVxuICAgICAgaWYgKGkgPiAxKSB7XG4gICAgICAgIHRoaXNSQyArPSBgdGhpc1JDLnkgKz0gMTtgO1xuICAgICAgfVxuXG4gICAgICBtYWluTG9vcCArPSBgXG4gICAgICAgICR7dGhpc1JDfVxuICAgICAgICAke2kgPiAwID8gYGlmKHRoaXNSQy55IDwgcm93cyAmJiB0aGlzUkMueiA8IGNvbHMpe2AgOiAnJ31cbiAgICAgICAgICBpbnQgZmxhdEluZGV4ID0gZ2V0RmxhdEluZGV4KHRoaXNSQyk7XG5cbiAgICAgICAgICBpdmVjMyBpbnB1dFJDID0gaW5wdXRDb29yZHNGcm9tUmVzaGFwZWRPdXRDb29yZHMoZmxhdEluZGV4KTtcbiAgICAgICAgICB2ZWMyIGlucHV0UkNJbm5lckRpbXMgPSB2ZWMyKGZsb2F0KGlucHV0UkMueSksZmxvYXQoaW5wdXRSQy56KSk7XG5cbiAgICAgICAgICByZXN1bHRbJHtpfV0gPVxuICAgICAgICAgICAgZ2V0Q2hhbm5lbChnZXRBKGlucHV0UkMueCwgaW5wdXRSQy55LCBpbnB1dFJDLnopLCBpbnB1dFJDSW5uZXJEaW1zKTtcbiAgICAgICAgJHtpID4gMCA/ICd9JyA6ICcnfVxuICAgICAgYDtcbiAgICB9XG5cbiAgICB0aGlzLnVzZXJDb2RlID0gYFxuICAgICAgJHtnZXRSZXNoYXBlZElucHV0Q29vcmRzKGlucHV0U2hhcGUsIHRoaXMuZW5hYmxlU2hhcGVVbmlmb3Jtcyl9XG4gICAgICAke1xuICAgICAgICB0aGlzLmVuYWJsZVNoYXBlVW5pZm9ybXMgPyBzaGFkZXJfdXRpbC5nZXRGbGF0SW5kZXhGcm9tM0RPdXRwdXQoKSA6XG4gICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHNoYWRlcl91dGlsLmdldEZsYXRJbmRleEZyb20zRChvdXRwdXRTaGFwZSl9XG5cbiAgICAgIHZvaWQgbWFpbigpIHtcbiAgICAgICAgaXZlYzMgcmMgPSBnZXRPdXRwdXRDb29yZHMoKTtcblxuICAgICAgICB2ZWM0IHJlc3VsdCA9IHZlYzQoMC4pO1xuXG4gICAgICAgIGl2ZWMzIHRoaXNSQztcbiAgICAgICAgaW50IHJvd3MgPSAke3RoaXMuZW5hYmxlU2hhcGVVbmlmb3JtcyA/ICdvdXRTaGFwZVsxXScgOiBvdXRwdXRTaGFwZVsxXX07XG4gICAgICAgIGludCBjb2xzID0gJHt0aGlzLmVuYWJsZVNoYXBlVW5pZm9ybXMgPyAnb3V0U2hhcGVbMl0nIDogb3V0cHV0U2hhcGVbMl19O1xuXG4gICAgICAgICR7bWFpbkxvb3B9XG5cbiAgICAgICAgc2V0T3V0cHV0KHJlc3VsdCk7XG4gICAgICB9XG4gICAgYDtcbiAgfVxufVxuXG5mdW5jdGlvbiBnZXRSZXNoYXBlZElucHV0Q29vcmRzKFxuICAgIHNoYXBlOiBbbnVtYmVyLCBudW1iZXIsIG51bWJlcl0sIGVuYWJsZVNoYXBlVW5pZm9ybXM6IGJvb2xlYW4pOiBzdHJpbmcge1xuICBjb25zdCBjb29yZHNGcm9tSW5kZXhTbmlwcGV0ID0gZW5hYmxlU2hhcGVVbmlmb3JtcyA/XG4gICAgICBzaGFkZXJfdXRpbC5nZXRMb2dpY2FsQ29vcmRpbmF0ZXNGcm9tRmxhdEluZGV4QnlVbmlmb3JtKFxuICAgICAgICAgIFsncicsICdjJywgJ2QnXSwgJ2lucHV0U2hhcGUnKSA6XG4gICAgICBzaGFkZXJfdXRpbC5nZXRMb2dpY2FsQ29vcmRpbmF0ZXNGcm9tRmxhdEluZGV4KFsncicsICdjJywgJ2QnXSwgc2hhcGUpO1xuXG4gIHJldHVybiBgXG4gICAgaXZlYzMgaW5wdXRDb29yZHNGcm9tUmVzaGFwZWRPdXRDb29yZHMoaW50IGluZGV4KSB7XG4gICAgICAke2Nvb3Jkc0Zyb21JbmRleFNuaXBwZXR9XG4gICAgICByZXR1cm4gaXZlYzMociwgYywgZCk7XG4gICAgfVxuICBgO1xufVxuIl19