gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { backend_util, GatherV2, util } from '@tensorflow/tfjs-core';
import { assertNotComplex } from '../cpu_util';
import { gatherV2Impl } from './GatherV2_impl';
import { reshape } from './Reshape';
export function gatherV2(args) {
    const { inputs, backend, attrs } = args;
    const { x, indices } = inputs;
    const { axis, batchDims } = attrs;
    assertNotComplex([x, indices], 'gatherV2');
    // Throw error when any index is out of bound.
    const parsedAxis = util.parseAxisParam(axis, x.shape)[0];
    const indicesVals = backend.data.get(indices.dataId).values;
    const axisDim = x.shape[parsedAxis];
    for (let i = 0; i < indicesVals.length; ++i) {
        const index = indicesVals[i];
        util.assert(index <= axisDim - 1 && index >= 0, () => `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`);
    }
    let $batchDims = batchDims;
    if (batchDims == null) {
        $batchDims = 0;
    }
    const indicesSize = util.sizeFromShape(indices.shape);
    const shapeInfo = backend_util.segment_util.collectGatherOpShapeInfo(x, indices, parsedAxis, $batchDims);
    const flattenX = reshape({
        inputs: { x },
        backend,
        attrs: {
            shape: [
                shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize,
                shapeInfo.sliceSize
            ]
        }
    });
    const flattenIndex = reshape({
        inputs: { x: indices },
        backend,
        attrs: { shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize] }
    });
    const flattenOutputShape = [
        shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize,
        shapeInfo.sliceSize
    ];
    const indicesBuf = backend.bufferSync(flattenIndex);
    const xBuf = backend.bufferSync(flattenX);
    const outBuf = gatherV2Impl(xBuf, indicesBuf, flattenOutputShape);
    backend.disposeIntermediateTensorInfo(flattenX);
    backend.disposeIntermediateTensorInfo(flattenIndex);
    return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values);
}
export const gatherV2Config = {
    kernelName: GatherV2,
    backendName: 'cpu',
    kernelFunc: gatherV2
};
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiR2F0aGVyVjIuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWJhY2tlbmQtY3B1L3NyYy9rZXJuZWxzL0dhdGhlclYyLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUVILE9BQU8sRUFBQyxZQUFZLEVBQUUsUUFBUSxFQUFtRixJQUFJLEVBQUMsTUFBTSx1QkFBdUIsQ0FBQztBQUdwSixPQUFPLEVBQUMsZ0JBQWdCLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFDN0MsT0FBTyxFQUFDLFlBQVksRUFBQyxNQUFNLGlCQUFpQixDQUFDO0FBQzdDLE9BQU8sRUFBQyxPQUFPLEVBQUMsTUFBTSxXQUFXLENBQUM7QUFFbEMsTUFBTSxVQUFVLFFBQVEsQ0FBQyxJQUl4QjtJQUNDLE1BQU0sRUFBQyxNQUFNLEVBQUUsT0FBTyxFQUFFLEtBQUssRUFBQyxHQUFHLElBQUksQ0FBQztJQUN0QyxNQUFNLEVBQUMsQ0FBQyxFQUFFLE9BQU8sRUFBQyxHQUFHLE1BQU0sQ0FBQztJQUM1QixNQUFNLEVBQUMsSUFBSSxFQUFFLFNBQVMsRUFBQyxHQUFHLEtBQUssQ0FBQztJQUVoQyxnQkFBZ0IsQ0FBQyxDQUFDLENBQUMsRUFBRSxPQUFPLENBQUMsRUFBRSxVQUFVLENBQUMsQ0FBQztJQUUzQyw4Q0FBOEM7SUFDOUMsTUFBTSxVQUFVLEdBQUcsSUFBSSxDQUFDLGNBQWMsQ0FBQyxJQUFJLEVBQUUsQ0FBQyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO0lBQ3pELE1BQU0sV0FBVyxHQUFHLE9BQU8sQ0FBQyxJQUFJLENBQUMsR0FBRyxDQUFDLE9BQU8sQ0FBQyxNQUFNLENBQUMsQ0FBQyxNQUFvQixDQUFDO0lBQzFFLE1BQU0sT0FBTyxHQUFHLENBQUMsQ0FBQyxLQUFLLENBQUMsVUFBVSxDQUFDLENBQUM7SUFDcEMsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLFdBQVcsQ0FBQyxNQUFNLEVBQUUsRUFBRSxDQUFDLEVBQUU7UUFDM0MsTUFBTSxLQUFLLEdBQUcsV0FBVyxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQzdCLElBQUksQ0FBQyxNQUFNLENBQ1AsS0FBSyxJQUFJLE9BQU8sR0FBRyxDQUFDLElBQUksS0FBSyxJQUFJLENBQUMsRUFDbEMsR0FBRyxFQUFFLENBQ0QsNkJBQTZCLEtBQUssa0JBQWtCLE9BQU8sR0FBRyxDQUFDLEdBQUcsQ0FBQyxDQUFDO0tBQzdFO0lBRUQsSUFBSSxVQUFVLEdBQUcsU0FBUyxDQUFDO0lBRTNCLElBQUksU0FBUyxJQUFJLElBQUksRUFBRTtRQUNyQixVQUFVLEdBQUcsQ0FBQyxDQUFDO0tBQ2hCO0lBRUQsTUFBTSxXQUFXLEdBQUcsSUFBSSxDQUFDLGFBQWEsQ0FBQyxPQUFPLENBQUMsS0FBSyxDQUFDLENBQUM7SUFFdEQsTUFBTSxTQUFTLEdBQUcsWUFBWSxDQUFDLFlBQVksQ0FBQyx3QkFBd0IsQ0FDaEUsQ0FBQyxFQUFFLE9BQU8sRUFBRSxVQUFVLEVBQUUsVUFBVSxDQUFDLENBQUM7SUFFeEMsTUFBTSxRQUFRLEdBQUcsT0FBTyxDQUFDO1FBQ3ZCLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBQztRQUNYLE9BQU87UUFDUCxLQUFLLEVBQUU7WUFDTCxLQUFLLEVBQUU7Z0JBQ0wsU0FBUyxDQUFDLFNBQVMsRUFBRSxTQUFTLENBQUMsU0FBUyxFQUFFLFNBQVMsQ0FBQyxPQUFPO2dCQUMzRCxTQUFTLENBQUMsU0FBUzthQUNwQjtTQUNGO0tBQ0YsQ0FBQyxDQUFDO0lBRUgsTUFBTSxZQUFZLEdBQUcsT0FBTyxDQUFDO1FBQzNCLE1BQU0sRUFBRSxFQUFDLENBQUMsRUFBRSxPQUFPLEVBQUM7UUFDcEIsT0FBTztRQUNQLEtBQUssRUFBRSxFQUFDLEtBQUssRUFBRSxDQUFDLFNBQVMsQ0FBQyxTQUFTLEVBQUUsV0FBVyxHQUFHLFNBQVMsQ0FBQyxTQUFTLENBQUMsRUFBQztLQUN6RSxDQUFDLENBQUM7SUFFSCxNQUFNLGtCQUFrQixHQUFHO1FBQ3pCLFNBQVMsQ0FBQyxTQUFTLEVBQUUsU0FBUyxDQUFDLFNBQVMsRUFBRSxXQUFXLEdBQUcsU0FBUyxDQUFDLFNBQVM7UUFDM0UsU0FBUyxDQUFDLFNBQVM7S0FDcEIsQ0FBQztJQUVGLE1BQU0sVUFBVSxHQUFHLE9BQU8sQ0FBQyxVQUFVLENBQUMsWUFBWSxDQUFDLENBQUM7SUFDcEQsTUFBTSxJQUFJLEdBQUcsT0FBTyxDQUFDLFVBQVUsQ0FBQyxRQUFRLENBQUMsQ0FBQztJQUMxQyxNQUFNLE1BQU0sR0FBRyxZQUFZLENBQUMsSUFBSSxFQUFFLFVBQVUsRUFBRSxrQkFBa0IsQ0FBQyxDQUFDO0lBRWxFLE9BQU8sQ0FBQyw2QkFBNkIsQ0FBQyxRQUFRLENBQUMsQ0FBQztJQUNoRCxPQUFPLENBQUMsNkJBQTZCLENBQUMsWUFBWSxDQUFDLENBQUM7SUFFcEQsT0FBTyxPQUFPLENBQUMsY0FBYyxDQUN6QixTQUFTLENBQUMsV0FBVyxFQUFFLE1BQU0sQ0FBQyxLQUFLLEVBQUUsTUFBTSxDQUFDLE1BQU0sQ0FBQyxDQUFDO0FBQzFELENBQUM7QUFFRCxNQUFNLENBQUMsTUFBTSxjQUFjLEdBQWlCO0lBQzFDLFVBQVUsRUFBRSxRQUFRO0lBQ3BCLFdBQVcsRUFBRSxLQUFLO0lBQ2xCLFVBQVUsRUFBRSxRQUFpQztDQUM5QyxDQUFDIiwic291cmNlc0NvbnRlbnQiOlsiLyoqXG4gKiBAbGljZW5zZVxuICogQ29weXJpZ2h0IDIwMjAgR29vZ2xlIExMQy4gQWxsIFJpZ2h0cyBSZXNlcnZlZC5cbiAqIExpY2Vuc2VkIHVuZGVyIHRoZSBBcGFjaGUgTGljZW5zZSwgVmVyc2lvbiAyLjAgKHRoZSBcIkxpY2Vuc2VcIik7XG4gKiB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuXG4gKiBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXRcbiAqXG4gKiBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjBcbiAqXG4gKiBVbmxlc3MgcmVxdWlyZWQgYnkgYXBwbGljYWJsZSBsYXcgb3IgYWdyZWVkIHRvIGluIHdyaXRpbmcsIHNvZnR3YXJlXG4gKiBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiBcIkFTIElTXCIgQkFTSVMsXG4gKiBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC5cbiAqIFNlZSB0aGUgTGljZW5zZSBmb3IgdGhlIHNwZWNpZmljIGxhbmd1YWdlIGdvdmVybmluZyBwZXJtaXNzaW9ucyBhbmRcbiAqIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLlxuICogPT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT1cbiAqL1xuXG5pbXBvcnQge2JhY2tlbmRfdXRpbCwgR2F0aGVyVjIsIEdhdGhlclYyQXR0cnMsIEdhdGhlclYySW5wdXRzLCBLZXJuZWxDb25maWcsIEtlcm5lbEZ1bmMsIFRlbnNvckluZm8sIFR5cGVkQXJyYXksIHV0aWx9IGZyb20gJ0B0ZW5zb3JmbG93L3RmanMtY29yZSc7XG5cbmltcG9ydCB7TWF0aEJhY2tlbmRDUFV9IGZyb20gJy4uL2JhY2tlbmRfY3B1JztcbmltcG9ydCB7YXNzZXJ0Tm90Q29tcGxleH0gZnJvbSAnLi4vY3B1X3V0aWwnO1xuaW1wb3J0IHtnYXRoZXJWMkltcGx9IGZyb20gJy4vR2F0aGVyVjJfaW1wbCc7XG5pbXBvcnQge3Jlc2hhcGV9IGZyb20gJy4vUmVzaGFwZSc7XG5cbmV4cG9ydCBmdW5jdGlvbiBnYXRoZXJWMihhcmdzOiB7XG4gIGlucHV0czogR2F0aGVyVjJJbnB1dHMsXG4gIGJhY2tlbmQ6IE1hdGhCYWNrZW5kQ1BVLFxuICBhdHRyczogR2F0aGVyVjJBdHRyc1xufSk6IFRlbnNvckluZm8ge1xuICBjb25zdCB7aW5wdXRzLCBiYWNrZW5kLCBhdHRyc30gPSBhcmdzO1xuICBjb25zdCB7eCwgaW5kaWNlc30gPSBpbnB1dHM7XG4gIGNvbnN0IHtheGlzLCBiYXRjaERpbXN9ID0gYXR0cnM7XG5cbiAgYXNzZXJ0Tm90Q29tcGxleChbeCwgaW5kaWNlc10sICdnYXRoZXJWMicpO1xuXG4gIC8vIFRocm93IGVycm9yIHdoZW4gYW55IGluZGV4IGlzIG91dCBvZiBib3VuZC5cbiAgY29uc3QgcGFyc2VkQXhpcyA9IHV0aWwucGFyc2VBeGlzUGFyYW0oYXhpcywgeC5zaGFwZSlbMF07XG4gIGNvbnN0IGluZGljZXNWYWxzID0gYmFja2VuZC5kYXRhLmdldChpbmRpY2VzLmRhdGFJZCkudmFsdWVzIGFzIFR5cGVkQXJyYXk7XG4gIGNvbnN0IGF4aXNEaW0gPSB4LnNoYXBlW3BhcnNlZEF4aXNdO1xuICBmb3IgKGxldCBpID0gMDsgaSA8IGluZGljZXNWYWxzLmxlbmd0aDsgKytpKSB7XG4gICAgY29uc3QgaW5kZXggPSBpbmRpY2VzVmFsc1tpXTtcbiAgICB1dGlsLmFzc2VydChcbiAgICAgICAgaW5kZXggPD0gYXhpc0RpbSAtIDEgJiYgaW5kZXggPj0gMCxcbiAgICAgICAgKCkgPT5cbiAgICAgICAgICAgIGBHYXRoZXJWMjogdGhlIGluZGV4IHZhbHVlICR7aW5kZXh9IGlzIG5vdCBpbiBbMCwgJHtheGlzRGltIC0gMX1dYCk7XG4gIH1cblxuICBsZXQgJGJhdGNoRGltcyA9IGJhdGNoRGltcztcblxuICBpZiAoYmF0Y2hEaW1zID09IG51bGwpIHtcbiAgICAkYmF0Y2hEaW1zID0gMDtcbiAgfVxuXG4gIGNvbnN0IGluZGljZXNTaXplID0gdXRpbC5zaXplRnJvbVNoYXBlKGluZGljZXMuc2hhcGUpO1xuXG4gIGNvbnN0IHNoYXBlSW5mbyA9IGJhY2tlbmRfdXRpbC5zZWdtZW50X3V0aWwuY29sbGVjdEdhdGhlck9wU2hhcGVJbmZvKFxuICAgICAgeCwgaW5kaWNlcywgcGFyc2VkQXhpcywgJGJhdGNoRGltcyk7XG5cbiAgY29uc3QgZmxhdHRlblggPSByZXNoYXBlKHtcbiAgICBpbnB1dHM6IHt4fSxcbiAgICBiYWNrZW5kLFxuICAgIGF0dHJzOiB7XG4gICAgICBzaGFwZTogW1xuICAgICAgICBzaGFwZUluZm8uYmF0Y2hTaXplLCBzaGFwZUluZm8ub3V0ZXJTaXplLCBzaGFwZUluZm8uZGltU2l6ZSxcbiAgICAgICAgc2hhcGVJbmZvLnNsaWNlU2l6ZVxuICAgICAgXVxuICAgIH1cbiAgfSk7XG5cbiAgY29uc3QgZmxhdHRlbkluZGV4ID0gcmVzaGFwZSh7XG4gICAgaW5wdXRzOiB7eDogaW5kaWNlc30sXG4gICAgYmFja2VuZCxcbiAgICBhdHRyczoge3NoYXBlOiBbc2hhcGVJbmZvLmJhdGNoU2l6ZSwgaW5kaWNlc1NpemUgLyBzaGFwZUluZm8uYmF0Y2hTaXplXX1cbiAgfSk7XG5cbiAgY29uc3QgZmxhdHRlbk91dHB1dFNoYXBlID0gW1xuICAgIHNoYXBlSW5mby5iYXRjaFNpemUsIHNoYXBlSW5mby5vdXRlclNpemUsIGluZGljZXNTaXplIC8gc2hhcGVJbmZvLmJhdGNoU2l6ZSxcbiAgICBzaGFwZUluZm8uc2xpY2VTaXplXG4gIF07XG5cbiAgY29uc3QgaW5kaWNlc0J1ZiA9IGJhY2tlbmQuYnVmZmVyU3luYyhmbGF0dGVuSW5kZXgpO1xuICBjb25zdCB4QnVmID0gYmFja2VuZC5idWZmZXJTeW5jKGZsYXR0ZW5YKTtcbiAgY29uc3Qgb3V0QnVmID0gZ2F0aGVyVjJJbXBsKHhCdWYsIGluZGljZXNCdWYsIGZsYXR0ZW5PdXRwdXRTaGFwZSk7XG5cbiAgYmFja2VuZC5kaXNwb3NlSW50ZXJtZWRpYXRlVGVuc29ySW5mbyhmbGF0dGVuWCk7XG4gIGJhY2tlbmQuZGlzcG9zZUludGVybWVkaWF0ZVRlbnNvckluZm8oZmxhdHRlbkluZGV4KTtcblxuICByZXR1cm4gYmFja2VuZC5tYWtlVGVuc29ySW5mbyhcbiAgICAgIHNoYXBlSW5mby5vdXRwdXRTaGFwZSwgb3V0QnVmLmR0eXBlLCBvdXRCdWYudmFsdWVzKTtcbn1cblxuZXhwb3J0IGNvbnN0IGdhdGhlclYyQ29uZmlnOiBLZXJuZWxDb25maWcgPSB7XG4gIGtlcm5lbE5hbWU6IEdhdGhlclYyLFxuICBiYWNrZW5kTmFtZTogJ2NwdScsXG4gIGtlcm5lbEZ1bmM6IGdhdGhlclYyIGFzIHVua25vd24gYXMgS2VybmVsRnVuY1xufTtcbiJdfQ==