/** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { getCoordsDataType } from './shader_compiler'; export class GatherProgram { constructor(aShape, outputShape) { this.variableNames = ['A', 'indices']; this.outputShape = outputShape; this.rank = outputShape.length; const dtype = getCoordsDataType(this.rank); const sourceCoords = getSourceCoords(aShape, 2); this.userCode = ` void main() { ${dtype} resRC = getOutputCoords(); int index = int(getIndices(resRC.x, resRC.z)); float inBounds = (index >= 0) && (index < ${aShape[2]}) ? 1.0 : 0.0; setOutput(inBounds * getA(${sourceCoords})); } `; } } // The input and output are always flattened into rank 4 tensors. function getSourceCoords(aShape, axis) { const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w']; const sourceCoords = []; for (let i = 0; i < aShape.length; i++) { if (i === 2) { sourceCoords.push('index'); } else { sourceCoords.push(`${currentCoords[i]}`); } } return sourceCoords.join(); } //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiZ2F0aGVyX2dwdS5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uL3RmanMtYmFja2VuZC13ZWJnbC9zcmMvZ2F0aGVyX2dwdS50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFBQTs7Ozs7Ozs7Ozs7Ozs7O0dBZUc7QUFHSCxPQUFPLEVBQUMsaUJBQWlCLEVBQUMsTUFBTSxtQkFBbUIsQ0FBQztBQUlwRCxNQUFNLE9BQU8sYUFBYTtJQU14QixZQUFZLE1BQW1CLEVBQUUsV0FBd0I7UUFMekQsa0JBQWEsR0FBRyxDQUFDLEdBQUcsRUFBRSxTQUFTLENBQUMsQ0FBQztRQU0vQixJQUFJLENBQUMsV0FBVyxHQUFHLFdBQVcsQ0FBQztRQUMvQixJQUFJLENBQUMsSUFBSSxHQUFHLFdBQVcsQ0FBQyxNQUFNLENBQUM7UUFDL0IsTUFBTSxLQUFLLEdBQUcsaUJBQWlCLENBQUMsSUFBSSxDQUFDLElBQUksQ0FBQyxDQUFDO1FBQzNDLE1BQU0sWUFBWSxHQUFHLGVBQWUsQ0FBQyxNQUFNLEVBQUUsQ0FBQyxDQUFDLENBQUM7UUFFaEQsSUFBSSxDQUFDLFFBQVEsR0FBRzs7VUFFVixLQUFLOztvREFFcUMsTUFBTSxDQUFDLENBQUMsQ0FBQztvQ0FDekIsWUFBWTs7S0FFM0MsQ0FBQztJQUNKLENBQUM7Q0FDRjtBQUVELGlFQUFpRTtBQUNqRSxTQUFTLGVBQWUsQ0FBQyxNQUFtQixFQUFFLElBQVk7SUFDeEQsTUFBTSxhQUFhLEdBQUcsQ0FBQyxTQUFTLEVBQUUsU0FBUyxFQUFFLFNBQVMsRUFBRSxTQUFTLENBQUMsQ0FBQztJQUVuRSxNQUFNLFlBQVksR0FBRyxFQUFFLENBQUM7SUFDeEIsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLE1BQU0sQ0FBQyxNQUFNLEVBQUUsQ0FBQyxFQUFFLEVBQUU7UUFDdEMsSUFBSSxDQUFDLEtBQUssQ0FBQyxFQUFFO1lBQ1gsWUFBWSxDQUFDLElBQUksQ0FBQyxPQUFPLENBQUMsQ0FBQztTQUM1QjthQUFNO1lBQ0wsWUFBWSxDQUFDLElBQUksQ0FBQyxHQUFHLGFBQWEsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUM7U0FDMUM7S0FDRjtJQUNELE9BQU8sWUFBWSxDQUFDLElBQUksRUFBRSxDQUFDO0FBQzdCLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAxNyBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7R1BHUFVQcm9ncmFtfSBmcm9tICcuL2dwZ3B1X21hdGgnO1xuaW1wb3J0IHtnZXRDb29yZHNEYXRhVHlwZX0gZnJvbSAnLi9zaGFkZXJfY29tcGlsZXInO1xuXG5leHBvcnQgdHlwZSBHYXRoZXJTaGFwZSA9IFtudW1iZXIsIG51bWJlciwgbnVtYmVyLCBudW1iZXJdO1xuXG5leHBvcnQgY2xhc3MgR2F0aGVyUHJvZ3JhbSBpbXBsZW1lbnRzIEdQR1BVUHJvZ3JhbSB7XG4gIHZhcmlhYmxlTmFtZXMgPSBbJ0EnLCAnaW5kaWNlcyddO1xuICBvdXRwdXRTaGFwZTogbnVtYmVyW107XG4gIHVzZXJDb2RlOiBzdHJpbmc7XG4gIHJhbms6IG51bWJlcjtcblxuICBjb25zdHJ1Y3RvcihhU2hhcGU6IEdhdGhlclNoYXBlLCBvdXRwdXRTaGFwZTogR2F0aGVyU2hhcGUpIHtcbiAgICB0aGlzLm91dHB1dFNoYXBlID0gb3V0cHV0U2hhcGU7XG4gICAgdGhpcy5yYW5rID0gb3V0cHV0U2hhcGUubGVuZ3RoO1xuICAgIGNvbnN0IGR0eXBlID0gZ2V0Q29vcmRzRGF0YVR5cGUodGhpcy5yYW5rKTtcbiAgICBjb25zdCBzb3VyY2VDb29yZHMgPSBnZXRTb3VyY2VDb29yZHMoYVNoYXBlLCAyKTtcblxuICAgIHRoaXMudXNlckNvZGUgPSBgXG4gICAgICB2b2lkIG1haW4oKSB7XG4gICAgICAgICR7ZHR5cGV9IHJlc1JDID0gZ2V0T3V0cHV0Q29vcmRzKCk7XG4gICAgICAgIGludCBpbmRleCA9IGludChnZXRJbmRpY2VzKHJlc1JDLngsIHJlc1JDLnopKTtcbiAgICAgICAgZmxvYXQgaW5Cb3VuZHMgPSAoaW5kZXggPj0gMCkgJiYgKGluZGV4IDwgJHthU2hhcGVbMl19KSA/IDEuMCA6IDAuMDtcbiAgICAgICAgc2V0T3V0cHV0KGluQm91bmRzICogZ2V0QSgke3NvdXJjZUNvb3Jkc30pKTtcbiAgICAgIH1cbiAgICBgO1xuICB9XG59XG5cbi8vIFRoZSBpbnB1dCBhbmQgb3V0cHV0IGFyZSBhbHdheXMgZmxhdHRlbmVkIGludG8gcmFuayA0IHRlbnNvcnMuXG5mdW5jdGlvbiBnZXRTb3VyY2VDb29yZHMoYVNoYXBlOiBHYXRoZXJTaGFwZSwgYXhpczogbnVtYmVyKTogc3RyaW5nIHtcbiAgY29uc3QgY3VycmVudENvb3JkcyA9IFsncmVzUkMueCcsICdyZXNSQy55JywgJ3Jlc1JDLnonLCAncmVzUkMudyddO1xuXG4gIGNvbnN0IHNvdXJjZUNvb3JkcyA9IFtdO1xuICBmb3IgKGxldCBpID0gMDsgaSA8IGFTaGFwZS5sZW5ndGg7IGkrKykge1xuICAgIGlmIChpID09PSAyKSB7XG4gICAgICBzb3VyY2VDb29yZHMucHVzaCgnaW5kZXgnKTtcbiAgICB9IGVsc2Uge1xuICAgICAgc291cmNlQ29vcmRzLnB1c2goYCR7Y3VycmVudENvb3Jkc1tpXX1gKTtcbiAgICB9XG4gIH1cbiAgcmV0dXJuIHNvdXJjZUNvb3Jkcy5qb2luKCk7XG59XG4iXX0=