/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { useShapeUniforms } from './gpgpu_math';
|
import * as shader_util from './shader_compiler_util';
|
export class ReshapePackedProgram {
|
constructor(outputShape, inputShape) {
|
this.variableNames = ['A'];
|
this.packedInputs = true;
|
this.packedOutput = true;
|
this.customUniforms = [{ name: 'inputShape', type: 'ivec3' }];
|
this.outputShape = outputShape;
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
let mainLoop = ``;
|
for (let i = 0; i < 4; i++) {
|
let thisRC = `thisRC = rc;`;
|
if (i % 2 === 1) {
|
thisRC += `thisRC.z += 1;`;
|
}
|
if (i > 1) {
|
thisRC += `thisRC.y += 1;`;
|
}
|
mainLoop += `
|
${thisRC}
|
${i > 0 ? `if(thisRC.y < rows && thisRC.z < cols){` : ''}
|
int flatIndex = getFlatIndex(thisRC);
|
|
ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);
|
vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));
|
|
result[${i}] =
|
getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);
|
${i > 0 ? '}' : ''}
|
`;
|
}
|
this.userCode = `
|
${getReshapedInputCoords(inputShape, this.enableShapeUniforms)}
|
${this.enableShapeUniforms ? shader_util.getFlatIndexFrom3DOutput() :
|
shader_util.getFlatIndexFrom3D(outputShape)}
|
|
void main() {
|
ivec3 rc = getOutputCoords();
|
|
vec4 result = vec4(0.);
|
|
ivec3 thisRC;
|
int rows = ${this.enableShapeUniforms ? 'outShape[1]' : outputShape[1]};
|
int cols = ${this.enableShapeUniforms ? 'outShape[2]' : outputShape[2]};
|
|
${mainLoop}
|
|
setOutput(result);
|
}
|
`;
|
}
|
}
|
function getReshapedInputCoords(shape, enableShapeUniforms) {
|
const coordsFromIndexSnippet = enableShapeUniforms ?
|
shader_util.getLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], 'inputShape') :
|
shader_util.getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
|
return `
|
ivec3 inputCoordsFromReshapedOutCoords(int index) {
|
${coordsFromIndexSnippet}
|
return ivec3(r, c, d);
|
}
|
`;
|
}
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoicmVzaGFwZV9wYWNrZWRfZ3B1LmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLXdlYmdsL3NyYy9yZXNoYXBlX3BhY2tlZF9ncHUudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFlLGdCQUFnQixFQUFDLE1BQU0sY0FBYyxDQUFDO0FBQzVELE9BQU8sS0FBSyxXQUFXLE1BQU0sd0JBQXdCLENBQUM7QUFFdEQsTUFBTSxPQUFPLG9CQUFvQjtJQVMvQixZQUFZLFdBQXFDLEVBQUUsVUFFbEQ7UUFWRCxrQkFBYSxHQUFHLENBQUMsR0FBRyxDQUFDLENBQUM7UUFDdEIsaUJBQVksR0FBRyxJQUFJLENBQUM7UUFDcEIsaUJBQVksR0FBRyxJQUFJLENBQUM7UUFJcEIsbUJBQWMsR0FBRyxDQUFDLEVBQUMsSUFBSSxFQUFFLFlBQVksRUFBRSxJQUFJLEVBQUUsT0FBZ0IsRUFBRSxDQUFDLENBQUM7UUFLL0QsSUFBSSxDQUFDLFdBQVcsR0FBRyxXQUFXLENBQUM7UUFDL0IsSUFBSSxDQUFDLG1CQUFtQixHQUFHLGdCQUFnQixDQUFDLElBQUksQ0FBQyxXQUFXLENBQUMsTUFBTSxDQUFDLENBQUM7UUFFckUsSUFBSSxRQUFRLEdBQUcsRUFBRSxDQUFDO1FBQ2xCLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxFQUFFLEVBQUU7WUFDMUIsSUFBSSxNQUFNLEdBQUcsY0FBYyxDQUFDO1lBQzVCLElBQUksQ0FBQyxHQUFHLENBQUMsS0FBSyxDQUFDLEVBQUU7Z0JBQ2YsTUFBTSxJQUFJLGdCQUFnQixDQUFDO2FBQzVCO1lBQ0QsSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFO2dCQUNULE1BQU0sSUFBSSxnQkFBZ0IsQ0FBQzthQUM1QjtZQUVELFFBQVEsSUFBSTtVQUNSLE1BQU07VUFDTixDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsQ0FBQyx5Q0FBeUMsQ0FBQyxDQUFDLENBQUMsRUFBRTs7Ozs7O21CQU03QyxDQUFDOztVQUVWLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsRUFBRTtPQUNuQixDQUFDO1NBQ0g7UUFFRCxJQUFJLENBQUMsUUFBUSxHQUFHO1FBQ1osc0JBQXNCLENBQUMsVUFBVSxFQUFFLElBQUksQ0FBQyxtQkFBbUIsQ0FBQztRQUU1RCxJQUFJLENBQUMsbUJBQW1CLENBQUMsQ0FBQyxDQUFDLFdBQVcsQ0FBQyx3QkFBd0IsRUFBRSxDQUFDLENBQUM7WUFDeEMsV0FBVyxDQUFDLGtCQUFrQixDQUFDLFdBQVcsQ0FBQzs7Ozs7Ozs7cUJBUXpELElBQUksQ0FBQyxtQkFBbUIsQ0FBQyxDQUFDLENBQUMsYUFBYSxDQUFDLENBQUMsQ0FBQyxXQUFXLENBQUMsQ0FBQyxDQUFDO3FCQUN6RCxJQUFJLENBQUMsbUJBQW1CLENBQUMsQ0FBQyxDQUFDLGFBQWEsQ0FBQyxDQUFDLENBQUMsV0FBVyxDQUFDLENBQUMsQ0FBQzs7VUFFcEUsUUFBUTs7OztLQUliLENBQUM7SUFDSixDQUFDO0NBQ0Y7QUFFRCxTQUFTLHNCQUFzQixDQUMzQixLQUErQixFQUFFLG1CQUE0QjtJQUMvRCxNQUFNLHNCQUFzQixHQUFHLG1CQUFtQixDQUFDLENBQUM7UUFDaEQsV0FBVyxDQUFDLDJDQUEyQyxDQUNuRCxDQUFDLEdBQUcsRUFBRSxHQUFHLEVBQUUsR0FBRyxDQUFDLEVBQUUsWUFBWSxDQUFDLENBQUMsQ0FBQztRQUNwQyxXQUFXLENBQUMsa0NBQWtDLENBQUMsQ0FBQyxHQUFHLEVBQUUsR0FBRyxFQUFFLEdBQUcsQ0FBQyxFQUFFLEtBQUssQ0FBQyxDQUFDO0lBRTNFLE9BQU87O1FBRUQsc0JBQXNCOzs7R0FHM0IsQ0FBQztBQUNKLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAxOCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7R1BHUFVQcm9ncmFtLCB1c2VTaGFwZVVuaWZvcm1zfSBmcm9tICcuL2dwZ3B1X21hdGgnO1xuaW1wb3J0ICogYXMgc2hhZGVyX3V0aWwgZnJvbSAnLi9zaGFkZXJfY29tcGlsZXJfdXRpbCc7XG5cbmV4cG9ydCBjbGFzcyBSZXNoYXBlUGFja2VkUHJvZ3JhbSBpbXBsZW1lbnRzIEdQR1BVUHJvZ3JhbSB7XG4gIHZhcmlhYmxlTmFtZXMgPSBbJ0EnXTtcbiAgcGFja2VkSW5wdXRzID0gdHJ1ZTtcbiAgcGFja2VkT3V0cHV0ID0gdHJ1ZTtcbiAgb3V0cHV0U2hhcGU6IG51bWJlcltdO1xuICB1c2VyQ29kZTogc3RyaW5nO1xuICBlbmFibGVTaGFwZVVuaWZvcm1zOiBib29sZWFuO1xuICBjdXN0b21Vbmlmb3JtcyA9IFt7bmFtZTogJ2lucHV0U2hhcGUnLCB0eXBlOiAnaXZlYzMnIGFzIGNvbnN0IH1dO1xuXG4gIGNvbnN0cnVjdG9yKG91dHB1dFNoYXBlOiBbbnVtYmVyLCBudW1iZXIsIG51bWJlcl0sIGlucHV0U2hhcGU6IFtcbiAgICBudW1iZXIsIG51bWJlciwgbnVtYmVyXG4gIF0pIHtcbiAgICB0aGlzLm91dHB1dFNoYXBlID0gb3V0cHV0U2hhcGU7XG4gICAgdGhpcy5lbmFibGVTaGFwZVVuaWZvcm1zID0gdXNlU2hhcGVVbmlmb3Jtcyh0aGlzLm91dHB1dFNoYXBlLmxlbmd0aCk7XG5cbiAgICBsZXQgbWFpbkxvb3AgPSBgYDtcbiAgICBmb3IgKGxldCBpID0gMDsgaSA8IDQ7IGkrKykge1xuICAgICAgbGV0IHRoaXNSQyA9IGB0aGlzUkMgPSByYztgO1xuICAgICAgaWYgKGkgJSAyID09PSAxKSB7XG4gICAgICAgIHRoaXNSQyArPSBgdGhpc1JDLnogKz0gMTtgO1xuICAgICAgfVxuICAgICAgaWYgKGkgPiAxKSB7XG4gICAgICAgIHRoaXNSQyArPSBgdGhpc1JDLnkgKz0gMTtgO1xuICAgICAgfVxuXG4gICAgICBtYWluTG9vcCArPSBgXG4gICAgICAgICR7dGhpc1JDfVxuICAgICAgICAke2kgPiAwID8gYGlmKHRoaXNSQy55IDwgcm93cyAmJiB0aGlzUkMueiA8IGNvbHMpe2AgOiAnJ31cbiAgICAgICAgICBpbnQgZmxhdEluZGV4ID0gZ2V0RmxhdEluZGV4KHRoaXNSQyk7XG5cbiAgICAgICAgICBpdmVjMyBpbnB1dFJDID0gaW5wdXRDb29yZHNGcm9tUmVzaGFwZWRPdXRDb29yZHMoZmxhdEluZGV4KTtcbiAgICAgICAgICB2ZWMyIGlucHV0UkNJbm5lckRpbXMgPSB2ZWMyKGZsb2F0KGlucHV0UkMueSksZmxvYXQoaW5wdXRSQy56KSk7XG5cbiAgICAgICAgICByZXN1bHRbJHtpfV0gPVxuICAgICAgICAgICAgZ2V0Q2hhbm5lbChnZXRBKGlucHV0UkMueCwgaW5wdXRSQy55LCBpbnB1dFJDLnopLCBpbnB1dFJDSW5uZXJEaW1zKTtcbiAgICAgICAgJHtpID4gMCA/ICd9JyA6ICcnfVxuICAgICAgYDtcbiAgICB9XG5cbiAgICB0aGlzLnVzZXJDb2RlID0gYFxuICAgICAgJHtnZXRSZXNoYXBlZElucHV0Q29vcmRzKGlucHV0U2hhcGUsIHRoaXMuZW5hYmxlU2hhcGVVbmlmb3Jtcyl9XG4gICAgICAke1xuICAgICAgICB0aGlzLmVuYWJsZVNoYXBlVW5pZm9ybXMgPyBzaGFkZXJfdXRpbC5nZXRGbGF0SW5kZXhGcm9tM0RPdXRwdXQoKSA6XG4gICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHNoYWRlcl91dGlsLmdldEZsYXRJbmRleEZyb20zRChvdXRwdXRTaGFwZSl9XG5cbiAgICAgIHZvaWQgbWFpbigpIHtcbiAgICAgICAgaXZlYzMgcmMgPSBnZXRPdXRwdXRDb29yZHMoKTtcblxuICAgICAgICB2ZWM0IHJlc3VsdCA9IHZlYzQoMC4pO1xuXG4gICAgICAgIGl2ZWMzIHRoaXNSQztcbiAgICAgICAgaW50IHJvd3MgPSAke3RoaXMuZW5hYmxlU2hhcGVVbmlmb3JtcyA/ICdvdXRTaGFwZVsxXScgOiBvdXRwdXRTaGFwZVsxXX07XG4gICAgICAgIGludCBjb2xzID0gJHt0aGlzLmVuYWJsZVNoYXBlVW5pZm9ybXMgPyAnb3V0U2hhcGVbMl0nIDogb3V0cHV0U2hhcGVbMl19O1xuXG4gICAgICAgICR7bWFpbkxvb3B9XG5cbiAgICAgICAgc2V0T3V0cHV0KHJlc3VsdCk7XG4gICAgICB9XG4gICAgYDtcbiAgfVxufVxuXG5mdW5jdGlvbiBnZXRSZXNoYXBlZElucHV0Q29vcmRzKFxuICAgIHNoYXBlOiBbbnVtYmVyLCBudW1iZXIsIG51bWJlcl0sIGVuYWJsZVNoYXBlVW5pZm9ybXM6IGJvb2xlYW4pOiBzdHJpbmcge1xuICBjb25zdCBjb29yZHNGcm9tSW5kZXhTbmlwcGV0ID0gZW5hYmxlU2hhcGVVbmlmb3JtcyA/XG4gICAgICBzaGFkZXJfdXRpbC5nZXRMb2dpY2FsQ29vcmRpbmF0ZXNGcm9tRmxhdEluZGV4QnlVbmlmb3JtKFxuICAgICAgICAgIFsncicsICdjJywgJ2QnXSwgJ2lucHV0U2hhcGUnKSA6XG4gICAgICBzaGFkZXJfdXRpbC5nZXRMb2dpY2FsQ29vcmRpbmF0ZXNGcm9tRmxhdEluZGV4KFsncicsICdjJywgJ2QnXSwgc2hhcGUpO1xuXG4gIHJldHVybiBgXG4gICAgaXZlYzMgaW5wdXRDb29yZHNGcm9tUmVzaGFwZWRPdXRDb29yZHMoaW50IGluZGV4KSB7XG4gICAgICAke2Nvb3Jkc0Zyb21JbmRleFNuaXBwZXR9XG4gICAgICByZXR1cm4gaXZlYzMociwgYywgZCk7XG4gICAgfVxuICBgO1xufVxuIl19
|