"use strict"; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ Object.defineProperty(exports, "__esModule", { value: true }); var engine_1 = require("./engine"); var environment_1 = require("./environment"); var tensor_1 = require("./tensor"); var util_1 = require("./util"); function inferShape(val, dtype) { var firstElem = val; if (util_1.isTypedArray(val)) { return dtype === 'string' ? [] : [val.length]; } if (!Array.isArray(val)) { return []; // Scalar. } var shape = []; while (Array.isArray(firstElem) || util_1.isTypedArray(firstElem) && dtype !== 'string') { shape.push(firstElem.length); firstElem = firstElem[0]; } if (Array.isArray(val) && environment_1.env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) { deepAssertShapeConsistency(val, shape, []); } return shape; } exports.inferShape = inferShape; function deepAssertShapeConsistency(val, shape, indices) { indices = indices || []; if (!(Array.isArray(val)) && !util_1.isTypedArray(val)) { util_1.assert(shape.length === 0, function () { return "Element arr[" + indices.join('][') + "] is a primitive, " + ("but should be an array/TypedArray of " + shape[0] + " elements"); }); return; } util_1.assert(shape.length > 0, function () { return "Element arr[" + indices.join('][') + "] should be a primitive, " + ("but is an array of " + val.length + " elements"); }); util_1.assert(val.length === shape[0], function () { return "Element arr[" + indices.join('][') + "] should have " + shape[0] + " " + ("elements, but has " + val.length + " elements"); }); var subShape = shape.slice(1); for (var i = 0; i < val.length; ++i) { deepAssertShapeConsistency(val[i], subShape, indices.concat(i)); } } function assertDtype(expectedDtype, actualDType, argName, functionName) { if (expectedDtype == null) { return; } if (expectedDtype !== 'numeric' && expectedDtype !== actualDType || expectedDtype === 'numeric' && actualDType === 'string') { throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must " + ("be " + expectedDtype + " tensor, but got " + actualDType + " tensor")); } } function convertToTensor(x, argName, functionName, parseAsDtype) { if (parseAsDtype === void 0) { parseAsDtype = 'numeric'; } if (x instanceof tensor_1.Tensor) { assertDtype(parseAsDtype, x.dtype, argName, functionName); return x; } var inferredDtype = util_1.inferDtype(x); // If the user expects a bool/int/float, use that info to update the // inferredDtype when it is not a string. if (inferredDtype !== 'string' && ['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) { inferredDtype = parseAsDtype; } assertDtype(parseAsDtype, inferredDtype, argName, functionName); if ((x == null) || (!util_1.isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' && typeof x !== 'boolean' && typeof x !== 'string')) { var type = x == null ? 'null' : x.constructor.name; throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must be a " + ("Tensor or TensorLike, but got '" + type + "'")); } var inferredShape = inferShape(x, inferredDtype); if (!util_1.isTypedArray(x) && !Array.isArray(x)) { x = [x]; } var skipTypedArray = true; var values = inferredDtype !== 'string' ? util_1.toTypedArray(x, inferredDtype, environment_1.env().getBool('DEBUG')) : util_1.flatten(x, [], skipTypedArray); return engine_1.ENGINE.makeTensor(values, inferredShape, inferredDtype); } exports.convertToTensor = convertToTensor; function convertToTensorArray(arg, argName, functionName, parseAsDtype) { if (parseAsDtype === void 0) { parseAsDtype = 'numeric'; } if (!Array.isArray(arg)) { throw new Error("Argument " + argName + " passed to " + functionName + " must be a " + '`Tensor[]` or `TensorLike[]`'); } var tensors = arg; return tensors.map(function (t, i) { return convertToTensor(t, argName + "[" + i + "]", functionName); }, parseAsDtype); } exports.convertToTensorArray = convertToTensorArray; //# sourceMappingURL=tensor_util_env.js.map