/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { env } from '../environment'; import * as util from '../util'; import { CompositeArrayBuffer } from './composite_array_buffer'; import { decodeWeights } from './io_utils'; import { monitorPromisesProgress } from './progress'; import { DTYPE_VALUE_SIZE_MAP } from './types'; /** * Reads binary weights data from a number of URLs. * * @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls. * @param requestOptions RequestInit (options) for the HTTP requests. * @param fetchFunc Optional overriding value for the `window.fetch` function. * @param onProgress Optional, progress callback function, fired periodically * before the load is completed. * @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same * length as `fetchURLs`. */ export async function loadWeightsAsArrayBuffer(fetchURLs, loadOptions) { if (loadOptions == null) { loadOptions = {}; } const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : loadOptions.fetchFunc; // Create the requests for all of the weights in parallel. const requests = fetchURLs.map(fetchURL => fetchFunc(fetchURL, loadOptions.requestInit, { isBinary: true })); const fetchStartFraction = 0; const fetchEndFraction = 0.5; const responses = loadOptions.onProgress == null ? await Promise.all(requests) : await monitorPromisesProgress(requests, loadOptions.onProgress, fetchStartFraction, fetchEndFraction); const bufferPromises = responses.map(response => response.arrayBuffer()); const bufferStartFraction = 0.5; const bufferEndFraction = 1; const buffers = loadOptions.onProgress == null ? await Promise.all(bufferPromises) : await monitorPromisesProgress(bufferPromises, loadOptions.onProgress, bufferStartFraction, bufferEndFraction); return buffers; } export function streamWeights(fetchURLs, loadOptions) { var _a; const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : loadOptions.fetchFunc; let fetchIndex = 0; let chunkReader; (_a = loadOptions.onProgress) === null || _a === void 0 ? void 0 : _a.call(loadOptions, 0); return new ReadableStream({ pull: async (controller) => { var _a; while (fetchIndex < fetchURLs.length) { if (!chunkReader) { const body = (await fetchFunc(fetchURLs[fetchIndex], loadOptions.requestInit, { isBinary: true })).body; chunkReader = body.getReader(); } const { done, value } = await chunkReader.read(); if (done) { fetchIndex++; chunkReader = undefined; (_a = loadOptions.onProgress) === null || _a === void 0 ? void 0 : _a.call(loadOptions, fetchIndex / fetchURLs.length); continue; } controller.enqueue(value); return; } controller.close(); }, }); } /** * Reads a weights manifest JSON configuration, fetches the weights and * returns them as `Tensor`s. * * @param manifest The weights manifest JSON. * @param filePathPrefix The path prefix for filenames given in the manifest. * Defaults to the empty string. * @param weightNames The names of the weights to be fetched. */ export async function loadWeights(manifest, filePathPrefix = '', weightNames, requestInit) { // TODO(nsthorat): Groups are currently fetched atomically. If you need a // single weight from a group, the whole group will be fetched. At a future // date, we should support fetching only the individual shards within a // group that are needed to reconstruct the requested weight. // TODO(cais): Use `decodeWeights` for implementation. const fetchWeights = (fetchUrls) => loadWeightsAsArrayBuffer(fetchUrls, { requestInit }); const loadWeights = weightsLoaderFactory(fetchWeights); return loadWeights(manifest, filePathPrefix, weightNames); } /** * Creates a function, which reads a weights manifest JSON configuration, * fetches the weight files using the specified function and returns them as * `Tensor`s. * * ```js * // example for creating a nodejs weight loader, which reads the weight files * // from disk using fs.readFileSync * * import * as fs from 'fs' * * const fetchWeightsFromDisk = (filePaths: string[]) => * filePaths.map(filePath => fs.readFileSync(filePath).buffer) * * const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk) * * const manifest = JSON.parse( * fs.readFileSync('./my_model-weights_manifest').toString() * ) * const weightMap = await loadWeights(manifest, './') * ``` * @param fetchWeightsFunction The function used for fetching the weight files. * @returns Weight loading function. */ export function weightsLoaderFactory(fetchWeightsFunction) { return async (manifest, filePathPrefix = '', weightNames) => { // Collect all the groups, weights, and their relative offsets to be // fetched. const groupIndicesToFetchMap = manifest.map(() => false); const groupWeightsToFetch = {}; const weightsFound = weightNames != null ? weightNames.map(() => false) : []; const allManifestWeightNames = []; manifest.forEach((manifestGroupConfig, groupIndex) => { let groupOffset = 0; manifestGroupConfig.weights.forEach(weightsEntry => { const rawDtype = ('quantization' in weightsEntry) ? weightsEntry.quantization.dtype : weightsEntry.dtype; const weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] * util.sizeFromShape(weightsEntry.shape); const enqueueWeightsForFetchingFn = () => { groupIndicesToFetchMap[groupIndex] = true; if (groupWeightsToFetch[groupIndex] == null) { groupWeightsToFetch[groupIndex] = []; } groupWeightsToFetch[groupIndex].push({ manifestEntry: weightsEntry, groupOffset, sizeBytes: weightsBytes }); }; if (weightNames != null) { weightNames.forEach((weightName, weightIndex) => { if (weightName === weightsEntry.name) { enqueueWeightsForFetchingFn(); weightsFound[weightIndex] = true; } }); } else { enqueueWeightsForFetchingFn(); } allManifestWeightNames.push(weightsEntry.name); groupOffset += weightsBytes; }); }); if (!weightsFound.every(found => found)) { const weightsNotFound = weightNames.filter((_, i) => !weightsFound[i]); throw new Error(`Could not find weights in manifest with names: ` + `${weightsNotFound.join(', ')}. \n` + `Manifest JSON has weights with names: ` + `${allManifestWeightNames.join(', ')}.`); } // Convert the one-hot boolean groupId => shouldFetch map to a list of group // IDs. const groupIndicesToFetch = groupIndicesToFetchMap.reduce((accumulator, shouldFetch, i) => { if (shouldFetch) { accumulator.push(i); } return accumulator; }, []); const fetchUrls = []; groupIndicesToFetch.forEach(i => { manifest[i].paths.forEach(filepath => { const fetchUrl = filePathPrefix + (!filePathPrefix.endsWith('/') ? '/' : '') + filepath; fetchUrls.push(fetchUrl); }); }); const buffers = await fetchWeightsFunction(fetchUrls); const weightsTensorMap = {}; let bufferIndexOffset = 0; groupIndicesToFetch.forEach(i => { const numBuffers = manifest[i].paths.length; const weightsBuffer = new CompositeArrayBuffer(buffers.slice(bufferIndexOffset, bufferIndexOffset + numBuffers)); const weightsEntries = groupWeightsToFetch[i]; weightsEntries.forEach(weightsEntry => { const byteBuffer = weightsBuffer.slice(weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes); const nameToTensorMap = decodeWeights(byteBuffer, [weightsEntry.manifestEntry]); for (const name in nameToTensorMap) { weightsTensorMap[name] = nameToTensorMap[name]; } }); bufferIndexOffset += numBuffers; }); return weightsTensorMap; }; } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"weights_loader.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/io/weights_loader.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,GAAG,EAAC,MAAM,gBAAgB,CAAC;AAGnC,OAAO,KAAK,IAAI,MAAM,SAAS,CAAC;AAChC,OAAO,EAAC,oBAAoB,EAAC,MAAM,0BAA0B,CAAC;AAC9D,OAAO,EAAC,aAAa,EAAC,MAAM,YAAY,CAAC;AACzC,OAAO,EAAC,uBAAuB,EAAC,MAAM,YAAY,CAAC;AACnD,OAAO,EAAC,oBAAoB,EAA2D,MAAM,SAAS,CAAC;AAEvG;;;;;;;;;;GAUG;AACH,MAAM,CAAC,KAAK,UAAU,wBAAwB,CAC5C,SAAmB,EAAE,WAAyB;IAC9C,IAAI,WAAW,IAAI,IAAI,EAAE;QACvB,WAAW,GAAG,EAAE,CAAC;KAClB;IAED,MAAM,SAAS,GAAG,WAAW,CAAC,SAAS,IAAI,IAAI,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC;QACtE,WAAW,CAAC,SAAS,CAAC;IAExB,0DAA0D;IAC1D,MAAM,QAAQ,GAAG,SAAS,CAAC,GAAG,CAC5B,QAAQ,CAAC,EAAE,CACT,SAAS,CAAC,QAAQ,EAAE,WAAW,CAAC,WAAW,EAAE,EAAE,QAAQ,EAAE,IAAI,EAAE,CAAC,CAAC,CAAC;IAEtE,MAAM,kBAAkB,GAAG,CAAC,CAAC;IAC7B,MAAM,gBAAgB,GAAG,GAAG,CAAC;IAE7B,MAAM,SAAS,GAAG,WAAW,CAAC,UAAU,IAAI,IAAI,CAAC,CAAC;QAChD,MAAM,OAAO,CAAC,GAAG,CAAC,QAAQ,CAAC,CAAC,CAAC;QAC7B,MAAM,uBAAuB,CAC3B,QAAQ,EAAE,WAAW,CAAC,UAAU,EAAE,kBAAkB,EACpD,gBAAgB,CAAC,CAAC;IAEtB,MAAM,cAAc,GAAG,SAAS,CAAC,GAAG,CAAC,QAAQ,CAAC,EAAE,CAAC,QAAQ,CAAC,WAAW,EAAE,CAAC,CAAC;IAEzE,MAAM,mBAAmB,GAAG,GAAG,CAAC;IAChC,MAAM,iBAAiB,GAAG,CAAC,CAAC;IAE5B,MAAM,OAAO,GAAG,WAAW,CAAC,UAAU,IAAI,IAAI,CAAC,CAAC;QAC9C,MAAM,OAAO,CAAC,GAAG,CAAC,cAAc,CAAC,CAAC,CAAC;QACnC,MAAM,uBAAuB,CAC3B,cAAc,EAAE,WAAW,CAAC,UAAU,EAAE,mBAAmB,EAC3D,iBAAiB,CAAC,CAAC;IACvB,OAAO,OAAO,CAAC;AACjB,CAAC;AAED,MAAM,UAAU,aAAa,CAAC,SAAmB,EAAE,WAAwB;;IACzE,MAAM,SAAS,GAAG,WAAW,CAAC,SAAS,IAAI,IAAI,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC;QACtE,WAAW,CAAC,SAAS,CAAC;IAExB,IAAI,UAAU,GAAG,CAAC,CAAC;IACnB,IAAI,WAAgE,CAAC;IACrE,MAAA,WAAW,CAAC,UAAU,4DAAG,CAAC,CAAC,CAAC;IAC5B,OAAO,IAAI,cAAc,CAAa;QACpC,IAAI,EAAE,KAAK,EAAE,UAAU,EAAE,EAAE;;YACzB,OAAO,UAAU,GAAG,SAAS,CAAC,MAAM,EAAE;gBACpC,IAAI,CAAC,WAAW,EAAE;oBAChB,MAAM,IAAI,GAAG,CAAC,MAAM,SAAS,CAAC,SAAS,CAAC,UAAU,CAAC,EACpB,WAAW,CAAC,WAAW,EACvB,EAAC,QAAQ,EAAE,IAAI,EAAC,CAAC,CAAC,CAAC,IAAI,CAAC;oBAEvD,WAAW,GAAG,IAAI,CAAC,SAAS,EAAE,CAAC;iBAChC;gBAED,MAAM,EAAC,IAAI,EAAE,KAAK,EAAC,GAAG,MAAM,WAAW,CAAC,IAAI,EAAE,CAAC;gBAE/C,IAAI,IAAI,EAAE;oBACR,UAAU,EAAE,CAAC;oBACb,WAAW,GAAG,SAAS,CAAC;oBACxB,MAAA,WAAW,CAAC,UAAU,4DAAG,UAAU,GAAG,SAAS,CAAC,MAAM,CAAC,CAAC;oBACxD,SAAS;iBACV;gBACD,UAAU,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC;gBAC1B,OAAO;aACR;YACD,UAAU,CAAC,KAAK,EAAE,CAAC;QACrB,CAAC;KACF,CAAC,CAAC;AACL,CAAC;AAED;;;;;;;;GAQG;AACH,MAAM,CAAC,KAAK,UAAU,WAAW,CAC/B,QAA+B,EAAE,cAAc,GAAG,EAAE,EACpD,WAAsB,EACtB,WAAyB;IACzB,yEAAyE;IACzE,2EAA2E;IAC3E,uEAAuE;IACvE,6DAA6D;IAC7D,sDAAsD;IAEtD,MAAM,YAAY,GAAG,CAAC,SAAmB,EAAE,EAAE,CAC3C,wBAAwB,CAAC,SAAS,EAAE,EAAE,WAAW,EAAE,CAAC,CAAC;IACvD,MAAM,WAAW,GAAG,oBAAoB,CAAC,YAAY,CAAC,CAAC;IAEvD,OAAO,WAAW,CAAC,QAAQ,EAAE,cAAc,EAAE,WAAW,CAAC,CAAC;AAC5D,CAAC;AAED;;;;;;;;;;;;;;;;;;;;;;;GAuBG;AACH,MAAM,UAAU,oBAAoB,CAClC,oBAAqE;IAGrE,OAAO,KAAK,EACV,QAA+B,EAAE,cAAc,GAAG,EAAE,EACpD,WAAsB,EAA2B,EAAE;QACnD,oEAAoE;QACpE,WAAW;QACX,MAAM,sBAAsB,GAAG,QAAQ,CAAC,GAAG,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC;QACzD,MAAM,mBAAmB,GAKrB,EAAE,CAAC;QACP,MAAM,YAAY,GAChB,WAAW,IAAI,IAAI,CAAC,CAAC,CAAC,WAAW,CAAC,GAAG,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;QAC1D,MAAM,sBAAsB,GAAa,EAAE,CAAC;QAC5C,QAAQ,CAAC,OAAO,CAAC,CAAC,mBAAmB,EAAE,UAAU,EAAE,EAAE;YACnD,IAAI,WAAW,GAAG,CAAC,CAAC;YACpB,mBAAmB,CAAC,OAAO,CAAC,OAAO,CAAC,YAAY,CAAC,EAAE;gBACjD,MAAM,QAAQ,GAAG,CAAC,cAAc,IAAI,YAAY,CAAC,CAAC,CAAC;oBACjD,YAAY,CAAC,YAAY,CAAC,KAAK,CAAC,CAAC;oBACjC,YAAY,CAAC,KAAK,CAAC;gBAErB,MAAM,YAAY,GAAG,oBAAoB,CAAC,QAAQ,CAAC;oBACjD,IAAI,CAAC,aAAa,CAAC,YAAY,CAAC,KAAK,CAAC,CAAC;gBAEzC,MAAM,2BAA2B,GAAG,GAAG,EAAE;oBACvC,sBAAsB,CAAC,UAAU,CAAC,GAAG,IAAI,CAAC;oBAC1C,IAAI,mBAAmB,CAAC,UAAU,CAAC,IAAI,IAAI,EAAE;wBAC3C,mBAAmB,CAAC,UAAU,CAAC,GAAG,EAAE,CAAC;qBACtC;oBAED,mBAAmB,CAAC,UAAU,CAAC,CAAC,IAAI,CAAC;wBACnC,aAAa,EAAE,YAAY;wBAC3B,WAAW;wBACX,SAAS,EAAE,YAAY;qBACxB,CAAC,CAAC;gBACL,CAAC,CAAC;gBAEF,IAAI,WAAW,IAAI,IAAI,EAAE;oBACvB,WAAW,CAAC,OAAO,CAAC,CAAC,UAAU,EAAE,WAAW,EAAE,EAAE;wBAC9C,IAAI,UAAU,KAAK,YAAY,CAAC,IAAI,EAAE;4BACpC,2BAA2B,EAAE,CAAC;4BAC9B,YAAY,CAAC,WAAW,CAAC,GAAG,IAAI,CAAC;yBAClC;oBACH,CAAC,CAAC,CAAC;iBACJ;qBAAM;oBACL,2BAA2B,EAAE,CAAC;iBAC/B;gBAED,sBAAsB,CAAC,IAAI,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;gBAC/C,WAAW,IAAI,YAAY,CAAC;YAC9B,CAAC,CAAC,CAAC;QACL,CAAC,CAAC,CAAC;QAEH,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,KAAK,CAAC,EAAE;YACvC,MAAM,eAAe,GAAG,WAAW,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC;YACvE,MAAM,IAAI,KAAK,CACb,iDAAiD;gBACjD,GAAG,eAAe,CAAC,IAAI,CAAC,IAAI,CAAC,MAAM;gBACnC,wCAAwC;gBACxC,GAAG,sBAAsB,CAAC,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;SAC5C;QAED,4EAA4E;QAC5E,OAAO;QACP,MAAM,mBAAmB,GACvB,sBAAsB,CAAC,MAAM,CAAC,CAAC,WAAW,EAAE,WAAW,EAAE,CAAC,EAAE,EAAE;YAC5D,IAAI,WAAW,EAAE;gBACf,WAAW,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;aACrB;YACD,OAAO,WAAW,CAAC;QACrB,CAAC,EAAE,EAAE,CAAC,CAAC;QAET,MAAM,SAAS,GAAa,EAAE,CAAC;QAC/B,mBAAmB,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE;YAC9B,QAAQ,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,OAAO,CAAC,QAAQ,CAAC,EAAE;gBACnC,MAAM,QAAQ,GAAG,cAAc;oBAC7B,CAAC,CAAC,cAAc,CAAC,QAAQ,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,GAAG,QAAQ,CAAC;gBACxD,SAAS,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC;YAC3B,CAAC,CAAC,CAAC;QACL,CAAC,CAAC,CAAC;QACH,MAAM,OAAO,GAAG,MAAM,oBAAoB,CAAC,SAAS,CAAC,CAAC;QAEtD,MAAM,gBAAgB,GAAmB,EAAE,CAAC;QAC5C,IAAI,iBAAiB,GAAG,CAAC,CAAC;QAC1B,mBAAmB,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE;YAC9B,MAAM,UAAU,GAAG,QAAQ,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,MAAM,CAAC;YAE5C,MAAM,aAAa,GAAG,IAAI,oBAAoB,CAC5C,OAAO,CAAC,KAAK,CAAC,iBAAiB,EAAE,iBAAiB,GAAG,UAAU,CAAC,CAAC,CAAC;YAEpE,MAAM,cAAc,GAAG,mBAAmB,CAAC,CAAC,CAAC,CAAC;YAE9C,cAAc,CAAC,OAAO,CAAC,YAAY,CAAC,EAAE;gBACpC,MAAM,UAAU,GAAG,aAAa,CAAC,KAAK,CACpC,YAAY,CAAC,WAAW,EACxB,YAAY,CAAC,WAAW,GAAG,YAAY,CAAC,SAAS,CAAC,CAAC;gBACrD,MAAM,eAAe,GACnB,aAAa,CAAC,UAAU,EAAE,CAAC,YAAY,CAAC,aAAa,CAAC,CAAC,CAAC;gBAC1D,KAAK,MAAM,IAAI,IAAI,eAAe,EAAE;oBAClC,gBAAgB,CAAC,IAAI,CAAC,GAAG,eAAe,CAAC,IAAI,CAAC,CAAC;iBAChD;YACH,CAAC,CAAC,CAAC;YAEH,iBAAiB,IAAI,UAAU,CAAC;QAClC,CAAC,CAAC,CAAC;QAEH,OAAO,gBAAgB,CAAC;IAC1B,CAAC,CAAC;AACJ,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {env} from '../environment';\n\nimport {NamedTensorMap} from '../tensor_types';\nimport * as util from '../util';\nimport {CompositeArrayBuffer} from './composite_array_buffer';\nimport {decodeWeights} from './io_utils';\nimport {monitorPromisesProgress} from './progress';\nimport {DTYPE_VALUE_SIZE_MAP, LoadOptions, WeightsManifestConfig, WeightsManifestEntry} from './types';\n\n/**\n * Reads binary weights data from a number of URLs.\n *\n * @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls.\n * @param requestOptions RequestInit (options) for the HTTP requests.\n * @param fetchFunc Optional overriding value for the `window.fetch` function.\n * @param onProgress Optional, progress callback function, fired periodically\n *   before the load is completed.\n * @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same\n *   length as `fetchURLs`.\n */\nexport async function loadWeightsAsArrayBuffer(\n  fetchURLs: string[], loadOptions?: LoadOptions): Promise<ArrayBuffer[]> {\n  if (loadOptions == null) {\n    loadOptions = {};\n  }\n\n  const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch :\n    loadOptions.fetchFunc;\n\n  // Create the requests for all of the weights in parallel.\n  const requests = fetchURLs.map(\n    fetchURL =>\n      fetchFunc(fetchURL, loadOptions.requestInit, { isBinary: true }));\n\n  const fetchStartFraction = 0;\n  const fetchEndFraction = 0.5;\n\n  const responses = loadOptions.onProgress == null ?\n    await Promise.all(requests) :\n    await monitorPromisesProgress(\n      requests, loadOptions.onProgress, fetchStartFraction,\n      fetchEndFraction);\n\n  const bufferPromises = responses.map(response => response.arrayBuffer());\n\n  const bufferStartFraction = 0.5;\n  const bufferEndFraction = 1;\n\n  const buffers = loadOptions.onProgress == null ?\n    await Promise.all(bufferPromises) :\n    await monitorPromisesProgress(\n      bufferPromises, loadOptions.onProgress, bufferStartFraction,\n      bufferEndFraction);\n  return buffers;\n}\n\nexport function streamWeights(fetchURLs: string[], loadOptions: LoadOptions): ReadableStream<ArrayBuffer> {\n  const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch :\n    loadOptions.fetchFunc;\n\n  let fetchIndex = 0;\n  let chunkReader: ReadableStreamDefaultReader<Uint8Array> | undefined;\n  loadOptions.onProgress?.(0);\n  return new ReadableStream<Uint8Array>({\n    pull: async (controller) => {\n      while (fetchIndex < fetchURLs.length) {\n        if (!chunkReader) {\n          const body = (await fetchFunc(fetchURLs[fetchIndex],\n                                         loadOptions.requestInit,\n                                         {isBinary: true})).body;\n\n          chunkReader = body.getReader();\n        }\n\n        const {done, value} = await chunkReader.read();\n\n        if (done) {\n          fetchIndex++;\n          chunkReader = undefined;\n          loadOptions.onProgress?.(fetchIndex / fetchURLs.length);\n          continue;\n        }\n        controller.enqueue(value);\n        return;\n      }\n      controller.close();\n    },\n  });\n}\n\n/**\n * Reads a weights manifest JSON configuration, fetches the weights and\n * returns them as `Tensor`s.\n *\n * @param manifest The weights manifest JSON.\n * @param filePathPrefix The path prefix for filenames given in the manifest.\n *     Defaults to the empty string.\n * @param weightNames The names of the weights to be fetched.\n */\nexport async function loadWeights(\n  manifest: WeightsManifestConfig, filePathPrefix = '',\n  weightNames?: string[],\n  requestInit?: RequestInit): Promise<NamedTensorMap> {\n  // TODO(nsthorat): Groups are currently fetched atomically. If you need a\n  // single weight from a group, the whole group will be fetched. At a future\n  // date, we should support fetching only the individual shards within a\n  // group that are needed to reconstruct the requested weight.\n  // TODO(cais): Use `decodeWeights` for implementation.\n\n  const fetchWeights = (fetchUrls: string[]) =>\n    loadWeightsAsArrayBuffer(fetchUrls, { requestInit });\n  const loadWeights = weightsLoaderFactory(fetchWeights);\n\n  return loadWeights(manifest, filePathPrefix, weightNames);\n}\n\n/**\n * Creates a function, which reads a weights manifest JSON configuration,\n * fetches the weight files using the specified function and returns them as\n * `Tensor`s.\n *\n * ```js\n * // example for creating a nodejs weight loader, which reads the weight files\n * // from disk using fs.readFileSync\n *\n * import * as fs from 'fs'\n *\n * const fetchWeightsFromDisk = (filePaths: string[]) =>\n *   filePaths.map(filePath => fs.readFileSync(filePath).buffer)\n *\n * const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk)\n *\n * const manifest = JSON.parse(\n *   fs.readFileSync('./my_model-weights_manifest').toString()\n * )\n * const weightMap = await loadWeights(manifest, './')\n * ```\n * @param fetchWeightsFunction The function used for fetching the weight files.\n * @returns Weight loading function.\n */\nexport function weightsLoaderFactory(\n  fetchWeightsFunction: (fetchUrls: string[]) => Promise<ArrayBuffer[]>):\n  (manifest: WeightsManifestConfig, filePathPrefix?: string,\n    weightNames?: string[]) => Promise<NamedTensorMap> {\n  return async (\n    manifest: WeightsManifestConfig, filePathPrefix = '',\n    weightNames?: string[]): Promise<NamedTensorMap> => {\n    // Collect all the groups, weights, and their relative offsets to be\n    // fetched.\n    const groupIndicesToFetchMap = manifest.map(() => false);\n    const groupWeightsToFetch: {\n      [group: number]: Array<{\n        manifestEntry: WeightsManifestEntry; groupOffset: number;\n        sizeBytes: number;\n      }>\n    } = {};\n    const weightsFound =\n      weightNames != null ? weightNames.map(() => false) : [];\n    const allManifestWeightNames: string[] = [];\n    manifest.forEach((manifestGroupConfig, groupIndex) => {\n      let groupOffset = 0;\n      manifestGroupConfig.weights.forEach(weightsEntry => {\n        const rawDtype = ('quantization' in weightsEntry) ?\n          weightsEntry.quantization.dtype :\n          weightsEntry.dtype;\n\n        const weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] *\n          util.sizeFromShape(weightsEntry.shape);\n\n        const enqueueWeightsForFetchingFn = () => {\n          groupIndicesToFetchMap[groupIndex] = true;\n          if (groupWeightsToFetch[groupIndex] == null) {\n            groupWeightsToFetch[groupIndex] = [];\n          }\n\n          groupWeightsToFetch[groupIndex].push({\n            manifestEntry: weightsEntry,\n            groupOffset,\n            sizeBytes: weightsBytes\n          });\n        };\n\n        if (weightNames != null) {\n          weightNames.forEach((weightName, weightIndex) => {\n            if (weightName === weightsEntry.name) {\n              enqueueWeightsForFetchingFn();\n              weightsFound[weightIndex] = true;\n            }\n          });\n        } else {\n          enqueueWeightsForFetchingFn();\n        }\n\n        allManifestWeightNames.push(weightsEntry.name);\n        groupOffset += weightsBytes;\n      });\n    });\n\n    if (!weightsFound.every(found => found)) {\n      const weightsNotFound = weightNames.filter((_, i) => !weightsFound[i]);\n      throw new Error(\n        `Could not find weights in manifest with names: ` +\n        `${weightsNotFound.join(', ')}. \\n` +\n        `Manifest JSON has weights with names: ` +\n        `${allManifestWeightNames.join(', ')}.`);\n    }\n\n    // Convert the one-hot boolean groupId => shouldFetch map to a list of group\n    // IDs.\n    const groupIndicesToFetch =\n      groupIndicesToFetchMap.reduce((accumulator, shouldFetch, i) => {\n        if (shouldFetch) {\n          accumulator.push(i);\n        }\n        return accumulator;\n      }, []);\n\n    const fetchUrls: string[] = [];\n    groupIndicesToFetch.forEach(i => {\n      manifest[i].paths.forEach(filepath => {\n        const fetchUrl = filePathPrefix +\n          (!filePathPrefix.endsWith('/') ? '/' : '') + filepath;\n        fetchUrls.push(fetchUrl);\n      });\n    });\n    const buffers = await fetchWeightsFunction(fetchUrls);\n\n    const weightsTensorMap: NamedTensorMap = {};\n    let bufferIndexOffset = 0;\n    groupIndicesToFetch.forEach(i => {\n      const numBuffers = manifest[i].paths.length;\n\n      const weightsBuffer = new CompositeArrayBuffer(\n        buffers.slice(bufferIndexOffset, bufferIndexOffset + numBuffers));\n\n      const weightsEntries = groupWeightsToFetch[i];\n\n      weightsEntries.forEach(weightsEntry => {\n        const byteBuffer = weightsBuffer.slice(\n          weightsEntry.groupOffset,\n          weightsEntry.groupOffset + weightsEntry.sizeBytes);\n        const nameToTensorMap =\n          decodeWeights(byteBuffer, [weightsEntry.manifestEntry]);\n        for (const name in nameToTensorMap) {\n          weightsTensorMap[name] = nameToTensorMap[name];\n        }\n      });\n\n      bufferIndexOffset += numBuffers;\n    });\n\n    return weightsTensorMap;\n  };\n}\n"]}