/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { env } from '../environment';
|
import * as util from '../util';
|
import { CompositeArrayBuffer } from './composite_array_buffer';
|
import { decodeWeights } from './io_utils';
|
import { monitorPromisesProgress } from './progress';
|
import { DTYPE_VALUE_SIZE_MAP } from './types';
|
/**
|
* Reads binary weights data from a number of URLs.
|
*
|
* @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls.
|
* @param requestOptions RequestInit (options) for the HTTP requests.
|
* @param fetchFunc Optional overriding value for the `window.fetch` function.
|
* @param onProgress Optional, progress callback function, fired periodically
|
* before the load is completed.
|
* @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same
|
* length as `fetchURLs`.
|
*/
|
export async function loadWeightsAsArrayBuffer(fetchURLs, loadOptions) {
|
if (loadOptions == null) {
|
loadOptions = {};
|
}
|
const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch :
|
loadOptions.fetchFunc;
|
// Create the requests for all of the weights in parallel.
|
const requests = fetchURLs.map(fetchURL => fetchFunc(fetchURL, loadOptions.requestInit, { isBinary: true }));
|
const fetchStartFraction = 0;
|
const fetchEndFraction = 0.5;
|
const responses = loadOptions.onProgress == null ?
|
await Promise.all(requests) :
|
await monitorPromisesProgress(requests, loadOptions.onProgress, fetchStartFraction, fetchEndFraction);
|
const bufferPromises = responses.map(response => response.arrayBuffer());
|
const bufferStartFraction = 0.5;
|
const bufferEndFraction = 1;
|
const buffers = loadOptions.onProgress == null ?
|
await Promise.all(bufferPromises) :
|
await monitorPromisesProgress(bufferPromises, loadOptions.onProgress, bufferStartFraction, bufferEndFraction);
|
return buffers;
|
}
|
export function streamWeights(fetchURLs, loadOptions) {
|
var _a;
|
const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch :
|
loadOptions.fetchFunc;
|
let fetchIndex = 0;
|
let chunkReader;
|
(_a = loadOptions.onProgress) === null || _a === void 0 ? void 0 : _a.call(loadOptions, 0);
|
return new ReadableStream({
|
pull: async (controller) => {
|
var _a;
|
while (fetchIndex < fetchURLs.length) {
|
if (!chunkReader) {
|
const body = (await fetchFunc(fetchURLs[fetchIndex], loadOptions.requestInit, { isBinary: true })).body;
|
chunkReader = body.getReader();
|
}
|
const { done, value } = await chunkReader.read();
|
if (done) {
|
fetchIndex++;
|
chunkReader = undefined;
|
(_a = loadOptions.onProgress) === null || _a === void 0 ? void 0 : _a.call(loadOptions, fetchIndex / fetchURLs.length);
|
continue;
|
}
|
controller.enqueue(value);
|
return;
|
}
|
controller.close();
|
},
|
});
|
}
|
/**
|
* Reads a weights manifest JSON configuration, fetches the weights and
|
* returns them as `Tensor`s.
|
*
|
* @param manifest The weights manifest JSON.
|
* @param filePathPrefix The path prefix for filenames given in the manifest.
|
* Defaults to the empty string.
|
* @param weightNames The names of the weights to be fetched.
|
*/
|
export async function loadWeights(manifest, filePathPrefix = '', weightNames, requestInit) {
|
// TODO(nsthorat): Groups are currently fetched atomically. If you need a
|
// single weight from a group, the whole group will be fetched. At a future
|
// date, we should support fetching only the individual shards within a
|
// group that are needed to reconstruct the requested weight.
|
// TODO(cais): Use `decodeWeights` for implementation.
|
const fetchWeights = (fetchUrls) => loadWeightsAsArrayBuffer(fetchUrls, { requestInit });
|
const loadWeights = weightsLoaderFactory(fetchWeights);
|
return loadWeights(manifest, filePathPrefix, weightNames);
|
}
|
/**
|
* Creates a function, which reads a weights manifest JSON configuration,
|
* fetches the weight files using the specified function and returns them as
|
* `Tensor`s.
|
*
|
* ```js
|
* // example for creating a nodejs weight loader, which reads the weight files
|
* // from disk using fs.readFileSync
|
*
|
* import * as fs from 'fs'
|
*
|
* const fetchWeightsFromDisk = (filePaths: string[]) =>
|
* filePaths.map(filePath => fs.readFileSync(filePath).buffer)
|
*
|
* const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk)
|
*
|
* const manifest = JSON.parse(
|
* fs.readFileSync('./my_model-weights_manifest').toString()
|
* )
|
* const weightMap = await loadWeights(manifest, './')
|
* ```
|
* @param fetchWeightsFunction The function used for fetching the weight files.
|
* @returns Weight loading function.
|
*/
|
export function weightsLoaderFactory(fetchWeightsFunction) {
|
return async (manifest, filePathPrefix = '', weightNames) => {
|
// Collect all the groups, weights, and their relative offsets to be
|
// fetched.
|
const groupIndicesToFetchMap = manifest.map(() => false);
|
const groupWeightsToFetch = {};
|
const weightsFound = weightNames != null ? weightNames.map(() => false) : [];
|
const allManifestWeightNames = [];
|
manifest.forEach((manifestGroupConfig, groupIndex) => {
|
let groupOffset = 0;
|
manifestGroupConfig.weights.forEach(weightsEntry => {
|
const rawDtype = ('quantization' in weightsEntry) ?
|
weightsEntry.quantization.dtype :
|
weightsEntry.dtype;
|
const weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] *
|
util.sizeFromShape(weightsEntry.shape);
|
const enqueueWeightsForFetchingFn = () => {
|
groupIndicesToFetchMap[groupIndex] = true;
|
if (groupWeightsToFetch[groupIndex] == null) {
|
groupWeightsToFetch[groupIndex] = [];
|
}
|
groupWeightsToFetch[groupIndex].push({
|
manifestEntry: weightsEntry,
|
groupOffset,
|
sizeBytes: weightsBytes
|
});
|
};
|
if (weightNames != null) {
|
weightNames.forEach((weightName, weightIndex) => {
|
if (weightName === weightsEntry.name) {
|
enqueueWeightsForFetchingFn();
|
weightsFound[weightIndex] = true;
|
}
|
});
|
}
|
else {
|
enqueueWeightsForFetchingFn();
|
}
|
allManifestWeightNames.push(weightsEntry.name);
|
groupOffset += weightsBytes;
|
});
|
});
|
if (!weightsFound.every(found => found)) {
|
const weightsNotFound = weightNames.filter((_, i) => !weightsFound[i]);
|
throw new Error(`Could not find weights in manifest with names: ` +
|
`${weightsNotFound.join(', ')}. \n` +
|
`Manifest JSON has weights with names: ` +
|
`${allManifestWeightNames.join(', ')}.`);
|
}
|
// Convert the one-hot boolean groupId => shouldFetch map to a list of group
|
// IDs.
|
const groupIndicesToFetch = groupIndicesToFetchMap.reduce((accumulator, shouldFetch, i) => {
|
if (shouldFetch) {
|
accumulator.push(i);
|
}
|
return accumulator;
|
}, []);
|
const fetchUrls = [];
|
groupIndicesToFetch.forEach(i => {
|
manifest[i].paths.forEach(filepath => {
|
const fetchUrl = filePathPrefix +
|
(!filePathPrefix.endsWith('/') ? '/' : '') + filepath;
|
fetchUrls.push(fetchUrl);
|
});
|
});
|
const buffers = await fetchWeightsFunction(fetchUrls);
|
const weightsTensorMap = {};
|
let bufferIndexOffset = 0;
|
groupIndicesToFetch.forEach(i => {
|
const numBuffers = manifest[i].paths.length;
|
const weightsBuffer = new CompositeArrayBuffer(buffers.slice(bufferIndexOffset, bufferIndexOffset + numBuffers));
|
const weightsEntries = groupWeightsToFetch[i];
|
weightsEntries.forEach(weightsEntry => {
|
const byteBuffer = weightsBuffer.slice(weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes);
|
const nameToTensorMap = decodeWeights(byteBuffer, [weightsEntry.manifestEntry]);
|
for (const name in nameToTensorMap) {
|
weightsTensorMap[name] = nameToTensorMap[name];
|
}
|
});
|
bufferIndexOffset += numBuffers;
|
});
|
return weightsTensorMap;
|
};
|
}
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"weights_loader.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/io/weights_loader.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,GAAG,EAAC,MAAM,gBAAgB,CAAC;AAGnC,OAAO,KAAK,IAAI,MAAM,SAAS,CAAC;AAChC,OAAO,EAAC,oBAAoB,EAAC,MAAM,0BAA0B,CAAC;AAC9D,OAAO,EAAC,aAAa,EAAC,MAAM,YAAY,CAAC;AACzC,OAAO,EAAC,uBAAuB,EAAC,MAAM,YAAY,CAAC;AACnD,OAAO,EAAC,oBAAoB,EAA2D,MAAM,SAAS,CAAC;AAEvG;;;;;;;;;;GAUG;AACH,MAAM,CAAC,KAAK,UAAU,wBAAwB,CAC5C,SAAmB,EAAE,WAAyB;IAC9C,IAAI,WAAW,IAAI,IAAI,EAAE;QACvB,WAAW,GAAG,EAAE,CAAC;KAClB;IAED,MAAM,SAAS,GAAG,WAAW,CAAC,SAAS,IAAI,IAAI,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC;QACtE,WAAW,CAAC,SAAS,CAAC;IAExB,0DAA0D;IAC1D,MAAM,QAAQ,GAAG,SAAS,CAAC,GAAG,CAC5B,QAAQ,CAAC,EAAE,CACT,SAAS,CAAC,QAAQ,EAAE,WAAW,CAAC,WAAW,EAAE,EAAE,QAAQ,EAAE,IAAI,EAAE,CAAC,CAAC,CAAC;IAEtE,MAAM,kBAAkB,GAAG,CAAC,CAAC;IAC7B,MAAM,gBAAgB,GAAG,GAAG,CAAC;IAE7B,MAAM,SAAS,GAAG,WAAW,CAAC,UAAU,IAAI,IAAI,CAAC,CAAC;QAChD,MAAM,OAAO,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC,CAAC;QAC7B,MAAM,uBAAuB,CAC3B,QAAQ,EAAE,WAAW,CAAC,UAAU,EAAE,kBAAkB,EACpD,gBAAgB,CAAC,CAAC;IAEtB,MAAM,cAAc,GAAG,SAAS,CAAC,GAAG,CAAC,QAAQ,CAAC,EAAE,CAAC,QAAQ,CAAC,WAAW,EAAE,CAAC,CAAC;IAEzE,MAAM,mBAAmB,GAAG,GAAG,CAAC;IAChC,MAAM,iBAAiB,GAAG,CAAC,CAAC;IAE5B,MAAM,OAAO,GAAG,WAAW,CAAC,UAAU,IAAI,IAAI,CAAC,CAAC;QAC9C,MAAM,OAAO,CAAC,GAAG,CAAC,cAAc,CAAC,CAAC,CAAC;QACnC,MAAM,uBAAuB,CAC3B,cAAc,EAAE,WAAW,CAAC,UAAU,EAAE,mBAAmB,EAC3D,iBAAiB,CAAC,CAAC;IACvB,OAAO,OAAO,CAAC;AACjB,CAAC;AAED,MAAM,UAAU,aAAa,CAAC,SAAmB,EAAE,WAAwB;;IACzE,MAAM,SAAS,GAAG,WAAW,CAAC,SAAS,IAAI,IAAI,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC;QACtE,WAAW,CAAC,SAAS,CAAC;IAExB,IAAI,UAAU,GAAG,CAAC,CAAC;IACnB,IAAI,WAAgE,CAAC;IACrE,MAAA,WAAW,CAAC,UAAU,4DAAG,CAAC,CAAC,CAAC;IAC5B,OAAO,IAAI,cAAc,CAAa;QACpC,IAAI,EAAE,KAAK,EAAE,UAAU,EAAE,EAAE;;YACzB,OAAO,UAAU,GAAG,SAAS,CAAC,MAAM,EAAE;gBACpC,IAAI,CAAC,WAAW,EAAE;oBAChB,MAAM,IAAI,GAAG,CAAC,MAAM,SAAS,CAAC,SAAS,CAAC,UAAU,CAAC,EACpB,WAAW,CAAC,WAAW,EACvB,EAAC,QAAQ,EAAE,IAAI,EAAC,CAAC,CAAC,CAAC,IAAI,CAAC;oBAEvD,WAAW,GAAG,IAAI,CAAC,SAAS,EAAE,CAAC;iBAChC;gBAED,MAAM,EAAC,IAAI,EAAE,KAAK,EAAC,GAAG,MAAM,WAAW,CAAC,IAAI,EAAE,CAAC;gBAE/C,IAAI,IAAI,EAAE;oBACR,UAAU,EAAE,CAAC;oBACb,WAAW,GAAG,SAAS,CAAC;oBACxB,MAAA,WAAW,CAAC,UAAU,4DAAG,UAAU,GAAG,SAAS,CAAC,MAAM,CAAC,CAAC;oBACxD,SAAS;iBACV;gBACD,UAAU,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC;gBAC1B,OAAO;aACR;YACD,UAAU,CAAC,KAAK,EAAE,CAAC;QACrB,CAAC;KACF,CAAC,CAAC;AACL,CAAC;AAED;;;;;;;;GAQG;AACH,MAAM,CAAC,KAAK,UAAU,WAAW,CAC/B,QAA+B,EAAE,cAAc,GAAG,EAAE,EACpD,WAAsB,EACtB,WAAyB;IACzB,yEAAyE;IACzE,2EAA2E;IAC3E,uEAAuE;IACvE,6DAA6D;IAC7D,sDAAsD;IAEtD,MAAM,YAAY,GAAG,CAAC,SAAmB,EAAE,EAAE,CAC3C,wBAAwB,CAAC,SAAS,EAAE,EAAE,WAAW,EAAE,CAAC,CAAC;IACvD,MAAM,WAAW,GAAG,oBAAoB,CAAC,YAAY,CAAC,CAAC;IAEvD,OAAO,WAAW,CAAC,QAAQ,EAAE,cAAc,EAAE,WAAW,CAAC,CAAC;AAC5D,CAAC;AAED;;;;;;;;;;;;;;;;;;;;;;;GAuBG;AACH,MAAM,UAAU,oBAAoB,CAClC,oBAAqE;IAGrE,OAAO,KAAK,EACV,QAA+B,EAAE,cAAc,GAAG,EAAE,EACpD,WAAsB,EAA2B,EAAE;QACnD,oEAAoE;QACpE,WAAW;QACX,MAAM,sBAAsB,GAAG,QAAQ,CAAC,GAAG,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC;QACzD,MAAM,mBAAmB,GAKrB,EAAE,CAAC;QACP,MAAM,YAAY,GAChB,WAAW,IAAI,IAAI,CAAC,CAAC,CAAC,WAAW,CAAC,GAAG,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;QAC1D,MAAM,sBAAsB,GAAa,EAAE,CAAC;QAC5C,QAAQ,CAAC,OAAO,CAAC,CAAC,mBAAmB,EAAE,UAAU,EAAE,EAAE;YACnD,IAAI,WAAW,GAAG,CAAC,CAAC;YACpB,mBAAmB,CAAC,OAAO,CAAC,OAAO,CAAC,YAAY,CAAC,EAAE;gBACjD,MAAM,QAAQ,GAAG,CAAC,cAAc,IAAI,YAAY,CAAC,CAAC,CAAC;oBACjD,YAAY,CAAC,YAAY,CAAC,KAAK,CAAC,CAAC;oBACjC,YAAY,CAAC,KAAK,CAAC;gBAErB,MAAM,YAAY,GAAG,oBAAoB,CAAC,QAAQ,CAAC;oBACjD,IAAI,CAAC,aAAa,CAAC,YAAY,CAAC,KAAK,CAAC,CAAC;gBAEzC,MAAM,2BAA2B,GAAG,GAAG,EAAE;oBACvC,sBAAsB,CAAC,UAAU,CAAC,GAAG,IAAI,CAAC;oBAC1C,IAAI,mBAAmB,CAAC,UAAU,CAAC,IAAI,IAAI,EAAE;wBAC3C,mBAAmB,CAAC,UAAU,CAAC,GAAG,EAAE,CAAC;qBACtC;oBAED,mBAAmB,CAAC,UAAU,CAAC,CAAC,IAAI,CAAC;wBACnC,aAAa,EAAE,YAAY;wBAC3B,WAAW;wBACX,SAAS,EAAE,YAAY;qBACxB,CAAC,CAAC;gBACL,CAAC,CAAC;gBAEF,IAAI,WAAW,IAAI,IAAI,EAAE;oBACvB,WAAW,CAAC,OAAO,CAAC,CAAC,UAAU,EAAE,WAAW,EAAE,EAAE;wBAC9C,IAAI,UAAU,KAAK,YAAY,CAAC,IAAI,EAAE;4BACpC,2BAA2B,EAAE,CAAC;4BAC9B,YAAY,CAAC,WAAW,CAAC,GAAG,IAAI,CAAC;yBAClC;oBACH,CAAC,CAAC,CAAC;iBACJ;qBAAM;oBACL,2BAA2B,EAAE,CAAC;iBAC/B;gBAED,sBAAsB,CAAC,IAAI,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;gBAC/C,WAAW,IAAI,YAAY,CAAC;YAC9B,CAAC,CAAC,CAAC;QACL,CAAC,CAAC,CAAC;QAEH,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,KAAK,CAAC,EAAE;YACvC,MAAM,eAAe,GAAG,WAAW,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC;YACvE,MAAM,IAAI,KAAK,CACb,iDAAiD;gBACjD,GAAG,eAAe,CAAC,IAAI,CAAC,IAAI,CAAC,MAAM;gBACnC,wCAAwC;gBACxC,GAAG,sBAAsB,CAAC,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;SAC5C;QAED,4EAA4E;QAC5E,OAAO;QACP,MAAM,mBAAmB,GACvB,sBAAsB,CAAC,MAAM,CAAC,CAAC,WAAW,EAAE,WAAW,EAAE,CAAC,EAAE,EAAE;YAC5D,IAAI,WAAW,EAAE;gBACf,WAAW,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;aACrB;YACD,OAAO,WAAW,CAAC;QACrB,CAAC,EAAE,EAAE,CAAC,CAAC;QAET,MAAM,SAAS,GAAa,EAAE,CAAC;QAC/B,mBAAmB,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE;YAC9B,QAAQ,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,OAAO,CAAC,QAAQ,CAAC,EAAE;gBACnC,MAAM,QAAQ,GAAG,cAAc;oBAC7B,CAAC,CAAC,cAAc,CAAC,QAAQ,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,GAAG,QAAQ,CAAC;gBACxD,SAAS,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC;YAC3B,CAAC,CAAC,CAAC;QACL,CAAC,CAAC,CAAC;QACH,MAAM,OAAO,GAAG,MAAM,oBAAoB,CAAC,SAAS,CAAC,CAAC;QAEtD,MAAM,gBAAgB,GAAmB,EAAE,CAAC;QAC5C,IAAI,iBAAiB,GAAG,CAAC,CAAC;QAC1B,mBAAmB,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE;YAC9B,MAAM,UAAU,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,MAAM,CAAC;YAE5C,MAAM,aAAa,GAAG,IAAI,oBAAoB,CAC5C,OAAO,CAAC,KAAK,CAAC,iBAAiB,EAAE,iBAAiB,GAAG,UAAU,CAAC,CAAC,CAAC;YAEpE,MAAM,cAAc,GAAG,mBAAmB,CAAC,CAAC,CAAC,CAAC;YAE9C,cAAc,CAAC,OAAO,CAAC,YAAY,CAAC,EAAE;gBACpC,MAAM,UAAU,GAAG,aAAa,CAAC,KAAK,CACpC,YAAY,CAAC,WAAW,EACxB,YAAY,CAAC,WAAW,GAAG,YAAY,CAAC,SAAS,CAAC,CAAC;gBACrD,MAAM,eAAe,GACnB,aAAa,CAAC,UAAU,EAAE,CAAC,YAAY,CAAC,aAAa,CAAC,CAAC,CAAC;gBAC1D,KAAK,MAAM,IAAI,IAAI,eAAe,EAAE;oBAClC,gBAAgB,CAAC,IAAI,CAAC,GAAG,eAAe,CAAC,IAAI,CAAC,CAAC;iBAChD;YACH,CAAC,CAAC,CAAC;YAEH,iBAAiB,IAAI,UAAU,CAAC;QAClC,CAAC,CAAC,CAAC;QAEH,OAAO,gBAAgB,CAAC;IAC1B,CAAC,CAAC;AACJ,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env} from '../environment';\n\nimport {NamedTensorMap} from '../tensor_types';\nimport * as util from '../util';\nimport {CompositeArrayBuffer} from './composite_array_buffer';\nimport {decodeWeights} from './io_utils';\nimport {monitorPromisesProgress} from './progress';\nimport {DTYPE_VALUE_SIZE_MAP, LoadOptions, WeightsManifestConfig, WeightsManifestEntry} from './types';\n\n/**\n * Reads binary weights data from a number of URLs.\n *\n * @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls.\n * @param requestOptions RequestInit (options) for the HTTP requests.\n * @param fetchFunc Optional overriding value for the `window.fetch` function.\n * @param onProgress Optional, progress callback function, fired periodically\n *   before the load is completed.\n * @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same\n *   length as `fetchURLs`.\n */\nexport async function loadWeightsAsArrayBuffer(\n  fetchURLs: string[], loadOptions?: LoadOptions): Promise<ArrayBuffer[]> {\n  if (loadOptions == null) {\n    loadOptions = {};\n  }\n\n  const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch :\n    loadOptions.fetchFunc;\n\n  // Create the requests for all of the weights in parallel.\n  const requests = fetchURLs.map(\n    fetchURL =>\n      fetchFunc(fetchURL, loadOptions.requestInit, { isBinary: true }));\n\n  const fetchStartFraction = 0;\n  const fetchEndFraction = 0.5;\n\n  const responses = loadOptions.onProgress == null ?\n    await Promise.all(requests) :\n    await monitorPromisesProgress(\n      requests, loadOptions.onProgress, fetchStartFraction,\n      fetchEndFraction);\n\n  const bufferPromises = responses.map(response => response.arrayBuffer());\n\n  const bufferStartFraction = 0.5;\n  const bufferEndFraction = 1;\n\n  const buffers = loadOptions.onProgress == null ?\n    await Promise.all(bufferPromises) :\n    await monitorPromisesProgress(\n      bufferPromises, loadOptions.onProgress, bufferStartFraction,\n      bufferEndFraction);\n  return buffers;\n}\n\nexport function streamWeights(fetchURLs: string[], loadOptions: LoadOptions): ReadableStream<ArrayBuffer> {\n  const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch :\n    loadOptions.fetchFunc;\n\n  let fetchIndex = 0;\n  let chunkReader: ReadableStreamDefaultReader<Uint8Array> | undefined;\n  loadOptions.onProgress?.(0);\n  return new ReadableStream<Uint8Array>({\n    pull: async (controller) => {\n      while (fetchIndex < fetchURLs.length) {\n        if (!chunkReader) {\n          const body = (await fetchFunc(fetchURLs[fetchIndex],\n                                         loadOptions.requestInit,\n                                         {isBinary: true})).body;\n\n          chunkReader = body.getReader();\n        }\n\n        const {done, value} = await chunkReader.read();\n\n        if (done) {\n          fetchIndex++;\n          chunkReader = undefined;\n          loadOptions.onProgress?.(fetchIndex / fetchURLs.length);\n          continue;\n        }\n        controller.enqueue(value);\n        return;\n      }\n      controller.close();\n    },\n  });\n}\n\n/**\n * Reads a weights manifest JSON configuration, fetches the weights and\n * returns them as `Tensor`s.\n *\n * @param manifest The weights manifest JSON.\n * @param filePathPrefix The path prefix for filenames given in the manifest.\n *     Defaults to the empty string.\n * @param weightNames The names of the weights to be fetched.\n */\nexport async function loadWeights(\n  manifest: WeightsManifestConfig, filePathPrefix = '',\n  weightNames?: string[],\n  requestInit?: RequestInit): Promise<NamedTensorMap> {\n  // TODO(nsthorat): Groups are currently fetched atomically. If you need a\n  // single weight from a group, the whole group will be fetched. At a future\n  // date, we should support fetching only the individual shards within a\n  // group that are needed to reconstruct the requested weight.\n  // TODO(cais): Use `decodeWeights` for implementation.\n\n  const fetchWeights = (fetchUrls: string[]) =>\n    loadWeightsAsArrayBuffer(fetchUrls, { requestInit });\n  const loadWeights = weightsLoaderFactory(fetchWeights);\n\n  return loadWeights(manifest, filePathPrefix, weightNames);\n}\n\n/**\n * Creates a function, which reads a weights manifest JSON configuration,\n * fetches the weight files using the specified function and returns them as\n * `Tensor`s.\n *\n * ```js\n * // example for creating a nodejs weight loader, which reads the weight files\n * // from disk using fs.readFileSync\n *\n * import * as fs from 'fs'\n *\n * const fetchWeightsFromDisk = (filePaths: string[]) =>\n *   filePaths.map(filePath => fs.readFileSync(filePath).buffer)\n *\n * const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk)\n *\n * const manifest = JSON.parse(\n *   fs.readFileSync('./my_model-weights_manifest').toString()\n * )\n * const weightMap = await loadWeights(manifest, './')\n * ```\n * @param fetchWeightsFunction The function used for fetching the weight files.\n * @returns Weight loading function.\n */\nexport function weightsLoaderFactory(\n  fetchWeightsFunction: (fetchUrls: string[]) => Promise<ArrayBuffer[]>):\n  (manifest: WeightsManifestConfig, filePathPrefix?: string,\n    weightNames?: string[]) => Promise<NamedTensorMap> {\n  return async (\n    manifest: WeightsManifestConfig, filePathPrefix = '',\n    weightNames?: string[]): Promise<NamedTensorMap> => {\n    // Collect all the groups, weights, and their relative offsets to be\n    // fetched.\n    const groupIndicesToFetchMap = manifest.map(() => false);\n    const groupWeightsToFetch: {\n      [group: number]: Array<{\n        manifestEntry: WeightsManifestEntry; groupOffset: number;\n        sizeBytes: number;\n      }>\n    } = {};\n    const weightsFound =\n      weightNames != null ? weightNames.map(() => false) : [];\n    const allManifestWeightNames: string[] = [];\n    manifest.forEach((manifestGroupConfig, groupIndex) => {\n      let groupOffset = 0;\n      manifestGroupConfig.weights.forEach(weightsEntry => {\n        const rawDtype = ('quantization' in weightsEntry) ?\n          weightsEntry.quantization.dtype :\n          weightsEntry.dtype;\n\n        const weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] *\n          util.sizeFromShape(weightsEntry.shape);\n\n        const enqueueWeightsForFetchingFn = () => {\n          groupIndicesToFetchMap[groupIndex] = true;\n          if (groupWeightsToFetch[groupIndex] == null) {\n            groupWeightsToFetch[groupIndex] = [];\n          }\n\n          groupWeightsToFetch[groupIndex].push({\n            manifestEntry: weightsEntry,\n            groupOffset,\n            sizeBytes: weightsBytes\n          });\n        };\n\n        if (weightNames != null) {\n          weightNames.forEach((weightName, weightIndex) => {\n            if (weightName === weightsEntry.name) {\n              enqueueWeightsForFetchingFn();\n              weightsFound[weightIndex] = true;\n            }\n          });\n        } else {\n          enqueueWeightsForFetchingFn();\n        }\n\n        allManifestWeightNames.push(weightsEntry.name);\n        groupOffset += weightsBytes;\n      });\n    });\n\n    if (!weightsFound.every(found => found)) {\n      const weightsNotFound = weightNames.filter((_, i) => !weightsFound[i]);\n      throw new Error(\n        `Could not find weights in manifest with names: ` +\n        `${weightsNotFound.join(', ')}. \\n` +\n        `Manifest JSON has weights with names: ` +\n        `${allManifestWeightNames.join(', ')}.`);\n    }\n\n    // Convert the one-hot boolean groupId => shouldFetch map to a list of group\n    // IDs.\n    const groupIndicesToFetch =\n      groupIndicesToFetchMap.reduce((accumulator, shouldFetch, i) => {\n        if (shouldFetch) {\n          accumulator.push(i);\n        }\n        return accumulator;\n      }, []);\n\n    const fetchUrls: string[] = [];\n    groupIndicesToFetch.forEach(i => {\n      manifest[i].paths.forEach(filepath => {\n        const fetchUrl = filePathPrefix +\n          (!filePathPrefix.endsWith('/') ? '/' : '') + filepath;\n        fetchUrls.push(fetchUrl);\n      });\n    });\n    const buffers = await fetchWeightsFunction(fetchUrls);\n\n    const weightsTensorMap: NamedTensorMap = {};\n    let bufferIndexOffset = 0;\n    groupIndicesToFetch.forEach(i => {\n      const numBuffers = manifest[i].paths.length;\n\n      const weightsBuffer = new CompositeArrayBuffer(\n        buffers.slice(bufferIndexOffset, bufferIndexOffset + numBuffers));\n\n      const weightsEntries = groupWeightsToFetch[i];\n\n      weightsEntries.forEach(weightsEntry => {\n        const byteBuffer = weightsBuffer.slice(\n          weightsEntry.groupOffset,\n          weightsEntry.groupOffset + weightsEntry.sizeBytes);\n        const nameToTensorMap =\n          decodeWeights(byteBuffer, [weightsEntry.manifestEntry]);\n        for (const name in nameToTensorMap) {\n          weightsTensorMap[name] = nameToTensorMap[name];\n        }\n      });\n\n      bufferIndexOffset += numBuffers;\n    });\n\n    return weightsTensorMap;\n  };\n}\n"]}
|