/**
|
* @license
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { backend_util } from '@tensorflow/tfjs-core';
|
export class ConcatProgram {
|
// Concats 2d tensors along axis=1. See comments in MathBackendWebGL.concat().
|
constructor(shapes) {
|
this.outputShape = [];
|
this.outputShape = backend_util.computeOutShape(shapes, 1 /* axis */);
|
this.variableNames = shapes.map((_, i) => `T${i}`);
|
const offsets = new Array(shapes.length - 1);
|
offsets[0] = shapes[0][1];
|
for (let i = 1; i < offsets.length; i++) {
|
offsets[i] = offsets[i - 1] + shapes[i][1];
|
}
|
const snippets = [`if (yC < ${offsets[0]}) setOutput(getT0(yR, yC));`];
|
for (let i = 1; i < offsets.length; i++) {
|
const shift = offsets[i - 1];
|
snippets.push(`else if (yC < ${offsets[i]}) ` +
|
`setOutput(getT${i}(yR, yC-${shift}));`);
|
}
|
const lastIndex = offsets.length;
|
const lastShift = offsets[offsets.length - 1];
|
snippets.push(`else setOutput(getT${lastIndex}(yR, yC-${lastShift}));`);
|
this.userCode = `
|
void main() {
|
ivec2 coords = getOutputCoords();
|
int yR = coords.x;
|
int yC = coords.y;
|
|
${snippets.join('\n ')}
|
}
|
`;
|
}
|
}
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiY29uY2F0X2dwdS5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13ZWJnbC9zcmMvY29uY2F0X2dwdS50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFFSCxPQUFPLEVBQUMsWUFBWSxFQUFDLE1BQU0sdUJBQXVCLENBQUM7QUFHbkQsTUFBTSxPQUFPLGFBQWE7SUFLeEIsOEVBQThFO0lBQzlFLFlBQVksTUFBK0I7UUFKM0MsZ0JBQVcsR0FBYSxFQUFFLENBQUM7UUFLekIsSUFBSSxDQUFDLFdBQVcsR0FBRyxZQUFZLENBQUMsZUFBZSxDQUFDLE1BQU0sRUFBRSxDQUFDLENBQUMsVUFBVSxDQUFDLENBQUM7UUFDdEUsSUFBSSxDQUFDLGFBQWEsR0FBRyxNQUFNLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxFQUFFLENBQUMsSUFBSSxDQUFDLEVBQUUsQ0FBQyxDQUFDO1FBRW5ELE1BQU0sT0FBTyxHQUFhLElBQUksS0FBSyxDQUFDLE1BQU0sQ0FBQyxNQUFNLEdBQUcsQ0FBQyxDQUFDLENBQUM7UUFDdkQsT0FBTyxDQUFDLENBQUMsQ0FBQyxHQUFHLE1BQU0sQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUMxQixLQUFLLElBQUksQ0FBQyxHQUFHLENBQUMsRUFBRSxDQUFDLEdBQUcsT0FBTyxDQUFDLE1BQU0sRUFBRSxDQUFDLEVBQUUsRUFBRTtZQUN2QyxPQUFPLENBQUMsQ0FBQyxDQUFDLEdBQUcsT0FBTyxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUMsR0FBRyxNQUFNLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7U0FDNUM7UUFFRCxNQUFNLFFBQVEsR0FBRyxDQUFDLFlBQVksT0FBTyxDQUFDLENBQUMsQ0FBQyw2QkFBNkIsQ0FBQyxDQUFDO1FBQ3ZFLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxPQUFPLENBQUMsTUFBTSxFQUFFLENBQUMsRUFBRSxFQUFFO1lBQ3ZDLE1BQU0sS0FBSyxHQUFHLE9BQU8sQ0FBQyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUM7WUFDN0IsUUFBUSxDQUFDLElBQUksQ0FDVCxpQkFBaUIsT0FBTyxDQUFDLENBQUMsQ0FBQyxJQUFJO2dCQUMvQixpQkFBaUIsQ0FBQyxXQUFXLEtBQUssS0FBSyxDQUFDLENBQUM7U0FDOUM7UUFDRCxNQUFNLFNBQVMsR0FBRyxPQUFPLENBQUMsTUFBTSxDQUFDO1FBQ2pDLE1BQU0sU0FBUyxHQUFHLE9BQU8sQ0FBQyxPQUFPLENBQUMsTUFBTSxHQUFHLENBQUMsQ0FBQyxDQUFDO1FBQzlDLFFBQVEsQ0FBQyxJQUFJLENBQUMsc0JBQXNCLFNBQVMsV0FBVyxTQUFTLEtBQUssQ0FBQyxDQUFDO1FBRXhFLElBQUksQ0FBQyxRQUFRLEdBQUc7Ozs7OztVQU1WLFFBQVEsQ0FBQyxJQUFJLENBQUMsWUFBWSxDQUFDOztLQUVoQyxDQUFDO0lBQ0osQ0FBQztDQUNGIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbH0gZnJvbSAnQHRlbnNvcmZsb3cvdGZqcy1jb3JlJztcbmltcG9ydCB7R1BHUFVQcm9ncmFtfSBmcm9tICcuL2dwZ3B1X21hdGgnO1xuXG5leHBvcnQgY2xhc3MgQ29uY2F0UHJvZ3JhbSBpbXBsZW1lbnRzIEdQR1BVUHJvZ3JhbSB7XG4gIHZhcmlhYmxlTmFtZXM6IHN0cmluZ1tdO1xuICBvdXRwdXRTaGFwZTogbnVtYmVyW10gPSBbXTtcbiAgdXNlckNvZGU6IHN0cmluZztcblxuICAvLyBDb25jYXRzIDJkIHRlbnNvcnMgYWxvbmcgYXhpcz0xLiBTZWUgY29tbWVudHMgaW4gTWF0aEJhY2tlbmRXZWJHTC5jb25jYXQoKS5cbiAgY29uc3RydWN0b3Ioc2hhcGVzOiBBcnJheTxbbnVtYmVyLCBudW1iZXJdPikge1xuICAgIHRoaXMub3V0cHV0U2hhcGUgPSBiYWNrZW5kX3V0aWwuY29tcHV0ZU91dFNoYXBlKHNoYXBlcywgMSAvKiBheGlzICovKTtcbiAgICB0aGlzLnZhcmlhYmxlTmFtZXMgPSBzaGFwZXMubWFwKChfLCBpKSA9PiBgVCR7aX1gKTtcblxuICAgIGNvbnN0IG9mZnNldHM6IG51bWJlcltdID0gbmV3IEFycmF5KHNoYXBlcy5sZW5ndGggLSAxKTtcbiAgICBvZmZzZXRzWzBdID0gc2hhcGVzWzBdWzFdO1xuICAgIGZvciAobGV0IGkgPSAxOyBpIDwgb2Zmc2V0cy5sZW5ndGg7IGkrKykge1xuICAgICAgb2Zmc2V0c1tpXSA9IG9mZnNldHNbaSAtIDFdICsgc2hhcGVzW2ldWzFdO1xuICAgIH1cblxuICAgIGNvbnN0IHNuaXBwZXRzID0gW2BpZiAoeUMgPCAke29mZnNldHNbMF19KSBzZXRPdXRwdXQoZ2V0VDAoeVIsIHlDKSk7YF07XG4gICAgZm9yIChsZXQgaSA9IDE7IGkgPCBvZmZzZXRzLmxlbmd0aDsgaSsrKSB7XG4gICAgICBjb25zdCBzaGlmdCA9IG9mZnNldHNbaSAtIDFdO1xuICAgICAgc25pcHBldHMucHVzaChcbiAgICAgICAgICBgZWxzZSBpZiAoeUMgPCAke29mZnNldHNbaV19KSBgICtcbiAgICAgICAgICBgc2V0T3V0cHV0KGdldFQke2l9KHlSLCB5Qy0ke3NoaWZ0fSkpO2ApO1xuICAgIH1cbiAgICBjb25zdCBsYXN0SW5kZXggPSBvZmZzZXRzLmxlbmd0aDtcbiAgICBjb25zdCBsYXN0U2hpZnQgPSBvZmZzZXRzW29mZnNldHMubGVuZ3RoIC0gMV07XG4gICAgc25pcHBldHMucHVzaChgZWxzZSBzZXRPdXRwdXQoZ2V0VCR7bGFzdEluZGV4fSh5UiwgeUMtJHtsYXN0U2hpZnR9KSk7YCk7XG5cbiAgICB0aGlzLnVzZXJDb2RlID0gYFxuICAgICAgdm9pZCBtYWluKCkge1xuICAgICAgICBpdmVjMiBjb29yZHMgPSBnZXRPdXRwdXRDb29yZHMoKTtcbiAgICAgICAgaW50IHlSID0gY29vcmRzLng7XG4gICAgICAgIGludCB5QyA9IGNvb3Jkcy55O1xuXG4gICAgICAgICR7c25pcHBldHMuam9pbignXFxuICAgICAgICAnKX1cbiAgICAgIH1cbiAgICBgO1xuICB9XG59XG4iXX0=
|