/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { complex } from '../ops/complex';
|
import { tensor } from '../ops/tensor';
|
import { sizeFromShape } from '../util';
|
import { DTYPE_VALUE_SIZE_MAP } from './types';
|
import { CompositeArrayBuffer } from './composite_array_buffer';
|
import { backend } from '../globals';
|
import { env } from '../environment';
|
import { getBackend } from '../globals';
|
/** Number of bytes reserved for the length of the string. (32bit integer). */
|
const NUM_BYTES_STRING_LENGTH = 4;
|
/**
|
* Encode a map from names to weight values as an ArrayBuffer, along with an
|
* `Array` of `WeightsManifestEntry` as specification of the encoded weights.
|
*
|
* This function does not perform sharding.
|
*
|
* This function is the reverse of `decodeWeights`.
|
*
|
* @param tensors A map ("dict") from names to tensors.
|
* @param group Group to which the weights belong (optional).
|
* @returns A `Promise` of
|
* - A flat `ArrayBuffer` with all the binary values of the `Tensor`s
|
* concatenated.
|
* - An `Array` of `WeightManifestEntry`s, carrying information including
|
* tensor names, `dtype`s and shapes.
|
* @throws Error: on unsupported tensor `dtype`.
|
*/
|
export async function encodeWeights(tensors, group) {
|
// TODO(adarob, cais): Support quantization.
|
const specs = [];
|
const dataPromises = [];
|
const names = Array.isArray(tensors) ?
|
tensors.map(tensor => tensor.name) :
|
Object.keys(tensors);
|
for (let i = 0; i < names.length; ++i) {
|
const name = names[i];
|
const t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name];
|
if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' &&
|
t.dtype !== 'string' && t.dtype !== 'complex64') {
|
throw new Error(`Unsupported dtype in weight '${name}': ${t.dtype}`);
|
}
|
const spec = { name, shape: t.shape, dtype: t.dtype };
|
if (t.dtype === 'string') {
|
const utf8bytes = new Promise(async (resolve) => {
|
const vals = await t.bytes();
|
const totalNumBytes = vals.reduce((p, c) => p + c.length, 0) +
|
NUM_BYTES_STRING_LENGTH * vals.length;
|
const bytes = new Uint8Array(totalNumBytes);
|
let offset = 0;
|
for (let i = 0; i < vals.length; i++) {
|
const val = vals[i];
|
const bytesOfLength = new Uint8Array(new Uint32Array([val.length]).buffer);
|
bytes.set(bytesOfLength, offset);
|
offset += NUM_BYTES_STRING_LENGTH;
|
bytes.set(val, offset);
|
offset += val.length;
|
}
|
resolve(bytes);
|
});
|
dataPromises.push(utf8bytes);
|
}
|
else {
|
dataPromises.push(t.data());
|
}
|
if (group != null) {
|
spec.group = group;
|
}
|
specs.push(spec);
|
}
|
const tensorValues = await Promise.all(dataPromises);
|
return { data: concatenateTypedArrays(tensorValues), specs };
|
}
|
/**
|
* Decode flat ArrayBuffer as weights.
|
*
|
* This function does not handle sharding.
|
*
|
* This function is the reverse of `encodeWeights`.
|
*
|
* @param weightData A flat ArrayBuffer or an array of ArrayBuffers carrying the
|
* binary values of the tensors concatenated in the order specified in
|
* `specs`.
|
* @param specs Specifications of the names, dtypes and shapes of the tensors
|
* whose value are encoded by `buffer`.
|
* @return A map from tensor name to tensor value, with the names corresponding
|
* to names in `specs`.
|
* @throws Error, if any of the tensors has unsupported dtype.
|
*/
|
export function decodeWeights(weightData, specs) {
|
// TODO(adarob, cais): Support quantization.
|
const compositeBuffer = new CompositeArrayBuffer(weightData);
|
const out = {};
|
let offset = 0;
|
for (const spec of specs) {
|
const byteLength = getWeightBytelength(spec, (start, end) => {
|
return compositeBuffer.slice(offset + start, offset + end);
|
});
|
out[spec.name] = decodeWeight(spec, compositeBuffer
|
.slice(offset, offset + byteLength));
|
offset += byteLength;
|
}
|
return out;
|
}
|
function getWeightBytelength(spec, slice) {
|
const size = sizeFromShape(spec.shape);
|
let bytesPerValue;
|
if ('quantization' in spec) {
|
const quantization = spec.quantization;
|
bytesPerValue = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
|
}
|
else if (spec.dtype === 'string') {
|
// Can not statically determine string length.
|
let byteLength = 0;
|
for (let i = 0; i < size; i++) {
|
byteLength += NUM_BYTES_STRING_LENGTH + new Uint32Array(slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH))[0];
|
}
|
return byteLength;
|
}
|
else {
|
bytesPerValue = DTYPE_VALUE_SIZE_MAP[spec.dtype];
|
}
|
return size * bytesPerValue;
|
}
|
async function getWeightBytelengthAsync(spec, slice) {
|
const size = sizeFromShape(spec.shape);
|
let bytesPerValue;
|
if ('quantization' in spec) {
|
const quantization = spec.quantization;
|
bytesPerValue = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
|
}
|
else if (spec.dtype === 'string') {
|
// Can not statically determine string length.
|
let byteLength = 0;
|
for (let i = 0; i < size; i++) {
|
byteLength += NUM_BYTES_STRING_LENGTH + new Uint32Array(await slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH))[0];
|
}
|
return byteLength;
|
}
|
else {
|
bytesPerValue = DTYPE_VALUE_SIZE_MAP[spec.dtype];
|
}
|
return size * bytesPerValue;
|
}
|
function decodeWeight(spec, byteBuffer) {
|
const name = spec.name;
|
const dtype = spec.dtype;
|
const shape = spec.shape;
|
const size = sizeFromShape(shape);
|
let values;
|
let offset = 0;
|
if ('quantization' in spec) {
|
const quantization = spec.quantization;
|
if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
|
if (!('min' in quantization && 'scale' in quantization)) {
|
throw new Error(`Weight ${spec.name} with quantization ${quantization.dtype} ` +
|
`doesn't have corresponding metadata min and scale.`);
|
}
|
}
|
else if (quantization.dtype === 'float16') {
|
if (dtype !== 'float32') {
|
throw new Error(`Weight ${spec.name} is quantized with ${quantization.dtype} ` +
|
`which only supports weights of type float32 not ${dtype}.`);
|
}
|
}
|
else {
|
throw new Error(`Weight ${spec.name} has unknown ` +
|
`quantization dtype ${quantization.dtype}. ` +
|
`Supported quantization dtypes are: ` +
|
`'uint8', 'uint16', and 'float16'.`);
|
}
|
const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
|
const quantizedArray = (quantization.dtype === 'uint8') ?
|
new Uint8Array(byteBuffer) :
|
new Uint16Array(byteBuffer);
|
if (dtype === 'float32') {
|
if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
|
values = new Float32Array(quantizedArray.length);
|
for (let i = 0; i < quantizedArray.length; i++) {
|
const v = quantizedArray[i];
|
values[i] = v * quantization.scale + quantization.min;
|
}
|
}
|
else if (quantization.dtype === 'float16') {
|
// TODO: This is inefficient. Make getFloat16Decoder efficient.
|
const float16Decode = getFloat16Decoder();
|
values = float16Decode(quantizedArray);
|
}
|
else {
|
throw new Error(`Unsupported quantization type ${quantization.dtype} ` +
|
`for weight type float32.`);
|
}
|
}
|
else if (dtype === 'int32') {
|
if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') {
|
throw new Error(`Unsupported quantization type ${quantization.dtype} ` +
|
`for weight type int32.`);
|
}
|
values = new Int32Array(quantizedArray.length);
|
for (let i = 0; i < quantizedArray.length; i++) {
|
const v = quantizedArray[i];
|
values[i] = Math.round(v * quantization.scale + quantization.min);
|
}
|
}
|
else {
|
throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
|
}
|
offset += size * quantizationSizeFactor;
|
}
|
else if (dtype === 'string') {
|
const size = sizeFromShape(spec.shape);
|
values = [];
|
for (let i = 0; i < size; i++) {
|
const byteLength = new Uint32Array(byteBuffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0];
|
offset += NUM_BYTES_STRING_LENGTH;
|
const bytes = new Uint8Array(byteBuffer.slice(offset, offset + byteLength));
|
values.push(bytes);
|
offset += byteLength;
|
}
|
}
|
else {
|
const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype];
|
if (dtype === 'float32') {
|
values = new Float32Array(byteBuffer);
|
}
|
else if (dtype === 'int32') {
|
values = new Int32Array(byteBuffer);
|
}
|
else if (dtype === 'bool') {
|
values = new Uint8Array(byteBuffer);
|
}
|
else if (dtype === 'complex64') {
|
values = new Float32Array(byteBuffer);
|
const real = new Float32Array(values.length / 2);
|
const image = new Float32Array(values.length / 2);
|
for (let i = 0; i < real.length; i++) {
|
real[i] = values[i * 2];
|
image[i] = values[i * 2 + 1];
|
}
|
const realTensor = tensor(real, shape, 'float32');
|
const imageTensor = tensor(image, shape, 'float32');
|
const complexTensor = complex(realTensor, imageTensor);
|
realTensor.dispose();
|
imageTensor.dispose();
|
return complexTensor;
|
}
|
else {
|
throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
|
}
|
offset += size * dtypeFactor;
|
}
|
return tensor(values, shape, dtype);
|
}
|
async function readToLength(reader, initialData, length) {
|
let data = new Uint8Array(initialData);
|
while (data.byteLength < length) {
|
const { done, value } = await reader.read();
|
if (done && value == null) {
|
const missing = length - data.byteLength;
|
throw new Error(`Reader is done but ${missing} bytes are still expected`);
|
}
|
// TODO: Don't create a new array every loop.
|
const newData = new Uint8Array(data.length + value.byteLength);
|
newData.set(data, 0);
|
newData.set(new Uint8Array(value), data.length);
|
data = newData;
|
}
|
return data.buffer;
|
}
|
export async function decodeWeightsStream(weightStream, specs) {
|
const tensors = {};
|
const reader = weightStream.getReader();
|
let data = new ArrayBuffer(0);
|
for (const spec of specs) {
|
const byteLength = await getWeightBytelengthAsync(spec, async (start, end) => {
|
data = await readToLength(reader, data, end);
|
return data.slice(start, end);
|
});
|
data = await readToLength(reader, data, byteLength);
|
// Slice the tensor out
|
const tensorData = data.slice(0, byteLength);
|
data = data.slice(byteLength);
|
const weightTensor = decodeWeight(spec, tensorData);
|
tensors[spec.name] = weightTensor;
|
// TODO(mattsoulanille): Better way to call uploadToGPU.
|
// TODO(mattsoulanille): Make this work for webgl too.
|
if (getBackend() === 'webgpu') {
|
const b = backend();
|
if ('uploadToGPU' in b &&
|
sizeFromShape(weightTensor.shape) >= env()
|
.get('WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD')) {
|
b.uploadToGPU(weightTensor.dataId);
|
}
|
}
|
}
|
return tensors;
|
}
|
/**
|
* Concatenate TypedArrays into an ArrayBuffer.
|
*/
|
export function concatenateTypedArrays(xs) {
|
// TODO(adarob, cais): Support quantization.
|
if (xs === null) {
|
throw new Error(`Invalid input value: ${JSON.stringify(xs)}`);
|
}
|
let totalByteLength = 0;
|
// `normalizedXs` is here for this reason: a `TypedArray`'s `buffer'
|
// can have a different byte length from that of the `TypedArray` itself,
|
// for example, when the `TypedArray` is created from an offset in an
|
// `ArrayBuffer`. `normliazedXs` holds `TypedArray`s whose `buffer`s match
|
// the `TypedArray` in byte length. If an element of `xs` does not show
|
// this property, a new `TypedArray` that satisfy this property will be
|
// constructed and pushed into `normalizedXs`.
|
const normalizedXs = [];
|
xs.forEach((x) => {
|
totalByteLength += x.byteLength;
|
// tslint:disable:no-any
|
normalizedXs.push(x.byteLength === x.buffer.byteLength ? x :
|
new x.constructor(x));
|
if (!(x instanceof Float32Array || x instanceof Int32Array ||
|
x instanceof Uint8Array)) {
|
throw new Error(`Unsupported TypedArray subtype: ${x.constructor.name}`);
|
}
|
// tslint:enable:no-any
|
});
|
const y = new Uint8Array(totalByteLength);
|
let offset = 0;
|
normalizedXs.forEach((x) => {
|
y.set(new Uint8Array(x.buffer), offset);
|
offset += x.byteLength;
|
});
|
return y.buffer;
|
}
|
// Use Buffer on Node.js instead of Blob/atob/btoa
|
const useNodeBuffer = typeof Buffer !== 'undefined' &&
|
(typeof Blob === 'undefined' || typeof atob === 'undefined' ||
|
typeof btoa === 'undefined');
|
/**
|
* Calculate the byte length of a JavaScript string.
|
*
|
* Note that a JavaScript string can contain wide characters, therefore the
|
* length of the string is not necessarily equal to the byte length.
|
*
|
* @param str Input string.
|
* @returns Byte length.
|
*/
|
export function stringByteLength(str) {
|
if (useNodeBuffer) {
|
return Buffer.byteLength(str, 'utf8');
|
}
|
return new Blob([str]).size;
|
}
|
/**
|
* Encode an ArrayBuffer as a base64 encoded string.
|
*
|
* @param buffer `ArrayBuffer` to be converted.
|
* @returns A string that base64-encodes `buffer`.
|
*/
|
export function arrayBufferToBase64String(buffer) {
|
if (useNodeBuffer) {
|
return Buffer.from(buffer).toString('base64');
|
}
|
const buf = new Uint8Array(buffer);
|
let s = '';
|
for (let i = 0, l = buf.length; i < l; i++) {
|
s += String.fromCharCode(buf[i]);
|
}
|
return btoa(s);
|
}
|
/**
|
* Decode a base64 string as an ArrayBuffer.
|
*
|
* @param str Base64 string.
|
* @returns Decoded `ArrayBuffer`.
|
*/
|
export function base64StringToArrayBuffer(str) {
|
if (useNodeBuffer) {
|
const buf = Buffer.from(str, 'base64');
|
return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength);
|
}
|
const s = atob(str);
|
const buffer = new Uint8Array(s.length);
|
for (let i = 0; i < s.length; ++i) {
|
buffer.set([s.charCodeAt(i)], i);
|
}
|
return buffer.buffer;
|
}
|
/**
|
* Concatenate a number of ArrayBuffers into one.
|
*
|
* @param buffers An array of ArrayBuffers to concatenate, or a single
|
* ArrayBuffer.
|
* @returns Result of concatenating `buffers` in order.
|
*
|
* @deprecated Use tf.io.CompositeArrayBuffer.join() instead.
|
*/
|
export function concatenateArrayBuffers(buffers) {
|
return CompositeArrayBuffer.join(buffers);
|
}
|
/**
|
* Get the basename of a path.
|
*
|
* Behaves in a way analogous to Linux's basename command.
|
*
|
* @param path
|
*/
|
export function basename(path) {
|
const SEPARATOR = '/';
|
path = path.trim();
|
while (path.endsWith(SEPARATOR)) {
|
path = path.slice(0, path.length - 1);
|
}
|
const items = path.split(SEPARATOR);
|
return items[items.length - 1];
|
}
|
/**
|
* Create `ModelJSON` from `ModelArtifacts`.
|
*
|
* @param artifacts Model artifacts, describing the model and its weights.
|
* @param manifest Weight manifest, describing where the weights of the
|
* `ModelArtifacts` are stored, and some metadata about them.
|
* @returns Object representing the `model.json` file describing the model
|
* artifacts and weights
|
*/
|
export function getModelJSONForModelArtifacts(artifacts, manifest) {
|
const result = {
|
modelTopology: artifacts.modelTopology,
|
format: artifacts.format,
|
generatedBy: artifacts.generatedBy,
|
convertedBy: artifacts.convertedBy,
|
weightsManifest: manifest
|
};
|
if (artifacts.signature != null) {
|
result.signature = artifacts.signature;
|
}
|
if (artifacts.userDefinedMetadata != null) {
|
result.userDefinedMetadata = artifacts.userDefinedMetadata;
|
}
|
if (artifacts.modelInitializer != null) {
|
result.modelInitializer = artifacts.modelInitializer;
|
}
|
if (artifacts.initializerSignature != null) {
|
result.initializerSignature = artifacts.initializerSignature;
|
}
|
if (artifacts.trainingConfig != null) {
|
result.trainingConfig = artifacts.trainingConfig;
|
}
|
return result;
|
}
|
/**
|
* Create `ModelArtifacts` from a JSON file and weights.
|
*
|
* @param modelJSON Object containing the parsed JSON of `model.json`
|
* @param weightSpecs The list of WeightsManifestEntry for the model. Must be
|
* passed if the modelJSON has a weightsManifest.
|
* @param weightData An ArrayBuffer or array of ArrayBuffers of weight data for
|
* the model corresponding to the weights in weightSpecs. Must be passed if
|
* the modelJSON has a weightsManifest.
|
* @returns A Promise of the `ModelArtifacts`, as described by the JSON file.
|
*/
|
export function getModelArtifactsForJSONSync(modelJSON, weightSpecs, weightData) {
|
const modelArtifacts = {
|
modelTopology: modelJSON.modelTopology,
|
format: modelJSON.format,
|
generatedBy: modelJSON.generatedBy,
|
convertedBy: modelJSON.convertedBy
|
};
|
if (modelJSON.trainingConfig != null) {
|
modelArtifacts.trainingConfig = modelJSON.trainingConfig;
|
}
|
if (modelJSON.weightsManifest != null) {
|
if (!weightSpecs) {
|
throw new Error('modelJSON has weightsManifest but weightSpecs is null');
|
}
|
if (!weightData) {
|
throw new Error('modelJSON has weightsManifest but weightData is null');
|
}
|
modelArtifacts.weightSpecs = weightSpecs;
|
modelArtifacts.weightData = weightData;
|
}
|
if (modelJSON.signature != null) {
|
modelArtifacts.signature = modelJSON.signature;
|
}
|
if (modelJSON.userDefinedMetadata != null) {
|
modelArtifacts.userDefinedMetadata = modelJSON.userDefinedMetadata;
|
}
|
if (modelJSON.modelInitializer != null) {
|
modelArtifacts.modelInitializer = modelJSON.modelInitializer;
|
}
|
if (modelJSON.initializerSignature != null) {
|
modelArtifacts.initializerSignature = modelJSON.initializerSignature;
|
}
|
return modelArtifacts;
|
}
|
/**
|
* Create `ModelArtifacts` from a JSON file.
|
*
|
* @param modelJSON Object containing the parsed JSON of `model.json`
|
* @param loadWeights Function that takes the JSON file's weights manifest,
|
* reads weights from the listed path(s), and returns a Promise of the
|
* weight manifest entries along with the weights data.
|
* @returns A Promise of the `ModelArtifacts`, as described by the JSON file.
|
*/
|
export async function getModelArtifactsForJSON(modelJSON, loadWeights) {
|
let weightSpecs;
|
let weightData;
|
if (modelJSON.weightsManifest != null) {
|
[weightSpecs, weightData] = await loadWeights(modelJSON.weightsManifest);
|
}
|
return getModelArtifactsForJSONSync(modelJSON, weightSpecs, weightData);
|
}
|
/**
|
* Populate ModelArtifactsInfo fields for a model with JSON topology.
|
* @param modelArtifacts
|
* @returns A ModelArtifactsInfo object.
|
*/
|
export function getModelArtifactsInfoForJSON(modelArtifacts) {
|
if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
|
throw new Error('Expected JSON model topology, received ArrayBuffer.');
|
}
|
return {
|
dateSaved: new Date(),
|
modelTopologyType: 'JSON',
|
modelTopologyBytes: modelArtifacts.modelTopology == null ?
|
0 :
|
stringByteLength(JSON.stringify(modelArtifacts.modelTopology)),
|
weightSpecsBytes: modelArtifacts.weightSpecs == null ?
|
0 :
|
stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)),
|
weightDataBytes: modelArtifacts.weightData == null ?
|
0 :
|
new CompositeArrayBuffer(modelArtifacts.weightData).byteLength,
|
};
|
}
|
/**
|
* Concatenate the weights stored in a WeightsManifestConfig into a list of
|
* WeightsManifestEntry
|
*
|
* @param weightsManifest The WeightsManifestConfig to extract weights from.
|
* @returns A list of WeightsManifestEntry of the weights in the weightsManifest
|
*/
|
export function getWeightSpecs(weightsManifest) {
|
const weightSpecs = [];
|
for (const entry of weightsManifest) {
|
weightSpecs.push(...entry.weights);
|
}
|
return weightSpecs;
|
}
|
/**
|
* Computes mantisa table for casting Float16 to Float32
|
* See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
|
*
|
* @returns Uint32Array, 2048 mantissa lookup values.
|
*/
|
function computeFloat16MantisaTable() {
|
const convertMantissa = (i) => {
|
let m = i << 13;
|
let e = 0;
|
while ((m & 0x00800000) === 0) {
|
e -= 0x00800000;
|
m <<= 1;
|
}
|
m &= ~0x00800000;
|
e += 0x38800000;
|
return m | e;
|
};
|
const mantisaTable = new Uint32Array(2048);
|
mantisaTable[0] = 0;
|
for (let i = 1; i < 1024; i++) {
|
mantisaTable[i] = convertMantissa(i);
|
}
|
for (let i = 1024; i < 2048; i++) {
|
mantisaTable[i] = 0x38000000 + ((i - 1024) << 13);
|
}
|
return mantisaTable;
|
}
|
/**
|
* Computes exponent table for casting Float16 to Float32
|
* See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
|
*
|
* @returns Uint32Array, 64 exponent lookup values.
|
*/
|
function computeFloat16ExponentTable() {
|
const exponentTable = new Uint32Array(64);
|
exponentTable[0] = 0;
|
exponentTable[31] = 0x47800000;
|
exponentTable[32] = 0x80000000;
|
exponentTable[63] = 0xc7800000;
|
for (let i = 1; i < 31; i++) {
|
exponentTable[i] = i << 23;
|
}
|
for (let i = 33; i < 63; i++) {
|
exponentTable[i] = 0x80000000 + ((i - 32) << 23);
|
}
|
return exponentTable;
|
}
|
/**
|
* Computes offset table for casting Float16 to Float32
|
* See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
|
*
|
* @returns Uint32Array, 6d offset values.
|
*/
|
function computeFloat16OffsetTable() {
|
const offsetTable = new Uint32Array(64);
|
for (let i = 0; i < 64; i++) {
|
offsetTable[i] = 1024;
|
}
|
offsetTable[0] = offsetTable[32] = 0;
|
return offsetTable;
|
}
|
/**
|
* Retrieve a Float16 decoder which will decode a ByteArray of Float16 values
|
* to a Float32Array.
|
*
|
* @returns Function (buffer: Uint16Array) => Float32Array which decodes
|
* the Uint16Array of Float16 bytes to a Float32Array.
|
*/
|
export function getFloat16Decoder() {
|
// Algorithm is based off of
|
// http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
|
// Cache lookup tables
|
const mantisaTable = computeFloat16MantisaTable();
|
const exponentTable = computeFloat16ExponentTable();
|
const offsetTable = computeFloat16OffsetTable();
|
return (quantizedArray) => {
|
const buffer = new ArrayBuffer(4 * quantizedArray.length);
|
const bufferUint32View = new Uint32Array(buffer);
|
for (let index = 0; index < quantizedArray.length; index++) {
|
const float16Bits = quantizedArray[index];
|
const float32Bits = mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 0x3ff)] +
|
exponentTable[float16Bits >> 10];
|
bufferUint32View[index] = float32Bits;
|
}
|
return new Float32Array(buffer);
|
};
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"io_utils.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/io/io_utils.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,OAAO,EAAC,MAAM,gBAAgB,CAAC;AACvC,OAAO,EAAC,MAAM,EAAC,MAAM,eAAe,CAAC;AAGrC,OAAO,EAAC,aAAa,EAAC,MAAM,SAAS,CAAC;AAEtC,OAAO,EAAC,oBAAoB,EAAsH,MAAM,SAAS,CAAC;AAClK,OAAO,EAAC,oBAAoB,EAAC,MAAM,0BAA0B,CAAC;AAE9D,OAAO,EAAC,OAAO,EAAC,MAAM,YAAY,CAAC;AAEnC,OAAO,EAAC,GAAG,EAAC,MAAM,gBAAgB,CAAC;AACnC,OAAO,EAAC,UAAU,EAAC,MAAM,YAAY,CAAC;AAEtC,8EAA8E;AAC9E,MAAM,uBAAuB,GAAG,CAAC,CAAC;AAElC;;;;;;;;;;;;;;;;GAgBG;AACH,MAAM,CAAC,KAAK,UAAU,aAAa,CAC/B,OAAqC,EAAE,KAAmB;IAE5D,4CAA4C;IAC5C,MAAM,KAAK,GAA2B,EAAE,CAAC;IACzC,MAAM,YAAY,GAA+B,EAAE,CAAC;IAEpD,MAAM,KAAK,GAAa,KAAK,CAAC,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC;QAC5C,OAAO,CAAC,GAAG,CAAC,MAAM,CAAC,EAAE,CAAC,MAAM,CAAC,IAAI,CAAC,CAAC,CAAC;QACpC,MAAM,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;IAEzB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QACrC,MAAM,IAAI,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC;QACtB,MAAM,CAAC,GAAG,KAAK,CAAC,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,OAAO,CAAC,IAAI,CAAC,CAAC;QACrE,IAAI,CAAC,CAAC,KAAK,KAAK,SAAS,IAAI,CAAC,CAAC,KAAK,KAAK,OAAO,IAAI,CAAC,CAAC,KAAK,KAAK,MAAM;YAClE,CAAC,CAAC,KAAK,KAAK,QAAQ,IAAI,CAAC,CAAC,KAAK,KAAK,WAAW,EAAE;YACnD,MAAM,IAAI,KAAK,CAAC,gCAAgC,IAAI,MAAM,CAAC,CAAC,KAAK,EAAE,CAAC,CAAC;SACtE;QACD,MAAM,IAAI,GAAyB,EAAC,IAAI,EAAE,KAAK,EAAE,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,CAAC,CAAC,KAAK,EAAC,CAAC;QAC1E,IAAI,CAAC,CAAC,KAAK,KAAK,QAAQ,EAAE;YACxB,MAAM,SAAS,GAAG,IAAI,OAAO,CAAa,KAAK,EAAC,OAAO,EAAC,EAAE;gBACxD,MAAM,IAAI,GAAG,MAAM,CAAC,CAAC,KAAK,EAAkB,CAAC;gBAC7C,MAAM,aAAa,GAAG,IAAI,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,CAAC,CAAC,MAAM,EAAE,CAAC,CAAC;oBACxD,uBAAuB,GAAG,IAAI,CAAC,MAAM,CAAC;gBAC1C,MAAM,KAAK,GAAG,IAAI,UAAU,CAAC,aAAa,CAAC,CAAC;gBAC5C,IAAI,MAAM,GAAG,CAAC,CAAC;gBACf,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;oBACpC,MAAM,GAAG,GAAG,IAAI,CAAC,CAAC,CAAC,CAAC;oBACpB,MAAM,aAAa,GACf,IAAI,UAAU,CAAC,IAAI,WAAW,CAAC,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;oBACzD,KAAK,CAAC,GAAG,CAAC,aAAa,EAAE,MAAM,CAAC,CAAC;oBACjC,MAAM,IAAI,uBAAuB,CAAC;oBAClC,KAAK,CAAC,GAAG,CAAC,GAAG,EAAE,MAAM,CAAC,CAAC;oBACvB,MAAM,IAAI,GAAG,CAAC,MAAM,CAAC;iBACtB;gBACD,OAAO,CAAC,KAAK,CAAC,CAAC;YACjB,CAAC,CAAC,CAAC;YACH,YAAY,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;SAC9B;aAAM;YACL,YAAY,CAAC,IAAI,CAAC,CAAC,CAAC,IAAI,EAAE,CAAC,CAAC;SAC7B;QACD,IAAI,KAAK,IAAI,IAAI,EAAE;YACjB,IAAI,CAAC,KAAK,GAAG,KAAK,CAAC;SACpB;QACD,KAAK,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;KAClB;IAED,MAAM,YAAY,GAAG,MAAM,OAAO,CAAC,GAAG,CAAC,YAAY,CAAC,CAAC;IACrD,OAAO,EAAC,IAAI,EAAE,sBAAsB,CAAC,YAAY,CAAC,EAAE,KAAK,EAAC,CAAC;AAC7D,CAAC;AAED;;;;;;;;;;;;;;;GAeG;AACH,MAAM,UAAU,aAAa,CACzB,UAAsB,EACtB,KAA6B;IAC/B,4CAA4C;IAC5C,MAAM,eAAe,GAAG,IAAI,oBAAoB,CAAC,UAAU,CAAC,CAAC;IAC7D,MAAM,GAAG,GAAmB,EAAE,CAAC;IAC/B,IAAI,MAAM,GAAG,CAAC,CAAC;IACf,KAAK,MAAM,IAAI,IAAI,KAAK,EAAE;QACxB,MAAM,UAAU,GAAG,mBAAmB,CAAC,IAAI,EAAE,CAAC,KAAK,EAAE,GAAG,EAAE,EAAE;YAC1D,OAAO,eAAe,CAAC,KAAK,CAAC,MAAM,GAAG,KAAK,EAAE,MAAM,GAAG,GAAG,CAAC,CAAC;QAC7D,CAAC,CAAC,CAAC;QACH,GAAG,CAAC,IAAI,CAAC,IAAI,CAAC,GAAG,YAAY,CAAC,IAAI,EAAE,eAAe;aAChD,KAAK,CAAC,MAAM,EAAE,MAAM,GAAG,UAAU,CAAC,CAAC,CAAC;QACvC,MAAM,IAAI,UAAU,CAAC;KACtB;IACD,OAAO,GAAG,CAAC;AACb,CAAC;AAED,SAAS,mBAAmB,CAAC,IAA0B,EACrD,KAAkD;IAElD,MAAM,IAAI,GAAG,aAAa,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;IACvC,IAAI,aAAqB,CAAC;IAC1B,IAAI,cAAc,IAAI,IAAI,EAAE;QAC1B,MAAM,YAAY,GAAG,IAAI,CAAC,YAAY,CAAC;QACvC,aAAa,GAAG,oBAAoB,CAAC,YAAY,CAAC,KAAK,CAAC,CAAC;KAC1D;SAAM,IAAI,IAAI,CAAC,KAAK,KAAK,QAAQ,EAAE;QAClC,8CAA8C;QAC9C,IAAI,UAAU,GAAG,CAAC,CAAC;QACnB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,EAAE;YAC7B,UAAU,IAAI,uBAAuB,GAAG,IAAI,WAAW,CACrD,KAAK,CAAC,UAAU,EAAE,UAAU,GAAG,uBAAuB,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;SAC/D;QACD,OAAO,UAAU,CAAC;KACnB;SAAM;QACL,aAAa,GAAG,oBAAoB,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;KAClD;IAED,OAAO,IAAI,GAAG,aAAa,CAAC;AAC9B,CAAC;AAED,KAAK,UAAU,wBAAwB,CACrC,IAA0B,EAC1B,KAA2D;IAG3D,MAAM,IAAI,GAAG,aAAa,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;IACvC,IAAI,aAAqB,CAAC;IAC1B,IAAI,cAAc,IAAI,IAAI,EAAE;QAC1B,MAAM,YAAY,GAAG,IAAI,CAAC,YAAY,CAAC;QACvC,aAAa,GAAG,oBAAoB,CAAC,YAAY,CAAC,KAAK,CAAC,CAAC;KAC1D;SAAM,IAAI,IAAI,CAAC,KAAK,KAAK,QAAQ,EAAE;QAClC,8CAA8C;QAC9C,IAAI,UAAU,GAAG,CAAC,CAAC;QACnB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,EAAE;YAC7B,UAAU,IAAI,uBAAuB,GAAG,IAAI,WAAW,CACrD,MAAM,KAAK,CAAC,UAAU,EAAE,UAAU,GAAG,uBAAuB,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;SACrE;QACD,OAAO,UAAU,CAAC;KACnB;SAAM;QACL,aAAa,GAAG,oBAAoB,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;KAClD;IAED,OAAO,IAAI,GAAG,aAAa,CAAC;AAC9B,CAAC;AAED,SAAS,YAAY,CACnB,IAA0B,EAC1B,UAAuB;IAEvB,MAAM,IAAI,GAAG,IAAI,CAAC,IAAI,CAAC;IACvB,MAAM,KAAK,GAAG,IAAI,CAAC,KAAK,CAAC;IACzB,MAAM,KAAK,GAAG,IAAI,CAAC,KAAK,CAAC;IACzB,MAAM,IAAI,GAAG,aAAa,CAAC,KAAK,CAAC,CAAC;IAClC,IAAI,MAA4C,CAAC;IACjD,IAAI,MAAM,GAAG,CAAC,CAAC;IAEf,IAAI,cAAc,IAAI,IAAI,EAAE;QAC1B,MAAM,YAAY,GAAG,IAAI,CAAC,YAAY,CAAC;QACvC,IAAI,YAAY,CAAC,KAAK,KAAK,OAAO,IAAI,YAAY,CAAC,KAAK,KAAK,QAAQ,EAAE;YACrE,IAAI,CAAC,CAAC,KAAK,IAAI,YAAY,IAAI,OAAO,IAAI,YAAY,CAAC,EAAE;gBACvD,MAAM,IAAI,KAAK,CACX,UAAU,IAAI,CAAC,IAAI,sBAAsB,YAAY,CAAC,KAAK,GAAG;oBAC9D,oDAAoD,CAAC,CAAC;aAC3D;SACF;aAAM,IAAI,YAAY,CAAC,KAAK,KAAK,SAAS,EAAE;YAC3C,IAAI,KAAK,KAAK,SAAS,EAAE;gBACvB,MAAM,IAAI,KAAK,CACX,UAAU,IAAI,CAAC,IAAI,sBAAsB,YAAY,CAAC,KAAK,GAAG;oBAC9D,mDAAmD,KAAK,GAAG,CAAC,CAAC;aAClE;SACF;aAAM;YACL,MAAM,IAAI,KAAK,CACX,UAAU,IAAI,CAAC,IAAI,eAAe;gBAClC,sBAAsB,YAAY,CAAC,KAAK,IAAI;gBAC5C,qCAAqC;gBACrC,mCAAmC,CAAC,CAAC;SAC1C;QACD,MAAM,sBAAsB,GAAG,oBAAoB,CAAC,YAAY,CAAC,KAAK,CAAC,CAAC;QACxE,MAAM,cAAc,GAAG,CAAC,YAAY,CAAC,KAAK,KAAK,OAAO,CAAC,CAAC,CAAC;YACvD,IAAI,UAAU,CAAC,UAAU,CAAC,CAAC,CAAC;YAC5B,IAAI,WAAW,CAAC,UAAU,CAAC,CAAC;QAC9B,IAAI,KAAK,KAAK,SAAS,EAAE;YACvB,IAAI,YAAY,CAAC,KAAK,KAAK,OAAO,IAAI,YAAY,CAAC,KAAK,KAAK,QAAQ,EAAE;gBACrE,MAAM,GAAG,IAAI,YAAY,CAAC,cAAc,CAAC,MAAM,CAAC,CAAC;gBACjD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,cAAc,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;oBAC9C,MAAM,CAAC,GAAG,cAAc,CAAC,CAAC,CAAC,CAAC;oBAC5B,MAAM,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,YAAY,CAAC,KAAK,GAAG,YAAY,CAAC,GAAG,CAAC;iBACvD;aACF;iBAAM,IAAI,YAAY,CAAC,KAAK,KAAK,SAAS,EAAE;gBAC3C,+DAA+D;gBAC/D,MAAM,aAAa,GAAG,iBAAiB,EAAE,CAAC;gBAC1C,MAAM,GAAG,aAAa,CAAC,cAA6B,CAAC,CAAC;aACvD;iBAAM;gBACL,MAAM,IAAI,KAAK,CACb,iCAAiC,YAAY,CAAC,KAAK,GAAG;oBACtD,0BAA0B,CAAC,CAAC;aAC/B;SACF;aAAM,IAAI,KAAK,KAAK,OAAO,EAAE;YAC5B,IAAI,YAAY,CAAC,KAAK,KAAK,OAAO,IAAI,YAAY,CAAC,KAAK,KAAK,QAAQ,EAAE;gBACrE,MAAM,IAAI,KAAK,CACb,iCAAiC,YAAY,CAAC,KAAK,GAAG;oBACtD,wBAAwB,CAAC,CAAC;aAC7B;YACD,MAAM,GAAG,IAAI,UAAU,CAAC,cAAc,CAAC,MAAM,CAAC,CAAC;YAC/C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,cAAc,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;gBAC9C,MAAM,CAAC,GAAG,cAAc,CAAC,CAAC,CAAC,CAAC;gBAC5B,MAAM,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,GAAG,YAAY,CAAC,KAAK,GAAG,YAAY,CAAC,GAAG,CAAC,CAAC;aACnE;SACF;aAAM;YACL,MAAM,IAAI,KAAK,CAAC,gCAAgC,IAAI,MAAM,KAAK,EAAE,CAAC,CAAC;SACpE;QACD,MAAM,IAAI,IAAI,GAAG,sBAAsB,CAAC;KACzC;SAAM,IAAI,KAAK,KAAK,QAAQ,EAAE;QAC7B,MAAM,IAAI,GAAG,aAAa,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;QACvC,MAAM,GAAG,EAAE,CAAC;QACZ,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,EAAE;YAC7B,MAAM,UAAU,GAAG,IAAI,WAAW,CAChC,UAAU,CAAC,KAAK,CAAC,MAAM,EAAE,MAAM,GAAG,uBAAuB,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YACjE,MAAM,IAAI,uBAAuB,CAAC;YAClC,MAAM,KAAK,GAAG,IAAI,UAAU,CAC1B,UAAU,CAAC,KAAK,CAAC,MAAM,EAAE,MAAM,GAAG,UAAU,CAAC,CAAC,CAAC;YAChD,MAAuB,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC;YACrC,MAAM,IAAI,UAAU,CAAC;SACtB;KACF;SAAM;QACL,MAAM,WAAW,GAAG,oBAAoB,CAAC,KAAK,CAAC,CAAC;QAChD,IAAI,KAAK,KAAK,SAAS,EAAE;YACvB,MAAM,GAAG,IAAI,YAAY,CAAC,UAAU,CAAC,CAAC;SACvC;aAAM,IAAI,KAAK,KAAK,OAAO,EAAE;YAC5B,MAAM,GAAG,IAAI,UAAU,CAAC,UAAU,CAAC,CAAC;SACrC;aAAM,IAAI,KAAK,KAAK,MAAM,EAAE;YAC3B,MAAM,GAAG,IAAI,UAAU,CAAC,UAAU,CAAC,CAAC;SACrC;aAAM,IAAI,KAAK,KAAK,WAAW,EAAE;YAChC,MAAM,GAAG,IAAI,YAAY,CAAC,UAAU,CAAC,CAAC;YACtC,MAAM,IAAI,GAAG,IAAI,YAAY,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;YACjD,MAAM,KAAK,GAAG,IAAI,YAAY,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;YAClD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE;gBACpC,IAAI,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;gBACxB,KAAK,CAAC,CAAC,CAAC,GAAG,MAAM,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;aAC9B;YACD,MAAM,UAAU,GAAG,MAAM,CAAC,IAAI,EAAE,KAAK,EAAE,SAAS,CAAC,CAAC;YAClD,MAAM,WAAW,GAAG,MAAM,CAAC,KAAK,EAAE,KAAK,EAAE,SAAS,CAAC,CAAC;YACpD,MAAM,aAAa,GAAG,OAAO,CAAC,UAAU,EAAE,WAAW,CAAC,CAAC;YACvD,UAAU,CAAC,OAAO,EAAE,CAAC;YACrB,WAAW,CAAC,OAAO,EAAE,CAAC;YACtB,OAAO,aAAa,CAAC;SACtB;aAAM;YACL,MAAM,IAAI,KAAK,CAAC,gCAAgC,IAAI,MAAM,KAAK,EAAE,CAAC,CAAC;SACpE;QACD,MAAM,IAAI,IAAI,GAAG,WAAW,CAAC;KAC9B;IACD,OAAO,MAAM,CAAC,MAAM,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC;AACtC,CAAC;AAED,KAAK,UAAU,YAAY,CAAC,MAAgD,EAChD,WAAwB,EACxB,MAAc;IACxC,IAAI,IAAI,GAAG,IAAI,UAAU,CAAC,WAAW,CAAC,CAAC;IAEvC,OAAO,IAAI,CAAC,UAAU,GAAG,MAAM,EAAE;QAC/B,MAAM,EAAC,IAAI,EAAE,KAAK,EAAC,GAAG,MAAM,MAAM,CAAC,IAAI,EAAE,CAAC;QAC1C,IAAI,IAAI,IAAI,KAAK,IAAI,IAAI,EAAE;YACzB,MAAM,OAAO,GAAI,MAAM,GAAG,IAAI,CAAC,UAAU,CAAC;YAC1C,MAAM,IAAI,KAAK,CAAC,sBAAsB,OAAO,2BAA2B,CAAC,CAAC;SAC3E;QAED,6CAA6C;QAC7C,MAAM,OAAO,GAAG,IAAI,UAAU,CAAC,IAAI,CAAC,MAAM,GAAG,KAAK,CAAC,UAAU,CAAC,CAAC;QAC/D,OAAO,CAAC,GAAG,CAAC,IAAI,EAAE,CAAC,CAAC,CAAC;QACrB,OAAO,CAAC,GAAG,CAAC,IAAI,UAAU,CAAC,KAAK,CAAC,EAAE,IAAI,CAAC,MAAM,CAAC,CAAC;QAChD,IAAI,GAAG,OAAO,CAAC;KAChB;IAED,OAAO,IAAI,CAAC,MAAM,CAAC;AACrB,CAAC;AAED,MAAM,CAAC,KAAK,UAAU,mBAAmB,CACvC,YAAyC,EACzC,KAA6B;IAE7B,MAAM,OAAO,GAAmB,EAAE,CAAC;IACnC,MAAM,MAAM,GAAG,YAAY,CAAC,SAAS,EAAE,CAAC;IACxC,IAAI,IAAI,GAAG,IAAI,WAAW,CAAC,CAAC,CAAC,CAAC;IAE9B,KAAK,MAAM,IAAI,IAAI,KAAK,EAAE;QACxB,MAAM,UAAU,GAAG,MAAM,wBAAwB,CAAC,IAAI,EACJ,KAAK,EAAE,KAAK,EAAE,GAAG,EAAE,EAAE;YACrE,IAAI,GAAG,MAAM,YAAY,CAAC,MAAM,EAAE,IAAI,EAAE,GAAG,CAAC,CAAC;YAC7C,OAAO,IAAI,CAAC,KAAK,CAAC,KAAK,EAAE,GAAG,CAAC,CAAC;QAChC,CAAC,CAAC,CAAC;QACH,IAAI,GAAG,MAAM,YAAY,CAAC,MAAM,EAAE,IAAI,EAAE,UAAU,CAAC,CAAC;QAEpD,uBAAuB;QACvB,MAAM,UAAU,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC;QAC7C,IAAI,GAAG,IAAI,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;QAE9B,MAAM,YAAY,GAAG,YAAY,CAAC,IAAI,EAAE,UAAU,CAAC,CAAC;QACpD,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,GAAG,YAAY,CAAC;QAElC,wDAAwD;QACxD,sDAAsD;QACtD,IAAI,UAAU,EAAE,KAAK,QAAQ,EAAE;YAC7B,MAAM,CAAC,GAAG,OAAO,EAAE,CAAC;YAEpB,IAAI,aAAa,IAAI,CAAC;gBACpB,aAAa,CAAC,YAAY,CAAC,KAAK,CAAC,IAAK,GAAG,EAAE;qBACxC,GAAG,CAAC,mCAAmC,CAAY,EAAE;gBACvD,CAAC,CAAC,WAAwC,CAAC,YAAY,CAAC,MAAM,CAAC,CAAC;aAClE;SACF;KACF;IAED,OAAO,OAAO,CAAC;AACjB,CAAC;AAED;;GAEG;AACH,MAAM,UAAU,sBAAsB,CAAC,EAAgB;IACrD,4CAA4C;IAC5C,IAAI,EAAE,KAAK,IAAI,EAAE;QACf,MAAM,IAAI,KAAK,CAAC,wBAAwB,IAAI,CAAC,SAAS,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;KAC/D;IAED,IAAI,eAAe,GAAG,CAAC,CAAC;IAExB,oEAAoE;IACpE,yEAAyE;IACzE,qEAAqE;IACrE,0EAA0E;IAC1E,uEAAuE;IACvE,uEAAuE;IACvE,8CAA8C;IAC9C,MAAM,YAAY,GAAiB,EAAE,CAAC;IACtC,EAAE,CAAC,OAAO,CAAC,CAAC,CAAa,EAAE,EAAE;QAC3B,eAAe,IAAI,CAAC,CAAC,UAAU,CAAC;QAChC,wBAAwB;QACxB,YAAY,CAAC,IAAI,CACb,CAAC,CAAC,UAAU,KAAK,CAAC,CAAC,MAAM,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YACH,IAAK,CAAC,CAAC,WAAmB,CAAC,CAAC,CAAC,CAAC,CAAC;QAC1E,IAAI,CAAC,CAAC,CAAQ,YAAY,YAAY,IAAI,CAAQ,YAAY,UAAU;YAClE,CAAQ,YAAY,UAAU,CAAC,EAAE;YACrC,MAAM,IAAI,KAAK,CAAC,mCAAmC,CAAC,CAAC,WAAW,CAAC,IAAI,EAAE,CAAC,CAAC;SAC1E;QACD,uBAAuB;IACzB,CAAC,CAAC,CAAC;IAEH,MAAM,CAAC,GAAG,IAAI,UAAU,CAAC,eAAe,CAAC,CAAC;IAC1C,IAAI,MAAM,GAAG,CAAC,CAAC;IACf,YAAY,CAAC,OAAO,CAAC,CAAC,CAAa,EAAE,EAAE;QACrC,CAAC,CAAC,GAAG,CAAC,IAAI,UAAU,CAAC,CAAC,CAAC,MAAM,CAAC,EAAE,MAAM,CAAC,CAAC;QACxC,MAAM,IAAI,CAAC,CAAC,UAAU,CAAC;IACzB,CAAC,CAAC,CAAC;IAEH,OAAO,CAAC,CAAC,MAAM,CAAC;AAClB,CAAC;AAED,kDAAkD;AAClD,MAAM,aAAa,GAAG,OAAO,MAAM,KAAK,WAAW;IAC/C,CAAC,OAAO,IAAI,KAAK,WAAW,IAAI,OAAO,IAAI,KAAK,WAAW;QAC1D,OAAO,IAAI,KAAK,WAAW,CAAC,CAAC;AAElC;;;;;;;;GAQG;AACH,MAAM,UAAU,gBAAgB,CAAC,GAAW;IAC1C,IAAI,aAAa,EAAE;QACjB,OAAO,MAAM,CAAC,UAAU,CAAC,GAAG,EAAE,MAAM,CAAC,CAAC;KACvC;IACD,OAAO,IAAI,IAAI,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,IAAI,CAAC;AAC9B,CAAC;AAED;;;;;GAKG;AACH,MAAM,UAAU,yBAAyB,CAAC,MAAmB;IAC3D,IAAI,aAAa,EAAE;QACjB,OAAO,MAAM,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC,QAAQ,CAAC,QAAQ,CAAC,CAAC;KAC/C;IACD,MAAM,GAAG,GAAG,IAAI,UAAU,CAAC,MAAM,CAAC,CAAC;IACnC,IAAI,CAAC,GAAG,EAAE,CAAC;IACX,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,CAAC,MAAM,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE;QAC1C,CAAC,IAAI,MAAM,CAAC,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC;KAClC;IACD,OAAO,IAAI,CAAC,CAAC,CAAC,CAAC;AACjB,CAAC;AAED;;;;;GAKG;AACH,MAAM,UAAU,yBAAyB,CAAC,GAAW;IACnD,IAAI,aAAa,EAAE;QACjB,MAAM,GAAG,GAAG,MAAM,CAAC,IAAI,CAAC,GAAG,EAAE,QAAQ,CAAC,CAAC;QACvC,OAAO,GAAG,CAAC,MAAM,CAAC,KAAK,CAAC,GAAG,CAAC,UAAU,EAAE,GAAG,CAAC,UAAU,GAAG,GAAG,CAAC,UAAU,CAAC,CAAC;KAC1E;IACD,MAAM,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC;IACpB,MAAM,MAAM,GAAG,IAAI,UAAU,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;IACxC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,MAAM,EAAE,EAAE,CAAC,EAAE;QACjC,MAAM,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;KAClC;IACD,OAAO,MAAM,CAAC,MAAM,CAAC;AACvB,CAAC;AAED;;;;;;;;GAQG;AACH,MAAM,UAAU,uBAAuB,CAAC,OACrB;IACjB,OAAO,oBAAoB,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;AAC5C,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,QAAQ,CAAC,IAAY;IACnC,MAAM,SAAS,GAAG,GAAG,CAAC;IACtB,IAAI,GAAG,IAAI,CAAC,IAAI,EAAE,CAAC;IACnB,OAAO,IAAI,CAAC,QAAQ,CAAC,SAAS,CAAC,EAAE;QAC/B,IAAI,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;KACvC;IACD,MAAM,KAAK,GAAG,IAAI,CAAC,KAAK,CAAC,SAAS,CAAC,CAAC;IACpC,OAAO,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;AACjC,CAAC;AAED;;;;;;;;GAQG;AACH,MAAM,UAAU,6BAA6B,CACzC,SAAyB,EAAE,QAA+B;IAC5D,MAAM,MAAM,GAAc;QACxB,aAAa,EAAE,SAAS,CAAC,aAAa;QACtC,MAAM,EAAE,SAAS,CAAC,MAAM;QACxB,WAAW,EAAE,SAAS,CAAC,WAAW;QAClC,WAAW,EAAE,SAAS,CAAC,WAAW;QAClC,eAAe,EAAE,QAAQ;KAC1B,CAAC;IACF,IAAI,SAAS,CAAC,SAAS,IAAI,IAAI,EAAE;QAC/B,MAAM,CAAC,SAAS,GAAG,SAAS,CAAC,SAAS,CAAC;KACxC;IACD,IAAI,SAAS,CAAC,mBAAmB,IAAI,IAAI,EAAE;QACzC,MAAM,CAAC,mBAAmB,GAAG,SAAS,CAAC,mBAAmB,CAAC;KAC5D;IACD,IAAI,SAAS,CAAC,gBAAgB,IAAI,IAAI,EAAE;QACtC,MAAM,CAAC,gBAAgB,GAAG,SAAS,CAAC,gBAAgB,CAAC;KACtD;IACD,IAAI,SAAS,CAAC,oBAAoB,IAAI,IAAI,EAAE;QAC1C,MAAM,CAAC,oBAAoB,GAAG,SAAS,CAAC,oBAAoB,CAAC;KAC9D;IACD,IAAI,SAAS,CAAC,cAAc,IAAI,IAAI,EAAE;QACpC,MAAM,CAAC,cAAc,GAAG,SAAS,CAAC,cAAc,CAAC;KAClD;IACD,OAAO,MAAM,CAAC;AAChB,CAAC;AAED;;;;;;;;;;GAUG;AACH,MAAM,UAAU,4BAA4B,CACxC,SAAoB,EAAE,WAAoC,EAC1D,UAAuB;IAEzB,MAAM,cAAc,GAAmB;QACrC,aAAa,EAAE,SAAS,CAAC,aAAa;QACtC,MAAM,EAAE,SAAS,CAAC,MAAM;QACxB,WAAW,EAAE,SAAS,CAAC,WAAW;QAClC,WAAW,EAAE,SAAS,CAAC,WAAW;KACnC,CAAC;IAEF,IAAI,SAAS,CAAC,cAAc,IAAI,IAAI,EAAE;QACpC,cAAc,CAAC,cAAc,GAAG,SAAS,CAAC,cAAc,CAAC;KAC1D;IACD,IAAI,SAAS,CAAC,eAAe,IAAI,IAAI,EAAE;QACrC,IAAI,CAAC,WAAW,EAAE;YAChB,MAAM,IAAI,KAAK,CAAC,uDAAuD,CAAC,CAAC;SAC1E;QACD,IAAI,CAAC,UAAU,EAAE;YACf,MAAM,IAAI,KAAK,CAAC,sDAAsD,CAAC,CAAC;SACzE;QACD,cAAc,CAAC,WAAW,GAAG,WAAW,CAAC;QACzC,cAAc,CAAC,UAAU,GAAG,UAAU,CAAC;KACxC;IACD,IAAI,SAAS,CAAC,SAAS,IAAI,IAAI,EAAE;QAC/B,cAAc,CAAC,SAAS,GAAG,SAAS,CAAC,SAAS,CAAC;KAChD;IACD,IAAI,SAAS,CAAC,mBAAmB,IAAI,IAAI,EAAE;QACzC,cAAc,CAAC,mBAAmB,GAAG,SAAS,CAAC,mBAAmB,CAAC;KACpE;IACD,IAAI,SAAS,CAAC,gBAAgB,IAAI,IAAI,EAAE;QACtC,cAAc,CAAC,gBAAgB,GAAG,SAAS,CAAC,gBAAgB,CAAC;KAC9D;IACD,IAAI,SAAS,CAAC,oBAAoB,IAAI,IAAI,EAAE;QAC1C,cAAc,CAAC,oBAAoB,GAAG,SAAS,CAAC,oBAAoB,CAAC;KACtE;IAED,OAAO,cAAc,CAAC;AACxB,CAAC;AAED;;;;;;;;GAQG;AACH,MAAM,CAAC,KAAK,UAAU,wBAAwB,CAC1C,SAAoB,EACpB,WAEE;IACJ,IAAI,WAA+C,CAAC;IACpD,IAAI,UAAkC,CAAC;IAEvC,IAAI,SAAS,CAAC,eAAe,IAAI,IAAI,EAAE;QACrC,CAAC,WAAW,EAAE,UAAU,CAAC,GAAG,MAAM,WAAW,CAAC,SAAS,CAAC,eAAe,CAAC,CAAC;KAC1E;IAED,OAAO,4BAA4B,CAAC,SAAS,EAAE,WAAW,EAAE,UAAU,CAAC,CAAC;AAC1E,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,4BAA4B,CAAC,cAA8B;IAEzE,IAAI,cAAc,CAAC,aAAa,YAAY,WAAW,EAAE;QACvD,MAAM,IAAI,KAAK,CAAC,qDAAqD,CAAC,CAAC;KACxE;IAED,OAAO;QACL,SAAS,EAAE,IAAI,IAAI,EAAE;QACrB,iBAAiB,EAAE,MAAM;QACzB,kBAAkB,EAAE,cAAc,CAAC,aAAa,IAAI,IAAI,CAAC,CAAC;YACtD,CAAC,CAAC,CAAC;YACH,gBAAgB,CAAC,IAAI,CAAC,SAAS,CAAC,cAAc,CAAC,aAAa,CAAC,CAAC;QAClE,gBAAgB,EAAE,cAAc,CAAC,WAAW,IAAI,IAAI,CAAC,CAAC;YAClD,CAAC,CAAC,CAAC;YACH,gBAAgB,CAAC,IAAI,CAAC,SAAS,CAAC,cAAc,CAAC,WAAW,CAAC,CAAC;QAChE,eAAe,EAAE,cAAc,CAAC,UAAU,IAAI,IAAI,CAAC,CAAC;YAChD,CAAC,CAAC,CAAC;YACH,IAAI,oBAAoB,CAAC,cAAc,CAAC,UAAU,CAAC,CAAC,UAAU;KACnE,CAAC;AACJ,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,cAAc,CAAC,eAAsC;IAEnE,MAAM,WAAW,GAA2B,EAAE,CAAC;IAC/C,KAAK,MAAM,KAAK,IAAI,eAAe,EAAE;QACnC,WAAW,CAAC,IAAI,CAAC,GAAG,KAAK,CAAC,OAAO,CAAC,CAAC;KACpC;IACD,OAAO,WAAW,CAAC;AACrB,CAAC;AAED;;;;;GAKG;AACH,SAAS,0BAA0B;IACjC,MAAM,eAAe,GAAG,CAAC,CAAS,EAAU,EAAE;QAC5C,IAAI,CAAC,GAAG,CAAC,IAAI,EAAE,CAAC;QAChB,IAAI,CAAC,GAAG,CAAC,CAAC;QAEV,OAAO,CAAC,CAAC,GAAG,UAAU,CAAC,KAAK,CAAC,EAAE;YAC7B,CAAC,IAAI,UAAU,CAAC;YAChB,CAAC,KAAK,CAAC,CAAC;SACT;QACD,CAAC,IAAI,CAAC,UAAU,CAAC;QACjB,CAAC,IAAI,UAAU,CAAC;QAEhB,OAAO,CAAC,GAAG,CAAC,CAAC;IACf,CAAC,CAAC;IAEF,MAAM,YAAY,GAAG,IAAI,WAAW,CAAC,IAAI,CAAC,CAAC;IAE3C,YAAY,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC;IACpB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,EAAE;QAC7B,YAAY,CAAC,CAAC,CAAC,GAAG,eAAe,CAAC,CAAC,CAAC,CAAC;KACtC;IACD,KAAK,IAAI,CAAC,GAAG,IAAI,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,EAAE;QAChC,YAAY,CAAC,CAAC,CAAC,GAAG,UAAU,GAAG,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC;KACnD;IAED,OAAO,YAAY,CAAC;AACtB,CAAC;AAED;;;;;GAKG;AACH,SAAS,2BAA2B;IAClC,MAAM,aAAa,GAAG,IAAI,WAAW,CAAC,EAAE,CAAC,CAAC;IAE1C,aAAa,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC;IACrB,aAAa,CAAC,EAAE,CAAC,GAAG,UAAU,CAAC;IAC/B,aAAa,CAAC,EAAE,CAAC,GAAG,UAAU,CAAC;IAC/B,aAAa,CAAC,EAAE,CAAC,GAAG,UAAU,CAAC;IAC/B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,EAAE,CAAC,EAAE,EAAE;QAC3B,aAAa,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,EAAE,CAAC;KAC5B;IACD,KAAK,IAAI,CAAC,GAAG,EAAE,EAAE,CAAC,GAAG,EAAE,EAAE,CAAC,EAAE,EAAE;QAC5B,aAAa,CAAC,CAAC,CAAC,GAAG,UAAU,GAAG,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,IAAI,EAAE,CAAC,CAAC;KAClD;IAED,OAAO,aAAa,CAAC;AACvB,CAAC;AAED;;;;;GAKG;AACH,SAAS,yBAAyB;IAChC,MAAM,WAAW,GAAG,IAAI,WAAW,CAAC,EAAE,CAAC,CAAC;IAExC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,EAAE,CAAC,EAAE,EAAE;QAC3B,WAAW,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC;KACvB;IACD,WAAW,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC;IAErC,OAAO,WAAW,CAAC;AACrB,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,iBAAiB;IAC/B,4BAA4B;IAC5B,6DAA6D;IAE7D,sBAAsB;IACtB,MAAM,YAAY,GAAG,0BAA0B,EAAE,CAAC;IAClD,MAAM,aAAa,GAAG,2BAA2B,EAAE,CAAC;IACpD,MAAM,WAAW,GAAG,yBAAyB,EAAE,CAAC;IAEhD,OAAO,CAAC,cAA2B,EAAE,EAAE;QACrC,MAAM,MAAM,GAAG,IAAI,WAAW,CAAC,CAAC,GAAG,cAAc,CAAC,MAAM,CAAC,CAAC;QAC1D,MAAM,gBAAgB,GAAG,IAAI,WAAW,CAAC,MAAM,CAAC,CAAC;QACjD,KAAK,IAAI,KAAK,GAAG,CAAC,EAAE,KAAK,GAAG,cAAc,CAAC,MAAM,EAAE,KAAK,EAAE,EAAE;YAC1D,MAAM,WAAW,GAAG,cAAc,CAAC,KAAK,CAAC,CAAC;YAC1C,MAAM,WAAW,GACb,YAAY,CAAC,WAAW,CAAC,WAAW,IAAI,EAAE,CAAC,GAAG,CAAC,WAAW,GAAG,KAAK,CAAC,CAAC;gBACpE,aAAa,CAAC,WAAW,IAAI,EAAE,CAAC,CAAC;YACrC,gBAAgB,CAAC,KAAK,CAAC,GAAG,WAAW,CAAC;SACvC;QACD,OAAO,IAAI,YAAY,CAAC,MAAM,CAAC,CAAC;IAClC,CAAC,CAAC;AACJ,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {complex} from '../ops/complex';\nimport {tensor} from '../ops/tensor';\nimport {NamedTensor, NamedTensorMap} from '../tensor_types';\nimport {TypedArray} from '../types';\nimport {sizeFromShape} from '../util';\n\nimport {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, ModelJSON, WeightData, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types';\nimport {CompositeArrayBuffer} from './composite_array_buffer';\nimport {Tensor} from '../tensor';\nimport {backend} from '../globals';\nimport {DataId} from '../tensor_info';\nimport {env} from '../environment';\nimport {getBackend} from '../globals';\n\n/** Number of bytes reserved for the length of the string. (32bit integer). */\nconst NUM_BYTES_STRING_LENGTH = 4;\n\n/**\n * Encode a map from names to weight values as an ArrayBuffer, along with an\n * `Array` of `WeightsManifestEntry` as specification of the encoded weights.\n *\n * This function does not perform sharding.\n *\n * This function is the reverse of `decodeWeights`.\n *\n * @param tensors A map (\"dict\") from names to tensors.\n * @param group Group to which the weights belong (optional).\n * @returns A `Promise` of\n *   - A flat `ArrayBuffer` with all the binary values of the `Tensor`s\n *     concatenated.\n *   - An `Array` of `WeightManifestEntry`s, carrying information including\n *     tensor names, `dtype`s and shapes.\n * @throws Error: on unsupported tensor `dtype`.\n */\nexport async function encodeWeights(\n    tensors: NamedTensorMap|NamedTensor[], group?: WeightGroup):\n    Promise<{data: ArrayBuffer, specs: WeightsManifestEntry[]}> {\n  // TODO(adarob, cais): Support quantization.\n  const specs: WeightsManifestEntry[] = [];\n  const dataPromises: Array<Promise<TypedArray>> = [];\n\n  const names: string[] = Array.isArray(tensors) ?\n      tensors.map(tensor => tensor.name) :\n      Object.keys(tensors);\n\n  for (let i = 0; i < names.length; ++i) {\n    const name = names[i];\n    const t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name];\n    if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' &&\n        t.dtype !== 'string' && t.dtype !== 'complex64') {\n      throw new Error(`Unsupported dtype in weight '${name}': ${t.dtype}`);\n    }\n    const spec: WeightsManifestEntry = {name, shape: t.shape, dtype: t.dtype};\n    if (t.dtype === 'string') {\n      const utf8bytes = new Promise<TypedArray>(async resolve => {\n        const vals = await t.bytes() as Uint8Array[];\n        const totalNumBytes = vals.reduce((p, c) => p + c.length, 0) +\n            NUM_BYTES_STRING_LENGTH * vals.length;\n        const bytes = new Uint8Array(totalNumBytes);\n        let offset = 0;\n        for (let i = 0; i < vals.length; i++) {\n          const val = vals[i];\n          const bytesOfLength =\n              new Uint8Array(new Uint32Array([val.length]).buffer);\n          bytes.set(bytesOfLength, offset);\n          offset += NUM_BYTES_STRING_LENGTH;\n          bytes.set(val, offset);\n          offset += val.length;\n        }\n        resolve(bytes);\n      });\n      dataPromises.push(utf8bytes);\n    } else {\n      dataPromises.push(t.data());\n    }\n    if (group != null) {\n      spec.group = group;\n    }\n    specs.push(spec);\n  }\n\n  const tensorValues = await Promise.all(dataPromises);\n  return {data: concatenateTypedArrays(tensorValues), specs};\n}\n\n/**\n * Decode flat ArrayBuffer as weights.\n *\n * This function does not handle sharding.\n *\n * This function is the reverse of `encodeWeights`.\n *\n * @param weightData A flat ArrayBuffer or an array of ArrayBuffers carrying the\n *   binary values of the tensors concatenated in the order specified in\n *   `specs`.\n * @param specs Specifications of the names, dtypes and shapes of the tensors\n *   whose value are encoded by `buffer`.\n * @return A map from tensor name to tensor value, with the names corresponding\n *   to names in `specs`.\n * @throws Error, if any of the tensors has unsupported dtype.\n */\nexport function decodeWeights(\n    weightData: WeightData,\n    specs: WeightsManifestEntry[]): NamedTensorMap {\n  // TODO(adarob, cais): Support quantization.\n  const compositeBuffer = new CompositeArrayBuffer(weightData);\n  const out: NamedTensorMap = {};\n  let offset = 0;\n  for (const spec of specs) {\n    const byteLength = getWeightBytelength(spec, (start, end) => {\n      return compositeBuffer.slice(offset + start, offset + end);\n    });\n    out[spec.name] = decodeWeight(spec, compositeBuffer\n      .slice(offset, offset + byteLength));\n    offset += byteLength;\n  }\n  return out;\n}\n\nfunction getWeightBytelength(spec: WeightsManifestEntry,\n  slice: (start: number, end: number) => ArrayBuffer): number {\n\n  const size = sizeFromShape(spec.shape);\n  let bytesPerValue: number;\n  if ('quantization' in spec) {\n    const quantization = spec.quantization;\n    bytesPerValue = DTYPE_VALUE_SIZE_MAP[quantization.dtype];\n  } else if (spec.dtype === 'string') {\n    // Can not statically determine string length.\n    let byteLength = 0;\n    for (let i = 0; i < size; i++) {\n      byteLength += NUM_BYTES_STRING_LENGTH + new Uint32Array(\n        slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH))[0];\n    }\n    return byteLength;\n  } else {\n    bytesPerValue = DTYPE_VALUE_SIZE_MAP[spec.dtype];\n  }\n\n  return size * bytesPerValue;\n}\n\nasync function getWeightBytelengthAsync(\n  spec: WeightsManifestEntry,\n  slice: (start: number, end: number) => Promise<ArrayBuffer>\n): Promise<number> {\n\n  const size = sizeFromShape(spec.shape);\n  let bytesPerValue: number;\n  if ('quantization' in spec) {\n    const quantization = spec.quantization;\n    bytesPerValue = DTYPE_VALUE_SIZE_MAP[quantization.dtype];\n  } else if (spec.dtype === 'string') {\n    // Can not statically determine string length.\n    let byteLength = 0;\n    for (let i = 0; i < size; i++) {\n      byteLength += NUM_BYTES_STRING_LENGTH + new Uint32Array(\n        await slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH))[0];\n    }\n    return byteLength;\n  } else {\n    bytesPerValue = DTYPE_VALUE_SIZE_MAP[spec.dtype];\n  }\n\n  return size * bytesPerValue;\n}\n\nfunction decodeWeight(\n  spec: WeightsManifestEntry,\n  byteBuffer: ArrayBuffer): Tensor {\n\n  const name = spec.name;\n  const dtype = spec.dtype;\n  const shape = spec.shape;\n  const size = sizeFromShape(shape);\n  let values: TypedArray | string[] | Uint8Array[];\n  let offset = 0;\n\n  if ('quantization' in spec) {\n    const quantization = spec.quantization;\n    if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {\n      if (!('min' in quantization && 'scale' in quantization)) {\n        throw new Error(\n            `Weight ${spec.name} with quantization ${quantization.dtype} ` +\n            `doesn't have corresponding metadata min and scale.`);\n      }\n    } else if (quantization.dtype === 'float16') {\n      if (dtype !== 'float32') {\n        throw new Error(\n            `Weight ${spec.name} is quantized with ${quantization.dtype} ` +\n            `which only supports weights of type float32 not ${dtype}.`);\n      }\n    } else {\n      throw new Error(\n          `Weight ${spec.name} has unknown ` +\n          `quantization dtype ${quantization.dtype}. ` +\n          `Supported quantization dtypes are: ` +\n          `'uint8', 'uint16', and 'float16'.`);\n    }\n    const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype];\n    const quantizedArray = (quantization.dtype === 'uint8') ?\n      new Uint8Array(byteBuffer) :\n      new Uint16Array(byteBuffer);\n    if (dtype === 'float32') {\n      if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {\n        values = new Float32Array(quantizedArray.length);\n        for (let i = 0; i < quantizedArray.length; i++) {\n          const v = quantizedArray[i];\n          values[i] = v * quantization.scale + quantization.min;\n        }\n      } else if (quantization.dtype === 'float16') {\n        // TODO: This is inefficient. Make getFloat16Decoder efficient.\n        const float16Decode = getFloat16Decoder();\n        values = float16Decode(quantizedArray as Uint16Array);\n      } else {\n        throw new Error(\n          `Unsupported quantization type ${quantization.dtype} ` +\n          `for weight type float32.`);\n      }\n    } else if (dtype === 'int32') {\n      if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') {\n        throw new Error(\n          `Unsupported quantization type ${quantization.dtype} ` +\n          `for weight type int32.`);\n      }\n      values = new Int32Array(quantizedArray.length);\n      for (let i = 0; i < quantizedArray.length; i++) {\n        const v = quantizedArray[i];\n        values[i] = Math.round(v * quantization.scale + quantization.min);\n      }\n    } else {\n      throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);\n    }\n    offset += size * quantizationSizeFactor;\n  } else if (dtype === 'string') {\n    const size = sizeFromShape(spec.shape);\n    values = [];\n    for (let i = 0; i < size; i++) {\n      const byteLength = new Uint32Array(\n        byteBuffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0];\n      offset += NUM_BYTES_STRING_LENGTH;\n      const bytes = new Uint8Array(\n        byteBuffer.slice(offset, offset + byteLength));\n      (values as Uint8Array[]).push(bytes);\n      offset += byteLength;\n    }\n  } else {\n    const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype];\n    if (dtype === 'float32') {\n      values = new Float32Array(byteBuffer);\n    } else if (dtype === 'int32') {\n      values = new Int32Array(byteBuffer);\n    } else if (dtype === 'bool') {\n      values = new Uint8Array(byteBuffer);\n    } else if (dtype === 'complex64') {\n      values = new Float32Array(byteBuffer);\n      const real = new Float32Array(values.length / 2);\n      const image = new Float32Array(values.length / 2);\n      for (let i = 0; i < real.length; i++) {\n        real[i] = values[i * 2];\n        image[i] = values[i * 2 + 1];\n      }\n      const realTensor = tensor(real, shape, 'float32');\n      const imageTensor = tensor(image, shape, 'float32');\n      const complexTensor = complex(realTensor, imageTensor);\n      realTensor.dispose();\n      imageTensor.dispose();\n      return complexTensor;\n    } else {\n      throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);\n    }\n    offset += size * dtypeFactor;\n  }\n  return tensor(values, shape, dtype);\n}\n\nasync function readToLength(reader: ReadableStreamDefaultReader<ArrayBuffer>,\n                            initialData: ArrayBuffer,\n                            length: number): Promise<ArrayBuffer> {\n  let data = new Uint8Array(initialData);\n\n  while (data.byteLength < length) {\n    const {done, value} = await reader.read();\n    if (done && value == null) {\n      const missing  = length - data.byteLength;\n      throw new Error(`Reader is done but ${missing} bytes are still expected`);\n    }\n\n    // TODO: Don't create a new array every loop.\n    const newData = new Uint8Array(data.length + value.byteLength);\n    newData.set(data, 0);\n    newData.set(new Uint8Array(value), data.length);\n    data = newData;\n  }\n\n  return data.buffer;\n}\n\nexport async function decodeWeightsStream(\n  weightStream: ReadableStream<ArrayBuffer>,\n  specs: WeightsManifestEntry[]): Promise<NamedTensorMap> {\n\n  const tensors: NamedTensorMap = {};\n  const reader = weightStream.getReader();\n  let data = new ArrayBuffer(0);\n\n  for (const spec of specs) {\n    const byteLength = await getWeightBytelengthAsync(spec,\n                                                      async (start, end) => {\n      data = await readToLength(reader, data, end);\n      return data.slice(start, end);\n    });\n    data = await readToLength(reader, data, byteLength);\n\n    // Slice the tensor out\n    const tensorData = data.slice(0, byteLength);\n    data = data.slice(byteLength);\n\n    const weightTensor = decodeWeight(spec, tensorData);\n    tensors[spec.name] = weightTensor;\n\n    // TODO(mattsoulanille): Better way to call uploadToGPU.\n    // TODO(mattsoulanille): Make this work for webgl too.\n    if (getBackend() === 'webgpu') {\n      const b = backend();\n\n      if ('uploadToGPU' in b &&\n        sizeFromShape(weightTensor.shape) >= (env()\n          .get('WEBGPU_CPU_HANDOFF_SIZE_THRESHOLD') as number)) {\n        (b.uploadToGPU as (dataId: DataId) => void)(weightTensor.dataId);\n      }\n    }\n  }\n\n  return tensors;\n}\n\n/**\n * Concatenate TypedArrays into an ArrayBuffer.\n */\nexport function concatenateTypedArrays(xs: TypedArray[]): ArrayBuffer {\n  // TODO(adarob, cais): Support quantization.\n  if (xs === null) {\n    throw new Error(`Invalid input value: ${JSON.stringify(xs)}`);\n  }\n\n  let totalByteLength = 0;\n\n  // `normalizedXs` is here for this reason: a `TypedArray`'s `buffer'\n  // can have a different byte length from that of the `TypedArray` itself,\n  // for example, when the `TypedArray` is created from an offset in an\n  // `ArrayBuffer`. `normliazedXs` holds `TypedArray`s whose `buffer`s match\n  // the `TypedArray` in byte length. If an element of `xs` does not show\n  // this property, a new `TypedArray` that satisfy this property will be\n  // constructed and pushed into `normalizedXs`.\n  const normalizedXs: TypedArray[] = [];\n  xs.forEach((x: TypedArray) => {\n    totalByteLength += x.byteLength;\n    // tslint:disable:no-any\n    normalizedXs.push(\n        x.byteLength === x.buffer.byteLength ? x :\n                                               new (x.constructor as any)(x));\n    if (!(x as any instanceof Float32Array || x as any instanceof Int32Array ||\n          x as any instanceof Uint8Array)) {\n      throw new Error(`Unsupported TypedArray subtype: ${x.constructor.name}`);\n    }\n    // tslint:enable:no-any\n  });\n\n  const y = new Uint8Array(totalByteLength);\n  let offset = 0;\n  normalizedXs.forEach((x: TypedArray) => {\n    y.set(new Uint8Array(x.buffer), offset);\n    offset += x.byteLength;\n  });\n\n  return y.buffer;\n}\n\n// Use Buffer on Node.js instead of Blob/atob/btoa\nconst useNodeBuffer = typeof Buffer !== 'undefined' &&\n    (typeof Blob === 'undefined' || typeof atob === 'undefined' ||\n     typeof btoa === 'undefined');\n\n/**\n * Calculate the byte length of a JavaScript string.\n *\n * Note that a JavaScript string can contain wide characters, therefore the\n * length of the string is not necessarily equal to the byte length.\n *\n * @param str Input string.\n * @returns Byte length.\n */\nexport function stringByteLength(str: string): number {\n  if (useNodeBuffer) {\n    return Buffer.byteLength(str, 'utf8');\n  }\n  return new Blob([str]).size;\n}\n\n/**\n * Encode an ArrayBuffer as a base64 encoded string.\n *\n * @param buffer `ArrayBuffer` to be converted.\n * @returns A string that base64-encodes `buffer`.\n */\nexport function arrayBufferToBase64String(buffer: ArrayBuffer): string {\n  if (useNodeBuffer) {\n    return Buffer.from(buffer).toString('base64');\n  }\n  const buf = new Uint8Array(buffer);\n  let s = '';\n  for (let i = 0, l = buf.length; i < l; i++) {\n    s += String.fromCharCode(buf[i]);\n  }\n  return btoa(s);\n}\n\n/**\n * Decode a base64 string as an ArrayBuffer.\n *\n * @param str Base64 string.\n * @returns Decoded `ArrayBuffer`.\n */\nexport function base64StringToArrayBuffer(str: string): ArrayBuffer {\n  if (useNodeBuffer) {\n    const buf = Buffer.from(str, 'base64');\n    return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength);\n  }\n  const s = atob(str);\n  const buffer = new Uint8Array(s.length);\n  for (let i = 0; i < s.length; ++i) {\n    buffer.set([s.charCodeAt(i)], i);\n  }\n  return buffer.buffer;\n}\n\n/**\n * Concatenate a number of ArrayBuffers into one.\n *\n * @param buffers An array of ArrayBuffers to concatenate, or a single\n *     ArrayBuffer.\n * @returns Result of concatenating `buffers` in order.\n *\n * @deprecated Use tf.io.CompositeArrayBuffer.join() instead.\n */\nexport function concatenateArrayBuffers(buffers: ArrayBuffer[]\n      | ArrayBuffer): ArrayBuffer {\n  return CompositeArrayBuffer.join(buffers);\n}\n\n/**\n * Get the basename of a path.\n *\n * Behaves in a way analogous to Linux's basename command.\n *\n * @param path\n */\nexport function basename(path: string): string {\n  const SEPARATOR = '/';\n  path = path.trim();\n  while (path.endsWith(SEPARATOR)) {\n    path = path.slice(0, path.length - 1);\n  }\n  const items = path.split(SEPARATOR);\n  return items[items.length - 1];\n}\n\n/**\n * Create `ModelJSON` from `ModelArtifacts`.\n *\n * @param artifacts Model artifacts, describing the model and its weights.\n * @param manifest Weight manifest, describing where the weights of the\n *     `ModelArtifacts` are stored, and some metadata about them.\n * @returns Object representing the `model.json` file describing the model\n *     artifacts and weights\n */\nexport function getModelJSONForModelArtifacts(\n    artifacts: ModelArtifacts, manifest: WeightsManifestConfig): ModelJSON {\n  const result: ModelJSON = {\n    modelTopology: artifacts.modelTopology,\n    format: artifacts.format,\n    generatedBy: artifacts.generatedBy,\n    convertedBy: artifacts.convertedBy,\n    weightsManifest: manifest\n  };\n  if (artifacts.signature != null) {\n    result.signature = artifacts.signature;\n  }\n  if (artifacts.userDefinedMetadata != null) {\n    result.userDefinedMetadata = artifacts.userDefinedMetadata;\n  }\n  if (artifacts.modelInitializer != null) {\n    result.modelInitializer = artifacts.modelInitializer;\n  }\n  if (artifacts.initializerSignature != null) {\n    result.initializerSignature = artifacts.initializerSignature;\n  }\n  if (artifacts.trainingConfig != null) {\n    result.trainingConfig = artifacts.trainingConfig;\n  }\n  return result;\n}\n\n/**\n * Create `ModelArtifacts` from a JSON file and weights.\n *\n * @param modelJSON Object containing the parsed JSON of `model.json`\n * @param weightSpecs The list of WeightsManifestEntry for the model. Must be\n *     passed if the modelJSON has a weightsManifest.\n * @param weightData An ArrayBuffer or array of ArrayBuffers of weight data for\n *     the model corresponding to the weights in weightSpecs. Must be passed if\n *     the modelJSON has a weightsManifest.\n * @returns A Promise of the `ModelArtifacts`, as described by the JSON file.\n */\nexport function getModelArtifactsForJSONSync(\n    modelJSON: ModelJSON, weightSpecs?: WeightsManifestEntry[],\n    weightData?: WeightData): ModelArtifacts {\n\n  const modelArtifacts: ModelArtifacts = {\n    modelTopology: modelJSON.modelTopology,\n    format: modelJSON.format,\n    generatedBy: modelJSON.generatedBy,\n    convertedBy: modelJSON.convertedBy\n  };\n\n  if (modelJSON.trainingConfig != null) {\n    modelArtifacts.trainingConfig = modelJSON.trainingConfig;\n  }\n  if (modelJSON.weightsManifest != null) {\n    if (!weightSpecs) {\n      throw new Error('modelJSON has weightsManifest but weightSpecs is null');\n    }\n    if (!weightData) {\n      throw new Error('modelJSON has weightsManifest but weightData is null');\n    }\n    modelArtifacts.weightSpecs = weightSpecs;\n    modelArtifacts.weightData = weightData;\n  }\n  if (modelJSON.signature != null) {\n    modelArtifacts.signature = modelJSON.signature;\n  }\n  if (modelJSON.userDefinedMetadata != null) {\n    modelArtifacts.userDefinedMetadata = modelJSON.userDefinedMetadata;\n  }\n  if (modelJSON.modelInitializer != null) {\n    modelArtifacts.modelInitializer = modelJSON.modelInitializer;\n  }\n  if (modelJSON.initializerSignature != null) {\n    modelArtifacts.initializerSignature = modelJSON.initializerSignature;\n  }\n\n  return modelArtifacts;\n}\n\n/**\n * Create `ModelArtifacts` from a JSON file.\n *\n * @param modelJSON Object containing the parsed JSON of `model.json`\n * @param loadWeights Function that takes the JSON file's weights manifest,\n *     reads weights from the listed path(s), and returns a Promise of the\n *     weight manifest entries along with the weights data.\n * @returns A Promise of the `ModelArtifacts`, as described by the JSON file.\n */\nexport async function getModelArtifactsForJSON(\n    modelJSON: ModelJSON,\n    loadWeights: (weightsManifest: WeightsManifestConfig) => Promise<[\n      /* weightSpecs */ WeightsManifestEntry[], WeightData,\n    ]>): Promise<ModelArtifacts> {\n  let weightSpecs: WeightsManifestEntry[] | undefined;\n  let weightData: WeightData | undefined;\n\n  if (modelJSON.weightsManifest != null) {\n    [weightSpecs, weightData] = await loadWeights(modelJSON.weightsManifest);\n  }\n\n  return getModelArtifactsForJSONSync(modelJSON, weightSpecs, weightData);\n}\n\n/**\n * Populate ModelArtifactsInfo fields for a model with JSON topology.\n * @param modelArtifacts\n * @returns A ModelArtifactsInfo object.\n */\nexport function getModelArtifactsInfoForJSON(modelArtifacts: ModelArtifacts):\n    ModelArtifactsInfo {\n  if (modelArtifacts.modelTopology instanceof ArrayBuffer) {\n    throw new Error('Expected JSON model topology, received ArrayBuffer.');\n  }\n\n  return {\n    dateSaved: new Date(),\n    modelTopologyType: 'JSON',\n    modelTopologyBytes: modelArtifacts.modelTopology == null ?\n        0 :\n        stringByteLength(JSON.stringify(modelArtifacts.modelTopology)),\n    weightSpecsBytes: modelArtifacts.weightSpecs == null ?\n        0 :\n        stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)),\n    weightDataBytes: modelArtifacts.weightData == null ?\n        0 :\n        new CompositeArrayBuffer(modelArtifacts.weightData).byteLength,\n  };\n}\n\n/**\n * Concatenate the weights stored in a WeightsManifestConfig into a list of\n * WeightsManifestEntry\n *\n * @param weightsManifest The WeightsManifestConfig to extract weights from.\n * @returns A list of WeightsManifestEntry of the weights in the weightsManifest\n */\nexport function getWeightSpecs(weightsManifest: WeightsManifestConfig):\n    WeightsManifestEntry[] {\n  const weightSpecs: WeightsManifestEntry[] = [];\n  for (const entry of weightsManifest) {\n    weightSpecs.push(...entry.weights);\n  }\n  return weightSpecs;\n}\n\n/**\n * Computes mantisa table for casting Float16 to Float32\n * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf\n *\n * @returns Uint32Array, 2048 mantissa lookup values.\n */\nfunction computeFloat16MantisaTable(): Uint32Array {\n  const convertMantissa = (i: number): number => {\n    let m = i << 13;\n    let e = 0;\n\n    while ((m & 0x00800000) === 0) {\n      e -= 0x00800000;\n      m <<= 1;\n    }\n    m &= ~0x00800000;\n    e += 0x38800000;\n\n    return m | e;\n  };\n\n  const mantisaTable = new Uint32Array(2048);\n\n  mantisaTable[0] = 0;\n  for (let i = 1; i < 1024; i++) {\n    mantisaTable[i] = convertMantissa(i);\n  }\n  for (let i = 1024; i < 2048; i++) {\n    mantisaTable[i] = 0x38000000 + ((i - 1024) << 13);\n  }\n\n  return mantisaTable;\n}\n\n/**\n * Computes exponent table for casting Float16 to Float32\n * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf\n *\n * @returns Uint32Array, 64 exponent lookup values.\n */\nfunction computeFloat16ExponentTable(): Uint32Array {\n  const exponentTable = new Uint32Array(64);\n\n  exponentTable[0] = 0;\n  exponentTable[31] = 0x47800000;\n  exponentTable[32] = 0x80000000;\n  exponentTable[63] = 0xc7800000;\n  for (let i = 1; i < 31; i++) {\n    exponentTable[i] = i << 23;\n  }\n  for (let i = 33; i < 63; i++) {\n    exponentTable[i] = 0x80000000 + ((i - 32) << 23);\n  }\n\n  return exponentTable;\n}\n\n/**\n * Computes offset table for casting Float16 to Float32\n * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf\n *\n * @returns Uint32Array, 6d offset values.\n */\nfunction computeFloat16OffsetTable(): Uint32Array {\n  const offsetTable = new Uint32Array(64);\n\n  for (let i = 0; i < 64; i++) {\n    offsetTable[i] = 1024;\n  }\n  offsetTable[0] = offsetTable[32] = 0;\n\n  return offsetTable;\n}\n\n/**\n * Retrieve a Float16 decoder which will decode a ByteArray of Float16 values\n * to a Float32Array.\n *\n * @returns Function (buffer: Uint16Array) => Float32Array which decodes\n *          the Uint16Array of Float16 bytes to a Float32Array.\n */\nexport function getFloat16Decoder(): (buffer: Uint16Array) => Float32Array {\n  // Algorithm is based off of\n  // http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf\n\n  // Cache lookup tables\n  const mantisaTable = computeFloat16MantisaTable();\n  const exponentTable = computeFloat16ExponentTable();\n  const offsetTable = computeFloat16OffsetTable();\n\n  return (quantizedArray: Uint16Array) => {\n    const buffer = new ArrayBuffer(4 * quantizedArray.length);\n    const bufferUint32View = new Uint32Array(buffer);\n    for (let index = 0; index < quantizedArray.length; index++) {\n      const float16Bits = quantizedArray[index];\n      const float32Bits =\n          mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 0x3ff)] +\n          exponentTable[float16Bits >> 10];\n      bufferUint32View[index] = float32Bits;\n    }\n    return new Float32Array(buffer);\n  };\n}\n"]}
|