/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { getGlslDifferences } from './glsl_version';
|
import { useShapeUniforms } from './gpgpu_math';
|
import * as shader_util from './shader_compiler_util';
|
/*
|
This is how the shader encodes a tensor with shape = [2, 3, 5]
|
(indices are [batch, row, col]).
|
|
000|001 002|003 004|xxx 020|021 022|023 024|xxx
|
------- ------- ------- ------- ------- -------
|
010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx
|
|
100|101 102|103 104|xxx 120|121 122|123 124|xxx
|
------- ------- ------- ------- ------- -------
|
110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx
|
|
Single texels contain only values from the same batch, and from adjacent rows
|
and columns.
|
*/
|
export class EncodeMatrixPackedProgram {
|
constructor(outputShape, inputIsUnsignedByte = false) {
|
this.variableNames = ['A'];
|
this.packedInputs = false;
|
this.packedOutput = true;
|
this.customUniforms = [{ name: 'texShape', type: 'ivec2' }];
|
const glsl = getGlslDifferences();
|
this.outputShape = outputShape;
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
let mainLoop = '';
|
let output = 'result';
|
if (inputIsUnsignedByte) {
|
output = 'floor(result * 255. + 0.5)';
|
}
|
for (let row = 0; row <= 1; row++) {
|
for (let col = 0; col <= 1; col++) {
|
const channel = row * 2 + col;
|
mainLoop += `
|
localCoords = coords;
|
if(localCoords[2] + ${col} < ${this.enableShapeUniforms ? 'outShape[2]' : `${outputShape[2]}`}) {
|
localCoords[2] += ${col};
|
if (localCoords[1] + ${row} < ${this.enableShapeUniforms ? 'outShape[1]' : `${outputShape[1]}`}) {
|
localCoords[1] += ${row};
|
|
flatIndex = getFlatIndex(localCoords);
|
offset = imod(flatIndex, 4);
|
|
flatIndex = idiv(flatIndex, 4, 1.);
|
|
int r = flatIndex / texShape[1];
|
int c = imod(flatIndex, texShape[1]);
|
vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]);
|
values = ${glsl.texture2D}(A, uv);
|
|
if (offset == 0) {
|
result[${channel}] = values[0];
|
} else if (offset == 1) {
|
result[${channel}] = values[1];
|
} else if (offset == 2) {
|
result[${channel}] = values[2];
|
} else {
|
result[${channel}] = values[3];
|
}
|
}
|
}
|
`;
|
}
|
}
|
this.userCode = `
|
${this.enableShapeUniforms ? shader_util.getFlatIndexFrom3DOutput() :
|
shader_util.getFlatIndexFrom3D(outputShape)}
|
|
void main() {
|
ivec3 coords = getOutputCoords();
|
|
vec4 result = vec4(0.);
|
int flatIndex, r, c, offset;
|
ivec3 localCoords;
|
vec2 uv;
|
vec4 values;
|
|
${mainLoop}
|
|
${glsl.output} = ${output};
|
}
|
`;
|
}
|
}
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiZW5jb2RlX21hdHJpeF9wYWNrZWRfZ3B1LmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vdGZqcy1iYWNrZW5kLXdlYmdsL3NyYy9lbmNvZGVfbWF0cml4X3BhY2tlZF9ncHUudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLGtCQUFrQixFQUFDLE1BQU0sZ0JBQWdCLENBQUM7QUFDbEQsT0FBTyxFQUFlLGdCQUFnQixFQUFDLE1BQU0sY0FBYyxDQUFDO0FBQzVELE9BQU8sS0FBSyxXQUFXLE1BQU0sd0JBQXdCLENBQUM7QUFFdEQ7Ozs7Ozs7Ozs7Ozs7O0dBY0c7QUFFSCxNQUFNLE9BQU8seUJBQXlCO0lBU3BDLFlBQ0ksV0FBcUMsRUFBRSxtQkFBbUIsR0FBRyxLQUFLO1FBVHRFLGtCQUFhLEdBQUcsQ0FBQyxHQUFHLENBQUMsQ0FBQztRQUd0QixpQkFBWSxHQUFHLEtBQUssQ0FBQztRQUNyQixpQkFBWSxHQUFHLElBQUksQ0FBQztRQUVwQixtQkFBYyxHQUFHLENBQUMsRUFBQyxJQUFJLEVBQUUsVUFBVSxFQUFFLElBQUksRUFBRSxPQUFnQixFQUFFLENBQUMsQ0FBQztRQUk3RCxNQUFNLElBQUksR0FBRyxrQkFBa0IsRUFBRSxDQUFDO1FBQ2xDLElBQUksQ0FBQyxXQUFXLEdBQUcsV0FBVyxDQUFDO1FBQy9CLElBQUksQ0FBQyxtQkFBbUIsR0FBRyxnQkFBZ0IsQ0FBQyxJQUFJLENBQUMsV0FBVyxDQUFDLE1BQU0sQ0FBQyxDQUFDO1FBRXJFLElBQUksUUFBUSxHQUFHLEVBQUUsQ0FBQztRQUNsQixJQUFJLE1BQU0sR0FBRyxRQUFRLENBQUM7UUFDdEIsSUFBSSxtQkFBbUIsRUFBRTtZQUN2QixNQUFNLEdBQUcsNEJBQTRCLENBQUM7U0FDdkM7UUFFRCxLQUFLLElBQUksR0FBRyxHQUFHLENBQUMsRUFBRSxHQUFHLElBQUksQ0FBQyxFQUFFLEdBQUcsRUFBRSxFQUFFO1lBQ2pDLEtBQUssSUFBSSxHQUFHLEdBQUcsQ0FBQyxFQUFFLEdBQUcsSUFBSSxDQUFDLEVBQUUsR0FBRyxFQUFFLEVBQUU7Z0JBQ2pDLE1BQU0sT0FBTyxHQUFHLEdBQUcsR0FBRyxDQUFDLEdBQUcsR0FBRyxDQUFDO2dCQUU5QixRQUFRLElBQUk7O2dDQUVZLEdBQUcsTUFDdkIsSUFBSSxDQUFDLG1CQUFtQixDQUFDLENBQUMsQ0FBQyxhQUFhLENBQUMsQ0FBQyxDQUFDLEdBQUcsV0FBVyxDQUFDLENBQUMsQ0FBQyxFQUFFOzhCQUM1QyxHQUFHO2lDQUNBLEdBQUcsTUFDeEIsSUFBSSxDQUFDLG1CQUFtQixDQUFDLENBQUMsQ0FBQyxhQUFhLENBQUMsQ0FBQyxDQUFDLEdBQUcsV0FBVyxDQUFDLENBQUMsQ0FBQyxFQUFFO2dDQUMxQyxHQUFHOzs7Ozs7Ozs7O3VCQVVaLElBQUksQ0FBQyxTQUFTOzs7dUJBR2QsT0FBTzs7dUJBRVAsT0FBTzs7dUJBRVAsT0FBTzs7dUJBRVAsT0FBTzs7OztTQUlyQixDQUFDO2FBQ0g7U0FDRjtRQUVELElBQUksQ0FBQyxRQUFRLEdBQUc7VUFFWixJQUFJLENBQUMsbUJBQW1CLENBQUMsQ0FBQyxDQUFDLFdBQVcsQ0FBQyx3QkFBd0IsRUFBRSxDQUFDLENBQUM7WUFDeEMsV0FBVyxDQUFDLGtCQUFrQixDQUFDLFdBQVcsQ0FBQzs7Ozs7Ozs7Ozs7WUFXbEUsUUFBUTs7WUFFUixJQUFJLENBQUMsTUFBTSxNQUFNLE1BQU07O0tBRTlCLENBQUM7SUFDSixDQUFDO0NBQ0YiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAxOCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCB7Z2V0R2xzbERpZmZlcmVuY2VzfSBmcm9tICcuL2dsc2xfdmVyc2lvbic7XG5pbXBvcnQge0dQR1BVUHJvZ3JhbSwgdXNlU2hhcGVVbmlmb3Jtc30gZnJvbSAnLi9ncGdwdV9tYXRoJztcbmltcG9ydCAqIGFzIHNoYWRlcl91dGlsIGZyb20gJy4vc2hhZGVyX2NvbXBpbGVyX3V0aWwnO1xuXG4vKlxuVGhpcyBpcyBob3cgdGhlIHNoYWRlciBlbmNvZGVzIGEgdGVuc29yIHdpdGggc2hhcGUgPSBbMiwgMywgNV1cbihpbmRpY2VzIGFyZSBbYmF0Y2gsIHJvdywgY29sXSkuXG5cbjAwMHwwMDEgICAwMDJ8MDAzICAgMDA0fHh4eCAgIDAyMHwwMjEgICAwMjJ8MDIzICAgMDI0fHh4eFxuLS0tLS0tLSAgIC0tLS0tLS0gICAtLS0tLS0tICAgLS0tLS0tLSAgIC0tLS0tLS0gICAtLS0tLS0tXG4wMTB8MDExICAgMDEyfDAxMyAgIDAxNHx4eHggICB4eHh8eHh4ICAgeHh4fHh4eCAgIHh4eHx4eHhcblxuMTAwfDEwMSAgIDEwMnwxMDMgICAxMDR8eHh4ICAgMTIwfDEyMSAgIDEyMnwxMjMgICAxMjR8eHh4XG4tLS0tLS0tICAgLS0tLS0tLSAgIC0tLS0tLS0gICAtLS0tLS0tICAgLS0tLS0tLSAgIC0tLS0tLS1cbjExMHwxMTEgICAxMTJ8MTEzICAgMTE0fHh4eCAgIHh4eHx4eHggICB4eHh8eHh4ICAgeHh4fHh4eFxuXG5TaW5nbGUgdGV4ZWxzIGNvbnRhaW4gb25seSB2YWx1ZXMgZnJvbSB0aGUgc2FtZSBiYXRjaCwgYW5kIGZyb20gYWRqYWNlbnQgcm93c1xuYW5kIGNvbHVtbnMuXG4gKi9cblxuZXhwb3J0IGNsYXNzIEVuY29kZU1hdHJpeFBhY2tlZFByb2dyYW0gaW1wbGVtZW50cyBHUEdQVVByb2dyYW0ge1xuICB2YXJpYWJsZU5hbWVzID0gWydBJ107XG4gIHVzZXJDb2RlOiBzdHJpbmc7XG4gIG91dHB1dFNoYXBlOiBudW1iZXJbXTtcbiAgcGFja2VkSW5wdXRzID0gZmFsc2U7XG4gIHBhY2tlZE91dHB1dCA9IHRydWU7XG4gIGVuYWJsZVNoYXBlVW5pZm9ybXM6IGJvb2xlYW47XG4gIGN1c3RvbVVuaWZvcm1zID0gW3tuYW1lOiAndGV4U2hhcGUnLCB0eXBlOiAnaXZlYzInIGFzIGNvbnN0IH1dO1xuXG4gIGNvbnN0cnVjdG9yKFxuICAgICAgb3V0cHV0U2hhcGU6IFtudW1iZXIsIG51bWJlciwgbnVtYmVyXSwgaW5wdXRJc1Vuc2lnbmVkQnl0ZSA9IGZhbHNlKSB7XG4gICAgY29uc3QgZ2xzbCA9IGdldEdsc2xEaWZmZXJlbmNlcygpO1xuICAgIHRoaXMub3V0cHV0U2hhcGUgPSBvdXRwdXRTaGFwZTtcbiAgICB0aGlzLmVuYWJsZVNoYXBlVW5pZm9ybXMgPSB1c2VTaGFwZVVuaWZvcm1zKHRoaXMub3V0cHV0U2hhcGUubGVuZ3RoKTtcblxuICAgIGxldCBtYWluTG9vcCA9ICcnO1xuICAgIGxldCBvdXRwdXQgPSAncmVzdWx0JztcbiAgICBpZiAoaW5wdXRJc1Vuc2lnbmVkQnl0ZSkge1xuICAgICAgb3V0cHV0ID0gJ2Zsb29yKHJlc3VsdCAqIDI1NS4gKyAwLjUpJztcbiAgICB9XG5cbiAgICBmb3IgKGxldCByb3cgPSAwOyByb3cgPD0gMTsgcm93KyspIHtcbiAgICAgIGZvciAobGV0IGNvbCA9IDA7IGNvbCA8PSAxOyBjb2wrKykge1xuICAgICAgICBjb25zdCBjaGFubmVsID0gcm93ICogMiArIGNvbDtcblxuICAgICAgICBtYWluTG9vcCArPSBgXG4gICAgICAgICAgbG9jYWxDb29yZHMgPSBjb29yZHM7XG4gICAgICAgICAgaWYobG9jYWxDb29yZHNbMl0gKyAke2NvbH0gPCAke1xuICAgICAgICAgICAgdGhpcy5lbmFibGVTaGFwZVVuaWZvcm1zID8gJ291dFNoYXBlWzJdJyA6IGAke291dHB1dFNoYXBlWzJdfWB9KSB7XG4gICAgICAgICAgbG9jYWxDb29yZHNbMl0gKz0gJHtjb2x9O1xuICAgICAgICAgIGlmIChsb2NhbENvb3Jkc1sxXSArICR7cm93fSA8ICR7XG4gICAgICAgICAgICB0aGlzLmVuYWJsZVNoYXBlVW5pZm9ybXMgPyAnb3V0U2hhcGVbMV0nIDogYCR7b3V0cHV0U2hhcGVbMV19YH0pIHtcbiAgICAgICAgICAgIGxvY2FsQ29vcmRzWzFdICs9ICR7cm93fTtcblxuICAgICAgICAgICAgZmxhdEluZGV4ID0gZ2V0RmxhdEluZGV4KGxvY2FsQ29vcmRzKTtcbiAgICAgICAgICAgIG9mZnNldCA9IGltb2QoZmxhdEluZGV4LCA0KTtcblxuICAgICAgICAgICAgZmxhdEluZGV4ID0gaWRpdihmbGF0SW5kZXgsIDQsIDEuKTtcblxuICAgICAgICAgICAgaW50IHIgPSBmbGF0SW5kZXggLyB0ZXhTaGFwZVsxXTtcbiAgICAgICAgICAgIGludCBjID0gaW1vZChmbGF0SW5kZXgsIHRleFNoYXBlWzFdKTtcbiAgICAgICAgICAgIHZlYzIgdXYgPSAodmVjMihjLCByKSArIGhhbGZDUikgLyB2ZWMyKHRleFNoYXBlWzFdLCB0ZXhTaGFwZVswXSk7XG4gICAgICAgICAgICB2YWx1ZXMgPSAke2dsc2wudGV4dHVyZTJEfShBLCB1dik7XG5cbiAgICAgICAgICAgIGlmIChvZmZzZXQgPT0gMCkge1xuICAgICAgICAgICAgICByZXN1bHRbJHtjaGFubmVsfV0gPSB2YWx1ZXNbMF07XG4gICAgICAgICAgICB9IGVsc2UgaWYgKG9mZnNldCA9PSAxKSB7XG4gICAgICAgICAgICAgIHJlc3VsdFske2NoYW5uZWx9XSA9IHZhbHVlc1sxXTtcbiAgICAgICAgICAgIH0gZWxzZSBpZiAob2Zmc2V0ID09IDIpIHtcbiAgICAgICAgICAgICAgcmVzdWx0WyR7Y2hhbm5lbH1dID0gdmFsdWVzWzJdO1xuICAgICAgICAgICAgfSBlbHNlIHtcbiAgICAgICAgICAgICAgcmVzdWx0WyR7Y2hhbm5lbH1dID0gdmFsdWVzWzNdO1xuICAgICAgICAgICAgfVxuICAgICAgICAgIH1cbiAgICAgICAgfVxuICAgICAgICBgO1xuICAgICAgfVxuICAgIH1cblxuICAgIHRoaXMudXNlckNvZGUgPSBgXG4gICAgICAgICR7XG4gICAgICAgIHRoaXMuZW5hYmxlU2hhcGVVbmlmb3JtcyA/IHNoYWRlcl91dGlsLmdldEZsYXRJbmRleEZyb20zRE91dHB1dCgpIDpcbiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgc2hhZGVyX3V0aWwuZ2V0RmxhdEluZGV4RnJvbTNEKG91dHB1dFNoYXBlKX1cblxuICAgICAgICB2b2lkIG1haW4oKSB7XG4gICAgICAgICAgaXZlYzMgY29vcmRzID0gZ2V0T3V0cHV0Q29vcmRzKCk7XG5cbiAgICAgICAgICB2ZWM0IHJlc3VsdCA9IHZlYzQoMC4pO1xuICAgICAgICAgIGludCBmbGF0SW5kZXgsIHIsIGMsIG9mZnNldDtcbiAgICAgICAgICBpdmVjMyBsb2NhbENvb3JkcztcbiAgICAgICAgICB2ZWMyIHV2O1xuICAgICAgICAgIHZlYzQgdmFsdWVzO1xuXG4gICAgICAgICAgJHttYWluTG9vcH1cblxuICAgICAgICAgICR7Z2xzbC5vdXRwdXR9ID0gJHtvdXRwdXR9O1xuICAgICAgICB9XG4gICAgYDtcbiAgfVxufVxuIl19
|