gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
/**
 * @license
 * Copyright 2017 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util } from '@tensorflow/tfjs-core';
export class ConcatProgram {
    // Concats 2d tensors along axis=1. See comments in MathBackendWebGL.concat().
    constructor(shapes) {
        this.outputShape = [];
        this.outputShape = backend_util.computeOutShape(shapes, 1 /* axis */);
        this.variableNames = shapes.map((_, i) => `T${i}`);
        const offsets = new Array(shapes.length - 1);
        offsets[0] = shapes[0][1];
        for (let i = 1; i < offsets.length; i++) {
            offsets[i] = offsets[i - 1] + shapes[i][1];
        }
        const snippets = [`if (yC < ${offsets[0]}) setOutput(getT0(yR, yC));`];
        for (let i = 1; i < offsets.length; i++) {
            const shift = offsets[i - 1];
            snippets.push(`else if (yC < ${offsets[i]}) ` +
                `setOutput(getT${i}(yR, yC-${shift}));`);
        }
        const lastIndex = offsets.length;
        const lastShift = offsets[offsets.length - 1];
        snippets.push(`else setOutput(getT${lastIndex}(yR, yC-${lastShift}));`);
        this.userCode = `
      void main() {
        ivec2 coords = getOutputCoords();
        int yR = coords.x;
        int yC = coords.y;
 
        ${snippets.join('\n        ')}
      }
    `;
    }
}
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiY29uY2F0X2dwdS5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13ZWJnbC9zcmMvY29uY2F0X2dwdS50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsWUFBWSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFHbkQsTUFBTSxPQUFPLGFBQWE7SUFLeEIsOEVBQThFO0lBQzlFLFlBQVksTUFBK0I7UUFKM0MsZ0JBQVcsR0FBYSxFQUFFLENBQUM7UUFLekIsSUFBSSxDQUFDLFdBQVcsR0FBRyxZQUFZLENBQUMsZUFBZSxDQUFDLE1BQU0sRUFBRSxDQUFDLENBQUMsVUFBVSxDQUFDLENBQUM7UUFDdEUsSUFBSSxDQUFDLGFBQWEsR0FBRyxNQUFNLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxFQUFFLENBQUMsSUFBSSxDQUFDLEVBQUUsQ0FBQyxDQUFDO1FBRW5ELE1BQU0sT0FBTyxHQUFhLElBQUksS0FBSyxDQUFDLE1BQU0sQ0FBQyxNQUFNLEdBQUcsQ0FBQyxDQUFDLENBQUM7UUFDdkQsT0FBTyxDQUFDLENBQUMsQ0FBQyxHQUFHLE1BQU0sQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUMxQixLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsT0FBTyxDQUFDLE1BQU0sRUFBRSxDQUFDLEVBQUUsRUFBRTtZQUN2QyxPQUFPLENBQUMsQ0FBQyxDQUFDLEdBQUcsT0FBTyxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUMsR0FBRyxNQUFNLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7U0FDNUM7UUFFRCxNQUFNLFFBQVEsR0FBRyxDQUFDLFlBQVksT0FBTyxDQUFDLENBQUMsQ0FBQyw2QkFBNkIsQ0FBQyxDQUFDO1FBQ3ZFLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxPQUFPLENBQUMsTUFBTSxFQUFFLENBQUMsRUFBRSxFQUFFO1lBQ3ZDLE1BQU0sS0FBSyxHQUFHLE9BQU8sQ0FBQyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUM7WUFDN0IsUUFBUSxDQUFDLElBQUksQ0FDVCxpQkFBaUIsT0FBTyxDQUFDLENBQUMsQ0FBQyxJQUFJO2dCQUMvQixpQkFBaUIsQ0FBQyxXQUFXLEtBQUssS0FBSyxDQUFDLENBQUM7U0FDOUM7UUFDRCxNQUFNLFNBQVMsR0FBRyxPQUFPLENBQUMsTUFBTSxDQUFDO1FBQ2pDLE1BQU0sU0FBUyxHQUFHLE9BQU8sQ0FBQyxPQUFPLENBQUMsTUFBTSxHQUFHLENBQUMsQ0FBQyxDQUFDO1FBQzlDLFFBQVEsQ0FBQyxJQUFJLENBQUMsc0JBQXNCLFNBQVMsV0FBVyxTQUFTLEtBQUssQ0FBQyxDQUFDO1FBRXhFLElBQUksQ0FBQyxRQUFRLEdBQUc7Ozs7OztVQU1WLFFBQVEsQ0FBQyxJQUFJLENBQUMsWUFBWSxDQUFDOztLQUVoQyxDQUFDO0lBQ0osQ0FBQztDQUNGIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbH0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcbmltcG9ydCB7R1BHUFVQcm9ncmFtfSBmcm9tICcuL2dwZ3B1X21hdGgnO1xuXG5leHBvcnQgY2xhc3MgQ29uY2F0UHJvZ3JhbSBpbXBsZW1lbnRzIEdQR1BVUHJvZ3JhbSB7XG4gIHZhcmlhYmxlTmFtZXM6IHN0cmluZ1tdO1xuICBvdXRwdXRTaGFwZTogbnVtYmVyW10gPSBbXTtcbiAgdXNlckNvZGU6IHN0cmluZztcblxuICAvLyBDb25jYXRzIDJkIHRlbnNvcnMgYWxvbmcgYXhpcz0xLiBTZWUgY29tbWVudHMgaW4gTWF0aEJhY2tlbmRXZWJHTC5jb25jYXQoKS5cbiAgY29uc3RydWN0b3Ioc2hhcGVzOiBBcnJheTxbbnVtYmVyLCBudW1iZXJdPikge1xuICAgIHRoaXMub3V0cHV0U2hhcGUgPSBiYWNrZW5kX3V0aWwuY29tcHV0ZU91dFNoYXBlKHNoYXBlcywgMSAvKiBheGlzICovKTtcbiAgICB0aGlzLnZhcmlhYmxlTmFtZXMgPSBzaGFwZXMubWFwKChfLCBpKSA9PiBgVCR7aX1gKTtcblxuICAgIGNvbnN0IG9mZnNldHM6IG51bWJlcltdID0gbmV3IEFycmF5KHNoYXBlcy5sZW5ndGggLSAxKTtcbiAgICBvZmZzZXRzWzBdID0gc2hhcGVzWzBdWzFdO1xuICAgIGZvciAobGV0IGkgPSAxOyBpIDwgb2Zmc2V0cy5sZW5ndGg7IGkrKykge1xuICAgICAgb2Zmc2V0c1tpXSA9IG9mZnNldHNbaSAtIDFdICsgc2hhcGVzW2ldWzFdO1xuICAgIH1cblxuICAgIGNvbnN0IHNuaXBwZXRzID0gW2BpZiAoeUMgPCAke29mZnNldHNbMF19KSBzZXRPdXRwdXQoZ2V0VDAoeVIsIHlDKSk7YF07XG4gICAgZm9yIChsZXQgaSA9IDE7IGkgPCBvZmZzZXRzLmxlbmd0aDsgaSsrKSB7XG4gICAgICBjb25zdCBzaGlmdCA9IG9mZnNldHNbaSAtIDFdO1xuICAgICAgc25pcHBldHMucHVzaChcbiAgICAgICAgICBgZWxzZSBpZiAoeUMgPCAke29mZnNldHNbaV19KSBgICtcbiAgICAgICAgICBgc2V0T3V0cHV0KGdldFQke2l9KHlSLCB5Qy0ke3NoaWZ0fSkpO2ApO1xuICAgIH1cbiAgICBjb25zdCBsYXN0SW5kZXggPSBvZmZzZXRzLmxlbmd0aDtcbiAgICBjb25zdCBsYXN0U2hpZnQgPSBvZmZzZXRzW29mZnNldHMubGVuZ3RoIC0gMV07XG4gICAgc25pcHBldHMucHVzaChgZWxzZSBzZXRPdXRwdXQoZ2V0VCR7bGFzdEluZGV4fSh5UiwgeUMtJHtsYXN0U2hpZnR9KSk7YCk7XG5cbiAgICB0aGlzLnVzZXJDb2RlID0gYFxuICAgICAgdm9pZCBtYWluKCkge1xuICAgICAgICBpdmVjMiBjb29yZHMgPSBnZXRPdXRwdXRDb29yZHMoKTtcbiAgICAgICAgaW50IHlSID0gY29vcmRzLng7XG4gICAgICAgIGludCB5QyA9IGNvb3Jkcy55O1xuXG4gICAgICAgICR7c25pcHBldHMuam9pbignXFxuICAgICAgICAnKX1cbiAgICAgIH1cbiAgICBgO1xuICB9XG59XG4iXX0=