gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
/**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { getGlslDifferences } from './glsl_version';
import { useShapeUniforms } from './gpgpu_math';
import * as shader_util from './shader_compiler_util';
/*
This is how the shader encodes a tensor with shape = [2, 3, 5]
(indices are [batch, row, col]).
 
000|001   002|003   004|xxx   020|021   022|023   024|xxx
-------   -------   -------   -------   -------   -------
010|011   012|013   014|xxx   xxx|xxx   xxx|xxx   xxx|xxx
 
100|101   102|103   104|xxx   120|121   122|123   124|xxx
-------   -------   -------   -------   -------   -------
110|111   112|113   114|xxx   xxx|xxx   xxx|xxx   xxx|xxx
 
Single texels contain only values from the same batch, and from adjacent rows
and columns.
 */
export class EncodeMatrixPackedProgram {
    constructor(outputShape, inputIsUnsignedByte = false) {
        this.variableNames = ['A'];
        this.packedInputs = false;
        this.packedOutput = true;
        this.customUniforms = [{ name: 'texShape', type: 'ivec2' }];
        const glsl = getGlslDifferences();
        this.outputShape = outputShape;
        this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
        let mainLoop = '';
        let output = 'result';
        if (inputIsUnsignedByte) {
            output = 'floor(result * 255. + 0.5)';
        }
        for (let row = 0; row <= 1; row++) {
            for (let col = 0; col <= 1; col++) {
                const channel = row * 2 + col;
                mainLoop += `
          localCoords = coords;
          if(localCoords[2] + ${col} < ${this.enableShapeUniforms ? 'outShape[2]' : `${outputShape[2]}`}) {
          localCoords[2] += ${col};
          if (localCoords[1] + ${row} < ${this.enableShapeUniforms ? 'outShape[1]' : `${outputShape[1]}`}) {
            localCoords[1] += ${row};
 
            flatIndex = getFlatIndex(localCoords);
            offset = imod(flatIndex, 4);
 
            flatIndex = idiv(flatIndex, 4, 1.);
 
            int r = flatIndex / texShape[1];
            int c = imod(flatIndex, texShape[1]);
            vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]);
            values = ${glsl.texture2D}(A, uv);
 
            if (offset == 0) {
              result[${channel}] = values[0];
            } else if (offset == 1) {
              result[${channel}] = values[1];
            } else if (offset == 2) {
              result[${channel}] = values[2];
            } else {
              result[${channel}] = values[3];
            }
          }
        }
        `;
            }
        }
        this.userCode = `
        ${this.enableShapeUniforms ? shader_util.getFlatIndexFrom3DOutput() :
            shader_util.getFlatIndexFrom3D(outputShape)}
 
        void main() {
          ivec3 coords = getOutputCoords();
 
          vec4 result = vec4(0.);
          int flatIndex, r, c, offset;
          ivec3 localCoords;
          vec2 uv;
          vec4 values;
 
          ${mainLoop}
 
          ${glsl.output} = ${output};
        }
    `;
    }
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiZW5jb2RlX21hdHJpeF9wYWNrZWRfZ3B1LmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLXdlYmdsL3NyYy9lbmNvZGVfbWF0cml4X3BhY2tlZF9ncHUudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLGtCQUFrQixFQUFDLE1BQU0sZ0JBQWdCLENBQUM7QUFDbEQsT0FBTyxFQUFlLGdCQUFnQixFQUFDLE1BQU0sY0FBYyxDQUFDO0FBQzVELE9BQU8sS0FBSyxXQUFXLE1BQU0sd0JBQXdCLENBQUM7QUFFdEQ7Ozs7Ozs7Ozs7Ozs7O0dBY0c7QUFFSCxNQUFNLE9BQU8seUJBQXlCO0lBU3BDLFlBQ0ksV0FBcUMsRUFBRSxtQkFBbUIsR0FBRyxLQUFLO1FBVHRFLGtCQUFhLEdBQUcsQ0FBQyxHQUFHLENBQUMsQ0FBQztRQUd0QixpQkFBWSxHQUFHLEtBQUssQ0FBQztRQUNyQixpQkFBWSxHQUFHLElBQUksQ0FBQztRQUVwQixtQkFBYyxHQUFHLENBQUMsRUFBQyxJQUFJLEVBQUUsVUFBVSxFQUFFLElBQUksRUFBRSxPQUFnQixFQUFFLENBQUMsQ0FBQztRQUk3RCxNQUFNLElBQUksR0FBRyxrQkFBa0IsRUFBRSxDQUFDO1FBQ2xDLElBQUksQ0FBQyxXQUFXLEdBQUcsV0FBVyxDQUFDO1FBQy9CLElBQUksQ0FBQyxtQkFBbUIsR0FBRyxnQkFBZ0IsQ0FBQyxJQUFJLENBQUMsV0FBVyxDQUFDLE1BQU0sQ0FBQyxDQUFDO1FBRXJFLElBQUksUUFBUSxHQUFHLEVBQUUsQ0FBQztRQUNsQixJQUFJLE1BQU0sR0FBRyxRQUFRLENBQUM7UUFDdEIsSUFBSSxtQkFBbUIsRUFBRTtZQUN2QixNQUFNLEdBQUcsNEJBQTRCLENBQUM7U0FDdkM7UUFFRCxLQUFLLElBQUksR0FBRyxHQUFHLENBQUMsRUFBRSxHQUFHLElBQUksQ0FBQyxFQUFFLEdBQUcsRUFBRSxFQUFFO1lBQ2pDLEtBQUssSUFBSSxHQUFHLEdBQUcsQ0FBQyxFQUFFLEdBQUcsSUFBSSxDQUFDLEVBQUUsR0FBRyxFQUFFLEVBQUU7Z0JBQ2pDLE1BQU0sT0FBTyxHQUFHLEdBQUcsR0FBRyxDQUFDLEdBQUcsR0FBRyxDQUFDO2dCQUU5QixRQUFRLElBQUk7O2dDQUVZLEdBQUcsTUFDdkIsSUFBSSxDQUFDLG1CQUFtQixDQUFDLENBQUMsQ0FBQyxhQUFhLENBQUMsQ0FBQyxDQUFDLEdBQUcsV0FBVyxDQUFDLENBQUMsQ0FBQyxFQUFFOzhCQUM1QyxHQUFHO2lDQUNBLEdBQUcsTUFDeEIsSUFBSSxDQUFDLG1CQUFtQixDQUFDLENBQUMsQ0FBQyxhQUFhLENBQUMsQ0FBQyxDQUFDLEdBQUcsV0FBVyxDQUFDLENBQUMsQ0FBQyxFQUFFO2dDQUMxQyxHQUFHOzs7Ozs7Ozs7O3VCQVVaLElBQUksQ0FBQyxTQUFTOzs7dUJBR2QsT0FBTzs7dUJBRVAsT0FBTzs7dUJBRVAsT0FBTzs7dUJBRVAsT0FBTzs7OztTQUlyQixDQUFDO2FBQ0g7U0FDRjtRQUVELElBQUksQ0FBQyxRQUFRLEdBQUc7VUFFWixJQUFJLENBQUMsbUJBQW1CLENBQUMsQ0FBQyxDQUFDLFdBQVcsQ0FBQyx3QkFBd0IsRUFBRSxDQUFDLENBQUM7WUFDeEMsV0FBVyxDQUFDLGtCQUFrQixDQUFDLFdBQVcsQ0FBQzs7Ozs7Ozs7Ozs7WUFXbEUsUUFBUTs7WUFFUixJQUFJLENBQUMsTUFBTSxNQUFNLE1BQU07O0tBRTlCLENBQUM7SUFDSixDQUFDO0NBQ0YiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAxOCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7Z2V0R2xzbERpZmZlcmVuY2VzfSBmcm9tICcuL2dsc2xfdmVyc2lvbic7XG5pbXBvcnQge0dQR1BVUHJvZ3JhbSwgdXNlU2hhcGVVbmlmb3Jtc30gZnJvbSAnLi9ncGdwdV9tYXRoJztcbmltcG9ydCAqIGFzIHNoYWRlcl91dGlsIGZyb20gJy4vc2hhZGVyX2NvbXBpbGVyX3V0aWwnO1xuXG4vKlxuVGhpcyBpcyBob3cgdGhlIHNoYWRlciBlbmNvZGVzIGEgdGVuc29yIHdpdGggc2hhcGUgPSBbMiwgMywgNV1cbihpbmRpY2VzIGFyZSBbYmF0Y2gsIHJvdywgY29sXSkuXG5cbjAwMHwwMDEgICAwMDJ8MDAzICAgMDA0fHh4eCAgIDAyMHwwMjEgICAwMjJ8MDIzICAgMDI0fHh4eFxuLS0tLS0tLSAgIC0tLS0tLS0gICAtLS0tLS0tICAgLS0tLS0tLSAgIC0tLS0tLS0gICAtLS0tLS0tXG4wMTB8MDExICAgMDEyfDAxMyAgIDAxNHx4eHggICB4eHh8eHh4ICAgeHh4fHh4eCAgIHh4eHx4eHhcblxuMTAwfDEwMSAgIDEwMnwxMDMgICAxMDR8eHh4ICAgMTIwfDEyMSAgIDEyMnwxMjMgICAxMjR8eHh4XG4tLS0tLS0tICAgLS0tLS0tLSAgIC0tLS0tLS0gICAtLS0tLS0tICAgLS0tLS0tLSAgIC0tLS0tLS1cbjExMHwxMTEgICAxMTJ8MTEzICAgMTE0fHh4eCAgIHh4eHx4eHggICB4eHh8eHh4ICAgeHh4fHh4eFxuXG5TaW5nbGUgdGV4ZWxzIGNvbnRhaW4gb25seSB2YWx1ZXMgZnJvbSB0aGUgc2FtZSBiYXRjaCwgYW5kIGZyb20gYWRqYWNlbnQgcm93c1xuYW5kIGNvbHVtbnMuXG4gKi9cblxuZXhwb3J0IGNsYXNzIEVuY29kZU1hdHJpeFBhY2tlZFByb2dyYW0gaW1wbGVtZW50cyBHUEdQVVByb2dyYW0ge1xuICB2YXJpYWJsZU5hbWVzID0gWydBJ107XG4gIHVzZXJDb2RlOiBzdHJpbmc7XG4gIG91dHB1dFNoYXBlOiBudW1iZXJbXTtcbiAgcGFja2VkSW5wdXRzID0gZmFsc2U7XG4gIHBhY2tlZE91dHB1dCA9IHRydWU7XG4gIGVuYWJsZVNoYXBlVW5pZm9ybXM6IGJvb2xlYW47XG4gIGN1c3RvbVVuaWZvcm1zID0gW3tuYW1lOiAndGV4U2hhcGUnLCB0eXBlOiAnaXZlYzInIGFzIGNvbnN0IH1dO1xuXG4gIGNvbnN0cnVjdG9yKFxuICAgICAgb3V0cHV0U2hhcGU6IFtudW1iZXIsIG51bWJlciwgbnVtYmVyXSwgaW5wdXRJc1Vuc2lnbmVkQnl0ZSA9IGZhbHNlKSB7XG4gICAgY29uc3QgZ2xzbCA9IGdldEdsc2xEaWZmZXJlbmNlcygpO1xuICAgIHRoaXMub3V0cHV0U2hhcGUgPSBvdXRwdXRTaGFwZTtcbiAgICB0aGlzLmVuYWJsZVNoYXBlVW5pZm9ybXMgPSB1c2VTaGFwZVVuaWZvcm1zKHRoaXMub3V0cHV0U2hhcGUubGVuZ3RoKTtcblxuICAgIGxldCBtYWluTG9vcCA9ICcnO1xuICAgIGxldCBvdXRwdXQgPSAncmVzdWx0JztcbiAgICBpZiAoaW5wdXRJc1Vuc2lnbmVkQnl0ZSkge1xuICAgICAgb3V0cHV0ID0gJ2Zsb29yKHJlc3VsdCAqIDI1NS4gKyAwLjUpJztcbiAgICB9XG5cbiAgICBmb3IgKGxldCByb3cgPSAwOyByb3cgPD0gMTsgcm93KyspIHtcbiAgICAgIGZvciAobGV0IGNvbCA9IDA7IGNvbCA8PSAxOyBjb2wrKykge1xuICAgICAgICBjb25zdCBjaGFubmVsID0gcm93ICogMiArIGNvbDtcblxuICAgICAgICBtYWluTG9vcCArPSBgXG4gICAgICAgICAgbG9jYWxDb29yZHMgPSBjb29yZHM7XG4gICAgICAgICAgaWYobG9jYWxDb29yZHNbMl0gKyAke2NvbH0gPCAke1xuICAgICAgICAgICAgdGhpcy5lbmFibGVTaGFwZVVuaWZvcm1zID8gJ291dFNoYXBlWzJdJyA6IGAke291dHB1dFNoYXBlWzJdfWB9KSB7XG4gICAgICAgICAgbG9jYWxDb29yZHNbMl0gKz0gJHtjb2x9O1xuICAgICAgICAgIGlmIChsb2NhbENvb3Jkc1sxXSArICR7cm93fSA8ICR7XG4gICAgICAgICAgICB0aGlzLmVuYWJsZVNoYXBlVW5pZm9ybXMgPyAnb3V0U2hhcGVbMV0nIDogYCR7b3V0cHV0U2hhcGVbMV19YH0pIHtcbiAgICAgICAgICAgIGxvY2FsQ29vcmRzWzFdICs9ICR7cm93fTtcblxuICAgICAgICAgICAgZmxhdEluZGV4ID0gZ2V0RmxhdEluZGV4KGxvY2FsQ29vcmRzKTtcbiAgICAgICAgICAgIG9mZnNldCA9IGltb2QoZmxhdEluZGV4LCA0KTtcblxuICAgICAgICAgICAgZmxhdEluZGV4ID0gaWRpdihmbGF0SW5kZXgsIDQsIDEuKTtcblxuICAgICAgICAgICAgaW50IHIgPSBmbGF0SW5kZXggLyB0ZXhTaGFwZVsxXTtcbiAgICAgICAgICAgIGludCBjID0gaW1vZChmbGF0SW5kZXgsIHRleFNoYXBlWzFdKTtcbiAgICAgICAgICAgIHZlYzIgdXYgPSAodmVjMihjLCByKSArIGhhbGZDUikgLyB2ZWMyKHRleFNoYXBlWzFdLCB0ZXhTaGFwZVswXSk7XG4gICAgICAgICAgICB2YWx1ZXMgPSAke2dsc2wudGV4dHVyZTJEfShBLCB1dik7XG5cbiAgICAgICAgICAgIGlmIChvZmZzZXQgPT0gMCkge1xuICAgICAgICAgICAgICByZXN1bHRbJHtjaGFubmVsfV0gPSB2YWx1ZXNbMF07XG4gICAgICAgICAgICB9IGVsc2UgaWYgKG9mZnNldCA9PSAxKSB7XG4gICAgICAgICAgICAgIHJlc3VsdFske2NoYW5uZWx9XSA9IHZhbHVlc1sxXTtcbiAgICAgICAgICAgIH0gZWxzZSBpZiAob2Zmc2V0ID09IDIpIHtcbiAgICAgICAgICAgICAgcmVzdWx0WyR7Y2hhbm5lbH1dID0gdmFsdWVzWzJdO1xuICAgICAgICAgICAgfSBlbHNlIHtcbiAgICAgICAgICAgICAgcmVzdWx0WyR7Y2hhbm5lbH1dID0gdmFsdWVzWzNdO1xuICAgICAgICAgICAgfVxuICAgICAgICAgIH1cbiAgICAgICAgfVxuICAgICAgICBgO1xuICAgICAgfVxuICAgIH1cblxuICAgIHRoaXMudXNlckNvZGUgPSBgXG4gICAgICAgICR7XG4gICAgICAgIHRoaXMuZW5hYmxlU2hhcGVVbmlmb3JtcyA/IHNoYWRlcl91dGlsLmdldEZsYXRJbmRleEZyb20zRE91dHB1dCgpIDpcbiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgc2hhZGVyX3V0aWwuZ2V0RmxhdEluZGV4RnJvbTNEKG91dHB1dFNoYXBlKX1cblxuICAgICAgICB2b2lkIG1haW4oKSB7XG4gICAgICAgICAgaXZlYzMgY29vcmRzID0gZ2V0T3V0cHV0Q29vcmRzKCk7XG5cbiAgICAgICAgICB2ZWM0IHJlc3VsdCA9IHZlYzQoMC4pO1xuICAgICAgICAgIGludCBmbGF0SW5kZXgsIHIsIGMsIG9mZnNldDtcbiAgICAgICAgICBpdmVjMyBsb2NhbENvb3JkcztcbiAgICAgICAgICB2ZWMyIHV2O1xuICAgICAgICAgIHZlYzQgdmFsdWVzO1xuXG4gICAgICAgICAgJHttYWluTG9vcH1cblxuICAgICAgICAgICR7Z2xzbC5vdXRwdXR9ID0gJHtvdXRwdXR9O1xuICAgICAgICB9XG4gICAgYDtcbiAgfVxufVxuIl19