/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { getChannels } from './packing_util';
|
import { getCoordsDataType } from './shader_compiler';
|
/**
|
* Example shader code for
|
* `mirrorPad(tf.tensor1d([1, 2, 3], 'int32'), [[2, 2]], 'reflect')`
|
* ```
|
* const int start = int(2);
|
* const int end = int(5);
|
*
|
* void main() {
|
* int outputLoc = getOutputCoords();
|
* vec4 result = vec4(0.);
|
*
|
* int rc = outputLoc;
|
*
|
* int source = rc;
|
* if (source < start) {
|
* source = start * 2 - source - 0;
|
* } else if (source >= end) {
|
* source = (end - 1) * 2 - source + 0;
|
* }
|
* source -= start;
|
*
|
* result[0] = getChannel(getX(source), source);
|
* rc += 1;
|
* if(rc < 6) {
|
* int source = rc;
|
* if (source < start) {
|
* source = start * 2 - source - 0;
|
* } else if (source >= end) {
|
* source = (end - 1) * 2 - source + 0;
|
* }
|
* source -= start;
|
*
|
* result[1] = getChannel(getX(source), source);
|
* }
|
*
|
* setOutput(result);
|
* }
|
* ```
|
*/
|
export class MirrorPadPackedProgram {
|
constructor(xShape, paddings, mode) {
|
this.variableNames = ['x'];
|
this.packedInputs = true;
|
this.packedOutput = true;
|
this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);
|
const rank = xShape.length;
|
const dtype = getCoordsDataType(rank);
|
const start = paddings.map(p => p[0]).join(',');
|
const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
|
const coords = getChannels('rc', rank);
|
const source = getChannels('source', rank);
|
const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`;
|
const innerDims = rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`;
|
const offset = mode === 'reflect' ? 0 : 1;
|
let mainLoop = '';
|
if (rank === 1) {
|
const padSetup = `
|
${dtype} source = rc;
|
if (source < start) {
|
source = start * 2 - source - ${offset};
|
} else if (source >= end) {
|
source = (end - 1) * 2 - source + ${offset};
|
}
|
source -= start;
|
`;
|
mainLoop = `
|
${dtype} rc = outputLoc;
|
${padSetup}
|
result[0] = getChannel(getX(${source.join()}), ${innerDims});
|
${coords[rank - 1]} += 1;
|
if(${cLimit}) {
|
${padSetup}
|
result[1] = getChannel(getX(${source.join()}), ${innerDims});
|
}
|
`;
|
}
|
else {
|
const padSetup = `
|
${dtype} source = rc;
|
${dtype} lt = ${dtype}(lessThan(source, start));
|
${dtype} gte = ${dtype}(greaterThanEqual(source, end));
|
${dtype} orig = 1 - (lt + gte);
|
source = orig * source +
|
lt * (start * 2 - source - ${offset}) +
|
gte * ((end - 1) * 2 - source + ${offset});
|
source -= start;
|
`;
|
mainLoop = `
|
${dtype} rc = outputLoc;
|
${padSetup}
|
result[0] = getChannel(getX(${source.join()}), ${innerDims});
|
${coords[rank - 1]} += 1;
|
if(${cLimit}) {
|
${padSetup}
|
result[1] = getChannel(getX(${source.join()}), ${innerDims});
|
}
|
rc = outputLoc;
|
${coords[rank - 2]} += 1;
|
if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) {
|
${padSetup}
|
result[2] = getChannel(getX(${source.join()}), ${innerDims});
|
${coords[rank - 1]} += 1;
|
if(${cLimit}) {
|
${padSetup}
|
result[3] = getChannel(getX(${source.join()}), ${innerDims});
|
}
|
}
|
`;
|
}
|
this.userCode = `
|
const ${dtype} start = ${dtype}(${start});
|
const ${dtype} end = ${dtype}(${end});
|
|
void main() {
|
${dtype} outputLoc = getOutputCoords();
|
vec4 result = vec4(0.);
|
${mainLoop}
|
setOutput(result);
|
}
|
`;
|
}
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"mirror_pad_packed_gpu.js","sourceRoot":"","sources":["../../../../../tfjs-backend-webgl/src/mirror_pad_packed_gpu.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAGH,OAAO,EAAC,WAAW,EAAC,MAAM,gBAAgB,CAAC;AAC3C,OAAO,EAAC,iBAAiB,EAAC,MAAM,mBAAmB,CAAC;AAEpD;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAsCG;AACH,MAAM,OAAO,sBAAsB;IAOjC,YACI,MAAgB,EAAE,QAAiC,EACnD,IAA2B;QAR/B,kBAAa,GAAG,CAAC,GAAG,CAAC,CAAC;QACtB,iBAAY,GAAG,IAAI,CAAC;QACpB,iBAAY,GAAG,IAAI,CAAC;QAOlB,IAAI,CAAC,WAAW,GAAG,QAAQ,CAAC,GAAG,CAC3B,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,eAAe,GAAG,MAAM,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,cAAc,CAAC,CAAC;QACtE,MAAM,IAAI,GAAG,MAAM,CAAC,MAAM,CAAC;QAC3B,MAAM,KAAK,GAAG,iBAAiB,CAAC,IAAI,CAAC,CAAC;QAEtC,MAAM,KAAK,GAAG,QAAQ,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;QAChD,MAAM,GAAG,GAAG,QAAQ,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;QAC/D,MAAM,MAAM,GAAG,WAAW,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC;QACvC,MAAM,MAAM,GAAG,WAAW,CAAC,QAAQ,EAAE,IAAI,CAAC,CAAC;QAC3C,MAAM,MAAM,GAAG,GAAG,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC,MAAM,IAAI,CAAC,WAAW,CAAC,IAAI,GAAG,CAAC,CAAC,EAAE,CAAC;QACrE,MAAM,SAAS,GACX,IAAI,KAAK,CAAC,CAAC,CAAC,CAAC,QAAQ,CAAC,CAAC,CAAC,QAAQ,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,EAAE,GAAG,CAAC;QAC/D,MAAM,MAAM,GAAG,IAAI,KAAK,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAE1C,IAAI,QAAQ,GAAG,EAAE,CAAC;QAClB,IAAI,IAAI,KAAK,CAAC,EAAE;YACd,MAAM,QAAQ,GAAG;UACb,KAAK;;0CAE2B,MAAM;;8CAEF,MAAM;;;OAG7C,CAAC;YACF,QAAQ,GAAG;UACP,KAAK;UACL,QAAQ;sCACoB,MAAM,CAAC,IAAI,EAAE,MAAM,SAAS;UACxD,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;aACb,MAAM;YACP,QAAQ;wCACoB,MAAM,CAAC,IAAI,EAAE,MAAM,SAAS;;OAE7D,CAAC;SACH;aAAM;YACL,MAAM,QAAQ,GAAG;UACb,KAAK;UACL,KAAK,SAAS,KAAK;UACnB,KAAK,UAAU,KAAK;UACpB,KAAK;;6CAE8B,MAAM;kDACD,MAAM;;OAEjD,CAAC;YAEF,QAAQ,GAAG;UACP,KAAK;UACL,QAAQ;sCACoB,MAAM,CAAC,IAAI,EAAE,MAAM,SAAS;UACxD,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;aACb,MAAM;YACP,QAAQ;wCACoB,MAAM,CAAC,IAAI,EAAE,MAAM,SAAS;;;UAG1D,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;aACb,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC,MAAM,IAAI,CAAC,WAAW,CAAC,IAAI,GAAG,CAAC,CAAC;YACjD,QAAQ;wCACoB,MAAM,CAAC,IAAI,EAAE,MAAM,SAAS;YACxD,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;eACb,MAAM;cACP,QAAQ;0CACoB,MAAM,CAAC,IAAI,EAAE,MAAM,SAAS;;;OAG/D,CAAC;SACH;QAED,IAAI,CAAC,QAAQ,GAAG;cACN,KAAK,YAAY,KAAK,IAAI,KAAK;cAC/B,KAAK,UAAU,KAAK,IAAI,GAAG;;;UAG/B,KAAK;;UAEL,QAAQ;;;KAGb,CAAC;IACJ,CAAC;CACF","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getChannels} from './packing_util';\nimport {getCoordsDataType} from './shader_compiler';\n\n/**\n * Example shader code for\n * `mirrorPad(tf.tensor1d([1, 2, 3], 'int32'), [[2, 2]], 'reflect')`\n * ```\n *    const int start = int(2);\n *    const int end = int(5);\n *\n *    void main() {\n *       int outputLoc = getOutputCoords();\n *       vec4 result = vec4(0.);\n *\n *       int rc = outputLoc;\n *\n *       int source = rc;\n *       if (source < start) {\n *         source = start * 2 - source - 0;\n *       } else if (source >= end) {\n *         source = (end - 1) * 2 - source + 0;\n *       }\n *       source -= start;\n *\n *       result[0] = getChannel(getX(source), source);\n *       rc += 1;\n *       if(rc < 6) {\n *          int source = rc;\n *          if (source < start) {\n *            source = start * 2 - source - 0;\n *          } else if (source >= end) {\n *            source = (end - 1) * 2 - source + 0;\n *          }\n *          source -= start;\n *\n *         result[1] = getChannel(getX(source), source);\n *       }\n *\n *       setOutput(result);\n *     }\n * ```\n */\nexport class MirrorPadPackedProgram implements GPGPUProgram {\n  variableNames = ['x'];\n  packedInputs = true;\n  packedOutput = true;\n  outputShape: number[];\n  userCode: string;\n\n  constructor(\n      xShape: number[], paddings: Array<[number, number]>,\n      mode: 'reflect'|'symmetric') {\n    this.outputShape = paddings.map(\n        (p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);\n    const rank = xShape.length;\n    const dtype = getCoordsDataType(rank);\n\n    const start = paddings.map(p => p[0]).join(',');\n    const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');\n    const coords = getChannels('rc', rank);\n    const source = getChannels('source', rank);\n    const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`;\n    const innerDims =\n        rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`;\n    const offset = mode === 'reflect' ? 0 : 1;\n\n    let mainLoop = '';\n    if (rank === 1) {\n      const padSetup = `\n        ${dtype} source = rc;\n        if (source < start) {\n          source = start * 2 - source - ${offset};\n        } else if (source >= end) {\n          source = (end - 1) * 2 - source + ${offset};\n        }\n        source -= start;\n      `;\n      mainLoop = `\n        ${dtype} rc = outputLoc;\n        ${padSetup}\n        result[0] = getChannel(getX(${source.join()}), ${innerDims});\n        ${coords[rank - 1]} += 1;\n        if(${cLimit}) {\n          ${padSetup}\n          result[1] = getChannel(getX(${source.join()}), ${innerDims});\n        }\n      `;\n    } else {\n      const padSetup = `\n        ${dtype} source = rc;\n        ${dtype} lt = ${dtype}(lessThan(source, start));\n        ${dtype} gte = ${dtype}(greaterThanEqual(source, end));\n        ${dtype} orig = 1 - (lt + gte);\n        source = orig * source +\n                lt * (start * 2 - source - ${offset}) +\n                gte * ((end - 1) * 2 - source + ${offset});\n        source -= start;\n      `;\n\n      mainLoop = `\n        ${dtype} rc = outputLoc;\n        ${padSetup}\n        result[0] = getChannel(getX(${source.join()}), ${innerDims});\n        ${coords[rank - 1]} += 1;\n        if(${cLimit}) {\n          ${padSetup}\n          result[1] = getChannel(getX(${source.join()}), ${innerDims});\n        }\n        rc = outputLoc;\n        ${coords[rank - 2]} += 1;\n        if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) {\n          ${padSetup}\n          result[2] = getChannel(getX(${source.join()}), ${innerDims});\n          ${coords[rank - 1]} += 1;\n          if(${cLimit}) {\n            ${padSetup}\n            result[3] = getChannel(getX(${source.join()}), ${innerDims});\n          }\n        }\n      `;\n    }\n\n    this.userCode = `\n      const ${dtype} start = ${dtype}(${start});\n      const ${dtype} end = ${dtype}(${end});\n\n      void main() {\n        ${dtype} outputLoc = getOutputCoords();\n        vec4 result = vec4(0.);\n        ${mainLoop}\n        setOutput(result);\n      }\n    `;\n  }\n}\n"]}
|