gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { env, TopK, util } from '@tensorflow/tfjs-core';
import { topKImplCPU } from '../kernel_utils/shared';
import { MergeProgram, SwapProgram } from '../top_k_gpu';
import { fill } from './Fill';
import { gatherV2 } from './GatherV2';
import { reshape } from './Reshape';
import { slice } from './Slice';
function disposeIntermediateTensorInfoOrNull(backend, tensorInfo) {
    if (tensorInfo !== null) {
        backend.disposeIntermediateTensorInfo(tensorInfo);
    }
}
function roundUpToPow2(num) {
    let pow2 = 1;
    while (pow2 < num) {
        pow2 *= 2;
    }
    return pow2;
}
// Based on Algorithm 2 of Bitonic Top K, ref:
// https://anilshanbhag.in/static/papers/gputopk_sigmod18.pdf
export function topK(args) {
    const { inputs, backend, attrs } = args;
    const { x } = inputs;
    const { k, sorted } = attrs;
    // Empirically determined constant used to determine last dim threshold for
    // handing off execution to the CPU.
    const TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD = env().getNumber('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD');
    // Empirically determined constant used to determine k threshold for handing
    // off execution to the CPU.
    const TOPK_K_CPU_HANDOFF_THRESHOLD = env().getNumber('TOPK_K_CPU_HANDOFF_THRESHOLD');
    const xShape = x.shape;
    const lastDim = xShape[xShape.length - 1];
    if (backend.shouldExecuteOnCPU([x]) ||
        lastDim < TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD ||
        k > TOPK_K_CPU_HANDOFF_THRESHOLD) {
        const xVals = backend.readSync(x.dataId);
        const [allTopKVals, allTopKIndices] = topKImplCPU(xVals, xShape, x.dtype, k, sorted);
        return [
            backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values),
            backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values)
        ];
    }
    if (k === 0) {
        xShape[xShape.length - 1] = 0;
        return [
            backend.makeTensorInfo(xShape, x.dtype, []),
            backend.makeTensorInfo(xShape, 'int32', [])
        ];
    }
    if (lastDim === 1 /* firstPass */) {
        return [
            x, fill({ attrs: { shape: xShape, dtype: 'int32', value: 0 }, backend })
        ];
    }
    // Eagerly unpack x input since it is passed in to all the shaders which
    // require unpacked inputs.
    const xtexData = backend.texData.get(x.dataId);
    const xIsPacked = xtexData !== null && xtexData.isPacked;
    const xUnPacked = xIsPacked ? backend.unpackTensor(x) : x;
    // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim.
    const xSize = util.sizeFromShape(xShape);
    const batch = xSize / lastDim;
    const x2D = reshape({ inputs: { x: xUnPacked }, attrs: { shape: [batch, lastDim] }, backend });
    if (xIsPacked) {
        disposeIntermediateTensorInfoOrNull(backend, xUnPacked);
    }
    const kPow2 = roundUpToPow2(k);
    const lastDimPow2 = roundUpToPow2(lastDim);
    // Only the indices containing the top K are kept at every step to reduce
    // number of outputs in the GPU algorithms, so once the final set of indices
    // is computed then gather is used to grab the corresponding values
    // from the original input.
    let indices = null;
    // GPU algorithm always takes in an indices input but this input is not used
    // on the first run of a GPU algorithm, therefore if indices is null we simply
    // pass in x2D instead of it but the value will not actually be used
    const getInputs = () => indices === null ? [x2D, x2D] : [x2D, indices];
    const runSwap = (dir, inc, shape) => {
        const inputs = getInputs();
        const program = new SwapProgram(shape);
        const fistPass = indices === null ? 1 : 0;
        const customValues = [[lastDim], [fistPass], [Number.NEGATIVE_INFINITY], [dir], [inc]];
        const prevIndices = indices;
        indices = backend.runWebGLProgram(program, inputs, 'int32', customValues);
        disposeIntermediateTensorInfoOrNull(backend, prevIndices);
    };
    // Step 1: local sort
    for (let len = 1; len < kPow2; len *= 2) {
        const dir = len * 2;
        for (let inc = len; inc >= 1; inc /= 2) {
            runSwap(dir, inc, [batch, lastDimPow2]);
        }
    }
    // Step 2: merge
    for (let indicesSize = lastDimPow2; indicesSize > kPow2; indicesSize /= 2) {
        const inputs = getInputs();
        const mergeProgram = new MergeProgram([batch, indicesSize / 2]);
        const firstPass = indices === null ? 1 : 0;
        const customValues = [[lastDim], [firstPass], [kPow2]];
        const prevIndices = indices;
        indices =
            backend.runWebGLProgram(mergeProgram, inputs, 'int32', customValues);
        disposeIntermediateTensorInfoOrNull(backend, prevIndices);
        // Step 3: rebuild
        const len = kPow2 / 2;
        const dir = len * 2;
        for (let inc = len; inc >= 1; inc /= 2) {
            runSwap(dir, inc, indices.shape);
        }
    }
    // Keep only the requested top K results instead of kPow2
    let prevIndices = indices;
    indices = slice({ inputs: { x: indices }, backend, attrs: { begin: 0, size: [batch, k] } });
    disposeIntermediateTensorInfoOrNull(backend, prevIndices);
    // Gather values on last dimension
    let values = gatherV2({ inputs: { x: x2D, indices }, backend, attrs: { axis: 1, batchDims: 1 } });
    disposeIntermediateTensorInfoOrNull(backend, x2D);
    // Reshape back to the original input shape, except that the last
    // dimension is k.
    const newShape = xShape.slice(0, -1);
    newShape.push(k);
    prevIndices = indices;
    indices = reshape({ inputs: { x: indices }, attrs: { shape: newShape }, backend });
    disposeIntermediateTensorInfoOrNull(backend, prevIndices);
    const prevValues = values;
    values = reshape({ inputs: { x: values }, attrs: { shape: newShape }, backend });
    disposeIntermediateTensorInfoOrNull(backend, prevValues);
    return [values, indices];
}
export const topKConfig = {
    kernelName: TopK,
    backendName: 'webgl',
    kernelFunc: topK
};
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"TopK.js","sourceRoot":"","sources":["../../../../../../tfjs-backend-webgl/src/kernels/TopK.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,GAAG,EAAyD,IAAI,EAAqC,IAAI,EAAC,MAAM,uBAAuB,CAAC;AAGhJ,OAAO,EAAC,WAAW,EAAC,MAAM,wBAAwB,CAAC;AACnD,OAAO,EAAC,YAAY,EAAE,WAAW,EAAC,MAAM,cAAc,CAAC;AACvD,OAAO,EAAC,IAAI,EAAC,MAAM,QAAQ,CAAC;AAC5B,OAAO,EAAC,QAAQ,EAAC,MAAM,YAAY,CAAC;AACpC,OAAO,EAAC,OAAO,EAAC,MAAM,WAAW,CAAC;AAClC,OAAO,EAAC,KAAK,EAAC,MAAM,SAAS,CAAC;AAE9B,SAAS,mCAAmC,CACxC,OAAyB,EAAE,UAAsB;IACnD,IAAI,UAAU,KAAK,IAAI,EAAE;QACvB,OAAO,CAAC,6BAA6B,CAAC,UAAU,CAAC,CAAC;KACnD;AACH,CAAC;AAED,SAAS,aAAa,CAAC,GAAW;IAChC,IAAI,IAAI,GAAG,CAAC,CAAC;IACb,OAAO,IAAI,GAAG,GAAG,EAAE;QACjB,IAAI,IAAI,CAAC,CAAC;KACX;IACD,OAAO,IAAI,CAAC;AACd,CAAC;AAED,8CAA8C;AAC9C,6DAA6D;AAC7D,MAAM,UAAU,IAAI,CAChB,IAAuE;IAEzE,MAAM,EAAC,MAAM,EAAE,OAAO,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC;IACtC,MAAM,EAAC,CAAC,EAAC,GAAG,MAAM,CAAC;IACnB,MAAM,EAAC,CAAC,EAAE,MAAM,EAAC,GAAG,KAAK,CAAC;IAE1B,2EAA2E;IAC3E,oCAAoC;IACpC,MAAM,wCAAwC,GAC1C,GAAG,EAAE,CAAC,SAAS,CAAC,0CAA0C,CAAC,CAAC;IAEhE,4EAA4E;IAC5E,4BAA4B;IAC5B,MAAM,4BAA4B,GAC9B,GAAG,EAAE,CAAC,SAAS,CAAC,8BAA8B,CAAC,CAAC;IAEpD,MAAM,MAAM,GAAG,CAAC,CAAC,KAAK,CAAC;IACvB,MAAM,OAAO,GAAG,MAAM,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;IAE1C,IAAI,OAAO,CAAC,kBAAkB,CAAC,CAAC,CAAC,CAAC,CAAC;QAC/B,OAAO,GAAG,wCAAwC;QAClD,CAAC,GAAG,4BAA4B,EAAE;QACpC,MAAM,KAAK,GAAG,OAAO,CAAC,QAAQ,CAAC,CAAC,CAAC,MAAM,CAAe,CAAC;QACvD,MAAM,CAAC,WAAW,EAAE,cAAc,CAAC,GAC/B,WAAW,CAAC,KAAK,EAAE,MAAM,EAAE,CAAC,CAAC,KAAwB,EAAE,CAAC,EAAE,MAAM,CAAC,CAAC;QAEtE,OAAO;YACL,OAAO,CAAC,cAAc,CAClB,WAAW,CAAC,KAAK,EAAE,WAAW,CAAC,KAAK,EAAE,WAAW,CAAC,MAAM,CAAC;YAC7D,OAAO,CAAC,cAAc,CAClB,cAAc,CAAC,KAAK,EAAE,cAAc,CAAC,KAAK,EAAE,cAAc,CAAC,MAAM,CAAC;SACvE,CAAC;KACH;IAED,IAAI,CAAC,KAAK,CAAC,EAAE;QACX,MAAM,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;QAC9B,OAAO;YACL,OAAO,CAAC,cAAc,CAAC,MAAM,EAAE,CAAC,CAAC,KAAK,EAAE,EAAE,CAAC;YAC3C,OAAO,CAAC,cAAc,CAAC,MAAM,EAAE,OAAO,EAAE,EAAE,CAAC;SAC5C,CAAC;KACH;IAED,IAAI,OAAO,KAAK,CAAC,CAAC,eAAe,EAAE;QACjC,OAAO;YACL,CAAC,EAAE,IAAI,CAAC,EAAC,KAAK,EAAE,EAAC,KAAK,EAAE,MAAM,EAAE,KAAK,EAAE,OAAO,EAAE,KAAK,EAAE,CAAC,EAAC,EAAE,OAAO,EAAC,CAAC;SACrE,CAAC;KACH;IAED,wEAAwE;IACxE,2BAA2B;IAC3B,MAAM,QAAQ,GAAG,OAAO,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;IAC/C,MAAM,SAAS,GAAG,QAAQ,KAAK,IAAI,IAAI,QAAQ,CAAC,QAAQ,CAAC;IACzD,MAAM,SAAS,GAAG,SAAS,CAAC,CAAC,CAAC,OAAO,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAE1D,4EAA4E;IAC5E,MAAM,KAAK,GAAG,IAAI,CAAC,aAAa,CAAC,MAAM,CAAC,CAAC;IACzC,MAAM,KAAK,GAAG,KAAK,GAAG,OAAO,CAAC;IAC9B,MAAM,GAAG,GAAG,OAAO,CACf,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,SAAS,EAAC,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,CAAC,KAAK,EAAE,OAAO,CAAC,EAAC,EAAE,OAAO,EAAC,CAAC,CAAC;IAEzE,IAAI,SAAS,EAAE;QACb,mCAAmC,CAAC,OAAO,EAAE,SAAS,CAAC,CAAC;KACzD;IAED,MAAM,KAAK,GAAG,aAAa,CAAC,CAAC,CAAC,CAAC;IAC/B,MAAM,WAAW,GAAG,aAAa,CAAC,OAAO,CAAC,CAAC;IAE3C,yEAAyE;IACzE,4EAA4E;IAC5E,mEAAmE;IACnE,2BAA2B;IAC3B,IAAI,OAAO,GAAe,IAAI,CAAC;IAE/B,4EAA4E;IAC5E,8EAA8E;IAC9E,oEAAoE;IACpE,MAAM,SAAS,GAAG,GAAG,EAAE,CAAC,OAAO,KAAK,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,EAAE,OAAO,CAAC,CAAC;IAEvE,MAAM,OAAO,GAAG,CAAC,GAAW,EAAE,GAAW,EAAE,KAAe,EAAE,EAAE;QAC5D,MAAM,MAAM,GAAG,SAAS,EAAE,CAAC;QAC3B,MAAM,OAAO,GAAG,IAAI,WAAW,CAAC,KAAK,CAAC,CAAC;QACvC,MAAM,QAAQ,GAAG,OAAO,KAAK,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAC1C,MAAM,YAAY,GACd,CAAC,CAAC,OAAO,CAAC,EAAE,CAAC,QAAQ,CAAC,EAAE,CAAC,MAAM,CAAC,iBAAiB,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC;QACtE,MAAM,WAAW,GAAG,OAAO,CAAC;QAC5B,OAAO,GAAG,OAAO,CAAC,eAAe,CAAC,OAAO,EAAE,MAAM,EAAE,OAAO,EAAE,YAAY,CAAC,CAAC;QAC1E,mCAAmC,CAAC,OAAO,EAAE,WAAW,CAAC,CAAC;IAC5D,CAAC,CAAC;IAEF,qBAAqB;IACrB,KAAK,IAAI,GAAG,GAAG,CAAC,EAAE,GAAG,GAAG,KAAK,EAAE,GAAG,IAAI,CAAC,EAAE;QACvC,MAAM,GAAG,GAAG,GAAG,GAAG,CAAC,CAAC;QACpB,KAAK,IAAI,GAAG,GAAG,GAAG,EAAE,GAAG,IAAI,CAAC,EAAE,GAAG,IAAI,CAAC,EAAE;YACtC,OAAO,CAAC,GAAG,EAAE,GAAG,EAAE,CAAC,KAAK,EAAE,WAAW,CAAC,CAAC,CAAC;SACzC;KACF;IAED,gBAAgB;IAChB,KAAK,IAAI,WAAW,GAAG,WAAW,EAAE,WAAW,GAAG,KAAK,EAAE,WAAW,IAAI,CAAC,EAAE;QACzE,MAAM,MAAM,GAAG,SAAS,EAAE,CAAC;QAC3B,MAAM,YAAY,GAAG,IAAI,YAAY,CAAC,CAAC,KAAK,EAAE,WAAW,GAAG,CAAC,CAAC,CAAC,CAAC;QAChE,MAAM,SAAS,GAAG,OAAO,KAAK,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3C,MAAM,YAAY,GAAG,CAAC,CAAC,OAAO,CAAC,EAAE,CAAC,SAAS,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC;QACvD,MAAM,WAAW,GAAG,OAAO,CAAC;QAC5B,OAAO;YACH,OAAO,CAAC,eAAe,CAAC,YAAY,EAAE,MAAM,EAAE,OAAO,EAAE,YAAY,CAAC,CAAC;QACzE,mCAAmC,CAAC,OAAO,EAAE,WAAW,CAAC,CAAC;QAE1D,kBAAkB;QAClB,MAAM,GAAG,GAAG,KAAK,GAAG,CAAC,CAAC;QACtB,MAAM,GAAG,GAAG,GAAG,GAAG,CAAC,CAAC;QACpB,KAAK,IAAI,GAAG,GAAG,GAAG,EAAE,GAAG,IAAI,CAAC,EAAE,GAAG,IAAI,CAAC,EAAE;YACtC,OAAO,CAAC,GAAG,EAAE,GAAG,EAAE,OAAO,CAAC,KAAK,CAAC,CAAC;SAClC;KACF;IAED,yDAAyD;IACzD,IAAI,WAAW,GAAG,OAAO,CAAC;IAC1B,OAAO,GAAG,KAAK,CACX,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,OAAO,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,CAAC,EAAE,IAAI,EAAE,CAAC,KAAK,EAAE,CAAC,CAAC,EAAC,EAAC,CAAC,CAAC;IAC1E,mCAAmC,CAAC,OAAO,EAAE,WAAW,CAAC,CAAC;IAE1D,kCAAkC;IAClC,IAAI,MAAM,GAAG,QAAQ,CACjB,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,GAAG,EAAE,OAAO,EAAC,EAAE,OAAO,EAAE,KAAK,EAAE,EAAC,IAAI,EAAE,CAAC,EAAE,SAAS,EAAE,CAAC,EAAC,EAAC,CAAC,CAAC;IAC1E,mCAAmC,CAAC,OAAO,EAAE,GAAG,CAAC,CAAC;IAElD,iEAAiE;IACjE,kBAAkB;IAClB,MAAM,QAAQ,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACrC,QAAQ,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAEjB,WAAW,GAAG,OAAO,CAAC;IACtB,OAAO,GAAG,OAAO,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,OAAO,EAAC,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,QAAQ,EAAC,EAAE,OAAO,EAAC,CAAC,CAAC;IAC7E,mCAAmC,CAAC,OAAO,EAAE,WAAW,CAAC,CAAC;IAE1D,MAAM,UAAU,GAAG,MAAM,CAAC;IAC1B,MAAM,GAAG,OAAO,CAAC,EAAC,MAAM,EAAE,EAAC,CAAC,EAAE,MAAM,EAAC,EAAE,KAAK,EAAE,EAAC,KAAK,EAAE,QAAQ,EAAC,EAAE,OAAO,EAAC,CAAC,CAAC;IAC3E,mCAAmC,CAAC,OAAO,EAAE,UAAU,CAAC,CAAC;IAEzD,OAAO,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;AAC3B,CAAC;AAED,MAAM,CAAC,MAAM,UAAU,GAAiB;IACtC,UAAU,EAAE,IAAI;IAChB,WAAW,EAAE,OAAO;IACpB,UAAU,EAAE,IAA6B;CAC1C,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env, KernelConfig, KernelFunc, NumericDataType, TensorInfo, TopK, TopKAttrs, TopKInputs, TypedArray, util} from '@tensorflow/tfjs-core';\n\nimport {MathBackendWebGL} from '../backend_webgl';\nimport {topKImplCPU} from '../kernel_utils/shared';\nimport {MergeProgram, SwapProgram} from '../top_k_gpu';\nimport {fill} from './Fill';\nimport {gatherV2} from './GatherV2';\nimport {reshape} from './Reshape';\nimport {slice} from './Slice';\n\nfunction disposeIntermediateTensorInfoOrNull(\n    backend: MathBackendWebGL, tensorInfo: TensorInfo) {\n  if (tensorInfo !== null) {\n    backend.disposeIntermediateTensorInfo(tensorInfo);\n  }\n}\n\nfunction roundUpToPow2(num: number) {\n  let pow2 = 1;\n  while (pow2 < num) {\n    pow2 *= 2;\n  }\n  return pow2;\n}\n\n// Based on Algorithm 2 of Bitonic Top K, ref:\n// https://anilshanbhag.in/static/papers/gputopk_sigmod18.pdf\nexport function topK(\n    args: {inputs: TopKInputs, backend: MathBackendWebGL, attrs: TopKAttrs}):\n    TensorInfo[] {\n  const {inputs, backend, attrs} = args;\n  const {x} = inputs;\n  const {k, sorted} = attrs;\n\n  // Empirically determined constant used to determine last dim threshold for\n  // handing off execution to the CPU.\n  const TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD =\n      env().getNumber('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD');\n\n  // Empirically determined constant used to determine k threshold for handing\n  // off execution to the CPU.\n  const TOPK_K_CPU_HANDOFF_THRESHOLD =\n      env().getNumber('TOPK_K_CPU_HANDOFF_THRESHOLD');\n\n  const xShape = x.shape;\n  const lastDim = xShape[xShape.length - 1];\n\n  if (backend.shouldExecuteOnCPU([x]) ||\n      lastDim < TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD ||\n      k > TOPK_K_CPU_HANDOFF_THRESHOLD) {\n    const xVals = backend.readSync(x.dataId) as TypedArray;\n    const [allTopKVals, allTopKIndices] =\n        topKImplCPU(xVals, xShape, x.dtype as NumericDataType, k, sorted);\n\n    return [\n      backend.makeTensorInfo(\n          allTopKVals.shape, allTopKVals.dtype, allTopKVals.values),\n      backend.makeTensorInfo(\n          allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values)\n    ];\n  }\n\n  if (k === 0) {\n    xShape[xShape.length - 1] = 0;\n    return [\n      backend.makeTensorInfo(xShape, x.dtype, []),\n      backend.makeTensorInfo(xShape, 'int32', [])\n    ];\n  }\n\n  if (lastDim === 1 /* firstPass */) {\n    return [\n      x, fill({attrs: {shape: xShape, dtype: 'int32', value: 0}, backend})\n    ];\n  }\n\n  // Eagerly unpack x input since it is passed in to all the shaders which\n  // require unpacked inputs.\n  const xtexData = backend.texData.get(x.dataId);\n  const xIsPacked = xtexData !== null && xtexData.isPacked;\n  const xUnPacked = xIsPacked ? backend.unpackTensor(x) : x;\n\n  // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim.\n  const xSize = util.sizeFromShape(xShape);\n  const batch = xSize / lastDim;\n  const x2D = reshape(\n      {inputs: {x: xUnPacked}, attrs: {shape: [batch, lastDim]}, backend});\n\n  if (xIsPacked) {\n    disposeIntermediateTensorInfoOrNull(backend, xUnPacked);\n  }\n\n  const kPow2 = roundUpToPow2(k);\n  const lastDimPow2 = roundUpToPow2(lastDim);\n\n  // Only the indices containing the top K are kept at every step to reduce\n  // number of outputs in the GPU algorithms, so once the final set of indices\n  // is computed then gather is used to grab the corresponding values\n  // from the original input.\n  let indices: TensorInfo = null;\n\n  // GPU algorithm always takes in an indices input but this input is not used\n  // on the first run of a GPU algorithm, therefore if indices is null we simply\n  // pass in x2D instead of it but the value will not actually be used\n  const getInputs = () => indices === null ? [x2D, x2D] : [x2D, indices];\n\n  const runSwap = (dir: number, inc: number, shape: number[]) => {\n    const inputs = getInputs();\n    const program = new SwapProgram(shape);\n    const fistPass = indices === null ? 1 : 0;\n    const customValues =\n        [[lastDim], [fistPass], [Number.NEGATIVE_INFINITY], [dir], [inc]];\n    const prevIndices = indices;\n    indices = backend.runWebGLProgram(program, inputs, 'int32', customValues);\n    disposeIntermediateTensorInfoOrNull(backend, prevIndices);\n  };\n\n  // Step 1: local sort\n  for (let len = 1; len < kPow2; len *= 2) {\n    const dir = len * 2;\n    for (let inc = len; inc >= 1; inc /= 2) {\n      runSwap(dir, inc, [batch, lastDimPow2]);\n    }\n  }\n\n  // Step 2: merge\n  for (let indicesSize = lastDimPow2; indicesSize > kPow2; indicesSize /= 2) {\n    const inputs = getInputs();\n    const mergeProgram = new MergeProgram([batch, indicesSize / 2]);\n    const firstPass = indices === null ? 1 : 0;\n    const customValues = [[lastDim], [firstPass], [kPow2]];\n    const prevIndices = indices;\n    indices =\n        backend.runWebGLProgram(mergeProgram, inputs, 'int32', customValues);\n    disposeIntermediateTensorInfoOrNull(backend, prevIndices);\n\n    // Step 3: rebuild\n    const len = kPow2 / 2;\n    const dir = len * 2;\n    for (let inc = len; inc >= 1; inc /= 2) {\n      runSwap(dir, inc, indices.shape);\n    }\n  }\n\n  // Keep only the requested top K results instead of kPow2\n  let prevIndices = indices;\n  indices = slice(\n      {inputs: {x: indices}, backend, attrs: {begin: 0, size: [batch, k]}});\n  disposeIntermediateTensorInfoOrNull(backend, prevIndices);\n\n  // Gather values on last dimension\n  let values = gatherV2(\n      {inputs: {x: x2D, indices}, backend, attrs: {axis: 1, batchDims: 1}});\n  disposeIntermediateTensorInfoOrNull(backend, x2D);\n\n  // Reshape back to the original input shape, except that the last\n  // dimension is k.\n  const newShape = xShape.slice(0, -1);\n  newShape.push(k);\n\n  prevIndices = indices;\n  indices = reshape({inputs: {x: indices}, attrs: {shape: newShape}, backend});\n  disposeIntermediateTensorInfoOrNull(backend, prevIndices);\n\n  const prevValues = values;\n  values = reshape({inputs: {x: values}, attrs: {shape: newShape}, backend});\n  disposeIntermediateTensorInfoOrNull(backend, prevValues);\n\n  return [values, indices];\n}\n\nexport const topKConfig: KernelConfig = {\n  kernelName: TopK,\n  backendName: 'webgl',\n  kernelFunc: topK as unknown as KernelFunc\n};\n"]}