gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
var util_1 = require("../util");
/**
 * Check whether updates.shape = indices.shape[:batchDim] +
 * shape[sliceDim:]
 *
 * @param x The input tensor.
 */
function validateUpdateShape(shape, indices, updates) {
    var sliceDim = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1;
    var batchDim = (indices.rank > 1) ? indices.rank - 1 : 1;
    var shapeError = 'Must have updates.shape = indices.shape[:batchDim] + ' +
        ("shape[sliceDim:], got updates.shape: " + updates.shape) +
        (", indices.shape: " + indices.shape + ", shape: " + shape) +
        (", sliceDim: " + sliceDim + ", and batchDim: " + batchDim + ".");
    if (updates.rank < batchDim) {
        throw new Error(shapeError + (" update.rank < " + batchDim + ". "));
    }
    if (shape.length < sliceDim + (updates.rank - batchDim)) {
        throw new Error(shapeError +
            (" Output shape length < " + (sliceDim + (updates.rank - batchDim))));
    }
    if (updates.rank !== batchDim + shape.length - sliceDim) {
        throw new Error(shapeError + (" update.rank != " + (batchDim + shape.length - sliceDim)));
    }
    for (var d = 0; d < batchDim; ++d) {
        if (updates.shape[d] !== indices.shape[d]) {
            throw new Error(shapeError +
                (" updates.shape[" + d + "] (" + updates.shape[d] + ") != indices.shape[" + d + "] (" + indices.shape[d] + ")."));
        }
    }
    for (var d = 0; d < updates.rank - batchDim; ++d) {
        if (updates.shape[d + batchDim] !== shape[d + sliceDim]) {
            throw new Error(shapeError +
                (" updates.shape[" + (d + batchDim) + "] (" + updates.shape[d + batchDim] + ") != shape[" + (d + batchDim) + "] (" + shape[d + batchDim] + ")"));
        }
    }
}
exports.validateUpdateShape = validateUpdateShape;
/**
 * Validate scatter nd inputs.
 *
 * @param update The tensor contains the update values.
 * @param indices The tensor contains the indices for the update values.
 * @param shape The shape of the output tensor.
 */
function validateInput(updates, indices, shape) {
    if (indices.rank < 1) {
        throw new Error('tf.scatterND() expects the indices to be rank 1 or higher,' +
            (" but the rank was " + indices.rank + "."));
    }
    if (updates.rank < 1) {
        throw new Error('tf.scatterND() expects the updates to be rank 1 or higher,' +
            (" but the rank was " + updates.rank + "."));
    }
    if (indices.dtype !== 'int32') {
        throw new Error("The dtype of 'indices' should be int32, but got dtype: " + indices.dtype);
    }
    if (shape.length < 1) {
        throw new Error("Output rank must be greater or equal to 1, but got shape: " + shape);
    }
    if (shape.length === 0) {
        if (indices.size === 0) {
            throw new Error("Indices specified for empty output. indices shape: " + indices.shape);
        }
        if (updates.size === 0) {
            throw new Error("Updates specified for empty output. updates shape: " + updates.shape);
        }
    }
    validateUpdateShape(shape, indices, updates);
}
exports.validateInput = validateInput;
/**
 * Calculate the shape information for the output.
 *
 * @param update The tensor contains the update values.
 * @param indices The tensor contains the indices for the update values.
 * @param shape The shape of the output tensor.
 *
 * @returns ScatterShapeInfo
 */
function calculateShapes(updates, indices, shape) {
    // Calculate the number of dimensions in indices
    var indicesRank = indices.shape.length;
    var sliceRank = (indicesRank > 1) ? indices.shape[indicesRank - 1] : 1;
    // Calculate the number of elements that make up each slice of our updated
    // tensor. This allows us to work with flattened tensors and copy over whole
    // slices at a time.
    var totalNd = shape.length;
    var sliceSize = 1;
    for (var i = sliceRank; i < totalNd; ++i) {
        sliceSize *= shape[i];
    }
    var safeSliceDim = (sliceRank < 1) ? 1 : sliceRank;
    var numUpdates = util_1.sizeFromShape(indices.shape) / safeSliceDim;
    var strides = util_1.computeStrides(shape.slice(0, sliceRank)).concat([1]);
    var outputSize = util_1.sizeFromShape(shape);
    return { sliceRank: sliceRank, numUpdates: numUpdates, sliceSize: sliceSize, strides: strides, outputSize: outputSize };
}
exports.calculateShapes = calculateShapes;
//# sourceMappingURL=scatter_nd_util.js.map