gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
var util_1 = require("../util");
/**
 * Validate gather nd inputs.
 *
 * @param tensor The tensor contains the source values.
 * @param indices The tensor contains the indices to slice the source.
 *
 * @returns [resultShape, numUpdates, sliceSize, strides]
 */
function prepareAndValidate(tensor, indices) {
    if (tensor.rank < 1) {
        throw new Error('tf.gatherND() expects the input to be rank 1 or higher,' +
            (" but the rank was " + tensor.rank + "."));
    }
    if (indices.rank < 1) {
        throw new Error('tf.gatherND() expects the indices to be rank 1 or higher,' +
            (" but the rank was " + indices.rank + "."));
    }
    if (indices.dtype !== 'int32') {
        throw new Error('tf.gatherND() expects the indices to be int32 type,' +
            (" but the dtype was " + indices.dtype + "."));
    }
    if (indices.shape[indices.rank - 1] > tensor.rank) {
        throw new Error('index innermost dimension length must be <= tensor rank; saw: ' +
            (indices.shape[indices.rank - 1] + " vs. " + tensor.rank));
    }
    if (tensor.size === 0) {
        throw new Error('Requested more than 0 entries, but input is empty.' +
            (" Input shape: " + tensor.shape + "."));
    }
    var indicesShape = indices.shape;
    var sliceRank = indicesShape[indicesShape.length - 1];
    // The result shape is
    //   indices.shape[:-1] + params.shape[indices.shape[-1]:]
    var nResult = 1;
    for (var i = 0; i < indicesShape.length - 1; ++i) {
        nResult *= indicesShape[i];
    }
    var inputShape = tensor.shape;
    var resultShape = indicesShape.slice();
    resultShape.pop();
    var sliceSize = 1;
    for (var i = sliceRank; i < tensor.rank; ++i) {
        sliceSize *= inputShape[i];
        resultShape.push(inputShape[i]);
    }
    var strides = util_1.computeStrides(tensor.shape).map(function (stride) { return stride / sliceSize; }).concat([1]).slice(0, sliceRank);
    return [resultShape, nResult, sliceSize, strides];
}
exports.prepareAndValidate = prepareAndValidate;
//# sourceMappingURL=gather_nd_util.js.map