/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { getChannels } from './packing_util'; import { getCoordsDataType } from './shader_compiler'; /** * Example shader code for * `mirrorPad(tf.tensor1d([1, 2, 3], 'int32'), [[2, 2]], 'reflect')` * ``` * const int start = int(2); * const int end = int(5); * * void main() { * int outputLoc = getOutputCoords(); * vec4 result = vec4(0.); * * int rc = outputLoc; * * int source = rc; * if (source < start) { * source = start * 2 - source - 0; * } else if (source >= end) { * source = (end - 1) * 2 - source + 0; * } * source -= start; * * result[0] = getChannel(getX(source), source); * rc += 1; * if(rc < 6) { * int source = rc; * if (source < start) { * source = start * 2 - source - 0; * } else if (source >= end) { * source = (end - 1) * 2 - source + 0; * } * source -= start; * * result[1] = getChannel(getX(source), source); * } * * setOutput(result); * } * ``` */ export class MirrorPadPackedProgram { constructor(xShape, paddings, mode) { this.variableNames = ['x']; this.packedInputs = true; this.packedOutput = true; this.outputShape = paddings.map((p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */); const rank = xShape.length; const dtype = getCoordsDataType(rank); const start = paddings.map(p => p[0]).join(','); const end = paddings.map((p, i) => p[0] + xShape[i]).join(','); const coords = getChannels('rc', rank); const source = getChannels('source', rank); const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`; const innerDims = rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`; const offset = mode === 'reflect' ? 0 : 1; let mainLoop = ''; if (rank === 1) { const padSetup = ` ${dtype} source = rc; if (source < start) { source = start * 2 - source - ${offset}; } else if (source >= end) { source = (end - 1) * 2 - source + ${offset}; } source -= start; `; mainLoop = ` ${dtype} rc = outputLoc; ${padSetup} result[0] = getChannel(getX(${source.join()}), ${innerDims}); ${coords[rank - 1]} += 1; if(${cLimit}) { ${padSetup} result[1] = getChannel(getX(${source.join()}), ${innerDims}); } `; } else { const padSetup = ` ${dtype} source = rc; ${dtype} lt = ${dtype}(lessThan(source, start)); ${dtype} gte = ${dtype}(greaterThanEqual(source, end)); ${dtype} orig = 1 - (lt + gte); source = orig * source + lt * (start * 2 - source - ${offset}) + gte * ((end - 1) * 2 - source + ${offset}); source -= start; `; mainLoop = ` ${dtype} rc = outputLoc; ${padSetup} result[0] = getChannel(getX(${source.join()}), ${innerDims}); ${coords[rank - 1]} += 1; if(${cLimit}) { ${padSetup} result[1] = getChannel(getX(${source.join()}), ${innerDims}); } rc = outputLoc; ${coords[rank - 2]} += 1; if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) { ${padSetup} result[2] = getChannel(getX(${source.join()}), ${innerDims}); ${coords[rank - 1]} += 1; if(${cLimit}) { ${padSetup} result[3] = getChannel(getX(${source.join()}), ${innerDims}); } } `; } this.userCode = ` const ${dtype} start = ${dtype}(${start}); const ${dtype} end = ${dtype}(${end}); void main() { ${dtype} outputLoc = getOutputCoords(); vec4 result = vec4(0.); ${mainLoop} setOutput(result); } `; } } //# sourceMappingURL=data:application/json;base64,