/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { ENGINE } from './engine';
|
import { env } from './environment';
|
import { getGlobalTensorClass } from './tensor';
|
import { isWebGLData, isWebGPUData } from './types';
|
import { assert, flatten, inferDtype, isTypedArray, toTypedArray } from './util';
|
import { bytesPerElement } from './util_base';
|
export function inferShape(val, dtype) {
|
let firstElem = val;
|
if (isTypedArray(val)) {
|
return dtype === 'string' ? [] : [val.length];
|
}
|
if (isWebGLData(val)) {
|
const usedChannels = val.channels || 'RGBA';
|
return [val.height, val.width * usedChannels.length];
|
}
|
else if (isWebGPUData(val)) {
|
return [val.buffer.size / (dtype == null ? 4 : bytesPerElement(dtype))];
|
}
|
if (!Array.isArray(val)) {
|
return []; // Scalar.
|
}
|
const shape = [];
|
while (Array.isArray(firstElem) ||
|
isTypedArray(firstElem) && dtype !== 'string') {
|
shape.push(firstElem.length);
|
firstElem = firstElem[0];
|
}
|
if (Array.isArray(val) &&
|
env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) {
|
deepAssertShapeConsistency(val, shape, []);
|
}
|
return shape;
|
}
|
function deepAssertShapeConsistency(val, shape, indices) {
|
indices = indices || [];
|
if (!(Array.isArray(val)) && !isTypedArray(val)) {
|
assert(shape.length === 0, () => `Element arr[${indices.join('][')}] is a primitive, ` +
|
`but should be an array/TypedArray of ${shape[0]} elements`);
|
return;
|
}
|
assert(shape.length > 0, () => `Element arr[${indices.join('][')}] should be a primitive, ` +
|
`but is an array of ${val.length} elements`);
|
assert(val.length === shape[0], () => `Element arr[${indices.join('][')}] should have ${shape[0]} ` +
|
`elements, but has ${val.length} elements`);
|
const subShape = shape.slice(1);
|
for (let i = 0; i < val.length; ++i) {
|
deepAssertShapeConsistency(val[i], subShape, indices.concat(i));
|
}
|
}
|
function assertDtype(expectedDtype, actualDType, argName, functionName) {
|
if (expectedDtype === 'string_or_numeric') {
|
return;
|
}
|
if (expectedDtype == null) {
|
throw new Error(`Expected dtype cannot be null.`);
|
}
|
if (expectedDtype !== 'numeric' && expectedDtype !== actualDType ||
|
expectedDtype === 'numeric' && actualDType === 'string') {
|
throw new Error(`Argument '${argName}' passed to '${functionName}' must ` +
|
`be ${expectedDtype} tensor, but got ${actualDType} tensor`);
|
}
|
}
|
export function convertToTensor(x, argName, functionName, parseAsDtype = 'numeric') {
|
if (x instanceof getGlobalTensorClass()) {
|
assertDtype(parseAsDtype, x.dtype, argName, functionName);
|
return x;
|
}
|
let inferredDtype = inferDtype(x);
|
// If the user expects a bool/int/float, use that info to update the
|
// inferredDtype when it is not a string.
|
if (inferredDtype !== 'string' &&
|
['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) {
|
inferredDtype = parseAsDtype;
|
}
|
assertDtype(parseAsDtype, inferredDtype, argName, functionName);
|
if ((x == null) ||
|
(!isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' &&
|
typeof x !== 'boolean' && typeof x !== 'string')) {
|
const type = x == null ? 'null' : x.constructor.name;
|
throw new Error(`Argument '${argName}' passed to '${functionName}' must be a ` +
|
`Tensor or TensorLike, but got '${type}'`);
|
}
|
const inferredShape = inferShape(x, inferredDtype);
|
if (!isTypedArray(x) && !Array.isArray(x)) {
|
x = [x];
|
}
|
const skipTypedArray = true;
|
const values = inferredDtype !== 'string' ?
|
toTypedArray(x, inferredDtype) :
|
flatten(x, [], skipTypedArray);
|
return ENGINE.makeTensor(values, inferredShape, inferredDtype);
|
}
|
export function convertToTensorArray(arg, argName, functionName, parseAsDtype = 'numeric') {
|
if (!Array.isArray(arg)) {
|
throw new Error(`Argument ${argName} passed to ${functionName} must be a ` +
|
'`Tensor[]` or `TensorLike[]`');
|
}
|
const tensors = arg;
|
return tensors.map((t, i) => convertToTensor(t, `${argName}[${i}]`, functionName, parseAsDtype));
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"tensor_util_env.js","sourceRoot":"","sources":["../../../../../tfjs-core/src/tensor_util_env.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,MAAM,EAAC,MAAM,UAAU,CAAC;AAChC,OAAO,EAAC,GAAG,EAAC,MAAM,eAAe,CAAC;AAClC,OAAO,EAAC,oBAAoB,EAAS,MAAM,UAAU,CAAC;AACtD,OAAO,EAAW,WAAW,EAAE,YAAY,EAAoC,MAAM,SAAS,CAAC;AAC/F,OAAO,EAAC,MAAM,EAAE,OAAO,EAAE,UAAU,EAAE,YAAY,EAAE,YAAY,EAAC,MAAM,QAAQ,CAAC;AAC/E,OAAO,EAAC,eAAe,EAAC,MAAM,aAAa,CAAC;AAE5C,MAAM,UAAU,UAAU,CACtB,GAAoC,EAAE,KAAgB;IACxD,IAAI,SAAS,GAAe,GAAG,CAAC;IAEhC,IAAI,YAAY,CAAC,GAAG,CAAC,EAAE;QACrB,OAAO,KAAK,KAAK,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC;KAC/C;IAED,IAAI,WAAW,CAAC,GAAG,CAAC,EAAE;QACpB,MAAM,YAAY,GAAG,GAAG,CAAC,QAAQ,IAAI,MAAM,CAAC;QAC5C,OAAO,CAAC,GAAG,CAAC,MAAM,EAAE,GAAG,CAAC,KAAK,GAAG,YAAY,CAAC,MAAM,CAAC,CAAC;KACtD;SAAM,IAAI,YAAY,CAAC,GAAG,CAAC,EAAE;QAC5B,OAAO,CAAC,GAAG,CAAC,MAAM,CAAC,IAAI,GAAG,CAAC,KAAK,IAAI,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,eAAe,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;KACzE;IACD,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,GAAG,CAAC,EAAE;QACvB,OAAO,EAAE,CAAC,CAAE,UAAU;KACvB;IACD,MAAM,KAAK,GAAa,EAAE,CAAC;IAE3B,OAAO,KAAK,CAAC,OAAO,CAAC,SAAS,CAAC;QACxB,YAAY,CAAC,SAAS,CAAC,IAAI,KAAK,KAAK,QAAQ,EAAE;QACpD,KAAK,CAAC,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC;QAC7B,SAAS,GAAG,SAAS,CAAC,CAAC,CAAC,CAAC;KAC1B;IACD,IAAI,KAAK,CAAC,OAAO,CAAC,GAAG,CAAC;QAClB,GAAG,EAAE,CAAC,OAAO,CAAC,oCAAoC,CAAC,EAAE;QACvD,0BAA0B,CAAC,GAAG,EAAE,KAAK,EAAE,EAAE,CAAC,CAAC;KAC5C;IAED,OAAO,KAAK,CAAC;AACf,CAAC;AAED,SAAS,0BAA0B,CAC/B,GAAe,EAAE,KAAe,EAAE,OAAiB;IACrD,OAAO,GAAG,OAAO,IAAI,EAAE,CAAC;IACxB,IAAI,CAAC,CAAC,KAAK,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC,IAAI,CAAC,YAAY,CAAC,GAAG,CAAC,EAAE;QAC/C,MAAM,CACF,KAAK,CAAC,MAAM,KAAK,CAAC,EAClB,GAAG,EAAE,CAAC,eAAe,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,oBAAoB;YACvD,wCAAwC,KAAK,CAAC,CAAC,CAAC,WAAW,CAAC,CAAC;QACrE,OAAO;KACR;IACD,MAAM,CACF,KAAK,CAAC,MAAM,GAAG,CAAC,EAChB,GAAG,EAAE,CAAC,eAAe,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,2BAA2B;QAC9D,sBAAsB,GAAG,CAAC,MAAM,WAAW,CAAC,CAAC;IACrD,MAAM,CACF,GAAG,CAAC,MAAM,KAAK,KAAK,CAAC,CAAC,CAAC,EACvB,GAAG,EAAE,CAAC,eAAe,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,iBAAiB,KAAK,CAAC,CAAC,CAAC,GAAG;QAC/D,qBAAqB,GAAG,CAAC,MAAM,WAAW,CAAC,CAAC;IACpD,MAAM,QAAQ,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IAChC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QACnC,0BAA0B,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,QAAQ,EAAE,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;KACjE;AACH,CAAC;AAED,SAAS,WAAW,CAChB,aAAqD,EACrD,WAAqB,EAAE,OAAe,EAAE,YAAoB;IAC9D,IAAI,aAAa,KAAK,mBAAmB,EAAE;QACzC,OAAO;KACR;IACD,IAAI,aAAa,IAAI,IAAI,EAAE;QACzB,MAAM,IAAI,KAAK,CAAC,gCAAgC,CAAC,CAAC;KACnD;IACD,IAAI,aAAa,KAAK,SAAS,IAAI,aAAa,KAAK,WAAW;QAC5D,aAAa,KAAK,SAAS,IAAI,WAAW,KAAK,QAAQ,EAAE;QAC3D,MAAM,IAAI,KAAK,CACX,aAAa,OAAO,gBAAgB,YAAY,SAAS;YACzD,MAAM,aAAa,oBAAoB,WAAW,SAAS,CAAC,CAAC;KAClE;AACH,CAAC;AAED,MAAM,UAAU,eAAe,CAC3B,CAAe,EAAE,OAAe,EAAE,YAAoB,EACtD,eAAuD,SAAS;IAClE,IAAI,CAAC,YAAY,oBAAoB,EAAE,EAAE;QACvC,WAAW,CAAC,YAAY,EAAE,CAAC,CAAC,KAAK,EAAE,OAAO,EAAE,YAAY,CAAC,CAAC;QAC1D,OAAO,CAAC,CAAC;KACV;IACD,IAAI,aAAa,GAAG,UAAU,CAAC,CAAC,CAAC,CAAC;IAClC,oEAAoE;IACpE,yCAAyC;IACzC,IAAI,aAAa,KAAK,QAAQ;QAC1B,CAAC,MAAM,EAAE,OAAO,EAAE,SAAS,CAAC,CAAC,OAAO,CAAC,YAAY,CAAC,IAAI,CAAC,EAAE;QAC3D,aAAa,GAAG,YAAwB,CAAC;KAC1C;IACD,WAAW,CAAC,YAAY,EAAE,aAAa,EAAE,OAAO,EAAE,YAAY,CAAC,CAAC;IAEhE,IAAI,CAAC,CAAC,IAAI,IAAI,CAAC;QACX,CAAC,CAAC,YAAY,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC,CAAC,IAAI,OAAO,CAAC,KAAK,QAAQ;YAC9D,OAAO,CAAC,KAAK,SAAS,IAAI,OAAO,CAAC,KAAK,QAAQ,CAAC,EAAE;QACrD,MAAM,IAAI,GAAG,CAAC,IAAI,IAAI,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAE,CAAQ,CAAC,WAAW,CAAC,IAAI,CAAC;QAC7D,MAAM,IAAI,KAAK,CACX,aAAa,OAAO,gBAAgB,YAAY,cAAc;YAC9D,kCAAkC,IAAI,GAAG,CAAC,CAAC;KAChD;IACD,MAAM,aAAa,GAAG,UAAU,CAAC,CAAC,EAAE,aAAa,CAAC,CAAC;IACnD,IAAI,CAAC,YAAY,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE;QACzC,CAAC,GAAG,CAAC,CAAC,CAAa,CAAC;KACrB;IACD,MAAM,cAAc,GAAG,IAAI,CAAC;IAC5B,MAAM,MAAM,GAAG,aAAa,KAAK,QAAQ,CAAC,CAAC;QACvC,YAAY,CAAC,CAAC,EAAE,aAAyB,CAAC,CAAC,CAAC;QAC5C,OAAO,CAAC,CAAa,EAAE,EAAE,EAAE,cAAc,CAAa,CAAC;IAC3D,OAAO,MAAM,CAAC,UAAU,CAAC,MAAM,EAAE,aAAa,EAAE,aAAa,CAAM,CAAC;AACtE,CAAC;AAED,MAAM,UAAU,oBAAoB,CAChC,GAAwB,EAAE,OAAe,EAAE,YAAoB,EAC/D,eAAuD,SAAS;IAClE,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,GAAG,CAAC,EAAE;QACvB,MAAM,IAAI,KAAK,CACX,YAAY,OAAO,cAAc,YAAY,aAAa;YAC1D,8BAA8B,CAAC,CAAC;KACrC;IACD,MAAM,OAAO,GAAG,GAAU,CAAC;IAC3B,OAAO,OAAO,CAAC,GAAG,CACd,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CACL,eAAe,CAAC,CAAC,EAAE,GAAG,OAAO,IAAI,CAAC,GAAG,EAAE,YAAY,EAAE,YAAY,CAAC,CAAC,CAAC;AAC9E,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {ENGINE} from './engine';\nimport {env} from './environment';\nimport {getGlobalTensorClass, Tensor} from './tensor';\nimport {DataType, isWebGLData, isWebGPUData, TensorLike, WebGLData, WebGPUData} from './types';\nimport {assert, flatten, inferDtype, isTypedArray, toTypedArray} from './util';\nimport {bytesPerElement} from './util_base';\n\nexport function inferShape(\n    val: TensorLike|WebGLData|WebGPUData, dtype?: DataType): number[] {\n  let firstElem: typeof val = val;\n\n  if (isTypedArray(val)) {\n    return dtype === 'string' ? [] : [val.length];\n  }\n\n  if (isWebGLData(val)) {\n    const usedChannels = val.channels || 'RGBA';\n    return [val.height, val.width * usedChannels.length];\n  } else if (isWebGPUData(val)) {\n    return [val.buffer.size / (dtype == null ? 4 : bytesPerElement(dtype))];\n  }\n  if (!Array.isArray(val)) {\n    return [];  // Scalar.\n  }\n  const shape: number[] = [];\n\n  while (Array.isArray(firstElem) ||\n         isTypedArray(firstElem) && dtype !== 'string') {\n    shape.push(firstElem.length);\n    firstElem = firstElem[0];\n  }\n  if (Array.isArray(val) &&\n      env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) {\n    deepAssertShapeConsistency(val, shape, []);\n  }\n\n  return shape;\n}\n\nfunction deepAssertShapeConsistency(\n    val: TensorLike, shape: number[], indices: number[]) {\n  indices = indices || [];\n  if (!(Array.isArray(val)) && !isTypedArray(val)) {\n    assert(\n        shape.length === 0,\n        () => `Element arr[${indices.join('][')}] is a primitive, ` +\n            `but should be an array/TypedArray of ${shape[0]} elements`);\n    return;\n  }\n  assert(\n      shape.length > 0,\n      () => `Element arr[${indices.join('][')}] should be a primitive, ` +\n          `but is an array of ${val.length} elements`);\n  assert(\n      val.length === shape[0],\n      () => `Element arr[${indices.join('][')}] should have ${shape[0]} ` +\n          `elements, but has ${val.length} elements`);\n  const subShape = shape.slice(1);\n  for (let i = 0; i < val.length; ++i) {\n    deepAssertShapeConsistency(val[i], subShape, indices.concat(i));\n  }\n}\n\nfunction assertDtype(\n    expectedDtype: DataType|'numeric'|'string_or_numeric',\n    actualDType: DataType, argName: string, functionName: string) {\n  if (expectedDtype === 'string_or_numeric') {\n    return;\n  }\n  if (expectedDtype == null) {\n    throw new Error(`Expected dtype cannot be null.`);\n  }\n  if (expectedDtype !== 'numeric' && expectedDtype !== actualDType ||\n      expectedDtype === 'numeric' && actualDType === 'string') {\n    throw new Error(\n        `Argument '${argName}' passed to '${functionName}' must ` +\n        `be ${expectedDtype} tensor, but got ${actualDType} tensor`);\n  }\n}\n\nexport function convertToTensor<T extends Tensor>(\n    x: T|TensorLike, argName: string, functionName: string,\n    parseAsDtype: DataType|'numeric'|'string_or_numeric' = 'numeric'): T {\n  if (x instanceof getGlobalTensorClass()) {\n    assertDtype(parseAsDtype, x.dtype, argName, functionName);\n    return x;\n  }\n  let inferredDtype = inferDtype(x);\n  // If the user expects a bool/int/float, use that info to update the\n  // inferredDtype when it is not a string.\n  if (inferredDtype !== 'string' &&\n      ['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) {\n    inferredDtype = parseAsDtype as DataType;\n  }\n  assertDtype(parseAsDtype, inferredDtype, argName, functionName);\n\n  if ((x == null) ||\n      (!isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' &&\n       typeof x !== 'boolean' && typeof x !== 'string')) {\n    const type = x == null ? 'null' : (x as {}).constructor.name;\n    throw new Error(\n        `Argument '${argName}' passed to '${functionName}' must be a ` +\n        `Tensor or TensorLike, but got '${type}'`);\n  }\n  const inferredShape = inferShape(x, inferredDtype);\n  if (!isTypedArray(x) && !Array.isArray(x)) {\n    x = [x] as number[];\n  }\n  const skipTypedArray = true;\n  const values = inferredDtype !== 'string' ?\n      toTypedArray(x, inferredDtype as DataType) :\n      flatten(x as string[], [], skipTypedArray) as string[];\n  return ENGINE.makeTensor(values, inferredShape, inferredDtype) as T;\n}\n\nexport function convertToTensorArray<T extends Tensor>(\n    arg: Array<T|TensorLike>, argName: string, functionName: string,\n    parseAsDtype: DataType|'numeric'|'string_or_numeric' = 'numeric'): T[] {\n  if (!Array.isArray(arg)) {\n    throw new Error(\n        `Argument ${argName} passed to ${functionName} must be a ` +\n        '`Tensor[]` or `TensorLike[]`');\n  }\n  const tensors = arg as T[];\n  return tensors.map(\n      (t, i) =>\n          convertToTensor(t, `${argName}[${i}]`, functionName, parseAsDtype));\n}\n"]}
|