/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { getChannels } from './packing_util'; import { getCoordsDataType } from './shader_compiler'; /** * Example shader code for * `mirrorPad(tf.tensor1d([1, 2, 3], 'int32'), [[2, 2]], 'reflect')` * ``` * const int start = int(2); * const int end = int(5); * * void main() { * int outputLoc = getOutputCoords(); * vec4 result = vec4(0.); * * int rc = outputLoc; * * int source = rc; * if (source < start) { * source = start * 2 - source - 0; * } else if (source >= end) { * source = (end - 1) * 2 - source + 0; * } * source -= start; * * result[0] = getChannel(getX(source), source); * rc += 1; * if(rc < 6) { * int source = rc; * if (source < start) { * source = start * 2 - source - 0; * } else if (source >= end) { * source = (end - 1) * 2 - source + 0; * } * source -= start; * * result[1] = getChannel(getX(source), source); * } * * setOutput(result); * } * ``` */ export class MirrorPadPackedProgram { constructor(xShape, paddings, mode) { this.variableNames = ['x']; this.packedInputs = true; this.packedOutput = true; this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */); const rank = xShape.length; const dtype = getCoordsDataType(rank); const start = paddings.map(p => p[0]).join(','); const end = paddings.map((p, i) => p[0] + xShape[i]).join(','); const coords = getChannels('rc', rank); const source = getChannels('source', rank); const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`; const innerDims = rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`; const offset = mode === 'reflect' ? 0 : 1; let mainLoop = ''; if (rank === 1) { const padSetup = ` ${dtype} source = rc; if (source < start) { source = start * 2 - source - ${offset}; } else if (source >= end) { source = (end - 1) * 2 - source + ${offset}; } source -= start; `; mainLoop = ` ${dtype} rc = outputLoc; ${padSetup} result[0] = getChannel(getX(${source.join()}), ${innerDims}); ${coords[rank - 1]} += 1; if(${cLimit}) { ${padSetup} result[1] = getChannel(getX(${source.join()}), ${innerDims}); } `; } else { const padSetup = ` ${dtype} source = rc; ${dtype} lt = ${dtype}(lessThan(source, start)); ${dtype} gte = ${dtype}(greaterThanEqual(source, end)); ${dtype} orig = 1 - (lt + gte); source = orig * source + lt * (start * 2 - source - ${offset}) + gte * ((end - 1) * 2 - source + ${offset}); source -= start; `; mainLoop = ` ${dtype} rc = outputLoc; ${padSetup} result[0] = getChannel(getX(${source.join()}), ${innerDims}); ${coords[rank - 1]} += 1; if(${cLimit}) { ${padSetup} result[1] = getChannel(getX(${source.join()}), ${innerDims}); } rc = outputLoc; ${coords[rank - 2]} += 1; if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) { ${padSetup} result[2] = getChannel(getX(${source.join()}), ${innerDims}); ${coords[rank - 1]} += 1; if(${cLimit}) { ${padSetup} result[3] = getChannel(getX(${source.join()}), ${innerDims}); } } `; } this.userCode = ` const ${dtype} start = ${dtype}(${start}); const ${dtype} end = ${dtype}(${end}); void main() { ${dtype} outputLoc = getOutputCoords(); vec4 result = vec4(0.); ${mainLoop} setOutput(result); } `; } } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"mirror_pad_packed_gpu.js","sourceRoot":"","sources":["../../../../../tfjs-backend-webgl/src/mirror_pad_packed_gpu.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAGH,OAAO,EAAC,WAAW,EAAC,MAAM,gBAAgB,CAAC;AAC3C,OAAO,EAAC,iBAAiB,EAAC,MAAM,mBAAmB,CAAC;AAEpD;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAsCG;AACH,MAAM,OAAO,sBAAsB;IAOjC,YACI,MAAgB,EAAE,QAAiC,EACnD,IAA2B;QAR/B,kBAAa,GAAG,CAAC,GAAG,CAAC,CAAC;QACtB,iBAAY,GAAG,IAAI,CAAC;QACpB,iBAAY,GAAG,IAAI,CAAC;QAOlB,IAAI,CAAC,WAAW,GAAG,QAAQ,CAAC,GAAG,CAC3B,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,eAAe,GAAG,MAAM,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,cAAc,CAAC,CAAC;QACtE,MAAM,IAAI,GAAG,MAAM,CAAC,MAAM,CAAC;QAC3B,MAAM,KAAK,GAAG,iBAAiB,CAAC,IAAI,CAAC,CAAC;QAEtC,MAAM,KAAK,GAAG,QAAQ,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;QAChD,MAAM,GAAG,GAAG,QAAQ,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;QAC/D,MAAM,MAAM,GAAG,WAAW,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC;QACvC,MAAM,MAAM,GAAG,WAAW,CAAC,QAAQ,EAAE,IAAI,CAAC,CAAC;QAC3C,MAAM,MAAM,GAAG,GAAG,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC,MAAM,IAAI,CAAC,WAAW,CAAC,IAAI,GAAG,CAAC,CAAC,EAAE,CAAC;QACrE,MAAM,SAAS,GACX,IAAI,KAAK,CAAC,CAAC,CAAC,CAAC,QAAQ,CAAC,CAAC,CAAC,QAAQ,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,EAAE,GAAG,CAAC;QAC/D,MAAM,MAAM,GAAG,IAAI,KAAK,SAAS,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAE1C,IAAI,QAAQ,GAAG,EAAE,CAAC;QAClB,IAAI,IAAI,KAAK,CAAC,EAAE;YACd,MAAM,QAAQ,GAAG;UACb,KAAK;;0CAE2B,MAAM;;8CAEF,MAAM;;;OAG7C,CAAC;YACF,QAAQ,GAAG;UACP,KAAK;UACL,QAAQ;sCACoB,MAAM,CAAC,IAAI,EAAE,MAAM,SAAS;UACxD,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;aACb,MAAM;YACP,QAAQ;wCACoB,MAAM,CAAC,IAAI,EAAE,MAAM,SAAS;;OAE7D,CAAC;SACH;aAAM;YACL,MAAM,QAAQ,GAAG;UACb,KAAK;UACL,KAAK,SAAS,KAAK;UACnB,KAAK,UAAU,KAAK;UACpB,KAAK;;6CAE8B,MAAM;kDACD,MAAM;;OAEjD,CAAC;YAEF,QAAQ,GAAG;UACP,KAAK;UACL,QAAQ;sCACoB,MAAM,CAAC,IAAI,EAAE,MAAM,SAAS;UACxD,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;aACb,MAAM;YACP,QAAQ;wCACoB,MAAM,CAAC,IAAI,EAAE,MAAM,SAAS;;;UAG1D,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;aACb,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC,MAAM,IAAI,CAAC,WAAW,CAAC,IAAI,GAAG,CAAC,CAAC;YACjD,QAAQ;wCACoB,MAAM,CAAC,IAAI,EAAE,MAAM,SAAS;YACxD,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;eACb,MAAM;cACP,QAAQ;0CACoB,MAAM,CAAC,IAAI,EAAE,MAAM,SAAS;;;OAG/D,CAAC;SACH;QAED,IAAI,CAAC,QAAQ,GAAG;cACN,KAAK,YAAY,KAAK,IAAI,KAAK;cAC/B,KAAK,UAAU,KAAK,IAAI,GAAG;;;UAG/B,KAAK;;UAEL,QAAQ;;;KAGb,CAAC;IACJ,CAAC;CACF","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getChannels} from './packing_util';\nimport {getCoordsDataType} from './shader_compiler';\n\n/**\n * Example shader code for\n * `mirrorPad(tf.tensor1d([1, 2, 3], 'int32'), [[2, 2]], 'reflect')`\n * ```\n *    const int start = int(2);\n *    const int end = int(5);\n *\n *    void main() {\n *       int outputLoc = getOutputCoords();\n *       vec4 result = vec4(0.);\n *\n *       int rc = outputLoc;\n *\n *       int source = rc;\n *       if (source < start) {\n *         source = start * 2 - source - 0;\n *       } else if (source >= end) {\n *         source = (end - 1) * 2 - source + 0;\n *       }\n *       source -= start;\n *\n *       result[0] = getChannel(getX(source), source);\n *       rc += 1;\n *       if(rc < 6) {\n *          int source = rc;\n *          if (source < start) {\n *            source = start * 2 - source - 0;\n *          } else if (source >= end) {\n *            source = (end - 1) * 2 - source + 0;\n *          }\n *          source -= start;\n *\n *         result[1] = getChannel(getX(source), source);\n *       }\n *\n *       setOutput(result);\n *     }\n * ```\n */\nexport class MirrorPadPackedProgram implements GPGPUProgram {\n  variableNames = ['x'];\n  packedInputs = true;\n  packedOutput = true;\n  outputShape: number[];\n  userCode: string;\n\n  constructor(\n      xShape: number[], paddings: Array<[number, number]>,\n      mode: 'reflect'|'symmetric') {\n    this.outputShape = paddings.map(\n        (p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);\n    const rank = xShape.length;\n    const dtype = getCoordsDataType(rank);\n\n    const start = paddings.map(p => p[0]).join(',');\n    const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');\n    const coords = getChannels('rc', rank);\n    const source = getChannels('source', rank);\n    const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`;\n    const innerDims =\n        rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`;\n    const offset = mode === 'reflect' ? 0 : 1;\n\n    let mainLoop = '';\n    if (rank === 1) {\n      const padSetup = `\n        ${dtype} source = rc;\n        if (source < start) {\n          source = start * 2 - source - ${offset};\n        } else if (source >= end) {\n          source = (end - 1) * 2 - source + ${offset};\n        }\n        source -= start;\n      `;\n      mainLoop = `\n        ${dtype} rc = outputLoc;\n        ${padSetup}\n        result[0] = getChannel(getX(${source.join()}), ${innerDims});\n        ${coords[rank - 1]} += 1;\n        if(${cLimit}) {\n          ${padSetup}\n          result[1] = getChannel(getX(${source.join()}), ${innerDims});\n        }\n      `;\n    } else {\n      const padSetup = `\n        ${dtype} source = rc;\n        ${dtype} lt = ${dtype}(lessThan(source, start));\n        ${dtype} gte = ${dtype}(greaterThanEqual(source, end));\n        ${dtype} orig = 1 - (lt + gte);\n        source = orig * source +\n                lt * (start * 2 - source - ${offset}) +\n                gte * ((end - 1) * 2 - source + ${offset});\n        source -= start;\n      `;\n\n      mainLoop = `\n        ${dtype} rc = outputLoc;\n        ${padSetup}\n        result[0] = getChannel(getX(${source.join()}), ${innerDims});\n        ${coords[rank - 1]} += 1;\n        if(${cLimit}) {\n          ${padSetup}\n          result[1] = getChannel(getX(${source.join()}), ${innerDims});\n        }\n        rc = outputLoc;\n        ${coords[rank - 2]} += 1;\n        if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) {\n          ${padSetup}\n          result[2] = getChannel(getX(${source.join()}), ${innerDims});\n          ${coords[rank - 1]} += 1;\n          if(${cLimit}) {\n            ${padSetup}\n            result[3] = getChannel(getX(${source.join()}), ${innerDims});\n          }\n        }\n      `;\n    }\n\n    this.userCode = `\n      const ${dtype} start = ${dtype}(${start});\n      const ${dtype} end = ${dtype}(${end});\n\n      void main() {\n        ${dtype} outputLoc = getOutputCoords();\n        vec4 result = vec4(0.);\n        ${mainLoop}\n        setOutput(result);\n      }\n    `;\n  }\n}\n"]}