/**
|
* @license
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { util } from '@tensorflow/tfjs-core';
|
import { getChannels } from './packing_util';
|
import { getCoordsDataType } from './shader_compiler';
|
export class ArgMinMaxPackedProgram {
|
constructor(shape, windowSize, op, firstPass) {
|
this.variableNames = ['A'];
|
this.packedInputs = true;
|
this.packedOutput = true;
|
util.assert(shape.length > 2, () => `Packed arg${op.charAt(0).toUpperCase() +
|
op.slice(1)} supports only inputs with rank above 2.`);
|
const inSize = shape[shape.length - 1];
|
const outSize = Math.ceil(inSize / windowSize);
|
this.outputShape = shape.slice(0, -1);
|
if (outSize > 1) {
|
this.outputShape.push(outSize);
|
}
|
if (!firstPass) {
|
this.variableNames.push('bestIndicesA');
|
}
|
const outShape = this.outputShape;
|
const rank = outShape.length;
|
const dtype = getCoordsDataType(rank);
|
const coords = getChannels('coords', rank);
|
let sourceLocSetup;
|
let sourceRank;
|
if (outSize === 1) {
|
sourceRank = rank + 1;
|
const sourceLocDType = getCoordsDataType(sourceRank);
|
sourceLocSetup = `
|
${sourceLocDType} sourceLocR = ${sourceLocDType}(${coords.join()}, 0);
|
++${coords[rank - 1]};
|
${sourceLocDType} sourceLocG = ${sourceLocDType}(${coords.join()}, 0);
|
++${coords[rank - 2]};
|
${sourceLocDType} sourceLocA = ${sourceLocDType}(${coords.join()}, 0);
|
--${coords[rank - 1]};
|
${sourceLocDType} sourceLocB = ${sourceLocDType}(${coords.join()}, 0);
|
--${coords[rank - 2]};`;
|
}
|
else {
|
sourceRank = rank;
|
sourceLocSetup = `
|
${dtype} sourceLocR = coords;
|
++${coords[rank - 1]};
|
${dtype} sourceLocG = coords;
|
++${coords[rank - 2]};
|
${dtype} sourceLocA = coords;
|
--${coords[rank - 1]};
|
${dtype} sourceLocB = coords;
|
--${coords[rank - 2]};`;
|
}
|
const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank);
|
const inChannel = '.' + channels[sourceRank - 1]; // e.g. ".b" for rank 3.
|
const intChannels = channels.map(x => 'int ' + x);
|
const srcRCoords = getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r');
|
const srcGCoords = getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g');
|
const srcBCoords = getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b');
|
const srcACoords = getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a');
|
const compOp = (op === 'max') ? 'greaterThan' : 'lessThan';
|
const fetchCandidateIdx = firstPass ? '' : `
|
inIdx = round(vec4(getBestIndicesAChannel(${srcRCoords.join()}),
|
getBestIndicesAChannel(${srcGCoords.join()}),
|
getBestIndicesAChannel(${srcBCoords.join()}),
|
getBestIndicesAChannel(${srcACoords.join()})));`;
|
const fetchValue = `vec4(
|
getAChannel(${srcRCoords.join()}),
|
hasNextCol ? getAChannel(${srcGCoords.join()}) : 0.,
|
hasNextRow ? getAChannel(${srcBCoords.join()}) : 0.,
|
hasNextRow && hasNextCol ? getAChannel(${srcACoords.join()}) : 0.)`;
|
const getBestIndicesAChannelSnippet = firstPass ? '' : `
|
float getBestIndicesAChannel(${intChannels.join()}) {
|
return getChannel(getBestIndicesA(${channels.join()}),
|
vec2(${channels.slice(-2).join()}));
|
}`;
|
this.userCode = `
|
float getAChannel(${intChannels.join()}) {
|
return getChannel(getA(${channels.join()}),
|
vec2(${channels.slice(-2).join()}));
|
}
|
${getBestIndicesAChannelSnippet}
|
void main() {
|
${dtype} coords = getOutputCoords();
|
bool hasNextCol = ${coords[rank - 1]} < ${outShape[rank - 1] - 1};
|
bool hasNextRow = ${coords[rank - 2]} < ${outShape[rank - 2] - 1};
|
${sourceLocSetup}
|
ivec4 srcIdx = ivec4(sourceLocR${inChannel}, sourceLocG${inChannel},
|
sourceLocB${inChannel}, sourceLocA${inChannel}) * ${windowSize};
|
ivec4 inIdx = srcIdx;
|
vec4 bestIndex = vec4(inIdx);
|
vec4 bestValue = ${fetchValue};
|
|
for (int i = 0; i < ${windowSize}; i++) {
|
inIdx = srcIdx;
|
${fetchCandidateIdx}
|
vec4 candidate = ${fetchValue};
|
bvec4 nan = isnan(candidate);
|
bvec4 replace = bvec4(
|
vec4(${compOp}(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));
|
|
bestValue = vec4(replace.x ? candidate.x : bestValue.x,
|
replace.y ? candidate.y : bestValue.y,
|
replace.z ? candidate.z : bestValue.z,
|
replace.w ? candidate.w : bestValue.w);
|
bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));
|
srcIdx++;
|
}
|
setOutput(bestIndex);
|
}
|
`;
|
}
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"argminmax_packed_gpu.js","sourceRoot":"","sources":["../../../../../tfjs-backend-webgl/src/argminmax_packed_gpu.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAG3C,OAAO,EAAC,WAAW,EAAC,MAAM,gBAAgB,CAAC;AAC3C,OAAO,EAAC,iBAAiB,EAAC,MAAM,mBAAmB,CAAC;AAEpD,MAAM,OAAO,sBAAsB;IAOjC,YACI,KAAe,EAAE,UAAkB,EAAE,EAAe,EACpD,SAAkB;QARtB,kBAAa,GAAG,CAAC,GAAG,CAAC,CAAC;QAGtB,iBAAY,GAAG,IAAI,CAAC;QACpB,iBAAY,GAAG,IAAI,CAAC;QAKlB,IAAI,CAAC,MAAM,CACP,KAAK,CAAC,MAAM,GAAG,CAAC,EAChB,GAAG,EAAE,CAAC,aACF,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,WAAW,EAAE;YAC1B,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,0CAA0C,CAAC,CAAC;QAC/D,MAAM,MAAM,GAAG,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;QACvC,MAAM,OAAO,GAAG,IAAI,CAAC,IAAI,CAAC,MAAM,GAAG,UAAU,CAAC,CAAC;QAC/C,IAAI,CAAC,WAAW,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,OAAO,GAAG,CAAC,EAAE;YACf,IAAI,CAAC,WAAW,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;SAChC;QACD,IAAI,CAAC,SAAS,EAAE;YACd,IAAI,CAAC,aAAa,CAAC,IAAI,CAAC,cAAc,CAAC,CAAC;SACzC;QACD,MAAM,QAAQ,GAAG,IAAI,CAAC,WAAW,CAAC;QAClC,MAAM,IAAI,GAAG,QAAQ,CAAC,MAAM,CAAC;QAC7B,MAAM,KAAK,GAAG,iBAAiB,CAAC,IAAI,CAAC,CAAC;QACtC,MAAM,MAAM,GAAG,WAAW,CAAC,QAAQ,EAAE,IAAI,CAAC,CAAC;QAE3C,IAAI,cAAc,CAAC;QACnB,IAAI,UAAU,CAAC;QACf,IAAI,OAAO,KAAK,CAAC,EAAE;YACjB,UAAU,GAAG,IAAI,GAAG,CAAC,CAAC;YACtB,MAAM,cAAc,GAAG,iBAAiB,CAAC,UAAU,CAAC,CAAC;YACrD,cAAc,GAAG;UACb,cAAc,iBAAiB,cAAc,IAAI,MAAM,CAAC,IAAI,EAAE;YAC5D,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;UAClB,cAAc,iBAAiB,cAAc,IAAI,MAAM,CAAC,IAAI,EAAE;YAC5D,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;UAClB,cAAc,iBAAiB,cAAc,IAAI,MAAM,CAAC,IAAI,EAAE;YAC5D,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;UAClB,cAAc,iBAAiB,cAAc,IAAI,MAAM,CAAC,IAAI,EAAE;YAC5D,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC,GAAG,CAAC;SAC3B;aAAM;YACL,UAAU,GAAG,IAAI,CAAC;YAClB,cAAc,GAAG;UACb,KAAK;YACH,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;UAClB,KAAK;YACH,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;UAClB,KAAK;YACH,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC;UAClB,KAAK;YACH,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC,GAAG,CAAC;SAC3B;QACD,MAAM,QAAQ,GAAG,CAAC,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC;QACrE,MAAM,SAAS,GAAG,GAAG,GAAG,QAAQ,CAAC,UAAU,GAAG,CAAC,CAAC,CAAC,CAAE,wBAAwB;QAC3E,MAAM,WAAW,GAAG,QAAQ,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;QAClD,MAAM,UAAU,GACZ,WAAW,CAAC,YAAY,EAAE,UAAU,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,SAAS,CAAC,CAAC;QAChE,MAAM,UAAU,GACZ,WAAW,CAAC,YAAY,EAAE,UAAU,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,SAAS,CAAC,CAAC;QAChE,MAAM,UAAU,GACZ,WAAW,CAAC,YAAY,EAAE,UAAU,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,SAAS,CAAC,CAAC;QAChE,MAAM,UAAU,GACZ,WAAW,CAAC,YAAY,EAAE,UAAU,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,SAAS,CAAC,CAAC;QAEhE,MAAM,MAAM,GAAG,CAAC,EAAE,KAAK,KAAK,CAAC,CAAC,CAAC,CAAC,aAAa,CAAC,CAAC,CAAC,UAAU,CAAC;QAC3D,MAAM,iBAAiB,GAAG,SAAS,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;sDACO,UAAU,CAAC,IAAI,EAAE;sDACjB,UAAU,CAAC,IAAI,EAAE;sDACjB,UAAU,CAAC,IAAI,EAAE;sDACjB,UAAU,CAAC,IAAI,EAAE,MAAM,CAAC;QAE1E,MAAM,UAAU,GAAG;0BACG,UAAU,CAAC,IAAI,EAAE;uCACJ,UAAU,CAAC,IAAI,EAAE;uCACjB,UAAU,CAAC,IAAI,EAAE;qDACH,UAAU,CAAC,IAAI,EAAE,SAAS,CAAC;QAE5E,MAAM,6BAA6B,GAAG,SAAS,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;qCACtB,WAAW,CAAC,IAAI,EAAE;4CACX,QAAQ,CAAC,IAAI,EAAE;iDACV,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,EAAE;QAClE,CAAC;QAEL,IAAI,CAAC,QAAQ,GAAG;0BACM,WAAW,CAAC,IAAI,EAAE;iCACX,QAAQ,CAAC,IAAI,EAAE;sCACV,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,EAAE;;QAEvD,6BAA6B;;UAE3B,KAAK;4BACa,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC,MAAM,QAAQ,CAAC,IAAI,GAAG,CAAC,CAAC,GAAG,CAAC;4BAC5C,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC,MAAM,QAAQ,CAAC,IAAI,GAAG,CAAC,CAAC,GAAG,CAAC;UAC9D,cAAc;yCACiB,SAAS,eAAe,SAAS;sBACpD,SAAS,eAAe,SAAS,OAAO,UAAU;;;2BAG7C,UAAU;;8BAEP,UAAU;;YAE5B,iBAAiB;6BACA,UAAU;;;mBAGpB,MAAM;;;;;;;;;;;KAWpB,CAAC;IACJ,CAAC;CACF","sourcesContent":["/**\n * @license\n * Copyright 2019 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {util} from '@tensorflow/tfjs-core';\n\nimport {GPGPUProgram} from './gpgpu_math';\nimport {getChannels} from './packing_util';\nimport {getCoordsDataType} from './shader_compiler';\n\nexport class ArgMinMaxPackedProgram implements GPGPUProgram {\n  variableNames = ['A'];\n  outputShape: number[];\n  userCode: string;\n  packedInputs = true;\n  packedOutput = true;\n\n  constructor(\n      shape: number[], windowSize: number, op: 'max'|'min',\n      firstPass: boolean) {\n    util.assert(\n        shape.length > 2,\n        () => `Packed arg${\n            op.charAt(0).toUpperCase() +\n            op.slice(1)} supports only inputs with rank above 2.`);\n    const inSize = shape[shape.length - 1];\n    const outSize = Math.ceil(inSize / windowSize);\n    this.outputShape = shape.slice(0, -1);\n    if (outSize > 1) {\n      this.outputShape.push(outSize);\n    }\n    if (!firstPass) {\n      this.variableNames.push('bestIndicesA');\n    }\n    const outShape = this.outputShape;\n    const rank = outShape.length;\n    const dtype = getCoordsDataType(rank);\n    const coords = getChannels('coords', rank);\n\n    let sourceLocSetup;\n    let sourceRank;\n    if (outSize === 1) {\n      sourceRank = rank + 1;\n      const sourceLocDType = getCoordsDataType(sourceRank);\n      sourceLocSetup = `\n        ${sourceLocDType} sourceLocR = ${sourceLocDType}(${coords.join()}, 0);\n        ++${coords[rank - 1]};\n        ${sourceLocDType} sourceLocG = ${sourceLocDType}(${coords.join()}, 0);\n        ++${coords[rank - 2]};\n        ${sourceLocDType} sourceLocA = ${sourceLocDType}(${coords.join()}, 0);\n        --${coords[rank - 1]};\n        ${sourceLocDType} sourceLocB = ${sourceLocDType}(${coords.join()}, 0);\n        --${coords[rank - 2]};`;\n    } else {\n      sourceRank = rank;\n      sourceLocSetup = `\n        ${dtype} sourceLocR = coords;\n        ++${coords[rank - 1]};\n        ${dtype} sourceLocG = coords;\n        ++${coords[rank - 2]};\n        ${dtype} sourceLocA = coords;\n        --${coords[rank - 1]};\n        ${dtype} sourceLocB = coords;\n        --${coords[rank - 2]};`;\n    }\n    const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank);\n    const inChannel = '.' + channels[sourceRank - 1];  // e.g. \".b\" for rank 3.\n    const intChannels = channels.map(x => 'int ' + x);\n    const srcRCoords =\n        getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r');\n    const srcGCoords =\n        getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g');\n    const srcBCoords =\n        getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b');\n    const srcACoords =\n        getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a');\n\n    const compOp = (op === 'max') ? 'greaterThan' : 'lessThan';\n    const fetchCandidateIdx = firstPass ? '' : `\n          inIdx = round(vec4(getBestIndicesAChannel(${srcRCoords.join()}),\n                             getBestIndicesAChannel(${srcGCoords.join()}),\n                             getBestIndicesAChannel(${srcBCoords.join()}),\n                             getBestIndicesAChannel(${srcACoords.join()})));`;\n\n    const fetchValue = `vec4(\n            getAChannel(${srcRCoords.join()}),\n            hasNextCol ? getAChannel(${srcGCoords.join()}) : 0.,\n            hasNextRow ? getAChannel(${srcBCoords.join()}) : 0.,\n            hasNextRow && hasNextCol ? getAChannel(${srcACoords.join()}) : 0.)`;\n\n    const getBestIndicesAChannelSnippet = firstPass ? '' : `\n      float getBestIndicesAChannel(${intChannels.join()}) {\n        return getChannel(getBestIndicesA(${channels.join()}),\n                                          vec2(${channels.slice(-2).join()}));\n      }`;\n\n    this.userCode = `\n      float getAChannel(${intChannels.join()}) {\n        return getChannel(getA(${channels.join()}),\n                               vec2(${channels.slice(-2).join()}));\n      }\n      ${getBestIndicesAChannelSnippet}\n      void main() {\n        ${dtype} coords = getOutputCoords();\n        bool hasNextCol = ${coords[rank - 1]} < ${outShape[rank - 1] - 1};\n        bool hasNextRow = ${coords[rank - 2]} < ${outShape[rank - 2] - 1};\n        ${sourceLocSetup}\n        ivec4 srcIdx = ivec4(sourceLocR${inChannel}, sourceLocG${inChannel},\n          sourceLocB${inChannel}, sourceLocA${inChannel}) * ${windowSize};\n        ivec4 inIdx = srcIdx;\n        vec4 bestIndex = vec4(inIdx);\n        vec4 bestValue = ${fetchValue};\n\n        for (int i = 0; i < ${windowSize}; i++) {\n          inIdx = srcIdx;\n          ${fetchCandidateIdx}\n          vec4 candidate = ${fetchValue};\n          bvec4 nan = isnan(candidate);\n          bvec4 replace = bvec4(\n            vec4(${compOp}(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));\n\n          bestValue = vec4(replace.x  ? candidate.x : bestValue.x,\n                           replace.y  ? candidate.y : bestValue.y,\n                           replace.z  ? candidate.z : bestValue.z,\n                           replace.w  ? candidate.w : bestValue.w);\n          bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));\n          srcIdx++;\n        }\n        setOutput(bestIndex);\n      }\n    `;\n  }\n}\n"]}
|