gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"use strict";
/**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
Object.defineProperty(exports, "__esModule", { value: true });
var engine_1 = require("./engine");
var environment_1 = require("./environment");
var tensor_1 = require("./tensor");
var util_1 = require("./util");
function inferShape(val, dtype) {
    var firstElem = val;
    if (util_1.isTypedArray(val)) {
        return dtype === 'string' ? [] : [val.length];
    }
    if (!Array.isArray(val)) {
        return []; // Scalar.
    }
    var shape = [];
    while (Array.isArray(firstElem) ||
        util_1.isTypedArray(firstElem) && dtype !== 'string') {
        shape.push(firstElem.length);
        firstElem = firstElem[0];
    }
    if (Array.isArray(val) &&
        environment_1.env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) {
        deepAssertShapeConsistency(val, shape, []);
    }
    return shape;
}
exports.inferShape = inferShape;
function deepAssertShapeConsistency(val, shape, indices) {
    indices = indices || [];
    if (!(Array.isArray(val)) && !util_1.isTypedArray(val)) {
        util_1.assert(shape.length === 0, function () { return "Element arr[" + indices.join('][') + "] is a primitive, " +
            ("but should be an array/TypedArray of " + shape[0] + " elements"); });
        return;
    }
    util_1.assert(shape.length > 0, function () { return "Element arr[" + indices.join('][') + "] should be a primitive, " +
        ("but is an array of " + val.length + " elements"); });
    util_1.assert(val.length === shape[0], function () { return "Element arr[" + indices.join('][') + "] should have " + shape[0] + " " +
        ("elements, but has " + val.length + " elements"); });
    var subShape = shape.slice(1);
    for (var i = 0; i < val.length; ++i) {
        deepAssertShapeConsistency(val[i], subShape, indices.concat(i));
    }
}
function assertDtype(expectedDtype, actualDType, argName, functionName) {
    if (expectedDtype == null) {
        return;
    }
    if (expectedDtype !== 'numeric' && expectedDtype !== actualDType ||
        expectedDtype === 'numeric' && actualDType === 'string') {
        throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must " +
            ("be " + expectedDtype + " tensor, but got " + actualDType + " tensor"));
    }
}
function convertToTensor(x, argName, functionName, parseAsDtype) {
    if (parseAsDtype === void 0) { parseAsDtype = 'numeric'; }
    if (x instanceof tensor_1.Tensor) {
        assertDtype(parseAsDtype, x.dtype, argName, functionName);
        return x;
    }
    var inferredDtype = util_1.inferDtype(x);
    // If the user expects a bool/int/float, use that info to update the
    // inferredDtype when it is not a string.
    if (inferredDtype !== 'string' &&
        ['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) {
        inferredDtype = parseAsDtype;
    }
    assertDtype(parseAsDtype, inferredDtype, argName, functionName);
    if ((x == null) ||
        (!util_1.isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' &&
            typeof x !== 'boolean' && typeof x !== 'string')) {
        var type = x == null ? 'null' : x.constructor.name;
        throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must be a " +
            ("Tensor or TensorLike, but got '" + type + "'"));
    }
    var inferredShape = inferShape(x, inferredDtype);
    if (!util_1.isTypedArray(x) && !Array.isArray(x)) {
        x = [x];
    }
    var skipTypedArray = true;
    var values = inferredDtype !== 'string' ?
        util_1.toTypedArray(x, inferredDtype, environment_1.env().getBool('DEBUG')) :
        util_1.flatten(x, [], skipTypedArray);
    return engine_1.ENGINE.makeTensor(values, inferredShape, inferredDtype);
}
exports.convertToTensor = convertToTensor;
function convertToTensorArray(arg, argName, functionName, parseAsDtype) {
    if (parseAsDtype === void 0) { parseAsDtype = 'numeric'; }
    if (!Array.isArray(arg)) {
        throw new Error("Argument " + argName + " passed to " + functionName + " must be a " +
            '`Tensor[]` or `TensorLike[]`');
    }
    var tensors = arg;
    return tensors.map(function (t, i) { return convertToTensor(t, argName + "[" + i + "]", functionName); }, parseAsDtype);
}
exports.convertToTensorArray = convertToTensorArray;
//# sourceMappingURL=tensor_util_env.js.map