/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { useShapeUniforms } from './gpgpu_math';
|
import { getChannels } from './packing_util';
|
import { getCoordsDataType } from './shader_compiler';
|
export class PackProgram {
|
constructor(outputShape) {
|
this.variableNames = ['A'];
|
this.packedInputs = false;
|
this.packedOutput = true;
|
// Only input / output 3D tensors.
|
this.outputShape = outputShape;
|
this.rank = outputShape.length;
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
if (this.rank === 0) {
|
this.userCode = `
|
void main() {
|
setOutput(vec4(getA(), 0., 0., 0.));
|
}
|
`;
|
}
|
else {
|
const channels = getChannels('rc', this.rank);
|
const dtype = getCoordsDataType(this.rank);
|
const outOfBoundsCondition = this.getOutOfBoundsCondition(channels);
|
const setup = this.getSetup(channels);
|
const output = this.getOutput(channels);
|
this.userCode = `
|
void main() {
|
${dtype} rc = getOutputCoords();
|
|
if(${outOfBoundsCondition}) {
|
setOutput(vec4(0));
|
} else {
|
${setup}
|
|
setOutput(vec4(${output}));
|
}
|
}
|
`;
|
}
|
}
|
getSourceCoordsArr(dims) {
|
const coords = [];
|
for (let row = 0; row <= 1; row++) {
|
for (let col = 0; col <= 1; col++) {
|
let coord = `${row === 0 ? 'r' : 'rp1'}, ${col === 0 ? 'c' : 'cp1'}`;
|
for (let d = 2; d < this.rank; d++) {
|
coord = `${dims[dims.length - 1 - d]},` + coord;
|
}
|
coords.push(coord);
|
}
|
}
|
return coords;
|
}
|
getOutOfBoundsCondition(dims) {
|
if (this.rank === 1) {
|
return `rc > ${this.enableShapeUniforms ? 'outShape' : this.outputShape[0]}`;
|
}
|
let cond = '';
|
for (let i = this.rank - 2; i < this.rank; i++) {
|
cond += `${dims[i]} >= ${this.enableShapeUniforms ? `outShape[${i}]` : this.outputShape[i]}`;
|
if (i < this.rank - 1) {
|
cond += '||';
|
}
|
}
|
return cond;
|
}
|
getSetup(dims) {
|
if (this.rank === 1) {
|
return '';
|
}
|
const innerDims = dims.slice(-2);
|
const col = this.enableShapeUniforms ? `outShape[${this.rank} - 1]` :
|
this.outputShape[this.rank - 1];
|
const row = this.enableShapeUniforms ? `outShape[${this.rank} - 2]` :
|
this.outputShape[this.rank - 2];
|
return `
|
int r = ${innerDims[0]};
|
int c = ${innerDims[1]};
|
int rp1 = r + 1;
|
int cp1 = c + 1;
|
|
bool cEdge = cp1 >= ${col};
|
bool rEdge = rp1 >= ${row};
|
`;
|
}
|
getOutput(dims) {
|
const sourceCoords = this.getSourceCoordsArr(dims);
|
if (this.rank === 1) {
|
const outShape = this.enableShapeUniforms ? 'outShape' : this.outputShape[0];
|
return `getA(rc), (rc + 1 >= ${outShape} ? 0. : getA(rc + 1)), 0, 0`;
|
}
|
return `getA(${sourceCoords[0]}),
|
cEdge ? 0. : getA(${sourceCoords[1]}),
|
rEdge ? 0. : getA(${sourceCoords[2]}),
|
rEdge || cEdge ? 0. : getA(${sourceCoords[3]})`;
|
}
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"pack_gpu.js","sourceRoot":"","sources":["../../../../../tfjs-backend-webgl/src/pack_gpu.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAe,gBAAgB,EAAC,MAAM,cAAc,CAAC;AAC5D,OAAO,EAAC,WAAW,EAAC,MAAM,gBAAgB,CAAC;AAC3C,OAAO,EAAC,iBAAiB,EAAC,MAAM,mBAAmB,CAAC;AAEpD,MAAM,OAAO,WAAW;IAStB,YACI,WACY;QAVhB,kBAAa,GAAG,CAAC,GAAG,CAAC,CAAC;QAGtB,iBAAY,GAAG,KAAK,CAAC;QACrB,iBAAY,GAAG,IAAI,CAAC;QAOC,kCAAkC;QACrD,IAAI,CAAC,WAAW,GAAG,WAAW,CAAC;QAC/B,IAAI,CAAC,IAAI,GAAG,WAAW,CAAC,MAAM,CAAC;QAC/B,IAAI,CAAC,mBAAmB,GAAG,gBAAgB,CAAC,IAAI,CAAC,WAAW,CAAC,MAAM,CAAC,CAAC;QAErE,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;YACnB,IAAI,CAAC,QAAQ,GAAG;;;;OAIf,CAAC;SACH;aAAM;YACL,MAAM,QAAQ,GAAG,WAAW,CAAC,IAAI,EAAE,IAAI,CAAC,IAAI,CAAC,CAAC;YAC9C,MAAM,KAAK,GAAG,iBAAiB,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;YAC3C,MAAM,oBAAoB,GAAG,IAAI,CAAC,uBAAuB,CAAC,QAAQ,CAAC,CAAC;YACpE,MAAM,KAAK,GAAG,IAAI,CAAC,QAAQ,CAAC,QAAQ,CAAC,CAAC;YACtC,MAAM,MAAM,GAAG,IAAI,CAAC,SAAS,CAAC,QAAQ,CAAC,CAAC;YAExC,IAAI,CAAC,QAAQ,GAAG;;YAEV,KAAK;;eAEF,oBAAoB;;;cAGrB,KAAK;;6BAEU,MAAM;;;OAG5B,CAAC;SACH;IACH,CAAC;IAEO,kBAAkB,CAAC,IAAc;QACvC,MAAM,MAAM,GAAG,EAAE,CAAC;QAElB,KAAK,IAAI,GAAG,GAAG,CAAC,EAAE,GAAG,IAAI,CAAC,EAAE,GAAG,EAAE,EAAE;YACjC,KAAK,IAAI,GAAG,GAAG,CAAC,EAAE,GAAG,IAAI,CAAC,EAAE,GAAG,EAAE,EAAE;gBACjC,IAAI,KAAK,GAAG,GAAG,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,KAAK,KAAK,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,KAAK,EAAE,CAAC;gBAErE,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,IAAI,EAAE,CAAC,EAAE,EAAE;oBAClC,KAAK,GAAG,GAAG,IAAI,CAAC,IAAI,CAAC,MAAM,GAAG,CAAC,GAAG,CAAC,CAAC,GAAG,GAAG,KAAK,CAAC;iBACjD;gBAED,MAAM,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;aACpB;SACF;QACD,OAAO,MAAM,CAAC;IAChB,CAAC;IAEO,uBAAuB,CAAC,IAAc;QAC5C,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;YACnB,OAAO,QACH,IAAI,CAAC,mBAAmB,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC,EAAE,CAAC;SACnE;QAED,IAAI,IAAI,GAAG,EAAE,CAAC;QACd,KAAK,IAAI,CAAC,GAAG,IAAI,CAAC,IAAI,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,IAAI,EAAE,CAAC,EAAE,EAAE;YAC9C,IAAI,IAAI,GAAG,IAAI,CAAC,CAAC,CAAC,OACd,IAAI,CAAC,mBAAmB,CAAC,CAAC,CAAC,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC,EAAE,CAAC;YACxE,IAAI,CAAC,GAAG,IAAI,CAAC,IAAI,GAAG,CAAC,EAAE;gBACrB,IAAI,IAAI,IAAI,CAAC;aACd;SACF;QAED,OAAO,IAAI,CAAC;IACd,CAAC;IAEO,QAAQ,CAAC,IAAc;QAC7B,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;YACnB,OAAO,EAAE,CAAC;SACX;QAED,MAAM,SAAS,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;QACjC,MAAM,GAAG,GAAG,IAAI,CAAC,mBAAmB,CAAC,CAAC,CAAC,YAAY,IAAI,CAAC,IAAI,OAAO,CAAC,CAAC;YAC9B,IAAI,CAAC,WAAW,CAAC,IAAI,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC;QACvE,MAAM,GAAG,GAAG,IAAI,CAAC,mBAAmB,CAAC,CAAC,CAAC,YAAY,IAAI,CAAC,IAAI,OAAO,CAAC,CAAC;YAC9B,IAAI,CAAC,WAAW,CAAC,IAAI,CAAC,IAAI,GAAG,CAAC,CAAC,CAAC;QAEvE,OAAO;gBACK,SAAS,CAAC,CAAC,CAAC;gBACZ,SAAS,CAAC,CAAC,CAAC;;;;4BAIA,GAAG;4BACH,GAAG;KAC1B,CAAC;IACJ,CAAC;IAEO,SAAS,CAAC,IAAc;QAC9B,MAAM,YAAY,GAAG,IAAI,CAAC,kBAAkB,CAAC,IAAI,CAAC,CAAC;QACnD,IAAI,IAAI,CAAC,IAAI,KAAK,CAAC,EAAE;YACnB,MAAM,QAAQ,GACV,IAAI,CAAC,mBAAmB,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC;YAChE,OAAO,wBAAwB,QAAQ,6BAA6B,CAAC;SACtE;QAED,OAAO,QAAQ,YAAY,CAAC,CAAC,CAAC;gCACF,YAAY,CAAC,CAAC,CAAC;gCACf,YAAY,CAAC,CAAC,CAAC;yCACN,YAAY,CAAC,CAAC,CAAC,GAAG,CAAC;IAC1D,CAAC;CACF","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram, useShapeUniforms} from './gpgpu_math';\nimport {getChannels} from './packing_util';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class PackProgram implements GPGPUProgram {\n  variableNames = ['A'];\n  outputShape: number[];\n  userCode: string;\n  packedInputs = false;\n  packedOutput = true;\n  enableShapeUniforms: boolean;\n  rank: number;\n\n  constructor(\n      outputShape:\n          number[]) {  // TODO(https://github.com/tensorflow/tfjs/issues/893):\n                       // Only input / output 3D tensors.\n    this.outputShape = outputShape;\n    this.rank = outputShape.length;\n    this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);\n\n    if (this.rank === 0) {\n      this.userCode = `\n        void main() {\n          setOutput(vec4(getA(), 0., 0., 0.));\n        }\n      `;\n    } else {\n      const channels = getChannels('rc', this.rank);\n      const dtype = getCoordsDataType(this.rank);\n      const outOfBoundsCondition = this.getOutOfBoundsCondition(channels);\n      const setup = this.getSetup(channels);\n      const output = this.getOutput(channels);\n\n      this.userCode = `\n        void main() {\n          ${dtype} rc = getOutputCoords();\n\n          if(${outOfBoundsCondition}) {\n            setOutput(vec4(0));\n          } else {\n            ${setup}\n\n            setOutput(vec4(${output}));\n          }\n        }\n      `;\n    }\n  }\n\n  private getSourceCoordsArr(dims: string[]): string[] {\n    const coords = [];\n\n    for (let row = 0; row <= 1; row++) {\n      for (let col = 0; col <= 1; col++) {\n        let coord = `${row === 0 ? 'r' : 'rp1'}, ${col === 0 ? 'c' : 'cp1'}`;\n\n        for (let d = 2; d < this.rank; d++) {\n          coord = `${dims[dims.length - 1 - d]},` + coord;\n        }\n\n        coords.push(coord);\n      }\n    }\n    return coords;\n  }\n\n  private getOutOfBoundsCondition(dims: string[]): string {\n    if (this.rank === 1) {\n      return `rc > ${\n          this.enableShapeUniforms ? 'outShape' : this.outputShape[0]}`;\n    }\n\n    let cond = '';\n    for (let i = this.rank - 2; i < this.rank; i++) {\n      cond += `${dims[i]} >= ${\n          this.enableShapeUniforms ? `outShape[${i}]` : this.outputShape[i]}`;\n      if (i < this.rank - 1) {\n        cond += '||';\n      }\n    }\n\n    return cond;\n  }\n\n  private getSetup(dims: string[]): string {\n    if (this.rank === 1) {\n      return '';\n    }\n\n    const innerDims = dims.slice(-2);\n    const col = this.enableShapeUniforms ? `outShape[${this.rank} - 1]` :\n                                           this.outputShape[this.rank - 1];\n    const row = this.enableShapeUniforms ? `outShape[${this.rank} - 2]` :\n                                           this.outputShape[this.rank - 2];\n\n    return `\n      int r = ${innerDims[0]};\n      int c = ${innerDims[1]};\n      int rp1 = r + 1;\n      int cp1 = c + 1;\n\n      bool cEdge = cp1 >= ${col};\n      bool rEdge = rp1 >= ${row};\n    `;\n  }\n\n  private getOutput(dims: string[]): string {\n    const sourceCoords = this.getSourceCoordsArr(dims);\n    if (this.rank === 1) {\n      const outShape =\n          this.enableShapeUniforms ? 'outShape' : this.outputShape[0];\n      return `getA(rc), (rc + 1 >= ${outShape} ? 0. : getA(rc + 1)), 0, 0`;\n    }\n\n    return `getA(${sourceCoords[0]}),\n            cEdge ? 0. : getA(${sourceCoords[1]}),\n            rEdge ? 0. : getA(${sourceCoords[2]}),\n            rEdge || cEdge ? 0. : getA(${sourceCoords[3]})`;\n  }\n}\n"]}
|