"use strict"; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); }; var __generator = (this && this.__generator) || function (thisArg, body) { var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g; return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g; function verb(n) { return function (v) { return step([n, v]); }; } function step(op) { if (f) throw new TypeError("Generator is already executing."); while (_) try { if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t; if (y = 0, t) op = [op[0] & 2, t.value]; switch (op[0]) { case 0: case 1: t = op; break; case 4: _.label++; return { value: op[1], done: false }; case 5: _.label++; y = op[1]; op = [0]; continue; case 7: op = _.ops.pop(); _.trys.pop(); continue; default: if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; } if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; } if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; } if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; } if (t[2]) _.ops.pop(); _.trys.pop(); continue; } op = body.call(thisArg, _); } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; } if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true }; } }; Object.defineProperty(exports, "__esModule", { value: true }); var environment_1 = require("../environment"); var util = require("../util"); var io_utils_1 = require("./io_utils"); var progress_1 = require("./progress"); var types_1 = require("./types"); /** * Reads binary weights data from a number of URLs. * * @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls. * @param requestOptions RequestInit (options) for the HTTP requests. * @param fetchFunc Optional overriding value for the `window.fetch` function. * @param onProgress Optional, progress callback function, fired periodically * before the load is completed. * @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same * length as `fetchURLs`. */ function loadWeightsAsArrayBuffer(fetchURLs, loadOptions) { return __awaiter(this, void 0, void 0, function () { var fetchFunc, requests, fetchStartFraction, fetchEndFraction, responses, _a, bufferPromises, bufferStartFraction, bufferEndFraction, buffers, _b; return __generator(this, function (_c) { switch (_c.label) { case 0: if (loadOptions == null) { loadOptions = {}; } fetchFunc = loadOptions.fetchFunc == null ? environment_1.env().platform.fetch : loadOptions.fetchFunc; requests = fetchURLs.map(function (fetchURL) { return fetchFunc(fetchURL, loadOptions.requestInit, { isBinary: true }); }); fetchStartFraction = 0; fetchEndFraction = 0.5; if (!(loadOptions.onProgress == null)) return [3 /*break*/, 2]; return [4 /*yield*/, Promise.all(requests)]; case 1: _a = _c.sent(); return [3 /*break*/, 4]; case 2: return [4 /*yield*/, progress_1.monitorPromisesProgress(requests, loadOptions.onProgress, fetchStartFraction, fetchEndFraction)]; case 3: _a = _c.sent(); _c.label = 4; case 4: responses = _a; bufferPromises = responses.map(function (response) { return response.arrayBuffer(); }); bufferStartFraction = 0.5; bufferEndFraction = 1; if (!(loadOptions.onProgress == null)) return [3 /*break*/, 6]; return [4 /*yield*/, Promise.all(bufferPromises)]; case 5: _b = _c.sent(); return [3 /*break*/, 8]; case 6: return [4 /*yield*/, progress_1.monitorPromisesProgress(bufferPromises, loadOptions.onProgress, bufferStartFraction, bufferEndFraction)]; case 7: _b = _c.sent(); _c.label = 8; case 8: buffers = _b; return [2 /*return*/, buffers]; } }); }); } exports.loadWeightsAsArrayBuffer = loadWeightsAsArrayBuffer; /** * Reads a weights manifest JSON configuration, fetches the weights and * returns them as `Tensor`s. * * @param manifest The weights manifest JSON. * @param filePathPrefix The path prefix for filenames given in the manifest. * Defaults to the empty string. * @param weightNames The names of the weights to be fetched. */ function loadWeights(manifest, filePathPrefix, weightNames, requestInit) { if (filePathPrefix === void 0) { filePathPrefix = ''; } return __awaiter(this, void 0, void 0, function () { var fetchWeights, loadWeights; return __generator(this, function (_a) { fetchWeights = function (fetchUrls) { return loadWeightsAsArrayBuffer(fetchUrls, { requestInit: requestInit }); }; loadWeights = weightsLoaderFactory(fetchWeights); return [2 /*return*/, loadWeights(manifest, filePathPrefix, weightNames)]; }); }); } exports.loadWeights = loadWeights; /** * Creates a function, which reads a weights manifest JSON configuration, * fetches the weight files using the specified function and returns them as * `Tensor`s. * * ```js * // example for creating a nodejs weight loader, which reads the weight files * // from disk using fs.readFileSync * * import * as fs from 'fs' * * const fetchWeightsFromDisk = (filePaths: string[]) => * filePaths.map(filePath => fs.readFileSync(filePath).buffer) * * const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk) * * const manifest = JSON.parse( * fs.readFileSync('./my_model-weights_manifest').toString() * ) * const weightMap = await loadWeights(manifest, './') * ``` * @param fetchWeightsFunction The function used for fetching the weight files. * @returns Weight loading function. */ function weightsLoaderFactory(fetchWeightsFunction) { var _this = this; return function (manifest, filePathPrefix, weightNames) { if (filePathPrefix === void 0) { filePathPrefix = ''; } return __awaiter(_this, void 0, void 0, function () { var groupIndicesToFetchMap, groupWeightsToFetch, weightsFound, allManifestWeightNames, weightsNotFound, groupIndicesToFetch, fetchUrls, buffers, weightsTensorMap, bufferIndexOffset; return __generator(this, function (_a) { switch (_a.label) { case 0: groupIndicesToFetchMap = manifest.map(function () { return false; }); groupWeightsToFetch = {}; weightsFound = weightNames != null ? weightNames.map(function () { return false; }) : []; allManifestWeightNames = []; manifest.forEach(function (manifestGroupConfig, groupIndex) { var groupOffset = 0; manifestGroupConfig.weights.forEach(function (weightsEntry) { var rawDtype = ('quantization' in weightsEntry) ? weightsEntry.quantization.dtype : weightsEntry.dtype; var weightsBytes = types_1.DTYPE_VALUE_SIZE_MAP[rawDtype] * util.sizeFromShape(weightsEntry.shape); var enqueueWeightsForFetchingFn = function () { groupIndicesToFetchMap[groupIndex] = true; if (groupWeightsToFetch[groupIndex] == null) { groupWeightsToFetch[groupIndex] = []; } groupWeightsToFetch[groupIndex].push({ manifestEntry: weightsEntry, groupOffset: groupOffset, sizeBytes: weightsBytes }); }; if (weightNames != null) { weightNames.forEach(function (weightName, weightIndex) { if (weightName === weightsEntry.name) { enqueueWeightsForFetchingFn(); weightsFound[weightIndex] = true; } }); } else { enqueueWeightsForFetchingFn(); } allManifestWeightNames.push(weightsEntry.name); groupOffset += weightsBytes; }); }); if (!weightsFound.every(function (found) { return found; })) { weightsNotFound = weightNames.filter(function (_, i) { return !weightsFound[i]; }); throw new Error("Could not find weights in manifest with names: " + (weightsNotFound.join(', ') + ". \n") + "Manifest JSON has weights with names: " + (allManifestWeightNames.join(', ') + ".")); } groupIndicesToFetch = groupIndicesToFetchMap.reduce(function (accumulator, shouldFetch, i) { if (shouldFetch) { accumulator.push(i); } return accumulator; }, []); fetchUrls = []; groupIndicesToFetch.forEach(function (i) { manifest[i].paths.forEach(function (filepath) { var fetchUrl = filePathPrefix + (!filePathPrefix.endsWith('/') ? '/' : '') + filepath; fetchUrls.push(fetchUrl); }); }); return [4 /*yield*/, fetchWeightsFunction(fetchUrls)]; case 1: buffers = _a.sent(); weightsTensorMap = {}; bufferIndexOffset = 0; groupIndicesToFetch.forEach(function (i) { var numBuffers = manifest[i].paths.length; var groupBytes = 0; for (var i_1 = 0; i_1 < numBuffers; i_1++) { groupBytes += buffers[bufferIndexOffset + i_1].byteLength; } // Create a buffer for the whole group. var groupBuffer = new ArrayBuffer(groupBytes); var groupByteBuffer = new Uint8Array(groupBuffer); var groupBufferOffset = 0; for (var i_2 = 0; i_2 < numBuffers; i_2++) { var buffer = new Uint8Array(buffers[bufferIndexOffset + i_2]); groupByteBuffer.set(buffer, groupBufferOffset); groupBufferOffset += buffer.byteLength; } var weightsEntries = groupWeightsToFetch[i]; weightsEntries.forEach(function (weightsEntry) { var byteBuffer = groupBuffer.slice(weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes); var nameToTensorMap = io_utils_1.decodeWeights(byteBuffer, [weightsEntry.manifestEntry]); for (var name_1 in nameToTensorMap) { weightsTensorMap[name_1] = nameToTensorMap[name_1]; } }); bufferIndexOffset += numBuffers; }); return [2 /*return*/, weightsTensorMap]; } }); }); }; } exports.weightsLoaderFactory = weightsLoaderFactory; //# sourceMappingURL=weights_loader.js.map