"use strict"; /** * @license * Copyright 2019 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ Object.defineProperty(exports, "__esModule", { value: true }); var util_1 = require("../../util"); var packing_util_1 = require("../packing_util"); var shader_compiler_1 = require("./shader_compiler"); var ArgMinMaxPackedProgram = /** @class */ (function () { function ArgMinMaxPackedProgram(shape, windowSize, op, firstPass) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; util_1.assert(shape.length > 2, function () { return "Packed arg" + (op.charAt(0).toUpperCase() + op.slice(1)) + " supports only inputs with rank above 2."; }); var inSize = shape[shape.length - 1]; var outSize = Math.ceil(inSize / windowSize); this.outputShape = shape.slice(0, -1); if (outSize > 1) { this.outputShape.push(outSize); } if (!firstPass) { this.variableNames.push('bestIndicesA'); } var outShape = this.outputShape; var rank = outShape.length; var dtype = shader_compiler_1.getCoordsDataType(rank); var coords = packing_util_1.getChannels('coords', rank); var sourceLocSetup; var sourceRank; if (outSize === 1) { sourceRank = rank + 1; var sourceLocDType = shader_compiler_1.getCoordsDataType(sourceRank); sourceLocSetup = "\n " + sourceLocDType + " sourceLocR = " + sourceLocDType + "(" + coords.join() + ", 0);\n ++" + coords[rank - 1] + ";\n " + sourceLocDType + " sourceLocG = " + sourceLocDType + "(" + coords.join() + ", 0);\n ++" + coords[rank - 2] + ";\n " + sourceLocDType + " sourceLocA = " + sourceLocDType + "(" + coords.join() + ", 0);\n --" + coords[rank - 1] + ";\n " + sourceLocDType + " sourceLocB = " + sourceLocDType + "(" + coords.join() + ", 0);\n --" + coords[rank - 2] + ";"; } else { sourceRank = rank; sourceLocSetup = "\n " + dtype + " sourceLocR = coords;\n ++" + coords[rank - 1] + ";\n " + dtype + " sourceLocG = coords;\n ++" + coords[rank - 2] + ";\n " + dtype + " sourceLocA = coords;\n --" + coords[rank - 1] + ";\n " + dtype + " sourceLocB = coords;\n --" + coords[rank - 2] + ";"; } var channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank); var inChannel = '.' + channels[sourceRank - 1]; // e.g. ".b" for rank 3. var intChannels = channels.map(function (x) { return 'int ' + x; }); var srcRCoords = packing_util_1.getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r'); var srcGCoords = packing_util_1.getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g'); var srcBCoords = packing_util_1.getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b'); var srcACoords = packing_util_1.getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a'); var compOp = (op === 'max') ? 'greaterThan' : 'lessThan'; var fetchCandidateIdx = firstPass ? '' : "\n inIdx = round(vec4(getBestIndicesAChannel(" + srcRCoords.join() + "),\n getBestIndicesAChannel(" + srcGCoords.join() + "),\n getBestIndicesAChannel(" + srcBCoords.join() + "),\n getBestIndicesAChannel(" + srcACoords.join() + ")));"; var fetchValue = "vec4(\n getAChannel(" + srcRCoords.join() + "),\n hasNextCol ? getAChannel(" + srcGCoords.join() + ") : 0.,\n hasNextRow ? getAChannel(" + srcBCoords.join() + ") : 0.,\n hasNextRow && hasNextCol ? getAChannel(" + srcACoords.join() + ") : 0.)"; var getBestIndicesAChannelSnippet = firstPass ? '' : "\n float getBestIndicesAChannel(" + intChannels.join() + ") {\n return getChannel(getBestIndicesA(" + channels.join() + "),\n vec2(" + channels.slice(-2).join() + "));\n }"; this.userCode = "\n float getAChannel(" + intChannels.join() + ") {\n return getChannel(getA(" + channels.join() + "),\n vec2(" + channels.slice(-2).join() + "));\n }\n " + getBestIndicesAChannelSnippet + "\n void main() {\n " + dtype + " coords = getOutputCoords();\n bool hasNextCol = " + coords[rank - 1] + " < " + (outShape[rank - 1] - 1) + ";\n bool hasNextRow = " + coords[rank - 2] + " < " + (outShape[rank - 2] - 1) + ";\n " + sourceLocSetup + "\n ivec4 srcIdx = ivec4(sourceLocR" + inChannel + ", sourceLocG" + inChannel + ",\n sourceLocB" + inChannel + ", sourceLocA" + inChannel + ") * " + windowSize + ";\n ivec4 inIdx = srcIdx;\n vec4 bestIndex = vec4(inIdx);\n vec4 bestValue = " + fetchValue + ";\n\n for (int i = 0; i < " + windowSize + "; i++) {\n inIdx = srcIdx;\n " + fetchCandidateIdx + "\n vec4 candidate = " + fetchValue + ";\n bvec4 nan = isnan(candidate);\n bvec4 replace = bvec4(\n vec4(" + compOp + "(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));\n\n bestValue = vec4(replace.x ? candidate.x : bestValue.x,\n replace.y ? candidate.y : bestValue.y,\n replace.z ? candidate.z : bestValue.z,\n replace.w ? candidate.w : bestValue.w);\n bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));\n srcIdx++;\n }\n setOutput(bestIndex);\n }\n "; } return ArgMinMaxPackedProgram; }()); exports.ArgMinMaxPackedProgram = ArgMinMaxPackedProgram; //# sourceMappingURL=argminmax_packed_gpu.js.map