/** * @license * Copyright 2023 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ 'use strict'; var tfc = require('@tensorflow/tfjs-core'); function _interopNamespaceDefault(e) { var n = Object.create(null); if (e) { Object.keys(e).forEach(function (k) { if (k !== 'default') { var d = Object.getOwnPropertyDescriptor(e, k); Object.defineProperty(n, k, d.get ? d : { enumerable: true, get: function () { return e[k]; } }); } }); } n.default = e; return n; } function _mergeNamespaces(n, m) { m.forEach(function (e) { e && typeof e !== 'string' && !Array.isArray(e) && Object.keys(e).forEach(function (k) { if (k !== 'default' && !(k in n)) { var d = Object.getOwnPropertyDescriptor(e, k); Object.defineProperty(n, k, d.get ? d : { enumerable: true, get: function () { return e[k]; } }); } }); }); return n; } var tfc__namespace = /*#__PURE__*/_interopNamespaceDefault(tfc); /****************************************************************************** Copyright (c) Microsoft Corporation. Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted. THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. ***************************************************************************** */ /* global Reflect, Promise */ var extendStatics = function (d, b) { extendStatics = Object.setPrototypeOf || ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) || function (d, b) { for (var p in b) if (Object.prototype.hasOwnProperty.call(b, p)) d[p] = b[p]; }; return extendStatics(d, b); }; function __extends(d, b) { if (typeof b !== "function" && b !== null) throw new TypeError("Class extends value " + String(b) + " is not a constructor or null"); extendStatics(d, b); function __() { this.constructor = d; } d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __()); } function __awaiter(thisArg, _arguments, P, generator) { function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); } function __generator(thisArg, body) { var _ = { label: 0, sent: function () { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g; return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function () { return this; }), g; function verb(n) { return function (v) { return step([n, v]); }; } function step(op) { if (f) throw new TypeError("Generator is already executing."); while (_) try { if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t; if (y = 0, t) op = [op[0] & 2, t.value]; switch (op[0]) { case 0: case 1: t = op; break; case 4: _.label++; return { value: op[1], done: false }; case 5: _.label++; y = op[1]; op = [0]; continue; case 7: op = _.ops.pop(); _.trys.pop(); continue; default: if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; } if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; } if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; } if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; } if (t[2]) _.ops.pop(); _.trys.pop(); continue; } op = body.call(thisArg, _); } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; } if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true }; } } function __values(o) { var s = typeof Symbol === "function" && Symbol.iterator, m = s && o[s], i = 0; if (m) return m.call(o); if (o && typeof o.length === "number") return { next: function () { if (o && i >= o.length) o = void 0; return { value: o && o[i++], done: !o }; } }; throw new TypeError(s ? "Object is not iterable." : "Symbol.iterator is not defined."); } function __read(o, n) { var m = typeof Symbol === "function" && o[Symbol.iterator]; if (!m) return o; var i = m.call(o), r, ar = [], e; try { while ((n === void 0 || n-- > 0) && !(r = i.next()).done) ar.push(r.value); } catch (error) { e = { error: error }; } finally { try { if (r && !r.done && (m = i["return"])) m.call(i); } finally { if (e) throw e.error; } } return ar; } function __spreadArray(to, from, pack) { if (pack || arguments.length === 2) for (var i = 0, l = from.length, ar; i < l; i++) { if (ar || !(i in from)) { if (!ar) ar = Array.prototype.slice.call(from, 0, i); ar[i] = from[i]; } } return to.concat(ar || Array.prototype.slice.call(from)); } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Explicit error types. * * See the following link for more information about why the code includes * calls to setPrototypeOf: * * https://github.com/Microsoft/TypeScript-wiki/blob/master/Breaking-Changes.md#extending-built-ins-like-error-array-and-map-may-no-longer-work */ // tslint:enable /** * Equivalent of Python's AttributeError. */ var AttributeError = /** @class */ (function (_super) { __extends(AttributeError, _super); function AttributeError(message) { var _this = _super.call(this, message) || this; // Set the prototype explicitly. Object.setPrototypeOf(_this, AttributeError.prototype); return _this; } return AttributeError; }(Error)); /** * Equivalent of Python's RuntimeError. */ var RuntimeError = /** @class */ (function (_super) { __extends(RuntimeError, _super); function RuntimeError(message) { var _this = _super.call(this, message) || this; // Set the prototype explicitly. Object.setPrototypeOf(_this, RuntimeError.prototype); return _this; } return RuntimeError; }(Error)); /** * Equivalent of Python's ValueError. */ var ValueError = /** @class */ (function (_super) { __extends(ValueError, _super); function ValueError(message) { var _this = _super.call(this, message) || this; // Set the prototype explicitly. Object.setPrototypeOf(_this, ValueError.prototype); return _this; } return ValueError; }(Error)); /** * Equivalent of Python's NotImplementedError. */ var NotImplementedError = /** @class */ (function (_super) { __extends(NotImplementedError, _super); function NotImplementedError(message) { var _this = _super.call(this, message) || this; // Set the prototype explicitly. Object.setPrototypeOf(_this, NotImplementedError.prototype); return _this; } return NotImplementedError; }(Error)); /** * Equivalent of Python's AssertionError. */ var AssertionError = /** @class */ (function (_super) { __extends(AssertionError, _super); function AssertionError(message) { var _this = _super.call(this, message) || this; // Set the prototype explicitly. Object.setPrototypeOf(_this, AssertionError.prototype); return _this; } return AssertionError; }(Error)); /** * Equivalent of Python's IndexError. */ /** @class */ ((function (_super) { __extends(IndexError, _super); function IndexError(message) { var _this = _super.call(this, message) || this; // Set the prototype explicitly. Object.setPrototypeOf(_this, IndexError.prototype); return _this; } return IndexError; })(Error)); /** * @license * Copyright 2022 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * LruCache: A mapping from the String to T. If the number of the entries is * exceeding the `maxEntries`, the LruCache will delete the least recently * used entry. */ var LruCache = /** @class */ (function () { function LruCache(maxEntries) { this.maxEntries = maxEntries || 100; this.cache = new Map(); } /** * Get the entry for the key and mark it as used recently. */ LruCache.prototype.get = function (key) { var entry; if (this.cache.has(key)) { entry = this.cache.get(key); this.cache.delete(key); this.cache.set(key, entry); } return entry; }; /** * Put the entry into the cache. If the key already existed, mark the key as * used recently. */ LruCache.prototype.put = function (key, value) { if (this.cache.has(key)) { this.cache.delete(key); } else if (this.cache.size >= this.maxEntries) { var keyToDelete = this.cache.keys().next().value; this.cache.delete(keyToDelete); } this.cache.set(key, value); }; /** * Get the MaxEntries of the cache. */ LruCache.prototype.getMaxEntries = function () { return this.maxEntries; }; /** * Set the MaxEntries of the cache. If the maxEntries is decreased, reduce * entries in the cache. */ LruCache.prototype.setMaxEntries = function (maxEntries) { if (maxEntries < 0) { throw new Error("The maxEntries of LRU caches must be at least 0, but got ".concat(maxEntries, ".")); } if (this.maxEntries > maxEntries) { for (var i = 0; i < this.maxEntries - maxEntries; i++) { var keyToDelete = this.cache.keys().next().value; this.cache.delete(keyToDelete); } } this.maxEntries = maxEntries; }; return LruCache; }()); // tslint:enable /** * If `value` is an Array, equivalent to Python's `value * numValues`. * If `value` is not an Array, equivalent to Python's `[value] * numValues` */ // tslint:disable-next-line:no-any function pyListRepeat(value, numValues) { if (Array.isArray(value)) { // tslint:disable-next-line:no-any var newArray = []; for (var i = 0; i < numValues; i++) { newArray = newArray.concat(value); } return newArray; } else { var newArray = new Array(numValues); newArray.fill(value); return newArray; } } function assert$1(val, message) { if (!val) { throw new AssertionError(message); } } /** * Count the number of elements of the `array` that are equal to `reference`. */ function count(array, refernce) { var e_1, _a; var counter = 0; try { for (var array_1 = __values(array), array_1_1 = array_1.next(); !array_1_1.done; array_1_1 = array_1.next()) { var item = array_1_1.value; if (item === refernce) { counter++; } } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (array_1_1 && !array_1_1.done && (_a = array_1.return)) _a.call(array_1); } finally { if (e_1) throw e_1.error; } } return counter; } /** * If an array is of length 1, just return the first element. Otherwise, return * the full array. * @param tensors */ function singletonOrArray(xs) { if (xs.length === 1) { return xs[0]; } return xs; } /** * Normalizes a list/tensor into a list. * * If a tensor is passed, we return * a list of size 1 containing the tensor. * * @param x target object to be normalized. */ // tslint:disable-next-line:no-any function toList(x) { if (Array.isArray(x)) { return x; } return [x]; } /** * Converts string to snake-case. * @param name */ function toSnakeCase(name) { var intermediate = name.replace(/(.)([A-Z][a-z0-9]+)/g, '$1_$2'); var insecure = intermediate.replace(/([a-z])([A-Z])/g, '$1_$2').toLowerCase(); /* If the class is private the name starts with "_" which is not secure for creating scopes. We prefix the name with "private" in this case. */ if (insecure[0] !== '_') { return insecure; } return 'private' + insecure; } function toCamelCase(identifier) { // quick return for empty string or single character strings if (identifier.length <= 1) { return identifier; } // Check for the underscore indicating snake_case if (identifier.indexOf('_') === -1) { return identifier; } return identifier.replace(/[_]+(\w|$)/g, function (m, p1) { return p1.toUpperCase(); }); } // tslint:disable-next-line:no-any var _GLOBAL_CUSTOM_OBJECTS = {}; function serializeKerasObject(instance) { if (instance === null || instance === undefined) { return null; } var dict = {}; dict['className'] = instance.getClassName(); dict['config'] = instance.getConfig(); return dict; } /** * Replace ndarray-style scalar objects in serialization objects with numbers. * * Background: In some versions of tf.keras, certain scalar values in the HDF5 * model save file can be serialized as: `{'type': 'ndarray', 'value': num}`, * where in `num` is a plain number. This method converts such serialization * to a `number`. * * @param config The keras-format serialization object to be processed * (in place). */ function convertNDArrayScalarsInConfig(config) { var e_3, _a; if (config == null || typeof config !== 'object') { return; } else if (Array.isArray(config)) { config.forEach(function (configItem) { return convertNDArrayScalarsInConfig(configItem); }); } else { var fields = Object.keys(config); try { for (var fields_1 = __values(fields), fields_1_1 = fields_1.next(); !fields_1_1.done; fields_1_1 = fields_1.next()) { var field = fields_1_1.value; var value = config[field]; if (value != null && typeof value === 'object') { if (!Array.isArray(value) && value['type'] === 'ndarray' && typeof value['value'] === 'number') { config[field] = value['value']; } else { convertNDArrayScalarsInConfig(value); } } } } catch (e_3_1) { e_3 = { error: e_3_1 }; } finally { try { if (fields_1_1 && !fields_1_1.done && (_a = fields_1.return)) _a.call(fields_1); } finally { if (e_3) throw e_3.error; } } } } /** * Deserialize a saved Keras Object * @param identifier either a string ID or a saved Keras dictionary * @param moduleObjects a list of Python class names to object constructors * @param customObjects a list of Python class names to object constructors * @param printableModuleName debug text for the object being reconstituted * @param fastWeightInit Optional flag to use fast weight initialization * during deserialization. This is applicable to cases in which * the initialization will be immediately overwritten by loaded weight * values. Default: `false`. * @returns a TensorFlow.js Layers object */ // tslint:disable:no-any function deserializeKerasObject(identifier, moduleObjects, customObjects, printableModuleName, fastWeightInit) { var _a, _b, _c, e_4, _d, e_5, _e, e_6, _f, e_7, _g; if (moduleObjects === void 0) { moduleObjects = {}; } if (customObjects === void 0) { customObjects = {}; } if (printableModuleName === void 0) { printableModuleName = 'object'; } if (fastWeightInit === void 0) { fastWeightInit = false; } // tslint:enable if (typeof identifier === 'string') { var functionName = identifier; var fn = void 0; if (functionName in customObjects) { fn = customObjects[functionName]; } else if (functionName in _GLOBAL_CUSTOM_OBJECTS) { fn = _GLOBAL_CUSTOM_OBJECTS[functionName]; } else { fn = moduleObjects[functionName]; if (fn == null) { throw new ValueError("Unknown ".concat(printableModuleName, ": ").concat(identifier, ". ") + "This may be due to one of the following reasons:\n" + "1. The ".concat(printableModuleName, " is defined in Python, in which ") + "case it needs to be ported to TensorFlow.js or your JavaScript " + "code.\n" + "2. The custom ".concat(printableModuleName, " is defined in JavaScript, ") + "but is not registered properly with " + "tf.serialization.registerClass()."); // TODO(cais): Add link to tutorial page on custom layers. } } return fn; } else { // In this case we are dealing with a Keras config dictionary. var config = identifier; if (config['className'] == null || config['config'] == null) { throw new ValueError("".concat(printableModuleName, ": Improper config format: ") + "".concat(JSON.stringify(config), ".\n") + "'className' and 'config' must set."); } var className = config['className']; var cls = void 0, fromConfig = void 0; if (className in customObjects) { _a = __read(customObjects[className], 2), cls = _a[0], fromConfig = _a[1]; } else if (className in _GLOBAL_CUSTOM_OBJECTS) { _b = __read(_GLOBAL_CUSTOM_OBJECTS['className'], 2), cls = _b[0], fromConfig = _b[1]; } else if (className in moduleObjects) { _c = __read(moduleObjects[className], 2), cls = _c[0], fromConfig = _c[1]; } if (cls == null) { throw new ValueError("Unknown ".concat(printableModuleName, ": ").concat(className, ". ") + "This may be due to one of the following reasons:\n" + "1. The ".concat(printableModuleName, " is defined in Python, in which ") + "case it needs to be ported to TensorFlow.js or your JavaScript " + "code.\n" + "2. The custom ".concat(printableModuleName, " is defined in JavaScript, ") + "but is not registered properly with " + "tf.serialization.registerClass()."); // TODO(cais): Add link to tutorial page on custom layers. } if (fromConfig != null) { // Porting notes: Instead of checking to see whether fromConfig accepts // customObjects, we create a customObjects dictionary and tack it on to // config['config'] as config['config'].customObjects. Objects can use it, // if they want. // tslint:disable-next-line:no-any var customObjectsCombined = {}; try { for (var _h = __values(Object.keys(_GLOBAL_CUSTOM_OBJECTS)), _j = _h.next(); !_j.done; _j = _h.next()) { var key = _j.value; customObjectsCombined[key] = _GLOBAL_CUSTOM_OBJECTS[key]; } } catch (e_4_1) { e_4 = { error: e_4_1 }; } finally { try { if (_j && !_j.done && (_d = _h.return)) _d.call(_h); } finally { if (e_4) throw e_4.error; } } try { for (var _k = __values(Object.keys(customObjects)), _l = _k.next(); !_l.done; _l = _k.next()) { var key = _l.value; customObjectsCombined[key] = customObjects[key]; } } catch (e_5_1) { e_5 = { error: e_5_1 }; } finally { try { if (_l && !_l.done && (_e = _k.return)) _e.call(_k); } finally { if (e_5) throw e_5.error; } } // Add the customObjects to config var nestedConfig = config['config']; nestedConfig['customObjects'] = customObjectsCombined; var backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS); try { for (var _m = __values(Object.keys(customObjects)), _o = _m.next(); !_o.done; _o = _m.next()) { var key = _o.value; _GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key]; } } catch (e_6_1) { e_6 = { error: e_6_1 }; } finally { try { if (_o && !_o.done && (_f = _m.return)) _f.call(_m); } finally { if (e_6) throw e_6.error; } } convertNDArrayScalarsInConfig(config['config']); var returnObj = fromConfig(cls, config['config'], customObjects, fastWeightInit); _GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects); return returnObj; } else { // Then `cls` may be a function returning a class. // In this case by convention `config` holds // the kwargs of the function. var backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS); try { for (var _p = __values(Object.keys(customObjects)), _q = _p.next(); !_q.done; _q = _p.next()) { var key = _q.value; _GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key]; } } catch (e_7_1) { e_7 = { error: e_7_1 }; } finally { try { if (_q && !_q.done && (_g = _p.return)) _g.call(_p); } finally { if (e_7) throw e_7.error; } } // In python this is **config['config'], for tfjs-layers we require // classes that use this fall-through construction method to take // a config interface that mimics the expansion of named parameters. var returnObj = new cls(config['config']); _GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects); return returnObj; } } } /** * Compares two numbers for sorting. * @param a * @param b */ function numberCompare(a, b) { return (a < b) ? -1 : ((a > b) ? 1 : 0); } /** * Comparison of two numbers for reverse sorting. * @param a * @param b */ function reverseNumberCompare(a, b) { return -1 * numberCompare(a, b); } /** * Get the unique elements of an array. * @param xs Array. * @returns An Array consisting of the unique elements in `xs`. */ function unique(xs) { var e_8, _a; if (xs == null) { return xs; } var out = []; try { // TODO(cais): Maybe improve performance by sorting. for (var xs_1 = __values(xs), xs_1_1 = xs_1.next(); !xs_1_1.done; xs_1_1 = xs_1.next()) { var x = xs_1_1.value; if (out.indexOf(x) === -1) { out.push(x); } } } catch (e_8_1) { e_8 = { error: e_8_1 }; } finally { try { if (xs_1_1 && !xs_1_1.done && (_a = xs_1.return)) _a.call(xs_1); } finally { if (e_8) throw e_8.error; } } return out; } /** * Determine if an Object is empty (i.e., does not have own properties). * @param obj Object * @returns Whether the Object is empty. * @throws ValueError: If object is `null` or `undefined`. */ function isObjectEmpty(obj) { if (obj == null) { throw new ValueError("Invalid value in obj: ".concat(JSON.stringify(obj))); } for (var key in obj) { if (obj.hasOwnProperty(key)) { return false; } } return true; } /** * Helper function used to build type union/enum run-time checkers. * @param values The list of allowed values. * @param label A string name for the type * @param value The value to test. * @throws ValueError: If the value is not in values nor `undefined`/`null`. */ function checkStringTypeUnionValue(values, label, value) { if (value == null) { return; } if (values.indexOf(value) < 0) { throw new ValueError("".concat(value, " is not a valid ").concat(label, ". Valid values are ").concat(values, " or null/undefined.")); } } /** * Helper function for verifying the types of inputs. * * Ensures that the elements of `x` are all of type `expectedType`. * Also verifies that the length of `x` is within bounds. * * @param x Object to test. * @param expectedType The string expected type of all of the elements in the * Array. * @param minLength Return false if x.length is less than this. * @param maxLength Return false if x.length is greater than this. * @returns true if and only if `x` is an `Array` with * length >= `minLength` and <= `maxLength`. */ // tslint:disable:no-any function checkArrayTypeAndLength(x, expectedType, minLength, maxLength) { if (minLength === void 0) { minLength = 0; } if (maxLength === void 0) { maxLength = Infinity; } assert$1(minLength >= 0); assert$1(maxLength >= minLength); return (Array.isArray(x) && x.length >= minLength && x.length <= maxLength && x.every(function (e) { return typeof e === expectedType; })); } // tslint:enable:no-any /** * Assert that a value or an array of value are positive integer. * * @param value The value being asserted on. May be a single number or an array * of numbers. * @param name Name of the value, used to make the error message. */ function assertPositiveInteger(value, name) { if (Array.isArray(value)) { tfc.util.assert(value.length > 0, function () { return "".concat(name, " is unexpectedly an empty array."); }); value.forEach(function (v, i) { return assertPositiveInteger(v, "element ".concat(i + 1, " of ").concat(name)); }); } else { tfc.util.assert(Number.isInteger(value) && value > 0, function () { return "Expected ".concat(name, " to be a positive integer, but got ") + "".concat(formatAsFriendlyString(value), "."); }); } } /** * Format a value into a display-friendly, human-readable fashion. * * - `null` is formatted as `'null'` * - Strings are formated with flanking pair of quotes. * - Arrays are formatted with flanking pair of square brackets. * * @param value The value to display. * @return Formatted string. */ // tslint:disable-next-line:no-any function formatAsFriendlyString(value) { if (value === null) { return 'null'; } else if (Array.isArray(value)) { return '[' + value.map(function (v) { return formatAsFriendlyString(v); }).join(',') + ']'; } else if (typeof value === 'string') { return "\"".concat(value, "\""); } else { return "".concat(value); } } /** * Returns a function `f2` (decorator) which wraps the original function * `f`. `f2` guarantees that `f` can be called at most once * every `waitMs` ms. If `f2` is called more often, it will return * the last returned result of `f`. * * @param f The original function `f` to wrap. * @param waitMs The time between two consecutive calls to `f` in ms. */ function debounce(f, waitMs, nowFunc) { var lastTime = nowFunc != null ? nowFunc() : tfc.util.now(); var lastResult; var f2 = function () { var args = []; for (var _i = 0; _i < arguments.length; _i++) { args[_i] = arguments[_i]; } var now = nowFunc != null ? nowFunc() : tfc.util.now(); if (now - lastTime < waitMs) { return lastResult; } lastTime = now; lastResult = f.apply(void 0, __spreadArray([], __read(args), false)); return lastResult; }; return f2; } /** * Returns the fusable activation given a layers identifier. * * @param activationName The layers identifier string. * @return The name of the fusable activation. */ function mapActivationToFusedKernel(activationName) { if (activationName === 'relu') { return 'relu'; } if (activationName === 'linear') { return 'linear'; } if (activationName === 'elu') { return 'elu'; } return null; } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Utilities related to persistent state in the backend. */ /** * An ID to track `tf.SymbolicTensor`s and derived classes. * Required in different places in engine/topology.ts to identify unique * tensors. */ var _nextUniqueTensorId = 0; function getNextUniqueTensorId() { return _nextUniqueTensorId++; } var _uidPrefixes = {}; /** * Provides a unique UID given a string prefix. * * @param prefix */ function getUid(prefix) { if (prefix === void 0) { prefix = ''; } if (!(prefix in _uidPrefixes)) { _uidPrefixes[prefix] = 0; } _uidPrefixes[prefix] += 1; return prefix + _uidPrefixes[prefix].toString(); } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ var VALID_DATA_FORMAT_VALUES = ['channelsFirst', 'channelsLast']; var VALID_INTERPOLATION_FORMAT_VALUES = ['nearest', 'bilinear']; var VALID_PADDING_MODE_VALUES = ['valid', 'same', 'causal']; var VALID_POOL_MODE_VALUES = ['max', 'avg']; var VALID_BIDIRECTIONAL_MERGE_MODES = ['sum', 'mul', 'concat', 'ave']; /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ // A map from the requested scoped name of a Tensor to the number of Tensors // wanting that name so far. This allows enforcing name uniqueness by appending // an incrementing index, e.g. scope/name, scope/name_1, scope/name_2, etc. var nameMap = new Map(); function checkDataFormat(value) { checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, 'DataFormat', value); } function checkInterpolationFormat(value) { checkStringTypeUnionValue(VALID_INTERPOLATION_FORMAT_VALUES, 'InterpolationFormat', value); } function checkPaddingMode(value) { checkStringTypeUnionValue(VALID_PADDING_MODE_VALUES, 'PaddingMode', value); } function checkPoolMode(value) { checkStringTypeUnionValue(VALID_POOL_MODE_VALUES, 'PoolMode', value); } var _nameScopeStack = []; var _nameScopeDivider = '/'; /** * Enter namescope, which can be nested. */ function nameScope(name, fn) { _nameScopeStack.push(name); try { var val = fn(); _nameScopeStack.pop(); return val; } catch (e) { _nameScopeStack.pop(); throw e; } } /** * Get the current namescope as a flat, concatenated string. */ function currentNameScopePrefix() { if (_nameScopeStack.length === 0) { return ''; } else { return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider; } } /** * Get the name a Tensor (or Variable) would have if not uniqueified. * @param tensorName * @return Scoped name string. */ function getScopedTensorName(tensorName) { if (!isValidTensorName(tensorName)) { throw new Error('Not a valid tensor name: \'' + tensorName + '\''); } return currentNameScopePrefix() + tensorName; } /** * Get unique names for Tensors and Variables. * @param scopedName The fully-qualified name of the Tensor, i.e. as produced by * `getScopedTensorName()`. * @return A unique version of the given fully scoped name. * If this is the first time that the scoped name is seen in this session, * then the given `scopedName` is returned unaltered. If the same name is * seen again (producing a collision), an incrementing suffix is added to the * end of the name, so it takes the form 'scope/name_1', 'scope/name_2', etc. */ function getUniqueTensorName(scopedName) { if (!isValidTensorName(scopedName)) { throw new Error('Not a valid tensor name: \'' + scopedName + '\''); } if (!nameMap.has(scopedName)) { nameMap.set(scopedName, 0); } var index = nameMap.get(scopedName); nameMap.set(scopedName, nameMap.get(scopedName) + 1); if (index > 0) { var result = "".concat(scopedName, "_").concat(index); // Mark the composed name as used in case someone wants // to call getUniqueTensorName("name_1"). nameMap.set(result, 1); return result; } else { return scopedName; } } var tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/); /** * Determine whether a string is a valid tensor name. * @param name * @returns A Boolean indicating whether `name` is a valid tensor name. */ function isValidTensorName(name) { return !!name.match(tensorNameRegex); } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Determine if a number is an integer. */ function isInteger(x) { return x === parseInt(x.toString(), 10); } /** * Calculate the product of an array of numbers. * @param array The array to calculate the product over. * @param begin Beginning index, inclusive. * @param end Ending index, exclusive. * @return The product. */ function arrayProd(array, begin, end) { if (begin == null) { begin = 0; } if (end == null) { end = array.length; } var prod = 1; for (var i = begin; i < end; ++i) { prod *= array[i]; } return prod; } /** * Compute minimum value. * @param array * @return minimum value. */ function min(array) { // same behavior as tf.min() if (array.length === 0) { return Number.NaN; } var min = Number.POSITIVE_INFINITY; for (var i = 0; i < array.length; i++) { var value = array[i]; if (value < min) { min = value; } } return min; } /** * Compute maximum value. * @param array * @return maximum value */ function max(array) { // same behavior as tf.max() if (array.length === 0) { return Number.NaN; } var max = Number.NEGATIVE_INFINITY; for (var i = 0; i < array.length; i++) { var value = array[i]; if (value > max) { max = value; } } return max; } /** * Generate an array of integers in [begin, end). * @param begin Beginning integer, inclusive. * @param end Ending integer, exclusive. * @returns Range array. * @throws ValueError, iff `end` < `begin`. */ function range(begin, end) { if (end < begin) { throw new ValueError("end (".concat(end, ") < begin (").concat(begin, ") is forbidden.")); } var out = []; for (var i = begin; i < end; ++i) { out.push(i); } return out; } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ var _epsilon; /** * Returns the value of the fuzz factor used in numeric expressions. */ function epsilon() { if (_epsilon == null) { _epsilon = tfc.backend().epsilon(); } return _epsilon; } /** * Returns the default image data format convention. */ function imageDataFormat() { return 'channelsLast'; } /** * Casts a tensor to a different dtype and returns it. * @param x Input tensor. * @param dtype String: 'float32'|'int32'|'bool'. * @returns Tensor of the specified `dtype`. */ function cast$1(x, dtype) { return tfc__namespace.cast(x, dtype); } /** * Adds a 1-sized dimension at index "axis". * @param x Input tensor. * @param axis Position where to add the new axis. * @returns Result of the dimension expansion. */ function expandDims$1(x, axis) { if (axis === void 0) { axis = -1; } var outShape = x.shape.slice(); if (axis < 0) { axis = outShape.length + axis + 1; } outShape.splice(axis, 0, 1); return tfc__namespace.reshape(x, outShape); } /** * Repeats a 2D tensor. * * If `x` has shape `[samples, dim]` and `n` is 2, for example, the output * will have shape `[samples, 2, dim]`. * * @param x Input tensor. * @param n Integer, number of times to repeat. * @returns The result of the repeat operation. * @throws ValueError: If input tensor is not 2D. */ function repeat(x, n) { return tfc.tidy(function () { if (x.shape.length !== 2) { throw new ValueError("repeat() expects a rank-2 tensor, but received a " + "rank-".concat(x.shape.length, " tensor.")); } var y = expandDims$1(x, 1); return tile$1(y, [1, n, 1]); }); } /** * Flatten a Tensor into 1D. * @param x Input tensor. * @return The result of the flattening `x`. */ function flatten$2(x) { var newShape = [arrayProd(x.shape)]; return tfc__namespace.reshape(x, newShape); } /** * Turn a nD tensor into a 2D tensor with same 0th dimension. * In other words, it flattens each data samples of a batch. * * @param x The tensor to flatten. The rank of this tensor is required to be 2 * or higher. * @return The result of the flattening. */ function batchFlatten(x) { if (x.rank <= 1) { throw new ValueError("batchFlatten requires a minimum rank of 2. Got rank: ".concat(x.rank, ".")); } var newShape = [x.shape[0], arrayProd(x.shape, 1)]; return tfc__namespace.reshape(x, newShape); } /** * Do slicing along the first axis. * @param array input `tf.Tensor`. * @param start starting index, inclusive. * @param size size of the slice along the first axis. * @returns result of the slicing. * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`. */ function sliceAlongFirstAxis(array, start, size) { return tfc.tidy(function () { switch (array.rank) { case 1: return tfc__namespace.slice1d(array, start, size); case 2: return tfc__namespace.slice2d(array, [start, 0], [size, array.shape[1]]); case 3: return tfc__namespace.slice3d(array, [start, 0, 0], [size, array.shape[1], array.shape[2]]); case 4: return tfc__namespace.slice4d(array, [start, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3]]); case 5: return tfc__namespace.slice(array, [start, 0, 0, 0, 0], [ size, array.shape[1], array.shape[2], array.shape[3], array.shape[4] ]); case 6: return tfc__namespace.slice(array, [start, 0, 0, 0, 0, 0], [ size, array.shape[1], array.shape[2], array.shape[3], array.shape[4], array.shape[5] ]); default: throw new ValueError("sliceAlongFirstAxis() received an unsupported tensor rank: " + "".concat(array.rank)); } }); } /** * Do slicing along the last axis. * @param array input `tf.Tensor`. * @param start starting index, inclusive. * @param size size of the slice along the last axis. * @returns result of the slicing. * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`. */ function sliceAlongLastAxis(array, start, size) { return tfc.tidy(function () { switch (array.rank) { case 1: return tfc__namespace.slice1d(array, start, size); case 2: return tfc__namespace.slice2d(array, [0, start], [array.shape[0], size]); case 3: return tfc__namespace.slice3d(array, [0, 0, start], [array.shape[0], array.shape[1], size]); case 4: return tfc__namespace.slice4d(array, [0, 0, 0, start], [array.shape[0], array.shape[1], array.shape[2], size]); default: throw new ValueError("sliceAlongLastAxis() received an unsupported tensor rank: " + "".concat(array.rank)); } }); } /** * Do slicing along the sepcified axis. * @param array input `tf.Tensor`. * @param start starting index, inclusive. * @param size of the slice along the chosen axis. * @param choose an axis. * @returns result of the slicing. * @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`. */ function sliceAlongAxis(array, start, size, axis) { return tfc.tidy(function () { switch (array.rank) { case 1: return tfc__namespace.slice1d(array, start, size); case 2: switch (axis) { case 1: return sliceAlongFirstAxis(array, start, size); case 2: return sliceAlongLastAxis(array, start, size); default: throw new ValueError("The axis is not within the rank of the tensor " + "".concat(axis)); } case 3: switch (axis) { case 1: return sliceAlongFirstAxis(array, start, size); case 2: return tfc__namespace.slice3d(array, [0, start, 0], [array.shape[0], size, array.shape[2]]); case 3: return sliceAlongLastAxis(array, start, size); default: throw new ValueError("The axis is not within the rank of the tensor " + "".concat(axis)); } case 4: switch (axis) { case 1: return sliceAlongFirstAxis(array, start, size); case 2: return tfc__namespace.slice4d(array, [0, start, 0, 0], [array.shape[0], size, array.shape[2], array.shape[3]]); case 3: return tfc__namespace.slice4d(array, [0, 0, start, 0], [array.shape[0], array.shape[1], size, array.shape[3]]); case 4: return sliceAlongLastAxis(array, start, size); default: throw new ValueError("The axis is not within the rank of the tensor " + "".concat(axis)); } default: throw new ValueError("sliceAlongLastAxis() received an unsupported tensor rank: " + "".concat(array.rank)); } }); } /** * Concatenates a list of tensors alongside the specified axis. * @param tensors `Array` of tensors to concatenate. * @param axis Concatenation axis. * @returns The result of the concatenation. */ function concatenate$1(tensors, axis) { if (axis === void 0) { axis = -1; } var rank; if (axis < 0) { rank = tensors[0].rank; if (rank !== 0) { axis = rank; } else { axis = 0; } } if (axis === tensors[0].rank) { // Porting Note: This is necessary because tfc.concat() requires axis to be // in the interval [-rank, rank). axis = -1; } // Porting Note: Sparse concat is not supported yet. return tfc__namespace.concat(tensors, axis); } /** * Concatenate two arrays along the first dimension. * @param a The 1st `tf.Tensor` to concatenate. * @param b The 2nd `tf.Tensor` to concatenate. * @returns Result of the concatenation. * @throws ValueError: If `a` is of an unsupported subtype of `tf.Tensor`. */ function concatAlongFirstAxis(a, b) { switch (a.rank) { case 1: return tfc__namespace.concat1d([a, b]); case 2: return tfc__namespace.concat2d([a, b], 0); case 3: return tfc__namespace.concat3d([a, b], 0); case 4: return tfc__namespace.concat4d([a, b], 0); default: throw new ValueError("concatAlongFirstAxis() received an unsupported " + "tensor rank: ".concat(a.rank)); } } /** * Creates a tensor by tiling `x` by `n`. * @param x A tensor. * @param n An Array of integers or a single integer. If an Array, the length * must be the same as the number of dimensions in `x`. If a single integer, * it will be treated as an Array of length 1. */ function tile$1(x, n) { if (!Array.isArray(n)) { n = [n]; } if (x.rank !== n.length) { throw new ValueError("The length of input n (".concat(n.length, ") does not match ") + "the number of dimensions in input x (".concat(x.rank, ")")); } return tfc__namespace.tile(x, n); } /* Creation of random tensors. */ /** * Get a tensor with normal distribution of values. * * @param shape Shape of the tensor. * @param mean mean value of the normal distribution. * @param stddev standard deviation of the normal distribution. * @param dtype * @param seed * @return The normal tensor. */ function randomNormal$1(shape, mean, stddev, dtype, seed) { if (mean === void 0) { mean = 0.0; } if (stddev === void 0) { stddev = 1.0; } return tfc__namespace.randomNormal(shape, mean, stddev, dtype, seed); } /* Linear Algebra */ /** * Multiply two tensors and returns the result as a tensor. * * For 2D tensors, this is equivalent to matrix multiplication (matMul). * For tensors of higher ranks, it follows the Theano behavior, * (e.g. `(2, 3) * (4, 3, 5) -> (2, 4, 5)`). From the Theano documentation: * * For N dimensions it is a sum product over the last axis of x and the * second-to-last of y: * * @param a A tensor of at least rank 2. * @param b A tensor of at least rank 2. * @param activation (optional) A string identifying the activation * function. * @return Result of the dot operation. */ function dot$1(a, b, activation, bias) { if ((a.rank < 2) || (b.rank < 2)) { throw new NotImplementedError("dot requires both inputs to be rank >= 2" + " but got x shape = ".concat(a.shape, " and y shape = ").concat(b.shape)); } if (b.rank >= 3) { var xLastDim = a.shape.slice(-1)[0]; var ySecondLastDim = b.shape.slice(-2)[0]; if (xLastDim !== ySecondLastDim) { throw new NotImplementedError("If rank y >= 3, then the second last dim" + " of y must equal the last dim of x but got x shape = ".concat(a.shape, " and ") + " y shape = ".concat(b.shape)); } } // Handle basic 2D x 2D case. if ((a.rank === 2) && (b.rank === 2)) { var transposeA = false; var transposeB = false; // tfc.fused.matMul only fuses certain activation functions. Unsupported // activation functions are treated as 'linear' activations, which is // equivalent to a no-op. return tfc__namespace.fused.matMul({ a: a, b: b, transposeA: transposeA, transposeB: transposeB, bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null, activation: activation }); } else { // Reshape x into the analogous 2D Tensor. var aFirstDims = a.shape.slice(); // Holds all but the last dim of x. var aLastDim = aFirstDims.pop(); a = tfc__namespace.reshape(a, [-1, aLastDim]); // Reshape y into the analogous 2D Tensor, and keep track of the // required dimensions to reproduce the output shape. var bShape = b.shape.slice(); var bLastDim = bShape.pop(); var ySecondLastDim = bShape.pop(); var yOtherDims = __spreadArray(__spreadArray([], __read(bShape), false), [bLastDim], false); // permutation should be like [r-2, 0, 1, 2, ... r-4, r-3, r-1] // where r is the rank of y. var perm = Array.from({ length: b.rank }, function (_, i) { if (i === 0) { return b.rank - 2; } else if (i <= b.rank - 2) { return i - 1; } return i; }); b = tfc__namespace.reshape(tfc__namespace.transpose(b, perm), [ySecondLastDim, -1]); // Multiply x and y as 2D Tensors, and then reshape back to original. var outputShape = __spreadArray(__spreadArray([], __read(aFirstDims), false), __read(yOtherDims), false); var transposeA = false; var transposeB = false; return tfc__namespace.reshape(tfc__namespace.fused.matMul({ a: a, b: b, transposeA: transposeA, transposeB: transposeB, bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null, activation: activation }), outputShape); } } /* Elementary math functions. */ /** * Retrieves the elements of indices `indices` in the tensor `reference`. * @param reference A tensor. * @param indices An integer tensor of indices or an `Array` of integers. * @param axis Axis along which to perform the gather operation. * @returns The result of the gathering as a tensor. */ function gather$1(reference, indices, axis) { return tfc.tidy(function () { if (Array.isArray(indices)) { indices = tfc.tensor1d(indices, 'int32'); } else { indices = tfc__namespace.cast(indices, 'int32'); } return tfc__namespace.gather(reference, indices, axis); }); } /** * Element-wise square. * @param x Input tensor. * @return element-wise x^2 */ function square$1(x) { return tfc__namespace.mul(x, x); } /** * Reshapes bias tensor according to rank of x. */ function reshapeBias(xRank, bias, dataFormat) { var biasShape = bias.shape; if (bias.rank !== 1 && bias.rank !== xRank) { throw new ValueError("Unexpected bias dimensions: ".concat(bias.rank) + "; expected it to be 1 or ".concat(xRank)); } if (xRank === 5) { if (dataFormat === 'channelsFirst') { if (biasShape.length === 1) { return tfc__namespace.reshape(bias, [1, biasShape[0], 1, 1, 1]); } else { return tfc__namespace.reshape(bias, [1, biasShape[3], biasShape[0], biasShape[1], biasShape[2]]); } } else if (dataFormat === 'channelsLast') { if (biasShape.length === 1) { return tfc__namespace.reshape(bias, [1, 1, 1, 1, biasShape[0]]); } else { return tfc__namespace.reshape(bias, [1].concat(biasShape)); } } } else if (xRank === 4) { if (dataFormat === 'channelsFirst') { if (biasShape.length === 1) { return tfc__namespace.reshape(bias, [1, biasShape[0], 1, 1]); } else { return tfc__namespace.reshape(bias, [1, biasShape[2], biasShape[0], biasShape[1]]); } } else if (dataFormat === 'channelsLast') { if (biasShape.length === 1) { return tfc__namespace.reshape(bias, [1, 1, 1, biasShape[0]]); } else { return tfc__namespace.reshape(bias, [1].concat(biasShape)); } } } else if (xRank === 3) { if (dataFormat === 'channelsFirst') { if (biasShape.length === 1) { return tfc__namespace.reshape(bias, [1, biasShape[0], 1]); } else { return tfc__namespace.reshape(bias, [1, biasShape[1], biasShape[0]]); } } else if (dataFormat === 'channelsLast') { if (biasShape.length === 1) { return tfc__namespace.reshape(bias, [1, 1, biasShape[0]]); } else { return tfc__namespace.reshape(bias, [1].concat(biasShape)); } } } else if (xRank < 3) { return bias; } throw new ValueError("Unsupported input rank by biasAdd: ".concat(bias.rank)); } /* Neural-network operations. */ /** * Add a bias to a tensor. * * @param x The tensor to add the bias to. * @param bias The bias to add to `x`. Must be 1D or the same rank as `x`. * @return Result of the bias adding. * @throws ValueError: If the rank of `bias` is incorrect. */ function biasAdd(x, bias, dataFormat) { return tfc.tidy(function () { if (dataFormat == null) { dataFormat = imageDataFormat(); } checkDataFormat(dataFormat); return tfc__namespace.add(x, reshapeBias(x.rank, bias, dataFormat)); }); } /** * Exponential linear unit (ELU). * @param x A tensor or variable to compute the activation function for. * @param alpha: A scalar, a scaling factor for the negative section. * @return Output of the ELU operation. */ function elu$1(x, alpha) { if (alpha === void 0) { alpha = 1; } // TODO(cais): Add support for alpha values other than 1. if (alpha !== 1) { throw new NotImplementedError("Support for alpha values other than 1 (".concat(alpha, ") is not implemented ") + "yet."); } return tfc__namespace.elu(x); } /** * Softsign of a tensor. * * Defined as x / (abs(x) + 1), element-wise. * * @param x: Input. * @returns Output. */ function softsign(x) { return tfc.tidy(function () { return tfc__namespace.div(x, tfc__namespace.add(tfc__namespace.abs(x), 1)); }); } /** * Sets entries in `x` to zero at random, while scaling the entire tensor. * * @param x input tensor. * @param level fraction of the entries in the tensor that will be set to 0. * @param noiseShape shape of randomly generated keep/drop flags, must be * broadcastable to the shape of `x`. Optional. * @param seed random seed to ensure determinism. Optional. * @returns Result of the dropout operation. */ function dropout$1(x, level, noiseShape, seed) { return tfc.tidy(function () { return tfc__namespace.dropout(x, level, noiseShape, seed); }); } /** * Element-wise, segment-wise linear approximation of sigmoid. * * Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`. * In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`. * * @param x Input tensor. * @returns Output tensor. */ function hardSigmoid(x) { return tfc.tidy(function () { var y = tfc__namespace.add(.5, tfc__namespace.mul(.2, x)); return tfc__namespace.clipByValue(y, 0, 1); }); } /** * Invoke `x` in the training phase, and `alt` otherwise. * * Porting Note: We do not create placeholder tensors for the `training` * boolean flag here, because there is no such thing in the TF.js imperative * backend. * * @param x The function to invoke iff `training` is `true`. * @param alt The function to invoke iff `training` is `false`. * @param training Boolean flag for whether training phase is active. * @returns The return value of `x()` if `training` is `true`, or the return * value of `alt()` if `training` is `false`. */ function inTrainPhase(x, alt, training) { if (training === void 0) { training = false; } return training ? x() : alt(); } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ var VALID_FAN_MODE_VALUES = ['fanIn', 'fanOut', 'fanAvg']; var VALID_DISTRIBUTION_VALUES = ['normal', 'uniform', 'truncatedNormal']; function checkFanMode(value) { checkStringTypeUnionValue(VALID_FAN_MODE_VALUES, 'FanMode', value); } function checkDistribution(value) { checkStringTypeUnionValue(VALID_DISTRIBUTION_VALUES, 'Distribution', value); } /** * Initializer base class. * * @doc { * heading: 'Initializers', subheading: 'Classes', namespace: 'initializers'} */ var Initializer = /** @class */ (function (_super) { __extends(Initializer, _super); function Initializer() { return _super !== null && _super.apply(this, arguments) || this; } Initializer.prototype.fromConfigUsesCustomObjects = function () { return false; }; Initializer.prototype.getConfig = function () { return {}; }; return Initializer; }(tfc.serialization.Serializable)); var Zeros = /** @class */ (function (_super) { __extends(Zeros, _super); function Zeros() { return _super !== null && _super.apply(this, arguments) || this; } Zeros.prototype.apply = function (shape, dtype) { return tfc.zeros(shape, dtype); }; return Zeros; }(Initializer)); /** @nocollapse */ Zeros.className = 'Zeros'; tfc.serialization.registerClass(Zeros); var Ones = /** @class */ (function (_super) { __extends(Ones, _super); function Ones() { return _super !== null && _super.apply(this, arguments) || this; } Ones.prototype.apply = function (shape, dtype) { return tfc.ones(shape, dtype); }; return Ones; }(Initializer)); /** @nocollapse */ Ones.className = 'Ones'; tfc.serialization.registerClass(Ones); var Constant = /** @class */ (function (_super) { __extends(Constant, _super); function Constant(args) { var _this = _super.call(this) || this; if (typeof args !== 'object') { throw new ValueError("Expected argument of type ConstantConfig but got ".concat(args)); } if (args.value === undefined) { throw new ValueError("config must have value set but got ".concat(args)); } _this.value = args.value; return _this; } Constant.prototype.apply = function (shape, dtype) { var _this = this; return tfc.tidy(function () { return tfc.mul(tfc.scalar(_this.value), tfc.ones(shape, dtype)); }); }; Constant.prototype.getConfig = function () { return { value: this.value, }; }; return Constant; }(Initializer)); /** @nocollapse */ Constant.className = 'Constant'; tfc.serialization.registerClass(Constant); var RandomUniform = /** @class */ (function (_super) { __extends(RandomUniform, _super); function RandomUniform(args) { var _this = _super.call(this) || this; _this.DEFAULT_MINVAL = -0.05; _this.DEFAULT_MAXVAL = 0.05; _this.minval = args.minval || _this.DEFAULT_MINVAL; _this.maxval = args.maxval || _this.DEFAULT_MAXVAL; _this.seed = args.seed; return _this; } RandomUniform.prototype.apply = function (shape, dtype) { return tfc.randomUniform(shape, this.minval, this.maxval, dtype, this.seed); }; RandomUniform.prototype.getConfig = function () { return { minval: this.minval, maxval: this.maxval, seed: this.seed }; }; return RandomUniform; }(Initializer)); /** @nocollapse */ RandomUniform.className = 'RandomUniform'; tfc.serialization.registerClass(RandomUniform); var RandomNormal = /** @class */ (function (_super) { __extends(RandomNormal, _super); function RandomNormal(args) { var _this = _super.call(this) || this; _this.DEFAULT_MEAN = 0.; _this.DEFAULT_STDDEV = 0.05; _this.mean = args.mean || _this.DEFAULT_MEAN; _this.stddev = args.stddev || _this.DEFAULT_STDDEV; _this.seed = args.seed; return _this; } RandomNormal.prototype.apply = function (shape, dtype) { dtype = dtype || 'float32'; if (dtype !== 'float32' && dtype !== 'int32') { throw new NotImplementedError("randomNormal does not support dType ".concat(dtype, ".")); } return randomNormal$1(shape, this.mean, this.stddev, dtype, this.seed); }; RandomNormal.prototype.getConfig = function () { return { mean: this.mean, stddev: this.stddev, seed: this.seed }; }; return RandomNormal; }(Initializer)); /** @nocollapse */ RandomNormal.className = 'RandomNormal'; tfc.serialization.registerClass(RandomNormal); var TruncatedNormal = /** @class */ (function (_super) { __extends(TruncatedNormal, _super); function TruncatedNormal(args) { var _this = _super.call(this) || this; _this.DEFAULT_MEAN = 0.; _this.DEFAULT_STDDEV = 0.05; _this.mean = args.mean || _this.DEFAULT_MEAN; _this.stddev = args.stddev || _this.DEFAULT_STDDEV; _this.seed = args.seed; return _this; } TruncatedNormal.prototype.apply = function (shape, dtype) { dtype = dtype || 'float32'; if (dtype !== 'float32' && dtype !== 'int32') { throw new NotImplementedError("truncatedNormal does not support dType ".concat(dtype, ".")); } return tfc.truncatedNormal(shape, this.mean, this.stddev, dtype, this.seed); }; TruncatedNormal.prototype.getConfig = function () { return { mean: this.mean, stddev: this.stddev, seed: this.seed }; }; return TruncatedNormal; }(Initializer)); /** @nocollapse */ TruncatedNormal.className = 'TruncatedNormal'; tfc.serialization.registerClass(TruncatedNormal); var Identity$1 = /** @class */ (function (_super) { __extends(Identity, _super); function Identity(args) { var _this = _super.call(this) || this; _this.gain = args.gain != null ? args.gain : 1.0; return _this; } Identity.prototype.apply = function (shape, dtype) { var _this = this; return tfc.tidy(function () { if (shape.length !== 2 || shape[0] !== shape[1]) { throw new ValueError('Identity matrix initializer can only be used for' + ' 2D square matrices.'); } else { return tfc.mul(_this.gain, tfc.eye(shape[0])); } }); }; Identity.prototype.getConfig = function () { return { gain: this.gain }; }; return Identity; }(Initializer)); /** @nocollapse */ Identity$1.className = 'Identity'; tfc.serialization.registerClass(Identity$1); /** * Computes the number of input and output units for a weight shape. * @param shape Shape of weight. * @param dataFormat data format to use for convolution kernels. * Note that all kernels in Keras are standardized on the * CHANNEL_LAST ordering (even when inputs are set to CHANNEL_FIRST). * @return An length-2 array: fanIn, fanOut. */ function computeFans(shape, dataFormat) { if (dataFormat === void 0) { dataFormat = 'channelsLast'; } var fanIn; var fanOut; checkDataFormat(dataFormat); if (shape.length === 2) { fanIn = shape[0]; fanOut = shape[1]; } else if ([3, 4, 5].indexOf(shape.length) !== -1) { if (dataFormat === 'channelsFirst') { var receptiveFieldSize = arrayProd(shape, 2); fanIn = shape[1] * receptiveFieldSize; fanOut = shape[0] * receptiveFieldSize; } else if (dataFormat === 'channelsLast') { var receptiveFieldSize = arrayProd(shape, 0, shape.length - 2); fanIn = shape[shape.length - 2] * receptiveFieldSize; fanOut = shape[shape.length - 1] * receptiveFieldSize; } } else { var shapeProd = arrayProd(shape); fanIn = Math.sqrt(shapeProd); fanOut = Math.sqrt(shapeProd); } return [fanIn, fanOut]; } var VarianceScaling = /** @class */ (function (_super) { __extends(VarianceScaling, _super); /** * Constructor of VarianceScaling. * @throws ValueError for invalid value in scale. */ function VarianceScaling(args) { var _this = _super.call(this) || this; if (args.scale < 0.0) { throw new ValueError("scale must be a positive float. Got: ".concat(args.scale)); } _this.scale = args.scale == null ? 1.0 : args.scale; _this.mode = args.mode == null ? 'fanIn' : args.mode; checkFanMode(_this.mode); _this.distribution = args.distribution == null ? 'normal' : args.distribution; checkDistribution(_this.distribution); _this.seed = args.seed; return _this; } VarianceScaling.prototype.apply = function (shape, dtype) { var fans = computeFans(shape); var fanIn = fans[0]; var fanOut = fans[1]; var scale = this.scale; if (this.mode === 'fanIn') { scale /= Math.max(1, fanIn); } else if (this.mode === 'fanOut') { scale /= Math.max(1, fanOut); } else { scale /= Math.max(1, (fanIn + fanOut) / 2); } if (this.distribution === 'normal') { var stddev = Math.sqrt(scale); dtype = dtype || 'float32'; if (dtype !== 'float32' && dtype !== 'int32') { throw new NotImplementedError("".concat(this.getClassName(), " does not support dType ").concat(dtype, ".")); } return tfc.truncatedNormal(shape, 0, stddev, dtype, this.seed); } else { var limit = Math.sqrt(3 * scale); return tfc.randomUniform(shape, -limit, limit, dtype, this.seed); } }; VarianceScaling.prototype.getConfig = function () { return { scale: this.scale, mode: this.mode, distribution: this.distribution, seed: this.seed }; }; return VarianceScaling; }(Initializer)); /** @nocollapse */ VarianceScaling.className = 'VarianceScaling'; tfc.serialization.registerClass(VarianceScaling); var GlorotUniform = /** @class */ (function (_super) { __extends(GlorotUniform, _super); /** * Constructor of GlorotUniform * @param scale * @param mode * @param distribution * @param seed */ function GlorotUniform(args) { return _super.call(this, { scale: 1.0, mode: 'fanAvg', distribution: 'uniform', seed: args == null ? null : args.seed }) || this; } GlorotUniform.prototype.getClassName = function () { // In Python Keras, GlorotUniform is not a class, but a helper method // that creates a VarianceScaling object. Use 'VarianceScaling' as // class name to be compatible with that. return VarianceScaling.className; }; return GlorotUniform; }(VarianceScaling)); /** @nocollapse */ GlorotUniform.className = 'GlorotUniform'; tfc.serialization.registerClass(GlorotUniform); var GlorotNormal = /** @class */ (function (_super) { __extends(GlorotNormal, _super); /** * Constructor of GlorotNormal. * @param scale * @param mode * @param distribution * @param seed */ function GlorotNormal(args) { return _super.call(this, { scale: 1.0, mode: 'fanAvg', distribution: 'normal', seed: args == null ? null : args.seed }) || this; } GlorotNormal.prototype.getClassName = function () { // In Python Keras, GlorotNormal is not a class, but a helper method // that creates a VarianceScaling object. Use 'VarianceScaling' as // class name to be compatible with that. return VarianceScaling.className; }; return GlorotNormal; }(VarianceScaling)); /** @nocollapse */ GlorotNormal.className = 'GlorotNormal'; tfc.serialization.registerClass(GlorotNormal); var HeNormal = /** @class */ (function (_super) { __extends(HeNormal, _super); function HeNormal(args) { return _super.call(this, { scale: 2.0, mode: 'fanIn', distribution: 'normal', seed: args == null ? null : args.seed }) || this; } HeNormal.prototype.getClassName = function () { // In Python Keras, HeNormal is not a class, but a helper method // that creates a VarianceScaling object. Use 'VarianceScaling' as // class name to be compatible with that. return VarianceScaling.className; }; return HeNormal; }(VarianceScaling)); /** @nocollapse */ HeNormal.className = 'HeNormal'; tfc.serialization.registerClass(HeNormal); var HeUniform = /** @class */ (function (_super) { __extends(HeUniform, _super); function HeUniform(args) { return _super.call(this, { scale: 2.0, mode: 'fanIn', distribution: 'uniform', seed: args == null ? null : args.seed }) || this; } HeUniform.prototype.getClassName = function () { // In Python Keras, HeUniform is not a class, but a helper method // that creates a VarianceScaling object. Use 'VarianceScaling' as // class name to be compatible with that. return VarianceScaling.className; }; return HeUniform; }(VarianceScaling)); /** @nocollapse */ HeUniform.className = 'HeUniform'; tfc.serialization.registerClass(HeUniform); var LeCunNormal = /** @class */ (function (_super) { __extends(LeCunNormal, _super); function LeCunNormal(args) { return _super.call(this, { scale: 1.0, mode: 'fanIn', distribution: 'normal', seed: args == null ? null : args.seed }) || this; } LeCunNormal.prototype.getClassName = function () { // In Python Keras, LeCunNormal is not a class, but a helper method // that creates a VarianceScaling object. Use 'VarianceScaling' as // class name to be compatible with that. return VarianceScaling.className; }; return LeCunNormal; }(VarianceScaling)); /** @nocollapse */ LeCunNormal.className = 'LeCunNormal'; tfc.serialization.registerClass(LeCunNormal); var LeCunUniform = /** @class */ (function (_super) { __extends(LeCunUniform, _super); function LeCunUniform(args) { return _super.call(this, { scale: 1.0, mode: 'fanIn', distribution: 'uniform', seed: args == null ? null : args.seed }) || this; } LeCunUniform.prototype.getClassName = function () { // In Python Keras, LeCunUniform is not a class, but a helper method // that creates a VarianceScaling object. Use 'VarianceScaling' as // class name to be compatible with that. return VarianceScaling.className; }; return LeCunUniform; }(VarianceScaling)); /** @nocollapse */ LeCunUniform.className = 'LeCunUniform'; tfc.serialization.registerClass(LeCunUniform); var Orthogonal = /** @class */ (function (_super) { __extends(Orthogonal, _super); function Orthogonal(args) { var _this = _super.call(this) || this; _this.DEFAULT_GAIN = 1; _this.ELEMENTS_WARN_SLOW = 2000; _this.gain = args.gain == null ? _this.DEFAULT_GAIN : args.gain; _this.seed = args.seed; return _this; } Orthogonal.prototype.apply = function (shape, dtype) { var _this = this; return tfc.tidy(function () { if (shape.length < 2) { throw new NotImplementedError('Shape must be at least 2D.'); } if (dtype !== 'int32' && dtype !== 'float32' && dtype !== undefined) { throw new TypeError("Unsupported data type ".concat(dtype, ".")); } dtype = dtype; // flatten the input shape with the last dimension remaining its // original shape so it works for conv2d var numRows = tfc.util.sizeFromShape(shape.slice(0, -1)); var numCols = shape[shape.length - 1]; var numElements = numRows * numCols; if (numElements > _this.ELEMENTS_WARN_SLOW) { console.warn("Orthogonal initializer is being called on a matrix with more " + "than ".concat(_this.ELEMENTS_WARN_SLOW, " (").concat(numElements, ") elements: ") + "Slowness may result."); } var flatShape = [Math.max(numCols, numRows), Math.min(numCols, numRows)]; // Generate a random matrix var randNormalMat = randomNormal$1(flatShape, 0, 1, dtype, _this.seed); // Compute QR factorization var qr = tfc.linalg.qr(randNormalMat, false); var qMat = qr[0]; var rMat = qr[1]; // Make Q uniform var diag = rMat.flatten().stridedSlice([0], [Math.min(numCols, numRows) * Math.min(numCols, numRows)], [Math.min(numCols, numRows) + 1]); qMat = tfc.mul(qMat, diag.sign()); if (numRows < numCols) { qMat = qMat.transpose(); } return tfc.mul(tfc.scalar(_this.gain), qMat.reshape(shape)); }); }; Orthogonal.prototype.getConfig = function () { return { gain: this.gain, seed: this.seed, }; }; return Orthogonal; }(Initializer)); /** @nocollapse */ Orthogonal.className = 'Orthogonal'; tfc.serialization.registerClass(Orthogonal); // Maps the JavaScript-like identifier keys to the corresponding registry // symbols. var INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = { 'constant': 'Constant', 'glorotNormal': 'GlorotNormal', 'glorotUniform': 'GlorotUniform', 'heNormal': 'HeNormal', 'heUniform': 'HeUniform', 'identity': 'Identity', 'leCunNormal': 'LeCunNormal', 'leCunUniform': 'LeCunUniform', 'ones': 'Ones', 'orthogonal': 'Orthogonal', 'randomNormal': 'RandomNormal', 'randomUniform': 'RandomUniform', 'truncatedNormal': 'TruncatedNormal', 'varianceScaling': 'VarianceScaling', 'zeros': 'Zeros' }; function deserializeInitializer(config, customObjects) { if (customObjects === void 0) { customObjects = {}; } return deserializeKerasObject(config, tfc.serialization.SerializationMap.getMap().classNameMap, customObjects, 'initializer'); } function serializeInitializer(initializer) { return serializeKerasObject(initializer); } function getInitializer(identifier) { if (typeof identifier === 'string') { var className = identifier in INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ? INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] : identifier; /* We have four 'helper' classes for common initializers that all get serialized as 'VarianceScaling' and shouldn't go through the deserializeInitializer pathway. */ if (className === 'GlorotNormal') { return new GlorotNormal(); } else if (className === 'GlorotUniform') { return new GlorotUniform(); } else if (className === 'HeNormal') { return new HeNormal(); } else if (className === 'HeUniform') { return new HeUniform(); } else if (className === 'LeCunNormal') { return new LeCunNormal(); } else if (className === 'LeCunUniform') { return new LeCunUniform(); } else { var config = {}; config['className'] = className; config['config'] = {}; return deserializeInitializer(config); } } else if (identifier instanceof Initializer) { return identifier; } else { return deserializeInitializer(identifier); } } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ // tslint:enable /** * Determine whether the input is an Array of Shapes. */ function isArrayOfShapes(x) { return Array.isArray(x) && Array.isArray(x[0]); } /** * Special case of normalizing shapes to lists. * * @param x A shape or list of shapes to normalize into a list of Shapes. * @return A list of Shapes. */ function normalizeShapeList(x) { if (x.length === 0) { return []; } if (!Array.isArray(x[0])) { return [x]; } return x; } /** * Helper function to obtain exactly one Tensor. * @param xs: A single `tf.Tensor` or an `Array` of `tf.Tensor`s. * @return A single `tf.Tensor`. If `xs` is an `Array`, return the first one. * @throws ValueError: If `xs` is an `Array` and its length is not 1. */ function getExactlyOneTensor(xs) { var x; if (Array.isArray(xs)) { if (xs.length !== 1) { throw new ValueError("Expected Tensor length to be 1; got ".concat(xs.length)); } x = xs[0]; } else { x = xs; } return x; } /** * Helper function to obtain exactly on instance of Shape. * * @param shapes Input single `Shape` or Array of `Shape`s. * @returns If input is a single `Shape`, return it unchanged. If the input is * an `Array` containing exactly one instance of `Shape`, return the instance. * Otherwise, throw a `ValueError`. * @throws ValueError: If input is an `Array` of `Shape`s, and its length is not * 1. */ function getExactlyOneShape(shapes) { if (Array.isArray(shapes) && Array.isArray(shapes[0])) { if (shapes.length === 1) { shapes = shapes; return shapes[0]; } else { throw new ValueError("Expected exactly 1 Shape; got ".concat(shapes.length)); } } else { return shapes; } } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Count the elements in an Array of LayerVariables. * * @param weights: The LayerVariables of which the constituent numbers are to * be counted. * @returns A count of the elements in all the LayerVariables */ function countParamsInWeights(weights) { var e_1, _a; var count = 0; try { for (var weights_1 = __values(weights), weights_1_1 = weights_1.next(); !weights_1_1.done; weights_1_1 = weights_1.next()) { var weight = weights_1_1.value; if (weight.shape.length === 0) { count += 1; } else { count += weight.shape.reduce(function (a, b) { return a * b; }); } } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (weights_1_1 && !weights_1_1.done && (_a = weights_1.return)) _a.call(weights_1); } finally { if (e_1) throw e_1.error; } } return count; } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ var DEFAULT_VARIABLE_NAME_PREFIX = 'Variable'; /** * A `tf.layers.LayerVariable` is similar to a `tf.Tensor` in that it has a * dtype and shape, but its value is mutable. The value is itself represented * as a`tf.Tensor`, and can be read with the `read()` method and updated with * the `write()` method. */ var LayerVariable = /** @class */ (function () { /** * Construct Variable from a `tf.Tensor`. * * If not explicitly named, the Variable will be given a name with the * prefix 'Variable'. Variable names are unique. In the case of name * collision, suffixies '_' will be added to the name. * * @param val Initial value of the Variable. * @param name Name of the variable. If `null` or `undefined` is provided, it * will default a name with the prefix 'Variable'. * @param constraint Optional, projection function to be applied to the * variable after optimize updates * @throws ValueError if `name` is `null` or `undefined`. */ function LayerVariable(val, dtype, name, trainable, constraint) { if (dtype === void 0) { dtype = 'float32'; } if (name === void 0) { name = DEFAULT_VARIABLE_NAME_PREFIX; } if (trainable === void 0) { trainable = true; } if (constraint === void 0) { constraint = null; } this.dtype = dtype == null ? 'float32' : dtype; this.shape = val.shape; this.id = getNextUniqueTensorId(); name = name == null ? DEFAULT_VARIABLE_NAME_PREFIX : name; this.originalName = getScopedTensorName(name); this.name = getUniqueTensorName(this.originalName); this.trainable_ = trainable; this.constraint = constraint; this.val = tfc__namespace.variable(val, this.trainable_, this.name, this.dtype); } /** * Get a snapshot of the Variable's value. * * The returned value is a snapshot of the Variable's value at the time of * the invocation. Future mutations in the value of the tensor will only * be reflected by future calls to this method. */ LayerVariable.prototype.read = function () { this.assertNotDisposed(); return this.val; }; /** * Update the value of the Variable. * * @param newVal: The new value to update to. Must be consistent with the * dtype and shape of the Variable. * @return This Variable. */ LayerVariable.prototype.write = function (newVal) { // TODO(cais): Once TF.js Core supports Tensor.dtype, check dtype match. this.assertNotDisposed(); checkShapesMatch(this.val, newVal); // Skip updating if this is the exact same tensor. if (this.val.id !== newVal.id) { this.val.assign(newVal); if (this.constraint != null) { this.val.assign(this.constraint.apply(this.val)); } } return this; }; /** * Dispose this LayersVariable instance from memory. */ LayerVariable.prototype.dispose = function () { this.assertNotDisposed(); this.val.dispose(); }; LayerVariable.prototype.assertNotDisposed = function () { if (this.val.isDisposed) { throw new Error("LayersVariable ".concat(this.name, " is already disposed.")); } }; Object.defineProperty(LayerVariable.prototype, "trainable", { get: function () { return this.trainable_; }, set: function (trainable) { this.trainable_ = trainable; this.val.trainable = trainable; }, enumerable: false, configurable: true }); return LayerVariable; }()); function checkShapesMatch(x, y) { if (x.shape.toString() !== y.shape.toString()) { throw new Error('Shape mismatch: ' + JSON.stringify(x.shape) + ' vs. ' + JSON.stringify(y.shape)); } } /** * Get the values of an array of Variables. * * @param tensors An `Array` of `Variable`s to get the values of. * @return The values of the inputs, as an `Array` of`tf.Tensor`s. */ function batchGetValue(xs) { return xs.map(function (x) { return x.read(); }); } /** * Update the value of multiple Variables at once. * * @param variablesAndValues An `Array`, each element is of type * [Variable, Tensor]. The first item is the * `Variable` of which the value is to be updated. The second item * carries the new value. */ function batchSetValue(variablesAndValues) { variablesAndValues.forEach(function (variableAndValue) { var variable = variableAndValue[0]; variable.write(variableAndValue[1]); }); } /** * Specifies the ndim, dtype and shape of every input to a layer. * * Every layer should expose (if appropriate) an `inputSpec` attribute: * a list of instances of InputSpec (one per input tensor). * * A null entry in a shape is compatible with any dimension, * a null shape is compatible with any shape. */ var InputSpec = /** @class */ (function () { function InputSpec(args) { this.dtype = args.dtype; this.shape = args.shape; /* TODO(michaelterry): Could throw error if ndim and shape are both defined (then backport). */ if (args.shape != null) { this.ndim = args.shape.length; } else { this.ndim = args.ndim; } this.maxNDim = args.maxNDim; this.minNDim = args.minNDim; this.axes = args.axes || {}; } return InputSpec; }()); /** * `tf.SymbolicTensor` is a placeholder for a Tensor without any concrete value. * * They are most often encountered when building a graph of `Layer`s for a * `tf.LayersModel` and the input data's shape, but not values are known. * * @doc {heading: 'Models', 'subheading': 'Classes'} */ var SymbolicTensor = /** @class */ (function () { /** * * @param dtype * @param shape * @param sourceLayer The Layer that produced this symbolic tensor. * @param inputs The inputs passed to sourceLayer's __call__() method. * @param nodeIndex * @param tensorIndex * @param callArgs The keyword arguments passed to the __call__() method. * @param name * @param outputTensorIndex The index of this tensor in the list of outputs * returned by apply(). */ function SymbolicTensor(dtype, shape, sourceLayer, inputs, callArgs, name, outputTensorIndex) { this.dtype = dtype; this.shape = shape; this.sourceLayer = sourceLayer; this.inputs = inputs; this.callArgs = callArgs; this.outputTensorIndex = outputTensorIndex; this.id = getNextUniqueTensorId(); if (name != null) { this.originalName = getScopedTensorName(name); this.name = getUniqueTensorName(this.originalName); } this.rank = shape.length; } return SymbolicTensor; }()); var _nextNodeID = 0; /** * A `Node` describes the connectivity between two layers. * * Each time a layer is connected to some new input, * a node is added to `layer.inboundNodes`. * * Each time the output of a layer is used by another layer, * a node is added to `layer.outboundNodes`. * * `nodeIndices` and `tensorIndices` are basically fine-grained coordinates * describing the origin of the `inputTensors`, verifying the following: * * `inputTensors[i] == * inboundLayers[i].inboundNodes[nodeIndices[i]].outputTensors[ * tensorIndices[i]]` * * A node from layer A to layer B is added to: * A.outboundNodes * B.inboundNodes */ var Node = /** @class */ (function () { function Node(args, // TODO(michaelterry): Define actual type for this. callArgs) { var e_1, _a; this.callArgs = callArgs; this.id = _nextNodeID++; /* Layer instance (NOT a list). this is the layer that takes a list of input tensors and turns them into a list of output tensors. the current node will be added to the inboundNodes of outboundLayer. */ this.outboundLayer = args.outboundLayer; /* The following 3 properties describe where the input tensors come from: which layers, and for each layer, which node and which tensor output of each node. */ // List of layer instances. this.inboundLayers = args.inboundLayers; // List of integers, 1:1 mapping with inboundLayers. this.nodeIndices = args.nodeIndices; // List of integers, 1:1 mapping with inboundLayers. this.tensorIndices = args.tensorIndices; /* Following 2 properties: tensor inputs and outputs of outboundLayer. */ // List of tensors. 1:1 mapping with inboundLayers. this.inputTensors = args.inputTensors; // List of tensors, created by outboundLayer.call(). this.outputTensors = args.outputTensors; /* Following 2 properties: input and output masks. List of tensors, 1:1 mapping with inputTensor. */ this.inputMasks = args.inputMasks; // List of tensors, created by outboundLayer.computeMask(). this.outputMasks = args.outputMasks; // Following 2 properties: input and output shapes. // List of shape tuples, shapes of inputTensors. this.inputShapes = args.inputShapes; // List of shape tuples, shapes of outputTensors. this.outputShapes = args.outputShapes; try { // Add nodes to all layers involved. for (var _b = __values(args.inboundLayers), _c = _b.next(); !_c.done; _c = _b.next()) { var layer = _c.value; if (layer != null) { layer.outboundNodes.push(this); } } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_1) throw e_1.error; } } args.outboundLayer.inboundNodes.push(this); } Node.prototype.getConfig = function () { var e_2, _a; var inboundNames = []; try { for (var _b = __values(this.inboundLayers), _c = _b.next(); !_c.done; _c = _b.next()) { var layer = _c.value; if (layer != null) { inboundNames.push(layer.name); } else { inboundNames.push(null); } } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_2) throw e_2.error; } } return { outboundLayer: this.outboundLayer ? this.outboundLayer.name : null, inboundLayers: inboundNames, nodeIndices: this.nodeIndices, tensorIndices: this.tensorIndices }; }; return Node; }()); var _nextLayerID = 0; /** * A layer is a grouping of operations and weights that can be composed to * create a `tf.LayersModel`. * * Layers are constructed by using the functions under the * [tf.layers](#Layers-Basic) namespace. * * @doc {heading: 'Layers', subheading: 'Classes', namespace: 'layers'} */ var Layer = /** @class */ (function (_super) { __extends(Layer, _super); function Layer(args) { if (args === void 0) { args = {}; } var _this = _super.call(this) || this; _this._callHook = null; _this._addedWeightNames = []; // Porting Notes: PyKeras does not have this property in this base Layer // class. Instead lets Layer subclass set it dynamically and checks the // value with `hasattr`. In tfjs-layers, we let this be a member of this // base class. _this._stateful = false; _this.id = _nextLayerID++; _this.activityRegularizer = null; _this.inputSpec = null; _this.supportsMasking = false; // These properties will be set upon call of this.build() _this._trainableWeights = []; _this._nonTrainableWeights = []; _this._losses = []; _this._updates = []; _this._built = false; /* These lists will be filled via successive calls to this.addInboundNode(). */ _this.inboundNodes = []; _this.outboundNodes = []; var name = args.name; if (!name) { var prefix = _this.getClassName(); name = toSnakeCase(prefix) + '_' + getUid(prefix); } _this.name = name; _this.trainable_ = args.trainable == null ? true : args.trainable; if (args.inputShape != null || args.batchInputShape != null) { /* In this case we will later create an input layer to insert before the current layer */ var batchInputShape = void 0; if (args.batchInputShape != null) { batchInputShape = args.batchInputShape; } else if (args.inputShape != null) { var batchSize = null; if (args.batchSize != null) { batchSize = args.batchSize; } batchInputShape = [batchSize].concat(args.inputShape); } _this.batchInputShape = batchInputShape; // Set dtype. var dtype = args.dtype; if (dtype == null) { dtype = args.inputDType; } if (dtype == null) { dtype = 'float32'; } _this.dtype = dtype; } if (args.weights != null) { _this.initialWeights = args.weights; } else { _this.initialWeights = null; } // The value of `_refCount` is initialized to null. When the layer is used // in a symbolic way for the first time, it will be set to 1. _this._refCount = null; _this.fastWeightInitDuringBuild = false; return _this; } /** * Converts a layer and its index to a unique (immutable type) name. * This function is used internally with `this.containerNodes`. * @param layer The layer. * @param nodeIndex The layer's position (e.g. via enumerate) in a list of * nodes. * * @returns The unique name. */ Layer.nodeKey = function (layer, nodeIndex) { return layer.name + '_ib-' + nodeIndex.toString(); }; /** * Returns this.inboundNode at index nodeIndex. * * Porting note: This is a replacement for _get_node_attribute_at_index() * @param nodeIndex * @param attrName The name of the attribute related to request for this node. */ Layer.prototype.getNodeAtIndex = function (nodeIndex, attrName) { if (this.inboundNodes.length === 0) { throw new RuntimeError('The layer has never been called ' + "and thus has no defined ".concat(attrName, ".")); } if (this.inboundNodes.length <= nodeIndex) { throw new ValueError("Asked to get ".concat(attrName, " at node ").concat(nodeIndex, ", ") + "but the layer has only ".concat(this.inboundNodes.length, " inbound nodes.")); } return this.inboundNodes[nodeIndex]; }; /** * Retrieves the input tensor(s) of a layer at a given node. * * @param nodeIndex Integer, index of the node from which to retrieve the * attribute. E.g. `nodeIndex=0` will correspond to the first time the layer * was called. * * @return A tensor (or list of tensors if the layer has multiple inputs). */ Layer.prototype.getInputAt = function (nodeIndex) { return singletonOrArray(this.getNodeAtIndex(nodeIndex, 'input').inputTensors); }; /** * Retrieves the output tensor(s) of a layer at a given node. * * @param nodeIndex Integer, index of the node from which to retrieve the * attribute. E.g. `nodeIndex=0` will correspond to the first time the layer * was called. * * @return A tensor (or list of tensors if the layer has multiple outputs). */ Layer.prototype.getOutputAt = function (nodeIndex) { return singletonOrArray(this.getNodeAtIndex(nodeIndex, 'output').outputTensors); }; Object.defineProperty(Layer.prototype, "input", { // Properties /** * Retrieves the input tensor(s) of a layer. * * Only applicable if the layer has exactly one inbound node, * i.e. if it is connected to one incoming layer. * * @return Input tensor or list of input tensors. * * @exception AttributeError if the layer is connected to more than one * incoming layers. */ get: function () { if (this.inboundNodes.length > 1) { throw new AttributeError("Layer ".concat(this.name) + ' has multiple inbound nodes, ' + 'hence the notion of "layer input" ' + 'is ill-defined. ' + 'Use `getInputAt(nodeIndex)` instead.'); } else if (this.inboundNodes.length === 0) { throw new AttributeError("Layer ".concat(this.name) + ' is not connected, no input to return.'); } return singletonOrArray(this.getNodeAtIndex(0, 'input').inputTensors); }, enumerable: false, configurable: true }); Object.defineProperty(Layer.prototype, "output", { /** * Retrieves the output tensor(s) of a layer. * * Only applicable if the layer has exactly one inbound node, * i.e. if it is connected to one incoming layer. * * @return Output tensor or list of output tensors. * * @exception AttributeError if the layer is connected to more than one * incoming layers. */ get: function () { if (this.inboundNodes.length === 0) { throw new AttributeError("Layer ".concat(this.name) + ' has no inbound nodes.'); } if (this.inboundNodes.length > 1) { throw new AttributeError("Layer ".concat(this.name) + ' has multiple inbound nodes, ' + 'hence the notion of "layer output" ' + 'is ill-defined. ' + 'Use `getOutputAt(nodeIndex)` instead.'); } return singletonOrArray(this.getNodeAtIndex(0, 'output').outputTensors); }, enumerable: false, configurable: true }); Object.defineProperty(Layer.prototype, "losses", { get: function () { return this._losses; }, enumerable: false, configurable: true }); /** * Retrieves the Layer's current loss values. * * Used for regularizers during training. */ Layer.prototype.calculateLosses = function () { // Porting Node: This is an augmentation to Layer.loss in PyKeras. // In PyKeras, Layer.loss returns symbolic tensors. Here a concrete // Tensor (specifically Scalar) values are returned. This is due to the // imperative backend. return this.losses.map(function (lossFn) { return lossFn(); }); }; Object.defineProperty(Layer.prototype, "updates", { get: function () { return this._updates; }, enumerable: false, configurable: true }); Object.defineProperty(Layer.prototype, "built", { get: function () { return this._built; }, set: function (built) { this._built = built; }, enumerable: false, configurable: true }); Object.defineProperty(Layer.prototype, "trainable", { get: function () { return this.trainable_; }, set: function (trainable) { this._trainableWeights.forEach(function (w) { return w.trainable = trainable; }); this.trainable_ = trainable; }, enumerable: false, configurable: true }); Object.defineProperty(Layer.prototype, "trainableWeights", { get: function () { if (this.trainable_) { return this._trainableWeights.filter(function (w) { return w.trainable; }); } else { return []; } }, set: function (weights) { this._trainableWeights = weights; }, enumerable: false, configurable: true }); Object.defineProperty(Layer.prototype, "nonTrainableWeights", { get: function () { if (this.trainable) { return this._trainableWeights.filter(function (w) { return !w.trainable; }) .concat(this._nonTrainableWeights); } else { return this._trainableWeights.concat(this._nonTrainableWeights); } }, set: function (weights) { this._nonTrainableWeights = weights; }, enumerable: false, configurable: true }); Object.defineProperty(Layer.prototype, "weights", { /** * The concatenation of the lists trainableWeights and nonTrainableWeights * (in this order). */ get: function () { return this.trainableWeights.concat(this.nonTrainableWeights); }, enumerable: false, configurable: true }); Object.defineProperty(Layer.prototype, "stateful", { get: function () { return this._stateful; }, enumerable: false, configurable: true }); /** * Reset the states of the layer. * * This method of the base Layer class is essentially a no-op. * Subclasses that are stateful (e.g., stateful RNNs) should override this * method. */ Layer.prototype.resetStates = function () { if (!this.stateful) { throw new Error('Cannot call the resetStates() method of a non-stateful Layer ' + 'object.'); } }; /** * Checks compatibility between the layer and provided inputs. * * This checks that the tensor(s) `input` * verify the input assumptions of the layer * (if any). If not, exceptions are raised. * * @param inputs Input tensor or list of input tensors. * * @exception ValueError in case of mismatch between * the provided inputs and the expectations of the layer. */ Layer.prototype.assertInputCompatibility = function (inputs) { var inputsList = toList(inputs); if (this.inputSpec == null || this.inputSpec.length === 0) { return; } var inputSpec = toList(this.inputSpec); if (inputsList.length !== inputSpec.length) { throw new ValueError("Layer ".concat(this.name, " expects ").concat(inputSpec.length, " inputs, ") + "but it received ".concat(inputsList.length, " input tensors. ") + "Input received: ".concat(inputs)); } for (var inputIndex = 0; inputIndex < inputsList.length; inputIndex++) { var x = inputsList[inputIndex]; var spec = inputSpec[inputIndex]; if (spec == null) { continue; } // Check ndim. var ndim = x.rank; if (spec.ndim != null) { if (ndim !== spec.ndim) { throw new ValueError("Input ".concat(inputIndex, " is incompatible with layer ").concat(this.name, ": ") + "expected ndim=".concat(spec.ndim, ", found ndim=").concat(ndim)); } } if (spec.maxNDim != null) { if (ndim > spec.maxNDim) { throw new ValueError("Input ".concat(inputIndex, " is incompatible with layer ").concat(this.name) + ": expected max_ndim=".concat(spec.maxNDim, ", found ndim=").concat(ndim)); } } if (spec.minNDim != null) { if (ndim < spec.minNDim) { throw new ValueError("Input ".concat(inputIndex, " is incompatible with layer ").concat(this.name) + ": expected min_ndim=".concat(spec.minNDim, ", found ndim=").concat(ndim, ".")); } } // Check dtype. if (spec.dtype != null) { if (x.dtype !== spec.dtype) { throw new ValueError("Input ".concat(inputIndex, " is incompatible with layer ").concat(this.name, " ") + ": expected dtype=".concat(spec.dtype, ", found dtype=").concat(x.dtype, ".")); } } // Check specific shape axes. if (spec.axes) { var xShape = x.shape; for (var key in spec.axes) { var axis = Number(key); var value = spec.axes[key]; // Perform Python-style slicing in case axis < 0; // TODO(cais): Use https://github.com/alvivi/typescript-underscore to // ensure type safety through Underscore calls. var xShapeAtAxis = axis >= 0 ? xShape[axis] : xShape[xShape.length + axis]; if (value != null && [value, null].indexOf(xShapeAtAxis) === -1) { throw new ValueError("Input ".concat(inputIndex, " is incompatible with layer ") + "".concat(this.name, ": expected axis ").concat(axis, " of input shape to ") + "have value ".concat(value, " but got shape ").concat(xShape, ".")); } } } // Check shape. if (spec.shape != null) { for (var i = 0; i < spec.shape.length; ++i) { var specDim = spec.shape[i]; var dim = x.shape[i]; if (specDim != null && dim != null) { if (specDim !== dim) { throw new ValueError("Input ".concat(inputIndex, " is incompatible with layer ") + "".concat(this.name, ": expected shape=").concat(spec.shape, ", ") + "found shape=".concat(x.shape, ".")); } } } } } }; /** * This is where the layer's logic lives. * * @param inputs Input tensor, or list/tuple of input tensors. * @param kwargs Additional keyword arguments. * * @return A tensor or list/tuple of tensors. */ Layer.prototype.call = function (inputs, kwargs) { return inputs; }; Layer.prototype.invokeCallHook = function (inputs, kwargs) { if (this._callHook != null) { this._callHook(inputs, kwargs); } }; /** * Set call hook. * This is currently used for testing only. * @param callHook */ Layer.prototype.setCallHook = function (callHook) { this._callHook = callHook; }; /** * Clear call hook. * This is currently used for testing only. */ Layer.prototype.clearCallHook = function () { this._callHook = null; }; /** * Builds or executes a `Layer`'s logic. * * When called with `tf.Tensor`(s), execute the `Layer`'s computation and * return Tensor(s). For example: * * ```js * const denseLayer = tf.layers.dense({ * units: 1, * kernelInitializer: 'zeros', * useBias: false * }); * * // Invoke the layer's apply() method with a `tf.Tensor` (with concrete * // numeric values). * const input = tf.ones([2, 2]); * const output = denseLayer.apply(input); * * // The output's value is expected to be [[0], [0]], due to the fact that * // the dense layer has a kernel initialized to all-zeros and does not have * // a bias. * output.print(); * ``` * * When called with `tf.SymbolicTensor`(s), this will prepare the layer for * future execution. This entails internal book-keeping on shapes of * expected Tensors, wiring layers together, and initializing weights. * * Calling `apply` with `tf.SymbolicTensor`s are typically used during the * building of non-`tf.Sequential` models. For example: * * ```js * const flattenLayer = tf.layers.flatten(); * const denseLayer = tf.layers.dense({units: 1}); * * // Use tf.layers.input() to obtain a SymbolicTensor as input to apply(). * const input = tf.input({shape: [2, 2]}); * const output1 = flattenLayer.apply(input); * * // output1.shape is [null, 4]. The first dimension is the undetermined * // batch size. The second dimension comes from flattening the [2, 2] * // shape. * console.log(JSON.stringify(output1.shape)); * * // The output SymbolicTensor of the flatten layer can be used to call * // the apply() of the dense layer: * const output2 = denseLayer.apply(output1); * * // output2.shape is [null, 1]. The first dimension is the undetermined * // batch size. The second dimension matches the number of units of the * // dense layer. * console.log(JSON.stringify(output2.shape)); * * // The input and output can be used to construct a model that consists * // of the flatten and dense layers. * const model = tf.model({inputs: input, outputs: output2}); * ``` * * @param inputs a `tf.Tensor` or `tf.SymbolicTensor` or an Array of them. * @param kwargs Additional keyword arguments to be passed to `call()`. * * @return Output of the layer's `call` method. * * @exception ValueError error in case the layer is missing shape information * for its `build` call. * * @doc {heading: 'Models', 'subheading': 'Classes'} */ // Porting Note: This is a replacement for __call__() in Python. Layer.prototype.apply = function (inputs, kwargs) { var _this = this; kwargs = kwargs || {}; this.assertNotDisposed(); // Ensure inputs are all the same type. var inputsList = toList(inputs); var allAreSymbolic = checkAllSymbolic(inputs); var noneAreSymbolic = checkNoneSymbolic(inputs); if (allAreSymbolic === noneAreSymbolic) { throw new ValueError('Arguments to apply() must be all ' + 'SymbolicTensors or all Tensors'); } // TODO(michaelterry): nameScope() may not be necessary. return nameScope(this.name, function () { var e_3, _a, e_4, _b; // Handle laying building (weight creating, input spec locking). if (!_this.built) { /* Throw exceptions in case the input is not compatible with the inputSpec specified in the layer constructor. */ _this.assertInputCompatibility(inputs); // Collect input shapes to build layer. var inputShapes = []; try { for (var _c = __values(toList(inputs)), _d = _c.next(); !_d.done; _d = _c.next()) { var xElem = _d.value; inputShapes.push(xElem.shape); } } catch (e_3_1) { e_3 = { error: e_3_1 }; } finally { try { if (_d && !_d.done && (_a = _c.return)) _a.call(_c); } finally { if (e_3) throw e_3.error; } } _this.build(singletonOrArray(inputShapes)); _this.built = true; // Load weights that were specified at layer instantiation. if (_this.initialWeights) { _this.setWeights(_this.initialWeights); } if (_this._refCount === null && noneAreSymbolic) { // The first use of this layer is a non-symbolic call, set ref count // to 1 so the Layer can be properly disposed if its dispose() method // is called. _this._refCount = 1; } } /* Throw exceptions in case the input is not compatible with the inputSpec set at build time. */ _this.assertInputCompatibility(inputs); // Handle mask propagation. // TODO(michaelterry): Mask propagation not currently implemented. // Actually call the layer, collecting output(s), mask(s), and shape(s). if (noneAreSymbolic) { var output = _this.call(inputs, kwargs); // Apply masks to the output tensors if the layer supports it. if (_this.supportsMasking) { // TODO(mattsoulanille): pass the input tensors' masks to computeMask _this.setMaskMetadata(inputs, output); } // If the layer returns tensors from its inputs, unmodified, // we copy them to avoid loss of tensor metadata. var outputList = toList(output); var outputListCopy = []; try { // TODO(michaelterry): This copying may not be necessary given our eager // backend. for (var outputList_1 = __values(outputList), outputList_1_1 = outputList_1.next(); !outputList_1_1.done; outputList_1_1 = outputList_1.next()) { var x = outputList_1_1.value; if (inputsList.indexOf(x) !== -1) { x = x.clone(); } outputListCopy.push(x); } } catch (e_4_1) { e_4 = { error: e_4_1 }; } finally { try { if (outputList_1_1 && !outputList_1_1.done && (_b = outputList_1.return)) _b.call(outputList_1); } finally { if (e_4) throw e_4.error; } } output = singletonOrArray(outputListCopy); if (_this.activityRegularizer != null) { throw new NotImplementedError('Layer invocation in the presence of activity ' + 'regularizer(s) is not supported yet.'); } // TODO(michaelterry): Call addInboundNode()? return output; } else { var inputShape = collectInputShape(inputs); var outputShape = _this.computeOutputShape(inputShape); var output = void 0; var outputDType_1 = guessOutputDType(inputs); _this.warnOnIncompatibleInputShape(Array.isArray(inputs) ? inputShape[0] : inputShape); if (outputShape != null && outputShape.length > 0 && Array.isArray(outputShape[0])) { // We have multiple output shapes. Create multiple output tensors. output = outputShape .map(function (shape, index) { return new SymbolicTensor(outputDType_1, shape, _this, toList(inputs), kwargs, _this.name, index); }); } else { output = new SymbolicTensor(outputDType_1, outputShape, _this, toList(inputs), kwargs, _this.name); } /* Add an inbound node to the layer, so that it keeps track of the call and of all new variables created during the call. This also updates the layer history of the output tensor(s). If the input tensor(s) had no previous history, this does nothing. */ _this.addInboundNode(inputs, output, null, null, inputShape, outputShape, kwargs); _this._refCount++; if (_this.activityRegularizer != null) { throw new NotImplementedError('Layer invocation in the presence of activity ' + 'regularizer(s) is not supported yet.'); } return output; } }); }; /** * Check compatibility between input shape and this layer's batchInputShape. * * Print warning if any incompatibility is found. * * @param inputShape Input shape to be checked. */ Layer.prototype.warnOnIncompatibleInputShape = function (inputShape) { if (this.batchInputShape == null) { return; } else if (inputShape.length !== this.batchInputShape.length) { console.warn("The rank of the input tensor provided (shape: " + "".concat(JSON.stringify(inputShape), ") does not match that of the ") + "batchInputShape (".concat(JSON.stringify(this.batchInputShape), ") ") + "of the layer ".concat(this.name)); } else { var dimMismatch_1 = false; this.batchInputShape.forEach(function (dimension, i) { if (dimension != null && inputShape[i] != null && inputShape[i] !== dimension) { dimMismatch_1 = true; } }); if (dimMismatch_1) { console.warn("The shape of the input tensor " + "(".concat(JSON.stringify(inputShape), ") does not ") + "match the expectation of layer ".concat(this.name, ": ") + "".concat(JSON.stringify(this.batchInputShape))); } } }; Object.defineProperty(Layer.prototype, "outputShape", { /** * Retrieves the output shape(s) of a layer. * * Only applicable if the layer has only one inbound node, or if all inbound * nodes have the same output shape. * * @returns Output shape or shapes. * @throws AttributeError: if the layer is connected to more than one incoming * nodes. * * @doc {heading: 'Models', 'subheading': 'Classes'} */ get: function () { var e_5, _a; if (this.inboundNodes == null || this.inboundNodes.length === 0) { throw new AttributeError("The layer ".concat(this.name, " has never been called and thus has no ") + "defined output shape."); } var allOutputShapes = []; try { for (var _b = __values(this.inboundNodes), _c = _b.next(); !_c.done; _c = _b.next()) { var node = _c.value; var shapeString = JSON.stringify(node.outputShapes); if (allOutputShapes.indexOf(shapeString) === -1) { allOutputShapes.push(shapeString); } } } catch (e_5_1) { e_5 = { error: e_5_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_5) throw e_5.error; } } if (allOutputShapes.length === 1) { var outputShapes = this.inboundNodes[0].outputShapes; if (Array.isArray(outputShapes) && Array.isArray(outputShapes[0]) && outputShapes.length === 1) { return outputShapes[0]; } else { return outputShapes; } } else { throw new AttributeError("The layer ".concat(this.name, " has multiple inbound nodes with different ") + "output shapes. Hence the notion of \"output shape\" is ill-defined " + "for the layer."); // TODO(cais): Implement getOutputShapeAt(). } }, enumerable: false, configurable: true }); /** * Counts the total number of numbers (e.g., float32, int32) in the * weights. * * @returns An integer count. * @throws RuntimeError: If the layer is not built yet (in which case its * weights are not defined yet.) * * @doc {heading: 'Models', 'subheading': 'Classes'} */ Layer.prototype.countParams = function () { if (!this.built) { throw new RuntimeError("You tried to call countParams() on ".concat(this.name, ", ") + "but the layer is not built yet. Build it first by calling " + "build(batchInputShape)."); } return countParamsInWeights(this.weights); }; /** * Creates the layer weights. * * Must be implemented on all layers that have weights. * * Called when apply() is called to construct the weights. * * @param inputShape A `Shape` or array of `Shape` (unused). * * @doc {heading: 'Models', 'subheading': 'Classes'} */ Layer.prototype.build = function (inputShape) { this.built = true; }; /** * Returns the current values of the weights of the layer. * * @param trainableOnly Whether to get the values of only trainable weights. * @returns Weight values as an `Array` of `tf.Tensor`s. * * @doc {heading: 'Models', 'subheading': 'Classes'} */ Layer.prototype.getWeights = function (trainableOnly) { if (trainableOnly === void 0) { trainableOnly = false; } return batchGetValue(trainableOnly ? this.trainableWeights : this.weights); }; /** * Sets the weights of the layer, from Tensors. * * @param weights a list of Tensors. The number of arrays and their shape * must match number of the dimensions of the weights of the layer (i.e. * it should match the output of `getWeights`). * * @exception ValueError If the provided weights list does not match the * layer's specifications. * * @doc {heading: 'Models', 'subheading': 'Classes'} */ Layer.prototype.setWeights = function (weights) { var _this = this; tfc.tidy(function () { var params = _this.weights; if (params.length !== weights.length) { // TODO(cais): Restore the following and use `providedWeights`, instead // of `weights` in the error message, once the deeplearn.js bug is // fixed: https://github.com/PAIR-code/deeplearnjs/issues/498 const // providedWeights = JSON.stringify(weights).slice(0, 50); throw new ValueError("You called setWeights(weights) on layer \"".concat(_this.name, "\" ") + "with a weight list of length ".concat(weights.length, ", ") + "but the layer was expecting ".concat(params.length, " weights. ") + "Provided weights: ".concat(weights, "...")); } if (params.length === 0) { return; } var weightValueTuples = []; var paramValues = batchGetValue(params); for (var i = 0; i < paramValues.length; ++i) { var pv = paramValues[i]; var p = params[i]; var w = weights[i]; if (!tfc.util.arraysEqual(pv.shape, w.shape)) { throw new ValueError("Layer weight shape ".concat(pv.shape, " ") + "not compatible with provided weight shape ".concat(w.shape)); } weightValueTuples.push([p, w]); } batchSetValue(weightValueTuples); }); }; /** * Adds a weight variable to the layer. * * @param name Name of the new weight variable. * @param shape The shape of the weight. * @param dtype The dtype of the weight. * @param initializer An initializer instance. * @param regularizer A regularizer instance. * @param trainable Whether the weight should be trained via backprop or not * (assuming that the layer itself is also trainable). * @param constraint An optional trainable. * @return The created weight variable. * * @doc {heading: 'Models', 'subheading': 'Classes'} */ Layer.prototype.addWeight = function (name, shape, dtype, initializer, regularizer, trainable, constraint, getInitializerFunc) { // Reject duplicate weight names. if (this._addedWeightNames.indexOf(name) !== -1) { throw new ValueError("Duplicate weight name ".concat(name, " for layer ").concat(this.name)); } this._addedWeightNames.push(name); if (dtype == null) { dtype = 'float32'; } if (this.fastWeightInitDuringBuild) { initializer = getInitializerFunc != null ? getInitializerFunc() : getInitializer('zeros'); } var initValue = initializer.apply(shape, dtype); var weight = new LayerVariable(initValue, dtype, name, trainable, constraint); initValue.dispose(); // Request backend not to dispose the weights of the model on scope() exit. if (regularizer != null) { this.addLoss(function () { return regularizer.apply(weight.read()); }); } if (trainable == null) { trainable = true; } if (trainable) { this._trainableWeights.push(weight); } else { this._nonTrainableWeights.push(weight); } return weight; }; /** * Set the fast-weight-initialization flag. * * In cases where the initialized weight values will be immediately * overwritten by loaded weight values during model loading, setting * the flag to `true` saves unnecessary calls to potentially expensive * initializers and speeds up the loading process. * * @param value Target value of the flag. */ Layer.prototype.setFastWeightInitDuringBuild = function (value) { this.fastWeightInitDuringBuild = value; }; /** * Add losses to the layer. * * The loss may potentially be conditional on some inputs tensors, * for instance activity losses are conditional on the layer's inputs. * * @doc {heading: 'Models', 'subheading': 'Classes'} */ Layer.prototype.addLoss = function (losses) { var _a; if (losses == null || Array.isArray(losses) && losses.length === 0) { return; } // Update this.losses losses = toList(losses); if (this._losses !== undefined && this._losses !== null) { (_a = this.losses).push.apply(_a, __spreadArray([], __read(losses), false)); } }; /** * Computes the output shape of the layer. * * Assumes that the layer will be built to match that input shape provided. * * @param inputShape A shape (tuple of integers) or a list of shape tuples * (one per output tensor of the layer). Shape tuples can include null for * free dimensions, instead of an integer. * * @doc {heading: 'Models', 'subheading': 'Classes'} */ Layer.prototype.computeOutputShape = function (inputShape) { return inputShape; }; /** * Computes an output mask tensor. * * @param inputs Tensor or list of tensors. * @param mask Tensor or list of tensors. * * @return null or a tensor (or list of tensors, one per output tensor of the * layer). */ Layer.prototype.computeMask = function (inputs, mask) { var _this = this; if (!this.supportsMasking) { if (mask != null) { if (Array.isArray(mask)) { mask.forEach(function (maskElement) { if (maskElement != null) { throw new TypeError("Layer ".concat(_this.name, " does not support masking, ") + 'but was passed an inputMask.'); } }); } else { throw new TypeError("Layer ".concat(this.name, " does not support masking, ") + 'but was passed an inputMask.'); } } // masking not explicitly supported: return null as mask return null; } // if masking is explictly supported, by default // carry over the input mask return mask; }; Layer.prototype.setMaskMetadata = function (inputs, outputs, previousMask) { if (!this.supportsMasking) { return; } var outputMasks = this.computeMask(inputs, previousMask); var outputsList = toList(outputs); var outputMasksList = toList(outputMasks); if (outputsList.length !== outputMasksList.length) { throw new Error("".concat(this.name, " outputs ").concat(outputsList.length, " tensors ") + "but ".concat(outputsList.length, " masks for those tensors")); } for (var i = 0; i < outputsList.length; i++) { outputsList[i].kerasMask = outputMasksList[i]; } }; /** * Internal method to create an inbound node for the layer. * * @param inputTensors List of input tensors. * @param outputTensors List of output tensors. * @param inputMasks List of input masks (a mask can be a tensor, or null). * @param outputMasks List of output masks (a mask can be a tensor, or null). * @param inputShapes List of input shape tuples. * @param outputShapes List of output shape tuples. * @param kwargs Dictionary of keyword arguments that were passed to the * `call` method of the layer at the call that created the node. */ Layer.prototype.addInboundNode = function (inputTensors, outputTensors, inputMasks, outputMasks, inputShapes, outputShapes, kwargs) { var e_6, _a; if (kwargs === void 0) { kwargs = null; } var inputTensorList = toList(inputTensors); outputTensors = toList(outputTensors); inputMasks = toList(inputMasks); outputMasks = toList(outputMasks); inputShapes = normalizeShapeList(inputShapes); outputShapes = normalizeShapeList(outputShapes); // Collect input tensor(s) coordinates. var inboundLayers = []; var nodeIndices = []; var tensorIndices = []; try { for (var inputTensorList_1 = __values(inputTensorList), inputTensorList_1_1 = inputTensorList_1.next(); !inputTensorList_1_1.done; inputTensorList_1_1 = inputTensorList_1.next()) { var x = inputTensorList_1_1.value; /* * TODO(michaelterry): Keras adds this value to tensors; it's not * clear whether we'll use this or not. */ inboundLayers.push(x.sourceLayer); nodeIndices.push(x.nodeIndex); tensorIndices.push(x.tensorIndex); } } catch (e_6_1) { e_6 = { error: e_6_1 }; } finally { try { if (inputTensorList_1_1 && !inputTensorList_1_1.done && (_a = inputTensorList_1.return)) _a.call(inputTensorList_1); } finally { if (e_6) throw e_6.error; } } // Create node, add it to inbound nodes. // (This call has side effects.) // tslint:disable-next-line:no-unused-expression new Node({ outboundLayer: this, inboundLayers: inboundLayers, nodeIndices: nodeIndices, tensorIndices: tensorIndices, inputTensors: inputTensorList, outputTensors: outputTensors, inputMasks: inputMasks, outputMasks: outputMasks, inputShapes: inputShapes, outputShapes: outputShapes }, kwargs); // Update tensor history for (var i = 0; i < outputTensors.length; i++) { // TODO(michaelterry: _uses_learning_phase not tracked. outputTensors[i].sourceLayer = this; outputTensors[i].nodeIndex = this.inboundNodes.length - 1; outputTensors[i].tensorIndex = i; } }; /** * Returns the config of the layer. * * A layer config is a TS dictionary (serializable) * containing the configuration of a layer. * The same layer can be reinstantiated later * (without its trained weights) from this configuration. * * The config of a layer does not include connectivity * information, nor the layer class name. These are handled * by 'Container' (one layer of abstraction above). * * Porting Note: The TS dictionary follows TS naming standards for * keys, and uses tfjs-layers type-safe Enums. Serialization methods * should use a helper function to convert to the pythonic storage * standard. (see serialization_utils.convertTsToPythonic) * * @returns TS dictionary of configuration. * * @doc {heading: 'Models', 'subheading': 'Classes'} */ Layer.prototype.getConfig = function () { var config = { name: this.name, trainable: this.trainable }; if (this.batchInputShape != null) { config['batchInputShape'] = this.batchInputShape; } if (this.dtype != null) { config['dtype'] = this.dtype; } return config; }; /** * Dispose the weight variables that this Layer instance holds. * * @returns {number} Number of disposed variables. */ Layer.prototype.disposeWeights = function () { this.weights.forEach(function (weight) { return weight.dispose(); }); return this.weights.length; }; Layer.prototype.assertNotDisposed = function () { if (this._refCount === 0) { throw new Error("Layer '".concat(this.name, "' is already disposed.")); } }; /** * Attempt to dispose layer's weights. * * This method decreases the reference count of the Layer object by 1. * * A Layer is reference-counted. Its reference count is incremented by 1 * the first item its `apply()` method is called and when it becomes a part * of a new `Node` (through calling the `apply()` method on a * `tf.SymbolicTensor`). * * If the reference count of a Layer becomes 0, all the weights will be * disposed and the underlying memory (e.g., the textures allocated in WebGL) * will be freed. * * Note: If the reference count is greater than 0 after the decrement, the * weights of the Layer will *not* be disposed. * * After a Layer is disposed, it cannot be used in calls such as `apply()`, * `getWeights()` or `setWeights()` anymore. * * @returns A DisposeResult Object with the following fields: * - refCountAfterDispose: The reference count of the Container after this * `dispose()` call. * - numDisposedVariables: Number of `tf.Variable`s (i.e., weights) disposed * during this `dispose()` call. * @throws {Error} If the layer is not built yet, or if the layer has already * been disposed. * * @doc {heading: 'Models', 'subheading': 'Classes'} */ Layer.prototype.dispose = function () { if (!this.built) { throw new Error("Cannot dispose Layer ".concat(this.name, " because it has not been ") + "built yet."); } if (this._refCount === null) { throw new Error("Cannot dispose Layer ".concat(this.name, " because it has not been used ") + "yet."); } this.assertNotDisposed(); var numDisposedVariables = 0; if (--this._refCount === 0) { numDisposedVariables = this.disposeWeights(); } return { refCountAfterDispose: this._refCount, numDisposedVariables: numDisposedVariables }; }; return Layer; }(tfc.serialization.Serializable)); /** * Collects the input shape(s) of a list of `tf.Tensor`s or * `tf.SymbolicTensor`s. * * TODO(michaelterry): Update PyKeras docs (backport). * * @param inputTensors List of input tensors (or single input tensor). * * @return List of shape tuples (or single tuple), one tuple per input. */ function collectInputShape(inputTensors) { var e_7, _a; inputTensors = toList(inputTensors); var shapes = []; try { for (var inputTensors_1 = __values(inputTensors), inputTensors_1_1 = inputTensors_1.next(); !inputTensors_1_1.done; inputTensors_1_1 = inputTensors_1.next()) { var x = inputTensors_1_1.value; shapes.push(x.shape); } } catch (e_7_1) { e_7 = { error: e_7_1 }; } finally { try { if (inputTensors_1_1 && !inputTensors_1_1.done && (_a = inputTensors_1.return)) _a.call(inputTensors_1); } finally { if (e_7) throw e_7.error; } } return singletonOrArray(shapes); } /** * Guesses output dtype based on inputs. * * At present, just returns 'float32' for any input. * * @param inputTensors List of input tensors (or single input tensor). * * @return The guessed DType. At present, always returns 'float32'. */ function guessOutputDType(inputTensors) { return 'float32'; } /** * Returns the list of input tensors necessary to compute `tensor`. * * Output will always be a list of tensors (potentially with 1 element). * * @param tensor The tensor to start from. * @param layer Origin layer of the tensor. * @param nodeIndex Origin node index of the tensor. * * @return Array of input tensors. */ function getSourceInputs(tensor, layer, nodeIndex) { var e_8, _a; if (layer == null || (nodeIndex != null && nodeIndex > 0)) { layer = tensor.sourceLayer; nodeIndex = tensor.nodeIndex; } if (layer.inboundNodes.length === 0) { return [tensor]; } else { var node = layer.inboundNodes[nodeIndex]; if (node.inboundLayers.length === 0) { return node.inputTensors; } else { var sourceTensors = []; for (var i = 0; i < node.inboundLayers.length; i++) { var x = node.inputTensors[i]; var layer_1 = node.inboundLayers[i]; var nodeIndex_1 = node.nodeIndices[i]; var previousSources = getSourceInputs(x, layer_1, nodeIndex_1); try { // Avoid input redundancy. for (var previousSources_1 = (e_8 = void 0, __values(previousSources)), previousSources_1_1 = previousSources_1.next(); !previousSources_1_1.done; previousSources_1_1 = previousSources_1.next()) { var x_1 = previousSources_1_1.value; if (sourceTensors.indexOf(x_1) === -1) { sourceTensors.push(x_1); } } } catch (e_8_1) { e_8 = { error: e_8_1 }; } finally { try { if (previousSources_1_1 && !previousSources_1_1.done && (_a = previousSources_1.return)) _a.call(previousSources_1); } finally { if (e_8) throw e_8.error; } } } return sourceTensors; } } } function checkAllSymbolic(tensors) { var e_9, _a; var allAreSymbolic = true; try { for (var _b = __values(toList(tensors)), _c = _b.next(); !_c.done; _c = _b.next()) { var tensor = _c.value; if (!(tensor instanceof SymbolicTensor)) { allAreSymbolic = false; break; } } } catch (e_9_1) { e_9 = { error: e_9_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_9) throw e_9.error; } } return allAreSymbolic; } function checkNoneSymbolic(tensors) { var e_10, _a; var noneAreSymbolic = true; try { for (var _b = __values(toList(tensors)), _c = _b.next(); !_c.done; _c = _b.next()) { var tensor = _c.value; if (tensor instanceof SymbolicTensor) { noneAreSymbolic = false; break; } } } catch (e_10_1) { e_10 = { error: e_10_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_10) throw e_10.error; } } return noneAreSymbolic; } var InputLayer = /** @class */ (function (_super) { __extends(InputLayer, _super); function InputLayer(args) { var _this = _super.call(this, { dtype: args.dtype, name: args.name != null ? args.name : getUid('input').toString() }) || this; // Normalize config.batchSize and config.sparse if (args.batchSize == null) { args.batchSize = null; } if (args.sparse == null) { args.sparse = false; } _this.trainable = false; _this.built = true; _this.sparse = args.sparse; if (args.inputShape != null && args.batchInputShape != null) { throw new ValueError('Only provide the inputShape OR ' + 'batchInputShape argument to inputLayer, not both at the same time.'); } var batchInputShape = args.batchInputShape; if (batchInputShape == null) { if (args.inputShape == null) { throw new ValueError('An InputLayer should be passed either a ' + '`batchInputShape` or an `inputShape`.'); } else { batchInputShape = [args.batchSize].concat(args.inputShape); } } else { // TODO(michaelterry): Backport to PyKeras if (args.batchSize != null) { throw new ValueError('Cannot specify batchSize if batchInputShape is ' + 'specified when creating an InputLayer.'); } } var dtype = args.dtype || 'float32'; _this.batchInputShape = batchInputShape; _this.dtype = dtype; // TODO(michaelterry): Backport this to PyKeras? _this.inputSpec = [{ shape: batchInputShape }]; var inputTensor = new SymbolicTensor(_this.dtype, _this.batchInputShape, _this, [], {}, _this.name); inputTensor.nodeIndex = 0; inputTensor.tensorIndex = 0; // Create an input node to add to this.outboundNode. // (This call has side effects.) // tslint:disable-next-line:no-unused-expression new Node({ outboundLayer: _this, inboundLayers: [], nodeIndices: [], tensorIndices: [], inputTensors: [inputTensor], outputTensors: [inputTensor], inputMasks: [null], outputMasks: [null], inputShapes: [batchInputShape], outputShapes: [batchInputShape] }); return _this; } InputLayer.prototype.apply = function (inputs, kwargs) { throw new ValueError('Cannot pass any input to an ' + "InputLayer's apply() method. InputLayer name: ".concat(this.name)); }; InputLayer.prototype.dispose = function () { // dispose() for InputLayer is overridden as no-op. return { refCountAfterDispose: this._refCount, numDisposedVariables: 0 }; }; InputLayer.prototype.getConfig = function () { return { batchInputShape: this.batchInputShape, dtype: this.dtype, sparse: this.sparse, name: this.name }; }; return InputLayer; }(Layer)); /** @nocollapse */ InputLayer.className = 'InputLayer'; tfc.serialization.registerClass(InputLayer); function Input(config) { if (config.batchShape == null && config.shape == null) { throw new Error('Please provide to Input either a `shape`' + ' or a `batchShape` argument. Note that ' + '`shape` does not include the batch ' + 'dimension.'); } if (config.batchShape != null && config.shape != null) { // TODO(michaelterry): Backport to PyKeras. throw new ValueError('Please provide either a `shape` or `batchShape` ' + 'argument to Input, but not both.'); } var batchShape = config.batchShape; if (config.shape != null && batchShape == null) { batchShape = [null].concat(config.shape); } var dtype = config.dtype; if (dtype == null) { dtype = 'float32'; } var inputLayer = new InputLayer({ batchInputShape: batchShape, name: config.name, dtype: dtype, sparse: config.sparse }); var outputs = inputLayer.inboundNodes[0].outputTensors; return outputs[0]; } /** * Helper function to check the dtype and shape compatibility of a feed value. */ function assertFeedCompatibility(key, val) { // Check dtype compatibility. if (key.dtype == null || key.dtype === val.dtype) { // a. If types match, return val tensor as is. return val; } try { // b. Attempt to convert to expected type. return tfc.cast(val, key.dtype); } catch (err) { // c. If conversion fails, return helpful error. throw new ValueError("The dtype of the feed (".concat(val.dtype, ") can not be cast to the dtype ") + "of the key '".concat(key.name, "' (").concat(key.dtype, ").")); } } /** * FeedDict: A mapping from unique SymbolicTensors to feed values for them. * A feed value is a concrete value represented as an `Tensor`. */ var FeedDict = /** @class */ (function () { /** * Constructor, optionally does copy-construction. * @param feeds An Array of `Feed`s, or another `FeedDict`, in which case * copy-construction will be performed. */ function FeedDict(feeds) { var e_1, _a; this.id2Value = {}; this.id2Mask = {}; this.name2Id = {}; if (feeds instanceof FeedDict) { for (var id in feeds.id2Value) { this.id2Value[id] = feeds.id2Value[id]; if (id in feeds.id2Mask) { this.id2Mask[id] = feeds.id2Mask[id]; } } } else { if (feeds == null) { return; } try { for (var feeds_1 = __values(feeds), feeds_1_1 = feeds_1.next(); !feeds_1_1.done; feeds_1_1 = feeds_1.next()) { var feed = feeds_1_1.value; this.add(feed.key, feed.value); } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (feeds_1_1 && !feeds_1_1.done && (_a = feeds_1.return)) _a.call(feeds_1); } finally { if (e_1) throw e_1.error; } } } } /** * Add a key-value pair to the FeedDict. * * @param key The key of the feed. * @param value The value of the tensor feed. * @param mask The value of the mask feed (optional). * @returns This `FeedDict`. * @throws ValueError: If the key `SymbolicTensor` already exists in the * `FeedDict`. */ FeedDict.prototype.add = function (key, value, mask) { if (this.id2Value[key.id] == null) { this.id2Value[key.id] = assertFeedCompatibility(key, value); this.name2Id[key.name] = key.id; if (mask != null) { this.id2Mask[key.id] = mask; } } else { throw new ValueError("Duplicate key: name=".concat(key.name, ", id=").concat(key.id)); } return this; }; /** * Add a Feed to the FeedDict. * @param feed The new `Feed` to add. * @returns This `FeedDict`. */ FeedDict.prototype.addFeed = function (feed) { this.add(feed.key, feed.value); }; /** * Probe whether a key already exists in the FeedDict. * @param key */ FeedDict.prototype.hasKey = function (key) { return this.id2Value[key.id] != null; }; /** * Get all the SymbolicTensor available in this FeedDict. */ FeedDict.prototype.names = function () { return Object.keys(this.name2Id); }; /** * Get the feed value for given key. * @param key The SymbolicTensor, or its name (as a string), of which the * value is sought. * @returns If `key` exists, the corresponding feed value. * @throws ValueError: If `key` does not exist in this `FeedDict`. */ FeedDict.prototype.getValue = function (key) { if (key instanceof SymbolicTensor) { if (this.id2Value[key.id] == null) { throw new ValueError("Nonexistent key: ".concat(key.name)); } else { return this.id2Value[key.id]; } } else { var id = this.name2Id[key]; if (id == null) { throw new ValueError("Feed dict has no SymbolicTensor name: ".concat(key)); } return this.id2Value[id]; } }; /** * Get the feed mask for given key. * @param key The SymbolicTensor, or its name (as a string), of which the * value is sought. * @returns If `key` exists, the corresponding feed mask. * @throws ValueError: If `key` does not exist in this `FeedDict`. */ FeedDict.prototype.getMask = function (key) { if (key instanceof SymbolicTensor) { if (this.id2Value[key.id] == null) { throw new ValueError("Nonexistent key: ".concat(key.name)); } else { return this.id2Mask[key.id]; } } else { var id = this.name2Id[key]; if (id == null) { throw new ValueError("Feed dict has no SymbolicTensor name: ".concat(key)); } return this.id2Mask[id]; } }; /** Dispose all mask Tensors held by this object. */ FeedDict.prototype.disposeMasks = function () { if (this.id2Mask != null) { tfc.dispose(this.id2Mask); } }; return FeedDict; }()); // Cache for topologically sorted SymbolicTensors for given execution // targets (i.e., fetches). var cachedSorted = new LruCache(); // Cache for recipient count maps for given execution targets (i.e., fetches). var cachedRecipientCounts = new LruCache(); function updateCacheMaxEntries(maxEntries) { if (cachedSorted != null) { cachedSorted.setMaxEntries(maxEntries); } if (cachedRecipientCounts != null) { cachedRecipientCounts.setMaxEntries(maxEntries); } } /** * Execute a SymbolicTensor by using concrete feed values. * * A `SymbolicTensor` object is a node in a computation graph of TF.js * Layers. The object is backed by a source layer and input * `SymbolicTensor`s to the source layer. This method evaluates * the `call()` method of the source layer, using concrete values of the * inputs obtained from either * * `feedDict`, if the input key exists in `feedDict`, or else, * * a recursive call to `execute()` itself. * * @param x: The `SymbolicTensor` to execute. * @param feedDict: The feed values, as base condition of the recursion. * execution. * @param kwargs: Optional keyword arguments. * @param probe: A probe object (of interface `ExecutionProbe`) used for * testing memory footprint of `execute` calls. * @returns Result of the execution. * @throws ValueError: If any `SymbolicTensor`s from `InputLayer`s * encountered during the execution lacks a feed value in `feedDict`. */ function execute(fetches, feedDict, kwargs, probe) { var e_2, _a, e_3, _b; var training = kwargs == null ? false : kwargs['training']; var arrayFetches = Array.isArray(fetches); var fetchArray = arrayFetches ? fetches : [fetches]; var outputNames = fetchArray.map(function (t) { return t.name; }); var finalOutputs = []; var feedNames = feedDict.names(); try { for (var outputNames_1 = __values(outputNames), outputNames_1_1 = outputNames_1.next(); !outputNames_1_1.done; outputNames_1_1 = outputNames_1.next()) { var outputName = outputNames_1_1.value; if (feedNames.indexOf(outputName) !== -1) { finalOutputs.push(feedDict.getValue(outputName)); } else { finalOutputs.push(null); } } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (outputNames_1_1 && !outputNames_1_1.done && (_a = outputNames_1.return)) _a.call(outputNames_1); } finally { if (e_2) throw e_2.error; } } if (probe != null) { // For optional probing of memory footprint during execution. probe.maxNumTensors = -Infinity; probe.minNumTensors = Infinity; } // Check cache. var fetchAndFeedKey = outputNames.join(',') + '|' + feedDict.names().sort().join(','); var sorted = cachedSorted.get(fetchAndFeedKey); var recipientCounts; if (sorted == null) { // Cache doesn't contain the desired combination of fetches. Compute // topological sort for the combination for the first time. var out = getTopologicalSortAndRecipientCounts(fetchArray, feedDict); sorted = out.sorted; recipientCounts = out.recipientCounts; // Store results in cache for future use. cachedSorted.put(fetchAndFeedKey, sorted); cachedRecipientCounts.put(fetchAndFeedKey, recipientCounts); } recipientCounts = {}; if (!training) { Object.assign(recipientCounts, cachedRecipientCounts.get(fetchAndFeedKey)); } var internalFeedDict = new FeedDict(feedDict); // Start iterative execution on the topologically-sorted SymbolicTensors. for (var i = 0; i < sorted.length; ++i) { if (probe != null) { // For optional probing of memory usage during execution. var numTensors = tfc.memory().numTensors; if (numTensors > probe.maxNumTensors) { probe.maxNumTensors = numTensors; } if (numTensors < probe.minNumTensors) { probe.minNumTensors = numTensors; } } var symbolic = sorted[i]; var srcLayer = symbolic.sourceLayer; if (srcLayer instanceof InputLayer) { continue; } var inputValues = []; var inputMasks = []; var tensorsToDispose = []; var maskExists = false; try { for (var _c = (e_3 = void 0, __values(symbolic.inputs)), _d = _c.next(); !_d.done; _d = _c.next()) { var input = _d.value; var value = internalFeedDict.getValue(input); var mask = internalFeedDict.getMask(input); inputValues.push(value); inputMasks.push(mask); if (mask != null) { maskExists = true; } if (!training) { recipientCounts[input.name]--; if (recipientCounts[input.name] === 0 && !feedDict.hasKey(input) && outputNames.indexOf(input.name) === -1 && !value.isDisposed && input.sourceLayer.stateful !== true) { tensorsToDispose.push(value); } } } } catch (e_3_1) { e_3 = { error: e_3_1 }; } finally { try { if (_d && !_d.done && (_b = _c.return)) _b.call(_c); } finally { if (e_3) throw e_3.error; } } if (maskExists) { kwargs = kwargs || {}; kwargs['mask'] = inputMasks[0]; } var outputTensors = toList(srcLayer.apply(inputValues, kwargs)); var outputMask = null; if (srcLayer.supportsMasking) { outputMask = srcLayer.computeMask(inputValues, inputMasks); } var layerOutputs = getNodeOutputs(symbolic); var outputSymbolicTensors = Array.isArray(layerOutputs) ? layerOutputs : [layerOutputs]; for (var i_1 = 0; i_1 < outputSymbolicTensors.length; ++i_1) { if (!internalFeedDict.hasKey(outputSymbolicTensors[i_1])) { internalFeedDict.add(outputSymbolicTensors[i_1], outputTensors[i_1], Array.isArray(outputMask) ? outputMask[0] : outputMask); } var index = outputNames.indexOf(outputSymbolicTensors[i_1].name); if (index !== -1) { finalOutputs[index] = outputTensors[i_1]; } } if (!training) { // Clean up Tensors that are no longer needed. tfc.dispose(tensorsToDispose); } } // NOTE(cais): Unlike intermediate tensors, we don't discard mask // tensors as we go, because these tensors are sometimes passed over a // series of mutliple layers, i.e., not obeying the immediate input // relations in the graph. If this becomes a memory-usage concern, // we can improve this in the future. internalFeedDict.disposeMasks(); return arrayFetches ? finalOutputs : finalOutputs[0]; } /** * Sort the `SymbolicTensor`s topologically, for an array of fetches. * * This function calls getTopologicalSortAndRecipientCountsForOneFetch and * merges their results. * * @param fetch The array of fetches requested. Must be a non-empty array. * @param feedDict The dictionary of fed values. * @returns sorted: Topologically-sorted array of SymbolicTensors. * recipientCounts: Recipient counts for all SymbolicTensors in `sorted`. */ function getTopologicalSortAndRecipientCounts(fetches, feedDict) { var e_4, _a, e_5, _b; tfc.util.assert(fetches != null && fetches.length > 0, function () { return "Expected at least one fetch, got none"; }); var finalSorted = []; var finalRecipientMap = {}; if (fetches.length === 1) { // Special-casing 1 fetch for efficiency. var out = getTopologicalSortAndRecipientCountsForOneFetch(fetches[0], feedDict); finalSorted = out.sorted; finalRecipientMap = out.recipientMap; } else { var visited = new Set(); try { for (var fetches_1 = __values(fetches), fetches_1_1 = fetches_1.next(); !fetches_1_1.done; fetches_1_1 = fetches_1.next()) { var fetch = fetches_1_1.value; var _c = getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict), sorted = _c.sorted, recipientMap = _c.recipientMap; try { // Merge sorted SymbolicTensor Arrays. for (var sorted_1 = (e_5 = void 0, __values(sorted)), sorted_1_1 = sorted_1.next(); !sorted_1_1.done; sorted_1_1 = sorted_1.next()) { var symbolicTensor = sorted_1_1.value; if (!visited.has(symbolicTensor.name)) { finalSorted.push(symbolicTensor); visited.add(symbolicTensor.name); } } } catch (e_5_1) { e_5 = { error: e_5_1 }; } finally { try { if (sorted_1_1 && !sorted_1_1.done && (_b = sorted_1.return)) _b.call(sorted_1); } finally { if (e_5) throw e_5.error; } } var _loop_1 = function (name) { if (finalRecipientMap[name] == null) { finalRecipientMap[name] = new Set(); } recipientMap[name].forEach(function (recipient) { return finalRecipientMap[name].add(recipient); }); }; // Merge recipient maps. for (var name in recipientMap) { _loop_1(name); } } } catch (e_4_1) { e_4 = { error: e_4_1 }; } finally { try { if (fetches_1_1 && !fetches_1_1.done && (_a = fetches_1.return)) _a.call(fetches_1); } finally { if (e_4) throw e_4.error; } } } return { sorted: finalSorted, recipientCounts: recipientMap2Counts(finalRecipientMap) }; } function recipientMap2Counts(recipientMap) { var recipientCounts = {}; for (var name in recipientMap) { recipientCounts[name] = recipientMap[name].size; } return recipientCounts; } /** * Sort the `SymbolicTensor`s topologically, for a single fetch. * * This helper function processes the upstream SymbolicTensors of a single * fetch. * * @param fetch The single fetch requested. * @param feedDict The dictionary of fed values. * @returns sorted: Topologically-sorted array of SymbolicTensors. * recipientMap: Recipient names for all SymbolicTensors in `sorted`. */ function getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict) { var e_6, _a, e_7, _b; var visited = new Set(); var sorted = []; var recipientMap = {}; try { // Put keys of the feedDict into visited first, so they don't have to be // walked. This is needed in case where there are feeds for intermediate // SymbolicTensors of the graph. for (var _c = __values(feedDict.names()), _d = _c.next(); !_d.done; _d = _c.next()) { var key = _d.value; visited.add(key); } } catch (e_6_1) { e_6 = { error: e_6_1 }; } finally { try { if (_d && !_d.done && (_a = _c.return)) _a.call(_c); } finally { if (e_6) throw e_6.error; } } var stack = []; var marks = []; // Initial population of stack and marks. stack.push(fetch); while (stack.length > 0) { var top = stack[stack.length - 1]; if (visited.has(top.name)) { stack.pop(); continue; } var topIsMarked = marks[marks.length - 1] === stack.length - 1; if (top.inputs.length === 0 || topIsMarked) { // Input SymbolicTensor or all children have been visited. stack.pop(); sorted.push(top); visited.add(top.name); if (topIsMarked) { marks.pop(); } } else { // A non-input SymbolicTensor whose upstream SymbolicTensors haven't // been visited yet. Push them onto the stack. marks.push(stack.length - 1); try { for (var _e = (e_7 = void 0, __values(top.inputs)), _f = _e.next(); !_f.done; _f = _e.next()) { var input = _f.value; // Increment the recipient count. Note that this needs to happen // regardless of whether the SymbolicTensor has been visited before. if (recipientMap[input.name] == null) { recipientMap[input.name] = new Set(); } recipientMap[input.name].add(top.name); if (visited.has(input.name)) { continue; // Avoid repeated visits to the same SymbolicTensor. } stack.push(input); } } catch (e_7_1) { e_7 = { error: e_7_1 }; } finally { try { if (_f && !_f.done && (_b = _e.return)) _b.call(_e); } finally { if (e_7) throw e_7.error; } } } } return { sorted: sorted, recipientMap: recipientMap }; } /** * Get the symbolic output tensors of the node to which a given fetch belongs. * @param fetch The fetched symbolic tensor. * @returns The Array of symbolic tensors output by the node to which `fetch` * belongs. */ function getNodeOutputs(fetch) { var e_8, _a; var layerOutputs; if (fetch.sourceLayer.inboundNodes.length === 1) { layerOutputs = fetch.sourceLayer.output; } else { var nodeIndex = null; for (var i = 0; i < fetch.sourceLayer.inboundNodes.length; ++i) { try { for (var _b = (e_8 = void 0, __values(fetch.sourceLayer.inboundNodes[i] .outputTensors)), _c = _b.next(); !_c.done; _c = _b.next()) { var outputTensor = _c.value; if (outputTensor.id === fetch.id) { nodeIndex = i; break; } } } catch (e_8_1) { e_8 = { error: e_8_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_8) throw e_8.error; } } } layerOutputs = fetch.sourceLayer.getOutputAt(nodeIndex); } return layerOutputs; } /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ENV$2 = tfc.env(); /** The max number of entries for the caches of layers' topological sort. */ ENV$2.registerFlag('TOPOLOGICAL_SORT_CACHE_MAX_ENTRIES', function () { return 100; }, updateCacheMaxEntries); var Abs = 'Abs'; var Acos = 'Acos'; var Acosh = 'Acosh'; var Add$1 = 'Add'; var AddN = 'AddN'; var ArgMax = 'ArgMax'; var ArgMin = 'ArgMin'; var Asin = 'Asin'; var Asinh = 'Asinh'; var Atan = 'Atan'; var Atanh = 'Atanh'; var Atan2 = 'Atan2'; var AvgPool = 'AvgPool'; var AvgPoolGrad = 'AvgPoolGrad'; var AvgPool3D = 'AvgPool3D'; var AvgPool3DGrad = 'AvgPool3DGrad'; var BatchMatMul = 'BatchMatMul'; var BatchToSpaceND = 'BatchToSpaceND'; var BroadcastTo = 'BroadcastTo'; var Cast = 'Cast'; var Ceil = 'Ceil'; var ClipByValue = 'ClipByValue'; var Complex = 'Complex'; var ComplexAbs = 'ComplexAbs'; var Concat = 'Concat'; var Conv2D$1 = 'Conv2D'; var Conv2DBackpropFilter = 'Conv2DBackpropFilter'; var Conv2DBackpropInput = 'Conv2DBackpropInput'; var Conv3D$1 = 'Conv3D'; var Conv3DBackpropFilterV2 = 'Conv3DBackpropFilterV2'; var Conv3DBackpropInputV2 = 'Conv3DBackpropInputV2'; var Cos = 'Cos'; var Cosh = 'Cosh'; var Cumprod = 'Cumprod'; var Cumsum = 'Cumsum'; var DepthwiseConv2dNative = 'DepthwiseConv2dNative'; var DepthwiseConv2dNativeBackpropFilter = 'DepthwiseConv2dNativeBackpropFilter'; var DepthwiseConv2dNativeBackpropInput = 'DepthwiseConv2dNativeBackpropInput'; var Dilation2D = 'Dilation2D'; var Dilation2DBackpropInput = 'Dilation2DBackpropInput'; var Dilation2DBackpropFilter = 'Dilation2DBackpropFilter'; var RealDiv = 'RealDiv'; var Elu$1 = 'Elu'; var EluGrad = 'EluGrad'; var Erf = 'Erf'; var Equal = 'Equal'; var Exp = 'Exp'; var ExpandDims = 'ExpandDims'; var Expm1 = 'Expm1'; var Fill = 'Fill'; var Floor = 'Floor'; var FloorDiv = 'FloorDiv'; var FusedBatchNorm = 'FusedBatchNorm'; var GatherV2 = 'GatherV2'; var Greater = 'Greater'; var GreaterEqual = 'GreaterEqual'; var Identity = 'Identity'; var Imag = 'Imag'; var IsFinite = 'IsFinite'; var IsInf = 'IsInf'; var IsNan = 'IsNan'; var LeakyRelu = 'LeakyRelu'; var Less = 'Less'; var LessEqual = 'LessEqual'; var Log = 'Log'; var Log1p = 'Log1p'; var LogicalAnd = 'LogicalAnd'; var LogicalNot = 'LogicalNot'; var LogSoftmax$1 = 'LogSoftmax'; var LRN = 'LRN'; var LRNGrad = 'LRNGrad'; var Max = 'Max'; var Maximum$1 = 'Maximum'; var MaxPool = 'MaxPool'; var MaxPoolGrad = 'MaxPoolGrad'; var MaxPool3D = 'MaxPool3D'; var MaxPool3DGrad = 'MaxPool3DGrad'; var Mean = 'Mean'; var Min = 'Min'; var Minimum$1 = 'Minimum'; var MirrorPad = 'MirrorPad'; var Mod = 'Mod'; var Multiply$1 = 'Multiply'; var Neg = 'Neg'; var OnesLike = 'OnesLike'; var OneHot = 'OneHot'; var Pack = 'Pack'; var PadV2 = 'PadV2'; var Pow = 'Pow'; var Prelu = 'Prelu'; var Prod = 'Prod'; var Real = 'Real'; var Reciprocal = 'Reciprocal'; var Relu$1 = 'Relu'; var Reshape$1 = 'Reshape'; var ResizeNearestNeighbor = 'ResizeNearestNeighbor'; var ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad'; var ResizeBilinear = 'ResizeBilinear'; var ResizeBilinearGrad = 'ResizeBilinearGrad'; var Relu6$1 = 'Relu6'; var Reverse = 'Reverse'; var Round = 'Round'; var Rsqrt = 'Rsqrt'; var TensorScatterUpdate = 'TensorScatterUpdate'; var Select = 'Select'; var Selu$1 = 'Selu'; var Slice = 'Slice'; var Sin = 'Sin'; var Sinh = 'Sinh'; var Sign = 'Sign'; var Sigmoid$1 = 'Sigmoid'; var Softplus$1 = 'Softplus'; var Sqrt = 'Sqrt'; var Sum = 'Sum'; var SpaceToBatchND = 'SpaceToBatchND'; var SplitV = 'SplitV'; var Softmax$2 = 'Softmax'; var SquaredDifference = 'SquaredDifference'; var Square = 'Square'; var Sub = 'Sub'; var Tan = 'Tan'; var Tanh$1 = 'Tanh'; var Tile = 'Tile'; var Transpose = 'Transpose'; var Unpack = 'Unpack'; var UnsortedSegmentSum = 'UnsortedSegmentSum'; var ZerosLike = 'ZerosLike'; /** * TensorFlow.js-only kernels */ var Step = 'Step'; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var EPSILON_FLOAT32 = 1e-7; var EPSILON_FLOAT16 = 1e-4; /** * The interface that defines the kernels that should be implemented when * adding a new backend. New backends don't need to implement every one of the * methods, this can be done gradually (throw an error for unimplemented * methods). */ var KernelBackend = /** @class */ (function () { function KernelBackend() { } KernelBackend.prototype.refCount = function (dataId) { return notYetImplemented('refCount'); }; KernelBackend.prototype.incRef = function (dataId) { return notYetImplemented('incRef'); }; KernelBackend.prototype.timerAvailable = function () { return true; }; KernelBackend.prototype.time = function (f) { return notYetImplemented('time'); }; KernelBackend.prototype.read = function (dataId) { return notYetImplemented('read'); }; KernelBackend.prototype.readSync = function (dataId) { return notYetImplemented('readSync'); }; KernelBackend.prototype.readToGPU = function (dataId, options) { return notYetImplemented('readToGPU'); }; KernelBackend.prototype.numDataIds = function () { return notYetImplemented('numDataIds'); }; KernelBackend.prototype.disposeData = function (dataId, force) { return notYetImplemented('disposeData'); }; KernelBackend.prototype.write = function (values, shape, dtype) { return notYetImplemented('write'); }; KernelBackend.prototype.move = function (dataId, values, shape, dtype, refCount) { return notYetImplemented('move'); }; KernelBackend.prototype.createTensorFromGPUData = function (values, shape, dtype) { return notYetImplemented('createTensorFromGPUData'); }; KernelBackend.prototype.memory = function () { return notYetImplemented('memory'); }; /** Returns the highest precision for floats in bits (e.g. 16 or 32) */ KernelBackend.prototype.floatPrecision = function () { return notYetImplemented('floatPrecision'); }; /** Returns the smallest representable number. */ KernelBackend.prototype.epsilon = function () { return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16; }; KernelBackend.prototype.dispose = function () { return notYetImplemented('dispose'); }; return KernelBackend; }()); function notYetImplemented(kernelName) { throw new Error("'".concat(kernelName, "' not yet implemented or not found in the registry. ") + "This kernel may not be supported by the tfjs backend you have chosen"); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Asserts that the expression is true. Otherwise throws an error with the * provided message. * * ```js * const x = 2; * tf.util.assert(x === 2, 'x is not 2'); * ``` * * @param expr The expression to assert (as a boolean). * @param msg A function that returns the message to report when throwing an * error. We use a function for performance reasons. * * @doc {heading: 'Util', namespace: 'util'} */ function assert(expr, msg) { if (!expr) { throw new Error(typeof msg === 'string' ? msg : msg()); } } function assertShapesMatch(shapeA, shapeB, errorMessagePrefix) { if (errorMessagePrefix === void 0) { errorMessagePrefix = ''; } assert(arraysEqual(shapeA, shapeB), function () { return errorMessagePrefix + " Shapes ".concat(shapeA, " and ").concat(shapeB, " must match"); }); } /** * Returns the size (number of elements) of the tensor given its shape. * * ```js * const shape = [3, 4, 2]; * const size = tf.util.sizeFromShape(shape); * console.log(size); * ``` * * @doc {heading: 'Util', namespace: 'util'} */ function sizeFromShape(shape) { if (shape.length === 0) { // Scalar. return 1; } var size = shape[0]; for (var i = 1; i < shape.length; i++) { size *= shape[i]; } return size; } function arraysEqual(n1, n2) { if (n1 === n2) { return true; } if (n1 == null || n2 == null) { return false; } if (n1.length !== n2.length) { return false; } for (var i = 0; i < n1.length; i++) { if (n1[i] !== n2[i]) { return false; } } return true; } function isInt(a) { return a % 1 === 0; } function rightPad(a, size) { if (size <= a.length) { return a; } return a + ' '.repeat(size - a.length); } function parseAxisParam(axis, shape) { var rank = shape.length; // Normalize input axis = axis == null ? shape.map(function (s, i) { return i; }) : [].concat(axis); // Check for valid range assert(axis.every(function (ax) { return ax >= -rank && ax < rank; }), function () { return "All values in axis param must be in range [-".concat(rank, ", ").concat(rank, ") but ") + "got axis ".concat(axis); }); // Check for only integers assert(axis.every(function (ax) { return isInt(ax); }), function () { return "All values in axis param must be integers but " + "got axis ".concat(axis); }); // Handle negative axis. return axis.map(function (a) { return a < 0 ? rank + a : a; }); } function getArrayFromDType(dtype, size) { var values = null; if (dtype == null || dtype === 'float32') { values = new Float32Array(size); } else if (dtype === 'int32') { values = new Int32Array(size); } else if (dtype === 'bool') { values = new Uint8Array(size); } else if (dtype === 'string') { values = new Array(size); } else { throw new Error("Unknown data type ".concat(dtype)); } return values; } function checkConversionForErrors(vals, dtype) { for (var i = 0; i < vals.length; i++) { var num = vals[i]; if (isNaN(num) || !isFinite(num)) { throw Error("A tensor of type ".concat(dtype, " being uploaded contains ").concat(num, ".")); } } } /** Returns true if the dtype is valid. */ function isValidDtype(dtype) { return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' || dtype === 'int32' || dtype === 'string'; } function bytesPerElement(dtype) { if (dtype === 'float32' || dtype === 'int32') { return 4; } else if (dtype === 'complex64') { return 8; } else if (dtype === 'bool') { return 1; } else { throw new Error("Unknown dtype ".concat(dtype)); } } /** * Returns the approximate number of bytes allocated in the string array - 2 * bytes per character. Computing the exact bytes for a native string in JS * is not possible since it depends on the encoding of the html page that * serves the website. */ function bytesFromStringArray(arr) { if (arr == null) { return 0; } var bytes = 0; arr.forEach(function (x) { return bytes += x.length; }); return bytes; } /** Returns true if the value is a string. */ function isString(value) { return typeof value === 'string' || value instanceof String; } function isBoolean(value) { return typeof value === 'boolean'; } function isNumber(value) { return typeof value === 'number'; } function inferDtype(values) { if (Array.isArray(values)) { return inferDtype(values[0]); } if (values instanceof Float32Array) { return 'float32'; } else if (values instanceof Int32Array || values instanceof Uint8Array || values instanceof Uint8ClampedArray) { return 'int32'; } else if (isNumber(values)) { return 'float32'; } else if (isString(values)) { return 'string'; } else if (isBoolean(values)) { return 'bool'; } return 'float32'; } function isFunction(f) { return !!(f && f.constructor && f.call && f.apply); } function computeStrides(shape) { var rank = shape.length; if (rank < 2) { return []; } // Last dimension has implicit stride of 1, thus having D-1 (instead of D) // strides. var strides = new Array(rank - 1); strides[rank - 2] = shape[rank - 1]; for (var i = rank - 3; i >= 0; --i) { strides[i] = strides[i + 1] * shape[i + 1]; } return strides; } function createNestedArray(offset, shape, a, isComplex) { if (isComplex === void 0) { isComplex = false; } var ret = new Array(); if (shape.length === 1) { var d = shape[0] * (isComplex ? 2 : 1); for (var i = 0; i < d; i++) { ret[i] = a[offset + i]; } } else { var d = shape[0]; var rest = shape.slice(1); var len = rest.reduce(function (acc, c) { return acc * c; }) * (isComplex ? 2 : 1); for (var i = 0; i < d; i++) { ret[i] = createNestedArray(offset + i * len, rest, a, isComplex); } } return ret; } // Provide a nested array of TypedArray in given shape. function toNestedArray(shape, a, isComplex) { if (isComplex === void 0) { isComplex = false; } if (shape.length === 0) { // Scalar type should return a single number. return a[0]; } var size = shape.reduce(function (acc, c) { return acc * c; }) * (isComplex ? 2 : 1); if (size === 0) { // A tensor with shape zero should be turned into empty list. return []; } if (size !== a.length) { throw new Error("[".concat(shape, "] does not match the input size ").concat(a.length).concat(isComplex ? ' for a complex tensor' : '', ".")); } return createNestedArray(0, shape, a, isComplex); } function makeOnesTypedArray(size, dtype) { var array = makeZerosTypedArray(size, dtype); for (var i = 0; i < array.length; i++) { array[i] = 1; } return array; } function makeZerosTypedArray(size, dtype) { if (dtype == null || dtype === 'float32' || dtype === 'complex64') { return new Float32Array(size); } else if (dtype === 'int32') { return new Int32Array(size); } else if (dtype === 'bool') { return new Uint8Array(size); } else { throw new Error("Unknown data type ".concat(dtype)); } } function assertNonNegativeIntegerDimensions(shape) { shape.forEach(function (dimSize) { assert(Number.isInteger(dimSize) && dimSize >= 0, function () { return "Tensor must have a shape comprised of positive integers but got " + "shape [".concat(shape, "]."); }); }); } /** * This method asserts whether an object is a Promise instance. * @param object */ // tslint:disable-next-line: no-any function isPromise(object) { // We chose to not use 'obj instanceOf Promise' for two reasons: // 1. It only reliably works for es6 Promise, not other Promise // implementations. // 2. It doesn't work with framework that uses zone.js. zone.js monkey // patch the async calls, so it is possible the obj (patched) is // comparing to a pre-patched Promise. return object && object.then && typeof object.then === 'function'; } // Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true. var TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags'; /** * The environment contains evaluated flags as well as the registered platform. * This is always used as a global singleton and can be retrieved with * `tf.env()`. * * @doc {heading: 'Environment'} */ var Environment = /** @class */ (function () { // tslint:disable-next-line: no-any function Environment(global) { this.global = global; this.flags = {}; this.flagRegistry = {}; this.urlFlags = {}; // Jasmine spies on this in 'environment_test.ts' this.getQueryParams = getQueryParams; this.populateURLFlags(); } Environment.prototype.setPlatform = function (platformName, platform) { if (this.platform != null) { if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) { console.warn("Platform ".concat(this.platformName, " has already been set. ") + "Overwriting the platform with ".concat(platformName, ".")); } } this.platformName = platformName; this.platform = platform; }; Environment.prototype.registerFlag = function (flagName, evaluationFn, setHook) { this.flagRegistry[flagName] = { evaluationFn: evaluationFn, setHook: setHook }; // Override the flag value from the URL. This has to happen here because // the environment is initialized before flags get registered. if (this.urlFlags[flagName] != null) { var flagValue = this.urlFlags[flagName]; if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) { console.warn("Setting feature override from URL ".concat(flagName, ": ").concat(flagValue, ".")); } this.set(flagName, flagValue); } }; Environment.prototype.getAsync = function (flagName) { return __awaiter(this, void 0, void 0, function () { var _a, _b; return __generator(this, function (_c) { switch (_c.label) { case 0: if (flagName in this.flags) { return [2 /*return*/, this.flags[flagName]]; } _a = this.flags; _b = flagName; return [4 /*yield*/, this.evaluateFlag(flagName)]; case 1: _a[_b] = _c.sent(); return [2 /*return*/, this.flags[flagName]]; } }); }); }; Environment.prototype.get = function (flagName) { if (flagName in this.flags) { return this.flags[flagName]; } var flagValue = this.evaluateFlag(flagName); if (isPromise(flagValue)) { throw new Error("Flag ".concat(flagName, " cannot be synchronously evaluated. ") + "Please use getAsync() instead."); } this.flags[flagName] = flagValue; return this.flags[flagName]; }; Environment.prototype.getNumber = function (flagName) { return this.get(flagName); }; Environment.prototype.getBool = function (flagName) { return this.get(flagName); }; Environment.prototype.getString = function (flagName) { return this.get(flagName); }; Environment.prototype.getFlags = function () { return this.flags; }; Object.defineProperty(Environment.prototype, "features", { // For backwards compatibility. get: function () { return this.flags; }, enumerable: false, configurable: true }); Environment.prototype.set = function (flagName, value) { if (this.flagRegistry[flagName] == null) { throw new Error("Cannot set flag ".concat(flagName, " as it has not been registered.")); } this.flags[flagName] = value; if (this.flagRegistry[flagName].setHook != null) { this.flagRegistry[flagName].setHook(value); } }; Environment.prototype.evaluateFlag = function (flagName) { if (this.flagRegistry[flagName] == null) { throw new Error("Cannot evaluate flag '".concat(flagName, "': no evaluation function found.")); } return this.flagRegistry[flagName].evaluationFn(); }; Environment.prototype.setFlags = function (flags) { this.flags = Object.assign({}, flags); }; Environment.prototype.reset = function () { this.flags = {}; this.urlFlags = {}; this.populateURLFlags(); }; Environment.prototype.populateURLFlags = function () { var _this = this; if (typeof this.global === 'undefined' || typeof this.global.location === 'undefined' || typeof this.global.location.search === 'undefined') { return; } var urlParams = this.getQueryParams(this.global.location.search); if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) { var keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(','); keyValues.forEach(function (keyValue) { var _a = __read(keyValue.split(':'), 2), key = _a[0], value = _a[1]; _this.urlFlags[key] = parseValue(key, value); }); } }; return Environment; }()); function getQueryParams(queryString) { var params = {}; queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, function (s) { var t = []; for (var _i = 1; _i < arguments.length; _i++) { t[_i - 1] = arguments[_i]; } decodeParam(params, t[0], t[1]); return t.join('='); }); return params; } function decodeParam(params, name, value) { params[decodeURIComponent(name)] = decodeURIComponent(value || ''); } function parseValue(flagName, value) { var lowerCaseValue = value.toLowerCase(); if (lowerCaseValue === 'true' || lowerCaseValue === 'false') { return lowerCaseValue === 'true'; } else if ("".concat(+lowerCaseValue) === lowerCaseValue) { return +lowerCaseValue; } else { return value; } } /** * Returns the current environment (a global singleton). * * The environment object contains the evaluated feature values as well as the * active platform. * * @doc {heading: 'Environment'} */ function env() { return ENV$1; } var ENV$1 = null; function setEnvironmentGlobal(environment) { ENV$1 = environment; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Note that the identifier globalNameSpace is scoped to this module, but will // always resolve to the same global object regardless of how the module is // resolved. // tslint:disable-next-line:no-any var globalNameSpace; // tslint:disable-next-line:no-any function getGlobalNamespace() { if (globalNameSpace == null) { // tslint:disable-next-line:no-any var ns = void 0; if (typeof (window) !== 'undefined') { ns = window; } else if (typeof (global) !== 'undefined') { ns = global; } else if (typeof (process) !== 'undefined') { ns = process; } else if (typeof (self) !== 'undefined') { ns = self; } else { throw new Error('Could not find a global object'); } globalNameSpace = ns; } return globalNameSpace; } // tslint:disable-next-line:no-any function getGlobalMap() { var ns = getGlobalNamespace(); if (ns._tfGlobals == null) { ns._tfGlobals = new Map(); } return ns._tfGlobals; } /** * Returns a globally accessible 'singleton' object. * * @param key the name of the object * @param init a function to initialize to initialize this object * the first time it is fetched. */ function getGlobal(key, init) { var globalMap = getGlobalMap(); if (globalMap.has(key)) { return globalMap.get(key); } else { var singleton = init(); globalMap.set(key, singleton); return globalMap.get(key); } } function warn() { var msg = []; for (var _i = 0; _i < arguments.length; _i++) { msg[_i] = arguments[_i]; } if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) { console.warn.apply(console, __spreadArray([], __read(msg), false)); } } var kernelRegistry = getGlobal('kernelRegistry', function () { return new Map(); }); var gradRegistry = getGlobal('gradRegistry', function () { return new Map(); }); /** * Returns the kernel function (code) associated with the provided names. * * @param kernelName The official name of the kernel. * @param backendName The official name of the backend. */ function getKernel(kernelName, backendName) { var key = makeKey(kernelName, backendName); return kernelRegistry.get(key); } /** * Returns the registered gradient info associated with the provided kernel. * @param kernelName The official TF kernel name. */ function getGradient(kernelName) { return gradRegistry.get(kernelName); } function getKernelsForBackend(backendName) { var it = kernelRegistry.entries(); var result = []; while (true) { var _a = it.next(), done = _a.done, value = _a.value; if (done) { break; } var _b = __read(value, 2), key = _b[0], config = _b[1]; var _c = __read(key.split('_'), 1), backend = _c[0]; if (backend === backendName) { result.push(config); } } return result; } /** * Registers a gradient function for a given kernel in the global registry, * to be used during the back-propagation of that kernel. * * @param config An object with the following properties: * - `kernelName` The name of the kernel that the gradient function is for. * - `gradFunc` The function to run during back-propagation. */ function registerGradient(config) { var kernelName = config.kernelName; if (gradRegistry.has(kernelName)) { // TODO (yassogba) after 3.0 assess whether we need to keep this gated // to debug mode. if (env().getBool('DEBUG')) { warn("Overriding the gradient for '".concat(kernelName, "'")); } } gradRegistry.set(kernelName, config); } function makeKey(kernelName, backendName) { return "".concat(backendName, "_").concat(kernelName); } /** * @license * Copyright 2023 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function isTypedArrayBrowser(a) { return a instanceof Float32Array || a instanceof Int32Array || a instanceof Uint8Array || a instanceof Uint8ClampedArray; } var commonjsGlobal = typeof globalThis !== 'undefined' ? globalThis : typeof window !== 'undefined' ? window : typeof global !== 'undefined' ? global : typeof self !== 'undefined' ? self : {}; function getDefaultExportFromCjs(x) { return x && x.__esModule && Object.prototype.hasOwnProperty.call(x, 'default') ? x['default'] : x; } function getAugmentedNamespace(n) { if (n.__esModule) return n; var f = n.default; if (typeof f == "function") { var a = function a() { if (this instanceof a) { var args = [null]; args.push.apply(args, arguments); var Ctor = Function.bind.apply(f, args); return new Ctor(); } return f.apply(this, arguments); }; a.prototype = f.prototype; } else a = {}; Object.defineProperty(a, '__esModule', { value: true }); Object.keys(n).forEach(function (k) { var d = Object.getOwnPropertyDescriptor(n, k); Object.defineProperty(a, k, d.get ? d : { enumerable: true, get: function () { return n[k]; } }); }); return a; } var long = Long$1; /** * wasm optimizations, to do native i64 multiplication and divide */ var wasm = null; try { wasm = new WebAssembly.Instance(new WebAssembly.Module(new Uint8Array([ 0, 97, 115, 109, 1, 0, 0, 0, 1, 13, 2, 96, 0, 1, 127, 96, 4, 127, 127, 127, 127, 1, 127, 3, 7, 6, 0, 1, 1, 1, 1, 1, 6, 6, 1, 127, 1, 65, 0, 11, 7, 50, 6, 3, 109, 117, 108, 0, 1, 5, 100, 105, 118, 95, 115, 0, 2, 5, 100, 105, 118, 95, 117, 0, 3, 5, 114, 101, 109, 95, 115, 0, 4, 5, 114, 101, 109, 95, 117, 0, 5, 8, 103, 101, 116, 95, 104, 105, 103, 104, 0, 0, 10, 191, 1, 6, 4, 0, 35, 0, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 126, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 127, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 128, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 129, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 130, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11 ])), {}).exports; } catch (e) { // no wasm support :( } /** * Constructs a 64 bit two's-complement integer, given its low and high 32 bit values as *signed* integers. * See the from* functions below for more convenient ways of constructing Longs. * @exports Long * @class A Long class for representing a 64 bit two's-complement integer value. * @param {number} low The low (signed) 32 bits of the long * @param {number} high The high (signed) 32 bits of the long * @param {boolean=} unsigned Whether unsigned or not, defaults to signed * @constructor */ function Long$1(low, high, unsigned) { /** * The low 32 bits as a signed value. * @type {number} */ this.low = low | 0; /** * The high 32 bits as a signed value. * @type {number} */ this.high = high | 0; /** * Whether unsigned or not. * @type {boolean} */ this.unsigned = !!unsigned; } // The internal representation of a long is the two given signed, 32-bit values. // We use 32-bit pieces because these are the size of integers on which // Javascript performs bit-operations. For operations like addition and // multiplication, we split each number into 16 bit pieces, which can easily be // multiplied within Javascript's floating-point representation without overflow // or change in sign. // // In the algorithms below, we frequently reduce the negative case to the // positive case by negating the input(s) and then post-processing the result. // Note that we must ALWAYS check specially whether those values are MIN_VALUE // (-2^63) because -MIN_VALUE == MIN_VALUE (since 2^63 cannot be represented as // a positive number, it overflows back into a negative). Not handling this // case would often result in infinite recursion. // // Common constant values ZERO, ONE, NEG_ONE, etc. are defined below the from* // methods on which they depend. /** * An indicator used to reliably determine if an object is a Long or not. * @type {boolean} * @const * @private */ Long$1.prototype.__isLong__; Object.defineProperty(Long$1.prototype, "__isLong__", { value: true }); /** * @function * @param {*} obj Object * @returns {boolean} * @inner */ function isLong(obj) { return (obj && obj["__isLong__"]) === true; } /** * Tests if the specified object is a Long. * @function * @param {*} obj Object * @returns {boolean} */ Long$1.isLong = isLong; /** * A cache of the Long representations of small integer values. * @type {!Object} * @inner */ var INT_CACHE = {}; /** * A cache of the Long representations of small unsigned integer values. * @type {!Object} * @inner */ var UINT_CACHE = {}; /** * @param {number} value * @param {boolean=} unsigned * @returns {!Long} * @inner */ function fromInt(value, unsigned) { var obj, cachedObj, cache; if (unsigned) { value >>>= 0; if (cache = (0 <= value && value < 256)) { cachedObj = UINT_CACHE[value]; if (cachedObj) return cachedObj; } obj = fromBits(value, (value | 0) < 0 ? -1 : 0, true); if (cache) UINT_CACHE[value] = obj; return obj; } else { value |= 0; if (cache = (-128 <= value && value < 128)) { cachedObj = INT_CACHE[value]; if (cachedObj) return cachedObj; } obj = fromBits(value, value < 0 ? -1 : 0, false); if (cache) INT_CACHE[value] = obj; return obj; } } /** * Returns a Long representing the given 32 bit integer value. * @function * @param {number} value The 32 bit integer in question * @param {boolean=} unsigned Whether unsigned or not, defaults to signed * @returns {!Long} The corresponding Long value */ Long$1.fromInt = fromInt; /** * @param {number} value * @param {boolean=} unsigned * @returns {!Long} * @inner */ function fromNumber(value, unsigned) { if (isNaN(value)) return unsigned ? UZERO : ZERO; if (unsigned) { if (value < 0) return UZERO; if (value >= TWO_PWR_64_DBL) return MAX_UNSIGNED_VALUE; } else { if (value <= -TWO_PWR_63_DBL) return MIN_VALUE; if (value + 1 >= TWO_PWR_63_DBL) return MAX_VALUE; } if (value < 0) return fromNumber(-value, unsigned).neg(); return fromBits((value % TWO_PWR_32_DBL) | 0, (value / TWO_PWR_32_DBL) | 0, unsigned); } /** * Returns a Long representing the given value, provided that it is a finite number. Otherwise, zero is returned. * @function * @param {number} value The number in question * @param {boolean=} unsigned Whether unsigned or not, defaults to signed * @returns {!Long} The corresponding Long value */ Long$1.fromNumber = fromNumber; /** * @param {number} lowBits * @param {number} highBits * @param {boolean=} unsigned * @returns {!Long} * @inner */ function fromBits(lowBits, highBits, unsigned) { return new Long$1(lowBits, highBits, unsigned); } /** * Returns a Long representing the 64 bit integer that comes by concatenating the given low and high bits. Each is * assumed to use 32 bits. * @function * @param {number} lowBits The low 32 bits * @param {number} highBits The high 32 bits * @param {boolean=} unsigned Whether unsigned or not, defaults to signed * @returns {!Long} The corresponding Long value */ Long$1.fromBits = fromBits; /** * @function * @param {number} base * @param {number} exponent * @returns {number} * @inner */ var pow_dbl = Math.pow; // Used 4 times (4*8 to 15+4) /** * @param {string} str * @param {(boolean|number)=} unsigned * @param {number=} radix * @returns {!Long} * @inner */ function fromString(str, unsigned, radix) { if (str.length === 0) throw Error('empty string'); if (str === "NaN" || str === "Infinity" || str === "+Infinity" || str === "-Infinity") return ZERO; if (typeof unsigned === 'number') { // For goog.math.long compatibility radix = unsigned, unsigned = false; } else { unsigned = !!unsigned; } radix = radix || 10; if (radix < 2 || 36 < radix) throw RangeError('radix'); var p; if ((p = str.indexOf('-')) > 0) throw Error('interior hyphen'); else if (p === 0) { return fromString(str.substring(1), unsigned, radix).neg(); } // Do several (8) digits each time through the loop, so as to // minimize the calls to the very expensive emulated div. var radixToPower = fromNumber(pow_dbl(radix, 8)); var result = ZERO; for (var i = 0; i < str.length; i += 8) { var size = Math.min(8, str.length - i), value = parseInt(str.substring(i, i + size), radix); if (size < 8) { var power = fromNumber(pow_dbl(radix, size)); result = result.mul(power).add(fromNumber(value)); } else { result = result.mul(radixToPower); result = result.add(fromNumber(value)); } } result.unsigned = unsigned; return result; } /** * Returns a Long representation of the given string, written using the specified radix. * @function * @param {string} str The textual representation of the Long * @param {(boolean|number)=} unsigned Whether unsigned or not, defaults to signed * @param {number=} radix The radix in which the text is written (2-36), defaults to 10 * @returns {!Long} The corresponding Long value */ Long$1.fromString = fromString; /** * @function * @param {!Long|number|string|!{low: number, high: number, unsigned: boolean}} val * @param {boolean=} unsigned * @returns {!Long} * @inner */ function fromValue(val, unsigned) { if (typeof val === 'number') return fromNumber(val, unsigned); if (typeof val === 'string') return fromString(val, unsigned); // Throws for non-objects, converts non-instanceof Long: return fromBits(val.low, val.high, typeof unsigned === 'boolean' ? unsigned : val.unsigned); } /** * Converts the specified value to a Long using the appropriate from* function for its type. * @function * @param {!Long|number|string|!{low: number, high: number, unsigned: boolean}} val Value * @param {boolean=} unsigned Whether unsigned or not, defaults to signed * @returns {!Long} */ Long$1.fromValue = fromValue; // NOTE: the compiler should inline these constant values below and then remove these variables, so there should be // no runtime penalty for these. /** * @type {number} * @const * @inner */ var TWO_PWR_16_DBL = 1 << 16; /** * @type {number} * @const * @inner */ var TWO_PWR_24_DBL = 1 << 24; /** * @type {number} * @const * @inner */ var TWO_PWR_32_DBL = TWO_PWR_16_DBL * TWO_PWR_16_DBL; /** * @type {number} * @const * @inner */ var TWO_PWR_64_DBL = TWO_PWR_32_DBL * TWO_PWR_32_DBL; /** * @type {number} * @const * @inner */ var TWO_PWR_63_DBL = TWO_PWR_64_DBL / 2; /** * @type {!Long} * @const * @inner */ var TWO_PWR_24 = fromInt(TWO_PWR_24_DBL); /** * @type {!Long} * @inner */ var ZERO = fromInt(0); /** * Signed zero. * @type {!Long} */ Long$1.ZERO = ZERO; /** * @type {!Long} * @inner */ var UZERO = fromInt(0, true); /** * Unsigned zero. * @type {!Long} */ Long$1.UZERO = UZERO; /** * @type {!Long} * @inner */ var ONE = fromInt(1); /** * Signed one. * @type {!Long} */ Long$1.ONE = ONE; /** * @type {!Long} * @inner */ var UONE = fromInt(1, true); /** * Unsigned one. * @type {!Long} */ Long$1.UONE = UONE; /** * @type {!Long} * @inner */ var NEG_ONE = fromInt(-1); /** * Signed negative one. * @type {!Long} */ Long$1.NEG_ONE = NEG_ONE; /** * @type {!Long} * @inner */ var MAX_VALUE = fromBits(0xFFFFFFFF | 0, 0x7FFFFFFF | 0, false); /** * Maximum signed value. * @type {!Long} */ Long$1.MAX_VALUE = MAX_VALUE; /** * @type {!Long} * @inner */ var MAX_UNSIGNED_VALUE = fromBits(0xFFFFFFFF | 0, 0xFFFFFFFF | 0, true); /** * Maximum unsigned value. * @type {!Long} */ Long$1.MAX_UNSIGNED_VALUE = MAX_UNSIGNED_VALUE; /** * @type {!Long} * @inner */ var MIN_VALUE = fromBits(0, 0x80000000 | 0, false); /** * Minimum signed value. * @type {!Long} */ Long$1.MIN_VALUE = MIN_VALUE; /** * @alias Long.prototype * @inner */ var LongPrototype = Long$1.prototype; /** * Converts the Long to a 32 bit integer, assuming it is a 32 bit integer. * @returns {number} */ LongPrototype.toInt = function toInt() { return this.unsigned ? this.low >>> 0 : this.low; }; /** * Converts the Long to a the nearest floating-point representation of this value (double, 53 bit mantissa). * @returns {number} */ LongPrototype.toNumber = function toNumber() { if (this.unsigned) return ((this.high >>> 0) * TWO_PWR_32_DBL) + (this.low >>> 0); return this.high * TWO_PWR_32_DBL + (this.low >>> 0); }; /** * Converts the Long to a string written in the specified radix. * @param {number=} radix Radix (2-36), defaults to 10 * @returns {string} * @override * @throws {RangeError} If `radix` is out of range */ LongPrototype.toString = function toString(radix) { radix = radix || 10; if (radix < 2 || 36 < radix) throw RangeError('radix'); if (this.isZero()) return '0'; if (this.isNegative()) { // Unsigned Longs are never negative if (this.eq(MIN_VALUE)) { // We need to change the Long value before it can be negated, so we remove // the bottom-most digit in this base and then recurse to do the rest. var radixLong = fromNumber(radix), div = this.div(radixLong), rem1 = div.mul(radixLong).sub(this); return div.toString(radix) + rem1.toInt().toString(radix); } else return '-' + this.neg().toString(radix); } // Do several (6) digits each time through the loop, so as to // minimize the calls to the very expensive emulated div. var radixToPower = fromNumber(pow_dbl(radix, 6), this.unsigned), rem = this; var result = ''; while (true) { var remDiv = rem.div(radixToPower), intval = rem.sub(remDiv.mul(radixToPower)).toInt() >>> 0, digits = intval.toString(radix); rem = remDiv; if (rem.isZero()) return digits + result; else { while (digits.length < 6) digits = '0' + digits; result = '' + digits + result; } } }; /** * Gets the high 32 bits as a signed integer. * @returns {number} Signed high bits */ LongPrototype.getHighBits = function getHighBits() { return this.high; }; /** * Gets the high 32 bits as an unsigned integer. * @returns {number} Unsigned high bits */ LongPrototype.getHighBitsUnsigned = function getHighBitsUnsigned() { return this.high >>> 0; }; /** * Gets the low 32 bits as a signed integer. * @returns {number} Signed low bits */ LongPrototype.getLowBits = function getLowBits() { return this.low; }; /** * Gets the low 32 bits as an unsigned integer. * @returns {number} Unsigned low bits */ LongPrototype.getLowBitsUnsigned = function getLowBitsUnsigned() { return this.low >>> 0; }; /** * Gets the number of bits needed to represent the absolute value of this Long. * @returns {number} */ LongPrototype.getNumBitsAbs = function getNumBitsAbs() { if (this.isNegative()) // Unsigned Longs are never negative return this.eq(MIN_VALUE) ? 64 : this.neg().getNumBitsAbs(); var val = this.high != 0 ? this.high : this.low; for (var bit = 31; bit > 0; bit--) if ((val & (1 << bit)) != 0) break; return this.high != 0 ? bit + 33 : bit + 1; }; /** * Tests if this Long's value equals zero. * @returns {boolean} */ LongPrototype.isZero = function isZero() { return this.high === 0 && this.low === 0; }; /** * Tests if this Long's value equals zero. This is an alias of {@link Long#isZero}. * @returns {boolean} */ LongPrototype.eqz = LongPrototype.isZero; /** * Tests if this Long's value is negative. * @returns {boolean} */ LongPrototype.isNegative = function isNegative() { return !this.unsigned && this.high < 0; }; /** * Tests if this Long's value is positive. * @returns {boolean} */ LongPrototype.isPositive = function isPositive() { return this.unsigned || this.high >= 0; }; /** * Tests if this Long's value is odd. * @returns {boolean} */ LongPrototype.isOdd = function isOdd() { return (this.low & 1) === 1; }; /** * Tests if this Long's value is even. * @returns {boolean} */ LongPrototype.isEven = function isEven() { return (this.low & 1) === 0; }; /** * Tests if this Long's value equals the specified's. * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.equals = function equals(other) { if (!isLong(other)) other = fromValue(other); if (this.unsigned !== other.unsigned && (this.high >>> 31) === 1 && (other.high >>> 31) === 1) return false; return this.high === other.high && this.low === other.low; }; /** * Tests if this Long's value equals the specified's. This is an alias of {@link Long#equals}. * @function * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.eq = LongPrototype.equals; /** * Tests if this Long's value differs from the specified's. * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.notEquals = function notEquals(other) { return !this.eq(/* validates */ other); }; /** * Tests if this Long's value differs from the specified's. This is an alias of {@link Long#notEquals}. * @function * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.neq = LongPrototype.notEquals; /** * Tests if this Long's value differs from the specified's. This is an alias of {@link Long#notEquals}. * @function * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.ne = LongPrototype.notEquals; /** * Tests if this Long's value is less than the specified's. * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.lessThan = function lessThan(other) { return this.comp(/* validates */ other) < 0; }; /** * Tests if this Long's value is less than the specified's. This is an alias of {@link Long#lessThan}. * @function * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.lt = LongPrototype.lessThan; /** * Tests if this Long's value is less than or equal the specified's. * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.lessThanOrEqual = function lessThanOrEqual(other) { return this.comp(/* validates */ other) <= 0; }; /** * Tests if this Long's value is less than or equal the specified's. This is an alias of {@link Long#lessThanOrEqual}. * @function * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.lte = LongPrototype.lessThanOrEqual; /** * Tests if this Long's value is less than or equal the specified's. This is an alias of {@link Long#lessThanOrEqual}. * @function * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.le = LongPrototype.lessThanOrEqual; /** * Tests if this Long's value is greater than the specified's. * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.greaterThan = function greaterThan(other) { return this.comp(/* validates */ other) > 0; }; /** * Tests if this Long's value is greater than the specified's. This is an alias of {@link Long#greaterThan}. * @function * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.gt = LongPrototype.greaterThan; /** * Tests if this Long's value is greater than or equal the specified's. * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.greaterThanOrEqual = function greaterThanOrEqual(other) { return this.comp(/* validates */ other) >= 0; }; /** * Tests if this Long's value is greater than or equal the specified's. This is an alias of {@link Long#greaterThanOrEqual}. * @function * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.gte = LongPrototype.greaterThanOrEqual; /** * Tests if this Long's value is greater than or equal the specified's. This is an alias of {@link Long#greaterThanOrEqual}. * @function * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.ge = LongPrototype.greaterThanOrEqual; /** * Compares this Long's value with the specified's. * @param {!Long|number|string} other Other value * @returns {number} 0 if they are the same, 1 if the this is greater and -1 * if the given one is greater */ LongPrototype.compare = function compare(other) { if (!isLong(other)) other = fromValue(other); if (this.eq(other)) return 0; var thisNeg = this.isNegative(), otherNeg = other.isNegative(); if (thisNeg && !otherNeg) return -1; if (!thisNeg && otherNeg) return 1; // At this point the sign bits are the same if (!this.unsigned) return this.sub(other).isNegative() ? -1 : 1; // Both are positive if at least one is unsigned return (other.high >>> 0) > (this.high >>> 0) || (other.high === this.high && (other.low >>> 0) > (this.low >>> 0)) ? -1 : 1; }; /** * Compares this Long's value with the specified's. This is an alias of {@link Long#compare}. * @function * @param {!Long|number|string} other Other value * @returns {number} 0 if they are the same, 1 if the this is greater and -1 * if the given one is greater */ LongPrototype.comp = LongPrototype.compare; /** * Negates this Long's value. * @returns {!Long} Negated Long */ LongPrototype.negate = function negate() { if (!this.unsigned && this.eq(MIN_VALUE)) return MIN_VALUE; return this.not().add(ONE); }; /** * Negates this Long's value. This is an alias of {@link Long#negate}. * @function * @returns {!Long} Negated Long */ LongPrototype.neg = LongPrototype.negate; /** * Returns the sum of this and the specified Long. * @param {!Long|number|string} addend Addend * @returns {!Long} Sum */ LongPrototype.add = function add(addend) { if (!isLong(addend)) addend = fromValue(addend); // Divide each number into 4 chunks of 16 bits, and then sum the chunks. var a48 = this.high >>> 16; var a32 = this.high & 0xFFFF; var a16 = this.low >>> 16; var a00 = this.low & 0xFFFF; var b48 = addend.high >>> 16; var b32 = addend.high & 0xFFFF; var b16 = addend.low >>> 16; var b00 = addend.low & 0xFFFF; var c48 = 0, c32 = 0, c16 = 0, c00 = 0; c00 += a00 + b00; c16 += c00 >>> 16; c00 &= 0xFFFF; c16 += a16 + b16; c32 += c16 >>> 16; c16 &= 0xFFFF; c32 += a32 + b32; c48 += c32 >>> 16; c32 &= 0xFFFF; c48 += a48 + b48; c48 &= 0xFFFF; return fromBits((c16 << 16) | c00, (c48 << 16) | c32, this.unsigned); }; /** * Returns the difference of this and the specified Long. * @param {!Long|number|string} subtrahend Subtrahend * @returns {!Long} Difference */ LongPrototype.subtract = function subtract(subtrahend) { if (!isLong(subtrahend)) subtrahend = fromValue(subtrahend); return this.add(subtrahend.neg()); }; /** * Returns the difference of this and the specified Long. This is an alias of {@link Long#subtract}. * @function * @param {!Long|number|string} subtrahend Subtrahend * @returns {!Long} Difference */ LongPrototype.sub = LongPrototype.subtract; /** * Returns the product of this and the specified Long. * @param {!Long|number|string} multiplier Multiplier * @returns {!Long} Product */ LongPrototype.multiply = function multiply(multiplier) { if (this.isZero()) return ZERO; if (!isLong(multiplier)) multiplier = fromValue(multiplier); // use wasm support if present if (wasm) { var low = wasm.mul(this.low, this.high, multiplier.low, multiplier.high); return fromBits(low, wasm.get_high(), this.unsigned); } if (multiplier.isZero()) return ZERO; if (this.eq(MIN_VALUE)) return multiplier.isOdd() ? MIN_VALUE : ZERO; if (multiplier.eq(MIN_VALUE)) return this.isOdd() ? MIN_VALUE : ZERO; if (this.isNegative()) { if (multiplier.isNegative()) return this.neg().mul(multiplier.neg()); else return this.neg().mul(multiplier).neg(); } else if (multiplier.isNegative()) return this.mul(multiplier.neg()).neg(); // If both longs are small, use float multiplication if (this.lt(TWO_PWR_24) && multiplier.lt(TWO_PWR_24)) return fromNumber(this.toNumber() * multiplier.toNumber(), this.unsigned); // Divide each long into 4 chunks of 16 bits, and then add up 4x4 products. // We can skip products that would overflow. var a48 = this.high >>> 16; var a32 = this.high & 0xFFFF; var a16 = this.low >>> 16; var a00 = this.low & 0xFFFF; var b48 = multiplier.high >>> 16; var b32 = multiplier.high & 0xFFFF; var b16 = multiplier.low >>> 16; var b00 = multiplier.low & 0xFFFF; var c48 = 0, c32 = 0, c16 = 0, c00 = 0; c00 += a00 * b00; c16 += c00 >>> 16; c00 &= 0xFFFF; c16 += a16 * b00; c32 += c16 >>> 16; c16 &= 0xFFFF; c16 += a00 * b16; c32 += c16 >>> 16; c16 &= 0xFFFF; c32 += a32 * b00; c48 += c32 >>> 16; c32 &= 0xFFFF; c32 += a16 * b16; c48 += c32 >>> 16; c32 &= 0xFFFF; c32 += a00 * b32; c48 += c32 >>> 16; c32 &= 0xFFFF; c48 += a48 * b00 + a32 * b16 + a16 * b32 + a00 * b48; c48 &= 0xFFFF; return fromBits((c16 << 16) | c00, (c48 << 16) | c32, this.unsigned); }; /** * Returns the product of this and the specified Long. This is an alias of {@link Long#multiply}. * @function * @param {!Long|number|string} multiplier Multiplier * @returns {!Long} Product */ LongPrototype.mul = LongPrototype.multiply; /** * Returns this Long divided by the specified. The result is signed if this Long is signed or * unsigned if this Long is unsigned. * @param {!Long|number|string} divisor Divisor * @returns {!Long} Quotient */ LongPrototype.divide = function divide(divisor) { if (!isLong(divisor)) divisor = fromValue(divisor); if (divisor.isZero()) throw Error('division by zero'); // use wasm support if present if (wasm) { // guard against signed division overflow: the largest // negative number / -1 would be 1 larger than the largest // positive number, due to two's complement. if (!this.unsigned && this.high === -0x80000000 && divisor.low === -1 && divisor.high === -1) { // be consistent with non-wasm code path return this; } var low = (this.unsigned ? wasm.div_u : wasm.div_s)(this.low, this.high, divisor.low, divisor.high); return fromBits(low, wasm.get_high(), this.unsigned); } if (this.isZero()) return this.unsigned ? UZERO : ZERO; var approx, rem, res; if (!this.unsigned) { // This section is only relevant for signed longs and is derived from the // closure library as a whole. if (this.eq(MIN_VALUE)) { if (divisor.eq(ONE) || divisor.eq(NEG_ONE)) return MIN_VALUE; // recall that -MIN_VALUE == MIN_VALUE else if (divisor.eq(MIN_VALUE)) return ONE; else { // At this point, we have |other| >= 2, so |this/other| < |MIN_VALUE|. var halfThis = this.shr(1); approx = halfThis.div(divisor).shl(1); if (approx.eq(ZERO)) { return divisor.isNegative() ? ONE : NEG_ONE; } else { rem = this.sub(divisor.mul(approx)); res = approx.add(rem.div(divisor)); return res; } } } else if (divisor.eq(MIN_VALUE)) return this.unsigned ? UZERO : ZERO; if (this.isNegative()) { if (divisor.isNegative()) return this.neg().div(divisor.neg()); return this.neg().div(divisor).neg(); } else if (divisor.isNegative()) return this.div(divisor.neg()).neg(); res = ZERO; } else { // The algorithm below has not been made for unsigned longs. It's therefore // required to take special care of the MSB prior to running it. if (!divisor.unsigned) divisor = divisor.toUnsigned(); if (divisor.gt(this)) return UZERO; if (divisor.gt(this.shru(1))) // 15 >>> 1 = 7 ; with divisor = 8 ; true return UONE; res = UZERO; } // Repeat the following until the remainder is less than other: find a // floating-point that approximates remainder / other *from below*, add this // into the result, and subtract it from the remainder. It is critical that // the approximate value is less than or equal to the real value so that the // remainder never becomes negative. rem = this; while (rem.gte(divisor)) { // Approximate the result of division. This may be a little greater or // smaller than the actual value. approx = Math.max(1, Math.floor(rem.toNumber() / divisor.toNumber())); // We will tweak the approximate result by changing it in the 48-th digit or // the smallest non-fractional digit, whichever is larger. var log2 = Math.ceil(Math.log(approx) / Math.LN2), delta = (log2 <= 48) ? 1 : pow_dbl(2, log2 - 48), // Decrease the approximation until it is smaller than the remainder. Note // that if it is too large, the product overflows and is negative. approxRes = fromNumber(approx), approxRem = approxRes.mul(divisor); while (approxRem.isNegative() || approxRem.gt(rem)) { approx -= delta; approxRes = fromNumber(approx, this.unsigned); approxRem = approxRes.mul(divisor); } // We know the answer can't be zero... and actually, zero would cause // infinite recursion since we would make no progress. if (approxRes.isZero()) approxRes = ONE; res = res.add(approxRes); rem = rem.sub(approxRem); } return res; }; /** * Returns this Long divided by the specified. This is an alias of {@link Long#divide}. * @function * @param {!Long|number|string} divisor Divisor * @returns {!Long} Quotient */ LongPrototype.div = LongPrototype.divide; /** * Returns this Long modulo the specified. * @param {!Long|number|string} divisor Divisor * @returns {!Long} Remainder */ LongPrototype.modulo = function modulo(divisor) { if (!isLong(divisor)) divisor = fromValue(divisor); // use wasm support if present if (wasm) { var low = (this.unsigned ? wasm.rem_u : wasm.rem_s)(this.low, this.high, divisor.low, divisor.high); return fromBits(low, wasm.get_high(), this.unsigned); } return this.sub(this.div(divisor).mul(divisor)); }; /** * Returns this Long modulo the specified. This is an alias of {@link Long#modulo}. * @function * @param {!Long|number|string} divisor Divisor * @returns {!Long} Remainder */ LongPrototype.mod = LongPrototype.modulo; /** * Returns this Long modulo the specified. This is an alias of {@link Long#modulo}. * @function * @param {!Long|number|string} divisor Divisor * @returns {!Long} Remainder */ LongPrototype.rem = LongPrototype.modulo; /** * Returns the bitwise NOT of this Long. * @returns {!Long} */ LongPrototype.not = function not() { return fromBits(~this.low, ~this.high, this.unsigned); }; /** * Returns the bitwise AND of this Long and the specified. * @param {!Long|number|string} other Other Long * @returns {!Long} */ LongPrototype.and = function and(other) { if (!isLong(other)) other = fromValue(other); return fromBits(this.low & other.low, this.high & other.high, this.unsigned); }; /** * Returns the bitwise OR of this Long and the specified. * @param {!Long|number|string} other Other Long * @returns {!Long} */ LongPrototype.or = function or(other) { if (!isLong(other)) other = fromValue(other); return fromBits(this.low | other.low, this.high | other.high, this.unsigned); }; /** * Returns the bitwise XOR of this Long and the given one. * @param {!Long|number|string} other Other Long * @returns {!Long} */ LongPrototype.xor = function xor(other) { if (!isLong(other)) other = fromValue(other); return fromBits(this.low ^ other.low, this.high ^ other.high, this.unsigned); }; /** * Returns this Long with bits shifted to the left by the given amount. * @param {number|!Long} numBits Number of bits * @returns {!Long} Shifted Long */ LongPrototype.shiftLeft = function shiftLeft(numBits) { if (isLong(numBits)) numBits = numBits.toInt(); if ((numBits &= 63) === 0) return this; else if (numBits < 32) return fromBits(this.low << numBits, (this.high << numBits) | (this.low >>> (32 - numBits)), this.unsigned); else return fromBits(0, this.low << (numBits - 32), this.unsigned); }; /** * Returns this Long with bits shifted to the left by the given amount. This is an alias of {@link Long#shiftLeft}. * @function * @param {number|!Long} numBits Number of bits * @returns {!Long} Shifted Long */ LongPrototype.shl = LongPrototype.shiftLeft; /** * Returns this Long with bits arithmetically shifted to the right by the given amount. * @param {number|!Long} numBits Number of bits * @returns {!Long} Shifted Long */ LongPrototype.shiftRight = function shiftRight(numBits) { if (isLong(numBits)) numBits = numBits.toInt(); if ((numBits &= 63) === 0) return this; else if (numBits < 32) return fromBits((this.low >>> numBits) | (this.high << (32 - numBits)), this.high >> numBits, this.unsigned); else return fromBits(this.high >> (numBits - 32), this.high >= 0 ? 0 : -1, this.unsigned); }; /** * Returns this Long with bits arithmetically shifted to the right by the given amount. This is an alias of {@link Long#shiftRight}. * @function * @param {number|!Long} numBits Number of bits * @returns {!Long} Shifted Long */ LongPrototype.shr = LongPrototype.shiftRight; /** * Returns this Long with bits logically shifted to the right by the given amount. * @param {number|!Long} numBits Number of bits * @returns {!Long} Shifted Long */ LongPrototype.shiftRightUnsigned = function shiftRightUnsigned(numBits) { if (isLong(numBits)) numBits = numBits.toInt(); numBits &= 63; if (numBits === 0) return this; else { var high = this.high; if (numBits < 32) { var low = this.low; return fromBits((low >>> numBits) | (high << (32 - numBits)), high >>> numBits, this.unsigned); } else if (numBits === 32) return fromBits(high, 0, this.unsigned); else return fromBits(high >>> (numBits - 32), 0, this.unsigned); } }; /** * Returns this Long with bits logically shifted to the right by the given amount. This is an alias of {@link Long#shiftRightUnsigned}. * @function * @param {number|!Long} numBits Number of bits * @returns {!Long} Shifted Long */ LongPrototype.shru = LongPrototype.shiftRightUnsigned; /** * Returns this Long with bits logically shifted to the right by the given amount. This is an alias of {@link Long#shiftRightUnsigned}. * @function * @param {number|!Long} numBits Number of bits * @returns {!Long} Shifted Long */ LongPrototype.shr_u = LongPrototype.shiftRightUnsigned; /** * Converts this Long to signed. * @returns {!Long} Signed long */ LongPrototype.toSigned = function toSigned() { if (!this.unsigned) return this; return fromBits(this.low, this.high, false); }; /** * Converts this Long to unsigned. * @returns {!Long} Unsigned long */ LongPrototype.toUnsigned = function toUnsigned() { if (this.unsigned) return this; return fromBits(this.low, this.high, true); }; /** * Converts this Long to its byte representation. * @param {boolean=} le Whether little or big endian, defaults to big endian * @returns {!Array.} Byte representation */ LongPrototype.toBytes = function toBytes(le) { return le ? this.toBytesLE() : this.toBytesBE(); }; /** * Converts this Long to its little endian byte representation. * @returns {!Array.} Little endian byte representation */ LongPrototype.toBytesLE = function toBytesLE() { var hi = this.high, lo = this.low; return [ lo & 0xff, lo >>> 8 & 0xff, lo >>> 16 & 0xff, lo >>> 24, hi & 0xff, hi >>> 8 & 0xff, hi >>> 16 & 0xff, hi >>> 24 ]; }; /** * Converts this Long to its big endian byte representation. * @returns {!Array.} Big endian byte representation */ LongPrototype.toBytesBE = function toBytesBE() { var hi = this.high, lo = this.low; return [ hi >>> 24, hi >>> 16 & 0xff, hi >>> 8 & 0xff, hi & 0xff, lo >>> 24, lo >>> 16 & 0xff, lo >>> 8 & 0xff, lo & 0xff ]; }; /** * Creates a Long from its byte representation. * @param {!Array.} bytes Byte representation * @param {boolean=} unsigned Whether unsigned or not, defaults to signed * @param {boolean=} le Whether little or big endian, defaults to big endian * @returns {Long} The corresponding Long value */ Long$1.fromBytes = function fromBytes(bytes, unsigned, le) { return le ? Long$1.fromBytesLE(bytes, unsigned) : Long$1.fromBytesBE(bytes, unsigned); }; /** * Creates a Long from its little endian byte representation. * @param {!Array.} bytes Little endian byte representation * @param {boolean=} unsigned Whether unsigned or not, defaults to signed * @returns {Long} The corresponding Long value */ Long$1.fromBytesLE = function fromBytesLE(bytes, unsigned) { return new Long$1(bytes[0] | bytes[1] << 8 | bytes[2] << 16 | bytes[3] << 24, bytes[4] | bytes[5] << 8 | bytes[6] << 16 | bytes[7] << 24, unsigned); }; /** * Creates a Long from its big endian byte representation. * @param {!Array.} bytes Big endian byte representation * @param {boolean=} unsigned Whether unsigned or not, defaults to signed * @returns {Long} The corresponding Long value */ Long$1.fromBytesBE = function fromBytesBE(bytes, unsigned) { return new Long$1(bytes[4] << 24 | bytes[5] << 16 | bytes[6] << 8 | bytes[7], bytes[0] << 24 | bytes[1] << 16 | bytes[2] << 8 | bytes[3], unsigned); }; var long$1 = /*@__PURE__*/ getDefaultExportFromCjs(long); var LongExports = /*#__PURE__*/_mergeNamespaces({ __proto__: null, default: long$1 }, [long]); // tslint:disable-next-line var Long = // tslint:disable-next-line long$1 || LongExports; function hexToLong(hex) { return Long.fromString(hex, true, 16); } // Some primes between 2^63 and 2^64 for various uses. // Hex 0xc3a5c85c97cb3127 hexToLong('c3a5c85c97cb3127'); // Hex 0xb492b66fbe98f273 hexToLong('b492b66fbe98f273'); // Hex 0x9ae16a3b2f90404f hexToLong('9ae16a3b2f90404f'); function noConversionNeeded(a, dtype) { return (a instanceof Float32Array && dtype === 'float32') || (a instanceof Int32Array && dtype === 'int32') || (a instanceof Uint8Array && dtype === 'bool'); } function toTypedArray(a, dtype) { if (dtype === 'string') { throw new Error('Cannot convert a string[] to a TypedArray'); } if (Array.isArray(a)) { a = flatten$1(a); } if (env().getBool('DEBUG')) { checkConversionForErrors(a, dtype); } if (noConversionNeeded(a, dtype)) { return a; } if (dtype == null || dtype === 'float32' || dtype === 'complex64') { return new Float32Array(a); } else if (dtype === 'int32') { return new Int32Array(a); } else if (dtype === 'bool') { var bool = new Uint8Array(a.length); for (var i = 0; i < bool.length; ++i) { if (Math.round(a[i]) !== 0) { bool[i] = 1; } } return bool; } else { throw new Error("Unknown data type ".concat(dtype)); } } /** * Returns the current high-resolution time in milliseconds relative to an * arbitrary time in the past. It works across different platforms (node.js, * browsers). * * ```js * console.log(tf.util.now()); * ``` * * @doc {heading: 'Util', namespace: 'util'} */ function now() { return env().platform.now(); } /** * Encodes the provided string into bytes using the provided encoding scheme. * * @param s The string to encode. * @param encoding The encoding scheme. Defaults to utf-8. * * @doc {heading: 'Util'} */ function encodeString(s, encoding) { if (encoding === void 0) { encoding = 'utf-8'; } encoding = encoding || 'utf-8'; return env().platform.encode(s, encoding); } /** * Decodes the provided bytes into a string using the provided encoding scheme. * @param bytes The bytes to decode. * * @param encoding The encoding scheme. Defaults to utf-8. * * @doc {heading: 'Util'} */ function decodeString(bytes, encoding) { if (encoding === void 0) { encoding = 'utf-8'; } encoding = encoding || 'utf-8'; return env().platform.decode(bytes, encoding); } function isTypedArray(a) { // TODO(mattsoulanille): Remove this fallback in 5.0.0 if (env().platform.isTypedArray != null) { return env().platform.isTypedArray(a); } else { return isTypedArrayBrowser(a); } } // NOTE: We explicitly type out what T extends instead of any so that // util.flatten on a nested array of number doesn't try to infer T as a // number[][], causing us to explicitly type util.flatten(). /** * Flattens an arbitrarily nested array. * * ```js * const a = [[1, 2], [3, 4], [5, [6, [7]]]]; * const flat = tf.util.flatten(a); * console.log(flat); * ``` * * @param arr The nested array to flatten. * @param result The destination array which holds the elements. * @param skipTypedArray If true, avoids flattening the typed arrays. Defaults * to false. * * @doc {heading: 'Util', namespace: 'util'} */ function flatten$1(arr, result, skipTypedArray) { var e_1, _a; if (result === void 0) { result = []; } if (skipTypedArray === void 0) { skipTypedArray = false; } if (result == null) { result = []; } if (typeof arr === 'boolean' || typeof arr === 'number' || typeof arr === 'string' || isPromise(arr) || arr == null || isTypedArray(arr) && skipTypedArray) { result.push(arr); } else if (Array.isArray(arr) || isTypedArray(arr)) { for (var i = 0; i < arr.length; ++i) { flatten$1(arr[i], result, skipTypedArray); } } else { var maxIndex = -1; try { for (var _b = __values(Object.keys(arr)), _c = _b.next(); !_c.done; _c = _b.next()) { var key = _c.value; // 0 or positive integer. if (/^([1-9]+[0-9]*|0)$/.test(key)) { maxIndex = Math.max(maxIndex, Number(key)); } } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_1) throw e_1.error; } } for (var i = 0; i <= maxIndex; i++) { // tslint:disable-next-line: no-unnecessary-type-assertion flatten$1(arr[i], result, skipTypedArray); } } return result; } var Profiler = /** @class */ (function () { function Profiler(backendTimer, logger) { this.backendTimer = backendTimer; this.logger = logger; if (logger == null) { this.logger = new Logger(); } } Profiler.prototype.profileKernel = function (kernelName, inputs, f) { var e_1, _a; var outputs; var holdResultWrapperFn = function () { outputs = f(); }; var timer; var start = now(); if (this.backendTimer.timerAvailable()) { timer = this.backendTimer.time(holdResultWrapperFn); } else { holdResultWrapperFn(); try { for (var outputs_1 = __values(outputs), outputs_1_1 = outputs_1.next(); !outputs_1_1.done; outputs_1_1 = outputs_1.next()) { var output = outputs_1_1.value; output.dataSync(); } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (outputs_1_1 && !outputs_1_1.done && (_a = outputs_1.return)) _a.call(outputs_1); } finally { if (e_1) throw e_1.error; } } timer = Promise.resolve({ kernelMs: now() - start }); } if (env().getBool('CHECK_COMPUTATION_FOR_ERRORS')) { var _loop_1 = function (i) { var output = outputs[i]; // Dangling promise here because we don't want to propagate up // asynchronicity. output.data().then(function (tensorVals) { checkComputationForErrors(tensorVals, output.dtype, kernelName); }); }; for (var i = 0; i < outputs.length; i++) { _loop_1(i); } } var kernelProfile = { kernelName: kernelName, outputs: outputs, inputs: inputs, timeMs: timer.then(function (timing) { return timing.kernelMs; }), extraInfo: timer.then(function (timing) { return timing.getExtraProfileInfo != null ? timing.getExtraProfileInfo() : ''; }) }; return kernelProfile; }; Profiler.prototype.logKernelProfile = function (kernelProfile) { var _this = this; var kernelName = kernelProfile.kernelName, outputs = kernelProfile.outputs, timeMs = kernelProfile.timeMs, inputs = kernelProfile.inputs, extraInfo = kernelProfile.extraInfo; outputs.forEach(function (result) { Promise.all([result.data(), timeMs, extraInfo]).then(function (valueContainer) { _this.logger.logKernelProfile(kernelName, result, valueContainer[0], valueContainer[1], inputs, valueContainer[2]); }); }); }; return Profiler; }()); function checkComputationForErrors(vals, dtype, kernelName) { if (dtype !== 'float32') { // Only floating point computations will generate NaN values return false; } for (var i = 0; i < vals.length; i++) { var num = vals[i]; if (isNaN(num) || !isFinite(num)) { // Throwing custom exception so behavior is testable. console.warn("Found ".concat(num, " in the result of '").concat(kernelName, "'")); return true; } } return false; } var Logger = /** @class */ (function () { function Logger() { } Logger.prototype.logKernelProfile = function (name, result, vals, timeMs, inputs, extraInfo) { var time = typeof timeMs === 'number' ? rightPad("".concat(timeMs, "ms"), 9) : timeMs['error']; var paddedName = rightPad(name, 25); var rank = result.rank; var size = result.size; var shape = rightPad(result.shape.toString(), 14); var inputShapesDescription = ''; for (var name_1 in inputs) { var input = inputs[name_1]; if (input != null) { // The input might be a non-tensor (e.g HTMLImageElement), in which case // we claim the output shape as input shape. var inputShape = input.shape || result.shape; var inputRank = inputShape.length; inputShapesDescription += "".concat(name_1, ": ").concat(inputRank, "D ").concat(inputRank > 0 ? inputShape : '', " "); } } console.log("%c".concat(paddedName, "\t%c").concat(time, "\t%c").concat(rank, "D ").concat(shape, "\t%c").concat(size, "\t%c").concat(inputShapesDescription, "\t%c").concat(extraInfo), 'font-weight:bold', 'color:red', 'color:blue', 'color: orange', 'color: green', 'color: steelblue'); }; return Logger; }()); /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes a list of TapeNodes that connect x to y, filtering everything else * out and preserving the order of the original tape elements. * * @param tape The tape elements to filter. * @param xs The input Tensors. * @param y The output Tensor. */ function getFilteredNodesXToY(tape, xs, y) { // Forward pass to compute all the nodes and Tensors that are transitively a // function of x. var tensorsFromX = {}; var nodesFromX = {}; for (var i = 0; i < xs.length; i++) { tensorsFromX[xs[i].id] = true; } for (var i = 0; i < tape.length; i++) { var node = tape[i]; var nodeInputs = node.inputs; for (var inputName in nodeInputs) { var input = nodeInputs[inputName]; var anyInputFromX = false; for (var j = 0; j < xs.length; j++) { if (tensorsFromX[input.id]) { node.outputs.forEach(function (output) { return tensorsFromX[output.id] = true; }); anyInputFromX = true; nodesFromX[node.id] = true; break; } } if (anyInputFromX) { break; } } } // Backward pass to find all of the nodes and Tensors that lead to y. var tensorsLeadToY = {}; tensorsLeadToY[y.id] = true; var nodesToY = {}; for (var i = tape.length - 1; i >= 0; i--) { var node = tape[i]; var nodeInputs = node.inputs; // If any of the outputs lead to y, mark all of the inputs as leading to y. for (var j = 0; j < node.outputs.length; j++) { if (tensorsLeadToY[node.outputs[j].id]) { for (var inputName in nodeInputs) { tensorsLeadToY[nodeInputs[inputName].id] = true; nodesToY[node.id] = true; } break; } } } // Return the paths that come from x and lead to y. var filteredTape = []; for (var i = 0; i < tape.length; i++) { var node = tape[i]; if (nodesFromX[node.id] && nodesToY[node.id]) { // Prune the inputs from the node that aren't a function of x. var prunedInputs = {}; for (var inputName in node.inputs) { var nodeInput = node.inputs[inputName]; if (tensorsFromX[nodeInput.id]) { prunedInputs[inputName] = nodeInput; } } // Copy the node and overwrite inputsAndArgs to the pruned version. var prunedNode = Object.assign({}, node); prunedNode.inputs = prunedInputs; prunedNode.outputs = node.outputs; filteredTape.push(prunedNode); } } return filteredTape; } /** * Backpropagate gradients through the filtered TapeNodes. * * @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map * is mutated by this method. * @param filteredTape The filtered TapeNodes to backprop through. */ function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy, add) { var _loop_1 = function (i) { var node = filteredTape[i]; var dys = []; node.outputs.forEach(function (o) { var gradTensor = tensorAccumulatedGradientMap[o.id]; if (gradTensor != null) { dys.push(gradTensor); } else { // This particular output is not in the back-propagation subgraph, so it // does not affect the final output, thus we put null for its dy. dys.push(null); } }); if (node.gradient == null) { throw new Error("Cannot compute gradient: gradient function not found " + "for ".concat(node.kernelName, ".")); } // Backprop dy through this node and accumulate gradients over the inputs. var inputGradients = node.gradient(dys); var _loop_2 = function (inputName) { if (!(inputName in inputGradients)) { throw new Error("Cannot backprop through input ".concat(inputName, ". ") + "Available gradients found: ".concat(Object.keys(inputGradients), ".")); } // Call the gradient function. var dx = tidy(function () { return inputGradients[inputName](); }); if (dx.dtype !== 'float32') { throw new Error("Error in gradient for op ".concat(node.kernelName, ". The gradient of input ") + "".concat(inputName, " must have 'float32' dtype, but has '").concat(dx.dtype, "'")); } var x = node.inputs[inputName]; if (!arraysEqual(dx.shape, x.shape)) { throw new Error("Error in gradient for op ".concat(node.kernelName, ". The gradient of input ") + "'".concat(inputName, "' has shape '").concat(dx.shape, "', which does not match ") + "the shape of the input '".concat(x.shape, "'")); } if (tensorAccumulatedGradientMap[x.id] == null) { tensorAccumulatedGradientMap[x.id] = dx; } else { var curGradient = tensorAccumulatedGradientMap[x.id]; tensorAccumulatedGradientMap[x.id] = add(curGradient, dx); curGradient.dispose(); } }; for (var inputName in node.inputs) { _loop_2(inputName); } }; // Walk the tape backward and keep a map of Tensor to its gradient. for (var i = filteredTape.length - 1; i >= 0; i--) { _loop_1(i); } } // Maximum number of values before we decide to show ellipsis. var FORMAT_LIMIT_NUM_VALS = 20; // Number of first and last values to show when displaying a, b,...,y, z. var FORMAT_NUM_FIRST_LAST_VALS = 3; // Number of significant digits to show. var FORMAT_NUM_SIG_DIGITS = 7; function tensorToString(vals, shape, dtype, verbose) { var strides = computeStrides(shape); var padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides); var rank = shape.length; var valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol); var lines = ['Tensor']; if (verbose) { lines.push(" dtype: ".concat(dtype)); lines.push(" rank: ".concat(rank)); lines.push(" shape: [".concat(shape, "]")); lines.push(" values:"); } lines.push(valsLines.map(function (l) { return ' ' + l; }).join('\n')); return lines.join('\n'); } function computeMaxSizePerColumn(vals, shape, dtype, strides) { var n = sizeFromShape(shape); var numCols = strides[strides.length - 1]; var padPerCol = new Array(numCols).fill(0); var rank = shape.length; var valuesOrTuples = dtype === 'complex64' ? createComplexTuples(vals) : vals; if (rank > 1) { for (var row = 0; row < n / numCols; row++) { var offset = row * numCols; for (var j = 0; j < numCols; j++) { padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length); } } } return padPerCol; } function valToString(val, pad, dtype) { var valStr; if (Array.isArray(val)) { valStr = "".concat(parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS)), " + ") + "".concat(parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS)), "j"); } else if (isString(val)) { valStr = "'".concat(val, "'"); } else if (dtype === 'bool') { valStr = boolNumToString(val); } else { valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString(); } return rightPad(valStr, pad); } function boolNumToString(v) { return v === 0 ? 'false' : 'true'; } function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast) { if (isLast === void 0) { isLast = true; } var storagePerElement = dtype === 'complex64' ? 2 : 1; var size = shape[0]; var rank = shape.length; if (rank === 0) { if (dtype === 'complex64') { var complexTuple = createComplexTuples(vals); return [valToString(complexTuple[0], 0, dtype)]; } if (dtype === 'bool') { return [boolNumToString(vals[0])]; } return [vals[0].toString()]; } if (rank === 1) { if (size > FORMAT_LIMIT_NUM_VALS) { var firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement; var firstVals = Array.from(vals.slice(0, firstValsSize)); var lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement)); if (dtype === 'complex64') { firstVals = createComplexTuples(firstVals); lastVals = createComplexTuples(lastVals); } return [ '[' + firstVals.map(function (x, i) { return valToString(x, padPerCol[i], dtype); }) .join(', ') + ', ..., ' + lastVals .map(function (x, i) { return valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype); }) .join(', ') + ']' ]; } var displayVals = dtype === 'complex64' ? createComplexTuples(vals) : Array.from(vals); return [ '[' + displayVals.map(function (x, i) { return valToString(x, padPerCol[i], dtype); }) .join(', ') + ']' ]; } // The array is rank 2 or more. var subshape = shape.slice(1); var substrides = strides.slice(1); var stride = strides[0] * storagePerElement; var lines = []; if (size > FORMAT_LIMIT_NUM_VALS) { for (var i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) { var start = i * stride; var end = start + stride; lines.push.apply(lines, __spreadArray([], __read(subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false /* isLast */)), false)); } lines.push('...'); for (var i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) { var start = i * stride; var end = start + stride; lines.push.apply(lines, __spreadArray([], __read(subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */)), false)); } } else { for (var i = 0; i < size; i++) { var start = i * stride; var end = start + stride; lines.push.apply(lines, __spreadArray([], __read(subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */)), false)); } } var sep = rank === 2 ? ',' : ''; lines[0] = '[' + (size > 0 ? lines[0] + sep : ''); for (var i = 1; i < lines.length - 1; i++) { lines[i] = ' ' + lines[i] + sep; } var newLineSep = ',\n'; for (var i = 2; i < rank; i++) { newLineSep += '\n'; } lines[lines.length - 1] = ' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep); return lines; } function createComplexTuples(vals) { var complexTuples = []; for (var i = 0; i < vals.length; i += 2) { complexTuples.push([vals[i], vals[i + 1]]); } return complexTuples; } /** * A mutable object, similar to `tf.Tensor`, that allows users to set values * at locations before converting to an immutable `tf.Tensor`. * * See `tf.buffer` for creating a tensor buffer. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ /** @class */ ((function () { function TensorBuffer(shape, dtype, values) { var _this = this; this.dtype = dtype; this.shape = shape.slice(); this.size = sizeFromShape(shape); if (values != null) { var n_1 = values.length; assert(n_1 === this.size, function () { return "Length of values '".concat(n_1, "' does not match the size ") + "inferred by the shape '".concat(_this.size, "'."); }); } if (dtype === 'complex64') { throw new Error("complex64 dtype TensorBuffers are not supported. Please create " + "a TensorBuffer for the real and imaginary parts separately and " + "call tf.complex(real, imag)."); } this.values = values || getArrayFromDType(dtype, this.size); this.strides = computeStrides(shape); } /** * Sets a value in the buffer at a given location. * * @param value The value to set. * @param locs The location indices. * * @doc {heading: 'Tensors', subheading: 'Creation'} */ TensorBuffer.prototype.set = function (value) { var _this = this; var locs = []; for (var _i = 1; _i < arguments.length; _i++) { locs[_i - 1] = arguments[_i]; } if (locs.length === 0) { locs = [0]; } assert(locs.length === this.rank, function () { return "The number of provided coordinates (".concat(locs.length, ") must ") + "match the rank (".concat(_this.rank, ")"); }); var index = this.locToIndex(locs); this.values[index] = value; }; /** * Returns the value in the buffer at the provided location. * * @param locs The location indices. * * @doc {heading: 'Tensors', subheading: 'Creation'} */ TensorBuffer.prototype.get = function () { var e_1, _b; var locs = []; for (var _i = 0; _i < arguments.length; _i++) { locs[_i] = arguments[_i]; } if (locs.length === 0) { locs = [0]; } var i = 0; try { for (var locs_1 = __values(locs), locs_1_1 = locs_1.next(); !locs_1_1.done; locs_1_1 = locs_1.next()) { var loc = locs_1_1.value; if (loc < 0 || loc >= this.shape[i]) { var msg = "Requested out of range element at ".concat(locs, ". ") + " Buffer shape=".concat(this.shape); throw new Error(msg); } i++; } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (locs_1_1 && !locs_1_1.done && (_b = locs_1.return)) _b.call(locs_1); } finally { if (e_1) throw e_1.error; } } var index = locs[locs.length - 1]; for (var i_1 = 0; i_1 < locs.length - 1; ++i_1) { index += this.strides[i_1] * locs[i_1]; } return this.values[index]; }; TensorBuffer.prototype.locToIndex = function (locs) { if (this.rank === 0) { return 0; } else if (this.rank === 1) { return locs[0]; } var index = locs[locs.length - 1]; for (var i = 0; i < locs.length - 1; ++i) { index += this.strides[i] * locs[i]; } return index; }; TensorBuffer.prototype.indexToLoc = function (index) { if (this.rank === 0) { return []; } else if (this.rank === 1) { return [index]; } var locs = new Array(this.shape.length); for (var i = 0; i < locs.length - 1; ++i) { locs[i] = Math.floor(index / this.strides[i]); index -= locs[i] * this.strides[i]; } locs[locs.length - 1] = index; return locs; }; Object.defineProperty(TensorBuffer.prototype, "rank", { get: function () { return this.shape.length; }, enumerable: false, configurable: true }); /** * Creates an immutable `tf.Tensor` object from the buffer. * * @doc {heading: 'Tensors', subheading: 'Creation'} */ TensorBuffer.prototype.toTensor = function () { return trackerFn().makeTensor(this.values, this.shape, this.dtype); }; return TensorBuffer; })()); // For tracking tensor creation and disposal. var trackerFn = null; // Used by chaining methods to call into ops. var opHandler = null; /** * An external consumer can register itself as the tensor tracker. This way * the Tensor class can notify the tracker for every tensor created and * disposed. */ function setTensorTracker(fn) { trackerFn = fn; } /** * A `tf.Tensor` object represents an immutable, multidimensional array of * numbers that has a shape and a data type. * * For performance reasons, functions that create tensors do not necessarily * perform a copy of the data passed to them (e.g. if the data is passed as a * `Float32Array`), and changes to the data will change the tensor. This is not * a feature and is not supported. To avoid this behavior, use the tensor before * changing the input data or create a copy with `copy = tf.add(yourTensor, 0)`. * * See `tf.tensor` for details on how to create a `tf.Tensor`. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ var Tensor = /** @class */ (function () { function Tensor(shape, dtype, dataId, id) { /** Whether this tensor has been globally kept. */ this.kept = false; this.isDisposedInternal = false; this.shape = shape.slice(); this.dtype = dtype || 'float32'; this.size = sizeFromShape(shape); this.strides = computeStrides(shape); this.dataId = dataId; this.id = id; this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher'); } Object.defineProperty(Tensor.prototype, "rank", { get: function () { return this.shape.length; }, enumerable: false, configurable: true }); /** * Returns a promise of `tf.TensorBuffer` that holds the underlying data. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.buffer = function () { return __awaiter(this, void 0, void 0, function () { var vals; return __generator(this, function (_b) { switch (_b.label) { case 0: return [4 /*yield*/, this.data()]; case 1: vals = _b.sent(); return [2 /*return*/, opHandler.buffer(this.shape, this.dtype, vals)]; } }); }); }; /** * Returns a `tf.TensorBuffer` that holds the underlying data. * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.bufferSync = function () { return opHandler.buffer(this.shape, this.dtype, this.dataSync()); }; /** * Returns the tensor data as a nested array. The transfer of data is done * asynchronously. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.array = function () { return __awaiter(this, void 0, void 0, function () { var vals; return __generator(this, function (_b) { switch (_b.label) { case 0: return [4 /*yield*/, this.data()]; case 1: vals = _b.sent(); return [2 /*return*/, toNestedArray(this.shape, vals, this.dtype === 'complex64')]; } }); }); }; /** * Returns the tensor data as a nested array. The transfer of data is done * synchronously. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.arraySync = function () { return toNestedArray(this.shape, this.dataSync(), this.dtype === 'complex64'); }; /** * Asynchronously downloads the values from the `tf.Tensor`. Returns a * promise of `TypedArray` that resolves when the computation has finished. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.data = function () { return __awaiter(this, void 0, void 0, function () { var data, bytes; return __generator(this, function (_b) { switch (_b.label) { case 0: this.throwIfDisposed(); data = trackerFn().read(this.dataId); if (!(this.dtype === 'string')) return [3 /*break*/, 2]; return [4 /*yield*/, data]; case 1: bytes = _b.sent(); try { return [2 /*return*/, bytes.map(function (b) { return decodeString(b); })]; } catch (_a) { throw new Error('Failed to decode the string bytes into utf-8. ' + 'To get the original bytes, call tensor.bytes().'); } _b.label = 2; case 2: return [2 /*return*/, data]; } }); }); }; /** * Copy the tensor's data to a new GPU resource. Comparing to the `dataSync()` * and `data()`, this method prevents data from being downloaded to CPU. * * For WebGL backend, the data will be stored on a densely packed texture. * This means that the texture will use the RGBA channels to store value. * * For WebGPU backend, the data will be stored on a buffer. There is no * parameter, so can not use a user-defined size to create the buffer. * * @param options: * For WebGL, * - customTexShape: Optional. If set, will use the user defined * texture shape to create the texture. * * @returns For WebGL backend, a GPUData contains the new texture and * its information. * { * tensorRef: The tensor that is associated with this texture, * texture: WebGLTexture, * texShape: [number, number] // [height, width] * } * * For WebGPU backend, a GPUData contains the new buffer. * { * tensorRef: The tensor that is associated with this buffer, * buffer: GPUBuffer, * } * * Remember to dispose the GPUData after it is used by * `res.tensorRef.dispose()`. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.dataToGPU = function (options) { this.throwIfDisposed(); return trackerFn().readToGPU(this.dataId, options); }; /** * Synchronously downloads the values from the `tf.Tensor`. This blocks the * UI thread until the values are ready, which can cause performance issues. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.dataSync = function () { this.throwIfDisposed(); var data = trackerFn().readSync(this.dataId); if (this.dtype === 'string') { try { return data.map(function (b) { return decodeString(b); }); } catch (_a) { throw new Error('Failed to decode the string bytes into utf-8. ' + 'To get the original bytes, call tensor.bytes().'); } } return data; }; /** Returns the underlying bytes of the tensor's data. */ Tensor.prototype.bytes = function () { return __awaiter(this, void 0, void 0, function () { var data; return __generator(this, function (_b) { switch (_b.label) { case 0: this.throwIfDisposed(); return [4 /*yield*/, trackerFn().read(this.dataId)]; case 1: data = _b.sent(); if (this.dtype === 'string') { return [2 /*return*/, data]; } else { return [2 /*return*/, new Uint8Array(data.buffer)]; } } }); }); }; /** * Disposes `tf.Tensor` from memory. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.dispose = function () { if (this.isDisposed) { return; } if (this.kerasMask) { this.kerasMask.dispose(); } trackerFn().disposeTensor(this); this.isDisposedInternal = true; }; Object.defineProperty(Tensor.prototype, "isDisposed", { get: function () { return this.isDisposedInternal; }, enumerable: false, configurable: true }); Tensor.prototype.throwIfDisposed = function () { if (this.isDisposed) { throw new Error("Tensor is disposed."); } }; /** * Prints the `tf.Tensor`. See `tf.print` for details. * * @param verbose Whether to print verbose information about the tensor, * including dtype and size. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.print = function (verbose) { if (verbose === void 0) { verbose = false; } return opHandler.print(this, verbose); }; /** * Returns a copy of the tensor. See `tf.clone` for details. * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.clone = function () { this.throwIfDisposed(); return opHandler.clone(this); }; /** * Returns a human-readable description of the tensor. Useful for logging. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.toString = function (verbose) { if (verbose === void 0) { verbose = false; } var vals = this.dataSync(); return tensorToString(vals, this.shape, this.dtype, verbose); }; Tensor.prototype.cast = function (dtype) { this.throwIfDisposed(); return opHandler.cast(this, dtype); }; Tensor.prototype.variable = function (trainable, name, dtype) { if (trainable === void 0) { trainable = true; } this.throwIfDisposed(); return trackerFn().makeVariable(this, trainable, name, dtype); }; return Tensor; }()); Object.defineProperty(Tensor, Symbol.hasInstance, { value: function (instance) { // Implementation note: we should use properties of the object that will be // defined before the constructor body has finished executing (methods). // This is because when this code is transpiled by babel, babel will call // classCallCheck before the constructor body is run. // See https://github.com/tensorflow/tfjs/issues/3384 for backstory. return !!instance && instance.data != null && instance.dataSync != null && instance.throwIfDisposed != null; } }); function getGlobalTensorClass() { // Use getGlobal so that we can augment the Tensor class across package // boundaries becase the node resolution alg may result in different modules // being returned for this file depending on the path they are loaded from. return getGlobal('Tensor', function () { return Tensor; }); } // Global side effect. Cache global reference to Tensor class getGlobalTensorClass(); /** * A mutable `tf.Tensor`, useful for persisting state, e.g. for training. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ var Variable = /** @class */ (function (_super) { __extends(Variable, _super); function Variable(initialValue, trainable, name, tensorId) { var _this = _super.call(this, initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId) || this; _this.trainable = trainable; _this.name = name; return _this; } /** * Assign a new `tf.Tensor` to this variable. The new `tf.Tensor` must have * the same shape and dtype as the old `tf.Tensor`. * * @param newValue New tensor to be assigned to this variable. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Variable.prototype.assign = function (newValue) { if (newValue.dtype !== this.dtype) { throw new Error("dtype of the new value (".concat(newValue.dtype, ") and ") + "previous value (".concat(this.dtype, ") must match")); } if (!arraysEqual(newValue.shape, this.shape)) { throw new Error("shape of the new value (".concat(newValue.shape, ") and ") + "previous value (".concat(this.shape, ") must match")); } trackerFn().disposeTensor(this); this.dataId = newValue.dataId; trackerFn().incRef(this, null /* backend */); }; Variable.prototype.dispose = function () { trackerFn().disposeVariable(this); this.isDisposedInternal = true; }; return Variable; }(Tensor)); Object.defineProperty(Variable, Symbol.hasInstance, { value: function (instance) { return instance instanceof Tensor && instance.assign != null && instance.assign instanceof Function; } }); /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var Rank; (function (Rank) { Rank["R0"] = "R0"; Rank["R1"] = "R1"; Rank["R2"] = "R2"; Rank["R3"] = "R3"; Rank["R4"] = "R4"; Rank["R5"] = "R5"; Rank["R6"] = "R6"; })(Rank || (Rank = {})); // Looks for upcasting types. Used, for example, in operations with mixed dtype // inputs. var UpcastInt32AndMap; (function (UpcastInt32AndMap) { UpcastInt32AndMap["float32"] = "float32"; UpcastInt32AndMap["int32"] = "int32"; UpcastInt32AndMap["bool"] = "int32"; UpcastInt32AndMap["complex64"] = "complex64"; })(UpcastInt32AndMap || (UpcastInt32AndMap = {})); var UpcastBoolAndMap; (function (UpcastBoolAndMap) { UpcastBoolAndMap["float32"] = "float32"; UpcastBoolAndMap["int32"] = "int32"; UpcastBoolAndMap["bool"] = "bool"; UpcastBoolAndMap["complex64"] = "complex64"; })(UpcastBoolAndMap || (UpcastBoolAndMap = {})); var UpcastFloat32AndMap; (function (UpcastFloat32AndMap) { UpcastFloat32AndMap["float32"] = "float32"; UpcastFloat32AndMap["int32"] = "float32"; UpcastFloat32AndMap["bool"] = "float32"; UpcastFloat32AndMap["complex64"] = "complex64"; })(UpcastFloat32AndMap || (UpcastFloat32AndMap = {})); var UpcastComplex64AndMap; (function (UpcastComplex64AndMap) { UpcastComplex64AndMap["float32"] = "complex64"; UpcastComplex64AndMap["int32"] = "complex64"; UpcastComplex64AndMap["bool"] = "complex64"; UpcastComplex64AndMap["complex64"] = "complex64"; })(UpcastComplex64AndMap || (UpcastComplex64AndMap = {})); var upcastTypeMap = { 'float32': UpcastFloat32AndMap, 'int32': UpcastInt32AndMap, 'bool': UpcastBoolAndMap, 'complex64': UpcastComplex64AndMap }; function upcastType(typeA, typeB) { if (typeA === 'string' || typeB === 'string') { if (typeA === 'string' && typeB === 'string') { return 'string'; } throw new Error("Can not upcast ".concat(typeA, " with ").concat(typeB)); } return upcastTypeMap[typeA][typeB]; } function isWebGLData(values) { return values != null && typeof values === 'object' && 'texture' in values && values.texture instanceof WebGLTexture; } function isWebGPUData(values) { return typeof GPUBuffer !== 'undefined' && values != null && typeof values === 'object' && 'buffer' in values && values.buffer instanceof GPUBuffer; } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function makeTypesMatch(a, b) { if (a.dtype === b.dtype) { return [a, b]; } var dtype = upcastType(a.dtype, b.dtype); return [a.cast(dtype), b.cast(dtype)]; } /** * Extracts any `Tensor`s found within the provided object. * * @param container an object that may be a `Tensor` or may directly contain * `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it * is safe to pass any object here, except that `Promise`s are not * supported. * @returns An array of `Tensors` found within the passed object. If the * argument is simply a `Tensor', a list containing that `Tensor` is * returned. If the object is not a `Tensor` or does not * contain `Tensors`, an empty list is returned. */ function getTensorsInContainer(result) { var list = []; var seen = new Set(); walkTensorContainer(result, list, seen); return list; } function walkTensorContainer(container, list, seen) { if (container == null) { return; } if (container instanceof Tensor) { list.push(container); return; } if (!isIterable(container)) { return; } // Iteration over keys works also for arrays. var iterable = container; for (var k in iterable) { var val = iterable[k]; if (!seen.has(val)) { seen.add(val); walkTensorContainer(val, list, seen); } } } // tslint:disable-next-line:no-any function isIterable(obj) { return Array.isArray(obj) || typeof obj === 'object'; } function isRegisteredKernelInvocation(kernelInvocation) { return kernelInvocation.kernelName != null; } var EngineState = /** @class */ (function () { function EngineState() { // Public since optimizers will use it. this.registeredVariables = {}; this.nextTapeNodeId = 0; this.numBytes = 0; this.numTensors = 0; this.numStringTensors = 0; this.numDataBuffers = 0; // Number of nested tf.grad() statements when computing higher-order // gradients. E.g. `1` for first-order gradients and `2` for second-order // gradients. Used to track if the tape should be removed after a backprop. this.gradientDepth = 0; // Number of nested kernel calls. When kernel depth is greater than 1, we turn // off the tape. this.kernelDepth = 0; this.scopeStack = []; /** * Keeps track of the number of data moves during a kernel execution. We * maintain a stack since kernels can call other kernels, recursively. */ this.numDataMovesStack = []; this.nextScopeId = 0; this.tensorInfo = new WeakMap(); this.profiling = false; this.activeProfile = { newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null, get kernelNames() { return Array.from(new Set(this.kernels.map(function (k) { return k.name; }))); } }; } EngineState.prototype.dispose = function () { for (var variableName in this.registeredVariables) { this.registeredVariables[variableName].dispose(); } }; return EngineState; }()); var Engine = /** @class */ (function () { function Engine(ENV) { this.ENV = ENV; this.registry = {}; this.registryFactory = {}; this.pendingBackendInitId = 0; this.state = new EngineState(); } Engine.prototype.ready = function () { return __awaiter(this, void 0, void 0, function () { var sortedBackends, i, backendName, success; return __generator(this, function (_a) { switch (_a.label) { case 0: if (this.pendingBackendInit != null) { return [2 /*return*/, this.pendingBackendInit.then(function () { })]; } if (this.backendInstance != null) { return [2 /*return*/]; } sortedBackends = this.getSortedBackends(); i = 0; _a.label = 1; case 1: if (!(i < sortedBackends.length)) return [3 /*break*/, 5]; backendName = sortedBackends[i]; return [4 /*yield*/, this.initializeBackend(backendName).success]; case 2: success = _a.sent(); if (!success) return [3 /*break*/, 4]; return [4 /*yield*/, this.setBackend(backendName)]; case 3: _a.sent(); return [2 /*return*/]; case 4: i++; return [3 /*break*/, 1]; case 5: throw new Error("Could not initialize any backends, all backend initializations " + "failed."); } }); }); }; Object.defineProperty(Engine.prototype, "backend", { get: function () { if (this.pendingBackendInit != null) { throw new Error("Backend '".concat(this.backendName, "' has not yet been initialized. Make ") + "sure to await tf.ready() or await tf.setBackend() before calling " + "other methods"); } if (this.backendInstance == null) { var _a = this.initializeBackendsAndReturnBest(), name = _a.name, asyncInit = _a.asyncInit; if (asyncInit) { throw new Error("The highest priority backend '".concat(name, "' has not yet been ") + "initialized. Make sure to await tf.ready() or " + "await tf.setBackend() before calling other methods"); } this.setBackend(name); } return this.backendInstance; }, enumerable: false, configurable: true }); Engine.prototype.backendNames = function () { return Object.keys(this.registryFactory); }; Engine.prototype.findBackend = function (backendName) { if (!(backendName in this.registry)) { // If the backend hasn't been initialized but we have a registry entry for // it, initialize it and return it. if (backendName in this.registryFactory) { var asyncInit = this.initializeBackend(backendName).asyncInit; if (asyncInit) { // Backend is not ready yet. return null; } } else { return null; } } return this.registry[backendName]; }; Engine.prototype.findBackendFactory = function (backendName) { if (!(backendName in this.registryFactory)) { return null; } return this.registryFactory[backendName].factory; }; Engine.prototype.registerBackend = function (backendName, factory, priority) { if (priority === void 0) { priority = 1; } if (backendName in this.registryFactory) { warn("".concat(backendName, " backend was already registered. ") + "Reusing existing backend factory."); return false; } this.registryFactory[backendName] = { factory: factory, priority: priority }; return true; }; Engine.prototype.setBackend = function (backendName) { return __awaiter(this, void 0, void 0, function () { var _a, success, asyncInit, result, _b; return __generator(this, function (_c) { switch (_c.label) { case 0: if (this.registryFactory[backendName] == null) { throw new Error("Backend name '".concat(backendName, "' not found in registry")); } this.backendName = backendName; if (!(this.registry[backendName] == null)) return [3 /*break*/, 4]; this.backendInstance = null; _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit; if (!asyncInit) return [3 /*break*/, 2]; return [4 /*yield*/, success]; case 1: _b = _c.sent(); return [3 /*break*/, 3]; case 2: _b = success; _c.label = 3; case 3: result = _b; if (!result) { return [2 /*return*/, false]; } _c.label = 4; case 4: this.backendInstance = this.registry[backendName]; this.setupRegisteredKernels(); // Reset the profiler. this.profiler = new Profiler(this.backendInstance); return [2 /*return*/, true]; } }); }); }; Engine.prototype.setupRegisteredKernels = function () { var _this = this; var kernels = getKernelsForBackend(this.backendName); kernels.forEach(function (kernel) { if (kernel.setupFunc != null) { kernel.setupFunc(_this.backendInstance); } }); }; Engine.prototype.disposeRegisteredKernels = function (backendName) { var _this = this; var kernels = getKernelsForBackend(backendName); kernels.forEach(function (kernel) { if (kernel.disposeFunc != null) { kernel.disposeFunc(_this.registry[backendName]); } }); }; /** * Initializes a backend by looking up the backend name in the factory * registry and calling the factory method. Returns a boolean representing * whether the initialization of the backend suceeded. Throws an error if * there is no backend in the factory registry. */ Engine.prototype.initializeBackend = function (backendName) { var _this = this; var registryFactoryEntry = this.registryFactory[backendName]; if (registryFactoryEntry == null) { throw new Error("Cannot initialize backend ".concat(backendName, ", no registration found.")); } try { var backend = registryFactoryEntry.factory(); /* Test if the factory returns a promise. Done in a more liberal way than previous 'Promise.resolve(backend)===backend' as we needed to account for custom Promise implementations (e.g. Angular) */ if (backend && !(backend instanceof KernelBackend) && typeof backend.then === 'function') { var promiseId_1 = ++this.pendingBackendInitId; var success = backend .then(function (backendInstance) { // Outdated promise. Another backend was set in the meantime. if (promiseId_1 < _this.pendingBackendInitId) { return false; } _this.registry[backendName] = backendInstance; _this.pendingBackendInit = null; return true; }) .catch(function (err) { // Outdated promise. Another backend was set in the meantime. if (promiseId_1 < _this.pendingBackendInitId) { return false; } _this.pendingBackendInit = null; warn("Initialization of backend ".concat(backendName, " failed")); warn(err.stack || err.message); return false; }); this.pendingBackendInit = success; return { success: success, asyncInit: true }; } else { this.registry[backendName] = backend; return { success: true, asyncInit: false }; } } catch (err) { warn("Initialization of backend ".concat(backendName, " failed")); warn(err.stack || err.message); return { success: false, asyncInit: false }; } }; Engine.prototype.removeBackend = function (backendName) { if (!(backendName in this.registryFactory)) { throw new Error("".concat(backendName, " backend not found in registry")); } if (this.backendName === backendName && this.pendingBackendInit != null) { // There is a pending promise of the backend we want to remove. Make it // obsolete. this.pendingBackendInitId++; } if (backendName in this.registry) { this.disposeRegisteredKernels(backendName); this.registry[backendName].dispose(); delete this.registry[backendName]; } delete this.registryFactory[backendName]; // Unset the backend if it is active. if (this.backendName === backendName) { this.pendingBackendInit = null; this.backendName = null; this.backendInstance = null; } }; Engine.prototype.getSortedBackends = function () { var _this = this; if (Object.keys(this.registryFactory).length === 0) { throw new Error('No backend found in registry.'); } return Object.keys(this.registryFactory).sort(function (a, b) { // Highest priority comes first. return _this.registryFactory[b].priority - _this.registryFactory[a].priority; }); }; Engine.prototype.initializeBackendsAndReturnBest = function () { var sortedBackends = this.getSortedBackends(); for (var i = 0; i < sortedBackends.length; i++) { var backendName = sortedBackends[i]; var _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit; if (asyncInit || success) { return { name: backendName, asyncInit: asyncInit }; } } throw new Error("Could not initialize any backends, all backend initializations " + "failed."); }; Engine.prototype.moveData = function (backend, dataId) { var info = this.state.tensorInfo.get(dataId); var srcBackend = info.backend; var values = this.readSync(dataId); var refCount = srcBackend.refCount(dataId); // Delete the tensor from the old backend and move it to the new // backend. srcBackend.disposeData(dataId, true); info.backend = backend; backend.move(dataId, values, info.shape, info.dtype, refCount); if (this.shouldCheckForMemLeaks()) { // Track the number of moves during a kernel execution to correctly // detect memory leaks. this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++; } }; Engine.prototype.tidy = function (nameOrFn, fn) { var _this = this; var name = null; if (fn == null) { // Called with only 1 argument. if (typeof nameOrFn !== 'function') { throw new Error('Please provide a function to tidy()'); } fn = nameOrFn; } else { // Called with 2 arguments. if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) { throw new Error('When calling with two arguments, the first argument ' + 'to tidy() must be a string'); } if (typeof fn !== 'function') { throw new Error('When calling with two arguments, the 2nd argument ' + 'to tidy() must be a function'); } name = nameOrFn; // TODO(nsthorat,smilkov): Do operation logging and performance // profiling. } var result; return this.scopedRun(function () { return _this.startScope(name); }, function () { return _this.endScope(result); }, function () { result = fn(); if (result instanceof Promise) { console.error('Cannot return a Promise inside of tidy.'); } return result; }); }; Engine.prototype.scopedRun = function (start, end, f) { start(); try { var res = f(); end(); return res; } catch (ex) { end(); throw ex; } }; Engine.prototype.nextTensorId = function () { return Engine.nextTensorId++; }; Engine.prototype.nextVariableId = function () { return Engine.nextVariableId++; }; /** * This method is called instead of the public-facing tensor.clone() when * saving a tensor for backwards pass. It makes sure to add the clone * operation to the tape regardless of being called inside a kernel * execution. */ Engine.prototype.clone = function (x) { var y = ENGINE.runKernel(Identity, { x: x }); var inputs = { x: x }; var grad = function (dy) { return ({ x: function () { var dtype = 'float32'; var gradInputs = { x: dy }; var attrs = { dtype: dtype }; return ENGINE.runKernel(Cast, gradInputs, // tslint:disable-next-line: no-unnecessary-type-assertion attrs); } }); }; var saved = []; this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {}); return y; }; /** * Execute a kernel with the given name and return the output tensor. * * @param kernelName The name of the kernel to execute. * @param inputs A map of input names to tensors. * @param attrs A map of attribute names to their values. An attribute is a * primitive (non-tensor) input to the kernel. * @param inputsToSave A list of tensors, inputs to save for the backprop * computation. * @param outputsToSave A list of booleans, specifying which output to save * for the backprop computation. These are booleans since the output * tensors are not visible to the user. */ Engine.prototype.runKernel = function (kernelName, inputs, attrs) { if (this.backendName == null) { // backend has not been initialized yet (backend initialization is lazy // can be deferred until an op/ kernel is run). // The below getter has side effects that will try to initialize the // backend and set properties like this.backendName // tslint:disable-next-line: no-unused-expression this.backend; } var hasKernel = getKernel(kernelName, this.backendName) != null; if (!hasKernel) { throw new Error("Kernel '".concat(kernelName, "' not registered for backend '").concat(this.backendName, "'")); } return this.runKernelFunc({ kernelName: kernelName, inputs: inputs, attrs: attrs }); }; Engine.prototype.shouldCheckForMemLeaks = function () { return this.ENV.getBool('IS_TEST'); }; Engine.prototype.checkKernelForMemLeak = function (kernelName, numDataIdsBefore, outInfos) { var numDataIdsAfter = this.backend.numDataIds(); // Count the number of data ids associated with the result of the kernel. var numOutputDataIds = 0; outInfos.forEach(function (info) { // Complex numbers allocate 3 data ids, one for 'real', one for // 'imaginary', and one for the container that holds the former two. numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1); }); // Account for the number of moves during kernel execution. A "data move" // can happen in the middle of a kernel execution, placing a new (key,value) // pair in the data storage. Since data moves have net zero effect (we // always remove the data from the old backend), we have to cancel them out // when detecting memory leaks. var numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]; var dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves; if (dataIdsLeaked > 0) { throw new Error("Backend '".concat(this.backendName, "' has an internal memory leak ") + "(".concat(dataIdsLeaked, " data ids) after running '").concat(kernelName, "'")); } }; /** * Internal helper method to execute a kernel Func * * Use `runKernel` to execute kernels from outside of engine. */ Engine.prototype.runKernelFunc = function (kernelParams) { var _this = this; var outputs; var saved = []; var isTapeOn = this.isTapeOn(); var startingBytecount = this.state.numBytes; var startingNumTensors = this.state.numTensors; if (this.shouldCheckForMemLeaks()) { this.state.numDataMovesStack.push(0); } var kernelFunc; if (this.backendName == null) { // backend has not been initialized yet (backend initialization is lazy // can be deferred until an op/ kernel is run). // The below getter has side effects that will try to initialize the // backend and set properties like this.backendName // tslint:disable-next-line: no-unused-expression this.backend; } var out; var kernelOrScopeName = isRegisteredKernelInvocation(kernelParams) ? kernelParams.kernelName : this.state.activeScope != null ? this.state.activeScope.name : ''; // Create the kernelFunc from either a registered kernel OR passed in // forward/backward functions (used by custom grad). In this context a // kernelFunc wraps a kernel implementation with some bookkeeping. if (isRegisteredKernelInvocation(kernelParams)) { var kernelName_1 = kernelParams.kernelName, inputs_1 = kernelParams.inputs, attrs_1 = kernelParams.attrs; if (this.backendName == null) { // backend has not been initialized yet (backend initialization is lazy // can be deferred until an op/ kernel is run). // The below getter has side effects that will try to initialize the // backend and set properties like this.backendName // tslint:disable-next-line: no-unused-expression this.backend; } var kernel_1 = getKernel(kernelName_1, this.backendName); assert(kernel_1 != null, function () { return "Cannot find registered kernel '".concat(kernelName_1, "' for backend '").concat(_this.backendName, "'"); }); kernelFunc = function () { var numDataIdsBefore = _this.backend.numDataIds(); out = kernel_1.kernelFunc({ inputs: inputs_1, attrs: attrs_1, backend: _this.backend }); var outInfos = Array.isArray(out) ? out : [out]; if (_this.shouldCheckForMemLeaks()) { _this.checkKernelForMemLeak(kernelName_1, numDataIdsBefore, outInfos); } var outTensors = outInfos.map(function (outInfo) { // todo (yassogba) remove this option (Tensor) when node backend // methods have been modularized and they all return tensorInfo. // TensorInfos do not have a rank attribute. if (outInfo.rank != null) { return outInfo; } return _this.makeTensorFromTensorInfo(outInfo); }); // Save any required inputs and outputs. // Do not save unless we are recording to the tape. Otherwise it would // cause a mem leak since there would be no backprop for these tensors // (which would otherwise dispose them). if (isTapeOn) { var tensorsToSave = _this.getTensorsForGradient(kernelName_1, inputs_1, outTensors); saved = _this.saveTensorsForBackwardMode(tensorsToSave); } return outTensors; }; } else { var forwardFunc_1 = kernelParams.forwardFunc; // Running a customGrad op. var saveFunc_1 = function (tensors) { // Do not save unless we are recording to the tape. Otherwise it would // cause a mem leak since we would never run backprop, which disposes // the kept tensors. if (!isTapeOn) { return; } saved = tensors.map(function (tensor) { return _this.keep(_this.clone(tensor)); }); }; kernelFunc = function () { var numDataIdsBefore = _this.backend.numDataIds(); out = _this.tidy(function () { return forwardFunc_1(_this.backend, saveFunc_1); }); var outs = (Array.isArray(out) ? out : [out]); if (_this.shouldCheckForMemLeaks()) { // Scope name is used to print a more helpful error message if needed. _this.checkKernelForMemLeak(kernelOrScopeName, numDataIdsBefore, outs); } return outs; }; } // // Run the kernelFunc. Optionally profiling it. // var inputs = kernelParams.inputs, attrs = kernelParams.attrs; var backwardsFunc = isRegisteredKernelInvocation(kernelParams) ? null : kernelParams.backwardsFunc; var kernelProfile; this.scopedRun( // Stop recording to a tape when running a kernel. function () { return _this.state.kernelDepth++; }, function () { return _this.state.kernelDepth--; }, function () { if (!_this.ENV.getBool('DEBUG') && !_this.state.profiling) { outputs = kernelFunc(); } else { kernelProfile = _this.profiler.profileKernel(kernelOrScopeName, inputs, function () { return kernelFunc(); }); if (_this.ENV.getBool('DEBUG')) { _this.profiler.logKernelProfile(kernelProfile); } outputs = kernelProfile.outputs; } }); if (isTapeOn) { this.addTapeNode(kernelOrScopeName, inputs, outputs, backwardsFunc, saved, attrs); } if (this.state.profiling) { this.state.activeProfile.kernels.push({ name: kernelOrScopeName, bytesAdded: this.state.numBytes - startingBytecount, totalBytesSnapshot: this.state.numBytes, tensorsAdded: this.state.numTensors - startingNumTensors, totalTensorsSnapshot: this.state.numTensors, inputShapes: Object.keys(inputs).map(function (key) { return inputs[key] != null ? inputs[key].shape : null; }), outputShapes: outputs.map(function (item) { return item.shape; }), kernelTimeMs: kernelProfile.timeMs, extraInfo: kernelProfile.extraInfo }); } return (Array.isArray(out) ? outputs : outputs[0]); }; /** * Saves tensors used in forward mode for use in backward mode. * * @param tensors the list of tensors to save. */ Engine.prototype.saveTensorsForBackwardMode = function (tensors) { var _this = this; var saved = tensors.map(function (tensor) { return _this.keep(_this.clone(tensor)); }); return saved; }; /** * Returns a list of tensors to save for a given gradient calculation. * * @param kernelName name of kernel to look up gradient for. * @param inputs a map of input tensors. * @param outputs an array of output tensors from forward mode of kernel. */ Engine.prototype.getTensorsForGradient = function (kernelName, inputs, outputs) { var gradConfig = getGradient(kernelName); if (gradConfig != null) { var inputsToSave = gradConfig.inputsToSave || []; var outputsToSave_1 = gradConfig.outputsToSave || []; // If saveAllInputs is true, all inputs will be saved. Otherwise, inputs // specified in inputsToSave will be saved. var inputTensorsToSave = void 0; if (gradConfig.saveAllInputs) { assert(Array.isArray(inputs), function () { return 'saveAllInputs is true, expected inputs to be an array.'; }); inputTensorsToSave = Object.keys(inputs).map(function (key) { return inputs[key]; }); } else { inputTensorsToSave = inputsToSave.map(function (inputName) { return inputs[inputName]; }); } var outputTensorsToSave = outputs.filter(function (_, i) { return outputsToSave_1[i]; }); return inputTensorsToSave.concat(outputTensorsToSave); } // We return an empty list rather than throw an error because the kernel we // are looking up may not actually be relevant to backproping through the // overall function // // See 'does not error if irrelevant (pruned) ops are missing grads' test // in gradients_test.ts for an example. return []; }; /** * Internal method used by public APIs for tensor creation. Makes a new * tensor with the provided shape, dtype and values. It always * creates a new data id and writes the values to the underlying backend. */ Engine.prototype.makeTensor = function (values, shape, dtype, backend) { if (values == null) { throw new Error('Values passed to engine.makeTensor() are null'); } dtype = dtype || 'float32'; backend = backend || this.backend; var backendVals = values; if (dtype === 'string' && isString(values[0])) { backendVals = values.map(function (d) { return encodeString(d); }); } var dataId = backend.write(backendVals, shape, dtype); var t = new Tensor(shape, dtype, dataId, this.nextTensorId()); this.trackTensor(t, backend); // Count bytes for string tensors. if (dtype === 'string') { var info = this.state.tensorInfo.get(dataId); var newBytes = bytesFromStringArray(backendVals); this.state.numBytes += newBytes - info.bytes; info.bytes = newBytes; } return t; }; /** * Internal method used by backends. Makes a new tensor * that is a wrapper around an existing data id. It doesn't create * a new data id, only increments the ref count used in memory tracking. * @deprecated */ Engine.prototype.makeTensorFromDataId = function (dataId, shape, dtype, backend) { dtype = dtype || 'float32'; var tensorInfo = { dataId: dataId, shape: shape, dtype: dtype }; return this.makeTensorFromTensorInfo(tensorInfo, backend); }; /** * Internal method used by backends. Makes a new tensor that is a wrapper * around an existing data id in TensorInfo. It doesn't create a new data id, * only increments the ref count used in memory tracking. */ Engine.prototype.makeTensorFromTensorInfo = function (tensorInfo, backend) { var dataId = tensorInfo.dataId, shape = tensorInfo.shape, dtype = tensorInfo.dtype; var t = new Tensor(shape, dtype, dataId, this.nextTensorId()); this.trackTensor(t, backend); return t; }; Engine.prototype.makeVariable = function (initialValue, trainable, name, dtype) { if (trainable === void 0) { trainable = true; } name = name || this.nextVariableId().toString(); if (dtype != null && dtype !== initialValue.dtype) { initialValue = initialValue.cast(dtype); } var v = new Variable(initialValue, trainable, name, this.nextTensorId()); if (this.state.registeredVariables[v.name] != null) { throw new Error("Variable with name ".concat(v.name, " was already registered")); } this.state.registeredVariables[v.name] = v; this.incRef(v, this.backend); return v; }; Engine.prototype.trackTensor = function (a, backend) { this.state.numTensors++; if (a.dtype === 'string') { this.state.numStringTensors++; } // Bytes for complex numbers are counted by their components. Bytes for // string tensors are counted when writing values. var bytes = 0; if (a.dtype !== 'complex64' && a.dtype !== 'string') { bytes = a.size * bytesPerElement(a.dtype); } this.state.numBytes += bytes; if (!this.state.tensorInfo.has(a.dataId)) { this.state.numDataBuffers++; this.state.tensorInfo.set(a.dataId, { backend: backend || this.backend, dtype: a.dtype, shape: a.shape, bytes: bytes }); } if (!(a instanceof Variable)) { this.track(a); } }; // Track the tensor by dataId and increase the refCount for the dataId in the // backend. // TODO(pyu10055): This is currently used by makeVariable method, to increase // refCount on the backend for the dataId. It can potentially be replaced with // Identity op indead of calling backend directly. Engine.prototype.incRef = function (a, backend) { this.trackTensor(a, backend); this.backend.incRef(a.dataId); }; Engine.prototype.removeDataId = function (dataId, backend) { if (this.state.tensorInfo.has(dataId) && this.state.tensorInfo.get(dataId).backend === backend) { this.state.tensorInfo.delete(dataId); this.state.numDataBuffers--; } }; Engine.prototype.disposeTensor = function (a) { if (!this.state.tensorInfo.has(a.dataId)) { return; } var info = this.state.tensorInfo.get(a.dataId); this.state.numTensors--; if (a.dtype === 'string') { this.state.numStringTensors--; this.state.numBytes -= info.bytes; } // Don't count bytes for complex numbers as they are counted by their // components. if (a.dtype !== 'complex64' && a.dtype !== 'string') { var bytes = a.size * bytesPerElement(a.dtype); this.state.numBytes -= bytes; } // Remove the reference to dataId if backend dispose the data successfully if (info.backend.disposeData(a.dataId)) { this.removeDataId(a.dataId, info.backend); } // TODO(nsthorat): Construct an error and save the stack trace for // debugging when in debug mode. Creating a stack trace is too expensive // to do unconditionally. }; Engine.prototype.disposeVariables = function () { for (var varName in this.state.registeredVariables) { var v = this.state.registeredVariables[varName]; this.disposeVariable(v); } }; Engine.prototype.disposeVariable = function (v) { this.disposeTensor(v); if (this.state.registeredVariables[v.name] != null) { delete this.state.registeredVariables[v.name]; } }; Engine.prototype.memory = function () { var info = this.backend.memory(); info.numTensors = this.state.numTensors; info.numDataBuffers = this.state.numDataBuffers; info.numBytes = this.state.numBytes; if (this.state.numStringTensors > 0) { info.unreliable = true; if (info.reasons == null) { info.reasons = []; } info.reasons.push('Memory usage by string tensors is approximate ' + '(2 bytes per character)'); } return info; }; Engine.prototype.profile = function (query) { return __awaiter(this, void 0, void 0, function () { var startBytes, startNumTensors, _a, _b, _c, kernel, _d, _e, e_1_1; var e_1, _f; return __generator(this, function (_g) { switch (_g.label) { case 0: this.state.profiling = true; startBytes = this.state.numBytes; startNumTensors = this.state.numTensors; this.state.activeProfile.kernels = []; _a = this.state.activeProfile; return [4 /*yield*/, query()]; case 1: _a.result = _g.sent(); this.state.profiling = false; this.state.activeProfile.peakBytes = Math.max.apply(Math, __spreadArray([], __read(this.state.activeProfile.kernels.map(function (d) { return d.totalBytesSnapshot; })), false)); this.state.activeProfile.newBytes = this.state.numBytes - startBytes; this.state.activeProfile.newTensors = this.state.numTensors - startNumTensors; _g.label = 2; case 2: _g.trys.push([2, 8, 9, 10]); _b = __values(this.state.activeProfile.kernels), _c = _b.next(); _g.label = 3; case 3: if (!!_c.done) return [3 /*break*/, 7]; kernel = _c.value; _d = kernel; return [4 /*yield*/, kernel.kernelTimeMs]; case 4: _d.kernelTimeMs = _g.sent(); _e = kernel; return [4 /*yield*/, kernel.extraInfo]; case 5: _e.extraInfo = _g.sent(); _g.label = 6; case 6: _c = _b.next(); return [3 /*break*/, 3]; case 7: return [3 /*break*/, 10]; case 8: e_1_1 = _g.sent(); e_1 = { error: e_1_1 }; return [3 /*break*/, 10]; case 9: try { if (_c && !_c.done && (_f = _b.return)) _f.call(_b); } finally { if (e_1) throw e_1.error; } return [7 /*endfinally*/]; case 10: return [2 /*return*/, this.state.activeProfile]; } }); }); }; Engine.prototype.isTapeOn = function () { return this.state.gradientDepth > 0 && this.state.kernelDepth === 0; }; Engine.prototype.addTapeNode = function (kernelName, inputs, outputs, gradientsFunc, saved, attrs) { var _this = this; var tapeNode = { id: this.state.nextTapeNodeId++, kernelName: kernelName, inputs: inputs, outputs: outputs, saved: saved }; var gradConfig = getGradient(kernelName); if (gradConfig != null) { gradientsFunc = gradConfig.gradFunc; } if (gradientsFunc != null) { tapeNode.gradient = function (dys) { // TODO(smilkov): To optimize back-prop, pass dys that are not used in // the backprop graph to the user as null instead of zeros dys = dys.map(function (dy, i) { if (dy == null) { var output = outputs[i]; var vals = makeZerosTypedArray(output.size, output.dtype); return _this.makeTensor(vals, output.shape, output.dtype); } return dy; }); // Grad functions of ops with single outputs expect a dy, while ops // with multiple outputs expect dys (array of dy). return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs); }; } this.state.activeTape.push(tapeNode); }; Engine.prototype.keep = function (result) { result.kept = true; return result; }; Engine.prototype.startTape = function () { if (this.state.gradientDepth === 0) { this.state.activeTape = []; } this.state.gradientDepth++; }; Engine.prototype.endTape = function () { this.state.gradientDepth--; }; /** * Start a scope. Use this with endScope() to achieve the same functionality * as scope() without the need for a function closure. */ Engine.prototype.startScope = function (name) { var scopeInfo = { track: [], name: 'unnamed scope', id: this.state.nextScopeId++ }; if (name) { scopeInfo.name = name; } this.state.scopeStack.push(scopeInfo); this.state.activeScope = scopeInfo; }; /** * End a scope. Use this with startScope() to achieve the same functionality * as scope() without the need for a function closure. */ Engine.prototype.endScope = function (result) { var _this = this; var tensorsToTrackInParent = getTensorsInContainer(result); var tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(function (t) { return t.id; })); // Dispose the arrays tracked in this scope. for (var i = 0; i < this.state.activeScope.track.length; i++) { var tensor = this.state.activeScope.track[i]; if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) { tensor.dispose(); } } var oldScope = this.state.scopeStack.pop(); this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1]; // Track the current result in the parent scope. tensorsToTrackInParent.forEach(function (tensor) { // Only track the tensor if was allocated in the inner scope and is not // globally kept. if (!tensor.kept && tensor.scopeId === oldScope.id) { _this.track(tensor); } }); }; /** * Returns gradients of `f` with respect to each of the `xs`. The gradients * returned are of the same length as `xs`, but some might be null if `f` * was not a function of that `x`. It also takes optional dy to multiply the * gradient, which defaults to `1`. */ Engine.prototype.gradients = function (f, xs, dy, allowNoGradients) { var _this = this; if (allowNoGradients === void 0) { allowNoGradients = false; } assert(xs.length > 0, function () { return 'gradients() received an empty list of xs.'; }); if (dy != null && dy.dtype !== 'float32') { throw new Error("dy must have 'float32' dtype, but has '".concat(dy.dtype, "'")); } var y = this.scopedRun(function () { return _this.startTape(); }, function () { return _this.endTape(); }, function () { return _this.tidy('forward', f); }); assert(y instanceof Tensor, function () { return 'The result y returned by f() must be a tensor.'; }); // Filter out the nodes that don't connect x => y. var filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y); if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) { throw new Error('Cannot compute gradient of y=f(x) with respect to x. Make sure ' + 'that the f you passed encloses all operations that lead from x ' + 'to y.'); } return this.tidy('backward', function () { var accumulatedGradientMap = {}; accumulatedGradientMap[y.id] = (dy == null) ? ones$2(y.shape) : dy; // Backprop gradients through the filtered nodes. backpropagateGradients(accumulatedGradientMap, filteredTape, // Pass the tidy function to avoid circular dep with `tape.ts`. function (f) { return _this.tidy(f); }, // Pass an add function to avoide a circular dep with `tape.ts`. add$2); var grads = xs.map(function (x) { return accumulatedGradientMap[x.id]; }); if (_this.state.gradientDepth === 0) { // This means that we are not computing higher-order gradients // and can clean up the tape. _this.state.activeTape.forEach(function (node) { var e_2, _a; try { for (var _b = __values(node.saved), _c = _b.next(); !_c.done; _c = _b.next()) { var tensor = _c.value; tensor.dispose(); } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_2) throw e_2.error; } } }); _this.state.activeTape = null; } return { value: y, grads: grads }; }); }; Engine.prototype.customGrad = function (f) { var _this = this; assert(isFunction(f), function () { return 'The f passed in customGrad(f) must be a function.'; }); return function () { var inputs = []; for (var _i = 0; _i < arguments.length; _i++) { inputs[_i] = arguments[_i]; } assert(inputs.every(function (t) { return t instanceof Tensor; }), function () { return 'The args passed in customGrad(f)(x1, x2,...) must all be ' + 'tensors'; }); var res; var inputMap = {}; inputs.forEach(function (input, i) { inputMap[i] = input; }); var forwardFunc = function (_, save) { res = f.apply(void 0, __spreadArray([], __read(__spreadArray(__spreadArray([], __read(inputs), false), [save], false)), false)); assert(res.value instanceof Tensor, function () { return 'The function f passed in customGrad(f) must return an ' + 'object where `obj.value` is a tensor'; }); assert(isFunction(res.gradFunc), function () { return 'The function f passed in customGrad(f) must return an ' + 'object where `obj.gradFunc` is a function.'; }); return res.value; }; var backwardsFunc = function (dy, saved) { var gradRes = res.gradFunc(dy, saved); var grads = Array.isArray(gradRes) ? gradRes : [gradRes]; assert(grads.length === inputs.length, function () { return 'The function f passed in customGrad(f) must return an ' + 'object where `obj.gradFunc` is a function that returns ' + 'the same number of tensors as inputs passed to f(...).'; }); assert(grads.every(function (t) { return t instanceof Tensor; }), function () { return 'The function f passed in customGrad(f) must return an ' + 'object where `obj.gradFunc` is a function that returns ' + 'a list of only tensors.'; }); var gradMap = {}; grads.forEach(function (grad, i) { gradMap[i] = function () { return grad; }; }); return gradMap; }; return _this.runKernelFunc({ forwardFunc: forwardFunc, backwardsFunc: backwardsFunc, inputs: inputMap, }); }; }; Engine.prototype.readSync = function (dataId) { // Route the read to the correct backend. var info = this.state.tensorInfo.get(dataId); return info.backend.readSync(dataId); }; Engine.prototype.read = function (dataId) { // Route the read to the correct backend. var info = this.state.tensorInfo.get(dataId); return info.backend.read(dataId); }; Engine.prototype.readToGPU = function (dataId, options) { // Route the read to the correct backend. var info = this.state.tensorInfo.get(dataId); return info.backend.readToGPU(dataId, options); }; Engine.prototype.time = function (query) { return __awaiter(this, void 0, void 0, function () { var start, timingInfo; return __generator(this, function (_a) { switch (_a.label) { case 0: start = now(); return [4 /*yield*/, this.backend.time(query)]; case 1: timingInfo = _a.sent(); timingInfo.wallMs = now() - start; return [2 /*return*/, timingInfo]; } }); }); }; /** * Tracks a Tensor in the current scope to be automatically cleaned up * when the current scope ends, and returns the value. * * @param result The Tensor to track in the current scope. */ Engine.prototype.track = function (result) { if (this.state.activeScope != null) { result.scopeId = this.state.activeScope.id; this.state.activeScope.track.push(result); } return result; }; Object.defineProperty(Engine.prototype, "registeredVariables", { get: function () { return this.state.registeredVariables; }, enumerable: false, configurable: true }); /** * Resets the engine state. Removes all backends but does not remove * registered backend factories. */ Engine.prototype.reset = function () { // Make any pending promise obsolete. this.pendingBackendInitId++; this.state.dispose(); this.ENV.reset(); this.state = new EngineState(); for (var backendName in this.registry) { this.disposeRegisteredKernels(backendName); this.registry[backendName].dispose(); delete this.registry[backendName]; } this.backendName = null; this.backendInstance = null; this.pendingBackendInit = null; }; return Engine; }()); Engine.nextTensorId = 0; Engine.nextVariableId = 0; function ones$2(shape) { var values = makeOnesTypedArray(sizeFromShape(shape), 'float32'); return ENGINE.makeTensor(values, shape, 'float32'); } function getOrMakeEngine() { var ns = getGlobalNamespace(); if (ns._tfengine == null) { var environment = new Environment(ns); ns._tfengine = new Engine(environment); } setEnvironmentGlobal(ns._tfengine.ENV); // Tell the current tensor interface that the global engine is responsible // for tracking. setTensorTracker(function () { return ns._tfengine; }); return ns._tfengine; } var ENGINE = getOrMakeEngine(); /** * A implementation of the add op for use within engine and tape. * * This allows us to avoid a circular dependency between add.ts and engine. * It is exported to be available in tape tests. */ function add$2(a, b) { // We duplicate Add here to avoid a circular dependency with add.ts. var inputs = { a: a, b: b }; return ENGINE.runKernel(Add$1, inputs); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function inferShape(val, dtype) { var firstElem = val; if (isTypedArray(val)) { return dtype === 'string' ? [] : [val.length]; } if (isWebGLData(val)) { var usedChannels = val.channels || 'RGBA'; return [val.height, val.width * usedChannels.length]; } else if (isWebGPUData(val)) { return [val.buffer.size / (dtype == null ? 4 : bytesPerElement(dtype))]; } if (!Array.isArray(val)) { return []; // Scalar. } var shape = []; while (Array.isArray(firstElem) || isTypedArray(firstElem) && dtype !== 'string') { shape.push(firstElem.length); firstElem = firstElem[0]; } if (Array.isArray(val) && env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) { deepAssertShapeConsistency(val, shape, []); } return shape; } function deepAssertShapeConsistency(val, shape, indices) { indices = indices || []; if (!(Array.isArray(val)) && !isTypedArray(val)) { assert(shape.length === 0, function () { return "Element arr[".concat(indices.join(']['), "] is a primitive, ") + "but should be an array/TypedArray of ".concat(shape[0], " elements"); }); return; } assert(shape.length > 0, function () { return "Element arr[".concat(indices.join(']['), "] should be a primitive, ") + "but is an array of ".concat(val.length, " elements"); }); assert(val.length === shape[0], function () { return "Element arr[".concat(indices.join(']['), "] should have ").concat(shape[0], " ") + "elements, but has ".concat(val.length, " elements"); }); var subShape = shape.slice(1); for (var i = 0; i < val.length; ++i) { deepAssertShapeConsistency(val[i], subShape, indices.concat(i)); } } function assertDtype(expectedDtype, actualDType, argName, functionName) { if (expectedDtype === 'string_or_numeric') { return; } if (expectedDtype == null) { throw new Error("Expected dtype cannot be null."); } if (expectedDtype !== 'numeric' && expectedDtype !== actualDType || expectedDtype === 'numeric' && actualDType === 'string') { throw new Error("Argument '".concat(argName, "' passed to '").concat(functionName, "' must ") + "be ".concat(expectedDtype, " tensor, but got ").concat(actualDType, " tensor")); } } function convertToTensor(x, argName, functionName, parseAsDtype) { if (parseAsDtype === void 0) { parseAsDtype = 'numeric'; } if (x instanceof getGlobalTensorClass()) { assertDtype(parseAsDtype, x.dtype, argName, functionName); return x; } var inferredDtype = inferDtype(x); // If the user expects a bool/int/float, use that info to update the // inferredDtype when it is not a string. if (inferredDtype !== 'string' && ['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) { inferredDtype = parseAsDtype; } assertDtype(parseAsDtype, inferredDtype, argName, functionName); if ((x == null) || (!isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' && typeof x !== 'boolean' && typeof x !== 'string')) { var type = x == null ? 'null' : x.constructor.name; throw new Error("Argument '".concat(argName, "' passed to '").concat(functionName, "' must be a ") + "Tensor or TensorLike, but got '".concat(type, "'")); } var inferredShape = inferShape(x, inferredDtype); if (!isTypedArray(x) && !Array.isArray(x)) { x = [x]; } var skipTypedArray = true; var values = inferredDtype !== 'string' ? toTypedArray(x, inferredDtype) : flatten$1(x, [], skipTypedArray); return ENGINE.makeTensor(values, inferredShape, inferredDtype); } function convertToTensorArray(arg, argName, functionName, parseAsDtype) { if (parseAsDtype === void 0) { parseAsDtype = 'numeric'; } if (!Array.isArray(arg)) { throw new Error("Argument ".concat(argName, " passed to ").concat(functionName, " must be a ") + '`Tensor[]` or `TensorLike[]`'); } var tensors = arg; return tensors.map(function (t, i) { return convertToTensor(t, "".concat(argName, "[").concat(i, "]"), functionName, parseAsDtype); }); } var OP_SCOPE_SUFFIX = '__op'; /** * Used for wrapping functions that perform math operations on * Tensors. The function will be wrapped in a named scope that cleans all * memory usage after the function is done. */ function op(f) { var keys = Object.keys(f); if (keys.length !== 1) { throw new Error("Please provide an object with a single key " + "(operation name) mapping to a function. Got an object with " + "".concat(keys.length, " keys.")); } var opName = keys[0]; var fn = f[opName]; // Strip the underscore from the end of the function name. if (opName.endsWith('_')) { opName = opName.substring(0, opName.length - 1); } // add an __op suffix to distinguish ops from kernels in tf.profile opName = opName + OP_SCOPE_SUFFIX; // tslint:disable-next-line:no-any var f2 = function () { var args = []; for (var _i = 0; _i < arguments.length; _i++) { args[_i] = arguments[_i]; } ENGINE.startScope(opName); try { var result = fn.apply(void 0, __spreadArray([], __read(args), false)); if (isPromise(result)) { console.error('Cannot return a Promise inside of tidy.'); } ENGINE.endScope(result); return result; } catch (ex) { ENGINE.endScope(null); throw ex; } }; Object.defineProperty(f2, 'name', { value: opName, configurable: true }); // tslint:disable-next-line:no-any return f2; } /** * @license * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Casts a `tf.Tensor` to a new dtype. * * ```js * const x = tf.tensor1d([1.5, 2.5, 3]); * tf.cast(x, 'int32').print(); * ``` * @param x The input tensor to be casted. * @param dtype The dtype to cast the input tensor to. * * @doc {heading: 'Tensors', subheading: 'Transformations'} */ function cast_(x, dtype) { var $x = convertToTensor(x, 'x', 'cast'); // Sanity checks. if (!isValidDtype(dtype)) { throw new Error("Failed to cast to unknown dtype ".concat(dtype)); } if (dtype === 'string' && $x.dtype !== 'string' || dtype !== 'string' && $x.dtype === 'string') { throw new Error('Only strings can be casted to strings'); } var inputs = { x: $x }; var attrs = { dtype: dtype }; return ENGINE.runKernel(Cast, inputs, attrs); } var cast = /* @__PURE__ */ op({ cast_: cast_ }); /** * Multiplies two `tf.Tensor`s element-wise, A * B. Supports broadcasting. * * We also expose `tf.mulStrict` which has the same signature as this op and * asserts that `a` and `b` are the same shape (does not broadcast). * * ```js * const a = tf.tensor1d([1, 2, 3, 4]); * const b = tf.tensor1d([2, 3, 4, 5]); * * a.mul(b).print(); // or tf.mul(a, b) * ``` * * ```js * // Broadcast mul a with b. * const a = tf.tensor1d([1, 2, 3, 4]); * const b = tf.scalar(5); * * a.mul(b).print(); // or tf.mul(a, b) * ``` * @param a The first tensor to multiply. * @param b The second tensor to multiply. Must have the same dtype as `a`. * * @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function mul_(a, b) { var _a; var $a = convertToTensor(a, 'a', 'mul'); var $b = convertToTensor(b, 'b', 'mul'); _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1]; var inputs = { a: $a, b: $b }; return ENGINE.runKernel(Multiply$1, inputs); } var mul = /* @__PURE__ */ op({ mul_: mul_ }); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes step of the input `tf.Tensor` element-wise: `x > 0 ? 1 : alpha` * * ```js * const x = tf.tensor1d([0, 2, -1, -3]); * * x.step(.5).print(); // or tf.step(x, .5) * ``` * @param x The input tensor. * @param alpha The gradient when input is negative. Defaults to 0. * * @doc {heading: 'Operations', subheading: 'Basic math'} */ function step_(x, alpha) { if (alpha === void 0) { alpha = 0.0; } var $x = convertToTensor(x, 'x', 'step'); var inputs = { x: $x }; var attrs = { alpha: alpha }; return ENGINE.runKernel(Step, inputs, attrs); } var step = /* @__PURE__ */ op({ step_: step_ }); var absGradConfig = { kernelName: Abs, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return mul(dy, step(cast(x, 'float32'), -1)); } }; } }; /** * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. * The result is rounded with floor function. * * * ```js * const a = tf.tensor1d([1, 4, 9, 16]); * const b = tf.tensor1d([1, 2, 3, 4]); * * a.floorDiv(b).print(); // or tf.div(a, b) * ``` * * ```js * // Broadcast div a with b. * const a = tf.tensor1d([2, 4, 6, 8]); * const b = tf.scalar(2); * * a.floorDiv(b).print(); // or tf.floorDiv(a, b) * ``` * * @param a The first tensor as the numerator. * @param b The second tensor as the denominator. Must have the same dtype as * `a`. * * @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function floorDiv_(a, b) { var _a; var $a = convertToTensor(a, 'a', 'floorDiv'); var $b = convertToTensor(b, 'b', 'floorDiv'); _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1]; var inputs = { a: $a, b: $b }; return ENGINE.runKernel(FloorDiv, inputs); } var floorDiv = /* @__PURE__ */ op({ floorDiv_: floorDiv_ }); /** * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. * * ```js * const a = tf.tensor1d([1, 4, 9, 16]); * const b = tf.tensor1d([1, 2, 3, 4]); * * a.div(b).print(); // or tf.div(a, b) * ``` * * ```js * // Broadcast div a with b. * const a = tf.tensor1d([2, 4, 6, 8]); * const b = tf.scalar(2); * * a.div(b).print(); // or tf.div(a, b) * ``` * * @param a The first tensor as the numerator. * @param b The second tensor as the denominator. Must have the same dtype as * `a`. * * @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function div_(a, b) { var _a; var $a = convertToTensor(a, 'a', 'div'); var $b = convertToTensor(b, 'b', 'div'); _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1]; if ($a.dtype === 'int32' && $b.dtype === 'int32') { return floorDiv($a, $b); } var inputs = { a: $a, b: $b }; var attrs = {}; // tslint:disable-next-line: no-unnecessary-type-assertion return ENGINE.runKernel(RealDiv, inputs, attrs); } var div = /* @__PURE__ */ op({ div_: div_ }); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes `-1 * x` element-wise. * * ```js * const x = tf.tensor2d([1, 2, -2, 0], [2, 2]); * * x.neg().print(); // or tf.neg(x) * ``` * * @param x The input tensor. * * @doc {heading: 'Operations', subheading: 'Basic math'} */ function neg_(x) { var $x = convertToTensor(x, 'x', 'neg'); var inputs = { x: $x }; return ENGINE.runKernel(Neg, inputs); } var neg = /* @__PURE__ */ op({ neg_: neg_ }); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** This is shared code across all tensor creation methods. */ function makeTensor(values, shape, inferredShape, dtype) { if (dtype == null) { dtype = inferDtype(values); } else if (dtype === 'complex64') { throw new Error("Cannot construct a complex64 tensor directly. " + "Please use tf.complex(real, imag)."); } if (isWebGPUData(values) || isWebGLData(values)) { if (dtype !== 'float32' && dtype !== 'int32') { throw new Error("Creating tensor from GPU data only supports " + "'float32'|'int32' dtype, while the dtype is ".concat(dtype, ".")); } return ENGINE.backend.createTensorFromGPUData(values, shape || inferredShape, dtype); } if (!isTypedArray(values) && !Array.isArray(values) && typeof values !== 'number' && typeof values !== 'boolean' && typeof values !== 'string') { throw new Error('values passed to tensor(values) must be a number/boolean/string or ' + 'an array of numbers/booleans/strings, or a TypedArray'); } // Verify that the shape matches the inferred shape. if (shape != null) { assertNonNegativeIntegerDimensions(shape); var providedSize_1 = sizeFromShape(shape); var inferredSize_1 = sizeFromShape(inferredShape); assert(providedSize_1 === inferredSize_1, function () { return "Based on the provided shape, [".concat(shape, "], the tensor should have ") + "".concat(providedSize_1, " values but has ").concat(inferredSize_1); }); for (var i = 0; i < inferredShape.length; ++i) { var inferred = inferredShape[i]; var flatDimsDontMatch = i === inferredShape.length - 1 ? inferred !== sizeFromShape(shape.slice(i)) : true; assert(inferredShape[i] === shape[i] || !flatDimsDontMatch, function () { return "Error creating a new Tensor. Inferred shape " + "(".concat(inferredShape, ") does not match the provided ") + "shape (".concat(shape, "). "); }); } } if (!isTypedArray(values) && !Array.isArray(values)) { values = [values]; } shape = shape || inferredShape; values = dtype !== 'string' ? toTypedArray(values, dtype) : flatten$1(values, [], true); return ENGINE.makeTensor(values, shape, dtype); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Creates rank-0 `tf.Tensor` (scalar) with the provided value and dtype. * * The same functionality can be achieved with `tf.tensor`, but in general * we recommend using `tf.scalar` as it makes the code more readable. * * ```js * tf.scalar(3.14).print(); * ``` * * @param value The value of the scalar. * @param dtype The data type. * * @doc {heading: 'Tensors', subheading: 'Creation'} */ function scalar(value, dtype) { if (((isTypedArray(value) && dtype !== 'string') || Array.isArray(value)) && dtype !== 'complex64') { throw new Error('Error creating a new Scalar: value must be a primitive ' + '(number|boolean|string)'); } if (dtype === 'string' && isTypedArray(value) && !(value instanceof Uint8Array)) { throw new Error('When making a scalar from encoded string, ' + 'the value must be `Uint8Array`.'); } var shape = []; var inferredShape = []; return makeTensor(value, shape, inferredShape, dtype); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes square root of the input `tf.Tensor` element-wise: `y = sqrt(x)` * * ```js * const x = tf.tensor1d([1, 2, 4, -1]); * * x.sqrt().print(); // or tf.sqrt(x) * ``` * @param x The input tensor. * * @doc {heading: 'Operations', subheading: 'Basic math'} */ function sqrt_(x) { var $x = convertToTensor(x, 'x', 'sqrt', 'float32'); var inputs = { x: $x }; return ENGINE.runKernel(Sqrt, inputs); } var sqrt = /* @__PURE__ */ op({ sqrt_: sqrt_ }); /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes square of `x` element-wise: `x ^ 2` * * ```js * const x = tf.tensor1d([1, 2, Math.sqrt(2), -1]); * * x.square().print(); // or tf.square(x) * ``` * @param x The input Tensor. * * @doc {heading: 'Operations', subheading: 'Basic math'} */ function square_(x) { var $x = convertToTensor(x, 'x', 'square'); var attrs = {}; return ENGINE.runKernel('Square', { x: $x }, attrs); } var square = /* @__PURE__ */ op({ square_: square_ }); /** * Subtracts two `tf.Tensor`s element-wise, A - B. Supports broadcasting. * * ```js * const a = tf.tensor1d([10, 20, 30, 40]); * const b = tf.tensor1d([1, 2, 3, 4]); * * a.sub(b).print(); // or tf.sub(a, b) * ``` * * ```js * // Broadcast subtract a with b. * const a = tf.tensor1d([10, 20, 30, 40]); * const b = tf.scalar(5); * * a.sub(b).print(); // or tf.sub(a, b) * ``` * @param a The first `tf.Tensor` to subtract from. * @param b The second `tf.Tensor` to be subtracted. Must have the same dtype as * `a`. * * @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function sub_(a, b) { var _a; var $a = convertToTensor(a, 'a', 'sub'); var $b = convertToTensor(b, 'b', 'sub'); _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1]; var inputs = { a: $a, b: $b }; return ENGINE.runKernel(Sub, inputs); } var sub = /* @__PURE__ */ op({ sub_: sub_ }); var acosGradConfig = { kernelName: Acos, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { var a = square(cast(x, 'float32')); var b = sqrt(sub(scalar(1), a)); return neg(div(dy, b)); } }; } }; var acoshGradConfig = { kernelName: Acosh, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { var a = sqrt(sub(square(cast(x, 'float32')), 1)); return div(dy, a); } }; } }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Returns the axes in the output space that should be reduced to produce * the input space. */ function getReductionAxes(inShape, outShape) { var result = []; for (var i = 0; i < outShape.length; i++) { var inDim = inShape[inShape.length - i - 1]; var outAxis = outShape.length - i - 1; var outDim = outShape[outAxis]; if (inDim == null || (inDim === 1 && outDim > 1)) { result.unshift(outAxis); } } return result; } function assertAndGetBroadcastShape(shapeA, shapeB) { var l = Math.max(shapeA.length, shapeB.length); var result = new Array(l); for (var i = 0; i < l; i++) { var a = shapeA[shapeA.length - i - 1]; if (a == null) { a = 1; } var b = shapeB[shapeB.length - i - 1]; if (b == null) { b = 1; } if (a === 1) { result[l - i - 1] = b; } else if (b === 1) { result[l - i - 1] = a; } else if (a !== b) { var errMsg = "Operands could not be broadcast together with shapes " + "".concat(shapeA, " and ").concat(shapeB, "."); throw Error(errMsg); } else { result[l - i - 1] = a; } } return result; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Reshapes a `tf.Tensor` to a given shape. * * Given an input tensor, returns a new tensor with the same values as the * input tensor with shape `shape`. * * If one component of shape is the special value -1, the size of that * dimension is computed so that the total size remains constant. In * particular, a shape of [-1] flattens into 1-D. At most one component of * shape can be -1. * * If shape is 1-D or higher, then the operation returns a tensor with shape * shape filled with the values of tensor. In this case, the number of * elements implied by shape must be the same as the number of elements in * tensor. * * ```js * const x = tf.tensor1d([1, 2, 3, 4]); * x.reshape([2, 2]).print(); * ``` * * @param x The input tensor to be reshaped. * @param shape An array of integers defining the output tensor shape. * * @doc {heading: 'Tensors', subheading: 'Transformations'} */ function reshape_(x, shape) { var $x = convertToTensor(x, 'x', 'reshape', 'string_or_numeric'); var inputs = { x: $x }; var attrs = { shape: shape }; return ENGINE.runKernel(Reshape$1, inputs, attrs); } var reshape$1 = /* @__PURE__ */ op({ reshape_: reshape_ }); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes the sum of elements across dimensions of a `tf.Tensor`. * * Reduces the input along the dimensions given in `axes`. Unless `keepDims` * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in * `axes`. If `keepDims` is true, the reduced dimensions are retained with * length 1. If axes has no entries, all dimensions are reduced, and a * `tf.Tensor` with a single element is returned. * * ```js * const x = tf.tensor1d([1, 2, 3]); * * x.sum().print(); // or tf.sum(x) * ``` * * ```js * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); * * const axis = 1; * x.sum(axis).print(); // or tf.sum(x, axis) * ``` * * @param x The input tensor to compute the sum over. If the dtype is `bool` * it will be converted to `int32` and the output dtype will be `int32`. * @param axis The dimension(s) to reduce. By default it reduces * all dimensions. * @param keepDims If true, retains reduced dimensions with size 1. * * @doc {heading: 'Operations', subheading: 'Reduction'} */ function sum_(x, axis, keepDims) { if (axis === void 0) { axis = null; } if (keepDims === void 0) { keepDims = false; } var $x = convertToTensor(x, 'x', 'sum'); if ($x.dtype === 'bool') { $x = cast($x, 'int32'); } var inputs = { x: $x }; var attrs = { axis: axis, keepDims: keepDims }; return ENGINE.runKernel(Sum, inputs, attrs); } var sum = /* @__PURE__ */ op({ sum_: sum_ }); var addGradConfig = { kernelName: Add$1, inputsToSave: ['a', 'b'], gradFunc: function (dy, saved) { var _a = __read(saved, 2), a = _a[0], b = _a[1]; var outShape = assertAndGetBroadcastShape(a.shape, b.shape); var derA = function () { var res = dy; var reduceAxes = getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { res = sum(res, reduceAxes); } return reshape$1(res, a.shape); }; var derB = function () { var res = dy; var reduceAxes = getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { res = sum(res, reduceAxes); } return reshape$1(res, b.shape); }; return { a: derA, b: derB }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var addNGradConfig = { kernelName: AddN, saveAllInputs: true, gradFunc: function (dy, saved) { var ders = {}; saved.forEach(function (_, i) { ders[i] = function () { return dy.clone(); }; }); return ders; } }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Creates a `tf.Tensor` with all elements set to 0 with the same shape as the * given tensor. * * ```js * const x = tf.tensor([1, 2]); * tf.zerosLike(x).print(); * ``` * * @param x The tensor of required shape. * * @doc {heading: 'Tensors', subheading: 'Creation'} */ function zerosLike_(x) { var $x = convertToTensor(x, 'x', 'zerosLike'); var inputs = { x: $x }; return ENGINE.runKernel(ZerosLike, inputs); } var zerosLike = /* @__PURE__ */ op({ zerosLike_: zerosLike_ }); var argMaxGradConfig = { kernelName: ArgMax, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return zerosLike(x); } }; } }; var argMinGradConfig = { kernelName: ArgMin, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return zerosLike(x); } }; } }; var asinGradConfig = { kernelName: Asin, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return div(dy, sqrt(sub(scalar(1), square(cast(x, 'float32'))))); } }; } }; /** * Adds two `tf.Tensor`s element-wise, A + B. Supports broadcasting. * * * ```js * const a = tf.tensor1d([1, 2, 3, 4]); * const b = tf.tensor1d([10, 20, 30, 40]); * * a.add(b).print(); // or tf.add(a, b) * ``` * * ```js * // Broadcast add a with b. * const a = tf.scalar(5); * const b = tf.tensor1d([10, 20, 30, 40]); * * a.add(b).print(); // or tf.add(a, b) * ``` * @param a The first `tf.Tensor` to add. * @param b The second `tf.Tensor` to add. Must have the same type as `a`. * * @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function add_(a, b) { var _a; var $a = convertToTensor(a, 'a', 'add'); var $b = convertToTensor(b, 'b', 'add'); _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1]; var inputs = { a: $a, b: $b }; return ENGINE.runKernel(Add$1, inputs); } var add$1 = /* @__PURE__ */ op({ add_: add_ }); var asinhGradConfig = { kernelName: Asinh, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { var a = sqrt(add$1(scalar(1), square(cast(x, 'float32')))); return div(dy, a); } }; } }; var atan2GradConfig = { kernelName: Atan2, inputsToSave: ['a', 'b'], gradFunc: function (dy, saved) { var _a = __read(saved, 2), a = _a[0], b = _a[1]; var outShape = assertAndGetBroadcastShape(a.shape, b.shape); var derA = function () { var d = add$1(square(a), square(b)); var res = mul(dy, div(b, d)); var reduceAxes = getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { res = sum(res, reduceAxes); } return reshape$1(res, a.shape); }; var derB = function () { var d = add$1(square(a), square(b)); var res = neg(mul(dy, div(a, d))); var reduceAxes = getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { res = sum(res, reduceAxes); } return reshape$1(res, b.shape); }; return { a: derA, b: derB }; } }; var atanGradConfig = { kernelName: Atan, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return div(dy, add$1(square(cast(x, 'float32')), 1)); } }; } }; var atanhGradConfig = { kernelName: Atanh, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return div(dy, sub(scalar(1), square(cast(x, 'float32')))); } }; } }; function parseTupleParam(param) { if (typeof param === 'number') { return [param, param, param]; } if (param.length === 2) { return [param[0], param[1], 1]; } return param; } function tupleValuesAreOne(param) { var _a = __read(parseTupleParam(param), 3), dimA = _a[0], dimB = _a[1], dimC = _a[2]; return dimA === 1 && dimB === 1 && dimC === 1; } function eitherStridesOrDilationsAreOne(strides, dilations) { return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations); } function stridesOrDilationsArePositive(values) { return parseTupleParam(values).every(function (value) { return value > 0; }); } /** * Check validity of pad when using dimRoundingMode. * @param opDesc A string of op description * @param pad The type of padding algorithm. * - `same` and stride 1: output will be of same size as input, * regardless of filter size. * - `valid` output will be smaller than input if filter is larger * than 1x1. * - For more info, see this guide: * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution]( * https://www.tensorflow.org/api_docs/python/tf/nn/convolution) * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is * provided, it will default to truncate. * @throws unknown padding parameter */ function checkPadOnDimRoundingMode(opDesc, pad, dimRoundingMode) { if (dimRoundingMode != null) { if (typeof pad === 'string') { throw Error("Error in ".concat(opDesc, ": pad must be an integer when using ") + "dimRoundingMode ".concat(dimRoundingMode, " but got pad ").concat(pad, ".")); } else if (typeof pad === 'number') { assert(isInt(pad), function () { return "Error in ".concat(opDesc, ": pad must be an integer when using ") + "dimRoundingMode ".concat(dimRoundingMode, " but got pad ").concat(pad, "."); }); } else if (typeof pad === 'object') { pad.forEach(function (p) { p.forEach(function (v) { assert(isInt(v), function () { return "Error in ".concat(opDesc, ": pad must be an integer when using ") + "dimRoundingMode ".concat(dimRoundingMode, " but got pad ").concat(v, "."); }); }); }); } else { throw Error("Error in ".concat(opDesc, ": Unknown padding parameter: ").concat(pad)); } } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes the backprop of a 3d avg pool. * * @param dy The dy error, of rank 5 of shape * [batchSize, depth, height, width, channels]. * assumed. * @param input The original input image, of rank 5 or rank4 of shape * [batchSize, depth, height, width, channels]. * @param filterSize The filter size: * `[filterDepth, filterHeight, filterWidth]`. * `filterSize` is a single number, * then `filterDepth == filterHeight == filterWidth`. * @param strides The strides of the pooling: * `[strideDepth, strideHeight, strideWidth]`. If * `strides` is a single number, then `strideHeight == strideWidth`. * @param pad A string from: 'same', 'valid'. The type of padding algorithm * used in the forward prop of the op. * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is * provided, it will default to truncate. */ function avgPool3dGrad_(dy, input, filterSize, strides, pad, dimRoundingMode) { var $dy = convertToTensor(dy, 'dy', 'avgPool3dGrad'); var $input = convertToTensor(input, 'input', 'avgPool3dGrad'); var dy5D = $dy; var input5D = $input; var reshapedTo5D = false; if ($input.rank === 4) { reshapedTo5D = true; dy5D = reshape$1($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]); input5D = reshape$1($input, [ 1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3] ]); } assert(dy5D.rank === 5, function () { return "Error in avgPool3dGrad: dy must be rank 5 but got rank " + "".concat(dy5D.rank, "."); }); assert(input5D.rank === 5, function () { return "Error in avgPool3dGrad: input must be rank 5 but got rank " + "".concat(input5D.rank, "."); }); checkPadOnDimRoundingMode('avgPool3dGrad', pad, dimRoundingMode); var inputs = { dy: dy5D, input: input5D }; var attrs = { filterSize: filterSize, strides: strides, pad: pad, dimRoundingMode: dimRoundingMode }; // tslint:disable-next-line: no-unnecessary-type-assertion var res = ENGINE.runKernel(AvgPool3DGrad, inputs, attrs); if (reshapedTo5D) { return reshape$1(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]); } return res; } var avgPool3dGrad = /* @__PURE__ */ op({ avgPool3dGrad_: avgPool3dGrad_ }); var avgPool3DGradConfig = { kernelName: AvgPool3D, inputsToSave: ['x'], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 1), x = _a[0]; var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode; return { x: function () { return avgPool3dGrad(dy, x, filterSize, strides, pad, dimRoundingMode); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes the backprop of an 2D avg pool. * * @param dy The dy error, of rank 4 or rank 3 of shape * [batchSize, height, width, channels]. If rank 3, batch of 1 is * assumed. * @param input The input image, of rank 4 or rank 3 of shape * [batchSize, height, width, channels]. If rank 3, batch of 1 is * assumed. * @param filterSize The filter size: `[filterHeight, filterWidth]`. If * `filterSize` is a single number, then `filterHeight == filterWidth`. * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If * `strides` is a single number, then `strideHeight == strideWidth`. * @param pad The type of padding algorithm used in the forward prop of the op. * 'same', 'valid', for more info, see this guide: * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution]( * https://www.tensorflow.org/api_docs/python/tf/nn/convolution) */ function avgPoolGrad_(dy, input, filterSize, strides, pad) { var $dy = convertToTensor(dy, 'dy', 'avgPoolGrad'); var $input = convertToTensor(input, 'input', 'avgPoolGrad'); assert($input.rank === $dy.rank, function () { return "Rank of input (".concat($input.rank, ") does not match rank of dy (").concat($dy.rank, ")"); }); var input4D = $input; var dy4D = $dy; var reshapedTo4D = false; if ($input.rank === 3) { reshapedTo4D = true; input4D = reshape$1($input, [1, $input.shape[0], $input.shape[1], $input.shape[2]]); dy4D = reshape$1($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2]]); } assert(dy4D.rank === 4, function () { return "Error in avgPoolGrad: dy must be rank 4 but got rank " + "".concat(dy4D.rank, "."); }); assert(input4D.rank === 4, function () { return "Error in avgPoolGrad: input must be rank 4 but got rank " + "".concat(input4D.rank, "."); }); var inputs = { dy: dy4D, input: input4D }; var attrs = { filterSize: filterSize, strides: strides, pad: pad }; // tslint:disable-next-line: no-unnecessary-type-assertion var res = ENGINE.runKernel(AvgPoolGrad, inputs, attrs); if (reshapedTo4D) { return reshape$1(res, [res.shape[1], res.shape[2], res.shape[3]]); } return res; } var avgPoolGrad = /* @__PURE__ */ op({ avgPoolGrad_: avgPoolGrad_ }); var avgPoolGradConfig = { kernelName: AvgPool, inputsToSave: ['x'], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 1), x = _a[0]; var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad; return { x: function () { return avgPoolGrad(dy, x, filterSize, strides, pad); } }; } }; /** * Computes the dot product of two matrices, A * B. These must be matrices. * * ```js * const a = tf.tensor2d([1, 2], [1, 2]); * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]); * * a.matMul(b).print(); // or tf.matMul(a, b) * ``` * @param a First matrix in dot product operation. * @param b Second matrix in dot product operation. * @param transposeA If true, `a` is transposed before multiplication. * @param transposeB If true, `b` is transposed before multiplication. * * @doc {heading: 'Operations', subheading: 'Matrices'} */ function matMul_(a, b, transposeA, transposeB) { var _a; if (transposeA === void 0) { transposeA = false; } if (transposeB === void 0) { transposeB = false; } var $a = convertToTensor(a, 'a', 'matMul'); var $b = convertToTensor(b, 'b', 'matMul'); _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1]; var inputs = { a: $a, b: $b }; var attrs = { transposeA: transposeA, transposeB: transposeB }; return ENGINE.runKernel(BatchMatMul, inputs, attrs); } var matMul = /* @__PURE__ */ op({ matMul_: matMul_ }); var batchMatMulGradConfig = { kernelName: BatchMatMul, inputsToSave: ['a', 'b'], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 2), a = _a[0], b = _a[1]; var transposeA = attrs.transposeA, transposeB = attrs.transposeB; if (!transposeA && !transposeB) { return { a: function () { return matMul(dy, b, false, true); }, b: function () { return matMul(a, dy, true, false); } }; } else if (!transposeA && transposeB) { return { a: function () { return matMul(dy, b, false, false); }, b: function () { return matMul(dy, a, true, false); } }; } else if (transposeA && !transposeB) { return { a: function () { return matMul(b, dy, false, true); }, b: function () { return matMul(a, dy, false, false); } }; } else { return { a: function () { return matMul(b, dy, true, true); }, b: function () { return matMul(dy, a, true, true); } }; } } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * This operation divides "spatial" dimensions `[1, ..., M]` of the input into * a grid of blocks of shape `blockShape`, and interleaves these blocks with * the "batch" dimension (0) such that in the output, the spatial * dimensions `[1, ..., M]` correspond to the position within the grid, * and the batch dimension combines both the position within a spatial block * and the original batch position. Prior to division into blocks, * the spatial dimensions of the input are optionally zero padded * according to `paddings`. See below for a precise description. * * ```js * const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); * const blockShape = [2, 2]; * const paddings = [[0, 0], [0, 0]]; * * x.spaceToBatchND(blockShape, paddings).print(); * ``` * * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape + * remainingShape`, where spatialShape has `M` dimensions. * @param blockShape A 1-D array. Must have shape `[M]`, all values must * be >= 1. * @param paddings A 2-D array. Must have shape `[M, 2]`, all values must be >= * 0. `paddings[i] = [padStart, padEnd]` specifies the amount to zero-pad * from input dimension `i + 1`, which corresponds to spatial dimension `i`. It * is required that * `(inputShape[i + 1] + padStart + padEnd) % blockShape[i] === 0` * * This operation is equivalent to the following steps: * * 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the input * according to `paddings` to produce `padded` of shape paddedShape. * * 2. Reshape `padded` to `reshapedPadded` of shape: * `[batch] + [paddedShape[1] / blockShape[0], blockShape[0], ..., * paddedShape[M] / blockShape[M-1], blockShape[M-1]] + remainingShape` * * 3. Permute dimensions of `reshapedPadded` to produce `permutedReshapedPadded` * of shape: `blockShape + [batch] + [paddedShape[1] / blockShape[0], ..., * paddedShape[M] / blockShape[M-1]] + remainingShape` * * 4. Reshape `permutedReshapedPadded` to flatten `blockShape` into the * batch dimension, producing an output tensor of shape: * `[batch * prod(blockShape)] + [paddedShape[1] / blockShape[0], ..., * paddedShape[M] / blockShape[M-1]] + remainingShape` * * @doc {heading: 'Tensors', subheading: 'Transformations'} */ function spaceToBatchND_(x, blockShape, paddings) { var $x = convertToTensor(x, 'x', 'spaceToBatchND'); assert($x.rank >= 1 + blockShape.length, function () { return "input rank ".concat($x.rank, " should be > than [blockShape] ").concat(blockShape.length); }); assert(paddings.length === blockShape.length, function () { return "paddings.shape[0] ".concat(paddings.length, " must be equal to [blockShape] ").concat(blockShape.length); }); assert($x.shape.reduce(function (a, b, i) { if (i > 0 && i <= blockShape.length) { return a && ((b + paddings[i - 1][0] + paddings[i - 1][1]) % blockShape[i - 1] === 0); } return a; }, true), function () { return "input spatial dimensions ".concat($x.shape.slice(1), " with paddings ").concat(paddings.toString(), " must be divisible by blockShapes ").concat(blockShape.toString()); }); var inputs = { x: $x }; var attrs = { blockShape: blockShape, paddings: paddings }; return ENGINE.runKernel(SpaceToBatchND, inputs, attrs); } var spaceToBatchND = /* @__PURE__ */ op({ spaceToBatchND_: spaceToBatchND_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var batchToSpaceNDGradConfig = { kernelName: BatchToSpaceND, gradFunc: function (dy, saved, attrs) { var blockShape = attrs.blockShape, crops = attrs.crops; return { x: function () { return spaceToBatchND(dy, blockShape, crops); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var broadcastToGradConfig = { kernelName: BroadcastTo, gradFunc: function (dy, saved, attrs) { var broadCastToAttrs = attrs; var inputShape = broadCastToAttrs.inputShape; var outputShape = broadCastToAttrs.shape; var reps = Array.from(outputShape); for (var i = inputShape.length - 1; i >= 0; i--) { if (inputShape[i] === outputShape[i]) { reps[i] = 1; } else if (inputShape[i] !== 1) { throw new Error("broadcastTo(): [".concat(inputShape, "] cannot be broadcast to [").concat(outputShape, "].")); } } var axes = []; for (var i = 0; i < reps.length; i++) { if (reps[i] > 1) { axes.push(i); } } return { x: function () { return sum(dy, axes, true /* keepDims */); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var castGradConfig = { kernelName: Cast, gradFunc: function (dy) { return { x: function () { return dy.clone(); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ceilGradConfig = { kernelName: Ceil, gradFunc: function (dy) { // TODO(manrajgrover): Return null for gradients when backprop supports it. return { x: function () { return zerosLike(dy); } }; } }; /** * Returns the truth value of (a >= b) element-wise. Supports broadcasting. * * ```js * const a = tf.tensor1d([1, 2, 3]); * const b = tf.tensor1d([2, 2, 2]); * * a.greaterEqual(b).print(); * ``` * * @param a The first input tensor. * @param b The second input tensor. Must have the same dtype as `a`. * * @doc {heading: 'Operations', subheading: 'Logical'} */ function greaterEqual_(a, b) { var _a; var $a = convertToTensor(a, 'a', 'greaterEqual', 'string_or_numeric'); var $b = convertToTensor(b, 'b', 'greaterEqual', 'string_or_numeric'); _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1]; assertAndGetBroadcastShape($a.shape, $b.shape); var inputs = { a: $a, b: $b }; return ENGINE.runKernel(GreaterEqual, inputs); } var greaterEqual = /* @__PURE__ */ op({ greaterEqual_: greaterEqual_ }); /** * Returns the truth value of (a <= b) element-wise. Supports broadcasting. * * ```js * const a = tf.tensor1d([1, 2, 3]); * const b = tf.tensor1d([2, 2, 2]); * * a.lessEqual(b).print(); * ``` * * @param a The first input tensor. * @param b The second input tensor. Must have the same dtype as `a`. * * @doc {heading: 'Operations', subheading: 'Logical'} */ function lessEqual_(a, b) { var _a; var $a = convertToTensor(a, 'a', 'lessEqual', 'string_or_numeric'); var $b = convertToTensor(b, 'b', 'lessEqual', 'string_or_numeric'); _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1]; assertAndGetBroadcastShape($a.shape, $b.shape); var inputs = { a: $a, b: $b }; return ENGINE.runKernel(LessEqual, inputs); } var lessEqual = /* @__PURE__ */ op({ lessEqual_: lessEqual_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Returns the truth value of `a AND b` element-wise. Supports broadcasting. * * ```js * const a = tf.tensor1d([false, false, true, true], 'bool'); * const b = tf.tensor1d([false, true, false, true], 'bool'); * * a.logicalAnd(b).print(); * ``` * * @param a The first input tensor. Must be of dtype bool. * @param b The second input tensor. Must be of dtype bool. * * @doc {heading: 'Operations', subheading: 'Logical'} */ function logicalAnd_(a, b) { var $a = convertToTensor(a, 'a', 'logicalAnd', 'bool'); var $b = convertToTensor(b, 'b', 'logicalAnd', 'bool'); assertAndGetBroadcastShape($a.shape, $b.shape); var inputs = { a: $a, b: $b }; return ENGINE.runKernel(LogicalAnd, inputs); } var logicalAnd = /* @__PURE__ */ op({ logicalAnd_: logicalAnd_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Creates a new tensor with the same values and shape as the specified * tensor. * * ```js * const x = tf.tensor([1, 2]); * * x.clone().print(); * ``` * * @param x The tensor to clone. * * @doc {heading: 'Tensors', subheading: 'Creation'} */ function clone_(x) { var $x = convertToTensor(x, 'x', 'clone', 'string_or_numeric'); var inputs = { x: $x }; // Note this op is called tf.identity in python. Hence the kernel name used // here. return ENGINE.runKernel(Identity, inputs); } var clone = /* @__PURE__ */ op({ clone_: clone_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Broadcast an array to a compatible shape NumPy-style. * * The tensor's shape is compared to the broadcast shape from end to beginning. * Ones are prepended to the tensor's shape until it has the same length as * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then * the input tensor is tiled N times along that axis (using tf.tile). * * @param input The tensor that is to be broadcasted. * @param shape The input is to be broadcast to this shape. * * @doc {heading: 'Tensors', subheading: 'Transformations'} */ function broadcastTo_(x, shape) { var input = convertToTensor(x, 'broadcastTo', 'x'); var xShape = input.shape; assertNonNegativeIntegerDimensions(shape); if (shape.length < input.rank) { throw new Error("broadcastTo(): shape.length=".concat(shape.length, " < input.rank=").concat(input.rank, ".")); } if (shape.length > input.rank) { var newShape = input.shape.slice(); while (newShape.length < shape.length) { newShape.unshift(1); } input = reshape$1(input, newShape); } var inputShape = input.shape; var reps = Array.from(shape); for (var i = shape.length - 1; i >= 0; i--) { if (inputShape[i] === shape[i]) { reps[i] = 1; } else if (input.shape[i] !== 1) { throw new Error("broadcastTo(): [".concat(xShape, "] cannot be broadcast to [").concat(shape, "].")); } } var axes = reps.map(function (n, i) { return n > 1 ? i : -1; }).filter(function (i) { return i >= 0; }); if (axes.length === 0) { return clone(input); } // TODO call broadcastTo kernel directly once backends implement broadcstTo var inputs = { x: input }; var attrs = { reps: reps }; return ENGINE.runKernel(Tile, inputs, attrs); } var broadcastTo = /* @__PURE__ */ op({ broadcastTo_: broadcastTo_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Returns the elements, either `a` or `b` depending on the `condition`. * * If the condition is true, select from `a`, otherwise select from `b`. * * ```js * const cond = tf.tensor1d([false, false, true], 'bool'); * const a = tf.tensor1d([1 , 2, 3]); * const b = tf.tensor1d([-1, -2, -3]); * * a.where(cond, b).print(); * ``` * * @param condition The input condition. Must be of dtype bool. * @param a If `condition` is rank 1, `a` may have a higher rank but * its first dimension must match the size of `condition`. * @param b A tensor with the same dtype as `a` and with shape that is * compatible with `a`. * @return A tensor with same dtype as `a` and `b`, and shape that is * broadcastable from `a` and `b`. * * @doc {heading: 'Operations', subheading: 'Logical'} */ function where_(condition, a, b) { var $a = convertToTensor(a, 'a', 'where'); var $b = convertToTensor(b, 'b', 'where'); var $condition = convertToTensor(condition, 'condition', 'where', 'bool'); // TODO: move this logic to forward function when the broadcastTo op is // implemented in WASM. // Find the broadcastable shape for $condition, $a, and $b. var broadcastShape = assertAndGetBroadcastShape(assertAndGetBroadcastShape($condition.shape, $a.shape), $b.shape); var $broadcastedCondition = broadcastTo($condition, broadcastShape); var $broadcastedA = broadcastTo($a, broadcastShape); var $broadcastedB = broadcastTo($b, broadcastShape); var inputs = { condition: $broadcastedCondition, t: $broadcastedA, e: $broadcastedB }; return ENGINE.runKernel(Select, inputs); } var where = /* @__PURE__ */ op({ where_: where_ }); var clipByValueGradConfig = { kernelName: ClipByValue, inputsToSave: ['x'], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 1), x = _a[0]; var clipValueMin = attrs.clipValueMin, clipValueMax = attrs.clipValueMax; return { x: function () { return where(logicalAnd(greaterEqual(x, clipValueMin), lessEqual(x, clipValueMax)), dy, zerosLike(dy)); }, }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var complexAbsGradConfig = { kernelName: ComplexAbs, inputsToSave: ['x'], gradFunc: absGradConfig.gradFunc, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Splits a `tf.Tensor` into sub tensors. * * If `numOrSizeSplits` is a number, splits `x` along dimension `axis` * into `numOrSizeSplits` smaller tensors. * Requires that `numOrSizeSplits` evenly divides `x.shape[axis]`. * * If `numOrSizeSplits` is a number array, splits `x` into * `numOrSizeSplits.length` pieces. The shape of the `i`-th piece has the * same size as `x` except along dimension `axis` where the size is * `numOrSizeSplits[i]`. * * ```js * const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); * const [a, b] = tf.split(x, 2, 1); * a.print(); * b.print(); * * const [c, d, e] = tf.split(x, [1, 2, 1], 1); * c.print(); * d.print(); * e.print(); * ``` * * @param x The input tensor to split. * @param numOrSizeSplits Either an integer indicating the number of * splits along the axis or an array of integers containing the sizes of * each output tensor along the axis. If a number then it must evenly divide * `x.shape[axis]`; otherwise the sum of sizes must match `x.shape[axis]`. * Can contain one -1 indicating that dimension is to be inferred. * @param axis The dimension along which to split. Defaults to 0 (the first * dim). * * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ function split_(x, numOrSizeSplits, axis) { if (axis === void 0) { axis = 0; } var $x = convertToTensor(x, 'x', 'split'); var inputs = { x: $x }; var attr = { numOrSizeSplits: numOrSizeSplits, axis: axis }; return ENGINE.runKernel(SplitV, inputs, attr); } var split = /* @__PURE__ */ op({ split_: split_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var concatGradConfig = { kernelName: Concat, saveAllInputs: true, gradFunc: function (dy, saved, attrs) { var shapes = saved.map(function (t) { return t.shape; }); var axis = attrs.axis; var $axis = parseAxisParam(axis, saved[0].shape)[0]; var sizeSplits = shapes.map(function (s) { return s[$axis]; }); var derTensors = split(dy, sizeSplits, $axis); return derTensors.map(function (t) { return function () { return t; }; }); } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes the derivative of the filter of a 2D convolution. * * @param x The input tensor, of rank 4 or rank 3 of shape * [batch, height, width, inChannels]. If rank 3, batch of 1 is assumed. * @param dy The dy image, of rank 4 or rank 3, of shape * [batch, height, width, outDepth]. If rank 3, batch of 1 is assumed. * @param filterShape The shape of the filter, length 4, * [filterHeight, filterWidth, inDepth, outDepth]. * @param strides The strides of the convolution: [strideHeight, * strideWidth]. * @param pad A string from: 'same', 'valid'. The type of padding algorithm * used in the forward prop of the op. * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to * "NHWC". Specify the data format of the input and output data. With the * default format "NHWC", the data is stored in the order of: [batch, * height, width, channels]. * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is * provided, it will default to truncate. */ function conv2DBackpropFilter_(x, dy, filterShape, strides, pad, dataFormat, dimRoundingMode) { if (dataFormat === void 0) { dataFormat = 'NHWC'; } var x4D = x; if (x.rank === 3) { x4D = reshape$1(x, [1, x.shape[0], x.shape[1], x.shape[2]]); } var dy4D = dy; if (dy4D.rank === 3) { dy4D = reshape$1(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]); } assert(x4D.rank === 4, function () { return "Error in conv2dDerFilter: input must be rank 4, but got shape " + "".concat(x4D.shape, "."); }); assert(dy4D.rank === 4, function () { return "Error in conv2dDerFilter: dy must be rank 4, but got shape " + "".concat(dy4D.shape, "."); }); assert(filterShape.length === 4, function () { return "Error in conv2dDerFilter: filterShape must be length 4, but got " + "".concat(filterShape, "."); }); var inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1]; var outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1]; assert(inDepth === filterShape[2], function () { return "Error in conv2dDerFilter: depth of input ".concat(inDepth, ") must ") + "match input depth in filter (".concat(filterShape[2], "."); }); assert(outDepth === filterShape[3], function () { return "Error in conv2dDerFilter: depth of dy (".concat(outDepth, ") must ") + "match output depth for filter (".concat(filterShape[3], ")."); }); checkPadOnDimRoundingMode('conv2dDerFilter', pad, dimRoundingMode); var inputs = { x: x4D, dy: dy4D }; var attrs = { strides: strides, pad: pad, dataFormat: dataFormat, dimRoundingMode: dimRoundingMode, filterShape: filterShape }; // tslint:disable-next-line: no-unnecessary-type-assertion return ENGINE.runKernel(Conv2DBackpropFilter, inputs, attrs); } var conv2DBackpropFilter = /* @__PURE__ */ op({ conv2DBackpropFilter_: conv2DBackpropFilter_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes the derivative of the input of a 2D convolution. * * @param xShape The shape of the input: [batch, height, width, inDepth]. * If length of 3, batch of 1 is assumed. * @param dy The derivative of the output, of rank 4 or rank 3 of shape * `[batch, outHeight, outWidth, outDepth]`. If rank 3, batch of 1 is * assumed. * @param filter The filter, rank 4, of shape * `[filterHeight, filterWidth, inDepth, outDepth]`. * @param strides The strides of the convolution: `[strideHeight, * strideWidth]`. * @param pad The type of padding algorithm used: * - `same` and stride 1: output will be of same size as input, * regardless of filter size. * - `valid`: output will be smaller than input if filter is larger * than 1x1. * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to * "NHWC". Specify the data format of the input and output data. With the * default format "NHWC", the data is stored in the order of: [batch, * height, width, channels]. * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is * provided, it will default to truncate. */ function conv2DBackpropInput_(xShape, dy, filter, strides, pad, dataFormat, dimRoundingMode) { if (dataFormat === void 0) { dataFormat = 'NHWC'; } assert(xShape.length === dy.rank, function () { return "Length of inShape " + "(".concat(xShape.length, ") and rank of dy (").concat(dy.rank, ") must match"); }); var xShape4D = xShape; var dy4D = dy; var reshapedTo4D = false; if (dy.rank === 3) { reshapedTo4D = true; dy4D = reshape$1(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]); xShape4D = [1, xShape[0], xShape[1], xShape[2]]; } assert(xShape4D.length === 4, function () { return "Error in conv2dDerInput: inShape must be length 4, but got length " + "".concat(xShape4D.length, "."); }); assert(dy4D.rank === 4, function () { return "Error in conv2dDerInput: dy must be rank 4, but got " + "rank ".concat(dy4D.rank); }); assert(filter.rank === 4, function () { return "Error in conv2dDerInput: filter must be rank 4, but got " + "rank ".concat(filter.rank); }); var inDepth = dataFormat === 'NHWC' ? xShape4D[3] : xShape4D[1]; var outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1]; assert(inDepth === filter.shape[2], function () { return "Error in conv2dDerInput: depth of input (".concat(inDepth, ") must ") + "match input depth for filter ".concat(filter.shape[2], "."); }); assert(outDepth === filter.shape[3], function () { return "Error in conv2dDerInput: depth of output (".concat(outDepth, ") must ") + "match output depth for filter ".concat(filter.shape[3], "."); }); checkPadOnDimRoundingMode('conv2dDerInput', pad, dimRoundingMode); var inputs = { dy: dy4D, filter: filter }; var attrs = { strides: strides, pad: pad, dataFormat: dataFormat, dimRoundingMode: dimRoundingMode, inputShape: xShape4D }; // tslint:disable-next-line: no-unnecessary-type-assertion var res = ENGINE.runKernel(Conv2DBackpropInput, inputs, attrs); if (reshapedTo4D) { return reshape$1(res, [res.shape[1], res.shape[2], res.shape[3]]); } return res; } var conv2DBackpropInput = /* @__PURE__ */ op({ conv2DBackpropInput_: conv2DBackpropInput_ }); var conv2DGradConfig = { kernelName: Conv2D$1, inputsToSave: ['x', 'filter'], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 2), x4D = _a[0], $filter = _a[1]; var dilations = attrs.dilations, strides = attrs.strides, pad = attrs.pad, dataFormat = attrs.dataFormat; assert(tupleValuesAreOne(dilations), function () { return 'Error in gradient of conv2D: dilation rates greater than 1 ' + "are not yet supported in gradients. Got dilations '".concat(dilations, "'"); }); return { x: function () { return conv2DBackpropInput(x4D.shape, dy, $filter, strides, pad, dataFormat); }, filter: function () { return conv2DBackpropFilter(x4D, dy, $filter.shape, strides, pad, dataFormat); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes a 2D convolution over the input x. * * @param x The input tensor, of rank 4 or rank 3, of shape * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is * assumed. * @param filter The filter, rank 4, of shape * `[filterHeight, filterWidth, inDepth, outDepth]`. * @param strides The strides of the convolution: `[strideHeight, * strideWidth]`. * @param pad The type of padding algorithm. * - `same` and stride 1: output will be of same size as input, * regardless of filter size. * - `valid`: output will be smaller than input if filter is larger * than 1x1. * - For more info, see this guide: * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution]( * https://www.tensorflow.org/api_docs/python/tf/nn/convolution) * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to * "NHWC". Specify the data format of the input and output data. With the * default format "NHWC", the data is stored in the order of: [batch, * height, width, channels]. * @param dilations The dilation rates: `[dilationHeight, dilationWidth]` * in which we sample input values across the height and width dimensions * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single * number, then `dilationHeight == dilationWidth`. If it is greater than * 1, then all values of `strides` must be 1. * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is * provided, it will default to truncate. * * @doc {heading: 'Operations', subheading: 'Convolution'} */ function conv2d_(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode) { if (dataFormat === void 0) { dataFormat = 'NHWC'; } if (dilations === void 0) { dilations = [1, 1]; } var $x = convertToTensor(x, 'x', 'conv2d', 'float32'); var $filter = convertToTensor(filter, 'filter', 'conv2d', 'float32'); var x4D = $x; var reshapedTo4D = false; if ($x.rank === 3) { reshapedTo4D = true; x4D = reshape$1($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]); } assert(x4D.rank === 4, function () { return "Error in conv2d: input must be rank 4, but got rank ".concat(x4D.rank, "."); }); assert($filter.rank === 4, function () { return "Error in conv2d: filter must be rank 4, but got rank " + "".concat($filter.rank, "."); }); checkPadOnDimRoundingMode('conv2d', pad, dimRoundingMode); var inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1]; assert(inDepth === $filter.shape[2], function () { return "Error in conv2d: depth of input (".concat(inDepth, ") must match ") + "input depth for filter ".concat($filter.shape[2], "."); }); assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in conv2D: Either strides or dilations must be 1. ' + "Got strides ".concat(strides, " and dilations '").concat(dilations, "'"); }); assert(stridesOrDilationsArePositive(dilations), function () { return 'Error in conv2D: Dilated rates should be larger than 0.'; }); assert(stridesOrDilationsArePositive(strides), function () { return 'Error in conv2D: Strides should be larger than 0.'; }); var inputs = { x: x4D, filter: $filter }; var attrs = { strides: strides, pad: pad, dataFormat: dataFormat, dilations: dilations, dimRoundingMode: dimRoundingMode }; // tslint:disable-next-line: no-unnecessary-type-assertion var res = ENGINE.runKernel(Conv2D$1, inputs, attrs); if (reshapedTo4D) { return reshape$1(res, [res.shape[1], res.shape[2], res.shape[3]]); } return res; } var conv2d$1 = /* @__PURE__ */ op({ conv2d_: conv2d_ }); var conv2DBackpropInputGradConfig = { kernelName: Conv2DBackpropInput, inputsToSave: ['dy', 'filter'], gradFunc: function (ddx, saved, attrs) { var _a = __read(saved, 2), dy = _a[0], filter = _a[1]; var strides = attrs.strides, pad = attrs.pad, dataFormat = attrs.dataFormat, dimRoundingMode = attrs.dimRoundingMode; return { dy: function () { return conv2d$1(ddx, filter, strides, pad, dataFormat, 1 /* dilations */, dimRoundingMode); }, filter: function () { return conv2DBackpropFilter(ddx, dy, filter.shape, strides, pad, dataFormat, dimRoundingMode); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes the derivative of the filter of a 3D convolution. * * @param x The input tensor, of rank 5 or rank 4 of shape * [batch, depth, height, width, inChannels]. If rank 4, batch of 1 is * assumed. * @param dy The dy image, of rank 5 or rank 4, of shape * [batch, depth, height, width, outDepth]. If rank 4, batch of 1 is * assumed. * @param filterShape The shape of the filter, length 5, * [filterDepth, filterHeight, filterWidth, inDepth, outDepth]. * @param strides The strides of the convolution: [strideDepth, strideHeight, * strideWidth]. * @param pad A string from: 'same', 'valid'. The type of padding algorithm * used in the forward prop of the op. */ function conv3DBackpropFilter_(x, dy, filterShape, strides, pad) { var x5D = x; if (x.rank === 4) { x5D = reshape$1(x, [1, x.shape[0], x.shape[1], x.shape[2], x.shape[3]]); } var dy5D = dy; if (dy5D.rank === 4) { dy5D = reshape$1(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]); } assert(x5D.rank === 5, function () { return "Error in conv3dDerFilter: input must be rank 5, but got shape " + "".concat(x5D.shape, "."); }); assert(dy5D.rank === 5, function () { return "Error in conv3dDerFilter: dy must be rank 5, but got shape " + "".concat(dy5D.shape, "."); }); assert(filterShape.length === 5, function () { return "Error in conv3dDerFilter: filterShape must be length 5, but got " + "".concat(filterShape, "."); }); assert(x5D.shape[4] === filterShape[3], function () { return "Error in conv3dDerFilter: depth of input ".concat(x5D.shape[4], ") must ") + "match input depth in filter (".concat(filterShape[3], "."); }); assert(dy5D.shape[4] === filterShape[4], function () { return "Error in conv3dDerFilter: depth of dy (".concat(dy5D.shape[4], ") must ") + "match output depth for filter (".concat(filterShape[4], ")."); }); var inputs = { x: x5D, dy: dy5D }; var attrs = { strides: strides, pad: pad, filterShape: filterShape }; // tslint:disable-next-line: no-unnecessary-type-assertion return ENGINE.runKernel(Conv3DBackpropFilterV2, inputs, attrs); } var conv3DBackpropFilter = /* @__PURE__ */ op({ conv3DBackpropFilter_: conv3DBackpropFilter_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes the derivative of the input of a 3D convolution. * * @param xShape The shape of the input: [batch, depth, height, width, * in_channels]. If length of 4, batch of 1 is assumed. * @param dy The derivative of the output, of rank 5 or rank 4 of shape * `[batch, outDepth, outHeight, outWidth, in_channels]`. * If rank 4, batch of 1 is assumed. * @param filter The filter, rank 5, of shape * `[filterDepth, filterHeight, filterWidth, inDepth, outDepth]`. * @param strides The strides of the convolution: `[strideDepth, strideHeight, * strideWidth]`. * @param pad The type of padding algorithm used: * - `same` and stride 1: output will be of same size as input, * regardless of filter size. * - `valid`: output will be smaller than input if filter is larger * than 1x1. */ function conv3DBackpropInput_(xShape, dy, filter, strides, pad) { assert(xShape.length === dy.rank, function () { return "Length of inShape " + "(".concat(xShape.length, ") and rank of dy (").concat(dy.rank, ") must match"); }); var xShape5D = xShape; var dy5D = dy; var reshapedTo5D = false; if (dy.rank === 4) { reshapedTo5D = true; dy5D = reshape$1(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]); xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]]; } var inDepth = xShape5D[4]; var outDepth = dy5D.shape[4]; assert(xShape5D.length === 5, function () { return "Error in conv3dDerInput: inShape must be length 5, but got length " + "".concat(xShape5D.length, "."); }); assert(dy5D.rank === 5, function () { return "Error in conv3dDerInput: dy must be rank 5, but got " + "rank ".concat(dy5D.rank); }); assert(filter.rank === 5, function () { return "Error in conv3dDerInput: filter must be rank 5, but got " + "rank ".concat(filter.rank); }); assert(inDepth === filter.shape[3], function () { return "Error in conv3dDerInput: depth of input (".concat(inDepth, ") must ") + "match input depth for filter ".concat(filter.shape[3], "."); }); assert(outDepth === filter.shape[4], function () { return "Error in conv3dDerInput: depth of output (".concat(outDepth, ") must ") + "match output depth for filter ".concat(filter.shape[4], "."); }); var inputs = { dy: dy5D, filter: filter }; var attrs = { pad: pad, strides: strides, inputShape: xShape5D }; // tslint:disable-next-line: no-unnecessary-type-assertion var res = ENGINE.runKernel(Conv3DBackpropInputV2, inputs, attrs); if (reshapedTo5D) { return reshape$1(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]); } return res; } var conv3DBackpropInput = /* @__PURE__ */ op({ conv3DBackpropInput_: conv3DBackpropInput_ }); var conv3DGradConfig = { kernelName: Conv3D$1, inputsToSave: ['x', 'filter'], gradFunc: function (dy, saved, attrs) { var dilations = attrs.dilations, strides = attrs.strides, pad = attrs.pad; assert(tupleValuesAreOne(dilations), function () { return 'Error in gradient of conv3D: dilation rates greater than 1 are ' + "not yet supported in gradients. Got dilations '".concat(dilations, "'"); }); var _a = __read(saved, 2), x5D = _a[0], $filter = _a[1]; return { x: function () { return conv3DBackpropInput(x5D.shape, dy, $filter, strides, pad); }, filter: function () { return conv3DBackpropFilter(x5D, dy, $filter.shape, strides, pad); } }; } }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes sin of the input Tensor element-wise: `sin(x)` * * ```js * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); * * x.sin().print(); // or tf.sin(x) * ``` * @param x The input tensor. * * @doc {heading: 'Operations', subheading: 'Basic math'} */ function sin_(x) { var $x = convertToTensor(x, 'x', 'sin', 'float32'); var inputs = { x: $x }; return ENGINE.runKernel(Sin, inputs); } var sin = /* @__PURE__ */ op({ sin_: sin_ }); var cosGradConfig = { kernelName: Cos, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return mul(neg(sin(cast(x, 'float32'))), dy); } }; } }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes hyperbolic sin of the input `tf.Tensor` element-wise: `sinh(x)` * * ```js * const x = tf.tensor1d([0, 1, -1, .7]); * * x.sinh().print(); // or tf.sinh(x) * ``` * @param x The input tensor. * * @doc {heading: 'Operations', subheading: 'Basic math'} */ function sinh_(x) { var $x = convertToTensor(x, 'x', 'sinh'); var inputs = { x: $x }; return ENGINE.runKernel(Sinh, inputs); } var sinh = /* @__PURE__ */ op({ sinh_: sinh_ }); var coshGradConfig = { kernelName: Cosh, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return mul(sinh(cast(x, 'float32')), dy); } }; } }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Returns true if the axis specifies the inner most dimensions of the * array. */ function axesAreInnerMostDims(axes, rank) { for (var i = 0; i < axes.length; ++i) { if (axes[axes.length - i - 1] !== rank - 1 - i) { return false; } } return true; } function combineLocations(outputLoc, reduceLoc, axes) { var rank = outputLoc.length + reduceLoc.length; var loc = []; var outIdx = 0; var reduceIdx = 0; for (var dim = 0; dim < rank; dim++) { if (axes.indexOf(dim) === -1) { loc.push(outputLoc[outIdx++]); } else { loc.push(reduceLoc[reduceIdx++]); } } return loc; } function computeOutAndReduceShapes(aShape, axes) { var outShape = []; var rank = aShape.length; for (var dim = 0; dim < rank; dim++) { if (axes.indexOf(dim) === -1) { outShape.push(aShape[dim]); } } var reduceShape = axes.map(function (dim) { return aShape[dim]; }); return [outShape, reduceShape]; } function expandShapeToKeepDim(shape, axes) { var reduceSubShape = axes.map(function (x) { return 1; }); return combineLocations(shape, reduceSubShape, axes); } /** * Returns the axes permutation to be used with `tf.transpose`, if such * permutation is necessary. Otherwise it returns null. This method is used by * operations that operate only on inner-most axes. */ function getAxesPermutation(axes, rank) { if (axesAreInnerMostDims(axes, rank)) { return null; } var result = []; for (var i = 0; i < rank; ++i) { if (axes.indexOf(i) === -1) { result.push(i); } } axes.forEach(function (axis) { return result.push(axis); }); return result; } /** Returns the axes permutation that undoes the original permutation. */ function getUndoAxesPermutation(axes) { return axes.map(function (axis, i) { return [i, axis]; }) .sort(function (a, b) { return a[1] - b[1]; }) .map(function (x) { return x[0]; }); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes the cumulative sum of a `tf.Tensor` along `axis`. * * ```js * const x = tf.tensor([1, 2, 3, 4]); * x.cumsum().print(); * ``` * ```js * const x = tf.tensor([[1, 2], [3, 4]]); * x.cumsum().print(); * ``` * * @param x The input tensor to be summed. * @param axis The axis along which to sum. Optional. Defaults to 0. * @param exclusive Whether to perform exclusive cumulative sum. Optional. * Defaults to false. If set to true then the sum of each tensor entry * does not include its own value, but only the values previous to it * along the specified axis. * @param reverse Whether to sum in the opposite direction. Optional. * Defaults to false. * * @doc {heading: 'Operations', subheading: 'Scan'} */ function cumsum_(x, axis, exclusive, reverse) { if (axis === void 0) { axis = 0; } if (exclusive === void 0) { exclusive = false; } if (reverse === void 0) { reverse = false; } var $x = convertToTensor(x, 'x', 'cumsum'); var inputs = { x: $x }; var attrs = { axis: axis, exclusive: exclusive, reverse: reverse }; return ENGINE.runKernel(Cumsum, inputs, attrs); } var cumsum = /* @__PURE__ */ op({ cumsum_: cumsum_ }); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Executes the provided function `fn` and after it is executed, cleans up all * intermediate tensors allocated by `fn` except those returned by `fn`. * `fn` must not return a Promise (async functions not allowed). The returned * result can be a complex object. * * Using this method helps avoid memory leaks. In general, wrap calls to * operations in `tf.tidy` for automatic memory cleanup. * * NOTE: Variables do *not* get cleaned up when inside a tidy(). If you want to * dispose variables, please use `tf.disposeVariables` or call dispose() * directly on variables. * * ```js * // y = 2 ^ 2 + 1 * const y = tf.tidy(() => { * // a, b, and one will be cleaned up when the tidy ends. * const one = tf.scalar(1); * const a = tf.scalar(2); * const b = a.square(); * * console.log('numTensors (in tidy): ' + tf.memory().numTensors); * * // The value returned inside the tidy function will return * // through the tidy, in this case to the variable y. * return b.add(one); * }); * * console.log('numTensors (outside tidy): ' + tf.memory().numTensors); * y.print(); * ``` * * @param nameOrFn The name of the closure, or the function to execute. * If a name is provided, the 2nd argument should be the function. * If debug mode is on, the timing and the memory usage of the function * will be tracked and displayed on the console using the provided name. * @param fn The function to execute. * * @doc {heading: 'Performance', subheading: 'Memory'} */ function tidy(nameOrFn, fn) { return ENGINE.tidy(nameOrFn, fn); } /** * Disposes any `tf.Tensor`s found within the provided object. * * @param container an object that may be a `tf.Tensor` or may directly * contain `tf.Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. If * the object is not a `tf.Tensor` or does not contain `Tensors`, nothing * happens. In general it is safe to pass any object here, except that * `Promise`s are not supported. * * @doc {heading: 'Performance', subheading: 'Memory'} */ function dispose(container) { var tensors = getTensorsInContainer(container); tensors.forEach(function (tensor) { return tensor.dispose(); }); } /** * Keeps a `tf.Tensor` generated inside a `tf.tidy` from being disposed * automatically. * * ```js * let b; * const y = tf.tidy(() => { * const one = tf.scalar(1); * const a = tf.scalar(2); * * // b will not be cleaned up by the tidy. a and one will be cleaned up * // when the tidy ends. * b = tf.keep(a.square()); * * console.log('numTensors (in tidy): ' + tf.memory().numTensors); * * // The value returned inside the tidy function will return * // through the tidy, in this case to the variable y. * return b.add(one); * }); * * console.log('numTensors (outside tidy): ' + tf.memory().numTensors); * console.log('y:'); * y.print(); * console.log('b:'); * b.print(); * ``` * * @param result The tensor to keep from being disposed. * * @doc {heading: 'Performance', subheading: 'Memory'} */ function keep(result) { return ENGINE.keep(result); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Converts two real numbers to a complex number. * * Given a tensor `real` representing the real part of a complex number, and a * tensor `imag` representing the imaginary part of a complex number, this * operation returns complex numbers elementwise of the form [r0, i0, r1, i1], * where r represents the real part and i represents the imag part. * * The input tensors real and imag must have the same shape. * * ```js * const real = tf.tensor1d([2.25, 3.25]); * const imag = tf.tensor1d([4.75, 5.75]); * const complex = tf.complex(real, imag); * * complex.print(); * ``` * * @doc {heading: 'Tensors', subheading: 'Creation'} */ function complex_(real, imag) { var $real = convertToTensor(real, 'real', 'complex'); var $imag = convertToTensor(imag, 'imag', 'complex'); assertShapesMatch($real.shape, $imag.shape, "real and imag shapes, ".concat($real.shape, " and ").concat($imag.shape, ", ") + "must match in call to tf.complex()."); var inputs = { real: $real, imag: $imag }; return ENGINE.runKernel(Complex, inputs); } var complex = /* @__PURE__ */ op({ complex_: complex_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Returns the imaginary part of a complex (or real) tensor. * * Given a tensor input, this operation returns a tensor of type float that is * the imaginary part of each element in input considered as a complex number. * If input is real, a tensor of all zeros is returned. * * ```js * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]); * tf.imag(x).print(); * ``` * * @doc {heading: 'Tensors', subheading: 'Creation'} */ function imag_(input) { var $input = convertToTensor(input, 'input', 'imag'); var inputs = { input: $input }; return ENGINE.runKernel(Imag, inputs); } var imag = /* @__PURE__ */ op({ imag_: imag_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Returns the real part of a complex (or real) tensor. * * Given a tensor input, this operation returns a tensor of type float that is * the real part of each element in input considered as a complex number. * * If the input is real, it simply makes a clone. * * ```js * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]); * tf.real(x).print(); * ``` * * @doc {heading: 'Tensors', subheading: 'Creation'} */ function real_(input) { var $input = convertToTensor(input, 'input', 'real'); var inputs = { input: $input }; return ENGINE.runKernel(Real, inputs); } var real = /* @__PURE__ */ op({ real_: real_ }); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Transposes the `tf.Tensor`. Permutes the dimensions according to `perm`. * * The returned `tf.Tensor`'s dimension `i` will correspond to the input * dimension `perm[i]`. If `perm` is not given, it is set to `[n-1...0]`, * where `n` is the rank of the input `tf.Tensor`. Hence by default, this * operation performs a regular matrix transpose on 2-D input `tf.Tensor`s. * * ```js * const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); * * a.transpose().print(); // or tf.transpose(a) * ``` * * @param x The tensor to transpose. * @param perm The permutation of the dimensions of a. * @param conjugate Will conjugate complex input if true. * * @doc {heading: 'Operations', subheading: 'Matrices'} */ function transpose_(x, perm, conjugate) { var $x = convertToTensor(x, 'x', 'transpose'); if (perm == null) { perm = $x.shape.map(function (s, i) { return i; }).reverse(); } assert($x.rank === perm.length, function () { return "Error in transpose: rank of input ".concat($x.rank, " ") + "must match length of perm ".concat(perm, "."); }); perm.forEach(function (axis) { assert(axis >= 0 && axis < $x.rank, function () { return "All entries in 'perm' must be between 0 and ".concat($x.rank - 1) + " but got ".concat(perm); }); }); if ($x.rank <= 1) { return $x.clone(); } var inputs = { x: $x }; var attrs = { perm: perm }; if ($x.dtype === 'complex64') { return tidy(function () { var $real = real($x); var $imag = imag($x); $real = ENGINE.runKernel(Transpose, { x: $real }, attrs); $imag = ENGINE.runKernel(Transpose, { x: $imag }, attrs); if (conjugate) { $imag = neg($imag); } return complex($real, $imag); }); } return ENGINE.runKernel(Transpose, inputs, attrs); } var transpose = /* @__PURE__ */ op({ transpose_: transpose_ }); var cumsumGradConfig = { kernelName: Cumsum, inputsToSave: ['x'], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 1), x = _a[0]; var axis = attrs.axis, exclusive = attrs.exclusive, reverse = attrs.reverse; return { x: function () { var permutation = getAxesPermutation([axis], x.rank); var out = cumsum(dy, axis, exclusive, !reverse); if (permutation != null) { out = transpose(out, permutation); } return out; } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function depthwiseConv2dNativeBackpropFilter_(x, dy, filterShape, strides, pad, dilations, dimRoundingMode) { if (dilations === void 0) { dilations = [1, 1]; } var x4D = x; if (x.rank === 3) { x4D = reshape$1(x, [1, x.shape[0], x.shape[1], x.shape[2]]); } var dy4D = dy; if (dy4D.rank === 3) { dy4D = reshape$1(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]); } var inputs = { x: x4D, dy: dy4D }; var attrs = { strides: strides, pad: pad, dimRoundingMode: dimRoundingMode, dilations: dilations, filterShape: filterShape }; // tslint:disable-next-line: no-unnecessary-type-assertion return ENGINE.runKernel(DepthwiseConv2dNativeBackpropFilter, inputs, attrs); } var depthwiseConv2dNativeBackpropFilter = op({ depthwiseConv2dNativeBackpropFilter_: depthwiseConv2dNativeBackpropFilter_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function depthwiseConv2dNativeBackpropInput_(xShape, dy, filter, strides, pad, dilations, dimRoundingMode) { if (dilations === void 0) { dilations = [1, 1]; } var dy4D = dy; var reshapedTo4D = false; if (dy.rank === 3) { reshapedTo4D = true; dy4D = reshape$1(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]); } var inputs = { dy: dy4D, filter: filter }; var attrs = { strides: strides, pad: pad, dimRoundingMode: dimRoundingMode, dilations: dilations, inputShape: xShape }; var res = // tslint:disable-next-line: no-unnecessary-type-assertion ENGINE.runKernel(DepthwiseConv2dNativeBackpropInput, inputs, attrs); if (reshapedTo4D) { return reshape$1(res, [res.shape[1], res.shape[2], res.shape[3]]); } return res; } var depthwiseConv2dNativeBackpropInput = op({ depthwiseConv2dNativeBackpropInput_: depthwiseConv2dNativeBackpropInput_ }); var depthwiseConv2dNativeGradConfig = { kernelName: DepthwiseConv2dNative, inputsToSave: ['x', 'filter'], gradFunc: function (dy, saved, attrs) { var dilations = attrs.dilations, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode; var $dilations = dilations == null ? [1, 1] : dilations; assert(tupleValuesAreOne($dilations), function () { return 'Error in gradient of depthwiseConv2dNative: dilation rates ' + "greater than 1 are not yet supported. Got dilations " + "'".concat($dilations, "'"); }); var _a = __read(saved, 2), x = _a[0], filter = _a[1]; assert(x.rank === 4, function () { return "Error in gradient of depthwiseConv2dNative: input must be " + "rank 4, but got rank ".concat(x.rank, "."); }); assert(filter.rank === 4, function () { return "Error in gradient of depthwiseConv2dNative: filter must be " + "rank 4, but got rank ".concat(filter.rank, "."); }); assert(x.shape[3] === filter.shape[2], function () { return "Error in gradient of depthwiseConv2d: number of input " + "channels (".concat(x.shape[3], ") must match the inChannels dimension ") + "in filter ".concat(filter.shape[2], "."); }); assert(eitherStridesOrDilationsAreOne(strides, $dilations), function () { return 'Error in gradient of depthwiseConv2d: Either strides or ' + "dilations must be 1. Got strides ".concat(strides, " and dilations ") + "'".concat($dilations, "'."); }); checkPadOnDimRoundingMode('depthwiseConv2d', pad, dimRoundingMode); return { x: function () { return depthwiseConv2dNativeBackpropInput(x.shape, dy, filter, strides, pad, $dilations, dimRoundingMode); }, filter: function () { return depthwiseConv2dNativeBackpropFilter(x, dy, filter.shape, strides, pad, $dilations, dimRoundingMode); }, }; } }; var dilation2dGradConfig = { kernelName: Dilation2D, inputsToSave: ['x', 'filter'], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 2), x = _a[0], filter = _a[1]; var inputInputs = { x: x, filter: filter, dy: dy }; var filterInputs = { x: x, filter: filter, dy: dy }; return { x: function () { return ENGINE.runKernel(Dilation2DBackpropInput, inputInputs, attrs); }, filter: function () { return ENGINE.runKernel(Dilation2DBackpropFilter, filterInputs, attrs); } }; } }; var eluGradConfig = { kernelName: Elu$1, outputsToSave: [true], gradFunc: function (dy, saved) { var _a = __read(saved, 1), y = _a[0]; var inputs = { dy: dy, y: y }; return { x: function () { return ENGINE.runKernel(EluGrad, inputs); } }; } }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes exponential of the input `tf.Tensor` element-wise. `e ^ x` * * ```js * const x = tf.tensor1d([1, 2, -3]); * * x.exp().print(); // or tf.exp(x) * ``` * @param x The input tensor. * * @doc {heading: 'Operations', subheading: 'Basic math'} */ function exp_(x) { var $x = convertToTensor(x, 'x', 'exp'); var inputs = { x: $x }; return ENGINE.runKernel(Exp, inputs); } var exp = /* @__PURE__ */ op({ exp_: exp_ }); var erfGradConfig = { kernelName: Erf, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; var a = mul(exp(neg(square(x))), 2 / Math.sqrt(Math.PI)); return { x: function () { return mul(dy, a); } }; } }; var expGradConfig = { kernelName: Exp, outputsToSave: [true], gradFunc: function (dy, saved) { var _a = __read(saved, 1), y = _a[0]; return { x: function () { return mul(dy, y); } }; } }; var expandDimsGradConfig = { kernelName: ExpandDims, inputsToSave: ['input'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), input = _a[0]; return { input: function () { return reshape$1(dy, input.shape); } }; } }; var expm1GradConfig = { kernelName: Expm1, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return mul(dy, exp(x)); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var floorGradConfig = { kernelName: Floor, gradFunc: function (dy) { return { x: function () { return zerosLike(dy); } }; } }; var floorDivGradConfig = { kernelName: FloorDiv, inputsToSave: ['a', 'b'], gradFunc: function (dy, saved) { var _a = __read(saved, 2), a = _a[0], b = _a[1]; var outShape = assertAndGetBroadcastShape(a.shape, b.shape); var derA = function () { var res = div(dy, cast(b, 'float32')); var reduceAxes = getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { return reshape$1(sum(res, reduceAxes), a.shape); } return res; }; var derB = function () { var res = mul(dy, cast(a, 'float32')); var reduceAxes = getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { res = reshape$1(sum(res, reduceAxes), b.shape); } var tmp = square(b); return neg(div(res, cast(tmp, 'float32'))); }; return { a: derA, b: derB }; } }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes reciprocal of square root of the input `tf.Tensor` element-wise: * `y = 1 / sqrt(x)` * * ```js * const x = tf.tensor1d([1, 2, 4, -1]); * * x.rsqrt().print(); // or tf.rsqrt(x) * ``` * @param x The input tensor. * * @doc {heading: 'Operations', subheading: 'Basic math'} */ function rsqrt_(x) { var $x = convertToTensor(x, 'x', 'rsqrt', 'float32'); var inputs = { x: $x }; return ENGINE.runKernel(Rsqrt, inputs); } var rsqrt = /* @__PURE__ */ op({ rsqrt_: rsqrt_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Construct a tensor by repeating it the number of times given by reps. * * This operation creates a new tensor by replicating `input` `reps` * times. The output tensor's `i`th dimension has `input.shape[i] * * reps[i]` elements, and the values of `input` are replicated * `reps[i]` times along the `i`th dimension. For example, tiling * `[a, b, c, d]` by `[2]` produces `[a, b, c, d, a, b, c, d]`. * * ```js * const a = tf.tensor1d([1, 2]); * * a.tile([2]).print(); // or tf.tile(a, [2]) * ``` * * ```js * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]); * * a.tile([1, 2]).print(); // or tf.tile(a, [1,2]) * ``` * @param x The tensor to tile. * @param reps Determines the number of replications per dimension. * * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ function tile_(x, reps) { var $x = convertToTensor(x, 'x', 'tile', 'string_or_numeric'); assert($x.rank === reps.length, function () { return "Error in transpose: rank of input ".concat($x.rank, " ") + "must match length of reps ".concat(reps, "."); }); var inputs = { x: $x }; var attrs = { reps: reps }; return ENGINE.runKernel(Tile, inputs, attrs); } var tile = /* @__PURE__ */ op({ tile_: tile_ }); var fusedBatchNormGradConfig = { kernelName: FusedBatchNorm, inputsToSave: ['x', 'mean', 'variance', 'scale'], gradFunc: function (dy, saved, attrs) { var varianceEpsilon = attrs.varianceEpsilon; var _a = __read(saved, 4), x = _a[0], mean = _a[1], variance = _a[2], scale = _a[3]; var scaleValue = scale == null ? scalar(1) : scale; var reductionAxes = getReductionAxes(mean.shape, x.shape); var tileShape = []; if (mean.rank === 1) { for (var i = 0; i < x.shape.length - 1; ++i) { tileShape.push(x.shape[i]); } tileShape.push(1); } var xMinusMean = sub(x, mean); var dyTimesScaleValue = mul(dy, scaleValue); var oneOverSqrtVariance = rsqrt(add$1(variance, scalar(varianceEpsilon))); var minusHalfRCube = mul(mul(mul(oneOverSqrtVariance, oneOverSqrtVariance), oneOverSqrtVariance), scalar(-0.5)); var derX = function () { if (mean.rank === 1) { return reshape$1(mul(mul(dy, tile(reshape$1(oneOverSqrtVariance, [1, 1, 1, mean.shape[0]]), tileShape)), scaleValue), x.shape); } else { return reshape$1(mul(mul(dy, oneOverSqrtVariance), scaleValue), x.shape); } }; var derMean = function () { var meanDer = mul(mul(oneOverSqrtVariance, scalar(-1)), dyTimesScaleValue); if (mean.rank === 1) { meanDer = sum(meanDer, reductionAxes); } return reshape$1(meanDer, mean.shape); }; var derVariance = function () { var varianceDer = mul(mul(minusHalfRCube, xMinusMean), dyTimesScaleValue); if (mean.rank === 1) { varianceDer = sum(varianceDer, reductionAxes); } return reshape$1(varianceDer, mean.shape); }; var derScale = function () { var xMinusMean2TimesRsqrt = mul(xMinusMean, oneOverSqrtVariance); var scaleDer = mul(dy, xMinusMean2TimesRsqrt); if (mean.rank === 1) { scaleDer = sum(scaleDer, reductionAxes); } return reshape$1(scaleDer, mean.shape); }; var derOffset = function () { var offsetDer = dy; if (mean.rank === 1) { offsetDer = sum(offsetDer, reductionAxes); } return reshape$1(offsetDer, mean.shape); }; return { x: derX, mean: derMean, variance: derVariance, scale: derScale, offset: derOffset }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Stacks a list of rank-`R` `tf.Tensor`s into one rank-`(R+1)` `tf.Tensor`. * * ```js * const a = tf.tensor1d([1, 2]); * const b = tf.tensor1d([3, 4]); * const c = tf.tensor1d([5, 6]); * tf.stack([a, b, c]).print(); * ``` * * @param tensors A list of tensor objects with the same shape and dtype. * @param axis The axis to stack along. Defaults to 0 (the first dim). * * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ function stack_(tensors, axis) { if (axis === void 0) { axis = 0; } var $tensors = convertToTensorArray(tensors, 'tensors', 'stack', 'string_or_numeric'); assert($tensors.length >= 1, function () { return 'Pass at least one tensor to tf.stack'; }); if ($tensors.length > 0) { assert(axis <= $tensors[0].rank, function () { return 'Axis must be <= rank of the tensor'; }); } var inputs = $tensors; var attrs = { axis: axis }; return ENGINE.runKernel(Pack, inputs, attrs); } var stack = /* @__PURE__ */ op({ stack_: stack_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes the sum along segments of a `tf.Tensor`. * * ```js * const x = tf.tensor1d([1, 2, 3, 4]); * const segmentIds = tf.tensor1d([1, 2, 0, 1], 'int32'); * const numSegments = 3; * * x.unsortedSegmentSum(segmentIds, numSegments).print() * //or tf.unsortedSegmentSum(x, segmentIds, numSegments) * ``` * @param x The `tf.Tensor` that will be summed along its segments. * @param segmentIds A `tf.Tensor1D` whose rank is equal to the rank of `x`'s * dimension along the `axis`. Maps each element of `x` to a segment. * @param numSegments The number of distinct `segmentIds`. * * @doc {heading: 'Operations', subheading: 'Segment'} */ function unsortedSegmentSum_(x, segmentIds, numSegments) { var $x = convertToTensor(x, 'x', 'unsortedSegmentSum'); var $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'unsortedSegmentSum', 'int32'); assert(isInt(numSegments), function () { return 'numSegments must be of dtype int'; }); var inputs = { x: $x, segmentIds: $segmentIds }; var attrs = { numSegments: numSegments }; return ENGINE.runKernel(UnsortedSegmentSum, inputs, attrs); } var unsortedSegmentSum = /* @__PURE__ */ op({ unsortedSegmentSum_: unsortedSegmentSum_ }); var gatherGradConfig = { kernelName: GatherV2, inputsToSave: ['x', 'indices'], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 2), x = _a[0], indices = _a[1]; var axis = attrs.axis, batchDims = attrs.batchDims; var parsedAxis = parseAxisParam(axis, x.shape)[0]; var derXBatch = function (x, indices, dy) { return function () { var paramsShape = x.shape; var indicesSize = indices.size; var outerShape = paramsShape.slice(0, parsedAxis); var outerDims = outerShape.length; var innerShape = paramsShape.slice(axis, paramsShape.length).slice(1); var innerDims = innerShape.length; var outerAxesIndices = arrayRange(0, outerDims); var innerAxesIndices = arrayRange(outerDims + 1, outerDims + 1 + innerDims); var valuesShape = arrayConcat([outerShape, [indicesSize], innerShape]); var values = reshape$1(dy, valuesShape); var reshapedIndices = reshape$1(indices, [indicesSize]); var transposeDims = arrayConcat([[outerDims], outerAxesIndices, innerAxesIndices]); var valuesTranspose = transpose(values, transposeDims); var paramsGrad = unsortedSegmentSum(valuesTranspose, reshapedIndices, x.shape[parsedAxis]); var invertTransposeDims = getUndoAxesPermutation(transposeDims); paramsGrad = transpose(paramsGrad, invertTransposeDims); return paramsGrad; }; }; if (batchDims === 1) { var batchSize = x.shape[0]; var xBatch_1 = x.split(batchSize, 0); var derXBatched = function () { var stacked = stack(xBatch_1.map(function (x, i) { return derXBatch(x, indices.slice(i, 1), dy.slice(i, 1))(); })); return stacked.reshape(x.shape); }; return { x: derXBatched, indices: function () { return indices; } }; } else { return { x: derXBatch(x, indices, dy), indices: function () { return indices; } }; } } }; function arrayRange(start, stop) { var result = []; for (var i = start; i < stop; ++i) { result.push(i); } return result; } function arrayConcat(arrays) { var result = []; for (var i = 0; i < arrays.length; ++i) { for (var j = 0; j < arrays[i].length; ++j) { result.push(arrays[i][j]); } } return result; } var greaterEqualGradConfig = { kernelName: GreaterEqual, inputsToSave: ['a', 'b'], gradFunc: function (dy, saved) { var _a = __read(saved, 2), a = _a[0], b = _a[1]; return { a: function () { return zerosLike(a); }, b: function () { return zerosLike(b); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var identityGradConfig = { kernelName: Identity, gradFunc: function (dy) { return { x: function () { return cast(dy, 'float32'); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var isFiniteGradConfig = { kernelName: IsFinite, gradFunc: function (dy) { // TODO(nsthorat): Let gradients be null for cases where we want to stop // backpropgation. return { x: function () { return zerosLike(dy); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var isInfGradConfig = { kernelName: IsInf, gradFunc: function (dy) { // TODO(nsthorat): Let gradients be null for cases where we want to stop // backpropgation. return { x: function () { return zerosLike(dy); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var isNanGradConfig = { kernelName: IsNan, gradFunc: function (dy) { // TODO(nsthorat): Let gradients be null for cases where we want to stop // backpropgation. return { x: function () { return zerosLike(dy); } }; } }; /** * Returns the truth value of (a > b) element-wise. Supports broadcasting. * * ```js * const a = tf.tensor1d([1, 2, 3]); * const b = tf.tensor1d([2, 2, 2]); * * a.greater(b).print(); * ``` * * @param a The first input tensor. * @param b The second input tensor. Must have the same dtype as `a`. * * @doc {heading: 'Operations', subheading: 'Logical'} */ function greater_(a, b) { var _a; var $a = convertToTensor(a, 'a', 'greater', 'string_or_numeric'); var $b = convertToTensor(b, 'b', 'greater', 'string_or_numeric'); _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1]; assertAndGetBroadcastShape($a.shape, $b.shape); var inputs = { a: $a, b: $b }; return ENGINE.runKernel(Greater, inputs); } var greater$1 = /* @__PURE__ */ op({ greater_: greater_ }); var leakyReluGradConfig = { kernelName: LeakyRelu, inputsToSave: ['x'], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 1), x = _a[0]; var alpha = attrs.alpha; var mask = greater$1(x, 0); // Returns `gradients * (features > 0) + alpha * gradients * (features <= // 0)`. return { x: function () { return where(mask, dy, mul(dy, alpha)); } }; } }; var log1pGradConfig = { kernelName: Log1p, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return div(dy, add$1(x, 1)); } }; } }; var logGradConfig = { kernelName: Log, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return div(dy, cast(x, 'float32')); } }; } }; var logSoftmaxGradConfig = { kernelName: LogSoftmax$1, inputsToSave: [], outputsToSave: [true], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 1), value = _a[0]; var axis = attrs.axis; return { logits: function () { var keepDims = true; var softmax = exp(value); return sub(dy, mul(sum(dy, axis, keepDims), softmax)); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function localResponseNormalizationBackprop_(x, y, dy, depthRadius, bias, alpha, beta) { if (depthRadius === void 0) { depthRadius = 5; } if (bias === void 0) { bias = 1; } if (alpha === void 0) { alpha = 1; } if (beta === void 0) { beta = 0.5; } var inputs = { x: x, y: y, dy: dy }; var attrs = { depthRadius: depthRadius, bias: bias, alpha: alpha, beta: beta }; return ENGINE.runKernel(LRNGrad, inputs, attrs); } var localResponseNormalizationBackprop = op({ localResponseNormalizationBackprop_: localResponseNormalizationBackprop_ }); var lrnGradConfig = { kernelName: LRN, inputsToSave: ['x'], outputsToSave: [true], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 2), x = _a[0], y = _a[1]; var depthRadius = attrs.depthRadius, bias = attrs.bias, alpha = attrs.alpha, beta = attrs.beta; return { x: function () { return localResponseNormalizationBackprop(x, y, dy, depthRadius, bias, alpha, beta); } }; } }; /** * Returns the truth value of (a == b) element-wise. Supports broadcasting. * * ```js * const a = tf.tensor1d([1, 2, 3]); * const b = tf.tensor1d([2, 2, 2]); * * a.equal(b).print(); * ``` * * @param a The first input tensor. * @param b The second input tensor. Must have the same dtype as `a`. * * @doc {heading: 'Operations', subheading: 'Logical'} */ function equal_(a, b) { var _a; var $a = convertToTensor(a, 'a', 'equal', 'string_or_numeric'); var $b = convertToTensor(b, 'b', 'equal', 'string_or_numeric'); _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1]; assertAndGetBroadcastShape($a.shape, $b.shape); var inputs = { a: $a, b: $b }; return ENGINE.runKernel(Equal, inputs); } var equal = /* @__PURE__ */ op({ equal_: equal_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Gradient helper function for the min and max operations. */ function gradForMinAndMax(dy, y, xOrig, origAxes) { if (y.rank < xOrig.rank) { y = reshape$1(y, expandShapeToKeepDim(y.shape, origAxes)); } if (dy.rank < xOrig.rank) { dy = reshape$1(dy, expandShapeToKeepDim(dy.shape, origAxes)); } return { x: function () { var dx = mul(dy, cast(equal(xOrig, y), dy.dtype)); return dx; } }; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var maxGradConfig = { kernelName: Max, inputsToSave: ['x'], outputsToSave: [true], gradFunc: function (dy, saved, attrs) { var maxAttrs = attrs; var reductionIndices = maxAttrs.reductionIndices; var x = saved[0]; var y = saved[1]; var origAxes = parseAxisParam(reductionIndices, x.shape); var maxGrad = gradForMinAndMax(dy, y, x, origAxes); return { x: function () { return maxGrad['x'](); } }; } }; /** * Returns the truth value of (a < b) element-wise. Supports broadcasting. * * ```js * const a = tf.tensor1d([1, 2, 3]); * const b = tf.tensor1d([2, 2, 2]); * * a.less(b).print(); * ``` * @param a The first input tensor. * @param b The second input tensor. Must have the same dtype as `a`. * * @doc {heading: 'Operations', subheading: 'Logical'} */ function less_(a, b) { var _a; var $a = convertToTensor(a, 'a', 'less', 'string_or_numeric'); var $b = convertToTensor(b, 'b', 'less', 'string_or_numeric'); _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1]; assertAndGetBroadcastShape($a.shape, $b.shape); var inputs = { a: $a, b: $b }; return ENGINE.runKernel(Less, inputs); } var less$1 = /* @__PURE__ */ op({ less_: less_ }); var maximumGradConfig = { kernelName: Maximum$1, inputsToSave: ['a', 'b'], gradFunc: function (dy, saved) { var _a = __read(saved, 2), a = _a[0], b = _a[1]; var derA = function () { return mul(dy, cast(greaterEqual(a, b), 'float32')); }; var derB = function () { return mul(dy, cast(less$1(a, b), 'float32')); }; return { a: derA, b: derB }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes the backprop of a 3d max pool. * * @param dy The dy error, of rank 5 of shape * [batchSize, depth, height, width, channels]. * assumed. * @param input The original input image, of rank 5 or rank 4 of shape * [batchSize, depth, height, width, channels]. * @param output The original output image, of rank 5 of shape * [batchSize, outDepth, outHeight, outWidth, channels]. * @param filterSize The filter size: * `[filterDepth, filterHeight, filterWidth]`. * `filterSize` is a single number, * then `filterDepth == filterHeight == filterWidth`. * @param strides The strides of the pooling: * `[strideDepth, strideHeight, strideWidth]`. If * `strides` is a single number, then `strideHeight == strideWidth`. * @param pad A string from: 'same', 'valid'. The type of padding algorithm * used in the forward prop of the op. * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is * provided, it will default to truncate. */ function maxPool3dGrad_(dy, input, output, filterSize, strides, pad, dimRoundingMode) { var $dy = convertToTensor(dy, 'dy', 'maxPool3dGrad'); var $input = convertToTensor(input, 'input', 'maxPool3dGrad'); var $output = convertToTensor(output, 'output', 'maxPool3dGrad'); var dy5D = $dy; var input5D = $input; var output5D = $output; var reshapedTo5D = false; if ($input.rank === 4) { reshapedTo5D = true; dy5D = reshape$1($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]); input5D = reshape$1($input, [ 1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3] ]); output5D = reshape$1($output, [ 1, $output.shape[0], $output.shape[1], $output.shape[2], $output.shape[3] ]); } assert(dy5D.rank === 5, function () { return "Error in maxPool3dGrad: dy must be rank 5 but got rank " + "".concat(dy5D.rank, "."); }); assert(input5D.rank === 5, function () { return "Error in maxPool3dGrad: input must be rank 5 but got rank " + "".concat(input5D.rank, "."); }); assert(output5D.rank === 5, function () { return "Error in maxPool3dGrad: output must be rank 5 but got rank " + "".concat(output5D.rank, "."); }); checkPadOnDimRoundingMode('maxPool3dGrad', pad, dimRoundingMode); var inputs = { dy: dy5D, input: input5D, output: output5D }; var attrs = { filterSize: filterSize, strides: strides, pad: pad, dimRoundingMode: dimRoundingMode }; // tslint:disable-next-line: no-unnecessary-type-assertion var res = ENGINE.runKernel(MaxPool3DGrad, inputs, attrs); if (reshapedTo5D) { return reshape$1(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]); } return res; } var maxPool3dGrad = /* @__PURE__ */ op({ maxPool3dGrad_: maxPool3dGrad_ }); var maxPool3DGradConfig = { kernelName: MaxPool3D, inputsToSave: ['x'], outputsToSave: [true], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 2), x = _a[0], y = _a[1]; var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode; return { x: function () { return maxPool3dGrad(dy, x, y, filterSize, strides, pad, dimRoundingMode); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes the backprop of a 2D max pool. * * @param dy The dy error, of rank 4 or rank 3 of shape * [batchSize, height, width, channels]. If rank 3, batch of 1 is * assumed. * @param input The original input image, of rank 4, of shape * [batchSize, height, width, channels]. * @param output The original output image, of rank 4, of shape * [batchSize, outHeight, outWidth, channels]. * @param filterSize The filter size: `[filterHeight, filterWidth]`. If * `filterSize` is a single number, then `filterHeight == filterWidth`. * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If * `strides` is a single number, then `strideHeight == strideWidth`. * @param pad The type of padding algorithm used in the forward prop of the op. * 'same', 'valid', for more info, see this guide: * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution]( * https://www.tensorflow.org/api_docs/python/tf/nn/convolution) * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is * provided, it will default to truncate. */ function maxPoolGrad_(dy, input, output, filterSize, strides, pad, dimRoundingMode) { var $dy = convertToTensor(dy, 'dy', 'maxPoolGrad'); var $input = convertToTensor(input, 'input', 'maxPoolGrad'); var $output = convertToTensor(output, 'output', 'maxPoolGrad'); assert($input.rank === $dy.rank, function () { return "Rank of input (".concat($input.rank, ") does not match rank of dy ") + "(".concat($dy.rank, ")"); }); assert($dy.rank === 4, function () { return "Error in maxPoolGrad: dy must be rank 4 but got rank " + "".concat($dy.rank, "."); }); assert($input.rank === 4, function () { return "Error in maxPoolGrad: input must be rank 4 but got rank " + "".concat($input.rank, "."); }); checkPadOnDimRoundingMode('maxPoolGrad', pad, dimRoundingMode); var inputs = { dy: $dy, input: $input, output: $output }; var attrs = { filterSize: filterSize, strides: strides, pad: pad, dimRoundingMode: dimRoundingMode }; // tslint:disable-next-line: no-unnecessary-type-assertion return ENGINE.runKernel(MaxPoolGrad, inputs, attrs); } var maxPoolGrad = /* @__PURE__ */ op({ maxPoolGrad_: maxPoolGrad_ }); var maxPoolGradConfig = { kernelName: MaxPool, inputsToSave: ['x'], outputsToSave: [true], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 2), x = _a[0], y = _a[1]; var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad; return { x: function () { return maxPoolGrad(dy, x, y, filterSize, strides, pad); } }; } }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Creates a `tf.Tensor` with all elements set to 0. * * ```js * tf.zeros([2, 2]).print(); * ``` * * @param shape An array of integers defining the output tensor shape. * @param dtype The type of an element in the resulting tensor. Can * be 'float32', 'int32' or 'bool'. Defaults to 'float'. * * @doc {heading: 'Tensors', subheading: 'Creation'} */ function zeros$1(shape, dtype) { if (dtype === void 0) { dtype = 'float32'; } assertNonNegativeIntegerDimensions(shape); if (dtype === 'complex64') { var real = zeros$1(shape, 'float32'); var imag = zeros$1(shape, 'float32'); return complex(real, imag); } var values = makeZerosTypedArray(sizeFromShape(shape), dtype); return ENGINE.makeTensor(values, shape, dtype); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Creates a `tf.Tensor` with all elements set to 1. * * ```js * tf.ones([2, 2]).print(); * ``` * * @param shape An array of integers defining the output tensor shape. * @param dtype The type of an element in the resulting tensor. Defaults to * 'float'. * * @doc {heading: 'Tensors', subheading: 'Creation'} */ function ones$1(shape, dtype) { if (dtype === void 0) { dtype = 'float32'; } assertNonNegativeIntegerDimensions(shape); if (dtype === 'complex64') { var real = ones$1(shape, 'float32'); var imag = zeros$1(shape, 'float32'); return complex(real, imag); } var values = makeOnesTypedArray(sizeFromShape(shape), dtype); return ENGINE.makeTensor(values, shape, dtype); } var meanGradConfig = { kernelName: Mean, inputsToSave: ['x'], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 1), x = _a[0]; var axis = attrs.axis; var axes = parseAxisParam(axis, x.shape); var shapes = computeOutAndReduceShapes(x.shape, axes); var reduceShape = shapes[1]; var reduceSize = sizeFromShape(reduceShape); var derX = function () { var expandedDyShape = x.shape.slice(); axes.forEach(function (axis) { expandedDyShape[axis] = 1; }); var expandedDy = reshape$1(dy, expandedDyShape); var res = div(mul(expandedDy, ones$1(x.shape, 'float32')), reduceSize); return res; }; return { x: derX }; } }; var minGradConfig = { kernelName: Min, inputsToSave: ['x'], outputsToSave: [true], gradFunc: function (dy, saved, attrs) { var minAttrs = attrs; var axis = minAttrs.axis; var _a = __read(saved, 2), x = _a[0], y = _a[1]; var origAxes = parseAxisParam(axis, x.shape); var minGrad = gradForMinAndMax(dy, y, x, origAxes); return { x: function () { return minGrad['x'](); } }; } }; var minimumGradConfig = { kernelName: Minimum$1, inputsToSave: ['a', 'b'], gradFunc: function (dy, saved) { var _a = __read(saved, 2), a = _a[0], b = _a[1]; var derA = function () { return mul(dy, cast(lessEqual(a, b), 'float32')); }; var derB = function () { return mul(dy, cast(greater$1(a, b), 'float32')); }; return { a: derA, b: derB }; } }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Extracts a slice from a `tf.Tensor` starting at coordinates `begin` * and is of size `size`. * * Also available are stricter rank-specific methods with the same signature * as this method that assert that `x` is of the given rank: * - `tf.slice1d` * - `tf.slice2d` * - `tf.slice3d` * - `tf.slice4d` * * ```js * const x = tf.tensor1d([1, 2, 3, 4]); * * x.slice([1], [2]).print(); * ``` * * ```js * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); * * x.slice([1, 0], [1, 2]).print(); * ``` * @param x The input `tf.Tensor` to slice from. * @param begin The coordinates to start the slice from. The length can be * less than the rank of x - the rest of the axes will have implicit 0 as * start. Can also be a single number, in which case it specifies the * first axis. * @param size The size of the slice. The length can be less than the rank of * x - the rest of the axes will have implicit -1. A value of -1 requests * the rest of the dimensions in the axis. Can also be a single number, * in which case it specifies the size of the first axis. * * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ function slice_(x, begin, size) { var $x = convertToTensor(x, 'x', 'slice', 'string_or_numeric'); if ($x.rank === 0) { throw new Error('Slicing scalar is not possible'); } var inputs = { x: $x }; var attrs = { begin: begin, size: size }; return ENGINE.runKernel(Slice, inputs, attrs); } var slice = /* @__PURE__ */ op({ slice_: slice_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var mirrorPadGradConfig = { kernelName: MirrorPad, inputsToSave: ['x'], gradFunc: function (dy, saved, attrs) { // Pad introduces values around the original tensor, so the gradient // slices the original shape out of the gradient. var x = saved[0]; var paddings = attrs.paddings; var begin = paddings.map(function (p) { return p[0]; }); return { x: function () { return slice(dy, begin, x.shape); } }; } }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes floor of input `tf.Tensor` element-wise: `floor(x)`. * * ```js * const x = tf.tensor1d([.6, 1.1, -3.3]); * * x.floor().print(); // or tf.floor(x) * ``` * @param x The input tensor. * * @doc {heading: 'Operations', subheading: 'Basic math'} */ function floor_(x) { var $x = convertToTensor(x, 'x', 'floor', 'float32'); var inputs = { x: $x }; return ENGINE.runKernel(Floor, inputs); } var floor = /* @__PURE__ */ op({ floor_: floor_ }); var modGradConfig = { kernelName: Mod, inputsToSave: ['a', 'b'], gradFunc: function (dy, saved) { var _a = __read(saved, 2), a = _a[0], b = _a[1]; var outShape = assertAndGetBroadcastShape(a.shape, b.shape); var derA = function () { var reduceAxes = getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { return reshape$1(sum(dy, reduceAxes), a.shape); } return dy; }; var derB = function () { var res = mul(dy, neg(floor(div(a, b)))); var reduceAxes = getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { return reshape$1(sum(res, reduceAxes), b.shape); } return res; }; return { a: derA, b: derB }; } }; var multiplyGradConfig = { kernelName: Multiply$1, inputsToSave: ['a', 'b'], gradFunc: function (dy, saved) { var _a = __read(saved, 2), a = _a[0], b = _a[1]; var outShape = assertAndGetBroadcastShape(a.shape, b.shape); var derA = function () { var res = mul(dy, cast(b, 'float32')); var reduceAxes = getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { return reshape$1(sum(res, reduceAxes), a.shape); } return res; }; var derB = function () { var res = mul(dy, cast(a, 'float32')); var reduceAxes = getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { return reshape$1(sum(res, reduceAxes), b.shape); } return res; }; return { a: derA, b: derB }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var negGradConfig = { kernelName: Neg, gradFunc: function (dy) { return { x: function () { return neg(dy); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var oneHotGradConfig = { kernelName: OneHot, inputsToSave: ['indices'], gradFunc: function (dy, saved) { var indices = saved[0]; return { indices: function () { return zeros$1(indices.shape, 'float32'); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var onesLikeGradConfig = { kernelName: OnesLike, gradFunc: function (dy) { return { x: function () { return zerosLike(dy); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Unstacks a `tf.Tensor` of rank-`R` into a list of rank-`(R-1)` `tf.Tensor`s. * * ```js * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]); * * tf.unstack(a).forEach(tensor => tensor.print()); * ``` * * @param x A tensor object. * @param axis The axis to unstack along. Defaults to 0 (the first dim). * * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ function unstack_(x, axis) { if (axis === void 0) { axis = 0; } var $x = convertToTensor(x, 'x', 'unstack', 'string_or_numeric'); assert(axis >= -$x.shape.length && axis < $x.shape.length, function () { return "Axis = ".concat(axis, " is not in [-").concat($x.shape.length, ", ").concat($x.shape.length, ")"); }); var inputs = { value: $x }; var attrs = { axis: axis }; return ENGINE.runKernel(Unpack, inputs, attrs); } var unstack = /* @__PURE__ */ op({ unstack_: unstack_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var packGradConfig = { kernelName: Pack, saveAllInputs: true, gradFunc: function (dy, saved, attrs) { var axis = attrs.axis; var derTensors = unstack(dy, axis); return derTensors.map(function (t) { return function () { return t; }; }); } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var padV2GradConfig = { kernelName: PadV2, inputsToSave: ['x'], gradFunc: function (dy, saved, attrs) { // Pad introduces values around the original tensor, so the gradient // slices the original shape out of the gradient. var x = saved[0]; var paddings = attrs.paddings; var begin = paddings.map(function (p) { return p[0]; }); return { x: function () { return slice(dy, begin, x.shape); } }; } }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes natural logarithm of the input `tf.Tensor` element-wise: `ln(x)` * * ```js * const x = tf.tensor1d([1, 2, Math.E]); * * x.log().print(); // or tf.log(x) * ``` * @param x The input tensor. * * @doc {heading: 'Operations', subheading: 'Basic math'} */ function log_(x) { var $x = convertToTensor(x, 'x', 'log', 'float32'); var inputs = { x: $x }; return ENGINE.runKernel(Log, inputs); } var log = /* @__PURE__ */ op({ log_: log_ }); /** * Computes the power of one `tf.Tensor` to another. Supports broadcasting. * * Given a `tf.Tensor` x and a `tf.Tensor` y, this operation computes x^y for * corresponding elements in x and y. The result's dtype will be the upcasted * type of the `base` and `exp` dtypes. * * ```js * const a = tf.tensor([[2, 3], [4, 5]]) * const b = tf.tensor([[1, 2], [3, 0]]).toInt(); * * a.pow(b).print(); // or tf.pow(a, b) * ``` * * ```js * const a = tf.tensor([[1, 2], [3, 4]]) * const b = tf.tensor(2).toInt(); * * a.pow(b).print(); // or tf.pow(a, b) * ``` * We also expose `powStrict` which has the same signature as this op and * asserts that `base` and `exp` are the same shape (does not broadcast). * * @param base The base `tf.Tensor` to pow element-wise. * @param exp The exponent `tf.Tensor` to pow element-wise. * * @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function pow_(base, exp) { var _a; var $base = convertToTensor(base, 'base', 'pow'); var $exp = convertToTensor(exp, 'exp', 'pow'); _a = __read(makeTypesMatch($base, $exp), 2), $base = _a[0], $exp = _a[1]; var inputs = { a: $base, b: $exp }; return ENGINE.runKernel(Pow, inputs); } var pow = /* @__PURE__ */ op({ pow_: pow_ }); var powGradConfig = { kernelName: Pow, inputsToSave: ['a', 'b'], outputsToSave: [true], gradFunc: function (dy, saved) { var _a = __read(saved, 3), a = _a[0], b = _a[1], y = _a[2]; var base = a; var exp = b; var outShape = assertAndGetBroadcastShape(base.shape, exp.shape); var derBase = function () { var expFloat = cast(exp, 'float32'); var res = mul(dy, mul(expFloat, pow(base, sub(expFloat, scalar(1))))); var reduceAxes = getReductionAxes(base.shape, outShape); if (reduceAxes.length > 0) { res = sum(res, reduceAxes); } return reshape$1(res, base.shape); }; var derExp = function () { var condition = greater$1(base, 0); var logBase = where(condition, log(base), zerosLike(base)); var res = mul(dy, mul(y, logBase)); var reduceAxes = getReductionAxes(exp.shape, outShape); if (reduceAxes.length > 0) { res = sum(res, reduceAxes); } return reshape$1(res, exp.shape); }; return { a: derBase, b: derExp }; } }; var preluGradConfig = { kernelName: Prelu, inputsToSave: ['x', 'alpha'], gradFunc: function (dy, saved) { var _a = __read(saved, 2), x = _a[0], alpha = _a[1]; var mask = greater$1(x, 0); return { x: function () { return where(mask, dy, mul(dy, alpha)); }, alpha: function () { var res = where(mask, zerosLike(dy), mul(dy, x)); var reduceAxes = getReductionAxes(alpha.shape, dy.shape); if (reduceAxes.length > 0) { res = sum(res, reduceAxes); } return reshape$1(res, alpha.shape); } }; } }; /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function isBrowser() { return (typeof window !== 'undefined' && window.document != null) || //@ts-ignore (typeof WorkerGlobalScope !== 'undefined'); } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ENV = env(); /** * This file contains environment-related flag registrations. */ /** Whether to enable debug mode. */ ENV.registerFlag('DEBUG', function () { return false; }, function (debugValue) { if (debugValue) { console.warn('Debugging mode is ON. The output of every math call will ' + 'be downloaded to CPU and checked for NaNs. ' + 'This significantly impacts performance.'); } }); /** Whether we are in a browser (as versus, say, node.js) environment. */ ENV.registerFlag('IS_BROWSER', function () { return isBrowser(); }); /** Whether we are in a browser (as versus, say, node.js) environment. */ ENV.registerFlag('IS_NODE', function () { return (typeof process !== 'undefined') && (typeof process.versions !== 'undefined') && (typeof process.versions.node !== 'undefined'); }); /** Whether this browser is Chrome. */ ENV.registerFlag('IS_CHROME', function () { return typeof navigator !== 'undefined' && navigator != null && navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && /Google Inc/.test(navigator.vendor); }); /** Whether this browser is Safari. */ ENV.registerFlag('IS_SAFARI', function () { return typeof navigator !== 'undefined' && navigator != null && navigator.userAgent != null && /Safari/.test(navigator.userAgent) && /Apple/.test(navigator.vendor); }); /** * True when the environment is "production" where we disable safety checks * to gain performance. */ ENV.registerFlag('PROD', function () { return false; }); /** * Whether to do sanity checks when inferring a shape from user-provided * values, used when creating a new tensor. */ ENV.registerFlag('TENSORLIKE_CHECK_SHAPE_CONSISTENCY', function () { return ENV.getBool('DEBUG'); }); /** Whether deprecation warnings are enabled. */ ENV.registerFlag('DEPRECATION_WARNINGS_ENABLED', function () { return true; }); /** True if running unit tests. */ ENV.registerFlag('IS_TEST', function () { return false; }); /** Whether to check computation result for errors. */ ENV.registerFlag('CHECK_COMPUTATION_FOR_ERRORS', function () { return ENV.getBool('DEBUG'); }); /** Whether the backend needs to wrap input to imageBitmap. */ ENV.registerFlag('WRAP_TO_IMAGEBITMAP', function () { return false; }); /** Whether to enable canvas2d willReadFrequently for GPU backends */ ENV.registerFlag('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU', function () { return false; }); /** Whether to use setTimeoutCustom */ ENV.registerFlag('USE_SETTIMEOUTCUSTOM', function () { return false; }); /** * Wraps a list of ArrayBuffers into a `slice()`-able object without allocating * a large ArrayBuffer. * * Allocating large ArrayBuffers (~2GB) can be unstable on Chrome. TFJS loads * its weights as a list of (usually) 4MB ArrayBuffers and then slices the * weight tensors out of them. For small models, it's safe to concatenate all * the weight buffers into a single ArrayBuffer and then slice the weight * tensors out of it, but for large models, a different approach is needed. */ var CompositeArrayBuffer = /** @class */ (function () { function CompositeArrayBuffer(buffers) { this.shards = []; this.previousShardIndex = 0; if (buffers == null) { return; } // Normalize the `buffers` input to be `ArrayBuffer[]`. if (!(buffers instanceof Array)) { buffers = [buffers]; } buffers = buffers.map(function (bufferOrTypedArray) { if (isTypedArray(bufferOrTypedArray)) { return bufferOrTypedArray.buffer; } return bufferOrTypedArray; }); // Skip setting up shards if there are no buffers. if (buffers.length === 0) { return; } this.bufferUniformSize = buffers[0].byteLength; var start = 0; for (var i = 0; i < buffers.length; i++) { var buffer = buffers[i]; // Check that all buffers except the last one have the same length. if (i !== buffers.length - 1 && buffer.byteLength !== this.bufferUniformSize) { // Unset the buffer uniform size, since the buffer sizes are not // uniform. this.bufferUniformSize = undefined; } // Create the shards, including their start and end points. var end = start + buffer.byteLength; this.shards.push({ buffer: buffer, start: start, end: end }); start = end; } // Set the byteLenghth if (this.shards.length === 0) { this.byteLength = 0; } this.byteLength = this.shards[this.shards.length - 1].end; } /** * Concatenate a number of ArrayBuffers into one. * * @param buffers An array of ArrayBuffers to concatenate, or a single * ArrayBuffer. * @returns Result of concatenating `buffers` in order. */ CompositeArrayBuffer.join = function (buffers) { return new CompositeArrayBuffer(buffers).slice(); }; CompositeArrayBuffer.prototype.slice = function (start, end) { if (start === void 0) { start = 0; } if (end === void 0) { end = this.byteLength; } // If there are no shards, then the CompositeArrayBuffer was initialized // with no data. if (this.shards.length === 0) { return new ArrayBuffer(0); } // NaN is treated as zero for slicing. This matches ArrayBuffer's behavior. start = isNaN(Number(start)) ? 0 : start; end = isNaN(Number(end)) ? 0 : end; // Fix the bounds to within the array. start = Math.max(0, start); end = Math.min(this.byteLength, end); if (end <= start) { return new ArrayBuffer(0); } var startShardIndex = this.findShardForByte(start); if (startShardIndex === -1) { // This should not happen since the start and end indices are always // within 0 and the composite array's length. throw new Error("Could not find start shard for byte ".concat(start)); } var size = end - start; var outputBuffer = new ArrayBuffer(size); var outputArray = new Uint8Array(outputBuffer); var sliced = 0; for (var i = startShardIndex; i < this.shards.length; i++) { var shard = this.shards[i]; var globalStart = start + sliced; var localStart = globalStart - shard.start; var outputStart = sliced; var globalEnd = Math.min(end, shard.end); var localEnd = globalEnd - shard.start; var outputSlice = new Uint8Array(shard.buffer, localStart, localEnd - localStart); outputArray.set(outputSlice, outputStart); sliced += outputSlice.length; if (end < shard.end) { break; } } return outputBuffer; }; /** * Get the index of the shard that contains the byte at `byteIndex`. */ CompositeArrayBuffer.prototype.findShardForByte = function (byteIndex) { if (this.shards.length === 0 || byteIndex < 0 || byteIndex >= this.byteLength) { return -1; } // If the buffers have a uniform size, compute the shard directly. if (this.bufferUniformSize != null) { this.previousShardIndex = Math.floor(byteIndex / this.bufferUniformSize); return this.previousShardIndex; } // If the buffers don't have a uniform size, we need to search for the // shard. That means we need a function to check where the byteIndex lies // relative to a given shard. function check(shard) { if (byteIndex < shard.start) { return -1; } if (byteIndex >= shard.end) { return 1; } return 0; } // For efficiency, try the previous shard first. if (check(this.shards[this.previousShardIndex]) === 0) { return this.previousShardIndex; } // Otherwise, use a generic search function. // This should almost never end up being used in practice since the weight // entries should always be in order. var index = search(this.shards, check); if (index === -1) { return -1; } this.previousShardIndex = index; return this.previousShardIndex; }; return CompositeArrayBuffer; }()); /** * Search for an element of a sorted array. * * @param sortedArray The sorted array to search * @param compare A function to compare the current value against the searched * value. Return 0 on a match, negative if the searched value is less than * the value passed to the function, and positive if the searched value is * greater than the value passed to the function. * @returns The index of the element, or -1 if it's not in the array. */ function search(sortedArray, compare) { // Binary search var min = 0; var max = sortedArray.length; while (min <= max) { var middle = Math.floor((max - min) / 2) + min; var side = compare(sortedArray[middle]); if (side === 0) { return middle; } else if (side < 0) { max = middle; } else { min = middle + 1; } } return -1; } // Use Buffer on Node.js instead of Blob/atob/btoa var useNodeBuffer = typeof Buffer !== 'undefined' && (typeof Blob === 'undefined' || typeof atob === 'undefined' || typeof btoa === 'undefined'); /** * Calculate the byte length of a JavaScript string. * * Note that a JavaScript string can contain wide characters, therefore the * length of the string is not necessarily equal to the byte length. * * @param str Input string. * @returns Byte length. */ function stringByteLength(str) { if (useNodeBuffer) { return Buffer.byteLength(str, 'utf8'); } return new Blob([str]).size; } /** * Encode an ArrayBuffer as a base64 encoded string. * * @param buffer `ArrayBuffer` to be converted. * @returns A string that base64-encodes `buffer`. */ function arrayBufferToBase64String(buffer) { if (useNodeBuffer) { return Buffer.from(buffer).toString('base64'); } var buf = new Uint8Array(buffer); var s = ''; for (var i = 0, l = buf.length; i < l; i++) { s += String.fromCharCode(buf[i]); } return btoa(s); } /** * Decode a base64 string as an ArrayBuffer. * * @param str Base64 string. * @returns Decoded `ArrayBuffer`. */ function base64StringToArrayBuffer(str) { if (useNodeBuffer) { var buf = Buffer.from(str, 'base64'); return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength); } var s = atob(str); var buffer = new Uint8Array(s.length); for (var i = 0; i < s.length; ++i) { buffer.set([s.charCodeAt(i)], i); } return buffer.buffer; } /** * Create `ModelJSON` from `ModelArtifacts`. * * @param artifacts Model artifacts, describing the model and its weights. * @param manifest Weight manifest, describing where the weights of the * `ModelArtifacts` are stored, and some metadata about them. * @returns Object representing the `model.json` file describing the model * artifacts and weights */ function getModelJSONForModelArtifacts(artifacts, manifest) { var result = { modelTopology: artifacts.modelTopology, format: artifacts.format, generatedBy: artifacts.generatedBy, convertedBy: artifacts.convertedBy, weightsManifest: manifest }; if (artifacts.signature != null) { result.signature = artifacts.signature; } if (artifacts.userDefinedMetadata != null) { result.userDefinedMetadata = artifacts.userDefinedMetadata; } if (artifacts.modelInitializer != null) { result.modelInitializer = artifacts.modelInitializer; } if (artifacts.initializerSignature != null) { result.initializerSignature = artifacts.initializerSignature; } if (artifacts.trainingConfig != null) { result.trainingConfig = artifacts.trainingConfig; } return result; } /** * Create `ModelArtifacts` from a JSON file and weights. * * @param modelJSON Object containing the parsed JSON of `model.json` * @param weightSpecs The list of WeightsManifestEntry for the model. Must be * passed if the modelJSON has a weightsManifest. * @param weightData An ArrayBuffer or array of ArrayBuffers of weight data for * the model corresponding to the weights in weightSpecs. Must be passed if * the modelJSON has a weightsManifest. * @returns A Promise of the `ModelArtifacts`, as described by the JSON file. */ function getModelArtifactsForJSONSync(modelJSON, weightSpecs, weightData) { var modelArtifacts = { modelTopology: modelJSON.modelTopology, format: modelJSON.format, generatedBy: modelJSON.generatedBy, convertedBy: modelJSON.convertedBy }; if (modelJSON.trainingConfig != null) { modelArtifacts.trainingConfig = modelJSON.trainingConfig; } if (modelJSON.weightsManifest != null) { if (!weightSpecs) { throw new Error('modelJSON has weightsManifest but weightSpecs is null'); } if (!weightData) { throw new Error('modelJSON has weightsManifest but weightData is null'); } modelArtifacts.weightSpecs = weightSpecs; modelArtifacts.weightData = weightData; } if (modelJSON.signature != null) { modelArtifacts.signature = modelJSON.signature; } if (modelJSON.userDefinedMetadata != null) { modelArtifacts.userDefinedMetadata = modelJSON.userDefinedMetadata; } if (modelJSON.modelInitializer != null) { modelArtifacts.modelInitializer = modelJSON.modelInitializer; } if (modelJSON.initializerSignature != null) { modelArtifacts.initializerSignature = modelJSON.initializerSignature; } return modelArtifacts; } /** * Create `ModelArtifacts` from a JSON file. * * @param modelJSON Object containing the parsed JSON of `model.json` * @param loadWeights Function that takes the JSON file's weights manifest, * reads weights from the listed path(s), and returns a Promise of the * weight manifest entries along with the weights data. * @returns A Promise of the `ModelArtifacts`, as described by the JSON file. */ function getModelArtifactsForJSON(modelJSON, loadWeights) { return __awaiter(this, void 0, void 0, function () { var weightSpecs, weightData; var _a; return __generator(this, function (_b) { switch (_b.label) { case 0: if (!(modelJSON.weightsManifest != null)) return [3 /*break*/, 2]; return [4 /*yield*/, loadWeights(modelJSON.weightsManifest)]; case 1: _a = __read.apply(void 0, [_b.sent(), 2]), weightSpecs = _a[0], weightData = _a[1]; _b.label = 2; case 2: return [2 /*return*/, getModelArtifactsForJSONSync(modelJSON, weightSpecs, weightData)]; } }); }); } /** * Populate ModelArtifactsInfo fields for a model with JSON topology. * @param modelArtifacts * @returns A ModelArtifactsInfo object. */ function getModelArtifactsInfoForJSON(modelArtifacts) { if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error('Expected JSON model topology, received ArrayBuffer.'); } return { dateSaved: new Date(), modelTopologyType: 'JSON', modelTopologyBytes: modelArtifacts.modelTopology == null ? 0 : stringByteLength(JSON.stringify(modelArtifacts.modelTopology)), weightSpecsBytes: modelArtifacts.weightSpecs == null ? 0 : stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)), weightDataBytes: modelArtifacts.weightData == null ? 0 : new CompositeArrayBuffer(modelArtifacts.weightData).byteLength, }; } /** * Concatenate the weights stored in a WeightsManifestConfig into a list of * WeightsManifestEntry * * @param weightsManifest The WeightsManifestConfig to extract weights from. * @returns A list of WeightsManifestEntry of the weights in the weightsManifest */ function getWeightSpecs(weightsManifest) { var e_3, _a; var weightSpecs = []; try { for (var weightsManifest_1 = __values(weightsManifest), weightsManifest_1_1 = weightsManifest_1.next(); !weightsManifest_1_1.done; weightsManifest_1_1 = weightsManifest_1.next()) { var entry = weightsManifest_1_1.value; weightSpecs.push.apply(weightSpecs, __spreadArray([], __read(entry.weights), false)); } } catch (e_3_1) { e_3 = { error: e_3_1 }; } finally { try { if (weightsManifest_1_1 && !weightsManifest_1_1.done && (_a = weightsManifest_1.return)) _a.call(weightsManifest_1); } finally { if (e_3) throw e_3.error; } } return weightSpecs; } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var IORouterRegistry = /** @class */ (function () { function IORouterRegistry() { this.saveRouters = []; this.loadRouters = []; } IORouterRegistry.getInstance = function () { if (IORouterRegistry.instance == null) { IORouterRegistry.instance = new IORouterRegistry(); } return IORouterRegistry.instance; }; /** * Register a save-handler router. * * @param saveRouter A function that maps a URL-like string onto an instance * of `IOHandler` with the `save` method defined or `null`. */ IORouterRegistry.registerSaveRouter = function (saveRouter) { IORouterRegistry.getInstance().saveRouters.push(saveRouter); }; /** * Register a load-handler router. * * @param loadRouter A function that maps a URL-like string onto an instance * of `IOHandler` with the `load` method defined or `null`. */ IORouterRegistry.registerLoadRouter = function (loadRouter) { IORouterRegistry.getInstance().loadRouters.push(loadRouter); }; /** * Look up IOHandler for saving, given a URL-like string. * * @param url * @returns If only one match is found, an instance of IOHandler with the * `save` method defined. If no match is found, `null`. * @throws Error, if more than one match is found. */ IORouterRegistry.getSaveHandlers = function (url) { return IORouterRegistry.getHandlers(url, 'save'); }; /** * Look up IOHandler for loading, given a URL-like string. * * @param url * @param loadOptions Optional, custom load options. * @returns All valid handlers for `url`, given the currently registered * handler routers. */ IORouterRegistry.getLoadHandlers = function (url, loadOptions) { return IORouterRegistry.getHandlers(url, 'load', loadOptions); }; IORouterRegistry.getHandlers = function (url, handlerType, loadOptions) { var validHandlers = []; var routers = handlerType === 'load' ? IORouterRegistry.getInstance().loadRouters : IORouterRegistry.getInstance().saveRouters; routers.forEach(function (router) { var handler = router(url, loadOptions); if (handler !== null) { validHandlers.push(handler); } }); return validHandlers; }; return IORouterRegistry; }()); var DATABASE_NAME = 'tensorflowjs'; var DATABASE_VERSION = 1; // Model data and ModelArtifactsInfo (metadata) are stored in two separate // stores for efficient access of the list of stored models and their metadata. // 1. The object store for model data: topology, weights and weight manifests. var MODEL_STORE_NAME = 'models_store'; // 2. The object store for ModelArtifactsInfo, including meta-information such // as the type of topology (JSON vs binary), byte size of the topology, byte // size of the weights, etc. var INFO_STORE_NAME = 'model_info_store'; function getIndexedDBFactory() { if (!env().getBool('IS_BROWSER')) { // TODO(cais): Add more info about what IOHandler subtypes are available. // Maybe point to a doc page on the web and/or automatically determine // the available IOHandlers and print them in the error message. throw new Error('Failed to obtain IndexedDB factory because the current environment' + 'is not a web browser.'); } // tslint:disable-next-line:no-any var theWindow = typeof window === 'undefined' ? self : window; var factory = theWindow.indexedDB || theWindow.mozIndexedDB || theWindow.webkitIndexedDB || theWindow.msIndexedDB || theWindow.shimIndexedDB; if (factory == null) { throw new Error('The current browser does not appear to support IndexedDB.'); } return factory; } function setUpDatabase(openRequest) { var db = openRequest.result; db.createObjectStore(MODEL_STORE_NAME, { keyPath: 'modelPath' }); db.createObjectStore(INFO_STORE_NAME, { keyPath: 'modelPath' }); } /** * IOHandler subclass: Browser IndexedDB. * * See the doc string of `browserIndexedDB` for more details. */ var BrowserIndexedDB = /** @class */ (function () { function BrowserIndexedDB(modelPath) { this.indexedDB = getIndexedDBFactory(); if (modelPath == null || !modelPath) { throw new Error('For IndexedDB, modelPath must not be null, undefined or empty.'); } this.modelPath = modelPath; } BrowserIndexedDB.prototype.save = function (modelArtifacts) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { // TODO(cais): Support saving GraphDef models. if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error('BrowserLocalStorage.save() does not support saving model topology ' + 'in binary formats yet.'); } return [2 /*return*/, this.databaseAction(this.modelPath, modelArtifacts)]; }); }); }; BrowserIndexedDB.prototype.load = function () { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { return [2 /*return*/, this.databaseAction(this.modelPath)]; }); }); }; /** * Perform database action to put model artifacts into or read model artifacts * from IndexedDB object store. * * Whether the action is put or get depends on whether `modelArtifacts` is * specified. If it is specified, the action will be put; otherwise the action * will be get. * * @param modelPath A unique string path for the model. * @param modelArtifacts If specified, it will be the model artifacts to be * stored in IndexedDB. * @returns A `Promise` of `SaveResult`, if the action is put, or a `Promise` * of `ModelArtifacts`, if the action is get. */ BrowserIndexedDB.prototype.databaseAction = function (modelPath, modelArtifacts) { var _this = this; return new Promise(function (resolve, reject) { var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); openRequest.onupgradeneeded = function () { return setUpDatabase(openRequest); }; openRequest.onsuccess = function () { var db = openRequest.result; if (modelArtifacts == null) { // Read model out from object store. var modelTx = db.transaction(MODEL_STORE_NAME, 'readonly'); var modelStore = modelTx.objectStore(MODEL_STORE_NAME); var getRequest_1 = modelStore.get(_this.modelPath); getRequest_1.onsuccess = function () { if (getRequest_1.result == null) { db.close(); return reject(new Error("Cannot find model with path '".concat(_this.modelPath, "' ") + "in IndexedDB.")); } else { resolve(getRequest_1.result.modelArtifacts); } }; getRequest_1.onerror = function (error) { db.close(); return reject(getRequest_1.error); }; modelTx.oncomplete = function () { return db.close(); }; } else { // Put model into object store. // Concatenate all the model weights into a single ArrayBuffer. Large // models (~1GB) have problems saving if they are not concatenated. // TODO(mattSoulanille): Save large models to multiple indexeddb // records. modelArtifacts.weightData = CompositeArrayBuffer.join(modelArtifacts.weightData); var modelArtifactsInfo_1 = getModelArtifactsInfoForJSON(modelArtifacts); // First, put ModelArtifactsInfo into info store. var infoTx_1 = db.transaction(INFO_STORE_NAME, 'readwrite'); var infoStore_1 = infoTx_1.objectStore(INFO_STORE_NAME); var putInfoRequest_1; try { putInfoRequest_1 = infoStore_1.put({ modelPath: _this.modelPath, modelArtifactsInfo: modelArtifactsInfo_1 }); } catch (error) { return reject(error); } var modelTx_1; putInfoRequest_1.onsuccess = function () { // Second, put model data into model store. modelTx_1 = db.transaction(MODEL_STORE_NAME, 'readwrite'); var modelStore = modelTx_1.objectStore(MODEL_STORE_NAME); var putModelRequest; try { putModelRequest = modelStore.put({ modelPath: _this.modelPath, modelArtifacts: modelArtifacts, modelArtifactsInfo: modelArtifactsInfo_1 }); } catch (error) { // Sometimes, the serialized value is too large to store. return reject(error); } putModelRequest.onsuccess = function () { return resolve({ modelArtifactsInfo: modelArtifactsInfo_1 }); }; putModelRequest.onerror = function (error) { // If the put-model request fails, roll back the info entry as // well. infoStore_1 = infoTx_1.objectStore(INFO_STORE_NAME); var deleteInfoRequest = infoStore_1.delete(_this.modelPath); deleteInfoRequest.onsuccess = function () { db.close(); return reject(putModelRequest.error); }; deleteInfoRequest.onerror = function (error) { db.close(); return reject(putModelRequest.error); }; }; }; putInfoRequest_1.onerror = function (error) { db.close(); return reject(putInfoRequest_1.error); }; infoTx_1.oncomplete = function () { if (modelTx_1 == null) { db.close(); } else { modelTx_1.oncomplete = function () { return db.close(); }; } }; } }; openRequest.onerror = function (error) { return reject(openRequest.error); }; }); }; return BrowserIndexedDB; }()); BrowserIndexedDB.URL_SCHEME = 'indexeddb://'; var indexedDBRouter = function (url) { if (!env().getBool('IS_BROWSER')) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) { return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length)); } else { return null; } } }; IORouterRegistry.registerSaveRouter(indexedDBRouter); IORouterRegistry.registerLoadRouter(indexedDBRouter); /** * Creates a browser IndexedDB IOHandler for saving and loading models. * * ```js * const model = tf.sequential(); * model.add( * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'})); * * const saveResult = await model.save('indexeddb://MyModel')); * console.log(saveResult); * ``` * * @param modelPath A unique identifier for the model to be saved. Must be a * non-empty string. * @returns An instance of `BrowserIndexedDB` (sublcass of `IOHandler`), * which can be used with, e.g., `tf.Model.save`. */ function browserIndexedDB(modelPath) { return new BrowserIndexedDB(modelPath); } var PATH_SEPARATOR = '/'; var PATH_PREFIX = 'tensorflowjs_models'; var INFO_SUFFIX = 'info'; var MODEL_TOPOLOGY_SUFFIX = 'model_topology'; var WEIGHT_SPECS_SUFFIX = 'weight_specs'; var WEIGHT_DATA_SUFFIX = 'weight_data'; var MODEL_METADATA_SUFFIX = 'model_metadata'; function getModelKeys(path) { return { info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR), topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR), weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR), weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR), modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR) }; } function removeItems(keys) { var e_1, _a; try { for (var _b = __values(Object.values(keys)), _c = _b.next(); !_c.done; _c = _b.next()) { var key = _c.value; window.localStorage.removeItem(key); } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_1) throw e_1.error; } } } /** * IOHandler subclass: Browser Local Storage. * * See the doc string to `browserLocalStorage` for more details. */ var BrowserLocalStorage = /** @class */ (function () { function BrowserLocalStorage(modelPath) { if (!env().getBool('IS_BROWSER') || typeof window === 'undefined' || typeof window.localStorage === 'undefined') { // TODO(cais): Add more info about what IOHandler subtypes are // available. // Maybe point to a doc page on the web and/or automatically determine // the available IOHandlers and print them in the error message. throw new Error('The current environment does not support local storage.'); } this.LS = window.localStorage; if (modelPath == null || !modelPath) { throw new Error('For local storage, modelPath must not be null, undefined or empty.'); } this.modelPath = modelPath; this.keys = getModelKeys(this.modelPath); } /** * Save model artifacts to browser local storage. * * See the documentation to `browserLocalStorage` for details on the saved * artifacts. * * @param modelArtifacts The model artifacts to be stored. * @returns An instance of SaveResult. */ BrowserLocalStorage.prototype.save = function (modelArtifacts) { return __awaiter(this, void 0, void 0, function () { var topology, weightSpecs, modelArtifactsInfo, weightBuffer, metadata; return __generator(this, function (_a) { if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error('BrowserLocalStorage.save() does not support saving model topology ' + 'in binary formats yet.'); } else { topology = JSON.stringify(modelArtifacts.modelTopology); weightSpecs = JSON.stringify(modelArtifacts.weightSpecs); modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts); weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData); try { this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo)); this.LS.setItem(this.keys.topology, topology); this.LS.setItem(this.keys.weightSpecs, weightSpecs); this.LS.setItem(this.keys.weightData, arrayBufferToBase64String(weightBuffer)); metadata = { format: modelArtifacts.format, generatedBy: modelArtifacts.generatedBy, convertedBy: modelArtifacts.convertedBy, signature: modelArtifacts.signature != null ? modelArtifacts.signature : undefined, userDefinedMetadata: modelArtifacts.userDefinedMetadata != null ? modelArtifacts.userDefinedMetadata : undefined, modelInitializer: modelArtifacts.modelInitializer != null ? modelArtifacts.modelInitializer : undefined, initializerSignature: modelArtifacts.initializerSignature != null ? modelArtifacts.initializerSignature : undefined, trainingConfig: modelArtifacts.trainingConfig != null ? modelArtifacts.trainingConfig : undefined }; this.LS.setItem(this.keys.modelMetadata, JSON.stringify(metadata)); return [2 /*return*/, { modelArtifactsInfo: modelArtifactsInfo }]; } catch (err) { // If saving failed, clean up all items saved so far. removeItems(this.keys); throw new Error("Failed to save model '".concat(this.modelPath, "' to local storage: ") + "size quota being exceeded is a possible cause of this failure: " + "modelTopologyBytes=".concat(modelArtifactsInfo.modelTopologyBytes, ", ") + "weightSpecsBytes=".concat(modelArtifactsInfo.weightSpecsBytes, ", ") + "weightDataBytes=".concat(modelArtifactsInfo.weightDataBytes, ".")); } } return [2 /*return*/]; }); }); }; /** * Load a model from local storage. * * See the documentation to `browserLocalStorage` for details on the saved * artifacts. * * @returns The loaded model (if loading succeeds). */ BrowserLocalStorage.prototype.load = function () { return __awaiter(this, void 0, void 0, function () { var info, out, topology, weightSpecs, metadataString, metadata, weightDataBase64; return __generator(this, function (_a) { info = JSON.parse(this.LS.getItem(this.keys.info)); if (info == null) { throw new Error("In local storage, there is no model with name '".concat(this.modelPath, "'")); } if (info.modelTopologyType !== 'JSON') { throw new Error('BrowserLocalStorage does not support loading non-JSON model ' + 'topology yet.'); } out = {}; topology = JSON.parse(this.LS.getItem(this.keys.topology)); if (topology == null) { throw new Error("In local storage, the topology of model '".concat(this.modelPath, "' ") + "is missing."); } out.modelTopology = topology; weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs)); if (weightSpecs == null) { throw new Error("In local storage, the weight specs of model '".concat(this.modelPath, "' ") + "are missing."); } out.weightSpecs = weightSpecs; metadataString = this.LS.getItem(this.keys.modelMetadata); if (metadataString != null) { metadata = JSON.parse(metadataString); out.format = metadata.format; out.generatedBy = metadata.generatedBy; out.convertedBy = metadata.convertedBy; if (metadata.signature != null) { out.signature = metadata.signature; } if (metadata.userDefinedMetadata != null) { out.userDefinedMetadata = metadata.userDefinedMetadata; } if (metadata.modelInitializer != null) { out.modelInitializer = metadata.modelInitializer; } if (metadata.initializerSignature != null) { out.initializerSignature = metadata.initializerSignature; } if (metadata.trainingConfig != null) { out.trainingConfig = metadata.trainingConfig; } } weightDataBase64 = this.LS.getItem(this.keys.weightData); if (weightDataBase64 == null) { throw new Error("In local storage, the binary weight values of model " + "'".concat(this.modelPath, "' are missing.")); } out.weightData = base64StringToArrayBuffer(weightDataBase64); return [2 /*return*/, out]; }); }); }; return BrowserLocalStorage; }()); BrowserLocalStorage.URL_SCHEME = 'localstorage://'; var localStorageRouter = function (url) { if (!env().getBool('IS_BROWSER')) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) { return browserLocalStorage(url.slice(BrowserLocalStorage.URL_SCHEME.length)); } else { return null; } } }; IORouterRegistry.registerSaveRouter(localStorageRouter); IORouterRegistry.registerLoadRouter(localStorageRouter); /** * Factory function for local storage IOHandler. * * This `IOHandler` supports both `save` and `load`. * * For each model's saved artifacts, four items are saved to local storage. * - `${PATH_SEPARATOR}/${modelPath}/info`: Contains meta-info about the * model, such as date saved, type of the topology, size in bytes, etc. * - `${PATH_SEPARATOR}/${modelPath}/topology`: Model topology. For Keras- * style models, this is a stringized JSON. * - `${PATH_SEPARATOR}/${modelPath}/weight_specs`: Weight specs of the * model, can be used to decode the saved binary weight values (see * item below). * - `${PATH_SEPARATOR}/${modelPath}/weight_data`: Concatenated binary * weight values, stored as a base64-encoded string. * * Saving may throw an `Error` if the total size of the artifacts exceed the * browser-specific quota. * * @param modelPath A unique identifier for the model to be saved. Must be a * non-empty string. * @returns An instance of `IOHandler`, which can be used with, e.g., * `tf.Model.save`. */ function browserLocalStorage(modelPath) { return new BrowserLocalStorage(modelPath); } var DEFAULT_FILE_NAME_PREFIX = 'model'; var DEFAULT_JSON_EXTENSION_NAME = '.json'; var DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin'; function defer(f) { return new Promise(function (resolve) { return setTimeout(resolve); }).then(f); } var BrowserDownloads = /** @class */ (function () { function BrowserDownloads(fileNamePrefix) { if (!env().getBool('IS_BROWSER')) { // TODO(cais): Provide info on what IOHandlers are available under the // current environment. throw new Error('browserDownloads() cannot proceed because the current environment ' + 'is not a browser.'); } if (fileNamePrefix.startsWith(BrowserDownloads.URL_SCHEME)) { fileNamePrefix = fileNamePrefix.slice(BrowserDownloads.URL_SCHEME.length); } if (fileNamePrefix == null || fileNamePrefix.length === 0) { fileNamePrefix = DEFAULT_FILE_NAME_PREFIX; } this.modelJsonFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME; this.weightDataFileName = fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME; } BrowserDownloads.prototype.save = function (modelArtifacts) { return __awaiter(this, void 0, void 0, function () { var weightBuffer, weightsURL, weightsManifest, modelJSON, modelJsonURL, jsonAnchor_1, weightDataAnchor_1; return __generator(this, function (_a) { switch (_a.label) { case 0: if (typeof (document) === 'undefined') { throw new Error('Browser downloads are not supported in ' + 'this environment since `document` is not present'); } weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData); weightsURL = window.URL.createObjectURL(new Blob([weightBuffer], { type: 'application/octet-stream' })); if (!(modelArtifacts.modelTopology instanceof ArrayBuffer)) return [3 /*break*/, 1]; throw new Error('BrowserDownloads.save() does not support saving model topology ' + 'in binary formats yet.'); case 1: weightsManifest = [{ paths: ['./' + this.weightDataFileName], weights: modelArtifacts.weightSpecs }]; modelJSON = getModelJSONForModelArtifacts(modelArtifacts, weightsManifest); modelJsonURL = window.URL.createObjectURL(new Blob([JSON.stringify(modelJSON)], { type: 'application/json' })); jsonAnchor_1 = this.modelJsonAnchor == null ? document.createElement('a') : this.modelJsonAnchor; jsonAnchor_1.download = this.modelJsonFileName; jsonAnchor_1.href = modelJsonURL; // Trigger downloads by evoking a click event on the download anchors. // When multiple downloads are started synchronously, Firefox will only // save the last one. return [4 /*yield*/, defer(function () { return jsonAnchor_1.dispatchEvent(new MouseEvent('click')); })]; case 2: // Trigger downloads by evoking a click event on the download anchors. // When multiple downloads are started synchronously, Firefox will only // save the last one. _a.sent(); if (!(modelArtifacts.weightData != null)) return [3 /*break*/, 4]; weightDataAnchor_1 = this.weightDataAnchor == null ? document.createElement('a') : this.weightDataAnchor; weightDataAnchor_1.download = this.weightDataFileName; weightDataAnchor_1.href = weightsURL; return [4 /*yield*/, defer(function () { return weightDataAnchor_1.dispatchEvent(new MouseEvent('click')); })]; case 3: _a.sent(); _a.label = 4; case 4: return [2 /*return*/, { modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts) }]; } }); }); }; return BrowserDownloads; }()); BrowserDownloads.URL_SCHEME = 'downloads://'; var browserDownloadsRouter = function (url) { if (!env().getBool('IS_BROWSER')) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) { return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length)); } else { return null; } } }; IORouterRegistry.registerSaveRouter(browserDownloadsRouter); /** * Creates an IOHandler that triggers file downloads from the browser. * * The returned `IOHandler` instance can be used as model exporting methods such * as `tf.Model.save` and supports only saving. * * ```js * const model = tf.sequential(); * model.add(tf.layers.dense( * {units: 1, inputShape: [10], activation: 'sigmoid'})); * const saveResult = await model.save('downloads://mymodel'); * // This will trigger downloading of two files: * // 'mymodel.json' and 'mymodel.weights.bin'. * console.log(saveResult); * ``` * * @param fileNamePrefix Prefix name of the files to be downloaded. For use with * `tf.Model`, `fileNamePrefix` should follow either of the following two * formats: * 1. `null` or `undefined`, in which case the default file * names will be used: * - 'model.json' for the JSON file containing the model topology and * weights manifest. * - 'model.weights.bin' for the binary file containing the binary weight * values. * 2. A single string or an Array of a single string, as the file name prefix. * For example, if `'foo'` is provided, the downloaded JSON * file and binary weights file will be named 'foo.json' and * 'foo.weights.bin', respectively. * @param config Additional configuration for triggering downloads. * @returns An instance of `BrowserDownloads` `IOHandler`. * * @doc { * heading: 'Models', * subheading: 'Loading', * namespace: 'io', * ignoreCI: true * } */ function browserDownloads(fileNamePrefix) { if (fileNamePrefix === void 0) { fileNamePrefix = 'model'; } return new BrowserDownloads(fileNamePrefix); } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Monitor Promise.all progress, fire onProgress callback function. * * @param promises Promise list going to be monitored * @param onProgress Callback function. Fired when a promise resolved. * @param startFraction Optional fraction start. Default to 0. * @param endFraction Optional fraction end. Default to 1. */ function monitorPromisesProgress(promises, onProgress, startFraction, endFraction) { checkPromises(promises); startFraction = startFraction == null ? 0 : startFraction; endFraction = endFraction == null ? 1 : endFraction; checkFraction(startFraction, endFraction); var resolvedPromise = 0; var registerMonitor = function (promise) { promise.then(function (value) { var fraction = startFraction + ++resolvedPromise / promises.length * (endFraction - startFraction); // pass fraction as parameter to callback function. onProgress(fraction); return value; }); return promise; }; function checkPromises(promises) { assert(promises != null && Array.isArray(promises) && promises.length > 0, function () { return 'promises must be a none empty array'; }); } function checkFraction(startFraction, endFraction) { assert(startFraction >= 0 && startFraction <= 1, function () { return "Progress fraction must be in range [0, 1], but " + "got startFraction ".concat(startFraction); }); assert(endFraction >= 0 && endFraction <= 1, function () { return "Progress fraction must be in range [0, 1], but " + "got endFraction ".concat(endFraction); }); assert(endFraction >= startFraction, function () { return "startFraction must be no more than endFraction, but " + "got startFraction ".concat(startFraction, " and endFraction ") + "".concat(endFraction); }); } return Promise.all(promises.map(registerMonitor)); } /** * Reads binary weights data from a number of URLs. * * @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls. * @param requestOptions RequestInit (options) for the HTTP requests. * @param fetchFunc Optional overriding value for the `window.fetch` function. * @param onProgress Optional, progress callback function, fired periodically * before the load is completed. * @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same * length as `fetchURLs`. */ function loadWeightsAsArrayBuffer(fetchURLs, loadOptions) { return __awaiter(this, void 0, void 0, function () { var fetchFunc, requests, fetchStartFraction, fetchEndFraction, responses, _b, bufferPromises, bufferStartFraction, bufferEndFraction, buffers, _c; return __generator(this, function (_d) { switch (_d.label) { case 0: if (loadOptions == null) { loadOptions = {}; } fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : loadOptions.fetchFunc; requests = fetchURLs.map(function (fetchURL) { return fetchFunc(fetchURL, loadOptions.requestInit, { isBinary: true }); }); fetchStartFraction = 0; fetchEndFraction = 0.5; if (!(loadOptions.onProgress == null)) return [3 /*break*/, 2]; return [4 /*yield*/, Promise.all(requests)]; case 1: _b = _d.sent(); return [3 /*break*/, 4]; case 2: return [4 /*yield*/, monitorPromisesProgress(requests, loadOptions.onProgress, fetchStartFraction, fetchEndFraction)]; case 3: _b = _d.sent(); _d.label = 4; case 4: responses = _b; bufferPromises = responses.map(function (response) { return response.arrayBuffer(); }); bufferStartFraction = 0.5; bufferEndFraction = 1; if (!(loadOptions.onProgress == null)) return [3 /*break*/, 6]; return [4 /*yield*/, Promise.all(bufferPromises)]; case 5: _c = _d.sent(); return [3 /*break*/, 8]; case 6: return [4 /*yield*/, monitorPromisesProgress(bufferPromises, loadOptions.onProgress, bufferStartFraction, bufferEndFraction)]; case 7: _c = _d.sent(); _d.label = 8; case 8: buffers = _c; return [2 /*return*/, buffers]; } }); }); } function streamWeights(fetchURLs, loadOptions) { var _this = this; var _a; var fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : loadOptions.fetchFunc; var fetchIndex = 0; var chunkReader; (_a = loadOptions.onProgress) === null || _a === void 0 ? void 0 : _a.call(loadOptions, 0); return new ReadableStream({ pull: function (controller) { return __awaiter(_this, void 0, void 0, function () { var _a, body, _b, done, value; return __generator(this, function (_c) { switch (_c.label) { case 0: if (!(fetchIndex < fetchURLs.length)) return [3 /*break*/, 4]; if (!!chunkReader) return [3 /*break*/, 2]; return [4 /*yield*/, fetchFunc(fetchURLs[fetchIndex], loadOptions.requestInit, { isBinary: true })]; case 1: body = (_c.sent()).body; chunkReader = body.getReader(); _c.label = 2; case 2: return [4 /*yield*/, chunkReader.read()]; case 3: _b = _c.sent(), done = _b.done, value = _b.value; if (done) { fetchIndex++; chunkReader = undefined; (_a = loadOptions.onProgress) === null || _a === void 0 ? void 0 : _a.call(loadOptions, fetchIndex / fetchURLs.length); return [3 /*break*/, 0]; } controller.enqueue(value); return [2 /*return*/]; case 4: controller.close(); return [2 /*return*/]; } }); }); }, }); } var OCTET_STREAM_MIME_TYPE = 'application/octet-stream'; var JSON_TYPE = 'application/json'; var HTTPRequest = /** @class */ (function () { function HTTPRequest(path, loadOptions) { this.DEFAULT_METHOD = 'POST'; if (loadOptions == null) { loadOptions = {}; } this.weightPathPrefix = loadOptions.weightPathPrefix; this.weightUrlConverter = loadOptions.weightUrlConverter; if (loadOptions.fetchFunc != null) { assert(typeof loadOptions.fetchFunc === 'function', function () { return 'Must pass a function that matches the signature of ' + '`fetch` (see ' + 'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)'; }); this.fetch = loadOptions.fetchFunc; } else { this.fetch = env().platform.fetch; } assert(path != null && path.length > 0, function () { return 'URL path for http must not be null, undefined or ' + 'empty.'; }); if (Array.isArray(path)) { assert(path.length === 2, function () { return 'URL paths for http must have a length of 2, ' + "(actual length is ".concat(path.length, ")."); }); } this.path = path; if (loadOptions.requestInit != null && loadOptions.requestInit.body != null) { throw new Error('requestInit is expected to have no pre-existing body, but has one.'); } this.requestInit = loadOptions.requestInit || {}; this.loadOptions = loadOptions; } HTTPRequest.prototype.save = function (modelArtifacts) { return __awaiter(this, void 0, void 0, function () { var init, weightsManifest, modelTopologyAndWeightManifest, weightBuffer, response; return __generator(this, function (_a) { switch (_a.label) { case 0: if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error('BrowserHTTPRequest.save() does not support saving model topology ' + 'in binary formats yet.'); } init = Object.assign({ method: this.DEFAULT_METHOD }, this.requestInit); init.body = new FormData(); weightsManifest = [{ paths: ['./model.weights.bin'], weights: modelArtifacts.weightSpecs, }]; modelTopologyAndWeightManifest = getModelJSONForModelArtifacts(modelArtifacts, weightsManifest); init.body.append('model.json', new Blob([JSON.stringify(modelTopologyAndWeightManifest)], { type: JSON_TYPE }), 'model.json'); if (modelArtifacts.weightData != null) { weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData); init.body.append('model.weights.bin', new Blob([weightBuffer], { type: OCTET_STREAM_MIME_TYPE }), 'model.weights.bin'); } return [4 /*yield*/, this.fetch(this.path, init)]; case 1: response = _a.sent(); if (response.ok) { return [2 /*return*/, { modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts), responses: [response], }]; } else { throw new Error("BrowserHTTPRequest.save() failed due to HTTP response status " + "".concat(response.status, ".")); } } }); }); }; HTTPRequest.prototype.loadModelJSON = function () { return __awaiter(this, void 0, void 0, function () { var modelConfigRequest, modelJSON, message, modelTopology, weightsManifest; return __generator(this, function (_a) { switch (_a.label) { case 0: return [4 /*yield*/, this.fetch(this.path, this.requestInit)]; case 1: modelConfigRequest = _a.sent(); if (!modelConfigRequest.ok) { throw new Error("Request to ".concat(this.path, " failed with status code ") + "".concat(modelConfigRequest.status, ". Please verify this URL points to ") + "the model JSON of the model to load."); } _a.label = 2; case 2: _a.trys.push([2, 4, , 5]); return [4 /*yield*/, modelConfigRequest.json()]; case 3: modelJSON = _a.sent(); return [3 /*break*/, 5]; case 4: _a.sent(); message = "Failed to parse model JSON of response from ".concat(this.path, "."); // TODO(nsthorat): Remove this after some time when we're comfortable that // .pb files are mostly gone. if (this.path.endsWith('.pb')) { message += ' Your path contains a .pb file extension. ' + 'Support for .pb models have been removed in TensorFlow.js 1.0 ' + 'in favor of .json models. You can re-convert your Python ' + 'TensorFlow model using the TensorFlow.js 1.0 conversion scripts ' + 'or you can convert your.pb models with the \'pb2json\'' + 'NPM script in the tensorflow/tfjs-converter repository.'; } else { message += ' Please make sure the server is serving valid ' + 'JSON for this request.'; } throw new Error(message); case 5: modelTopology = modelJSON.modelTopology; weightsManifest = modelJSON.weightsManifest; if (modelTopology == null && weightsManifest == null) { throw new Error("The JSON from HTTP path ".concat(this.path, " contains neither model ") + "topology or manifest for weights."); } return [2 /*return*/, modelJSON]; } }); }); }; /** * Load model artifacts via HTTP request(s). * * See the documentation to `tf.io.http` for details on the saved * artifacts. * * @returns The loaded model artifacts (if loading succeeds). */ HTTPRequest.prototype.load = function () { return __awaiter(this, void 0, void 0, function () { var modelJSON; var _this = this; return __generator(this, function (_a) { switch (_a.label) { case 0: if (this.loadOptions.streamWeights) { return [2 /*return*/, this.loadStream()]; } return [4 /*yield*/, this.loadModelJSON()]; case 1: modelJSON = _a.sent(); return [2 /*return*/, getModelArtifactsForJSON(modelJSON, function (weightsManifest) { return _this.loadWeights(weightsManifest); })]; } }); }); }; HTTPRequest.prototype.loadStream = function () { return __awaiter(this, void 0, void 0, function () { var modelJSON, fetchURLs, weightSpecs, stream; var _this = this; return __generator(this, function (_a) { switch (_a.label) { case 0: return [4 /*yield*/, this.loadModelJSON()]; case 1: modelJSON = _a.sent(); return [4 /*yield*/, this.getWeightUrls(modelJSON.weightsManifest)]; case 2: fetchURLs = _a.sent(); weightSpecs = getWeightSpecs(modelJSON.weightsManifest); stream = function () { return streamWeights(fetchURLs, _this.loadOptions); }; return [2 /*return*/, Object.assign(Object.assign({}, modelJSON), { weightSpecs: weightSpecs, getWeightStream: stream })]; } }); }); }; HTTPRequest.prototype.getWeightUrls = function (weightsManifest) { return __awaiter(this, void 0, void 0, function () { var weightPath, _a, prefix, suffix, pathPrefix, fetchURLs, urlPromises, weightsManifest_1, weightsManifest_1_1, weightsGroup, _b, _c, path, _d, _e, _f, _g; var e_2, _h, e_3, _j; return __generator(this, function (_k) { switch (_k.label) { case 0: weightPath = Array.isArray(this.path) ? this.path[1] : this.path; _a = __read(parseUrl(weightPath), 2), prefix = _a[0], suffix = _a[1]; pathPrefix = this.weightPathPrefix || prefix; fetchURLs = []; urlPromises = []; try { for (weightsManifest_1 = __values(weightsManifest), weightsManifest_1_1 = weightsManifest_1.next(); !weightsManifest_1_1.done; weightsManifest_1_1 = weightsManifest_1.next()) { weightsGroup = weightsManifest_1_1.value; try { for (_b = (e_3 = void 0, __values(weightsGroup.paths)), _c = _b.next(); !_c.done; _c = _b.next()) { path = _c.value; if (this.weightUrlConverter != null) { urlPromises.push(this.weightUrlConverter(path)); } else { fetchURLs.push(pathPrefix + path + suffix); } } } catch (e_3_1) { e_3 = { error: e_3_1 }; } finally { try { if (_c && !_c.done && (_j = _b.return)) _j.call(_b); } finally { if (e_3) throw e_3.error; } } } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (weightsManifest_1_1 && !weightsManifest_1_1.done && (_h = weightsManifest_1.return)) _h.call(weightsManifest_1); } finally { if (e_2) throw e_2.error; } } if (!this.weightUrlConverter) return [3 /*break*/, 2]; _e = (_d = fetchURLs.push).apply; _f = [fetchURLs]; _g = [[]]; return [4 /*yield*/, Promise.all(urlPromises)]; case 1: _e.apply(_d, _f.concat([__spreadArray.apply(void 0, _g.concat([__read.apply(void 0, [_k.sent()]), false]))])); _k.label = 2; case 2: return [2 /*return*/, fetchURLs]; } }); }); }; HTTPRequest.prototype.loadWeights = function (weightsManifest) { return __awaiter(this, void 0, void 0, function () { var fetchURLs, weightSpecs, buffers; return __generator(this, function (_a) { switch (_a.label) { case 0: return [4 /*yield*/, this.getWeightUrls(weightsManifest)]; case 1: fetchURLs = _a.sent(); weightSpecs = getWeightSpecs(weightsManifest); return [4 /*yield*/, loadWeightsAsArrayBuffer(fetchURLs, this.loadOptions)]; case 2: buffers = _a.sent(); return [2 /*return*/, [weightSpecs, buffers]]; } }); }); }; return HTTPRequest; }()); HTTPRequest.URL_SCHEME_REGEX = /^https?:\/\//; /** * Extract the prefix and suffix of the url, where the prefix is the path before * the last file, and suffix is the search params after the last file. * ``` * const url = 'http://tfhub.dev/model/1/tensorflowjs_model.pb?tfjs-format=file' * [prefix, suffix] = parseUrl(url) * // prefix = 'http://tfhub.dev/model/1/' * // suffix = '?tfjs-format=file' * ``` * @param url the model url to be parsed. */ function parseUrl(url) { var lastSlash = url.lastIndexOf('/'); var lastSearchParam = url.lastIndexOf('?'); var prefix = url.substring(0, lastSlash); var suffix = lastSearchParam > lastSlash ? url.substring(lastSearchParam) : ''; return [prefix + '/', suffix]; } function isHTTPScheme(url) { return url.match(HTTPRequest.URL_SCHEME_REGEX) != null; } var httpRouter = function (url, loadOptions) { if (typeof fetch === 'undefined' && (loadOptions == null || loadOptions.fetchFunc == null)) { // `http` uses `fetch` or `node-fetch`, if one wants to use it in // an environment that is not the browser or node they have to setup a // global fetch polyfill. return null; } else { var isHTTP = true; if (Array.isArray(url)) { isHTTP = url.every(function (urlItem) { return isHTTPScheme(urlItem); }); } else { isHTTP = isHTTPScheme(url); } if (isHTTP) { return http(url, loadOptions); } } return null; }; IORouterRegistry.registerSaveRouter(httpRouter); IORouterRegistry.registerLoadRouter(httpRouter); /** * Creates an IOHandler subtype that sends model artifacts to HTTP server. * * An HTTP request of the `multipart/form-data` mime type will be sent to the * `path` URL. The form data includes artifacts that represent the topology * and/or weights of the model. In the case of Keras-style `tf.Model`, two * blobs (files) exist in form-data: * - A JSON file consisting of `modelTopology` and `weightsManifest`. * - A binary weights file consisting of the concatenated weight values. * These files are in the same format as the one generated by * [tfjs_converter](https://js.tensorflow.org/tutorials/import-keras.html). * * The following code snippet exemplifies the client-side code that uses this * function: * * ```js * const model = tf.sequential(); * model.add( * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'})); * * const saveResult = await model.save(tf.io.http( * 'http://model-server:5000/upload', {requestInit: {method: 'PUT'}})); * console.log(saveResult); * ``` * * If the default `POST` method is to be used, without any custom parameters * such as headers, you can simply pass an HTTP or HTTPS URL to `model.save`: * * ```js * const saveResult = await model.save('http://model-server:5000/upload'); * ``` * * The following GitHub Gist * https://gist.github.com/dsmilkov/1b6046fd6132d7408d5257b0976f7864 * implements a server based on [flask](https://github.com/pallets/flask) that * can receive the request. Upon receiving the model artifacts via the requst, * this particular server reconstitutes instances of [Keras * Models](https://keras.io/models/model/) in memory. * * * @param path A URL path to the model. * Can be an absolute HTTP path (e.g., * 'http://localhost:8000/model-upload)') or a relative path (e.g., * './model-upload'). * @param requestInit Request configurations to be used when sending * HTTP request to server using `fetch`. It can contain fields such as * `method`, `credentials`, `headers`, `mode`, etc. See * https://developer.mozilla.org/en-US/docs/Web/API/Request/Request * for more information. `requestInit` must not have a body, because the * body will be set by TensorFlow.js. File blobs representing the model * topology (filename: 'model.json') and the weights of the model (filename: * 'model.weights.bin') will be appended to the body. If `requestInit` has a * `body`, an Error will be thrown. * @param loadOptions Optional configuration for the loading. It includes the * following fields: * - weightPathPrefix Optional, this specifies the path prefix for weight * files, by default this is calculated from the path param. * - fetchFunc Optional, custom `fetch` function. E.g., in Node.js, * the `fetch` from node-fetch can be used here. * - onProgress Optional, progress callback function, fired periodically * before the load is completed. * @returns An instance of `IOHandler`. * * @doc { * heading: 'Models', * subheading: 'Loading', * namespace: 'io', * ignoreCI: true * } */ function http(path, loadOptions) { return new HTTPRequest(path, loadOptions); } /** * Check whether updates.shape = indices.shape[:batchDim] + * shape[sliceDim:] * * @param x The input tensor. */ function validateUpdateShape(shape, indices, updates) { var sliceDim = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1; var batchDim = (indices.rank > 1) ? indices.rank - 1 : 1; var shapeError = 'Must have updates.shape = indices.shape[:batchDim] + ' + "shape[sliceDim:], got updates.shape: ".concat(updates.shape) + ", indices.shape: ".concat(indices.shape, ", shape: ").concat(shape) + ", sliceDim: ".concat(sliceDim, ", and batchDim: ").concat(batchDim, "."); if (updates.rank < batchDim) { throw new Error(shapeError + " update.rank < ".concat(batchDim, ". ")); } if (shape.length < sliceDim + (updates.rank - batchDim)) { throw new Error(shapeError + " Output shape length < ".concat(sliceDim + (updates.rank - batchDim))); } if (updates.rank !== batchDim + shape.length - sliceDim) { throw new Error(shapeError + " update.rank != ".concat(batchDim + shape.length - sliceDim)); } for (var d = 0; d < batchDim; ++d) { if (updates.shape[d] !== indices.shape[d]) { throw new Error(shapeError + " updates.shape[".concat(d, "] (").concat(updates.shape[d], ") != indices.shape[").concat(d, "] (").concat(indices.shape[d], ").")); } } for (var d = 0; d < updates.rank - batchDim; ++d) { if (updates.shape[d + batchDim] !== shape[d + sliceDim]) { throw new Error(shapeError + " updates.shape[".concat(d + batchDim, "] (").concat(updates.shape[d + batchDim], ") != shape[").concat(d + batchDim, "] (").concat(shape[d + batchDim], ")")); } } } /** * Validate scatter nd inputs. * * @param update The tensor contains the update values. * @param indices The tensor contains the indices for the update values. * @param shape The shape of the output tensor. */ function validateInput(updates, indices, shape) { if (indices.rank < 1) { throw new Error('tf.scatterND() expects the indices to be rank 1 or higher,' + " but the rank was ".concat(indices.rank, ".")); } if (updates.rank < 1) { throw new Error('tf.scatterND() expects the updates to be rank 1 or higher,' + " but the rank was ".concat(updates.rank, ".")); } if (indices.dtype !== 'int32') { throw new Error("The dtype of 'indices' should be int32, but got dtype: ".concat(indices.dtype)); } if (shape.length < 1) { throw new Error("Output rank must be greater or equal to 1, but got shape: ".concat(shape)); } if (shape.length === 0) { if (indices.size === 0) { throw new Error("Indices specified for empty output. indices shape: ".concat(indices.shape)); } if (updates.size === 0) { throw new Error("Updates specified for empty output. updates shape: ".concat(updates.shape)); } } validateUpdateShape(shape, indices, updates); } function parseSliceParams(x, begin, size) { // The following logic allows for more ergonomic calls. var begin_; var xRank = x.shape.length; if (typeof begin === 'number') { begin_ = __spreadArray([begin], __read(new Array(xRank - 1).fill(0)), false); } else if (begin.length < xRank) { begin_ = begin.concat(new Array(xRank - begin.length).fill(0)); } else { begin_ = begin.slice(); } begin_.forEach(function (d) { assert(d !== -1, function () { return 'slice() does not support negative begin indexing.'; }); }); var size_; if (size == null) { size_ = new Array(xRank).fill(-1); } else if (typeof size === 'number') { size_ = __spreadArray([size], __read(new Array(xRank - 1).fill(-1)), false); } else if (size.length < xRank) { size_ = size.concat(new Array(xRank - size.length).fill(-1)); } else { size_ = size; } size_ = size_.map(function (d, i) { if (d >= 0) { return d; } else { assert(d === -1, function () { return "Negative size values should be exactly -1 but got " + "".concat(d, " for the slice() size at index ").concat(i, "."); }); return x.shape[i] - begin_[i]; } }); return [begin_, size_]; } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Serializable defines the serialization contract. * * TFJS requires serializable classes to return their className when asked * to avoid issues with minification. */ var Serializable = /** @class */ (function () { function Serializable() { } /** * Return the class name for this class to use in serialization contexts. * * Generally speaking this will be the same thing that constructor.name * would have returned. However, the class name needs to be robust * against minification for serialization/deserialization to work properly. * * There's also places such as initializers.VarianceScaling, where * implementation details between different languages led to different * class hierarchies and a non-leaf node is used for serialization purposes. */ Serializable.prototype.getClassName = function () { return this.constructor .className; }; /** * Creates an instance of T from a ConfigDict. * * This works for most descendants of serializable. A few need to * provide special handling. * @param cls A Constructor for the class to instantiate. * @param config The Configuration for the object. */ /** @nocollapse */ Serializable.fromConfig = function (cls, config) { return new cls(config); }; return Serializable; }()); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes absolute value element-wise: `abs(x)` * * ```js * const x = tf.tensor1d([-1, 2, -3, 4]); * * x.abs().print(); // or tf.abs(x) * ``` * @param x The input `tf.Tensor`. * * @doc {heading: 'Operations', subheading: 'Basic math'} */ function abs_(x) { var $x = convertToTensor(x, 'x', 'abs'); if ($x.dtype === 'complex64') { var inputs = { x: $x }; return ENGINE.runKernel(ComplexAbs, inputs); } else { var inputs = { x: $x }; return ENGINE.runKernel(Abs, inputs); } } var abs = /* @__PURE__ */ op({ abs_: abs_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Concatenates a list of `tf.Tensor`s along a given axis. * * The tensors ranks and types must match, and their sizes must match in all * dimensions except `axis`. * * Also available are stricter rank-specific methods that assert that * `tensors` are of the given rank: * - `tf.concat1d` * - `tf.concat2d` * - `tf.concat3d` * - `tf.concat4d` * * Except `tf.concat1d` (which does not have axis param), all methods have * same signature as this method. * * ```js * const a = tf.tensor1d([1, 2]); * const b = tf.tensor1d([3, 4]); * a.concat(b).print(); // or a.concat(b) * ``` * * ```js * const a = tf.tensor1d([1, 2]); * const b = tf.tensor1d([3, 4]); * const c = tf.tensor1d([5, 6]); * tf.concat([a, b, c]).print(); * ``` * * ```js * const a = tf.tensor2d([[1, 2], [10, 20]]); * const b = tf.tensor2d([[3, 4], [30, 40]]); * const axis = 1; * tf.concat([a, b], axis).print(); * ``` * @param tensors A list of tensors to concatenate. * @param axis The axis to concatenate along. Defaults to 0 (the first dim). * * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ function concat_(tensors, axis) { if (axis === void 0) { axis = 0; } assert(tensors.length >= 1, function () { return 'Pass at least one tensor to concat'; }); var $tensors = convertToTensorArray(tensors, 'tensors', 'concat', 'string_or_numeric'); if ($tensors[0].dtype === 'complex64') { $tensors.forEach(function (tensor) { if (tensor.dtype !== 'complex64') { throw new Error("Cannot concatenate complex64 tensors with a tensor\n with dtype ".concat(tensor.dtype, ". ")); } }); } if ($tensors.length === 1) { return clone($tensors[0]); } var inputs = $tensors; var attr = { axis: axis }; return ENGINE.runKernel(Concat, inputs, attr); } var concat = /* @__PURE__ */ op({ concat_: concat_ }); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes sigmoid element-wise, `1 / (1 + exp(-x))` * * ```js * const x = tf.tensor1d([0, -1, 2, -3]); * * x.sigmoid().print(); // or tf.sigmoid(x) * ``` * @param x The input tensor. * * @doc {heading: 'Operations', subheading: 'Basic math'} */ function sigmoid_(x) { var $x = convertToTensor(x, 'x', 'sigmoid', 'float32'); var inputs = { x: $x }; return ENGINE.runKernel(Sigmoid$1, inputs); } var sigmoid = /* @__PURE__ */ op({ sigmoid_: sigmoid_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of * shape `blockShape + [batch]`, interleaves these blocks back into the grid * defined by the spatial dimensions `[1, ..., M]`, to obtain a result with * the same rank as the input. The spatial dimensions of this intermediate * result are then optionally cropped according to `crops` to produce the * output. This is the reverse of `tf.spaceToBatchND`. See below for a precise * description. * * ```js * const x = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]); * const blockShape = [2, 2]; * const crops = [[0, 0], [0, 0]]; * * x.batchToSpaceND(blockShape, crops).print(); * ``` * * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape + * remainingShape`, where spatialShape has `M` dimensions. * @param blockShape A 1-D array. Must have shape `[M]`, all values must * be >= 1. * @param crops A 2-D array. Must have shape `[M, 2]`, all values must be >= 0. * `crops[i] = [cropStart, cropEnd]` specifies the amount to crop from input * dimension `i + 1`, which corresponds to spatial dimension `i`. It is required * that `cropStart[i] + cropEnd[i] <= blockShape[i] * inputShape[i + 1]` * * This operation is equivalent to the following steps: * * 1. Reshape `x` to `reshaped` of shape: `[blockShape[0], ..., * blockShape[M-1], batch / prod(blockShape), x.shape[1], ..., * x.shape[N-1]]` * * 2. Permute dimensions of `reshaped` to produce `permuted` of shape `[batch / * prod(blockShape),x.shape[1], blockShape[0], ..., x.shape[M], * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]` * * 3. Reshape `permuted` to produce `reshapedPermuted` of shape `[batch / * prod(blockShape),x.shape[1] * blockShape[0], ..., x.shape[M] * * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]` * * 4. Crop the start and end of dimensions `[1, ..., M]` of `reshapedPermuted` * according to `crops` to produce the output of shape: `[batch / * prod(blockShape),x.shape[1] * blockShape[0] - crops[0,0] - crops[0,1], * ..., x.shape[M] * blockShape[M-1] - crops[M-1,0] - * crops[M-1,1],x.shape[M+1], ..., x.shape[N-1]]` * * @doc {heading: 'Tensors', subheading: 'Transformations'} */ function batchToSpaceND_(x, blockShape, crops) { var $x = convertToTensor(x, 'x', 'batchToSpaceND'); var prod = blockShape.reduce(function (a, b) { return a * b; }); assert($x.rank >= 1 + blockShape.length, function () { return "input rank is ".concat($x.rank, " but should be > than blockShape.length ").concat(blockShape.length); }); assert(crops.length === blockShape.length, function () { return "crops.length is ".concat(crops.length, " but should be equal to blockShape.length ").concat(blockShape.length); }); assert($x.shape[0] % prod === 0, function () { return "input tensor batch is ".concat($x.shape[0], " but is not divisible by the product of ") + "the elements of blockShape ".concat(blockShape.join(' * '), " === ").concat(prod); }); var inputs = { x: $x }; var attrs = { blockShape: blockShape, crops: crops }; return ENGINE.runKernel(BatchToSpaceND, inputs, attrs); } var batchToSpaceND = /* @__PURE__ */ op({ batchToSpaceND_: batchToSpaceND_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Creates a `tf.Tensor` filled with a scalar value. * * ```js * tf.fill([2, 2], 4).print(); * ``` * * @param shape An array of integers defining the output tensor shape. * @param value The scalar value to fill the tensor with. * @param dtype The type of an element in the resulting tensor. Defaults to * 'float32' if the given param value is a number, otherwise 'string'. * * @doc {heading: 'Tensors', subheading: 'Creation'} */ function fill(shape, value, dtype) { assertNonNegativeIntegerDimensions(shape); dtype = dtype || inferDtype(value); var attrs = { shape: shape, value: value, dtype: dtype }; return ENGINE.runKernel(Fill, {}, attrs); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes cos of the input `tf.Tensor` element-wise: `cos(x)` * * ```js * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); * * x.cos().print(); // or tf.cos(x) * ``` * @param x The input tensor. Must be float32 type. * * @doc {heading: 'Operations', subheading: 'Basic math'} */ function cos_(x) { var $x = convertToTensor(x, 'x', 'cos', 'float32'); var inputs = { x: $x }; return ENGINE.runKernel(Cos, inputs); } var cos = /* @__PURE__ */ op({ cos_: cos_ }); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes hyperbolic cos of the input `tf.Tensor` element-wise: `cosh(x)` * * ```js * const x = tf.tensor1d([0, 1, -1, .7]); * * x.cosh().print(); // or tf.cosh(x) * ``` * @param x The input tensor. Must be float32 type. * * @doc {heading: 'Operations', subheading: 'Basic math'} */ function cosh_(x) { var $x = convertToTensor(x, 'x', 'cosh', 'float32'); var inputs = { x: $x }; return ENGINE.runKernel(Cosh, inputs); } var cosh = /* @__PURE__ */ op({ cosh_: cosh_ }); /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the 'License'); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an 'AS IS' BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes the cumulative product of a `tf.Tensor` along `axis`. * * ```js * const x = tf.tensor([1, 2, 3, 4]); * x.cumprod().print(); * ``` * ```js * const x = tf.tensor([[1, 2], [3, 4]]); * x.cumprod().print(); * ``` * * @param x The input tensor to cumulatively multiply. * @param axis The axis along which to multiply. Optional. Defaults to 0. * @param exclusive Whether to perform exclusive cumulative product. Optional. * Defaults to false. If set to true then the product of each tensor entry * does not include its own value, but only the values previous to it * along the specified axis. * @param reverse Whether to multiply in the opposite direction. Optional. * Defaults to false. * * @doc {heading: 'Operations', subheading: 'Scan'} */ function cumprod_(x, axis, exclusive, reverse) { if (axis === void 0) { axis = 0; } if (exclusive === void 0) { exclusive = false; } if (reverse === void 0) { reverse = false; } var $x = convertToTensor(x, 'x', 'cumprod'); var inputs = { x: $x }; var attrs = { axis: axis, exclusive: exclusive, reverse: reverse }; return ENGINE.runKernel(Cumprod, inputs, attrs); } var cumprod = /* @__PURE__ */ op({ cumprod_: cumprod_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension * into the tensor's shape. * * ```js * const x = tf.tensor1d([1, 2, 3, 4]); * const axis = 1; * x.expandDims(axis).print(); * ``` * * @param x The input tensor whose dimensions are to be expanded. * @param axis The dimension index at which to insert shape of `1`. Defaults * to 0 (the first dimension). * * @doc {heading: 'Tensors', subheading: 'Transformations'} */ function expandDims_(x, axis) { if (axis === void 0) { axis = 0; } var $x = convertToTensor(x, 'x', 'expandDims', 'string_or_numeric'); assert(axis <= $x.rank, function () { return 'Axis must be <= rank of the tensor'; }); var inputs = { input: $x }; var attrs = { dim: axis }; return ENGINE.runKernel(ExpandDims, inputs, attrs); } var expandDims = /* @__PURE__ */ op({ expandDims_: expandDims_ }); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Gather slices from tensor `x`'s axis `axis` according to `indices`. * * ```js * const x = tf.tensor1d([1, 2, 3, 4]); * const indices = tf.tensor1d([1, 3, 3], 'int32'); * * x.gather(indices).print(); * ``` * * ```js * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); * const indices = tf.tensor1d([1, 1, 0], 'int32'); * * x.gather(indices).print(); * ``` * @param x The input tensor whose slices are to be gathered. * @param indices The indices of the values to extract. * @param axis The axis over which to select values. Defaults to 0. * @param batchDims Optional. The number of batch dimensions. It must be less * than or equal to rank(indices). Defaults to 0. * The output tensor will have shape of * `x.shape[:axis] + indices.shape[batchDims:] + x.shape[axis + 1:]` * * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ function gather_(x, indices, axis, batchDims) { if (axis === void 0) { axis = 0; } if (batchDims === void 0) { batchDims = 0; } var $x = convertToTensor(x, 'x', 'gather'); var $indices = convertToTensor(indices, 'indices', 'gather', 'int32'); var inputs = { x: $x, indices: $indices }; var attrs = { axis: axis, batchDims: batchDims }; return ENGINE.runKernel(GatherV2, inputs, attrs); } var gather = /* @__PURE__ */ op({ gather_: gather_ }); /** * Computes and returns the gradient of f(x) with respect to the list of * trainable variables provided by `varList`. If no list is provided, it * defaults to all trainable variables. * * ```js * const a = tf.variable(tf.tensor1d([3, 4])); * const b = tf.variable(tf.tensor1d([5, 6])); * const x = tf.tensor1d([1, 2]); * * // f(a, b) = a * x ^ 2 + b * x * const f = () => a.mul(x.square()).add(b.mul(x)).sum(); * // df/da = x ^ 2, df/db = x * const {value, grads} = tf.variableGrads(f); * * Object.keys(grads).forEach(varName => grads[varName].print()); * ``` * * @param f The function to execute. f() should return a scalar. * @param varList The list of variables to compute the gradients with respect * to. Defaults to all trainable variables. * @returns An object with the following keys and values: * - `value`: The value of the function `f`. * - `grads`: A map from the names of the variables to the gradients. * If the `varList` argument is provided explicitly and contains a subset of * non-trainable variables, this map in the return value will contain keys * that map the names of the non-trainable variables to `null`. * * @doc {heading: 'Training', subheading: 'Gradients'} */ function variableGrads(f, varList) { assert(isFunction(f), function () { return 'The f passed in variableGrads(f) must be a function'; }); assert(varList == null || Array.isArray(varList) && varList.every(function (v) { return v instanceof Variable; }), function () { return 'The varList passed in variableGrads(f, varList) must be an array ' + 'of variables'; }); var specifiedVarList = varList != null; if (!specifiedVarList) { // Get all of the trainable variables. varList = []; for (var varName in ENGINE.registeredVariables) { varList.push(ENGINE.registeredVariables[varName]); } } var specifiedNonTrainable = specifiedVarList ? varList.filter(function (variable) { return !variable.trainable; }) : null; // Prune non-trainable variables. var originalVarCount = varList.length; varList = varList.filter(function (variable) { return variable.trainable; }); assert(varList.length > 0, function () { return "variableGrads() expects at least one of the input variables to " + "be trainable, but none of the ".concat(originalVarCount, " variables is ") + "trainable."; }); var allowNoGradients = true; var _a = ENGINE.gradients(f, varList, null, allowNoGradients), value = _a.value, grads = _a.grads; assert(grads.some(function (g) { return g != null; }), function () { return 'Cannot find a connection between any variable and the result of ' + 'the loss function y=f(x). Please make sure the operations that ' + 'use variables are inside the function f passed to minimize().'; }); assert(value.rank === 0, function () { return "The f passed in variableGrads(f) must return a scalar, but it " + "returned a rank-".concat(value.rank, " tensor"); }); var namedGrads = {}; varList.forEach(function (v, i) { if (grads[i] != null) { namedGrads[v.name] = grads[i]; } }); if (specifiedNonTrainable != null) { // If varList is explicitly provided and contains non-trainable values, // add them to the returned gradients with `null` values. specifiedNonTrainable.forEach(function (v) { return namedGrads[v.name] = null; }); } return { value: value, grads: namedGrads }; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Returns the truth value of `NOT x` element-wise. * * ```js * const a = tf.tensor1d([false, true], 'bool'); * * a.logicalNot().print(); * ``` * * @param x The input tensor. Must be of dtype 'bool'. * * @doc {heading: 'Operations', subheading: 'Logical'} */ function logicalNot_(x) { var $x = convertToTensor(x, 'x', 'logicalNot', 'bool'); var inputs = { x: $x }; return ENGINE.runKernel(LogicalNot, inputs); } var logicalNot = /* @__PURE__ */ op({ logicalNot_: logicalNot_ }); /** * Returns the max of a and b (`a > b ? a : b`) element-wise. * Supports broadcasting. * * We also expose `tf.maximumStrict` which has the same signature as this op and * asserts that `a` and `b` are the same shape (does not broadcast). * * ```js * const a = tf.tensor1d([1, 4, 3, 16]); * const b = tf.tensor1d([1, 2, 9, 4]); * * a.maximum(b).print(); // or tf.maximum(a, b) * ``` * * ```js * // Broadcast maximum a with b. * const a = tf.tensor1d([2, 4, 6, 8]); * const b = tf.scalar(5); * * a.maximum(b).print(); // or tf.maximum(a, b) * ``` * * @param a The first tensor. * @param b The second tensor. Must have the same type as `a`. * * @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function maximum_(a, b) { var _a; var $a = convertToTensor(a, 'a', 'maximum'); var $b = convertToTensor(b, 'b', 'maximum'); _a = __read(makeTypesMatch($a, $b), 2), $a = _a[0], $b = _a[1]; if ($a.dtype === 'bool') { $a = cast($a, 'int32'); $b = cast($b, 'int32'); } assertAndGetBroadcastShape($a.shape, $b.shape); var inputs = { a: $a, b: $b }; return ENGINE.runKernel(Maximum$1, inputs); } var maximum$1 = /* @__PURE__ */ op({ maximum_: maximum_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Pads a `tf.Tensor` with a given value and paddings. * * This operation implements `CONSTANT` mode. For `REFLECT` and `SYMMETRIC`, * refer to `tf.mirrorPad`. * * Also available are stricter rank-specific methods with the same signature * as this method that assert that `paddings` is of given length. * - `tf.pad1d` * - `tf.pad2d` * - `tf.pad3d` * - `tf.pad4d` * * ```js * const x = tf.tensor1d([1, 2, 3, 4]); * x.pad([[1, 2]]).print(); * ``` * @param x The tensor to pad. * @param paddings An array of length `R` (the rank of the tensor), where * each element is a length-2 tuple of ints `[padBefore, padAfter]`, * specifying how much to pad along each dimension of the tensor. * @param constantValue The pad value to use. Defaults to 0. * * @doc {heading: 'Tensors', subheading: 'Transformations'} */ function pad_(x, paddings, constantValue) { if (constantValue === void 0) { constantValue = 0; } var $x = convertToTensor(x, 'x', 'pad'); if ($x.rank === 0) { throw new Error('pad(scalar) is not defined. Pass non-scalar to pad'); } var attrs = { paddings: paddings, constantValue: constantValue }; var inputs = { x: $x }; return ENGINE.runKernel(PadV2, inputs, attrs); } var pad = /* @__PURE__ */ op({ pad_: pad_ }); var alea$1 = { exports: {} }; (function (module) { // A port of an algorithm by Johannes Baagøe , 2010 // http://baagoe.com/en/RandomMusings/javascript/ // https://github.com/nquinlan/better-random-numbers-for-javascript-mirror // Original work is under MIT license - // Copyright (C) 2010 by Johannes Baagøe // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. (function (global, module, define) { function Alea(seed) { var me = this, mash = Mash(); me.next = function () { var t = 2091639 * me.s0 + me.c * 2.3283064365386963e-10; // 2^-32 me.s0 = me.s1; me.s1 = me.s2; return me.s2 = t - (me.c = t | 0); }; // Apply the seeding algorithm from Baagoe. me.c = 1; me.s0 = mash(' '); me.s1 = mash(' '); me.s2 = mash(' '); me.s0 -= mash(seed); if (me.s0 < 0) { me.s0 += 1; } me.s1 -= mash(seed); if (me.s1 < 0) { me.s1 += 1; } me.s2 -= mash(seed); if (me.s2 < 0) { me.s2 += 1; } mash = null; } function copy(f, t) { t.c = f.c; t.s0 = f.s0; t.s1 = f.s1; t.s2 = f.s2; return t; } function impl(seed, opts) { var xg = new Alea(seed), state = opts && opts.state, prng = xg.next; prng.int32 = function () { return (xg.next() * 0x100000000) | 0; }; prng.double = function () { return prng() + (prng() * 0x200000 | 0) * 1.1102230246251565e-16; // 2^-53 }; prng.quick = prng; if (state) { if (typeof (state) == 'object') copy(state, xg); prng.state = function () { return copy(xg, {}); }; } return prng; } function Mash() { var n = 0xefc8249d; var mash = function (data) { data = String(data); for (var i = 0; i < data.length; i++) { n += data.charCodeAt(i); var h = 0.02519603282416938 * n; n = h >>> 0; h -= n; h *= n; n = h >>> 0; h -= n; n += h * 0x100000000; // 2^32 } return (n >>> 0) * 2.3283064365386963e-10; // 2^-32 }; return mash; } if (module && module.exports) { module.exports = impl; } else if (define && define.amd) { define(function () { return impl; }); } else { this.alea = impl; } })(commonjsGlobal, module, // present in node.js (typeof undefined) == 'function' // present with an AMD loader ); }(alea$1)); var aleaExports = alea$1.exports; var xor128$1 = { exports: {} }; (function (module) { // A Javascript implementaion of the "xor128" prng algorithm by // George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper (function (global, module, define) { function XorGen(seed) { var me = this, strseed = ''; me.x = 0; me.y = 0; me.z = 0; me.w = 0; // Set up generator function. me.next = function () { var t = me.x ^ (me.x << 11); me.x = me.y; me.y = me.z; me.z = me.w; return me.w ^= (me.w >>> 19) ^ t ^ (t >>> 8); }; if (seed === (seed | 0)) { // Integer seed. me.x = seed; } else { // String seed. strseed += seed; } // Mix in string seed, then discard an initial batch of 64 values. for (var k = 0; k < strseed.length + 64; k++) { me.x ^= strseed.charCodeAt(k) | 0; me.next(); } } function copy(f, t) { t.x = f.x; t.y = f.y; t.z = f.z; t.w = f.w; return t; } function impl(seed, opts) { var xg = new XorGen(seed), state = opts && opts.state, prng = function () { return (xg.next() >>> 0) / 0x100000000; }; prng.double = function () { do { var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 0x100000000, result = (top + bot) / (1 << 21); } while (result === 0); return result; }; prng.int32 = xg.next; prng.quick = prng; if (state) { if (typeof (state) == 'object') copy(state, xg); prng.state = function () { return copy(xg, {}); }; } return prng; } if (module && module.exports) { module.exports = impl; } else if (define && define.amd) { define(function () { return impl; }); } else { this.xor128 = impl; } })(commonjsGlobal, module, // present in node.js (typeof undefined) == 'function' // present with an AMD loader ); }(xor128$1)); var xor128Exports = xor128$1.exports; var xorwow$1 = { exports: {} }; (function (module) { // A Javascript implementaion of the "xorwow" prng algorithm by // George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper (function (global, module, define) { function XorGen(seed) { var me = this, strseed = ''; // Set up generator function. me.next = function () { var t = (me.x ^ (me.x >>> 2)); me.x = me.y; me.y = me.z; me.z = me.w; me.w = me.v; return (me.d = (me.d + 362437 | 0)) + (me.v = (me.v ^ (me.v << 4)) ^ (t ^ (t << 1))) | 0; }; me.x = 0; me.y = 0; me.z = 0; me.w = 0; me.v = 0; if (seed === (seed | 0)) { // Integer seed. me.x = seed; } else { // String seed. strseed += seed; } // Mix in string seed, then discard an initial batch of 64 values. for (var k = 0; k < strseed.length + 64; k++) { me.x ^= strseed.charCodeAt(k) | 0; if (k == strseed.length) { me.d = me.x << 10 ^ me.x >>> 4; } me.next(); } } function copy(f, t) { t.x = f.x; t.y = f.y; t.z = f.z; t.w = f.w; t.v = f.v; t.d = f.d; return t; } function impl(seed, opts) { var xg = new XorGen(seed), state = opts && opts.state, prng = function () { return (xg.next() >>> 0) / 0x100000000; }; prng.double = function () { do { var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 0x100000000, result = (top + bot) / (1 << 21); } while (result === 0); return result; }; prng.int32 = xg.next; prng.quick = prng; if (state) { if (typeof (state) == 'object') copy(state, xg); prng.state = function () { return copy(xg, {}); }; } return prng; } if (module && module.exports) { module.exports = impl; } else if (define && define.amd) { define(function () { return impl; }); } else { this.xorwow = impl; } })(commonjsGlobal, module, // present in node.js (typeof undefined) == 'function' // present with an AMD loader ); }(xorwow$1)); var xorwowExports = xorwow$1.exports; var xorshift7$1 = { exports: {} }; (function (module) { // A Javascript implementaion of the "xorshift7" algorithm by // François Panneton and Pierre L'ecuyer: // "On the Xorgshift Random Number Generators" // http://saluc.engr.uconn.edu/refs/crypto/rng/panneton05onthexorshift.pdf (function (global, module, define) { function XorGen(seed) { var me = this; // Set up generator function. me.next = function () { // Update xor generator. var X = me.x, i = me.i, t, v; t = X[i]; t ^= (t >>> 7); v = t ^ (t << 24); t = X[(i + 1) & 7]; v ^= t ^ (t >>> 10); t = X[(i + 3) & 7]; v ^= t ^ (t >>> 3); t = X[(i + 4) & 7]; v ^= t ^ (t << 7); t = X[(i + 7) & 7]; t = t ^ (t << 13); v ^= t ^ (t << 9); X[i] = v; me.i = (i + 1) & 7; return v; }; function init(me, seed) { var j, X = []; if (seed === (seed | 0)) { // Seed state array using a 32-bit integer. X[0] = seed; } else { // Seed state using a string. seed = '' + seed; for (j = 0; j < seed.length; ++j) { X[j & 7] = (X[j & 7] << 15) ^ (seed.charCodeAt(j) + X[(j + 1) & 7] << 13); } } // Enforce an array length of 8, not all zeroes. while (X.length < 8) X.push(0); for (j = 0; j < 8 && X[j] === 0; ++j) ; if (j == 8) X[7] = -1; else X[j]; me.x = X; me.i = 0; // Discard an initial 256 values. for (j = 256; j > 0; --j) { me.next(); } } init(me, seed); } function copy(f, t) { t.x = f.x.slice(); t.i = f.i; return t; } function impl(seed, opts) { if (seed == null) seed = +(new Date); var xg = new XorGen(seed), state = opts && opts.state, prng = function () { return (xg.next() >>> 0) / 0x100000000; }; prng.double = function () { do { var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 0x100000000, result = (top + bot) / (1 << 21); } while (result === 0); return result; }; prng.int32 = xg.next; prng.quick = prng; if (state) { if (state.x) copy(state, xg); prng.state = function () { return copy(xg, {}); }; } return prng; } if (module && module.exports) { module.exports = impl; } else if (define && define.amd) { define(function () { return impl; }); } else { this.xorshift7 = impl; } })(commonjsGlobal, module, // present in node.js (typeof undefined) == 'function' // present with an AMD loader ); }(xorshift7$1)); var xorshift7Exports = xorshift7$1.exports; var xor4096$1 = { exports: {} }; (function (module) { // A Javascript implementaion of Richard Brent's Xorgens xor4096 algorithm. // // This fast non-cryptographic random number generator is designed for // use in Monte-Carlo algorithms. It combines a long-period xorshift // generator with a Weyl generator, and it passes all common batteries // of stasticial tests for randomness while consuming only a few nanoseconds // for each prng generated. For background on the generator, see Brent's // paper: "Some long-period random number generators using shifts and xors." // http://arxiv.org/pdf/1004.3115v1.pdf // // Usage: // // var xor4096 = require('xor4096'); // random = xor4096(1); // Seed with int32 or string. // assert.equal(random(), 0.1520436450538547); // (0, 1) range, 53 bits. // assert.equal(random.int32(), 1806534897); // signed int32, 32 bits. // // For nonzero numeric keys, this impelementation provides a sequence // identical to that by Brent's xorgens 3 implementaion in C. This // implementation also provides for initalizing the generator with // string seeds, or for saving and restoring the state of the generator. // // On Chrome, this prng benchmarks about 2.1 times slower than // Javascript's built-in Math.random(). (function (global, module, define) { function XorGen(seed) { var me = this; // Set up generator function. me.next = function () { var w = me.w, X = me.X, i = me.i, t, v; // Update Weyl generator. me.w = w = (w + 0x61c88647) | 0; // Update xor generator. v = X[(i + 34) & 127]; t = X[i = ((i + 1) & 127)]; v ^= v << 13; t ^= t << 17; v ^= v >>> 15; t ^= t >>> 12; // Update Xor generator array state. v = X[i] = v ^ t; me.i = i; // Result is the combination. return (v + (w ^ (w >>> 16))) | 0; }; function init(me, seed) { var t, v, i, j, w, X = [], limit = 128; if (seed === (seed | 0)) { // Numeric seeds initialize v, which is used to generates X. v = seed; seed = null; } else { // String seeds are mixed into v and X one character at a time. seed = seed + '\0'; v = 0; limit = Math.max(limit, seed.length); } // Initialize circular array and weyl value. for (i = 0, j = -32; j < limit; ++j) { // Put the unicode characters into the array, and shuffle them. if (seed) v ^= seed.charCodeAt((j + 32) % seed.length); // After 32 shuffles, take v as the starting w value. if (j === 0) w = v; v ^= v << 10; v ^= v >>> 15; v ^= v << 4; v ^= v >>> 13; if (j >= 0) { w = (w + 0x61c88647) | 0; // Weyl. t = (X[j & 127] ^= (v + w)); // Combine xor and weyl to init array. i = (0 == t) ? i + 1 : 0; // Count zeroes. } } // We have detected all zeroes; make the key nonzero. if (i >= 128) { X[(seed && seed.length || 0) & 127] = -1; } // Run the generator 512 times to further mix the state before using it. // Factoring this as a function slows the main generator, so it is just // unrolled here. The weyl generator is not advanced while warming up. i = 127; for (j = 4 * 128; j > 0; --j) { v = X[(i + 34) & 127]; t = X[i = ((i + 1) & 127)]; v ^= v << 13; t ^= t << 17; v ^= v >>> 15; t ^= t >>> 12; X[i] = v ^ t; } // Storing state as object members is faster than using closure variables. me.w = w; me.X = X; me.i = i; } init(me, seed); } function copy(f, t) { t.i = f.i; t.w = f.w; t.X = f.X.slice(); return t; } function impl(seed, opts) { if (seed == null) seed = +(new Date); var xg = new XorGen(seed), state = opts && opts.state, prng = function () { return (xg.next() >>> 0) / 0x100000000; }; prng.double = function () { do { var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 0x100000000, result = (top + bot) / (1 << 21); } while (result === 0); return result; }; prng.int32 = xg.next; prng.quick = prng; if (state) { if (state.X) copy(state, xg); prng.state = function () { return copy(xg, {}); }; } return prng; } if (module && module.exports) { module.exports = impl; } else if (define && define.amd) { define(function () { return impl; }); } else { this.xor4096 = impl; } })(commonjsGlobal, // window object or global module, // present in node.js (typeof undefined) == 'function' // present with an AMD loader ); }(xor4096$1)); var xor4096Exports = xor4096$1.exports; var tychei$1 = { exports: {} }; (function (module) { // A Javascript implementaion of the "Tyche-i" prng algorithm by // Samuel Neves and Filipe Araujo. // See https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf (function (global, module, define) { function XorGen(seed) { var me = this, strseed = ''; // Set up generator function. me.next = function () { var b = me.b, c = me.c, d = me.d, a = me.a; b = (b << 25) ^ (b >>> 7) ^ c; c = (c - d) | 0; d = (d << 24) ^ (d >>> 8) ^ a; a = (a - b) | 0; me.b = b = (b << 20) ^ (b >>> 12) ^ c; me.c = c = (c - d) | 0; me.d = (d << 16) ^ (c >>> 16) ^ a; return me.a = (a - b) | 0; }; /* The following is non-inverted tyche, which has better internal * bit diffusion, but which is about 25% slower than tyche-i in JS. me.next = function() { var a = me.a, b = me.b, c = me.c, d = me.d; a = (me.a + me.b | 0) >>> 0; d = me.d ^ a; d = d << 16 ^ d >>> 16; c = me.c + d | 0; b = me.b ^ c; b = b << 12 ^ d >>> 20; me.a = a = a + b | 0; d = d ^ a; me.d = d = d << 8 ^ d >>> 24; me.c = c = c + d | 0; b = b ^ c; return me.b = (b << 7 ^ b >>> 25); } */ me.a = 0; me.b = 0; me.c = 2654435769 | 0; me.d = 1367130551; if (seed === Math.floor(seed)) { // Integer seed. me.a = (seed / 0x100000000) | 0; me.b = seed | 0; } else { // String seed. strseed += seed; } // Mix in string seed, then discard an initial batch of 64 values. for (var k = 0; k < strseed.length + 20; k++) { me.b ^= strseed.charCodeAt(k) | 0; me.next(); } } function copy(f, t) { t.a = f.a; t.b = f.b; t.c = f.c; t.d = f.d; return t; } function impl(seed, opts) { var xg = new XorGen(seed), state = opts && opts.state, prng = function () { return (xg.next() >>> 0) / 0x100000000; }; prng.double = function () { do { var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 0x100000000, result = (top + bot) / (1 << 21); } while (result === 0); return result; }; prng.int32 = xg.next; prng.quick = prng; if (state) { if (typeof (state) == 'object') copy(state, xg); prng.state = function () { return copy(xg, {}); }; } return prng; } if (module && module.exports) { module.exports = impl; } else if (define && define.amd) { define(function () { return impl; }); } else { this.tychei = impl; } })(commonjsGlobal, module, // present in node.js (typeof undefined) == 'function' // present with an AMD loader ); }(tychei$1)); var tycheiExports = tychei$1.exports; var seedrandom = { exports: {} }; var _nodeResolve_empty = {}; var _nodeResolve_empty$1 = { __proto__: null, default: _nodeResolve_empty }; var require$$0 = /*@__PURE__*/ getAugmentedNamespace(_nodeResolve_empty$1); /* Copyright 2019 David Bau. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ (function (module) { (function (global, pool, math) { // // The following constants are related to IEEE 754 limits. // var width = 256, // each RC4 output is 0 <= x < 256 chunks = 6, // at least six RC4 outputs for each double digits = 52, // there are 52 significant digits in a double rngname = 'random', // rngname: name for Math.random and Math.seedrandom startdenom = math.pow(width, chunks), significance = math.pow(2, digits), overflow = significance * 2, mask = width - 1, nodecrypto; // node.js crypto module, initialized at the bottom. // // seedrandom() // This is the seedrandom function described above. // function seedrandom(seed, options, callback) { var key = []; options = (options == true) ? { entropy: true } : (options || {}); // Flatten the seed string or build one from local entropy if needed. var shortseed = mixkey(flatten(options.entropy ? [seed, tostring(pool)] : (seed == null) ? autoseed() : seed, 3), key); // Use the seed to initialize an ARC4 generator. var arc4 = new ARC4(key); // This function returns a random double in [0, 1) that contains // randomness in every bit of the mantissa of the IEEE 754 value. var prng = function () { var n = arc4.g(chunks), // Start with a numerator n < 2 ^ 48 d = startdenom, // and denominator d = 2 ^ 48. x = 0; // and no 'extra last byte'. while (n < significance) { // Fill up all significant digits by n = (n + x) * width; // shifting numerator and d *= width; // denominator and generating a x = arc4.g(1); // new least-significant-byte. } while (n >= overflow) { // To avoid rounding up, before adding n /= 2; // last byte, shift everything d /= 2; // right using integer math until x >>>= 1; // we have exactly the desired bits. } return (n + x) / d; // Form the number within [0, 1). }; prng.int32 = function () { return arc4.g(4) | 0; }; prng.quick = function () { return arc4.g(4) / 0x100000000; }; prng.double = prng; // Mix the randomness into accumulated entropy. mixkey(tostring(arc4.S), pool); // Calling convention: what to return as a function of prng, seed, is_math. return (options.pass || callback || function (prng, seed, is_math_call, state) { if (state) { // Load the arc4 state from the given state if it has an S array. if (state.S) { copy(state, arc4); } // Only provide the .state method if requested via options.state. prng.state = function () { return copy(arc4, {}); }; } // If called as a method of Math (Math.seedrandom()), mutate // Math.random because that is how seedrandom.js has worked since v1.0. if (is_math_call) { math[rngname] = prng; return seed; } // Otherwise, it is a newer calling convention, so return the // prng directly. else return prng; })(prng, shortseed, 'global' in options ? options.global : (this == math), options.state); } // // ARC4 // // An ARC4 implementation. The constructor takes a key in the form of // an array of at most (width) integers that should be 0 <= x < (width). // // The g(count) method returns a pseudorandom integer that concatenates // the next (count) outputs from ARC4. Its return value is a number x // that is in the range 0 <= x < (width ^ count). // function ARC4(key) { var t, keylen = key.length, me = this, i = 0, j = me.i = me.j = 0, s = me.S = []; // The empty key [] is treated as [0]. if (!keylen) { key = [keylen++]; } // Set up S using the standard key scheduling algorithm. while (i < width) { s[i] = i++; } for (i = 0; i < width; i++) { s[i] = s[j = mask & (j + key[i % keylen] + (t = s[i]))]; s[j] = t; } // The "g" method returns the next (count) outputs as one number. (me.g = function (count) { // Using instance members instead of closure state nearly doubles speed. var t, r = 0, i = me.i, j = me.j, s = me.S; while (count--) { t = s[i = mask & (i + 1)]; r = r * width + s[mask & ((s[i] = s[j = mask & (j + t)]) + (s[j] = t))]; } me.i = i; me.j = j; return r; // For robust unpredictability, the function call below automatically // discards an initial batch of values. This is called RC4-drop[256]. // See http://google.com/search?q=rsa+fluhrer+response&btnI })(width); } // // copy() // Copies internal state of ARC4 to or from a plain object. // function copy(f, t) { t.i = f.i; t.j = f.j; t.S = f.S.slice(); return t; } // // flatten() // Converts an object tree to nested arrays of strings. // function flatten(obj, depth) { var result = [], typ = (typeof obj), prop; if (depth && typ == 'object') { for (prop in obj) { try { result.push(flatten(obj[prop], depth - 1)); } catch (e) { } } } return (result.length ? result : typ == 'string' ? obj : obj + '\0'); } // // mixkey() // Mixes a string seed into a key that is an array of integers, and // returns a shortened string seed that is equivalent to the result key. // function mixkey(seed, key) { var stringseed = seed + '', smear, j = 0; while (j < stringseed.length) { key[mask & j] = mask & ((smear ^= key[mask & j] * 19) + stringseed.charCodeAt(j++)); } return tostring(key); } // // autoseed() // Returns an object for autoseeding, using window.crypto and Node crypto // module if available. // function autoseed() { try { var out; if (nodecrypto && (out = nodecrypto.randomBytes)) { // The use of 'out' to remember randomBytes makes tight minified code. out = out(width); } else { out = new Uint8Array(width); (global.crypto || global.msCrypto).getRandomValues(out); } return tostring(out); } catch (e) { var browser = global.navigator, plugins = browser && browser.plugins; return [+new Date, global, plugins, global.screen, tostring(pool)]; } } // // tostring() // Converts an array of charcodes to a string // function tostring(a) { return String.fromCharCode.apply(0, a); } // // When seedrandom.js is loaded, we immediately mix a few bits // from the built-in RNG into the entropy pool. Because we do // not want to interfere with deterministic PRNG state later, // seedrandom will not call math.random on its own again after // initialization. // mixkey(math.random(), pool); // // Nodejs and AMD support: export the implementation as a module using // either convention. // if (module.exports) { module.exports = seedrandom; // When in node.js, try using crypto package for autoseeding. try { nodecrypto = require$$0; } catch (ex) { } } else { // When included as a plain script, set up Math.seedrandom global. math['seed' + rngname] = seedrandom; } // End anonymous scope, and pass initial values. })( // global: `self` in browsers (including strict mode and web workers), // otherwise `this` in Node and other environments (typeof self !== 'undefined') ? self : commonjsGlobal, [], // pool: entropy pool starts empty Math // math: package containing random, pow, and seedrandom ); }(seedrandom)); var seedrandomExports = seedrandom.exports; // A library of seedable RNGs implemented in Javascript. // // Usage: // // var seedrandom = require('seedrandom'); // var random = seedrandom(1); // or any seed. // var x = random(); // 0 <= x < 1. Every bit is random. // var x = random.quick(); // 0 <= x < 1. 32 bits of randomness. // alea, a 53-bit multiply-with-carry generator by Johannes Baagøe. // Period: ~2^116 // Reported to pass all BigCrush tests. var alea = aleaExports; // xor128, a pure xor-shift generator by George Marsaglia. // Period: 2^128-1. // Reported to fail: MatrixRank and LinearComp. var xor128 = xor128Exports; // xorwow, George Marsaglia's 160-bit xor-shift combined plus weyl. // Period: 2^192-2^32 // Reported to fail: CollisionOver, SimpPoker, and LinearComp. var xorwow = xorwowExports; // xorshift7, by François Panneton and Pierre L'ecuyer, takes // a different approach: it adds robustness by allowing more shifts // than Marsaglia's original three. It is a 7-shift generator // with 256 bits, that passes BigCrush with no systmatic failures. // Period 2^256-1. // No systematic BigCrush failures reported. var xorshift7 = xorshift7Exports; // xor4096, by Richard Brent, is a 4096-bit xor-shift with a // very long period that also adds a Weyl generator. It also passes // BigCrush with no systematic failures. Its long period may // be useful if you have many generators and need to avoid // collisions. // Period: 2^4128-2^32. // No systematic BigCrush failures reported. var xor4096 = xor4096Exports; // Tyche-i, by Samuel Neves and Filipe Araujo, is a bit-shifting random // number generator derived from ChaCha, a modern stream cipher. // https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf // Period: ~2^127 // No systematic BigCrush failures reported. var tychei = tycheiExports; // The original ARC4-based prng included in this library. // Period: ~2^1600 var sr = seedrandomExports; sr.alea = alea; sr.xor128 = xor128; sr.xorwow = xorwow; sr.xorshift7 = xorshift7; sr.xor4096 = xor4096; sr.tychei = tychei; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Reverses a `tf.Tensor` along a specified axis. * * Also available are stricter rank-specific methods that assert that `x` is * of the given rank: * - `tf.reverse1d` * - `tf.reverse2d` * - `tf.reverse3d` * - `tf.reverse4d` * * Except `tf.reverse1d` (which does not have axis param), all methods have * same signature as this method. * * ```js * const x = tf.tensor1d([1, 2, 3, 4]); * * x.reverse().print(); * ``` * * ```js * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); * * const axis = 1; * x.reverse(axis).print(); * ``` * @param x The input tensor to be reversed. * @param axis The set of dimensions to reverse. Must be in the * range [-rank(x), rank(x)). Defaults to all axes. * * @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ function reverse_(x, axis) { var $x = convertToTensor(x, 'x', 'reverse'); var inputs = { x: $x }; var attrs = { dims: axis }; return ENGINE.runKernel(Reverse, inputs, attrs); } var reverse = /* @__PURE__ */ op({ reverse_: reverse_ }); /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Creates a new tensor by applying sparse updates to individual * values or slices to the passed in tensor according to * indices. This operator is the similar to scatterNd op, except that the * udpates are scattered on an existing tensor (as opposed to a zero-tensor). * * If indices contains duplicates, then we pick the last update for the index. * * If an out of bound index is found on CPU, an error is returned. * * Warning: There are some GPU specific semantics for this operation. * - If an out of bound index is found, the index is ignored. * - The order in which updates are applied is nondeterministic, so the output * will be nondeterministic if indices contains duplicates. * ```js * const shape = [8]; * const tensor = tf.ones(shape); * const indices = tf.tensor2d([4, 3, 1, 7], [4, 1], 'int32'); * const updates = tf.tensor1d([9, 10, 11, 12]); * * tf.tensorScatterUpdate(tensor, indices, updates).print(); * //[1, 11, 1, 10, 9, 1, 1, 12] * ``` * * @param tensor A Tensor. Tensor to copy/update. * @param indices The tensor contains the indices into the output tensor, must * have at least 2 axes: (num_updates, index_depth). * @param updates The tensor contains the value for the indices. * * @doc {heading: 'Operations', subheading: 'Slicing and Joining'} */ function tensorScatterUpdate_(tensor, indices, updates) { var $tensor = convertToTensor(tensor, 'tensor', 'tensorScatterupdate'); var $indices = convertToTensor(indices, 'indices', 'tensorScatterupdate', 'int32'); var $updates = convertToTensor(updates, 'updates', 'tensorScatterupdate'); validateInput($updates, $indices, $tensor.shape); if ($tensor.dtype !== $updates.dtype) { throw new Error("tensor and updates must have the same dtype, instead they are ".concat($tensor.dtype, " and ").concat($updates.dtype, ".")); } var inputs = { tensor: $tensor, indices: $indices, updates: $updates }; var attrs = {}; // tslint:disable-next-line: no-unnecessary-type-assertion return ENGINE.runKernel(TensorScatterUpdate, inputs, attrs); } op({ tensorScatterUpdate_: tensorScatterUpdate_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var Reduction; (function (Reduction) { Reduction[Reduction["NONE"] = 0] = "NONE"; Reduction[Reduction["MEAN"] = 1] = "MEAN"; Reduction[Reduction["SUM"] = 2] = "SUM"; Reduction[Reduction["SUM_BY_NONZERO_WEIGHTS"] = 3] = "SUM_BY_NONZERO_WEIGHTS"; })(Reduction || (Reduction = {})); /** @doc {heading: 'Training', subheading: 'Classes', namespace: 'train'} */ var Optimizer = /** @class */ (function (_super) { __extends(Optimizer, _super); function Optimizer() { return _super !== null && _super.apply(this, arguments) || this; } /** * Executes `f()` and minimizes the scalar output of `f()` by computing * gradients of y with respect to the list of trainable variables provided by * `varList`. If no list is provided, it defaults to all trainable variables. * * @param f The function to execute and whose output to minimize. * @param returnCost Whether to return the scalar cost value produced by * executing `f()`. * @param varList An optional list of variables to update. If specified, only * the trainable variables in varList will be updated by minimize. Defaults to * all trainable variables. * * @doc {heading: 'Training', subheading: 'Optimizers'} */ Optimizer.prototype.minimize = function (f, returnCost, varList) { if (returnCost === void 0) { returnCost = false; } var _a = this.computeGradients(f, varList), value = _a.value, grads = _a.grads; if (varList != null) { var gradArray = varList.map(function (v) { return ({ name: v.name, tensor: grads[v.name] }); }); this.applyGradients(gradArray); } else { this.applyGradients(grads); } // Dispose gradients. dispose(grads); if (returnCost) { return value; } else { value.dispose(); return null; } }; Object.defineProperty(Optimizer.prototype, "iterations", { /** * The number of iterations that this optimizer instance has been invoked for. */ get: function () { if (this.iterations_ == null) { this.iterations_ = 0; } return this.iterations_; }, enumerable: false, configurable: true }); Optimizer.prototype.incrementIterations = function () { this.iterations_ = this.iterations + 1; }; /** * Executes f() and computes the gradient of the scalar output of f() with * respect to the list of trainable variables provided by `varList`. If no * list is provided, it defaults to all trainable variables. * * @param f The function to execute and whose output to use for computing * gradients with respect to variables. * @param varList An optional list of variables to compute gradients with * respect to. If specified, only the trainable variables in varList will have * gradients computed with respect to. Defaults to all trainable variables. * * @doc {heading: 'Training', subheading: 'Optimizers'} */ Optimizer.prototype.computeGradients = function (f, varList) { return variableGrads(f, varList); }; /** * Dispose the variables (if any) owned by this optimizer instance. */ Optimizer.prototype.dispose = function () { if (this.iterations_ != null) { dispose(this.iterations_); } }; Optimizer.prototype.saveIterations = function () { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { if (this.iterations_ == null) { this.iterations_ = 0; } return [2 /*return*/, { name: 'iter', // TODO(cais): Use 'int64' type when available. tensor: scalar(this.iterations_, 'int32') }]; }); }); }; Optimizer.prototype.getWeights = function () { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { throw new Error('getWeights() is not implemented for this optimizer yet.'); }); }); }; Optimizer.prototype.setWeights = function (weightValues) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { throw new Error("setWeights() is not implemented for this optimizer class " + "".concat(this.getClassName())); }); }); }; /** * Extract the first element of the weight values and set it * as the iterations counter variable of this instance of optimizer. * * @param weightValues * @returns Weight values with the first element consumed and excluded. */ Optimizer.prototype.extractIterations = function (weightValues) { return __awaiter(this, void 0, void 0, function () { var _a; return __generator(this, function (_b) { switch (_b.label) { case 0: _a = this; return [4 /*yield*/, weightValues[0].tensor.data()]; case 1: _a.iterations_ = (_b.sent())[0]; return [2 /*return*/, weightValues.slice(1)]; } }); }); }; return Optimizer; }(Serializable)); Object.defineProperty(Optimizer, Symbol.hasInstance, { value: function (instance) { return instance.minimize != null && instance.computeGradients != null && instance.applyGradients != null; } }); /** @doclink Optimizer */ /** @class */ ((function (_super) { __extends(AdadeltaOptimizer, _super); function AdadeltaOptimizer(learningRate, rho, epsilon) { if (epsilon === void 0) { epsilon = null; } var _this = _super.call(this) || this; _this.learningRate = learningRate; _this.rho = rho; _this.epsilon = epsilon; _this.accumulatedGrads = []; _this.accumulatedUpdates = []; if (epsilon == null) { _this.epsilon = ENGINE.backend.epsilon(); } return _this; } Object.defineProperty(AdadeltaOptimizer, "className", { /** @nocollapse */ get: function () { // Name matters for Python compatibility. // This is a getter instead of a property because when it's a property, it // prevents the entire class from being tree-shaken. return 'Adadelta'; }, enumerable: false, configurable: true }); AdadeltaOptimizer.prototype.applyGradients = function (variableGradients) { var _this = this; var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function (item) { return item.name; }) : Object.keys(variableGradients); variableNames.forEach(function (name, i) { var value = ENGINE.registeredVariables[name]; var trainable = false; if (_this.accumulatedGrads[i] == null) { _this.accumulatedGrads[i] = { originalName: "".concat(name, "/accum_grad"), variable: tidy(function () { return zerosLike(value).variable(trainable); }) }; } if (_this.accumulatedUpdates[i] == null) { _this.accumulatedUpdates[i] = { originalName: "".concat(name, "/accum_var"), variable: tidy(function () { return zerosLike(value).variable(trainable); }) }; } var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name]; if (gradient == null) { return; } var accumulatedGrad = _this.accumulatedGrads[i].variable; var accumulatedUpdate = _this.accumulatedUpdates[i].variable; tidy(function () { var newAccumulatedGrad = add$1(mul(accumulatedGrad, _this.rho), mul(square(gradient), 1 - _this.rho)); var updates = mul(div(sqrt(add$1(accumulatedUpdate, _this.epsilon)), sqrt(add$1(accumulatedGrad, _this.epsilon))), gradient); var newAccumulatedUpdate = add$1(mul(accumulatedUpdate, _this.rho), mul(square(updates), 1 - _this.rho)); accumulatedGrad.assign(newAccumulatedGrad); accumulatedUpdate.assign(newAccumulatedUpdate); var newValue = add$1(mul(updates, -_this.learningRate), value); value.assign(newValue); }); }); this.incrementIterations(); }; AdadeltaOptimizer.prototype.dispose = function () { if (this.accumulatedUpdates != null) { dispose(this.accumulatedGrads.map(function (v) { return v.variable; })); dispose(this.accumulatedUpdates.map(function (v) { return v.variable; })); } }; AdadeltaOptimizer.prototype.getWeights = function () { return __awaiter(this, void 0, void 0, function () { var variables; return __generator(this, function (_a) { switch (_a.label) { case 0: variables = __spreadArray(__spreadArray([], __read(this.accumulatedGrads), false), __read(this.accumulatedUpdates), false); return [4 /*yield*/, this.saveIterations()]; case 1: return [2 /*return*/, [_a.sent()].concat(variables.map(function (v) { return ({ name: v.originalName, tensor: v.variable }); }))]; } }); }); }; AdadeltaOptimizer.prototype.setWeights = function (weightValues) { return __awaiter(this, void 0, void 0, function () { var variableCount, trainable; return __generator(this, function (_a) { switch (_a.label) { case 0: return [4 /*yield*/, this.extractIterations(weightValues)]; case 1: weightValues = _a.sent(); variableCount = weightValues.length / 2; trainable = false; this.accumulatedGrads = weightValues.slice(0, variableCount).map(function (v) { return ({ originalName: v.name, variable: v.tensor.variable(trainable) }); }); this.accumulatedUpdates = weightValues.slice(variableCount, variableCount * 2) .map(function (v) { return ({ originalName: v.name, variable: v.tensor.variable(trainable) }); }); return [2 /*return*/]; } }); }); }; AdadeltaOptimizer.prototype.getConfig = function () { return { 'learningRate': this.learningRate, 'rho': this.rho, 'epsilon': this.epsilon }; }; /** @nocollapse */ AdadeltaOptimizer.fromConfig = function (cls, config) { return new cls(config['learningRate'], config['rho'], config['epsilon']); }; return AdadeltaOptimizer; })(Optimizer)); /** @doclink Optimizer */ /** @class */ ((function (_super) { __extends(AdagradOptimizer, _super); function AdagradOptimizer(learningRate, initialAccumulatorValue) { if (initialAccumulatorValue === void 0) { initialAccumulatorValue = 0.1; } var _this = _super.call(this) || this; _this.learningRate = learningRate; _this.initialAccumulatorValue = initialAccumulatorValue; _this.accumulatedGrads = []; return _this; } Object.defineProperty(AdagradOptimizer, "className", { /** @nocollapse */ get: function () { // Name matters for Python compatibility. // This is a getter instead of a property because when it's a property, it // prevents the entire class from being tree-shaken. return 'Adagrad'; }, enumerable: false, configurable: true }); AdagradOptimizer.prototype.applyGradients = function (variableGradients) { var _this = this; var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function (item) { return item.name; }) : Object.keys(variableGradients); variableNames.forEach(function (name, i) { var value = ENGINE.registeredVariables[name]; if (_this.accumulatedGrads[i] == null) { var trainable_1 = false; _this.accumulatedGrads[i] = { originalName: "".concat(name, "/accumulator"), variable: tidy(function () { return fill(value.shape, _this.initialAccumulatorValue) .variable(trainable_1); }) }; } var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name]; if (gradient == null) { return; } var accumulatedGrad = _this.accumulatedGrads[i].variable; tidy(function () { var newAccumulatedGrad = add$1(accumulatedGrad, square(gradient)); accumulatedGrad.assign(newAccumulatedGrad); var newValue = add$1(mul(div(gradient, sqrt(add$1(newAccumulatedGrad, ENGINE.backend.epsilon()))), -_this.learningRate), value); value.assign(newValue); }); }); this.incrementIterations(); }; AdagradOptimizer.prototype.dispose = function () { if (this.accumulatedGrads != null) { dispose(this.accumulatedGrads.map(function (v) { return v.variable; })); } }; AdagradOptimizer.prototype.getWeights = function () { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { switch (_a.label) { case 0: return [4 /*yield*/, this.saveIterations()]; case 1: // Order matters for Python compatibility. return [2 /*return*/, [_a.sent()].concat(this.accumulatedGrads.map(function (v) { return ({ name: v.originalName, tensor: v.variable }); }))]; } }); }); }; AdagradOptimizer.prototype.setWeights = function (weightValues) { return __awaiter(this, void 0, void 0, function () { var trainable; return __generator(this, function (_a) { switch (_a.label) { case 0: return [4 /*yield*/, this.extractIterations(weightValues)]; case 1: weightValues = _a.sent(); trainable = false; this.accumulatedGrads = weightValues.map(function (v) { return ({ originalName: v.name, variable: v.tensor.variable(trainable) }); }); return [2 /*return*/]; } }); }); }; AdagradOptimizer.prototype.getConfig = function () { return { 'learningRate': this.learningRate, 'initialAccumulatorValue': this.initialAccumulatorValue, }; }; /** @nocollapse */ AdagradOptimizer.fromConfig = function (cls, config) { return new cls(config['learningRate'], config['initialAccumulatorValue']); }; return AdagradOptimizer; })(Optimizer)); /** @class */ ((function (_super) { __extends(AdamOptimizer, _super); function AdamOptimizer(learningRate, beta1, beta2, epsilon) { if (epsilon === void 0) { epsilon = null; } var _this = _super.call(this) || this; _this.learningRate = learningRate; _this.beta1 = beta1; _this.beta2 = beta2; _this.epsilon = epsilon; _this.accumulatedFirstMoment = []; _this.accumulatedSecondMoment = []; tidy(function () { // accB* will be updated by batch. _this.accBeta1 = scalar(beta1).variable(); _this.accBeta2 = scalar(beta2).variable(); }); if (epsilon == null) { _this.epsilon = ENGINE.backend.epsilon(); } return _this; } Object.defineProperty(AdamOptimizer, "className", { /** @nocollapse */ get: function () { // Name matters for Python compatibility. // This is a getter instead of a property because when it's a property, it // prevents the entire class from being tree-shaken. return 'Adam'; }, enumerable: false, configurable: true }); AdamOptimizer.prototype.applyGradients = function (variableGradients) { var _this = this; var varNames = Array.isArray(variableGradients) ? variableGradients.map(function (v) { return v.name; }) : Object.keys(variableGradients); tidy(function () { var oneMinusAccBeta1 = sub(1, _this.accBeta1); var oneMinusAccBeta2 = sub(1, _this.accBeta2); varNames.forEach(function (name, i) { var value = ENGINE.registeredVariables[name]; var trainable = false; if (_this.accumulatedFirstMoment[i] == null) { _this.accumulatedFirstMoment[i] = { originalName: "".concat(name, "/m"), variable: tidy(function () { return zerosLike(value).variable(trainable); }) }; } if (_this.accumulatedSecondMoment[i] == null) { _this.accumulatedSecondMoment[i] = { originalName: "".concat(name, "/v"), variable: tidy(function () { return zerosLike(value).variable(trainable); }) }; } var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name]; if (gradient == null) { return; } var firstMoment = _this.accumulatedFirstMoment[i].variable; var secondMoment = _this.accumulatedSecondMoment[i].variable; var newFirstMoment = add$1(mul(firstMoment, _this.beta1), mul(gradient, 1 - _this.beta1)); var newSecondMoment = add$1(mul(secondMoment, _this.beta2), mul(square(gradient), 1 - _this.beta2)); var biasCorrectedFirstMoment = div(newFirstMoment, oneMinusAccBeta1); var biasCorrectedSecondMoment = div(newSecondMoment, oneMinusAccBeta2); firstMoment.assign(newFirstMoment); secondMoment.assign(newSecondMoment); var newValue = add$1(mul(div(biasCorrectedFirstMoment, add$1(sqrt(biasCorrectedSecondMoment), _this.epsilon)), -_this.learningRate), value); value.assign(newValue); }); _this.accBeta1.assign(mul(_this.accBeta1, _this.beta1)); _this.accBeta2.assign(mul(_this.accBeta2, _this.beta2)); }); this.incrementIterations(); }; AdamOptimizer.prototype.dispose = function () { this.accBeta1.dispose(); this.accBeta2.dispose(); if (this.accumulatedFirstMoment != null) { dispose(this.accumulatedFirstMoment.map(function (v) { return v.variable; })); } if (this.accumulatedSecondMoment != null) { dispose(this.accumulatedSecondMoment.map(function (v) { return v.variable; })); } }; AdamOptimizer.prototype.getWeights = function () { return __awaiter(this, void 0, void 0, function () { var variables; return __generator(this, function (_a) { switch (_a.label) { case 0: variables = __spreadArray(__spreadArray([], __read(this.accumulatedFirstMoment), false), __read(this.accumulatedSecondMoment), false); return [4 /*yield*/, this.saveIterations()]; case 1: return [2 /*return*/, [_a.sent()].concat(variables.map(function (v) { return ({ name: v.originalName, tensor: v.variable }); }))]; } }); }); }; AdamOptimizer.prototype.setWeights = function (weightValues) { return __awaiter(this, void 0, void 0, function () { var variableCount, trainable; var _this = this; return __generator(this, function (_a) { switch (_a.label) { case 0: return [4 /*yield*/, this.extractIterations(weightValues)]; case 1: weightValues = _a.sent(); tidy(function () { _this.accBeta1.assign(pow(_this.beta1, _this.iterations_ + 1)); _this.accBeta2.assign(pow(_this.beta2, _this.iterations_ + 1)); }); variableCount = weightValues.length / 2; trainable = false; this.accumulatedFirstMoment = weightValues.slice(0, variableCount).map(function (v) { return ({ originalName: v.name, variable: v.tensor.variable(trainable) }); }); this.accumulatedSecondMoment = weightValues.slice(variableCount, variableCount * 2) .map(function (v) { return ({ originalName: v.name, variable: v.tensor.variable(trainable) }); }); return [2 /*return*/]; } }); }); }; AdamOptimizer.prototype.getConfig = function () { return { 'learningRate': this.learningRate, 'beta1': this.beta1, 'beta2': this.beta2, 'epsilon': this.epsilon, }; }; /** @nocollapse */ AdamOptimizer.fromConfig = function (cls, config) { return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon']); }; return AdamOptimizer; })(Optimizer)); /** @class */ ((function (_super) { __extends(AdamaxOptimizer, _super); function AdamaxOptimizer(learningRate, beta1, beta2, epsilon, decay) { if (epsilon === void 0) { epsilon = null; } if (decay === void 0) { decay = 0.0; } var _this = _super.call(this) || this; _this.learningRate = learningRate; _this.beta1 = beta1; _this.beta2 = beta2; _this.epsilon = epsilon; _this.decay = decay; _this.accumulatedFirstMoment = []; _this.accumulatedWeightedInfNorm = []; tidy(function () { _this.iteration = scalar(0).variable(); _this.accBeta1 = scalar(beta1).variable(); }); if (epsilon == null) { _this.epsilon = ENGINE.backend.epsilon(); } return _this; } Object.defineProperty(AdamaxOptimizer, "className", { /** @nocollapse */ get: function () { // Name matters for Python compatibility. // This is a getter instead of a property because when it's a property, it // prevents the entire class from being tree-shaken. return 'Adamax'; }, enumerable: false, configurable: true }); AdamaxOptimizer.prototype.applyGradients = function (variableGradients) { var _this = this; var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function (item) { return item.name; }) : Object.keys(variableGradients); tidy(function () { var oneMinusAccBeta1 = sub(1, _this.accBeta1); var lr = div(-_this.learningRate, add$1(mul(_this.iteration, _this.decay), 1)); variableNames.forEach(function (name, i) { var value = ENGINE.registeredVariables[name]; var trainable = false; if (_this.accumulatedFirstMoment[i] == null) { _this.accumulatedFirstMoment[i] = { originalName: "".concat(name, "/m"), variable: zerosLike(value).variable(trainable) }; } if (_this.accumulatedWeightedInfNorm[i] == null) { _this.accumulatedWeightedInfNorm[i] = { originalName: "".concat(name, "/v"), variable: zerosLike(value).variable(trainable) }; } var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name]; if (gradient == null) { return; } var firstMoment = _this.accumulatedFirstMoment[i].variable; var weightedInfNorm = _this.accumulatedWeightedInfNorm[i].variable; var newFirstMoment = add$1(mul(firstMoment, _this.beta1), mul(gradient, 1 - _this.beta1)); var ut0 = mul(weightedInfNorm, _this.beta2); var ut1 = abs(gradient); var newWeightedInfNorm = maximum$1(ut0, ut1); firstMoment.assign(newFirstMoment); weightedInfNorm.assign(newWeightedInfNorm); var newValue = add$1(mul(div(lr, oneMinusAccBeta1), div(newFirstMoment, add$1(newWeightedInfNorm, _this.epsilon))), value); value.assign(newValue); }); _this.iteration.assign(add$1(_this.iteration, 1)); _this.accBeta1.assign(mul(_this.accBeta1, _this.beta1)); }); this.incrementIterations(); }; AdamaxOptimizer.prototype.dispose = function () { this.accBeta1.dispose(); this.iteration.dispose(); if (this.accumulatedFirstMoment != null) { dispose(this.accumulatedFirstMoment.map(function (v) { return v.variable; })); } if (this.accumulatedWeightedInfNorm != null) { dispose(this.accumulatedWeightedInfNorm.map(function (v) { return v.variable; })); } }; AdamaxOptimizer.prototype.getWeights = function () { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { throw new Error('getWeights() is not implemented for Adamax yet.'); }); }); }; AdamaxOptimizer.prototype.setWeights = function (weightValues) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { throw new Error('setWeights() is not implemented for Adamax yet.'); }); }); }; AdamaxOptimizer.prototype.getConfig = function () { return { 'learningRate': this.learningRate, 'beta1': this.beta1, 'beta2': this.beta2, 'epsilon': this.epsilon, 'decay': this.decay }; }; /** @nocollapse */ AdamaxOptimizer.fromConfig = function (cls, config) { return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon'], config['decay']); }; return AdamaxOptimizer; })(Optimizer)); /** @doclink Optimizer */ var SGDOptimizer = /** @class */ (function (_super) { __extends(SGDOptimizer, _super); function SGDOptimizer(learningRate) { var _this = _super.call(this) || this; _this.learningRate = learningRate; _this.setLearningRate(learningRate); return _this; } Object.defineProperty(SGDOptimizer, "className", { /** @nocollapse */ get: function () { // Name matters for Python compatibility. // This is a getter instead of a property because when it's a property, it // prevents the entire class from being tree-shaken. return 'SGD'; }, enumerable: false, configurable: true }); SGDOptimizer.prototype.applyGradients = function (variableGradients) { var _this = this; var varNames = Array.isArray(variableGradients) ? variableGradients.map(function (v) { return v.name; }) : Object.keys(variableGradients); varNames.forEach(function (name, i) { var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name]; if (gradient == null) { return; } var value = ENGINE.registeredVariables[name]; tidy(function () { var newValue = add$1(mul(_this.c, gradient), value); value.assign(newValue); }); }); this.incrementIterations(); }; /** * Sets the learning rate of the optimizer. */ SGDOptimizer.prototype.setLearningRate = function (learningRate) { this.learningRate = learningRate; if (this.c != null) { this.c.dispose(); } this.c = keep(scalar(-learningRate)); }; SGDOptimizer.prototype.dispose = function () { this.c.dispose(); }; SGDOptimizer.prototype.getWeights = function () { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { switch (_a.label) { case 0: return [4 /*yield*/, this.saveIterations()]; case 1: return [2 /*return*/, [_a.sent()]]; } }); }); }; SGDOptimizer.prototype.setWeights = function (weightValues) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { switch (_a.label) { case 0: return [4 /*yield*/, this.extractIterations(weightValues)]; case 1: weightValues = _a.sent(); if (weightValues.length !== 0) { throw new Error('SGD optimizer does not have settable weights.'); } return [2 /*return*/]; } }); }); }; SGDOptimizer.prototype.getConfig = function () { return { 'learningRate': this.learningRate }; }; /** @nocollapse */ SGDOptimizer.fromConfig = function (cls, config) { return new cls(config['learningRate']); }; return SGDOptimizer; }(Optimizer)); /** @doclink Optimizer */ /** @class */ ((function (_super) { __extends(MomentumOptimizer, _super); function MomentumOptimizer(learningRate, momentum, useNesterov) { if (useNesterov === void 0) { useNesterov = false; } var _this = _super.call(this, learningRate) || this; _this.learningRate = learningRate; _this.momentum = momentum; _this.useNesterov = useNesterov; _this.accumulations = []; _this.m = scalar(_this.momentum); return _this; } Object.defineProperty(MomentumOptimizer, "className", { /** @nocollapse */ // Name matters for Python compatibility. get: function () { // Name matters for Python compatibility. // This is a getter instead of a property because when it's a property, it // prevents the entire class from being tree-shaken. return 'Momentum'; }, enumerable: false, configurable: true }); MomentumOptimizer.prototype.applyGradients = function (variableGradients) { var _this = this; var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function (item) { return item.name; }) : Object.keys(variableGradients); variableNames.forEach(function (name, i) { var value = ENGINE.registeredVariables[name]; if (_this.accumulations[i] == null) { var trainable_1 = false; _this.accumulations[i] = { originalName: "".concat(name, "/momentum"), variable: tidy(function () { return zerosLike(value).variable(trainable_1); }) }; } var accumulation = _this.accumulations[i].variable; var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name]; if (gradient == null) { return; } tidy(function () { var newValue; var newAccumulation = add$1(mul(_this.m, accumulation), gradient); if (_this.useNesterov) { newValue = add$1(mul(_this.c, add$1(gradient, mul(newAccumulation, _this.m))), value); } else { newValue = add$1(mul(_this.c, newAccumulation), value); } accumulation.assign(newAccumulation); value.assign(newValue); }); }); this.incrementIterations(); }; MomentumOptimizer.prototype.dispose = function () { this.m.dispose(); if (this.accumulations != null) { dispose(this.accumulations.map(function (v) { return v.variable; })); } }; /** * Sets the momentum of the optimizer. * * @param momentum */ MomentumOptimizer.prototype.setMomentum = function (momentum) { this.momentum = momentum; }; MomentumOptimizer.prototype.getWeights = function () { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { switch (_a.label) { case 0: return [4 /*yield*/, this.saveIterations()]; case 1: // Order matters for Python compatibility. return [2 /*return*/, [_a.sent()].concat(this.accumulations.map(function (v) { return ({ name: v.originalName, tensor: v.variable }); }))]; } }); }); }; MomentumOptimizer.prototype.setWeights = function (weightValues) { return __awaiter(this, void 0, void 0, function () { var trainable; return __generator(this, function (_a) { switch (_a.label) { case 0: return [4 /*yield*/, this.extractIterations(weightValues)]; case 1: weightValues = _a.sent(); trainable = false; this.accumulations = weightValues.map(function (v) { return ({ originalName: v.name, variable: v.tensor.variable(trainable) }); }); return [2 /*return*/]; } }); }); }; MomentumOptimizer.prototype.getConfig = function () { return { 'learningRate': this.learningRate, 'momentum': this.momentum, 'useNesterov': this.useNesterov }; }; /** @nocollapse */ MomentumOptimizer.fromConfig = function (cls, config) { return new cls(config['learningRate'], config['momentum'], config['useNesterov']); }; return MomentumOptimizer; })(SGDOptimizer)); /** @doclink Optimizer */ /** @class */ ((function (_super) { __extends(RMSPropOptimizer, _super); function RMSPropOptimizer(learningRate, decay, momentum, epsilon, centered) { if (decay === void 0) { decay = 0.9; } if (momentum === void 0) { momentum = 0.0; } if (epsilon === void 0) { epsilon = null; } if (centered === void 0) { centered = false; } var _this = _super.call(this) || this; _this.learningRate = learningRate; _this.decay = decay; _this.momentum = momentum; _this.epsilon = epsilon; _this.accumulatedMeanSquares = []; _this.accumulatedMoments = []; _this.accumulatedMeanGrads = []; _this.centered = centered; if (epsilon == null) { _this.epsilon = ENGINE.backend.epsilon(); } if (learningRate == null) { throw new Error("learningRate for RMSPropOptimizer must be defined."); } return _this; } Object.defineProperty(RMSPropOptimizer, "className", { /** @nocollapse */ get: function () { // Name matters for Python compatibility. // This is a getter instead of a property because when it's a property, it // prevents the entire class from being tree-shaken. return 'RMSProp'; }, enumerable: false, configurable: true }); RMSPropOptimizer.prototype.applyGradients = function (variableGradients) { var _this = this; var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function (item) { return item.name; }) : Object.keys(variableGradients); variableNames.forEach(function (name, i) { var value = ENGINE.registeredVariables[name]; var trainable = false; if (_this.accumulatedMeanSquares[i] == null) { _this.accumulatedMeanSquares[i] = { originalName: "".concat(name, "/rms"), variable: tidy(function () { return zerosLike(value).variable(trainable); }) }; } if (_this.accumulatedMoments[i] == null) { _this.accumulatedMoments[i] = { originalName: "".concat(name, "/momentum"), variable: tidy(function () { return zerosLike(value).variable(trainable); }) }; } if (_this.accumulatedMeanGrads[i] == null && _this.centered) { _this.accumulatedMeanGrads[i] = { originalName: "".concat(name, "/mg"), variable: tidy(function () { return zerosLike(value).variable(trainable); }) }; } var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name]; if (gradient == null) { return; } var accumulatedMeanSquare = _this.accumulatedMeanSquares[i].variable; var accumulatedMoments = _this.accumulatedMoments[i].variable; tidy(function () { var newAccumulatedMeanSquare = add$1(mul(accumulatedMeanSquare, _this.decay), mul(square(gradient), 1 - _this.decay)); if (_this.centered) { var accumulatedMeanGrad = _this.accumulatedMeanGrads[i].variable; // Centered gradient var newAccumulatedMeanGrad = add$1(mul(accumulatedMeanGrad, _this.decay), mul(gradient, 1 - _this.decay)); var gradContribution = div(mul(gradient, _this.learningRate), sqrt(sub(newAccumulatedMeanSquare, add$1(square(newAccumulatedMeanGrad), _this.epsilon)))); var newAccumulatedMoments = add$1(mul(accumulatedMoments, _this.momentum), gradContribution); accumulatedMeanSquare.assign(newAccumulatedMeanSquare); accumulatedMeanGrad.assign(newAccumulatedMeanGrad); accumulatedMoments.assign(newAccumulatedMoments); var newValue = sub(value, newAccumulatedMoments); value.assign(newValue); } else { // Plain gradient var newAccumulatedMeanSquare_1 = add$1(mul(accumulatedMeanSquare, _this.decay), mul(square(gradient), 1 - _this.decay)); var newAccumulatedMoments = add$1(mul(accumulatedMoments, _this.momentum), div(mul(gradient, _this.learningRate), sqrt(add$1(newAccumulatedMeanSquare_1, _this.epsilon)))); accumulatedMeanSquare.assign(newAccumulatedMeanSquare_1); accumulatedMoments.assign(newAccumulatedMoments); var newValue = sub(value, newAccumulatedMoments); value.assign(newValue); } }); }); this.incrementIterations(); }; RMSPropOptimizer.prototype.dispose = function () { if (this.accumulatedMeanSquares != null) { dispose(this.accumulatedMeanSquares.map(function (v) { return v.variable; })); } if (this.accumulatedMeanGrads != null && this.centered) { dispose(this.accumulatedMeanGrads.map(function (v) { return v.variable; })); } if (this.accumulatedMoments != null) { dispose(this.accumulatedMoments.map(function (v) { return v.variable; })); } }; RMSPropOptimizer.prototype.getWeights = function () { return __awaiter(this, void 0, void 0, function () { var variables; return __generator(this, function (_a) { switch (_a.label) { case 0: variables = __spreadArray(__spreadArray([], __read(this.accumulatedMeanSquares), false), __read(this.accumulatedMoments), false); if (this.centered) { variables.push.apply(variables, __spreadArray([], __read(this.accumulatedMeanGrads), false)); } return [4 /*yield*/, this.saveIterations()]; case 1: return [2 /*return*/, [_a.sent()].concat(variables.map(function (v) { return ({ name: v.originalName, tensor: v.variable }); }))]; } }); }); }; RMSPropOptimizer.prototype.setWeights = function (weightValues) { return __awaiter(this, void 0, void 0, function () { var variableCount, trainable; return __generator(this, function (_a) { switch (_a.label) { case 0: return [4 /*yield*/, this.extractIterations(weightValues)]; case 1: weightValues = _a.sent(); variableCount = this.centered ? weightValues.length / 3 : weightValues.length / 2; trainable = false; this.accumulatedMeanSquares = weightValues.slice(0, variableCount).map(function (v) { return ({ originalName: v.name, variable: v.tensor.variable(trainable) }); }); this.accumulatedMoments = weightValues.slice(variableCount, variableCount * 2) .map(function (v) { return ({ originalName: v.name, variable: v.tensor.variable(trainable) }); }); if (this.centered) { this.accumulatedMeanGrads = weightValues.slice(variableCount * 2, variableCount * 3) .map(function (v) { return ({ originalName: v.name, variable: v.tensor.variable(trainable) }); }); } return [2 /*return*/]; } }); }); }; RMSPropOptimizer.prototype.getConfig = function () { return { 'learningRate': this.learningRate, 'decay': this.decay, 'momentum': this.momentum, 'epsilon': this.epsilon, 'centered': this.centered }; }; /** @nocollapse */ RMSPropOptimizer.fromConfig = function (cls, config) { return new cls(config['learningRate'], config['decay'], config['momentum'], config['epsilon'], config['centered']); }; return RMSPropOptimizer; })(Optimizer)); /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ ((function () { if (typeof requestAnimationFrame !== 'undefined') { return requestAnimationFrame; } else if (typeof setImmediate !== 'undefined') { return setImmediate; } return function (f) { return f(); }; // no delays }))(); /** * @license * Copyright 2022 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var RowPartitionType; (function (RowPartitionType) { RowPartitionType[RowPartitionType["FIRST_DIM_SIZE"] = 0] = "FIRST_DIM_SIZE"; RowPartitionType[RowPartitionType["VALUE_ROWIDS"] = 1] = "VALUE_ROWIDS"; RowPartitionType[RowPartitionType["ROW_LENGTHS"] = 2] = "ROW_LENGTHS"; RowPartitionType[RowPartitionType["ROW_SPLITS"] = 3] = "ROW_SPLITS"; RowPartitionType[RowPartitionType["ROW_LIMITS"] = 4] = "ROW_LIMITS"; RowPartitionType[RowPartitionType["ROW_STARTS"] = 5] = "ROW_STARTS"; })(RowPartitionType || (RowPartitionType = {})); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var SELU_SCALEALPHA = 1.7580993408473768599402175208123; var SELU_SCALE = 1.0507009873554804934193349852946; // Gradient for product operation on a single axis. function prodGradFn_(x, dy, axis) { // The gradient tensor (dy) has a set of axes removed, so we create re-shaped // versions (of size 1) for the removed axis; this supports broadcasting over // those dimensions. var expandedYShape = x.shape.slice(); expandedYShape[axis] = 1; // The actual gradient computation. var expandedDy = reshape$1(dy, expandedYShape); var xCumProd = cumprod(x, axis, true, false); var xCumRevProd = cumprod(x, axis, true, true); var dx = mul(xCumProd, xCumRevProd); return mul(expandedDy, dx); } // Support gradients when the product is done on many axes at once. // This done py pushing all the axes on which the product is applied into a // single axis. function prodsGradFn_(x, dy, axis) { // Move all axes for doing prod over to the end of the tensor. var xRank = x.shape.length; var finalProdAxis = xRank - axis.length; var xPermutation = getAxesPermutation(axis, xRank); var permutedX = x; if (xPermutation != null) { permutedX = transpose(x, xPermutation); } // Reshape all the prod dimensions into a single one, and do compute prod // gradients on that. var newShape = permutedX.shape.slice(); var removedShape = newShape.splice(xRank - axis.length, axis.length); var endPartShape = removedShape.reduce(function (p, c) { return p * c; }, 1); newShape.push(endPartShape); var reshapedPermutedX = permutedX.reshape(newShape); var prodGrad = prodGradFn_(reshapedPermutedX, dy, finalProdAxis); // Undo the re-shaping now we have the dx vector, and permute back to // original axes order. prodGrad = prodGrad.reshape(permutedX.shape); if (xPermutation != null) { var undoPermutation = getUndoAxesPermutation(xPermutation); prodGrad = transpose(prodGrad, undoPermutation); } return prodGrad; } // Running example: // [ // [ // [3.0, 4.0], // [5.0, 6.0], // [7.0, 8.0] // ], // [ // [3.0, 5.0], // [0.0, 6.0], // [5.0, 6.0] // ] // ] // var prodGradConfig = { kernelName: Prod, inputsToSave: ['x'], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 1), x = _a[0]; var axis = attrs.axis; var axisArr = []; if (axis === undefined || axis === null) { axisArr = x.shape.map(function (_, i) { return i; }); } else if (typeof axis === 'number') { axisArr = [axis]; } else { axisArr = axis; } return { x: function () { return prodsGradFn_(x, dy, axisArr); } }; } }; var divGradConfig = { kernelName: RealDiv, inputsToSave: ['a', 'b'], gradFunc: function (dy, saved) { var _a = __read(saved, 2), a = _a[0], b = _a[1]; var outShape = assertAndGetBroadcastShape(a.shape, b.shape); var derA = function () { var res = div(dy, cast(b, 'float32')); var reduceAxes = getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { return reshape$1(sum(res, reduceAxes), a.shape); } return res; }; var derB = function () { var res = mul(dy, cast(a, 'float32')); var reduceAxes = getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { res = reshape$1(sum(res, reduceAxes), b.shape); } var tmp = square(b); return neg(div(res, cast(tmp, 'float32'))); }; return { a: derA, b: derB }; } }; var reciprocalGradConfig = { kernelName: Reciprocal, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return div(dy, neg(square(x))); } }; } }; var relu6GradConfig = { kernelName: Relu6$1, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; var mask = mul(lessEqual(x, 6), step(x)); return { x: function () { return mul(dy, cast(mask, 'float32')); } }; } }; var reluGradConfig = { kernelName: Relu$1, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return mul(dy, cast(step(x), 'float32')); } }; } }; var reshapeGradConfig = { kernelName: Reshape$1, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return reshape$1(dy, x.shape); } }; } }; var resizeBilinearGradConfig = { kernelName: ResizeBilinear, inputsToSave: ['images'], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 1), images = _a[0]; var inputs = { dy: dy, images: images }; var imagesDer = function () { // tslint:disable-next-line: no-unnecessary-type-assertion return ENGINE.runKernel(ResizeBilinearGrad, inputs, attrs); }; return { images: imagesDer }; } }; var resizeNearestNeighborGradConfig = { kernelName: ResizeNearestNeighbor, inputsToSave: ['images'], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 1), images = _a[0]; var inputs = { dy: dy, images: images }; var imagesDer = function () { // tslint:disable-next-line: no-unnecessary-type-assertion return ENGINE.runKernel(ResizeNearestNeighborGrad, inputs, attrs); }; return { images: imagesDer }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var reverseGradConfig = { kernelName: Reverse, gradFunc: function (dy, saved, attrs) { var dims = attrs.dims; var axes = parseAxisParam(dims, dy.shape); return { x: function () { return reverse(dy, axes); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var roundGradConfig = { kernelName: Round, gradFunc: function (dy) { // TODO(nsthorat): Let gradients be null for cases where we want to stop // backpropgation. return { x: function () { return zerosLike(dy); } }; } }; var rsqrtGradConfig = { kernelName: Rsqrt, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return neg(div(dy, mul(pow(x, 1.5), 2))); } }; } }; var selectGradConfig = { kernelName: Select, inputsToSave: ['condition'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), condition = _a[0]; return { // TODO(julianoks): Return null for condition gradient // when backprop supports it. condition: function () { return cast(zerosLike(condition), 'float32'); }, t: function () { return mul(dy, cast(condition, dy.dtype)); }, e: function () { return mul(dy, cast(logicalNot(condition), dy.dtype)); } }; } }; var seluGradConfig = { kernelName: Selu$1, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { var mask = greater$1(x, scalar(0)); var scaleAlpha = scalar(SELU_SCALEALPHA); var scale = scalar(SELU_SCALE); var greaterThanZeroDer = mul(dy, scale); var lessEqualZeroDer = mul(mul(dy, scaleAlpha), exp(cast(x, 'float32'))); return where(mask, greaterThanZeroDer, lessEqualZeroDer); } }; } }; var sigmoidGradConfig = { kernelName: Sigmoid$1, outputsToSave: [true], gradFunc: function (dy, saved) { var _a = __read(saved, 1), y = _a[0]; return { x: function () { return mul(dy, mul(y, sub(scalar(1), y))); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var signGradConfig = { kernelName: Sign, gradFunc: function (dy) { return { x: function () { return zerosLike(dy); } }; } }; var sinGradConfig = { kernelName: Sin, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return mul(cos(cast(x, 'float32')), dy); } }; } }; var sinhGradConfig = { kernelName: Sinh, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return mul(cosh(cast(x, 'float32')), dy); } }; } }; var sliceGradConfig = { kernelName: Slice, inputsToSave: ['x'], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 1), x = _a[0]; var begin = attrs.begin, size = attrs.size; var inputShape = x.shape; var _b = __read(parseSliceParams(x, begin, size), 2), begin_ = _b[0], size_ = _b[1]; // Create an Nx2 padding where the first column represents how many // zeros are prepended (at start) for each dimension, and the second // column indicates how many zeros are appended (at end). // The number of zeros to append is the shape of the input // elementwise-subtracted by both the begin vector and sizes vector. var paddings = []; for (var i = 0; i < dy.rank; i++) { paddings.push([begin_[i], inputShape[i] - begin_[i] - size_[i]]); } return { x: function () { return pad(dy, paddings); } }; } }; var softmaxGradConfig = { kernelName: Softmax$2, outputsToSave: [true], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 1), y = _a[0]; var dim = attrs.dim; var keepDims = true; var dyTimesY = mul(dy, y); return { logits: function () { return sub(dyTimesY, mul(sum(dyTimesY, [dim], keepDims), y)); } }; } }; var softplusGradConfig = { kernelName: Softplus$1, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return mul(dy, sigmoid(x)); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var spaceToBatchNDGradConfig = { kernelName: SpaceToBatchND, gradFunc: function (dy, saved, attrs) { var blockShape = attrs.blockShape, paddings = attrs.paddings; return { x: function () { return batchToSpaceND(dy, blockShape, paddings); } }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var splitVGradConfig = { kernelName: SplitV, gradFunc: function (dy, saved, attrs) { var axis = attrs.axis; return { x: function () { return concat(dy, axis); } }; } }; var sqrtGradConfig = { kernelName: Sqrt, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return div(dy, mul(sqrt(cast(x, 'float32')), 2)); } }; } }; var squareGradConfig = { kernelName: Square, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return mul(dy, mul(cast(x, 'float32'), 2)); } }; } }; var squaredDifferenceGradConfig = { kernelName: SquaredDifference, inputsToSave: ['a', 'b'], gradFunc: function (dy, saved) { var _a = __read(saved, 2), a = _a[0], b = _a[1]; var two = scalar(2); var derA = function () { return mul(dy, mul(two, sub(a, b))); }; var derB = function () { return mul(dy, mul(two, sub(b, a))); }; return { a: derA, b: derB }; } }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var stepGradConfig = { kernelName: Step, gradFunc: function (dy) { // TODO(manrajgrover): Return null for gradients when backprop supports // it. return { x: function () { return zerosLike(dy); } }; } }; var subGradConfig = { kernelName: Sub, inputsToSave: ['a', 'b'], gradFunc: function (dy, saved) { var _a = __read(saved, 2), a = _a[0], b = _a[1]; var outShape = assertAndGetBroadcastShape(a.shape, b.shape); var derA = function () { var res = dy; var reduceAxes = getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { res = sum(res, reduceAxes); } return reshape$1(res, a.shape); }; var derB = function () { var res = dy; var reduceAxes = getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { res = sum(res, reduceAxes); } return reshape$1(neg(res), b.shape); }; return { a: derA, b: derB }; } }; var sumGradConfig = { kernelName: Sum, inputsToSave: ['x'], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 1), x = _a[0]; var expandedDyShape = x.shape.slice(); var axis = attrs.axis; var axes = parseAxisParam(axis, x.shape); axes.forEach(function (axis) { expandedDyShape[axis] = 1; }); var expandedDy = reshape$1(dy, expandedDyShape); var derX = mul(expandedDy, ones$1(x.shape, 'float32')); return { x: function () { return derX; } }; } }; var tanGradConfig = { kernelName: Tan, inputsToSave: ['x'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), x = _a[0]; return { x: function () { return div(dy, square(cos(x))); } }; } }; var tanhGradConfig = { kernelName: Tanh$1, outputsToSave: [true], gradFunc: function (dy, saved) { var _a = __read(saved, 1), y = _a[0]; return { x: function () { return mul(sub(scalar(1), square(y)), dy); } }; } }; var tileGradConfig = { kernelName: Tile, inputsToSave: ['x'], gradFunc: function (dy, saved, attrs) { var _a = __read(saved, 1), x = _a[0]; var reps = attrs.reps; var derX = function () { var xGrad = zerosLike(x); // TODO(cais): Maybe reduce memory footprint by avoiding repeated // slicing. if (x.rank === 1) { for (var i = 0; i < reps[0]; ++i) { xGrad = add$1(xGrad, slice(dy, [i * x.shape[0]], [x.shape[0]])); } } else if (x.rank === 2) { for (var i = 0; i < reps[0]; ++i) { for (var j = 0; j < reps[1]; ++j) { xGrad = add$1(xGrad, slice(dy, [i * x.shape[0], j * x.shape[1]], [ x.shape[0], x.shape[1] ])); } } } else if (x.rank === 3) { for (var i = 0; i < reps[0]; ++i) { for (var j = 0; j < reps[1]; ++j) { for (var k = 0; k < reps[2]; ++k) { xGrad = add$1(xGrad, slice(dy, [i * x.shape[0], j * x.shape[1], k * x.shape[2]], [x.shape[0], x.shape[1], x.shape[2]])); } } } } else if (x.rank === 4) { for (var i = 0; i < reps[0]; ++i) { for (var j = 0; j < reps[1]; ++j) { for (var k = 0; k < reps[2]; ++k) { for (var l = 0; l < reps[3]; ++l) { xGrad = add$1(xGrad, slice(dy, [ i * x.shape[0], j * x.shape[1], k * x.shape[2], l * x.shape[3] ], [x.shape[0], x.shape[1], x.shape[2], x.shape[3]])); } } } } } else { throw new Error("Gradient for tile operation is not implemented for rank-" + "".concat(x.rank, " tensors yet.")); } return xGrad; }; return { x: derX }; }, }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var transposeGradConfig = { kernelName: Transpose, gradFunc: function (dy, saved, attrs) { var transposeAttrs = attrs; var perm = transposeAttrs.perm; var undoPerm = getUndoAxesPermutation(perm); return { x: function () { return transpose(dy, undoPerm); } }; } }; /** * @license * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var unpackGradConfig = { kernelName: Unpack, gradFunc: function (dy, saved, attrs) { var unpackAttrs = attrs; var axis = unpackAttrs.axis; return { value: function () { return stack(dy, axis); } }; } }; var unsortedSegmentSumGradConfig = { kernelName: UnsortedSegmentSum, inputsToSave: ['segmentIds'], gradFunc: function (dy, saved) { var _a = __read(saved, 1), segmentIds = _a[0]; var derX = function () { return gatherDropNegatives(dy, segmentIds); }; return { x: derX }; } }; function gatherDropNegatives(x, indices) { // Helper function for unsorted segment ops. Gathers params for // positive segment ids and gathers 0 for inputs with negative segment id. // Mirrors _GatherDropNegatives from tensorflow/python/ops/math_grad.py var zeroClippedIndices = maximum$1(indices, zerosLike(indices)); var gathered = gather(x, zeroClippedIndices); var isPositive = greaterEqual(indices, scalar(0, 'int32')); var numIters = gathered.rank - isPositive.rank; for (var i = 0; i < numIters; ++i) { isPositive = expandDims(isPositive, i + 1); } isPositive = logicalAnd(isPositive, ones$1(gathered.shape, 'bool')); var zeroSlice = zerosLike(gathered); return where(isPositive, gathered, zeroSlice); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var zerosLikeGradConfig = { kernelName: ZerosLike, gradFunc: function (dy) { return { x: function () { return zerosLike(dy); } }; } }; var e_1, _a; // Export all kernel configs here so that the package can auto register them var gradConfigs = [ absGradConfig, acosGradConfig, acoshGradConfig, addGradConfig, addNGradConfig, argMaxGradConfig, argMinGradConfig, asinGradConfig, asinhGradConfig, atan2GradConfig, atanGradConfig, atanhGradConfig, avgPool3DGradConfig, avgPoolGradConfig, batchMatMulGradConfig, batchToSpaceNDGradConfig, broadcastToGradConfig, castGradConfig, ceilGradConfig, clipByValueGradConfig, complexAbsGradConfig, concatGradConfig, conv2DBackpropInputGradConfig, conv2DGradConfig, conv3DGradConfig, cosGradConfig, coshGradConfig, cumsumGradConfig, depthwiseConv2dNativeGradConfig, dilation2dGradConfig, divGradConfig, eluGradConfig, erfGradConfig, expGradConfig, expandDimsGradConfig, expm1GradConfig, floorDivGradConfig, floorGradConfig, fusedBatchNormGradConfig, gatherGradConfig, greaterEqualGradConfig, identityGradConfig, isFiniteGradConfig, isInfGradConfig, isNanGradConfig, leakyReluGradConfig, log1pGradConfig, logGradConfig, logSoftmaxGradConfig, lrnGradConfig, maxGradConfig, maxGradConfig, maximumGradConfig, maxPool3DGradConfig, maxPoolGradConfig, meanGradConfig, minGradConfig, minimumGradConfig, mirrorPadGradConfig, modGradConfig, multiplyGradConfig, negGradConfig, oneHotGradConfig, onesLikeGradConfig, packGradConfig, padV2GradConfig, padV2GradConfig, powGradConfig, preluGradConfig, prodGradConfig, reciprocalGradConfig, relu6GradConfig, reluGradConfig, reshapeGradConfig, resizeBilinearGradConfig, resizeNearestNeighborGradConfig, reverseGradConfig, roundGradConfig, rsqrtGradConfig, selectGradConfig, seluGradConfig, sigmoidGradConfig, signGradConfig, sinGradConfig, sinhGradConfig, sliceGradConfig, softmaxGradConfig, softplusGradConfig, spaceToBatchNDGradConfig, spaceToBatchNDGradConfig, splitVGradConfig, splitVGradConfig, sqrtGradConfig, squaredDifferenceGradConfig, squareGradConfig, stepGradConfig, subGradConfig, sumGradConfig, tanGradConfig, tanhGradConfig, tileGradConfig, transposeGradConfig, unpackGradConfig, unsortedSegmentSumGradConfig, zerosLikeGradConfig ]; try { for (var gradConfigs_1 = __values(gradConfigs), gradConfigs_1_1 = gradConfigs_1.next(); !gradConfigs_1_1.done; gradConfigs_1_1 = gradConfigs_1.next()) { var gradientConfig = gradConfigs_1_1.value; registerGradient(gradientConfig); } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (gradConfigs_1_1 && !gradConfigs_1_1.done && (_a = gradConfigs_1.return)) _a.call(gradConfigs_1); } finally { if (e_1) throw e_1.error; } } /** * Helper function used by many of the Constraints to find the L2Norms. */ function calcL2Norms(w, axis) { return tfc.tidy(function () { return tfc__namespace.sqrt(tfc__namespace.sum(tfc__namespace.mul(w, w), axis, true)); }); } /** * Base class for functions that impose constraints on weight values * * @doc { * heading: 'Constraints', * subheading: 'Classes', * namespace: 'constraints' * } */ var Constraint = /** @class */ (function (_super) { __extends(Constraint, _super); function Constraint() { return _super !== null && _super.apply(this, arguments) || this; } Constraint.prototype.getConfig = function () { return {}; }; return Constraint; }(tfc.serialization.Serializable)); var MaxNorm = /** @class */ (function (_super) { __extends(MaxNorm, _super); function MaxNorm(args) { var _this = _super.call(this) || this; _this.defaultMaxValue = 2; _this.defaultAxis = 0; _this.maxValue = args.maxValue != null ? args.maxValue : _this.defaultMaxValue; _this.axis = args.axis != null ? args.axis : _this.defaultAxis; return _this; } MaxNorm.prototype.apply = function (w) { var _this = this; return tfc.tidy(function () { var norms = calcL2Norms(w, _this.axis); var desired = tfc__namespace.clipByValue(norms, 0, _this.maxValue); return tfc__namespace.mul(w, tfc__namespace.div(desired, tfc__namespace.add(epsilon(), norms))); }); }; MaxNorm.prototype.getConfig = function () { return { maxValue: this.maxValue, axis: this.axis }; }; return MaxNorm; }(Constraint)); /** @nocollapse */ MaxNorm.className = 'MaxNorm'; tfc.serialization.registerClass(MaxNorm); var UnitNorm = /** @class */ (function (_super) { __extends(UnitNorm, _super); function UnitNorm(args) { var _this = _super.call(this) || this; _this.defaultAxis = 0; _this.axis = args.axis != null ? args.axis : _this.defaultAxis; return _this; } UnitNorm.prototype.apply = function (w) { var _this = this; return tfc.tidy(function () { return tfc__namespace.div(w, tfc__namespace.add(epsilon(), calcL2Norms(w, _this.axis))); }); }; UnitNorm.prototype.getConfig = function () { return { axis: this.axis }; }; return UnitNorm; }(Constraint)); /** @nocollapse */ UnitNorm.className = 'UnitNorm'; tfc.serialization.registerClass(UnitNorm); var NonNeg = /** @class */ (function (_super) { __extends(NonNeg, _super); function NonNeg() { return _super !== null && _super.apply(this, arguments) || this; } NonNeg.prototype.apply = function (w) { return tfc__namespace.relu(w); }; return NonNeg; }(Constraint)); /** @nocollapse */ NonNeg.className = 'NonNeg'; tfc.serialization.registerClass(NonNeg); var MinMaxNorm = /** @class */ (function (_super) { __extends(MinMaxNorm, _super); function MinMaxNorm(args) { var _this = _super.call(this) || this; _this.defaultMinValue = 0.0; _this.defaultMaxValue = 1.0; _this.defaultRate = 1.0; _this.defaultAxis = 0; _this.minValue = args.minValue != null ? args.minValue : _this.defaultMinValue; _this.maxValue = args.maxValue != null ? args.maxValue : _this.defaultMaxValue; _this.rate = args.rate != null ? args.rate : _this.defaultRate; _this.axis = args.axis != null ? args.axis : _this.defaultAxis; return _this; } MinMaxNorm.prototype.apply = function (w) { var _this = this; return tfc.tidy(function () { var norms = calcL2Norms(w, _this.axis); var desired = tfc__namespace.add(tfc__namespace.mul(_this.rate, tfc__namespace.clipByValue(norms, _this.minValue, _this.maxValue)), tfc__namespace.mul(1.0 - _this.rate, norms)); return tfc__namespace.mul(w, tfc__namespace.div(desired, tfc__namespace.add(epsilon(), norms))); }); }; MinMaxNorm.prototype.getConfig = function () { return { minValue: this.minValue, maxValue: this.maxValue, rate: this.rate, axis: this.axis }; }; return MinMaxNorm; }(Constraint)); /** @nocollapse */ MinMaxNorm.className = 'MinMaxNorm'; tfc.serialization.registerClass(MinMaxNorm); // Maps the JavaScript-like identifier keys to the corresponding registry // symbols. var CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP = { 'maxNorm': 'MaxNorm', 'minMaxNorm': 'MinMaxNorm', 'nonNeg': 'NonNeg', 'unitNorm': 'UnitNorm' }; function serializeConstraint(constraint) { return serializeKerasObject(constraint); } function deserializeConstraint(config, customObjects) { if (customObjects === void 0) { customObjects = {}; } return deserializeKerasObject(config, tfc.serialization.SerializationMap.getMap().classNameMap, customObjects, 'constraint'); } function getConstraint(identifier) { if (identifier == null) { return null; } if (typeof identifier === 'string') { var className = identifier in CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP ? CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] : identifier; var config = { className: className, config: {} }; return deserializeConstraint(config); } else if (identifier instanceof Constraint) { return identifier; } else { return deserializeConstraint(identifier); } } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * MaxNorm weight constraint. * * Constrains the weights incident to each hidden unit * to have a norm less than or equal to a desired value. * * References * - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting * Srivastava, Hinton, et al. * 2014](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf) * * @doc {heading: 'Constraints',namespace: 'constraints'} */ function maxNorm(args) { return new MaxNorm(args); } /** * Constrains the weights incident to each hidden unit to have unit norm. * * @doc {heading: 'Constraints', namespace: 'constraints'} */ function unitNorm(args) { return new UnitNorm(args); } /** * Constrains the weight to be non-negative. * * @doc {heading: 'Constraints', namespace: 'constraints'} */ function nonNeg() { return new NonNeg(); } /** @doc {heading: 'Constraints', namespace: 'constraints'} */ function minMaxNorm(config) { return new MinMaxNorm(config); } var exports_constraints = { __proto__: null, maxNorm: maxNorm, minMaxNorm: minMaxNorm, nonNeg: nonNeg, unitNorm: unitNorm }; /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Initializer that generates tensors initialized to 0. * * @doc {heading: 'Initializers', namespace: 'initializers'} */ function zeros() { return new Zeros(); } /** * Initializer that generates tensors initialized to 1. * * @doc {heading: 'Initializers', namespace: 'initializers'} */ function ones() { return new Ones(); } /** * Initializer that generates values initialized to some constant. * * @doc {heading: 'Initializers', namespace: 'initializers'} */ function constant(args) { return new Constant(args); } /** * Initializer that generates random values initialized to a uniform * distribution. * * Values will be distributed uniformly between the configured minval and * maxval. * * @doc {heading: 'Initializers', namespace: 'initializers'} */ function randomUniform(args) { return new RandomUniform(args); } /** * Initializer that generates random values initialized to a normal * distribution. * * @doc {heading: 'Initializers', namespace: 'initializers'} */ function randomNormal(args) { return new RandomNormal(args); } /** * Initializer that generates random values initialized to a truncated normal * distribution. * * These values are similar to values from a `RandomNormal` except that values * more than two standard deviations from the mean are discarded and re-drawn. * This is the recommended initializer for neural network weights and filters. * * @doc {heading: 'Initializers', namespace: 'initializers'} */ function truncatedNormal(args) { return new TruncatedNormal(args); } /** * Initializer that generates the identity matrix. * Only use for square 2D matrices. * * @doc {heading: 'Initializers', namespace: 'initializers'} */ function identity(args) { return new Identity$1(args); } /** * Initializer capable of adapting its scale to the shape of weights. * With distribution=NORMAL, samples are drawn from a truncated normal * distribution centered on zero, with `stddev = sqrt(scale / n)` where n is: * - number of input units in the weight tensor, if mode = FAN_IN. * - number of output units, if mode = FAN_OUT. * - average of the numbers of input and output units, if mode = FAN_AVG. * With distribution=UNIFORM, * samples are drawn from a uniform distribution * within [-limit, limit], with `limit = sqrt(3 * scale / n)`. * * @doc {heading: 'Initializers',namespace: 'initializers'} */ function varianceScaling(config) { return new VarianceScaling(config); } /** * Glorot uniform initializer, also called Xavier uniform initializer. * It draws samples from a uniform distribution within [-limit, limit] * where `limit` is `sqrt(6 / (fan_in + fan_out))` * where `fan_in` is the number of input units in the weight tensor * and `fan_out` is the number of output units in the weight tensor * * Reference: * Glorot & Bengio, AISTATS 2010 * http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf. * * @doc {heading: 'Initializers', namespace: 'initializers'} */ function glorotUniform(args) { return new GlorotUniform(args); } /** * Glorot normal initializer, also called Xavier normal initializer. * It draws samples from a truncated normal distribution centered on 0 * with `stddev = sqrt(2 / (fan_in + fan_out))` * where `fan_in` is the number of input units in the weight tensor * and `fan_out` is the number of output units in the weight tensor. * * Reference: * Glorot & Bengio, AISTATS 2010 * http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf * * @doc {heading: 'Initializers', namespace: 'initializers'} */ function glorotNormal(args) { return new GlorotNormal(args); } /** * He normal initializer. * * It draws samples from a truncated normal distribution centered on 0 * with `stddev = sqrt(2 / fanIn)` * where `fanIn` is the number of input units in the weight tensor. * * Reference: * He et al., http://arxiv.org/abs/1502.01852 * * @doc {heading: 'Initializers', namespace: 'initializers'} */ function heNormal(args) { return new HeNormal(args); } /** * He uniform initializer. * * It draws samples from a uniform distribution within [-limit, limit] * where `limit` is `sqrt(6 / fan_in)` * where `fanIn` is the number of input units in the weight tensor. * * Reference: * He et al., http://arxiv.org/abs/1502.01852 * * @doc {heading: 'Initializers',namespace: 'initializers'} */ function heUniform(args) { return new HeUniform(args); } /** * LeCun normal initializer. * * It draws samples from a truncated normal distribution centered on 0 * with `stddev = sqrt(1 / fanIn)` * where `fanIn` is the number of input units in the weight tensor. * * References: * [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) * [Efficient Backprop](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf) * * @doc {heading: 'Initializers', namespace: 'initializers'} */ function leCunNormal(args) { return new LeCunNormal(args); } /** * LeCun uniform initializer. * * It draws samples from a uniform distribution in the interval * `[-limit, limit]` with `limit = sqrt(3 / fanIn)`, * where `fanIn` is the number of input units in the weight tensor. * * @doc {heading: 'Initializers', namespace: 'initializers'} */ function leCunUniform(args) { return new LeCunUniform(args); } /** * Initializer that generates a random orthogonal matrix. * * Reference: * [Saxe et al., http://arxiv.org/abs/1312.6120](http://arxiv.org/abs/1312.6120) * * @doc {heading: 'Initializers', namespace: 'initializers'} */ function orthogonal(args) { return new Orthogonal(args); } var exports_initializers = { __proto__: null, constant: constant, glorotNormal: glorotNormal, glorotUniform: glorotUniform, heNormal: heNormal, heUniform: heUniform, identity: identity, leCunNormal: leCunNormal, leCunUniform: leCunUniform, ones: ones, orthogonal: orthogonal, randomNormal: randomNormal, randomUniform: randomUniform, truncatedNormal: truncatedNormal, varianceScaling: varianceScaling, zeros: zeros }; /** * Turn any Scalar values in a Logs object into actual number values. * * @param logs The `Logs` object to be resolved in place. */ function resolveScalarsInLogs(logs) { return __awaiter(this, void 0, void 0, function () { var promises, keys, scalarsToDispose, key, value, valueScalar, values, i; return __generator(this, function (_a) { switch (_a.label) { case 0: if (logs == null) { return [2 /*return*/]; } promises = []; keys = []; scalarsToDispose = []; for (key in logs) { value = logs[key]; if (typeof value !== 'number') { valueScalar = value; promises.push(valueScalar.data()); keys.push(key); scalarsToDispose.push(valueScalar); } } if (!(promises.length > 0)) return [3 /*break*/, 2]; return [4 /*yield*/, Promise.all(promises)]; case 1: values = _a.sent(); for (i = 0; i < values.length; ++i) { logs[keys[i]] = values[i][0]; } // Dispose the original scalar tensors. tfc.dispose(scalarsToDispose); _a.label = 2; case 2: return [2 /*return*/]; } }); }); } /** * Dispose all Tensors in an UnresolvedLogs object. * * @param logs An `UnresolvedLogs` object potentially containing `tf.Tensor`s in * places where the values can be `tf.Tensor` or `number`. */ function disposeTensorsInLogs(logs) { if (logs == null) { return; } for (var key in logs) { var value = logs[key]; if (typeof value !== 'number') { value.dispose(); } } } /** Verbosity logging level when fitting a model. */ var ModelLoggingVerbosity; (function (ModelLoggingVerbosity) { ModelLoggingVerbosity[ModelLoggingVerbosity["SILENT"] = 0] = "SILENT"; ModelLoggingVerbosity[ModelLoggingVerbosity["VERBOSE"] = 1] = "VERBOSE"; })(ModelLoggingVerbosity || (ModelLoggingVerbosity = {})); /** How often to yield to the main thread when training (in ms). */ var DEFAULT_YIELD_EVERY_MS = 125; /** * Abstract base class used to build new callbacks. * * The `logs` dictionary that callback methods take as argument will contain * keys for quantities relevant to the current batch or epoch. * * Currently, the `.fit()` method of the `Sequential` model class * will include the following quantities in the `logs` that * it passes to its callbacks: * * onEpochEnd: Logs include `acc` and `loss`, and optionally include `valLoss` * (if validation is enabled in `fit`), and `valAcc` (if validation and * accuracy monitoring are enabled). * onBatchBegin: Logs include `size`, the number of samples in the current * batch. * onBatchEnd: Logs include `loss`, and optionally `acc` (if accuracy monitoring * is enabled). */ var BaseCallback = /** @class */ (function () { function BaseCallback() { // TODO(michaelterry): This type is a best guess. this.validationData = null; } BaseCallback.prototype.setParams = function (params) { this.params = params; }; BaseCallback.prototype.onEpochBegin = function (epoch, logs) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { return [2 /*return*/]; }); }); }; BaseCallback.prototype.onEpochEnd = function (epoch, logs) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { return [2 /*return*/]; }); }); }; BaseCallback.prototype.onBatchBegin = function (batch, logs) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { return [2 /*return*/]; }); }); }; BaseCallback.prototype.onBatchEnd = function (batch, logs) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { return [2 /*return*/]; }); }); }; BaseCallback.prototype.onTrainBegin = function (logs) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { return [2 /*return*/]; }); }); }; BaseCallback.prototype.onTrainEnd = function (logs) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { return [2 /*return*/]; }); }); }; // LayersModel needs to call Callback.setModel(), but cannot actually depend // on Callback because that creates a cyclic dependency. Providing this no-op // method on BaseCallback breaks the cycle: this way LayersModel can depend on // BaseCallback but not on Callback. The argument is typed as `Container` // (the superclass of LayersModel) to avoid recapitulating the cycle. Callback // overrides this method and enforces that the argument is really a // LayersModel. BaseCallback.prototype.setModel = function (model) { // Do nothing. Use Callback instead of BaseCallback to track the model. }; return BaseCallback; }()); /** * Container abstracting a list of callbacks. */ var CallbackList = /** @class */ (function () { // TODO(cais): When the need arises, uncomment the following lines and // implement the queue for time values. // private deltaTBatch: number; // private deltaTsBatchBegin: Array; // private deltaTsBatchEnd: Array; /** * Constructor of CallbackList. * @param callbacks Array of `Callback` instances. * @param queueLength Queue length for keeping running statistics over * callback execution time. */ function CallbackList(callbacks, queueLength) { if (queueLength === void 0) { queueLength = 10; } // TODO(cais): Make use of queueLength when implementing the queue for time // values. if (callbacks == null) { callbacks = []; } this.callbacks = callbacks; this.queueLength = queueLength; } CallbackList.prototype.append = function (callback) { this.callbacks.push(callback); }; CallbackList.prototype.setParams = function (params) { var e_1, _a; try { for (var _b = __values(this.callbacks), _c = _b.next(); !_c.done; _c = _b.next()) { var callback = _c.value; callback.setParams(params); } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_1) throw e_1.error; } } }; CallbackList.prototype.setModel = function (model) { var e_2, _a; try { for (var _b = __values(this.callbacks), _c = _b.next(); !_c.done; _c = _b.next()) { var callback = _c.value; callback.setModel(model); } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_2) throw e_2.error; } } }; /** * Called at the start of an epoch. * @param epoch Index of epoch. * @param logs Dictionary of logs. */ CallbackList.prototype.onEpochBegin = function (epoch, logs) { return __awaiter(this, void 0, void 0, function () { var _a, _b, callback, e_3_1; var e_3, _c; return __generator(this, function (_d) { switch (_d.label) { case 0: if (logs == null) { logs = {}; } _d.label = 1; case 1: _d.trys.push([1, 6, 7, 8]); _a = __values(this.callbacks), _b = _a.next(); _d.label = 2; case 2: if (!!_b.done) return [3 /*break*/, 5]; callback = _b.value; return [4 /*yield*/, callback.onEpochBegin(epoch, logs)]; case 3: _d.sent(); _d.label = 4; case 4: _b = _a.next(); return [3 /*break*/, 2]; case 5: return [3 /*break*/, 8]; case 6: e_3_1 = _d.sent(); e_3 = { error: e_3_1 }; return [3 /*break*/, 8]; case 7: try { if (_b && !_b.done && (_c = _a.return)) _c.call(_a); } finally { if (e_3) throw e_3.error; } return [7 /*endfinally*/]; case 8: return [2 /*return*/]; } }); }); }; /** * Called at the end of an epoch. * @param epoch Index of epoch. * @param logs Dictionary of logs. */ CallbackList.prototype.onEpochEnd = function (epoch, logs) { return __awaiter(this, void 0, void 0, function () { var _a, _b, callback, e_4_1; var e_4, _c; return __generator(this, function (_d) { switch (_d.label) { case 0: if (logs == null) { logs = {}; } _d.label = 1; case 1: _d.trys.push([1, 6, 7, 8]); _a = __values(this.callbacks), _b = _a.next(); _d.label = 2; case 2: if (!!_b.done) return [3 /*break*/, 5]; callback = _b.value; return [4 /*yield*/, callback.onEpochEnd(epoch, logs)]; case 3: _d.sent(); _d.label = 4; case 4: _b = _a.next(); return [3 /*break*/, 2]; case 5: return [3 /*break*/, 8]; case 6: e_4_1 = _d.sent(); e_4 = { error: e_4_1 }; return [3 /*break*/, 8]; case 7: try { if (_b && !_b.done && (_c = _a.return)) _c.call(_a); } finally { if (e_4) throw e_4.error; } return [7 /*endfinally*/]; case 8: return [2 /*return*/]; } }); }); }; /** * Called right before processing a batch. * @param batch Index of batch within the current epoch. * @param logs Dictionary of logs. */ CallbackList.prototype.onBatchBegin = function (batch, logs) { return __awaiter(this, void 0, void 0, function () { var _a, _b, callback, e_5_1; var e_5, _c; return __generator(this, function (_d) { switch (_d.label) { case 0: if (logs == null) { logs = {}; } _d.label = 1; case 1: _d.trys.push([1, 6, 7, 8]); _a = __values(this.callbacks), _b = _a.next(); _d.label = 2; case 2: if (!!_b.done) return [3 /*break*/, 5]; callback = _b.value; return [4 /*yield*/, callback.onBatchBegin(batch, logs)]; case 3: _d.sent(); _d.label = 4; case 4: _b = _a.next(); return [3 /*break*/, 2]; case 5: return [3 /*break*/, 8]; case 6: e_5_1 = _d.sent(); e_5 = { error: e_5_1 }; return [3 /*break*/, 8]; case 7: try { if (_b && !_b.done && (_c = _a.return)) _c.call(_a); } finally { if (e_5) throw e_5.error; } return [7 /*endfinally*/]; case 8: return [2 /*return*/]; } }); }); }; /** * Called at the end of a batch. * @param batch Index of batch within the current epoch. * @param logs Dictionary of logs. */ CallbackList.prototype.onBatchEnd = function (batch, logs) { return __awaiter(this, void 0, void 0, function () { var _a, _b, callback, e_6_1; var e_6, _c; return __generator(this, function (_d) { switch (_d.label) { case 0: if (logs == null) { logs = {}; } _d.label = 1; case 1: _d.trys.push([1, 6, 7, 8]); _a = __values(this.callbacks), _b = _a.next(); _d.label = 2; case 2: if (!!_b.done) return [3 /*break*/, 5]; callback = _b.value; return [4 /*yield*/, callback.onBatchEnd(batch, logs)]; case 3: _d.sent(); _d.label = 4; case 4: _b = _a.next(); return [3 /*break*/, 2]; case 5: return [3 /*break*/, 8]; case 6: e_6_1 = _d.sent(); e_6 = { error: e_6_1 }; return [3 /*break*/, 8]; case 7: try { if (_b && !_b.done && (_c = _a.return)) _c.call(_a); } finally { if (e_6) throw e_6.error; } return [7 /*endfinally*/]; case 8: return [2 /*return*/]; } }); }); }; /** * Called at the beginning of training. * @param logs Dictionary of logs. */ CallbackList.prototype.onTrainBegin = function (logs) { return __awaiter(this, void 0, void 0, function () { var _a, _b, callback, e_7_1; var e_7, _c; return __generator(this, function (_d) { switch (_d.label) { case 0: if (logs == null) { logs = {}; } _d.label = 1; case 1: _d.trys.push([1, 6, 7, 8]); _a = __values(this.callbacks), _b = _a.next(); _d.label = 2; case 2: if (!!_b.done) return [3 /*break*/, 5]; callback = _b.value; return [4 /*yield*/, callback.onTrainBegin(logs)]; case 3: _d.sent(); _d.label = 4; case 4: _b = _a.next(); return [3 /*break*/, 2]; case 5: return [3 /*break*/, 8]; case 6: e_7_1 = _d.sent(); e_7 = { error: e_7_1 }; return [3 /*break*/, 8]; case 7: try { if (_b && !_b.done && (_c = _a.return)) _c.call(_a); } finally { if (e_7) throw e_7.error; } return [7 /*endfinally*/]; case 8: return [2 /*return*/]; } }); }); }; /** * Called at the end of training. * @param logs Dictionary of logs. */ CallbackList.prototype.onTrainEnd = function (logs) { return __awaiter(this, void 0, void 0, function () { var _a, _b, callback, e_8_1; var e_8, _c; return __generator(this, function (_d) { switch (_d.label) { case 0: if (logs == null) { logs = {}; } _d.label = 1; case 1: _d.trys.push([1, 6, 7, 8]); _a = __values(this.callbacks), _b = _a.next(); _d.label = 2; case 2: if (!!_b.done) return [3 /*break*/, 5]; callback = _b.value; return [4 /*yield*/, callback.onTrainEnd(logs)]; case 3: _d.sent(); _d.label = 4; case 4: _b = _a.next(); return [3 /*break*/, 2]; case 5: return [3 /*break*/, 8]; case 6: e_8_1 = _d.sent(); e_8 = { error: e_8_1 }; return [3 /*break*/, 8]; case 7: try { if (_b && !_b.done && (_c = _a.return)) _c.call(_a); } finally { if (e_8) throw e_8.error; } return [7 /*endfinally*/]; case 8: return [2 /*return*/]; } }); }); }; return CallbackList; }()); /** * Callback that accumulates epoch averages of metrics. * * This callback is automatically applied to every LayersModel. */ var BaseLogger = /** @class */ (function (_super) { __extends(BaseLogger, _super); function BaseLogger() { return _super.call(this) || this; } BaseLogger.prototype.onEpochBegin = function (epoch) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { this.seen = 0; this.totals = {}; return [2 /*return*/]; }); }); }; BaseLogger.prototype.onBatchEnd = function (batch, logs) { return __awaiter(this, void 0, void 0, function () { var batchSize, _loop_1, this_1, key; var _this = this; return __generator(this, function (_a) { if (logs == null) { logs = {}; } batchSize = logs['size'] == null ? 0 : logs['size']; this.seen += batchSize; _loop_1 = function (key) { var value = logs[key]; if (typeof value === 'number') { if (!this_1.totals.hasOwnProperty(key)) { this_1.totals[key] = 0; } this_1.totals[key] = this_1.totals[key] + value * batchSize; } else { var oldTotalsToDispose = void 0; if (key in this_1.totals) { oldTotalsToDispose = this_1.totals[key]; } else { this_1.totals[key] = 0; } var total = tfc.tidy(function () { return tfc.add((_this.totals[key]), tfc.mul(value, batchSize)); }); this_1.totals[key] = total; if (oldTotalsToDispose != null) { oldTotalsToDispose.dispose(); } } }; this_1 = this; for (key in logs) { _loop_1(key); } return [2 /*return*/]; }); }); }; BaseLogger.prototype.onEpochEnd = function (epoch, logs) { return __awaiter(this, void 0, void 0, function () { var _loop_2, this_2, _a, _b, key; var e_9, _c; var _this = this; return __generator(this, function (_d) { if (logs != null) { _loop_2 = function (key) { if (this_2.totals[key] == null) { return "continue"; } if (typeof this_2.totals[key] === 'number') { logs[key] = this_2.totals[key] / this_2.seen; } else { tfc.tidy(function () { var log = tfc.mul(tfc.div(1, _this.seen), _this.totals[key]); logs[key] = log; _this.totals[key].dispose(); tfc.keep(logs[key]); }); } }; this_2 = this; try { for (_a = __values(this.params['metrics']), _b = _a.next(); !_b.done; _b = _a.next()) { key = _b.value; _loop_2(key); } } catch (e_9_1) { e_9 = { error: e_9_1 }; } finally { try { if (_b && !_b.done && (_c = _a.return)) _c.call(_a); } finally { if (e_9) throw e_9.error; } } } return [2 /*return*/]; }); }); }; return BaseLogger; }(BaseCallback)); /** * Callback that records events into a `History` object. This callback is * automatically applied to every TF.js Layers model. The `History` object * gets returned by the `fit` method of models. */ var History = /** @class */ (function (_super) { __extends(History, _super); function History() { return _super !== null && _super.apply(this, arguments) || this; } History.prototype.onTrainBegin = function (logs) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { this.epoch = []; this.history = {}; return [2 /*return*/]; }); }); }; History.prototype.onEpochEnd = function (epoch, logs) { return __awaiter(this, void 0, void 0, function () { var key; return __generator(this, function (_a) { if (logs == null) { logs = {}; } this.epoch.push(epoch); for (key in logs) { if (this.history[key] == null) { this.history[key] = []; } this.history[key].push(logs[key]); } return [2 /*return*/]; }); }); }; /** * Await the values of all losses and metrics. */ History.prototype.syncData = function () { return __awaiter(this, void 0, void 0, function () { var promises, keys, indices, key, valueArray, i, valueScalar, values, n, tensorToDispose; return __generator(this, function (_a) { switch (_a.label) { case 0: promises = []; keys = []; indices = []; for (key in this.history) { valueArray = this.history[key]; for (i = 0; i < valueArray.length; ++i) { if (typeof valueArray[i] !== 'number') { valueScalar = valueArray[i]; promises.push(valueScalar.data()); keys.push(key); indices.push(i); } } } return [4 /*yield*/, Promise.all(promises)]; case 1: values = _a.sent(); for (n = 0; n < values.length; ++n) { tensorToDispose = this.history[keys[n]][indices[n]]; tensorToDispose.dispose(); this.history[keys[n]][indices[n]] = values[n][0]; } return [2 /*return*/]; } }); }); }; return History; }(BaseCallback)); /** * Custom callback for training. */ var CustomCallback = /** @class */ (function (_super) { __extends(CustomCallback, _super); function CustomCallback(args, yieldEvery) { var _this = _super.call(this) || this; _this.currentEpoch = 0; _this.nowFunc = args.nowFunc; _this.nextFrameFunc = args.nextFrameFunc || tfc.nextFrame; _this.yieldEvery = yieldEvery || 'auto'; if (_this.yieldEvery === 'auto') { _this.yieldEvery = DEFAULT_YIELD_EVERY_MS; } if (_this.yieldEvery === 'never' && args.onYield != null) { throw new Error('yieldEvery is `never` but you provided an `onYield` callback. ' + 'Either change `yieldEvery` or remove the callback'); } if (tfc.util.isNumber(_this.yieldEvery)) { // Decorate `maybeWait` so it will be called at most once every // `yieldEvery` ms. _this.maybeWait = debounce(_this.maybeWait.bind(_this), _this.yieldEvery, _this.nowFunc); } _this.trainBegin = args.onTrainBegin; _this.trainEnd = args.onTrainEnd; _this.epochBegin = args.onEpochBegin; _this.epochEnd = args.onEpochEnd; _this.batchBegin = args.onBatchBegin; _this.batchEnd = args.onBatchEnd; _this.yield = args.onYield; return _this; } CustomCallback.prototype.maybeWait = function (epoch, batch, logs) { return __awaiter(this, void 0, void 0, function () { var ps; return __generator(this, function (_a) { switch (_a.label) { case 0: ps = []; if (!(this.yield != null)) return [3 /*break*/, 2]; return [4 /*yield*/, resolveScalarsInLogs(logs)]; case 1: _a.sent(); ps.push(this.yield(epoch, batch, logs)); _a.label = 2; case 2: ps.push(this.nextFrameFunc()); return [4 /*yield*/, Promise.all(ps)]; case 3: _a.sent(); return [2 /*return*/]; } }); }); }; CustomCallback.prototype.onEpochBegin = function (epoch, logs) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { switch (_a.label) { case 0: this.currentEpoch = epoch; if (!(this.epochBegin != null)) return [3 /*break*/, 3]; return [4 /*yield*/, resolveScalarsInLogs(logs)]; case 1: _a.sent(); return [4 /*yield*/, this.epochBegin(epoch, logs)]; case 2: _a.sent(); _a.label = 3; case 3: return [2 /*return*/]; } }); }); }; CustomCallback.prototype.onEpochEnd = function (epoch, logs) { return __awaiter(this, void 0, void 0, function () { var ps; return __generator(this, function (_a) { switch (_a.label) { case 0: ps = []; if (!(this.epochEnd != null)) return [3 /*break*/, 2]; return [4 /*yield*/, resolveScalarsInLogs(logs)]; case 1: _a.sent(); ps.push(this.epochEnd(epoch, logs)); _a.label = 2; case 2: if (this.yieldEvery === 'epoch') { ps.push(this.nextFrameFunc()); } return [4 /*yield*/, Promise.all(ps)]; case 3: _a.sent(); return [2 /*return*/]; } }); }); }; CustomCallback.prototype.onBatchBegin = function (batch, logs) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { switch (_a.label) { case 0: if (!(this.batchBegin != null)) return [3 /*break*/, 3]; return [4 /*yield*/, resolveScalarsInLogs(logs)]; case 1: _a.sent(); return [4 /*yield*/, this.batchBegin(batch, logs)]; case 2: _a.sent(); _a.label = 3; case 3: return [2 /*return*/]; } }); }); }; CustomCallback.prototype.onBatchEnd = function (batch, logs) { return __awaiter(this, void 0, void 0, function () { var ps; return __generator(this, function (_a) { switch (_a.label) { case 0: ps = []; if (!(this.batchEnd != null)) return [3 /*break*/, 2]; return [4 /*yield*/, resolveScalarsInLogs(logs)]; case 1: _a.sent(); ps.push(this.batchEnd(batch, logs)); _a.label = 2; case 2: if (this.yieldEvery === 'batch') { ps.push(this.nextFrameFunc()); } else if (tfc.util.isNumber(this.yieldEvery)) { ps.push(this.maybeWait(this.currentEpoch, batch, logs)); } return [4 /*yield*/, Promise.all(ps)]; case 3: _a.sent(); return [2 /*return*/]; } }); }); }; CustomCallback.prototype.onTrainBegin = function (logs) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { switch (_a.label) { case 0: if (!(this.trainBegin != null)) return [3 /*break*/, 3]; return [4 /*yield*/, resolveScalarsInLogs(logs)]; case 1: _a.sent(); return [4 /*yield*/, this.trainBegin(logs)]; case 2: _a.sent(); _a.label = 3; case 3: return [2 /*return*/]; } }); }); }; CustomCallback.prototype.onTrainEnd = function (logs) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { switch (_a.label) { case 0: if (!(this.trainEnd != null)) return [3 /*break*/, 3]; return [4 /*yield*/, resolveScalarsInLogs(logs)]; case 1: _a.sent(); return [4 /*yield*/, this.trainEnd(logs)]; case 2: _a.sent(); _a.label = 3; case 3: return [2 /*return*/]; } }); }); }; return CustomCallback; }(BaseCallback)); /** * Standardize callbacks or configurations of them to an Array of callbacks. */ function standardizeCallbacks(callbacks, yieldEvery) { if (callbacks == null) { callbacks = {}; } if (callbacks instanceof BaseCallback) { return [callbacks]; } if (Array.isArray(callbacks) && callbacks[0] instanceof BaseCallback) { return callbacks; } // Convert custom callback configs to custom callback objects. var callbackConfigs = toList(callbacks); return callbackConfigs.map(function (callbackConfig) { return new CustomCallback(callbackConfig, yieldEvery); }); } /** * A global registry for callback constructors to be used during * LayersModel.fit(). */ var CallbackConstructorRegistry = /** @class */ (function () { /** * Blocks public access to constructor. */ function CallbackConstructorRegistry() { } /** * Register a tf.LayersModel.fit() callback constructor. * * The registered callback constructor will be used to instantiate * callbacks for every tf.LayersModel.fit() call afterwards. * * @param verbosityLevel Level of verbosity at which the `callbackConstructor` * is to be reigstered. * @param callbackConstructor A no-arg constructor for `tf.Callback`. * @throws Error, if the same callbackConstructor has been registered before, * either at the same or a different `verbosityLevel`. */ CallbackConstructorRegistry.registerCallbackConstructor = function (verbosityLevel, callbackConstructor) { tfc.util.assert(verbosityLevel >= 0 && Number.isInteger(verbosityLevel), function () { return "Verbosity level is expected to be an integer >= 0, " + "but got ".concat(verbosityLevel); }); CallbackConstructorRegistry.checkForDuplicate(callbackConstructor); if (CallbackConstructorRegistry.constructors[verbosityLevel] == null) { CallbackConstructorRegistry.constructors[verbosityLevel] = []; } CallbackConstructorRegistry.constructors[verbosityLevel].push(callbackConstructor); }; CallbackConstructorRegistry.checkForDuplicate = function (callbackConstructor) { for (var levelName in CallbackConstructorRegistry.constructors) { var constructors = CallbackConstructorRegistry.constructors[+levelName]; constructors.forEach(function (ctor) { if (ctor === callbackConstructor) { throw new ValueError('Duplicate callback constructor.'); } }); } }; /** * Clear all registered callback constructors. */ CallbackConstructorRegistry.clear = function () { CallbackConstructorRegistry.constructors = {}; }; /** * Create callbacks using the registered callback constructors. * * Given `verbosityLevel`, all constructors registered at that level or above * will be called and the instantiated callbacks will be used. * * @param verbosityLevel: Level of verbosity. */ CallbackConstructorRegistry.createCallbacks = function (verbosityLevel) { var constructors = []; for (var levelName in CallbackConstructorRegistry.constructors) { var level = +levelName; if (verbosityLevel >= level) { constructors.push.apply(constructors, __spreadArray([], __read(CallbackConstructorRegistry.constructors[level]), false)); } } return constructors.map(function (ctor) { return new ctor(); }); }; return CallbackConstructorRegistry; }()); CallbackConstructorRegistry.constructors = {}; function configureCallbacks(callbacks, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics) { var history = new History(); var actualCallbacks = __spreadArray([ new BaseLogger() ], __read(CallbackConstructorRegistry.createCallbacks(verbose)), false); if (callbacks != null) { actualCallbacks.push.apply(actualCallbacks, __spreadArray([], __read(callbacks), false)); } actualCallbacks.push(history); var callbackList = new CallbackList(actualCallbacks); // TODO(cais): Figure out when this LayersModel instance can have a // dynamically // set property called 'callback_model' as in PyKeras. callbackList.setParams({ epochs: epochs, initialEpoch: initialEpoch, samples: numTrainSamples, steps: stepsPerEpoch, batchSize: batchSize, verbose: verbose, doValidation: doValidation, metrics: callbackMetrics, }); return { callbackList: callbackList, history: history }; } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Instantiate a layer from a config dictionary. * @param config dict of the form {class_name: str, config: dict} * @param customObjects dict mapping class names (or function names) * of custom (non-Keras) objects to class/functions * @param fastWeightInit Optional flag to use fast weight initialization * during deserialization. This is applicable to cases in which * the initialization will be immediately overwritten by loaded weight * values. Default: `false`. * @returns Layer instance (may be LayersModel, Sequential, Layer...) */ function deserialize(config, customObjects, fastWeightInit) { if (customObjects === void 0) { customObjects = {}; } if (fastWeightInit === void 0) { fastWeightInit = false; } return deserializeKerasObject(config, tfc.serialization.SerializationMap.getMap().classNameMap, customObjects, 'layer', fastWeightInit); } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Normalizes a tensor wrt the L2 norm alongside the specified axis. * @param x * @param axis Axis along which to perform normalization. */ function l2Normalize(x, axis) { return tfc.tidy(function () { if (x.dtype !== 'float32') { x = tfc__namespace.cast(x, 'float32'); } var squareSum = tfc__namespace.sum(square$1(x), axis, true); var epsilonTensor = tfc__namespace.fill(squareSum.shape, epsilon()); var norm = tfc__namespace.sqrt(tfc__namespace.maximum(squareSum, epsilonTensor)); return tfc__namespace.div(x, norm); }); } function meanSquaredError$1(yTrue, yPred) { return tfc.tidy(function () { return tfc__namespace.mean(square$1(tfc__namespace.sub(yPred, yTrue)), -1); }); } function meanAbsoluteError$1(yTrue, yPred) { return tfc.tidy(function () { return tfc__namespace.mean(tfc__namespace.abs(tfc__namespace.sub(yPred, yTrue)), -1); }); } function meanAbsolutePercentageError$1(yTrue, yPred) { return tfc.tidy(function () { var diff = tfc__namespace.sub(yTrue, yPred); var clippedTrue = tfc__namespace.clipByValue(tfc__namespace.abs(yTrue), epsilon(), Number.MAX_VALUE); var absResult = tfc__namespace.abs(tfc__namespace.div(diff, clippedTrue)); return tfc__namespace.mul(100, tfc__namespace.mean(absResult, -1)); }); } function meanSquaredLogarithmicError(yTrue, yPred) { return tfc.tidy(function () { var clippedPred = tfc__namespace.clipByValue(yPred, epsilon(), Number.MAX_VALUE); var firstLog = tfc__namespace.log(tfc__namespace.add(1, clippedPred)); var clippedTrue = tfc__namespace.clipByValue(yTrue, epsilon(), Number.MAX_VALUE); var secondLog = tfc__namespace.log(tfc__namespace.add(1, clippedTrue)); return tfc__namespace.mean(square$1(tfc__namespace.sub(firstLog, secondLog)), -1); }); } function squaredHinge(yTrue, yPred) { return tfc.tidy(function () { var maxResult = tfc__namespace.maximum(0, tfc__namespace.sub(1, tfc__namespace.mul(yTrue, yPred))); return tfc__namespace.mean(square$1(maxResult), -1); }); } function hinge(yTrue, yPred) { return tfc.tidy(function () { var maxResult = tfc__namespace.maximum(0, tfc__namespace.sub(1, tfc__namespace.mul(yTrue, yPred))); return tfc__namespace.mean(maxResult, -1); }); } function categoricalHinge(yTrue, yPred) { return tfc.tidy(function () { var pos = tfc__namespace.sum(tfc__namespace.mul(yTrue, yPred), -1); var neg = tfc__namespace.max(tfc__namespace.mul(tfc__namespace.sub(1, yTrue), yPred), -1); return tfc__namespace.maximum(0, tfc__namespace.add(1, tfc__namespace.sub(neg, pos))); }); } /** * Logarithm of the hyperbolic cosine of the prediction error. * * `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and * to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works mostly * like the mean squared error, but will not be so strongly affected by the * occasional wildly incorrect prediction. */ function logcosh(yTrue, yPred) { return tfc.tidy(function () { var log2 = Math.log(2); var predictionDiff = tfc__namespace.sub(yPred, yTrue); var logcoshResult = tfc__namespace.sub(tfc__namespace.add(predictionDiff, tfc__namespace.softplus(tfc__namespace.mul(-2, predictionDiff))), log2); return tfc__namespace.mean(logcoshResult, -1); }); } function categoricalCrossentropy$2(target, output, fromLogits) { if (fromLogits === void 0) { fromLogits = false; } return tfc.tidy(function () { if (fromLogits) { output = tfc__namespace.softmax(output); } else { // scale preds so that the class probabilities of each sample sum to 1. var outputSum = tfc__namespace.sum(output, output.shape.length - 1, true); output = tfc__namespace.div(output, outputSum); } output = tfc__namespace.clipByValue(output, epsilon(), 1 - epsilon()); return tfc__namespace.neg(tfc__namespace.sum(tfc__namespace.mul(tfc__namespace.cast(target, 'float32'), tfc__namespace.log(output)), output.shape.length - 1)); }); } /** * Categorical crossentropy with integer targets. * * @param target An integer tensor. * @param output A tensor resulting from a softmax (unless `fromLogits` is * `true`, in which case `output` is expected to be the logits). * @param fromLogits Boolean, whether `output` is the result of a softmax, or is * a tensor of logits. */ function sparseCategoricalCrossentropy$1(target, output, fromLogits) { if (fromLogits === void 0) { fromLogits = false; } return tfc.tidy(function () { var flatTarget = tfc__namespace.cast(tfc__namespace.floor(flatten$2(target)), 'int32'); output = tfc__namespace.clipByValue(output, epsilon(), 1 - epsilon()); var outputShape = output.shape; var oneHotTarget = tfc__namespace.reshape(tfc__namespace.oneHot(flatTarget, outputShape[outputShape.length - 1]), outputShape); return categoricalCrossentropy$2(oneHotTarget, output, fromLogits); }); } /** * From TensorFlow's implementation in nn_impl.py: * * For brevity, let `x = logits`, `z = labels`. The logistic loss is * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) * = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) * = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) * = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) * = (1 - z) * x + log(1 + exp(-x)) * = x - x * z + log(1 + exp(-x)) * For x < 0, to avoid overflow in exp(-x), we reformulate the above * x - x * z + log(1 + exp(-x)) * = log(exp(x)) - x * z + log(1 + exp(-x)) * = - x * z + log(1 + exp(x)) * Hence, to ensure stability and avoid overflow, the implementation uses this * equivalent formulation * max(x, 0) - x * z + log(1 + exp(-abs(x))) * * @param labels The labels. * @param logits The logits. */ function sigmoidCrossEntropyWithLogits(labels, logits) { if (!tfc.util.arraysEqual(labels.shape, logits.shape)) { throw new ValueError("logits and labels must have the same shape, but got shapes " + "".concat(JSON.stringify(labels.shape), " and ").concat(JSON.stringify(logits.shape))); } return tfc.tidy(function () { // The logistic loss formula from above is // x - x * z + log(1 + exp(-x)) // For x < 0, a more numerically stable formula is // -x * z + log(1 + exp(x)) // Note that these two expressions can be combined into the following: // max(x, 0) - x * z + log(1 + exp(-abs(x))) var reluLogits = tfc__namespace.relu(logits); var negAbsLogits = tfc__namespace.neg(tfc__namespace.abs(logits)); return tfc__namespace.add(tfc__namespace.sub(reluLogits, tfc__namespace.mul(logits, labels)), tfc__namespace.log1p(tfc__namespace.exp(negAbsLogits))); }); } function binaryCrossentropy$2(yTrue, yPred) { return tfc.tidy(function () { var y; y = tfc__namespace.clipByValue(yPred, epsilon(), 1 - epsilon()); y = tfc__namespace.log(tfc__namespace.div(y, tfc__namespace.sub(1, y))); return tfc__namespace.mean(sigmoidCrossEntropyWithLogits(yTrue, y), -1); }); } function kullbackLeiblerDivergence(yTrue, yPred) { return tfc.tidy(function () { var clippedTrue = tfc__namespace.clipByValue(yTrue, epsilon(), 1); var clippedPred = tfc__namespace.clipByValue(yPred, epsilon(), 1); return tfc__namespace.sum(tfc__namespace.mul(yTrue, tfc__namespace.log(tfc__namespace.div(clippedTrue, clippedPred))), -1); }); } function poisson(yTrue, yPred) { return tfc.tidy(function () { var logPred = tfc__namespace.log(tfc__namespace.add(epsilon(), yPred)); return tfc__namespace.mean(tfc__namespace.sub(yPred, tfc__namespace.mul(yTrue, logPred)), -1); }); } function cosineProximity$1(yTrue, yPred) { return tfc.tidy(function () { var trueNormalized = l2Normalize(yTrue, -1); var predNormalized = l2Normalize(yPred, -1); var trueXPred = tfc__namespace.mul(trueNormalized, predNormalized); return tfc__namespace.neg(tfc__namespace.sum(trueXPred, -1)); }); } // TODO(michaelterry): Add deserialize() function. var lossesMap = { meanSquaredError: meanSquaredError$1, meanAbsoluteError: meanAbsoluteError$1, meanAbsolutePercentageError: meanAbsolutePercentageError$1, meanSquaredLogarithmicError: meanSquaredLogarithmicError, squaredHinge: squaredHinge, hinge: hinge, categoricalHinge: categoricalHinge, logcosh: logcosh, categoricalCrossentropy: categoricalCrossentropy$2, sparseCategoricalCrossentropy: sparseCategoricalCrossentropy$1, binaryCrossentropy: binaryCrossentropy$2, kullbackLeiblerDivergence: kullbackLeiblerDivergence, poisson: poisson, cosineProximity: cosineProximity$1 }; // Porting note: This diverges from the PyKeras implementation and may need to // change based on (de)serialization requirements. function get$1(identifierOrFn) { if (typeof identifierOrFn === 'string') { if (identifierOrFn in lossesMap) { return lossesMap[identifierOrFn]; } var errMsg = "Unknown loss ".concat(identifierOrFn); if (identifierOrFn.toLowerCase().includes('softmaxcrossentropy')) { errMsg = "Unknown loss ".concat(identifierOrFn, ". ") + 'Use "categoricalCrossentropy" as the string name for ' + 'tf.losses.softmaxCrossEntropy'; } throw new ValueError(errMsg); } else { return identifierOrFn; } } function binaryAccuracy$1(yTrue, yPred) { return tfc.tidy(function () { var threshold = tfc__namespace.mul(.5, tfc__namespace.onesLike(yPred)); var yPredThresholded = cast$1(tfc__namespace.greater(yPred, threshold), yTrue.dtype); return tfc__namespace.mean(tfc__namespace.equal(yTrue, yPredThresholded), -1); }); } function categoricalAccuracy$1(yTrue, yPred) { return tfc.tidy(function () { return cast$1(tfc__namespace.equal(tfc__namespace.argMax(yTrue, -1), tfc__namespace.argMax(yPred, -1)), 'float32'); }); } function truePositives(yTrue, yPred) { return tfc.tidy(function () { return tfc__namespace.cast(tfc__namespace.sum(tfc__namespace.logicalAnd(tfc__namespace.equal(yTrue, 1), tfc__namespace.equal(yPred, 1))), 'float32'); }); } function falseNegatives(yTrue, yPred) { return tfc.tidy(function () { return tfc__namespace.cast(tfc__namespace.sum(tfc__namespace.logicalAnd(tfc__namespace.equal(yTrue, 1), tfc__namespace.equal(yPred, 0))), 'float32'); }); } function falsePositives(yTrue, yPred) { return tfc.tidy(function () { return tfc__namespace.cast(tfc__namespace.sum(tfc__namespace.logicalAnd(tfc__namespace.equal(yTrue, 0), tfc__namespace.equal(yPred, 1))), 'float32'); }); } function precision$1(yTrue, yPred) { return tfc.tidy(function () { var tp = truePositives(yTrue, yPred); var fp = falsePositives(yTrue, yPred); var denominator = tfc__namespace.add(tp, fp); return tfc__namespace.cast(tfc__namespace.where(tfc__namespace.greater(denominator, 0), tfc__namespace.div(tp, denominator), 0), 'float32'); }); } function recall$1(yTrue, yPred) { return tfc.tidy(function () { var tp = truePositives(yTrue, yPred); var fn = falseNegatives(yTrue, yPred); var denominator = tfc__namespace.add(tp, fn); return tfc__namespace.cast(tfc__namespace.where(tfc__namespace.greater(denominator, 0), tfc__namespace.div(tp, denominator), 0), 'float32'); }); } function binaryCrossentropy$1(yTrue, yPred) { return binaryCrossentropy$2(yTrue, yPred); } function sparseCategoricalAccuracy$1(yTrue, yPred) { if (yTrue.rank === yPred.rank) { yTrue = tfc__namespace.squeeze(yTrue, [yTrue.rank - 1]); } yPred = tfc__namespace.argMax(yPred, -1); if (yPred.dtype !== yTrue.dtype) { yPred = tfc__namespace.cast(yPred, yTrue.dtype); } return tfc__namespace.cast(tfc__namespace.equal(yTrue, yPred), 'float32'); } // Aliases. var mse$1 = meanSquaredError$1; var MSE$1 = meanSquaredError$1; var mae = meanAbsoluteError$1; var MAE = meanAbsoluteError$1; var mape$1 = meanAbsolutePercentageError$1; var MAPE$1 = meanAbsolutePercentageError$1; var categoricalCrossentropy$1 = categoricalCrossentropy$2; var cosine = cosineProximity$1; var sparseCategoricalCrossentropy = sparseCategoricalCrossentropy$1; // TODO(cais, nielsene): Add serialize(). var metricsMap = { binaryAccuracy: binaryAccuracy$1, categoricalAccuracy: categoricalAccuracy$1, precision: precision$1, categoricalCrossentropy: categoricalCrossentropy$1, sparseCategoricalCrossentropy: sparseCategoricalCrossentropy, mse: mse$1, MSE: MSE$1, mae: mae, MAE: MAE, mape: mape$1, MAPE: MAPE$1, cosine: cosine }; function get(identifier) { if (typeof identifier === 'string' && identifier in metricsMap) { return metricsMap[identifier]; } else if (typeof identifier !== 'string' && identifier != null) { return identifier; } else { throw new ValueError("Unknown metric ".concat(identifier)); } } /** * Get the shortcut function name. * * If the fn name is a string, * directly return the string name. * If the function is included in metricsMap or lossesMap, * return key of the map. * - If the function relative to multiple keys, * return the first found key as the function name. * - If the function exists in both lossesMap and metricsMap, * search lossesMap first. * If the function is not included in metricsMap or lossesMap, * return the function name. * * @param fn loss function, metric function, or short cut name. * @returns Loss or Metric name in string. */ function getLossOrMetricName(fn) { var e_1, _a, e_2, _b; assert$1(fn !== null, "Unknown LossOrMetricFn ".concat(fn)); if (typeof fn === 'string') { return fn; } else { var fnName = void 0; try { for (var _c = __values(Object.keys(lossesMap)), _d = _c.next(); !_d.done; _d = _c.next()) { var key = _d.value; if (lossesMap[key] === fn) { fnName = key; break; } } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (_d && !_d.done && (_a = _c.return)) _a.call(_c); } finally { if (e_1) throw e_1.error; } } if (fnName !== undefined) { return fnName; } try { for (var _e = __values(Object.keys(metricsMap)), _f = _e.next(); !_f.done; _f = _e.next()) { var key = _f.value; if (metricsMap[key] === fn) { fnName = key; break; } } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (_f && !_f.done && (_b = _e.return)) _b.call(_e); } finally { if (e_2) throw e_2.error; } } if (fnName !== undefined) { return fnName; } return fn.name; } } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ // Add (de)serialize() // Porting note: This diverges from the PyKeras implementation and may need to // change based on (de)serialization requirements. function getOptimizer(identifier) { var optimizerMap = { 'Adagrad': function () { return tfc.train.adagrad(0.01); }, 'Adadelta': function () { return tfc.train.adadelta(1, 0.95, epsilon()); }, 'Adam': function () { return tfc.train.adam(0.001, 0.9, 0.999, epsilon()); }, 'Adamax': function () { return tfc.train.adamax(0.002, 0.9, 0.999, epsilon(), 0); }, 'RMSProp': function () { return tfc.train.rmsprop(0.001, 0.9, 0, epsilon()); }, 'SGD': function () { return tfc.train.sgd(0.01); } }; optimizerMap['adagrad'] = optimizerMap['Adagrad']; optimizerMap['adadelta'] = optimizerMap['Adadelta']; optimizerMap['adam'] = optimizerMap['Adam']; optimizerMap['adamax'] = optimizerMap['Adamax']; optimizerMap['rmsprop'] = optimizerMap['RMSProp']; optimizerMap['sgd'] = optimizerMap['SGD']; if (identifier in optimizerMap) { return optimizerMap[identifier](); } throw new ValueError("Unknown Optimizer ".concat(identifier)); } /** * @license * Copyright 2019 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** Utility functions related to user-defined metadata. */ // Maximum recommended serialized size for user-defined metadata. // Beyond this limit, a warning message will be printed during model loading and // saving. var MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH = 1 * 1024 * 1024; /** * Check validity of user-defined metadata. * * @param userDefinedMetadata * @param modelName Name of the model that the user-defined metadata belongs to. * Used during construction of error messages. * @param checkSize Whether to check the size of the metadata is under * recommended limit. Default: `false`. If `true`, will try stringify the * JSON object and print a console warning if the serialzied size is above the * limit. * @throws Error if `userDefinedMetadata` is not a plain JSON object. */ function checkUserDefinedMetadata(userDefinedMetadata, modelName, checkSize) { if (checkSize === void 0) { checkSize = false; } if (userDefinedMetadata == null || typeof userDefinedMetadata !== 'object' || Object.getPrototypeOf(userDefinedMetadata) !== Object.prototype || !plainObjectCheck(userDefinedMetadata)) { throw new Error('User-defined metadata is expected to be a JSON object, but is not.'); } if (checkSize) { var out = JSON.stringify(userDefinedMetadata); if (out.length > MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH) { console.warn("User-defined metadata of model \"".concat(modelName, "\" is too large in ") + "size (length=".concat(out.length, " when serialized). It is not ") + "recommended to store such large objects in user-defined metadata. " + "Please make sure its serialized length is <= " + "".concat(MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH, ".")); } } } /** * Check if an input is plain JSON object or any valid subfield of it. * * @param x The input to be checked. * @param assertObject Whether to assert `x` is a JSON object, i.e., reject * cases of arrays and primitives. * @return Returns `true` if and only if `x` is a plain JSON object, * a JSON-valid primitive including string, number, boolean and null, * or an array of the said types. */ // tslint:disable-next-line:no-any function plainObjectCheck(x) { var e_1, _a, e_2, _b; if (x === null) { // Note: typeof `null` is 'object', and `null` is valid in JSON. return true; } else if (typeof x === 'object') { if (Object.getPrototypeOf(x) === Object.prototype) { // `x` is a JavaScript object and its prototype is Object. var keys = Object.keys(x); try { for (var keys_1 = __values(keys), keys_1_1 = keys_1.next(); !keys_1_1.done; keys_1_1 = keys_1.next()) { var key = keys_1_1.value; if (typeof key !== 'string') { // JSON keys must be strings. return false; } if (!plainObjectCheck(x[key])) { // Recursive call. return false; } } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (keys_1_1 && !keys_1_1.done && (_a = keys_1.return)) _a.call(keys_1); } finally { if (e_1) throw e_1.error; } } return true; } else { // `x` is a JavaScript object but its prototype is not Object. if (Array.isArray(x)) { try { // `x` is a JavaScript array. for (var x_1 = __values(x), x_1_1 = x_1.next(); !x_1_1.done; x_1_1 = x_1.next()) { var item = x_1_1.value; if (!plainObjectCheck(item)) { // Recursive call. return false; } } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (x_1_1 && !x_1_1.done && (_b = x_1.return)) _b.call(x_1); } finally { if (e_2) throw e_2.error; } } return true; } else { // `x` is a JavaScript object and its prototype is not Object, // and it's not an Array. I.e., it's a complex object such as // `Error` and `Date`. return false; } } } else { // `x` is not a JavaScript object or `null`. var xType = typeof x; return xType === 'string' || xType === 'number' || xType === 'boolean'; } } /** * Print the summary of a LayersModel object. * * @param model tf.LayersModel instance. * @param lineLength Total length of printed lines. Set this to adapt to the * display to different terminal or console sizes. * @param positions Relative or absolute positions of log elements in each * line. Each number corresponds to right-most (i.e., ending) position of a * column. * If not provided, defaults to `[0.45, 0.85, 1]` for sequential-like * models and `[0.33, 0.55, 0.67, 1]` for non-sequential like models. * @param printFn Print function to use. * It will be called on each line of the summary. You can provide a custom * function in order to capture the string summary. Defaults to `console.log`. */ function printSummary(model, lineLength, positions, // tslint:disable-next-line:no-any printFn) { if (printFn === void 0) { printFn = console.log; } var sequentialLike = isModelSequentialLike(model); // Header names for different log elements. var toDisplay = ['Layer (type)', 'Input Shape', 'Output shape', 'Param #']; if (sequentialLike) { lineLength = lineLength || 90; positions = positions || [0.32, 0.61, 0.89, 1]; } else { lineLength = lineLength || 115; positions = positions || [0.24, 0.48, 0.70, 0.80, 1]; // Header names for different log elements. } if (positions[positions.length - 1] <= 1) { // `positions` is relative. Convert it to absolute positioning. positions = positions.map(function (p) { return Math.floor(lineLength * p); }); } var relevantNodes; if (!sequentialLike) { toDisplay.push('Receives inputs'); relevantNodes = []; for (var depth in model.nodesByDepth) { relevantNodes.push.apply(relevantNodes, __spreadArray([], __read(model.nodesByDepth[depth]), false)); } } printFn('_'.repeat(lineLength)); printRow(toDisplay, positions, printFn); printFn('='.repeat(lineLength)); var layers = model.layers; for (var i = 0; i < layers.length; ++i) { if (sequentialLike) { printLayerSummary(layers[i], positions, printFn); } else { printLayerSummaryWithConnections(layers[i], positions, relevantNodes, printFn); } printFn((i === layers.length - 1 ? '=' : '_').repeat(lineLength)); } // tslint:disable-next-line:no-any model.checkTrainableWeightsConsistency(); var trainableCount = countTrainableParams(model); var nonTrainableCount = countParamsInWeights(model.nonTrainableWeights); printFn("Total params: ".concat(trainableCount + nonTrainableCount)); printFn("Trainable params: ".concat(trainableCount)); printFn("Non-trainable params: ".concat(nonTrainableCount)); printFn('_'.repeat(lineLength)); } function countTrainableParams(model) { var trainableCount; // tslint:disable:no-any if (model.collectedTrainableWeights != null) { trainableCount = countParamsInWeights(model.collectedTrainableWeights); } else { trainableCount = countParamsInWeights(model.trainableWeights); } // tslint:enable:no-any return trainableCount; } function isModelSequentialLike(model) { var e_1, _a, e_2, _b, e_3, _c; var sequentialLike = true; var nodesByDepth = []; var nodes = []; for (var depth in model.nodesByDepth) { nodesByDepth.push(model.nodesByDepth[depth]); } try { for (var nodesByDepth_1 = __values(nodesByDepth), nodesByDepth_1_1 = nodesByDepth_1.next(); !nodesByDepth_1_1.done; nodesByDepth_1_1 = nodesByDepth_1.next()) { var depthNodes = nodesByDepth_1_1.value; if (depthNodes.length > 1 || depthNodes.length === 1 && depthNodes[0].inboundLayers.length > 1) { sequentialLike = false; break; } nodes.push.apply(nodes, __spreadArray([], __read(depthNodes), false)); } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (nodesByDepth_1_1 && !nodesByDepth_1_1.done && (_a = nodesByDepth_1.return)) _a.call(nodesByDepth_1); } finally { if (e_1) throw e_1.error; } } if (sequentialLike) { try { // Search for shared layers. for (var _d = __values(model.layers), _e = _d.next(); !_e.done; _e = _d.next()) { var layer = _e.value; var flag = false; try { for (var _f = (e_3 = void 0, __values(layer.inboundNodes)), _g = _f.next(); !_g.done; _g = _f.next()) { var node = _g.value; if (nodes.indexOf(node) !== -1) { if (flag) { sequentialLike = false; break; } else { flag = true; } } } } catch (e_3_1) { e_3 = { error: e_3_1 }; } finally { try { if (_g && !_g.done && (_c = _f.return)) _c.call(_f); } finally { if (e_3) throw e_3.error; } } if (!sequentialLike) { break; } } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (_e && !_e.done && (_b = _d.return)) _b.call(_d); } finally { if (e_2) throw e_2.error; } } } return sequentialLike; } function printRow(fields, positions, // tslint:disable-next-line:no-any printFn) { if (printFn === void 0) { printFn = console.log; } var line = ''; for (var i = 0; i < fields.length; ++i) { if (i > 0) { line = line.slice(0, line.length - 1) + ' '; } line += fields[i]; line = line.slice(0, positions[i]); line += ' '.repeat(positions[i] - line.length); } printFn(line); } /** * Prints a summary for a single Layer, without connectivity information. * * @param layer: Layer instance to print. */ function printLayerSummary(layer, positions, // tslint:disable-next-line:no-any printFn) { var outputShape; var inputShape; try { inputShape = (layer.inboundNodes.map(function (x) { return JSON.stringify(x.inputShapes); })).join(','); } catch (err) { inputShape = 'multiple'; } try { outputShape = JSON.stringify(layer.outputShape); } catch (err) { outputShape = 'multiple'; } var name = layer.name; var className = layer.getClassName(); var fields = ["".concat(name, " (").concat(className, ")"), inputShape, outputShape, layer.countParams().toString()]; printRow(fields, positions, printFn); } /** * Prints a summary for a single Layer, with connectivity information. */ function printLayerSummaryWithConnections(layer, positions, relevantNodes, // tslint:disable-next-line:no-any printFn) { var e_4, _a; var outputShape; var inputShape; try { inputShape = (layer.inboundNodes.map(function (x) { return JSON.stringify(x.inputShapes); })).join(','); } catch (err) { inputShape = 'multiple'; } try { outputShape = JSON.stringify(layer.outputShape); } catch (err) { outputShape = 'multiple'; } var connections = []; try { for (var _b = __values(layer.inboundNodes), _c = _b.next(); !_c.done; _c = _b.next()) { var node = _c.value; if (relevantNodes != null && relevantNodes.length > 0 && relevantNodes.indexOf(node) === -1) { continue; } for (var i = 0; i < node.inboundLayers.length; ++i) { var inboundLayer = node.inboundLayers[i].name; var inboundLayerIndex = node.nodeIndices[i]; var inboundTensorIndex = node.tensorIndices[i]; connections.push("".concat(inboundLayer, "[").concat(inboundLayerIndex, "][").concat(inboundTensorIndex, "]")); } } } catch (e_4_1) { e_4 = { error: e_4_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_4) throw e_4.error; } } var name = layer.name; var className = layer.getClassName(); var firstConnection = connections.length === 0 ? '' : connections[0]; var fields = [ "".concat(name, " (").concat(className, ")"), inputShape, outputShape, layer.countParams().toString(), firstConnection ]; printRow(fields, positions, printFn); for (var i = 1; i < connections.length; ++i) { printRow(['', '', '', '', connections[i]], positions, printFn); } } // tslint:enable /** * Test whether a value in an array is the name of a LayersModel or Layer. * @param key The key name that the value is found under. Note that the key * may not be at the level immediately above the value, if the value is in a * nested array. * @param index Index of the value in the Array that it is found in. * @param value The value object. * @returns A boolean indicating whether value is a name. */ function isArrayItemInputOrOutputName(key, index, value) { return (key === 'inboundNodes' || key === 'outputLayers' || key === 'inputLayers') && index === 0 && typeof value === 'string'; } /** * Convert a Pythonic config object to TypeScript config object. * @param pythonicConfig The config object to convert. * @param key Optional key name of the object being converted. * @returns Result of the conversion. */ function convertPythonicToTs(pythonicConfig, key) { var e_1, _a; if (pythonicConfig === null) { return null; } else if (typeof pythonicConfig === 'string') { return toCamelCase(pythonicConfig); } else if ((typeof pythonicConfig === 'number') || (typeof pythonicConfig === 'boolean')) { return pythonicConfig; } else if (pythonicConfig instanceof Array) { var tsArray = []; var arrayLength = pythonicConfig.length; for (var i = 0; i < arrayLength; ++i) { var item = pythonicConfig[i]; if (isArrayItemInputOrOutputName(key, i, item)) { tsArray.push(item); } else { tsArray.push(convertPythonicToTs(item, key)); } } return tsArray; } else { var tsDict = {}; try { for (var _b = __values(Object.keys(pythonicConfig)), _c = _b.next(); !_c.done; _c = _b.next()) { var pythonicKey = _c.value; var pythonicValue = pythonicConfig[pythonicKey]; if (pythonicKey === 'name' && typeof pythonicValue === 'string') { // Special case the 'name' key with a string value. Name values, such as // the names of LayersModel and Layer instances, should not undergo the // camel-case conversion. tsDict[pythonicKey] = pythonicValue; } else { var tsKey = toCamelCase(pythonicKey); tsDict[tsKey] = convertPythonicToTs(pythonicValue, tsKey); } } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_1) throw e_1.error; } } return tsDict; } } /** * Convert a TypeScript config object to Python config object. * @param tsConfig The config object to convert. * @param key Optional key name of the object being converted. * @returns Result of the conversion. */ function convertTsToPythonic(tsConfig, key) { var e_2, _a; if (tsConfig === null || tsConfig === undefined) { return null; } else if (typeof tsConfig === 'string') { return toSnakeCase(tsConfig); } else if ((typeof tsConfig === 'number') || (typeof tsConfig === 'boolean')) { return tsConfig; } else if (tsConfig instanceof Array) { var pyArray = []; var arrayLength = tsConfig.length; for (var i = 0; i < arrayLength; ++i) { var item = tsConfig[i]; if (isArrayItemInputOrOutputName(key, i, item)) { pyArray.push(item); } else { pyArray.push(convertTsToPythonic(item, key)); } } return pyArray; } else { var pyDict = {}; try { for (var _b = __values(Object.keys(tsConfig)), _c = _b.next(); !_c.done; _c = _b.next()) { var tsKey = _c.value; var tsValue = tsConfig[tsKey]; var pyKey = toSnakeCase(tsKey); if ((tsKey === 'name' || tsKey === 'className') && typeof tsValue === 'string') { // Special case the 'name' key with a string value. Name values, such as // the names of LayersModel and Layer instances, should not undergo the // snake-case conversion. pyDict[pyKey] = tsValue; } else { pyDict[pyKey] = convertTsToPythonic(tsValue, tsKey); } } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_2) throw e_2.error; } } return pyDict; } } /** @license See the LICENSE file. */ // This code is auto-generated, do not modify this file! var version = '4.15.0'; // get weights key from tensor map in order to check if it is from keras v3. // e.g. dense/0 var isKerasSavedModelFormat = function (weights) { var keys = Object.keys(weights); if (keys.length === 0) { return false; } var key = keys[0].split('/'); return !isNaN(parseInt(key[key.length - 1], 10)); }; /** * A Container is a directed acyclic graph of layers. * * It is the topological form of a "model". A LayersModel * is simply a Container with added training routines. * */ var Container = /** @class */ (function (_super) { __extends(Container, _super); function Container(args) { var e_1, _a, e_2, _b, e_3, _c, e_4, _d, e_5, _e, e_6, _f, e_7, _g, e_8, _h, e_9, _j, e_10, _k, e_11, _l, e_12, _m; var _this = // No args passed to super's constructor. _super.call(this, {}) || this; _this.containerNodes = new Set(); _this.name = args.name; if (_this.name == null) { var prefix = _this.getClassName().toLowerCase(); _this.name = getUid(prefix); } _this.supportsMasking = false; _this.trainable_ = true; // TODO(michaelterry): Initialize perInputLosses/Updates here. // Container-specific properties. if (Array.isArray(args.inputs)) { _this.inputs = args.inputs.slice(); } else { _this.inputs = [args.inputs]; } if (Array.isArray(args.outputs)) { _this.outputs = args.outputs.slice(); } else { _this.outputs = [args.outputs]; } // Check for redundancy in inputs. if (unique(_this.inputs).length !== _this.inputs.length) { throw new ValueError('The list of inputs passed to the model is ' + 'redundant. All inputs should only appear once. Found: ' + "".concat(_this.inputs.map(function (x) { return x.name; }))); } // Check for redundancy in outputs. if (unique(_this.outputs).length !== _this.outputs.length) { console.warn('The list of outputs passed to the model is redundant. ' + 'All outputs should only appear once. Found: ' + "".concat(_this.outputs.map(function (x) { return x.name; }))); } /* List of initial layers (1 to 1 mapping with this.inputs, hence the same layer might appear twice) */ _this.inputLayers = []; _this.inputLayersNodeIndices = []; _this.inputLayersTensorIndices = []; /* List of layers (1 to 1 mapping with this.outputs, hence the same layer might appear twice) */ _this.outputLayers = []; _this.outputLayersNodeIndices = []; _this.outputLayersTensorIndices = []; /* All layers in order of horizontal graph traversal. Entries are unique. Includes input and output layers. */ _this.layers = []; /* References to container layers that were constructed internally. We need these to properly dispose of tensors from nested containers. */ _this.internalContainerRefs = []; try { // TODO(michaelterry): Determine if caching still needed with eager // backend. /* This is for performance optimization when calling the Container on new inputs. Every time the Container is called on a set on input tensors, we compute the output tensors, output masks and output shapes in one pass, then cache them here. When one of these outputs is queried later, we retrieve it from there instead of recomputing it. */ // this.outputTensorCache = {}; // this.outputShapeCache = {}; // Build this.outputLayers: for (var _o = __values(_this.outputs), _p = _o.next(); !_p.done; _p = _o.next()) { var x = _p.value; var layer = x.sourceLayer; var nodeIndex = x.nodeIndex; var tensorIndex = x.tensorIndex; _this.outputLayers.push(layer); _this.outputLayersNodeIndices.push(nodeIndex); _this.outputLayersTensorIndices.push(tensorIndex); } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (_p && !_p.done && (_a = _o.return)) _a.call(_o); } finally { if (e_1) throw e_1.error; } } try { // TODO(michaelterry): Add output mask cache code. // Build this.inputLayers: for (var _q = __values(_this.inputs), _r = _q.next(); !_r.done; _r = _q.next()) { var x = _r.value; var layer = x.sourceLayer; var nodeIndex = x.nodeIndex; var tensorIndex = x.tensorIndex; /* It's supposed to be an input layer, so only one node and one tensor output. */ assert$1(nodeIndex === 0, 'input layer has >1 nodes'); assert$1(tensorIndex === 0, 'input layer has >1 tensors'); _this.inputLayers.push(layer); _this.inputLayersNodeIndices.push(nodeIndex); _this.inputLayersTensorIndices.push(tensorIndex); } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (_r && !_r.done && (_b = _q.return)) _b.call(_q); } finally { if (e_2) throw e_2.error; } } // Build this.inputNames and this.outputNames. _this.inputNames = []; _this.outputNames = []; _this.feedInputShapes = []; _this.feedInputNames = []; _this.feedOutputNames = []; for (var i = 0; i < _this.inputLayers.length; i++) { var layer = _this.inputLayers[i]; // Check that layer is an InputLayer. if (!(layer instanceof InputLayer)) { throw new TypeError('Input layers to a LayersModel must be InputLayer objects. ' + "Received inputs: ".concat(args.inputs, ". ") + "Input ".concat(i, " (0-based) originates ") + "from layer type ".concat(layer.getClassName(), ".")); } _this.inputNames.push(layer.name); _this.feedInputShapes.push(layer.batchInputShape); _this.feedInputNames.push(layer.name); } try { for (var _s = __values(_this.outputLayers), _t = _s.next(); !_t.done; _t = _s.next()) { var layer = _t.value; _this.outputNames.push(layer.name); } } catch (e_3_1) { e_3 = { error: e_3_1 }; } finally { try { if (_t && !_t.done && (_c = _s.return)) _c.call(_s); } finally { if (e_3) throw e_3.error; } } _this.internalInputShapes = _this.inputs.map(function (x) { return x.shape; }); _this.internalOutputShapes = _this.outputs.map(function (x) { return x.shape; }); /* Container_nodes: set of nodes included in the graph (not all nodes included in the layers are relevant to the current graph). */ // ids of all nodes relevant to the Container: var nodesDepths = {}; // To recover nodes from their ID. var nodeIDToNode = {}; var layersDepths = {}; // To layers from their ID. var layerIDToLayer = {}; var layerIndices = {}; var nodesInDecreasingDepth = []; /** * Builds a map of the graph of layers. * * This recursively updates the map `layerIndices`, * the list `nodesInDecreasingDepth` and the set `containerNodes`. * * @param tensor Some tensor in a graph. * @param finishedNodes Set of nodes whose subgraphs have been traversed * completely. Useful to prevent duplicated work. * @param nodesInProgress Set of nodes that are currently active on the * recursion stack. Useful to detect cycles. * @param layer Layer from which `tensor` comes from. If not provided, * will be obtained from tensor.sourceLayer. * @param nodeIndex Node index from which `tensor` comes from. * @param tensorIndex TensorIndex from which `tensor` comes from. * * @exception RuntimeError if a cycle is detected. */ var buildMapOfGraph = function (tensor, finishedNodes, nodesInProgress, layer, nodeIndex, tensorIndex) { if (layer == null || nodeIndex == null || tensorIndex == null) { layer = tensor.sourceLayer; nodeIndex = tensor.nodeIndex; tensorIndex = tensor.tensorIndex; } var node = layer.inboundNodes[nodeIndex]; // Prevent cycles. if (nodesInProgress.indexOf(node) !== -1) { throw new RuntimeError("The tensor ".concat(tensor.name, " at layer \"").concat(layer.name, "\" ") + 'is part of a cycle.'); } // Don't repeat work for shared subgraphs if (finishedNodes.indexOf(node) !== -1) { return; } // Update containerNodes. _this.containerNodes.add(Container.nodeKey(layer, nodeIndex)); // Store the traversal order for layer sorting. if (!(layer.id in layerIndices)) { layerIndices[layer.id] = Object.keys(layerIndices).length; } if (nodesInProgress.indexOf(node) === -1) { nodesInProgress.push(node); } // Propagate to all previous tensors connected to this node. var numInboundLayers = node.inboundLayers.length; for (var i = 0; i < numInboundLayers; i++) { var x = node.inputTensors[i]; var layer_1 = node.inboundLayers[i]; var nodeIndex_1 = node.nodeIndices[i]; var tensorIndex_1 = node.tensorIndices[i]; buildMapOfGraph(x, finishedNodes, nodesInProgress, layer_1, nodeIndex_1, tensorIndex_1); } finishedNodes.push(node); while (nodesInProgress.indexOf(node) >= 0) { nodesInProgress.splice(nodesInProgress.indexOf(node), 1); } nodesInDecreasingDepth.push(node); }; var finishedNodes = []; var nodesInProgress = []; try { for (var _u = __values(_this.outputs), _v = _u.next(); !_v.done; _v = _u.next()) { var x = _v.value; buildMapOfGraph(x, finishedNodes, nodesInProgress); } } catch (e_4_1) { e_4 = { error: e_4_1 }; } finally { try { if (_v && !_v.done && (_d = _u.return)) _d.call(_u); } finally { if (e_4) throw e_4.error; } } var reversedNodesInDecreasingDepth = nodesInDecreasingDepth.slice().reverse(); try { for (var reversedNodesInDecreasingDepth_1 = __values(reversedNodesInDecreasingDepth), reversedNodesInDecreasingDepth_1_1 = reversedNodesInDecreasingDepth_1.next(); !reversedNodesInDecreasingDepth_1_1.done; reversedNodesInDecreasingDepth_1_1 = reversedNodesInDecreasingDepth_1.next()) { var node = reversedNodesInDecreasingDepth_1_1.value; nodeIDToNode[node.id] = node; // If the depth is not set, the node has no outbound nodes (depth 0). if (!(node.id in nodesDepths)) { nodesDepths[node.id] = 0; } var depth = nodesDepths[node.id]; // Update the depth of the corresponding layer var previousDepth = (layersDepths[node.outboundLayer.id] == null ? 0 : layersDepths[node.outboundLayer.id]); /* If we've seen this layer before at a higher depth, we should use that depth instead of the node depth. This is necessary for shared layers that have inputs at different depth levels in the graph. */ depth = Math.max(depth, previousDepth); layersDepths[node.outboundLayer.id] = depth; layerIDToLayer[node.outboundLayer.id] = node.outboundLayer; nodesDepths[node.id] = depth; // Update the depth of inbound nodes. for (var i = 0; i < node.inboundLayers.length; i++) { var inboundLayer = node.inboundLayers[i]; var nodeIndex = node.nodeIndices[i]; var inboundNode = inboundLayer.inboundNodes[nodeIndex]; var previousDepth_1 = (nodesDepths[inboundNode.id] == null ? 0 : nodesDepths[inboundNode.id]); nodesDepths[inboundNode.id] = Math.max(depth + 1, previousDepth_1); nodeIDToNode[inboundNode.id] = inboundNode; } } } catch (e_5_1) { e_5 = { error: e_5_1 }; } finally { try { if (reversedNodesInDecreasingDepth_1_1 && !reversedNodesInDecreasingDepth_1_1.done && (_e = reversedNodesInDecreasingDepth_1.return)) _e.call(reversedNodesInDecreasingDepth_1); } finally { if (e_5) throw e_5.error; } } // Build a dict {depth: list of nodes with this depth} var nodesByDepth = {}; for (var nodeID in nodesDepths) { var depth = nodesDepths[nodeID]; if (!(depth in nodesByDepth)) { nodesByDepth[depth] = []; } nodesByDepth[depth].push(nodeIDToNode[nodeID]); } // Build a dict {depth: list of layers with this depth} var layersByDepth = {}; for (var layerID in layersDepths) { var depth = layersDepths[layerID]; if (!(depth in layersByDepth)) { layersByDepth[depth] = []; } layersByDepth[depth].push(layerIDToLayer[layerID]); } // Get sorted list of layer depths. var depthKeys = Object.keys(layersByDepth) .map(function (x) { return parseInt(x, 10); }) .sort(reverseNumberCompare); // Set this.layers and this.layersByDepth. _this.layers = []; try { for (var depthKeys_1 = __values(depthKeys), depthKeys_1_1 = depthKeys_1.next(); !depthKeys_1_1.done; depthKeys_1_1 = depthKeys_1.next()) { var depth = depthKeys_1_1.value; var layersForDepth = layersByDepth[depth]; // Container.layers needs to have a deterministic order: // here we order them by traversal order. layersForDepth.sort(function (a, b) { var aIndex = layerIndices[a.id]; var bIndex = layerIndices[b.id]; if (aIndex < bIndex) { return -1; } if (aIndex > bIndex) { return 1; } return 0; }); try { for (var layersForDepth_1 = (e_7 = void 0, __values(layersForDepth)), layersForDepth_1_1 = layersForDepth_1.next(); !layersForDepth_1_1.done; layersForDepth_1_1 = layersForDepth_1.next()) { var layer = layersForDepth_1_1.value; if (layer instanceof Container) { _this.internalContainerRefs.push(layer); } _this.layers.push(layer); } } catch (e_7_1) { e_7 = { error: e_7_1 }; } finally { try { if (layersForDepth_1_1 && !layersForDepth_1_1.done && (_g = layersForDepth_1.return)) _g.call(layersForDepth_1); } finally { if (e_7) throw e_7.error; } } } } catch (e_6_1) { e_6 = { error: e_6_1 }; } finally { try { if (depthKeys_1_1 && !depthKeys_1_1.done && (_f = depthKeys_1.return)) _f.call(depthKeys_1); } finally { if (e_6) throw e_6.error; } } _this.layersByDepth = layersByDepth; // Get sorted list of node depths; depthKeys = Object.keys(nodesByDepth) .map(function (x) { return parseInt(x, 10); }) .sort(reverseNumberCompare); // Check that all tensors required are computable. // computable_tensors: all tensors in the graph // that can be computed from the inputs provided. var computableTensors = _this.inputs.slice(); // To provide a better error msg. var layersWithCompleteInput = []; try { for (var depthKeys_2 = __values(depthKeys), depthKeys_2_1 = depthKeys_2.next(); !depthKeys_2_1.done; depthKeys_2_1 = depthKeys_2.next()) { var depth = depthKeys_2_1.value; try { for (var _w = (e_9 = void 0, __values(nodesByDepth[depth])), _x = _w.next(); !_x.done; _x = _w.next()) { var node = _x.value; var layer = node.outboundLayer; if (layer != null) { try { for (var _y = (e_10 = void 0, __values(node.inputTensors)), _z = _y.next(); !_z.done; _z = _y.next()) { var x = _z.value; if (computableTensors.indexOf(x) === -1) { throw new RuntimeError("Graph disconnected: cannot obtain value for tensor ".concat(x) + " at layer \"".concat(layer.name, "\". ") + 'The following previous layers were accessed without ' + "issue: ".concat(layersWithCompleteInput)); } } } catch (e_10_1) { e_10 = { error: e_10_1 }; } finally { try { if (_z && !_z.done && (_k = _y.return)) _k.call(_y); } finally { if (e_10) throw e_10.error; } } try { for (var _0 = (e_11 = void 0, __values(node.outputTensors)), _1 = _0.next(); !_1.done; _1 = _0.next()) { var x = _1.value; computableTensors.push(x); } } catch (e_11_1) { e_11 = { error: e_11_1 }; } finally { try { if (_1 && !_1.done && (_l = _0.return)) _l.call(_0); } finally { if (e_11) throw e_11.error; } } layersWithCompleteInput.push(layer.name); } } } catch (e_9_1) { e_9 = { error: e_9_1 }; } finally { try { if (_x && !_x.done && (_j = _w.return)) _j.call(_w); } finally { if (e_9) throw e_9.error; } } } } catch (e_8_1) { e_8 = { error: e_8_1 }; } finally { try { if (depthKeys_2_1 && !depthKeys_2_1.done && (_h = depthKeys_2.return)) _h.call(depthKeys_2); } finally { if (e_8) throw e_8.error; } } // Set this.containerNodes and this.nodesByDepth. _this.nodesByDepth = nodesByDepth; // Ensure name unicity, which will be crucial for serialization // (since serialized nodes refer to layers by their name). var allNames = _this.layers.map(function (x) { return x.name; }); var _loop_1 = function (name) { var numOccurrences = allNames.filter(function (x) { return x === name; }).length; if (numOccurrences !== 1) { throw new RuntimeError("The name \"".concat(name, "\" is used ").concat(numOccurrences, " times ") + 'in the model. All layer names should be unique. Layer names: ' + JSON.stringify(allNames)); } }; try { for (var allNames_1 = __values(allNames), allNames_1_1 = allNames_1.next(); !allNames_1_1.done; allNames_1_1 = allNames_1.next()) { var name = allNames_1_1.value; _loop_1(name); } } catch (e_12_1) { e_12 = { error: e_12_1 }; } finally { try { if (allNames_1_1 && !allNames_1_1.done && (_m = allNames_1.return)) _m.call(allNames_1); } finally { if (e_12) throw e_12.error; } } // Layer parameters. // The new container starts with a single inbound node // for its inputs, and no outbound nodes. // Will be appended to by future calls to apply(). _this.outboundNodes = []; // Will be appended to below, and by future calls to apply(). _this.inboundNodes = []; // Create the node linking internal inputs to internal outputs. // (This call has side effects.) // tslint:disable-next-line:no-unused-expression new Node({ outboundLayer: _this, inboundLayers: [], nodeIndices: [], tensorIndices: [], inputTensors: _this.inputs, outputTensors: _this.outputs, inputMasks: _this.inputs.map(function (x) { return null; }), outputMasks: _this.outputs.map(function (x) { return null; }), inputShapes: _this.inputs.map(function (x) { return x.shape; }), outputShapes: _this.outputs.map(function (x) { return x.shape; }) }); _this.built = true; _this._refCount = 1; // The ref count of a container always start at 1. return _this; } Container.prototype.assertNotDisposed = function () { if (this._refCount === 0) { throw new Error("Container '".concat(this.name, "' is already disposed.")); } }; /** * Attempt to dispose a LayersModel's weights. * * This method decrease the reference count of the LayersModel object by 1. * * A LayersModel is reference-counted. Its reference count is incremented by 1 * when it is first constructed and when it is used as a Layer of another * LayersModel. * * If the reference count of a LayersModel becomes 0, the `dispose` method of * all its constituent `Layer`s will be called. * * Note: If the reference count is greater than 0 after the decrement, the * `dispose` method of its constituent `Layer`s will *not* be called. * * After a LayersModel is disposed, it cannot be used in calls such as * 'predict`, `evaluate` or `fit` anymore. * * @returns A DisposeResult Object with the following fields: * - refCountAfterDispose: The reference count of the LayersModel after this * `dispose()` call. * - numDisposedVariables: Number of `tf.Variable`s (i.e., weights) disposed * during this `dispose()` call. * @throws {Error} If the layer is not built yet, or if the LayersModel has * already been disposed. */ Container.prototype.dispose = function () { var e_13, _a, e_14, _b; this.assertNotDisposed(); var result = { refCountAfterDispose: null, numDisposedVariables: 0 }; if (--this._refCount === 0) { try { for (var _c = __values(this.layers), _d = _c.next(); !_d.done; _d = _c.next()) { var layer = _d.value; result.numDisposedVariables += layer.dispose().numDisposedVariables; } } catch (e_13_1) { e_13 = { error: e_13_1 }; } finally { try { if (_d && !_d.done && (_a = _c.return)) _a.call(_c); } finally { if (e_13) throw e_13.error; } } try { // Call dispose on each internally created container layer again to ensure // their refCounts hit zero and their tensors are subsequently deleted. for (var _e = __values(this.internalContainerRefs), _f = _e.next(); !_f.done; _f = _e.next()) { var container = _f.value; result.numDisposedVariables += container.dispose().numDisposedVariables; } } catch (e_14_1) { e_14 = { error: e_14_1 }; } finally { try { if (_f && !_f.done && (_b = _e.return)) _b.call(_e); } finally { if (e_14) throw e_14.error; } } } result.refCountAfterDispose = this._refCount; return result; }; Object.defineProperty(Container.prototype, "trainable", { get: function () { return this.trainable_; }, set: function (trainable) { this.layers.forEach(function (layer) { // tslint:disable-next-line:no-any layer._trainableWeights .forEach(function (w) { return w.trainable = trainable; }); }); this.trainable_ = trainable; }, enumerable: false, configurable: true }); Object.defineProperty(Container.prototype, "trainableWeights", { get: function () { var e_15, _a; // Porting Note: This check below is to prevent errors where the // _trainableWeights inherited from the parent class (Layer) gets // inadvertently used. if (this._trainableWeights.length > 0) { throw new ValueError('Container instance unexpectedly contains _trainableWeights.' + 'The trainable weights of a Container are a union of the ' + 'trainable weights of its consituent Layers. Its own ' + '_trainableWeights must remain an empty Array.'); } if (!this.trainable) { return []; } var weights = []; try { for (var _b = __values(this.layers), _c = _b.next(); !_c.done; _c = _b.next()) { var layer = _c.value; weights = weights.concat(layer.trainableWeights); } } catch (e_15_1) { e_15 = { error: e_15_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_15) throw e_15.error; } } return weights; }, enumerable: false, configurable: true }); Object.defineProperty(Container.prototype, "nonTrainableWeights", { get: function () { var e_16, _a, e_17, _b; var weights = []; try { for (var _c = __values(this.layers), _d = _c.next(); !_d.done; _d = _c.next()) { var layer = _d.value; weights.push.apply(weights, __spreadArray([], __read(layer.nonTrainableWeights), false)); } } catch (e_16_1) { e_16 = { error: e_16_1 }; } finally { try { if (_d && !_d.done && (_a = _c.return)) _a.call(_c); } finally { if (e_16) throw e_16.error; } } if (!this.trainable) { var trainableWeights = []; try { for (var _e = __values(this.layers), _f = _e.next(); !_f.done; _f = _e.next()) { var layer = _f.value; trainableWeights.push.apply(trainableWeights, __spreadArray([], __read(layer.trainableWeights), false)); } } catch (e_17_1) { e_17 = { error: e_17_1 }; } finally { try { if (_f && !_f.done && (_b = _e.return)) _b.call(_e); } finally { if (e_17) throw e_17.error; } } return trainableWeights.concat(weights); } return weights; }, enumerable: false, configurable: true }); Object.defineProperty(Container.prototype, "weights", { get: function () { return this.trainableWeights.concat(this.nonTrainableWeights); }, enumerable: false, configurable: true }); /** * Loads all layer weights from a JSON object. * * Porting Note: HDF5 weight files cannot be directly loaded in JavaScript / * TypeScript. The utility script at `scripts/pykeras.py` offers means * to convert them into JSON strings compatible with this method. * Porting Note: TensorFlow.js Layers supports only loading by name currently. * * @param weights A JSON mapping weight names to weight values as nested * arrays of numbers, or a `NamedTensorMap`, i.e., a JSON mapping weight * names to `tf.Tensor` objects. * @param strict Require that the provided weights exactly match those * required by the container. Default: `true`. Passing `false` means that * extra weights and missing weights will be silently ignored. */ Container.prototype.loadWeights = function (weights, strict) { var e_18, _a, e_19, _b; if (strict === void 0) { strict = true; } var nameToWeight = {}; var totalWeightsCount = 0; var modelIsKerasSavedModelFormat = isKerasSavedModelFormat(weights); if (modelIsKerasSavedModelFormat) { this.parseWeights(weights); } try { // Check if weights from keras v3. for (var _c = __values(this.layers), _d = _c.next(); !_d.done; _d = _c.next()) { var layer = _d.value; try { for (var _e = (e_19 = void 0, __values(layer.weights.entries())), _f = _e.next(); !_f.done; _f = _e.next()) { var _g = __read(_f.value, 2), index = _g[0], weight = _g[1]; // Parse the name to layerName/index. // e.g. dense/0, dense/1, dense_1/0, dense_1/1 var parsedName = modelIsKerasSavedModelFormat ? "".concat(weight.name.split('/').slice(0, -1).join('/') + '/').concat(index) : weight.originalName; if (nameToWeight[parsedName] != null) { throw new ValueError("Duplicate weight name: ".concat(parsedName)); } nameToWeight[parsedName] = weight; totalWeightsCount++; } } catch (e_19_1) { e_19 = { error: e_19_1 }; } finally { try { if (_f && !_f.done && (_b = _e.return)) _b.call(_e); } finally { if (e_19) throw e_19.error; } } } } catch (e_18_1) { e_18 = { error: e_18_1 }; } finally { try { if (_d && !_d.done && (_a = _c.return)) _a.call(_c); } finally { if (e_18) throw e_18.error; } } var weightValueTuples = []; for (var name in weights) { // TF 2.2.0 added cell name to the weight name in the format of // layer_name/cell_name/weight_name, we need to remove // the inner cell name. var validatedName = name; if (nameToWeight[name] == null) { var tokens = name.split('/'); var shortenNameArray = tokens.slice(0, -2).concat([tokens[tokens.length - 1]]); validatedName = shortenNameArray.join('/'); } if (nameToWeight[validatedName] != null) { weightValueTuples.push([nameToWeight[validatedName], weights[name]]); } else if (strict) { throw new ValueError("Provided weight data has no target variable: ".concat(name)); } delete nameToWeight[validatedName]; } if (strict) { // Check that all weights are set. var unsetNames = []; for (var name in nameToWeight) { unsetNames.push(name); } if (unsetNames.length > 0) { throw new ValueError("".concat(unsetNames.length, " of ").concat(totalWeightsCount, " weights are not set: ") + "".concat(unsetNames)); } } batchSetValue(weightValueTuples); }; Container.prototype.parseWeights = function (weights) { var _loop_2 = function (key) { var listParts = key.split('/'); var list = ['vars', 'layer_checkpoint_dependencies']; // For keras v3, the weights name are saved based on the folder structure. // e.g. _backbone/_layer_checkpoint_dependencies/transformer/_self../ // _output_dense/vars/0 // Therefore we discard the `vars` and `layer_checkpoint_depencies` within // the saved name and only keeps the layer name and weights. // This can help to mapping the actual name of the layers and load each // weight accordingly. var newKey = listParts .map(function (str) { if (str.startsWith('_')) { return str.slice(1); } return str; }) .filter(function (str) { return !list.includes(str); }) .join('/'); if (newKey !== key) { weights[newKey] = weights[key]; delete weights[key]; } }; for (var key in Object.keys(weights)) { _loop_2(key); } }; /** * Util shared between different serialization methods. * @returns LayersModel config with Keras version information added. */ Container.prototype.updatedConfig = function () { var theConfig = this.getConfig(); var modelConfig = {}; modelConfig['className'] = this.getClassName(); modelConfig['config'] = theConfig; modelConfig['kerasVersion'] = "tfjs-layers ".concat(version); // TODO(nielsene): Replace something like K.backend() once // possible. modelConfig['backend'] = 'TensorFlow.js'; return modelConfig; }; /** * Returns a JSON string containing the network configuration. * * To load a network from a JSON save file, use * models.modelFromJSON(jsonString); * @param extraJsonArgs Unused in tfjs-layers, maintained for PyKeras * @param returnString Whether the return value should be stringified * (default: `true`). * @returns a JSON string if `returnString` (default), or a JSON object if * `!returnString`. */ // tslint:disable-next-line:no-any Container.prototype.toJSON = function (unused, returnString) { if (returnString === void 0) { returnString = true; } var modelConfig = convertTsToPythonic(this.updatedConfig()); return returnString ? JSON.stringify(modelConfig) : modelConfig; }; /** * Call the model on new inputs. * * In this case `call` just reapplies all ops in the graph to the new inputs * (e.g. build a new computational graph from the provided inputs). * * @param inputs A tensor or list of tensors. * @param mask A mask or list of masks. A mask can be either a tensor or null * (no mask). * * @return A tensor if there is a single output, or a list of tensors if there * are more than one outputs. */ Container.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { inputs = toList(inputs); var feedDict = new FeedDict(); for (var i = 0; i < _this.inputs.length; ++i) { feedDict.add(_this.inputs[i], inputs[i]); } return execute(_this.outputs, feedDict, kwargs); }); }; /** * Computes an output mask tensor. * * @param inputs Tensor or list of tensors. * @param mask Tensor or list of tensors. * * @return null or a tensor (or list of tensors, one per output tensor of the * layer). */ Container.prototype.computeMask = function (inputs, mask) { var _this = this; return tfc.tidy(function () { inputs = toList(inputs); var masks; if (mask == null) { masks = pyListRepeat(null, inputs.length); } else { masks = toList(mask); } // TODO(michaelterry): Add support for mask caching. return _this.runInternalGraph(inputs, masks)[1]; }); }; /** * Computes the output shape of the layer. * * Assumes that the layer will be built to match that input shape provided. * * @param inputShape A shape (tuple of integers) or a list of shape tuples * (one per output tensor of the layer). Shape tuples can include null for * free dimensions, instead of an integer. */ Container.prototype.computeOutputShape = function (inputShape) { var e_20, _a, e_21, _b; var inputShapes = normalizeShapeList(inputShape); if (inputShapes.length !== this.inputLayers.length) { throw new ValueError("Invalid inputShape argument ".concat(inputShape, ": ") + "model has ".concat(this.inputLayers.length, " tensor inputs.")); } // TODO(michaelterry): Add caching var layersToOutputShapes = {}; for (var i = 0; i < inputShapes.length; i++) { var layer = this.inputLayers[i]; var inputShape_1 = inputShapes[i]; // It's an input layer: computeOutputShape is identity, // and there is only one node and one tensor output. var shapeKey = layer.name + '_0_0'; layersToOutputShapes[shapeKey] = inputShape_1; } var depthKeys = Object.keys(this.nodesByDepth) .map(function (x) { return parseInt(x, 10); }) .sort(reverseNumberCompare); // Iterate over nodes, by depth level. if (depthKeys.length > 1) { try { for (var depthKeys_3 = __values(depthKeys), depthKeys_3_1 = depthKeys_3.next(); !depthKeys_3_1.done; depthKeys_3_1 = depthKeys_3.next()) { var depth = depthKeys_3_1.value; var nodes = this.nodesByDepth[depth]; try { for (var nodes_1 = (e_21 = void 0, __values(nodes)), nodes_1_1 = nodes_1.next(); !nodes_1_1.done; nodes_1_1 = nodes_1.next()) { var node = nodes_1_1.value; // This is always a single layer, never a list. var layer = node.outboundLayer; if (this.inputLayers.map(function (x) { return x.id; }).indexOf(layer.id) !== -1) { // We've already covered the input layers a few lines above. continue; } // Potentially redundant list, same size of node.inputTensors. var inputShapes_1 = []; for (var j = 0; j < node.inboundLayers.length; j++) { var inboundLayer = node.inboundLayers[j]; var nodeIndex_2 = node.nodeIndices[j]; var tensorIndex = node.tensorIndices[j]; var shapeKey = "".concat(inboundLayer.name, "_").concat(nodeIndex_2, "_").concat(tensorIndex); var inputShape_2 = layersToOutputShapes[shapeKey]; inputShapes_1.push(inputShape_2); } var outputShape = layer.computeOutputShape(singletonOrArray(inputShapes_1)); var outputShapes_1 = normalizeShapeList(outputShape); var nodeIndex = layer.inboundNodes.indexOf(node); for (var j = 0; j < outputShapes_1.length; j++) { var shapeKey = "".concat(layer.name, "_").concat(nodeIndex, "_").concat(j); layersToOutputShapes[shapeKey] = outputShapes_1[j]; } } } catch (e_21_1) { e_21 = { error: e_21_1 }; } finally { try { if (nodes_1_1 && !nodes_1_1.done && (_b = nodes_1.return)) _b.call(nodes_1); } finally { if (e_21) throw e_21.error; } } } } catch (e_20_1) { e_20 = { error: e_20_1 }; } finally { try { if (depthKeys_3_1 && !depthKeys_3_1.done && (_a = depthKeys_3.return)) _a.call(depthKeys_3); } finally { if (e_20) throw e_20.error; } } } // Read final output shapes from layersToOutputShapes. var outputShapes = []; var outputShapeKeys = []; for (var i = 0; i < this.outputLayers.length; i++) { var layer = this.outputLayers[i]; var nodeIndex = this.outputLayersNodeIndices[i]; var tensorIndex = this.outputLayersTensorIndices[i]; var shapeKey = "".concat(layer.name, "_").concat(nodeIndex, "_").concat(tensorIndex); outputShapeKeys.push(shapeKey); } for (var i = 0; i < outputShapeKeys.length; i++) { var key = outputShapeKeys[i]; assert$1(key in layersToOutputShapes); outputShapes.push(layersToOutputShapes[key]); } // TODO(michaelterry): Update cache return singletonOrArray(outputShapes); }; /** * Computes output tensors for new inputs. * * Note: * - Expects `inputs` to be a list (potentially with 1 element). * * @param inputs List of tensors * @param masks List of masks (tensors or null). * @return Three lists: outputTensors, outputMasks, outputShapes */ Container.prototype.runInternalGraph = function (inputs, masks) { var e_22, _a, e_23, _b, e_24, _c, e_25, _d; if (masks == null) { masks = pyListRepeat(null, inputs.length); } // Dictionary mapping reference tensors to tuples // (computed tensor, compute mask) // we assume a 1:1 mapping from tensor to mask // TODO: raise exception when a `.computeMask()` call // does not return a list the same size as `call` var tensorMap = {}; for (var i = 0; i < this.inputs.length; ++i) { var x = this.inputs[i]; var y = inputs[i]; var mask = masks[i]; tensorMap[x.id] = [y, mask]; } var depthKeys = Object.keys(this.nodesByDepth) .map(function (x) { return parseInt(x, 10); }) .sort(reverseNumberCompare); try { for (var depthKeys_4 = __values(depthKeys), depthKeys_4_1 = depthKeys_4.next(); !depthKeys_4_1.done; depthKeys_4_1 = depthKeys_4.next()) { var depth = depthKeys_4_1.value; var nodes = this.nodesByDepth[depth]; try { for (var nodes_2 = (e_23 = void 0, __values(nodes)), nodes_2_1 = nodes_2.next(); !nodes_2_1.done; nodes_2_1 = nodes_2.next()) { var node = nodes_2_1.value; // This is always a single layer, never a list. var layer = node.outboundLayer; var referenceInputTensors = node.inputTensors; var referenceOutputTensors = node.outputTensors; // If all previous input tensors are available in tensorMap, // then call node.inboundLayer on them. // List of tuples [input, mask]: var computedData = new Array(); try { for (var referenceInputTensors_1 = (e_24 = void 0, __values(referenceInputTensors)), referenceInputTensors_1_1 = referenceInputTensors_1.next(); !referenceInputTensors_1_1.done; referenceInputTensors_1_1 = referenceInputTensors_1.next()) { var x = referenceInputTensors_1_1.value; if (x.id in tensorMap) { computedData.push(tensorMap[x.id]); } } } catch (e_24_1) { e_24 = { error: e_24_1 }; } finally { try { if (referenceInputTensors_1_1 && !referenceInputTensors_1_1.done && (_c = referenceInputTensors_1.return)) _c.call(referenceInputTensors_1); } finally { if (e_24) throw e_24.error; } } if (computedData.length === referenceInputTensors.length) { // TODO(michaelterry): Add K.name_scope here, if we need it. var kwargs = {}; var computedTensors = void 0; var computedMasks = void 0; var outputTensors_1 = void 0; var outputMasks_1 = void 0; // call layer if (node.callArgs != null) { kwargs = node.callArgs; } if (computedData.length === 1) { var _e = __read(computedData[0], 2), computedTensor = _e[0], computedMask = _e[1]; if (kwargs['mask'] == null) { kwargs['mask'] = computedMask; } outputTensors_1 = toList(layer.call(computedTensor, kwargs)); outputMasks_1 = toList(layer.computeMask(computedTensor, computedMask)); computedTensors = [computedTensor]; computedMasks = [computedMask]; } else { computedTensors = computedData.map(function (x) { return x[0]; }); computedMasks = computedData.map(function (x) { return x[1]; }); if (kwargs['mask'] == null) { kwargs['mask'] = computedMasks; } outputTensors_1 = toList(layer.call(computedTensors, kwargs)); outputMasks_1 = toList(layer.computeMask(computedTensors, computedMasks)); } if (layer.activityRegularizer) { throw new NotImplementedError('LayersModel invocation with concrete Tensor value(s) in the ' + 'presence of activity regularizer(s) is not supported yet.'); } // TODO(michaelterry): Add model updates and losses // Update tensor map. for (var i = 0; i < referenceOutputTensors.length; ++i) { var x = referenceOutputTensors[i]; var y = outputTensors_1[i]; var mask = outputMasks_1[i]; tensorMap[x.id] = [y, mask]; } } } } catch (e_23_1) { e_23 = { error: e_23_1 }; } finally { try { if (nodes_2_1 && !nodes_2_1.done && (_b = nodes_2.return)) _b.call(nodes_2); } finally { if (e_23) throw e_23.error; } } } } catch (e_22_1) { e_22 = { error: e_22_1 }; } finally { try { if (depthKeys_4_1 && !depthKeys_4_1.done && (_a = depthKeys_4.return)) _a.call(depthKeys_4); } finally { if (e_22) throw e_22.error; } } var outputTensors = []; var outputMasks = []; var outputShapes = []; try { for (var _f = __values(this.outputs), _g = _f.next(); !_g.done; _g = _f.next()) { var x = _g.value; assert$1(x.id in tensorMap, "Could not compute output ".concat(x.name, " : ").concat(x.id)); var _h = __read(tensorMap[x.id], 2), tensor = _h[0], mask = _h[1]; outputShapes.push(tensor.shape); outputTensors.push(tensor); outputMasks.push(mask); } } catch (e_25_1) { e_25 = { error: e_25_1 }; } finally { try { if (_g && !_g.done && (_d = _f.return)) _d.call(_f); } finally { if (e_25) throw e_25.error; } } // TODO(michaelterry): Add support for caches. return [outputTensors, outputMasks, outputShapes]; }; /** * Builds a map of internal node keys to node ordering. * Used in serializaion a node orderings may change as unused nodes are * dropped. Porting Note: This helper method was pulled out of getConfig to * improve readability. * @param layers An array of Layers in the model. * @returns Map of Node Keys to index order within the layer. */ Container.prototype.buildNodeConversionMap = function (layers) { var e_26, _a; var nodeConversionMap = {}; var keptNodes; try { for (var _b = __values(this.layers), _c = _b.next(); !_c.done; _c = _b.next()) { var layer = _c.value; keptNodes = layer instanceof Container ? 1 : 0; for (var originalNodeIndex = 0; originalNodeIndex < layer.inboundNodes.length; originalNodeIndex++) { var nodeKey = Container.nodeKey(layer, originalNodeIndex); if (this.containerNodes.has(nodeKey)) { // i.e. we mark it to be saved nodeConversionMap[nodeKey] = keptNodes; keptNodes += 1; } } } } catch (e_26_1) { e_26 = { error: e_26_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_26) throw e_26.error; } } return nodeConversionMap; }; Container.prototype.getLayer = function (nameOrIndex, index) { var e_27, _a; if (index != null) { return this.findLayer(index); } else { if (nameOrIndex == null) { throw new ValueError('Provide either a layer name or layer index'); } if (typeof nameOrIndex === 'number') { return this.findLayer(nameOrIndex); } } try { for (var _b = __values(this.layers), _c = _b.next(); !_c.done; _c = _b.next()) { var layer = _c.value; if (layer.name === nameOrIndex) { return layer; } } } catch (e_27_1) { e_27 = { error: e_27_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_27) throw e_27.error; } } throw new ValueError("No such layer: ".concat(nameOrIndex)); }; Container.prototype.findLayer = function (index) { if (this.layers.length <= index) { throw new ValueError("Was asked to retrieve layer at index ".concat(index, ", but model only ") + "has ".concat(this.layers.length, " layer(s).")); } else { return this.layers[index]; } }; /** * Retrieves the Container's current loss values. * * Used for regularizers during training. */ Container.prototype.calculateLosses = function () { var _this = this; // Porting Node: This is an augmentation to Container.loss in PyKeras. // In PyKeras, Container.loss returns symbolic tensors. Here a concrete // Tensor (specifically Scalar) values are returned. This is due to the // imperative backend. return tfc.tidy(function () { var e_28, _a; var losses = []; try { for (var _b = __values(_this.layers), _c = _b.next(); !_c.done; _c = _b.next()) { var layer = _c.value; for (var nodeIndex = 0; nodeIndex < layer.inboundNodes.length; ++nodeIndex) { var nodeKey = Container.nodeKey(layer, nodeIndex); if (_this.containerNodes.has(nodeKey)) { losses.push.apply(losses, __spreadArray([], __read(layer.calculateLosses()), false)); } } } } catch (e_28_1) { e_28 = { error: e_28_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_28) throw e_28.error; } } // TODO(cais): Add any unconditional model-level losses? return losses; }); }; Container.prototype.getConfig = function () { var e_29, _a; var config = { name: this.name }; // Build a map from layer unique name (self._node_key) // to the index of the nodes that are saved in the config. // Only nodes in container_nodes are saved. var nodeConversionMap = this.buildNodeConversionMap(this.layers); // Serialize and save the layers in layerConfigs var layerConfigs = []; try { for (var _b = __values(this.layers), _c = _b.next(); !_c.done; _c = _b.next()) { var layer = _c.value; var layerClassName = layer.getClassName(); var layerConfig = layer.getConfig(); var filteredInboundNodes = []; for (var originalNodeIndex = 0; originalNodeIndex < layer.inboundNodes.length; originalNodeIndex++) { var node = layer.inboundNodes[originalNodeIndex]; var nodeKey = Container.nodeKey(layer, originalNodeIndex); var kwargs = {}; if (this.containerNodes.has(nodeKey)) { // The node is relevant to the model: // add to filteredInboundNodes. if (node.callArgs) { try { JSON.stringify(node.callArgs); kwargs = node.callArgs; } catch (err) { console.warn("Layer ".concat(layer.name, " was passed ") + "non-serializable keyword arguments: " + "".concat(node.callArgs, ". They will not be included ") + "in the serialized model (and thus will be " + "missing at deserialization time)."); kwargs = {}; } } if (node.inboundLayers.length > 0) { var nodeData = []; for (var i = 0; i < node.inboundLayers.length; i++) { var inboundLayer = node.inboundLayers[i]; var nodeIndex = node.nodeIndices[i]; var tensorIndex = node.tensorIndices[i]; var nodeKey_1 = Container.nodeKey(inboundLayer, nodeIndex); var newNodeIndex = nodeConversionMap[nodeKey_1]; if (newNodeIndex == null) { newNodeIndex = 0; } nodeData.push([inboundLayer.name, newNodeIndex, tensorIndex, kwargs]); } filteredInboundNodes.push(nodeData); } } } var dict = {}; dict['name'] = layer.name; dict['className'] = layerClassName; dict['config'] = layerConfig; dict['inboundNodes'] = filteredInboundNodes; layerConfigs.push(dict); } } catch (e_29_1) { e_29 = { error: e_29_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_29) throw e_29.error; } } config['layers'] = layerConfigs; // Gather info about inputs and outputs var modelInputs = []; for (var i = 0; i < this.inputLayers.length; i++) { var layer = this.inputLayers[i]; var nodeIndex = this.inputLayersNodeIndices[i]; var nodeKey = Container.nodeKey(layer, nodeIndex); if (!this.containerNodes.has(nodeKey)) { continue; } var newNodeIndex = nodeConversionMap[nodeKey]; if (newNodeIndex === null || newNodeIndex === undefined) { newNodeIndex = 0; } var tensorIndex = this.inputLayersTensorIndices[i]; modelInputs.push([layer.name, newNodeIndex, tensorIndex]); } config['inputLayers'] = modelInputs; var modelOutputs = []; for (var i = 0; i < this.outputLayers.length; i++) { var layer = this.outputLayers[i]; var nodeIndex = this.outputLayersNodeIndices[i]; var nodeKey = Container.nodeKey(layer, nodeIndex); if (!this.containerNodes.has(nodeKey)) { continue; } var newNodeIndex = nodeConversionMap[nodeKey]; if (newNodeIndex === null || newNodeIndex === undefined) { newNodeIndex = 0; } var tensorIndex = this.outputLayersTensorIndices[i]; modelOutputs.push([layer.name, newNodeIndex, tensorIndex]); } config['outputLayers'] = modelOutputs; return config; }; /** * Instantiates a LayersModel from its config (output of `get_config()`). * @param cls the class to create * @param config LayersModel config dictionary. * @param customObjects An optional dictionary of custom objects. * @param fastWeightInit Optional flag to use fast weight initialization * during deserialization. This is applicable to cases in which * the initialization will be immediately overwritten by loaded weight * values. Default: `false`. * @returns A LayersModel instance. * @throws ValueError: In case of improperly formatted config dict. */ /** @nocollapse */ Container.fromConfig = function (cls, config, customObjects, fastWeightInit) { var e_30, _a, e_31, _b, e_32, _c, e_33, _d, e_34, _e; if (fastWeightInit === void 0) { fastWeightInit = false; } // Layer instances created during // the graph reconstruction process var createdLayers = {}; // Dictionary mapping layer instances to // node data that specifies a layer call. // It acts as a queue that maintains any unprocessed // layer call until it becomes possible to process it // (i.e. until the input tensors to the call all exist). var unprocessedNodes = {}; function addUnprocessedNode(layer, nodeData) { if (!(layer.name in unprocessedNodes)) { unprocessedNodes[layer.name] = [nodeData]; } else { unprocessedNodes[layer.name].push(nodeData); } } function processNode(layer, nodeData) { var e_35, _a; var inputTensors = []; var kwargs; try { for (var nodeData_1 = __values(nodeData), nodeData_1_1 = nodeData_1.next(); !nodeData_1_1.done; nodeData_1_1 = nodeData_1.next()) { var inputData = nodeData_1_1.value; var inboundLayerName = inputData[0]; var inboundNodeIndex = inputData[1]; var inboundTensorIndex = inputData[2]; kwargs = inputData[3] == null ? {} : inputData[3]; if (!(inboundLayerName in createdLayers)) { addUnprocessedNode(layer, nodeData); return; } var inboundLayer = createdLayers[inboundLayerName]; if (inboundLayer.inboundNodes.length <= inboundNodeIndex) { addUnprocessedNode(layer, nodeData); return; } var inboundNode = inboundLayer.inboundNodes[inboundNodeIndex]; inputTensors.push(inboundNode.outputTensors[inboundTensorIndex]); } } catch (e_35_1) { e_35 = { error: e_35_1 }; } finally { try { if (nodeData_1_1 && !nodeData_1_1.done && (_a = nodeData_1.return)) _a.call(nodeData_1); } finally { if (e_35) throw e_35.error; } } // Call layer on its inputs, thus creating the node // and building the layer if needed. // Note: This has Eager vs Graph Implications. if (inputTensors.length > 0) { layer.apply(singletonOrArray(inputTensors), kwargs); // was ** kwargs } } /** * Deserialize a layer, then call it on appropriate inputs. * @param layerData: layer config dict. * @throws ValueError: In case of improperly formatted `layer_data` * dict. */ function processLayer(layerData) { var layerName = layerData['name']; // Instantiate layer. var layer = deserialize(layerData, config['customObjects'] != null ? config['customObjects'] : {}); layer.setFastWeightInitDuringBuild(fastWeightInit); createdLayers[layerName] = layer; // Gather layer inputs. var inboundNodesData = layerData['inboundNodes']; inboundNodesData.forEach(function (nodeData) { if (!(nodeData instanceof Array)) { throw new ValueError("Corrupted configuration, expected array for nodeData: ".concat(nodeData)); } // We don't process nodes (i.e. make layer calls) // on the fly because the inbound node may not yet exist, // in case of layer shared at different topological depths // (e.g.a model such as A(B(A(B(x))))) addUnprocessedNode(layer, nodeData); }); } // First, we create all layers and enqueue nodes to be processed. var name = config['name']; var layersFromConfig = config['layers']; try { for (var layersFromConfig_1 = __values(layersFromConfig), layersFromConfig_1_1 = layersFromConfig_1.next(); !layersFromConfig_1_1.done; layersFromConfig_1_1 = layersFromConfig_1.next()) { var layerData = layersFromConfig_1_1.value; processLayer(layerData); } } catch (e_30_1) { e_30 = { error: e_30_1 }; } finally { try { if (layersFromConfig_1_1 && !layersFromConfig_1_1.done && (_a = layersFromConfig_1.return)) _a.call(layersFromConfig_1); } finally { if (e_30) throw e_30.error; } } // Then we process nodes in order of layer depth. // Nodes that cannot yet be processed(if the inbound node // does not yet exist) are re - enqueued, and the process // is repeated until all nodes are processed. while (!isObjectEmpty(unprocessedNodes)) { try { for (var layersFromConfig_2 = (e_31 = void 0, __values(layersFromConfig)), layersFromConfig_2_1 = layersFromConfig_2.next(); !layersFromConfig_2_1.done; layersFromConfig_2_1 = layersFromConfig_2.next()) { var layerData = layersFromConfig_2_1.value; var layer = createdLayers[layerData['name']]; if (layer.name in unprocessedNodes) { var currentUnprocessedNodesForLayer = unprocessedNodes[layer.name]; delete unprocessedNodes[layer.name]; try { for (var currentUnprocessedNodesForLayer_1 = (e_32 = void 0, __values(currentUnprocessedNodesForLayer)), currentUnprocessedNodesForLayer_1_1 = currentUnprocessedNodesForLayer_1.next(); !currentUnprocessedNodesForLayer_1_1.done; currentUnprocessedNodesForLayer_1_1 = currentUnprocessedNodesForLayer_1.next()) { var nodeData = currentUnprocessedNodesForLayer_1_1.value; processNode(layer, nodeData); } } catch (e_32_1) { e_32 = { error: e_32_1 }; } finally { try { if (currentUnprocessedNodesForLayer_1_1 && !currentUnprocessedNodesForLayer_1_1.done && (_c = currentUnprocessedNodesForLayer_1.return)) _c.call(currentUnprocessedNodesForLayer_1); } finally { if (e_32) throw e_32.error; } } } } } catch (e_31_1) { e_31 = { error: e_31_1 }; } finally { try { if (layersFromConfig_2_1 && !layersFromConfig_2_1.done && (_b = layersFromConfig_2.return)) _b.call(layersFromConfig_2); } finally { if (e_31) throw e_31.error; } } } var inputTensors = []; var outputTensors = []; var inputLayersFromConfig = config['inputLayers']; try { for (var inputLayersFromConfig_1 = __values(inputLayersFromConfig), inputLayersFromConfig_1_1 = inputLayersFromConfig_1.next(); !inputLayersFromConfig_1_1.done; inputLayersFromConfig_1_1 = inputLayersFromConfig_1.next()) { var layerData = inputLayersFromConfig_1_1.value; var layerName = layerData[0]; var nodeIndex = layerData[1]; var tensorIndex = layerData[2]; assert$1(layerName in createdLayers); var layer = createdLayers[layerName]; var layerOutputTensors = layer.inboundNodes[nodeIndex].outputTensors; inputTensors.push(layerOutputTensors[tensorIndex]); } } catch (e_33_1) { e_33 = { error: e_33_1 }; } finally { try { if (inputLayersFromConfig_1_1 && !inputLayersFromConfig_1_1.done && (_d = inputLayersFromConfig_1.return)) _d.call(inputLayersFromConfig_1); } finally { if (e_33) throw e_33.error; } } var outputLayersFromConfig = config['outputLayers']; try { for (var outputLayersFromConfig_1 = __values(outputLayersFromConfig), outputLayersFromConfig_1_1 = outputLayersFromConfig_1.next(); !outputLayersFromConfig_1_1.done; outputLayersFromConfig_1_1 = outputLayersFromConfig_1.next()) { var layerData = outputLayersFromConfig_1_1.value; var layerName = layerData[0]; var nodeIndex = layerData[1]; var tensorIndex = layerData[2]; assert$1(layerName in createdLayers); var layer = createdLayers[layerName]; var layerOutputTensors = layer.inboundNodes[nodeIndex].outputTensors; outputTensors.push(layerOutputTensors[tensorIndex]); } } catch (e_34_1) { e_34 = { error: e_34_1 }; } finally { try { if (outputLayersFromConfig_1_1 && !outputLayersFromConfig_1_1.done && (_e = outputLayersFromConfig_1.return)) _e.call(outputLayersFromConfig_1); } finally { if (e_34) throw e_34.error; } } return new cls({ inputs: inputTensors, outputs: outputTensors, name: name }); }; Object.defineProperty(Container.prototype, "stateful", { /** * Determine whether the container is stateful. * * Porting Note: this is the equivalent of the stateful @property of * the Container class in PyKeras. */ get: function () { var e_36, _a; // Porting Note: This check is to prevent inadvertent setting of the // _stateful property of the Container instance. if (this._stateful) { throw new ValueError('Container instance unexpectedly has _stateful = true. The ' + 'statefulness of a Container is determined by the Layers it ' + 'contains. Its _stateful property must remain the default false.'); } try { for (var _b = __values(this.layers), _c = _b.next(); !_c.done; _c = _b.next()) { var layer = _c.value; if (layer.stateful) { return true; } } } catch (e_36_1) { e_36 = { error: e_36_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_36) throw e_36.error; } } return false; }, enumerable: false, configurable: true }); /** * Reset the state of all stateful constituent layers (if any). * * Examples of stateful layers include RNN layers whose `stateful` property * is set as `true`. */ Container.prototype.resetStates = function () { var _this = this; tfc.tidy(function () { _this.layers.forEach(function (layer) { // tslint:disable:no-any if (layer.stateful) { layer.resetStates(); } // tslint:enable:no-any }); }); }; return Container; }(Layer)); function standardizeSampleOrClassWeights(xWeight, outputNames, weightType) { var numOutputs = outputNames.length; if (xWeight == null || (Array.isArray(xWeight) && xWeight.length === 0)) { return outputNames.map(function (name) { return null; }); } if (numOutputs === 1) { if (Array.isArray(xWeight) && xWeight.length === 1) { return xWeight; } else if (typeof xWeight === 'object' && outputNames[0] in xWeight) { return [xWeight[outputNames[0]]]; } else { return [xWeight]; } } if (Array.isArray(xWeight)) { if (xWeight.length !== numOutputs) { throw new Error("Provided ".concat(weightType, " is an array of ").concat(xWeight.length, " ") + "element(s), but the model has ".concat(numOutputs, " outputs. ") + "Make sure a set of weights is provided for each model output."); } return xWeight; } else if (typeof xWeight === 'object' && Object.keys(xWeight).length > 0 && typeof xWeight[Object.keys(xWeight)[0]] === 'object') { var output_1 = []; outputNames.forEach(function (outputName) { if (outputName in xWeight) { output_1.push(xWeight[outputName]); } else { output_1.push(null); } }); return output_1; } else { throw new Error("The model has multiple (".concat(numOutputs, ") outputs, ") + "so ".concat(weightType, " must be either an array with ") + "".concat(numOutputs, " elements or an object with ").concat(outputNames, " keys. ") + "Provided ".concat(weightType, " not understood: ").concat(JSON.stringify(xWeight))); } } /** * Standardize class weighting objects. * * This function takes a single class-weighting object, an array of them, * or a map from output name to class-weighting object. It compares it to the * output name(s) of the model, base on which it outputs an array of * class-weighting objects of which the length matches the number of outputs. * * @param classWeight Input class-weighting object(s). * @param outputNames All output name(s) of the model. * @return An array of class-weighting objects. The length of the array matches * the model's number of outputs. */ function standardizeClassWeights(classWeight, outputNames) { return standardizeSampleOrClassWeights(classWeight, outputNames, 'classWeight'); } /** * Standardize by-sample and/or by-class weights for training. * * Note that this function operates on one model output at a time. For a model * with multiple outputs, you must call this function multiple times. * * @param y The target tensor that the by-sample and/or by-class weight is for. * The values of y are assumed to encode the classes, either directly * as an integer index, or as one-hot encoding. * @param sampleWeight By-sample weights. * @param classWeight By-class weights: an object mapping class indices * (integers) to a weight (float) to apply to the model's loss for the * samples from this class during training. This can be useful to tell the * model to "pay more attention" to samples from an under-represented class. * @param sampleWeightMode The mode for the sample weights. * @return A Promise of weight tensor, of which the size of the first dimension * matches that of `y`. */ function standardizeWeights(y, sampleWeight, classWeight, sampleWeightMode) { return __awaiter(this, void 0, void 0, function () { var yClasses, yClassIndices, _a, _b, classSampleWeight_1; return __generator(this, function (_c) { switch (_c.label) { case 0: if (sampleWeight != null || sampleWeightMode != null) { // TODO(cais): Once 'temporal' mode is implemented, document it in the doc // string. throw new Error('Support sampleWeight is not implemented yet'); } if (!(classWeight != null)) return [3 /*break*/, 2]; yClasses = tfc.tidy(function () { if (y.shape.length === 1) { // Assume class indices. return tfc.clone(y); } else if (y.shape.length === 2) { if (y.shape[1] > 1) { // Assume one-hot encoding of classes. var axis = 1; return tfc.argMax(y, axis); } else if (y.shape[1] === 1) { // Class index. return tfc.reshape(y, [y.shape[0]]); } else { throw new Error("Encountered unexpected last-dimension size (".concat(y.shape[1], ") ") + "during handling of class weights. The size is expected to be " + ">= 1."); } } else { throw new Error("Unexpected rank of target (y) tensor (".concat(y.rank, ") during ") + "handling of class weights. The rank is expected to be 1 or 2."); } }); _b = (_a = Array).from; return [4 /*yield*/, yClasses.data()]; case 1: yClassIndices = _b.apply(_a, [_c.sent()]); tfc.dispose(yClasses); classSampleWeight_1 = []; yClassIndices.forEach(function (classIndex) { if (classWeight[classIndex] == null) { throw new Error("classWeight must contain all classes in the training data. " + "The class ".concat(classIndex, " exists in the data but not in ") + "classWeight"); } else { classSampleWeight_1.push(classWeight[classIndex]); } }); return [2 /*return*/, tfc.tensor1d(classSampleWeight_1, 'float32')]; case 2: return [2 /*return*/, null]; } }); }); } /** * Apply per-sample weights on the loss values from a number of samples. * * @param losses Loss tensor of shape `[batchSize]`. * @param sampleWeights Per-sample weight tensor of shape `[batchSize]`. * @returns Tensor of the same shape as`losses`. */ function computeWeightedLoss(losses, sampleWeights) { return tfc.mul(losses, sampleWeights); } // Default batch size used during tensor-based validation. var DEFAULT_VALIDATION_BATCH_SIZE = 32; /** * Standardize the output of a dataset iterator for use by * LayersModel.fitDataset(). * * @param model: A `tf.LayersModel` object. * @param iteratorOut The output of a dataset iterator. It is required to be * an object of the form `{xs: TensorOrArrayOrMap, ys: * TensorOrArrayOrMap}`, where `TensorOrArrayOrMap` is a single `tf.Tensor`, * a `tf.Tensor[]`, or a flat map from string names to `tf.Tensor`s. * @returns A flat array of `tf.Tensor` objects: the input `tf.Tensor`s * followed by the target `tf.Tensor`s. When `tf.Tensor`s are provided * as a map, the order in the resulting array is taken from the `inputNames` * and `outputNames` of the model. */ function standardizeDataIteratorOutput( // Type `model` as `any` here to avoid circular dependency w/ // training.ts. // tslint:disable-next-line:no-any model, iteratorOut) { var xs; var ys; var iteratorOutObj = iteratorOut; xs = iteratorOutObj['xs']; ys = iteratorOutObj['ys']; tfc__namespace.util.assert(xs != null && ys != null, function () { return 'A Dataset iterator for fitDataset() is expected to generate ' + 'objects of the form `{xs: xVal, ys: yVal}`, where the two ' + 'values may be `tf.Tensor`, an array of Tensors, or a map of ' + 'string to Tensor. The provided Dataset instead generates ' + "".concat(iteratorOut); }); var flattenedXs = flattenTensorOrArrayOrMap('input', model.inputNames, xs); var flattenedYs = flattenTensorOrArrayOrMap('output', model.outputNames, ys); var batchSize = flattenedXs[0].shape[0]; tfc__namespace.util.assert(flattenedXs.length === model.inputs.length, function () { return "LayersModel has ".concat(model.inputs.length, " inputs, but the dataset ") + "provides ".concat(flattenedXs.length, " inputs. (Expected input keys: ") + "".concat(JSON.stringify(model.inputNames), ")"); }); tfc__namespace.util.assert(flattenedYs.length === model.outputs.length, function () { return "LayersModel has ".concat(model.outputs.length, " outputs, but the dataset ") + "provides ".concat(flattenedYs.length, " outputs. (Expected output keys: ") + "".concat(JSON.stringify(model.outputNames), ")"); }); var _loop_1 = function (xIndex) { tfc__namespace.util.assert(flattenedXs[xIndex].shape[0] === batchSize, function () { return "Batch size mismatch: input " + "".concat(model.inputNames[xIndex], " has ").concat(flattenedXs[xIndex].shape[0], "; ") + "expected ".concat(batchSize, " based on input ").concat(model.inputNames[0], "."); }); }; for (var xIndex = 0; xIndex < flattenedXs.length; xIndex++) { _loop_1(xIndex); } var _loop_2 = function (yIndex) { tfc__namespace.util.assert(flattenedYs[yIndex].shape[0] === batchSize, function () { return "Batch size mismatch: output " + "".concat(model.outputNames[yIndex], " has ").concat(flattenedYs[yIndex].shape[0], "; ") + "expected ".concat(batchSize, " based on input ").concat(model.inputNames[0], "."); }); }; for (var yIndex = 0; yIndex < flattenedYs.length; yIndex++) { _loop_2(yIndex); } return { xs: flattenedXs, ys: flattenedYs }; } function flattenTensorOrArrayOrMap(inputOrOutput, names, values) { var e_1, _a; if (values instanceof tfc__namespace.Tensor) { return [values]; } else if (Array.isArray(values)) { tfc__namespace.util.assert(values.length === names.length, function () { return "Received an array of ".concat(values.length, " Tensors, but expected ").concat(names.length, " to match the ").concat(inputOrOutput, " keys ").concat(names, "."); }); return values; } else { var result = []; try { // Check that all the required keys are available. for (var names_1 = __values(names), names_1_1 = names_1.next(); !names_1_1.done; names_1_1 = names_1.next()) { var name = names_1_1.value; if (values[name] == null) { throw new ValueError("The feature data generated by the dataset lacks the required " + "".concat(inputOrOutput, " key '").concat(name, "'.")); } result.push(values[name]); } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (names_1_1 && !names_1_1.done && (_a = names_1.return)) _a.call(names_1); } finally { if (e_1) throw e_1.error; } } return result; } } function standardizeTensorValidationData(data) { if (data.length === 3) { throw new NotImplementedError('Validation with sample weights is not implemented yet.'); } return { xs: data[0], ys: data[1] }; } function fitDataset( // Type `model` as `any` here to avoid circular dependency w/ // training.ts. // tslint:disable-next-line:no-any model, dataset, args) { return __awaiter(this, void 0, void 0, function () { var hasBatchesPerEpoch, doValidation, valXs, valYs, validationData, trainFunction, outLabels, callbackMetrics, callbacks, verbose, _a, callbackList, history, epoch, dataIterator, epochLogs, stepsDone, batchIndex, iteratorOut, _b, xs, ys, batchLogs, sampleWeights, standardClassWeights, i, _c, _d, ins, outs, i, label, out, valOuts, _e, i; return __generator(this, function (_f) { switch (_f.label) { case 0: hasBatchesPerEpoch = args.batchesPerEpoch != null; tfc__namespace.util.assert(model.optimizer != null, function () { return 'You must compile a model before training/testing. Use ' + 'LayersModel.compile(modelCompileConfig).'; }); tfc__namespace.util.assert(args != null, function () { return "For fitDataset(), the 2nd argument (config) is required, " + "but it is not provided in this call."; }); tfc__namespace.util.assert(args.epochs != null && args.epochs > 0 && Number.isInteger(args.epochs), function () { return "For fitDataset(), config.epochs is expected to be a positive " + "integer, but got ".concat(args.epochs); }); tfc__namespace.util.assert(!hasBatchesPerEpoch || (args.batchesPerEpoch > 0 && Number.isInteger(args.batchesPerEpoch)), function () { return "For fitDataset(), config.batchesPerEpoch is expected to be a " + "positive integer if specified, but got ".concat(args.batchesPerEpoch); }); tfc__namespace.util.assert( // tslint:disable-next-line:no-any args['validationSplit'] == null, function () { return '`validationSplit` is not supported by `fitDataset()`. ' + 'Use validationData instead.'; }); if (model.isTraining) { throw new Error('Cannot start training because another fit() call is ongoing.'); } model.isTraining = true; _f.label = 1; case 1: _f.trys.push([1, , 26, 27]); doValidation = args.validationData != null; valXs = void 0; valYs = void 0; if (doValidation) { if (isDatasetObject(args.validationData)) { tfc__namespace.util.assert(args.validationBatches == null || (args.validationBatches > 0 && Number.isInteger(args.validationBatches)), function () { return "For fitDataset() with dataset-based validation, " + "config.validationBatches is expected not to be provided, " + "or to be a positive integer, " + "but got ".concat(args.validationBatches); }); } else { validationData = standardizeTensorValidationData(args.validationData); valXs = validationData.xs; valYs = validationData.ys; } } trainFunction = model.makeTrainFunction(); outLabels = model.getDedupedMetricsNames(); callbackMetrics = void 0; if (doValidation) { callbackMetrics = outLabels.slice().concat(outLabels.map(function (n) { return 'val_' + n; })); } else { callbackMetrics = outLabels.slice(); } callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery); verbose = args.verbose == null ? 1 : args.verbose; _a = configureCallbacks(callbacks, verbose, args.epochs, null, null, getStepsPerEpoch(dataset, args), null, // Batch size determined by the dataset itself. doValidation, callbackMetrics), callbackList = _a.callbackList, history = _a.history; callbackList.setModel(model); model.history = history; return [4 /*yield*/, callbackList.onTrainBegin()]; case 2: _f.sent(); model.stopTraining_ = false; epoch = args.initialEpoch == null ? 0 : args.initialEpoch; return [4 /*yield*/, dataset.iterator()]; case 3: dataIterator = _f.sent(); _f.label = 4; case 4: if (!(epoch < args.epochs)) return [3 /*break*/, 23]; epochLogs = {}; return [4 /*yield*/, callbackList.onEpochBegin(epoch)]; case 5: _f.sent(); stepsDone = 0; batchIndex = 0; if (!!hasBatchesPerEpoch) return [3 /*break*/, 7]; return [4 /*yield*/, dataset.iterator()]; case 6: dataIterator = _f.sent(); _f.label = 7; case 7: if (!(hasBatchesPerEpoch ? stepsDone < args.batchesPerEpoch : true)) return [3 /*break*/, 21]; return [4 /*yield*/, dataIterator.next()]; case 8: iteratorOut = _f.sent(); // If `batchesPerEpoch` is specified, the dataset should not be // exhausted until all epoches are done. if (hasBatchesPerEpoch && iteratorOut.done) { console.warn('You provided `batchesPerEpoch` as ' + "".concat(args.batchesPerEpoch, ", ") + 'but your dataset iterator ran out of data after ' + "".concat(stepsDone, " batches; ") + 'interrupting training. Make sure that your ' + 'dataset can generate at least `batchesPerEpoch * epochs` ' + 'batches (in this case, ' + "".concat(args.batchesPerEpoch * args.epochs, " batches). ") + 'You may need to use the repeat() function when building ' + 'your dataset.'); return [3 /*break*/, 21]; } if (!(iteratorOut.value != null)) return [3 /*break*/, 15]; _b = standardizeDataIteratorOutput(model, iteratorOut.value), xs = _b.xs, ys = _b.ys; batchLogs = {}; batchLogs['batch'] = batchIndex; batchLogs['size'] = xs[0].shape[0]; return [4 /*yield*/, callbackList.onBatchBegin(batchIndex, batchLogs)]; case 9: _f.sent(); sampleWeights = []; if (!(args.classWeight != null)) return [3 /*break*/, 13]; standardClassWeights = standardizeClassWeights(args.classWeight, model.outputNames); i = 0; _f.label = 10; case 10: if (!(i < standardClassWeights.length)) return [3 /*break*/, 13]; _d = (_c = sampleWeights).push; return [4 /*yield*/, standardizeWeights(ys[i], null, standardClassWeights[i])]; case 11: _d.apply(_c, [_f.sent()]); _f.label = 12; case 12: ++i; return [3 /*break*/, 10]; case 13: ins = xs.concat(ys).concat(sampleWeights); outs = trainFunction(ins); tfc__namespace.dispose(ins); for (i = 0; i < outLabels.length; ++i) { label = outLabels[i]; out = outs[i]; batchLogs[label] = out; tfc__namespace.keep(out); } return [4 /*yield*/, callbackList.onBatchEnd(batchIndex, batchLogs)]; case 14: _f.sent(); disposeTensorsInLogs(batchLogs); batchIndex++; stepsDone++; _f.label = 15; case 15: if (!(hasBatchesPerEpoch ? stepsDone >= args.batchesPerEpoch : iteratorOut.done)) return [3 /*break*/, 20]; if (!doValidation) return [3 /*break*/, 19]; valOuts = void 0; if (!isDatasetObject(args.validationData)) return [3 /*break*/, 17]; _e = toList; return [4 /*yield*/, model.evaluateDataset(args.validationData, { batches: args.validationBatches })]; case 16: valOuts = _e.apply(void 0, [_f.sent()]); return [3 /*break*/, 18]; case 17: valOuts = toList(model.evaluate(valXs, valYs, { batchSize: args.validationBatchSize == null ? DEFAULT_VALIDATION_BATCH_SIZE : args.validationBatchSize, verbose: 0 })); _f.label = 18; case 18: for (i = 0; i < model.metricsNames.length; ++i) { epochLogs["val_".concat(model.metricsNames[i])] = valOuts[i]; } _f.label = 19; case 19: // Call `break` to exit one epoch lopp after validation is done. If // config.batchesPerEpoch is specified, an epoch while loop will // stop when `stepsDone >= config.batchesPerEpoch`. When // config.batchesPerEpoch is not provided, the following `break` is // required to exit the while lopp after dataset is exhausted. return [3 /*break*/, 21]; case 20: if (model.stopTraining_) { return [3 /*break*/, 21]; } return [3 /*break*/, 7]; case 21: return [4 /*yield*/, callbackList.onEpochEnd(epoch, epochLogs)]; case 22: _f.sent(); epoch++; if (model.stopTraining_) { return [3 /*break*/, 23]; } return [3 /*break*/, 4]; case 23: return [4 /*yield*/, callbackList.onTrainEnd()]; case 24: _f.sent(); return [4 /*yield*/, model.history.syncData()]; case 25: _f.sent(); return [2 /*return*/, model.history]; case 26: model.isTraining = false; return [7 /*endfinally*/]; case 27: return [2 /*return*/]; } }); }); } /** Helper function that determines number of steps (batches) per epoch. */ function getStepsPerEpoch(dataset, args) { // Attempt to determine # of batches in an epoch. var stepsPerEpoch = null; if (args.batchesPerEpoch != null) { stepsPerEpoch = args.batchesPerEpoch; } else if (Number.isFinite(dataset.size)) { stepsPerEpoch = dataset.size; } return stepsPerEpoch; } // Check if provided object is a Dataset object by checking its .iterator // element. function isDatasetObject(dataset) { return (typeof dataset.iterator === 'function'); } // Check if provided object is a LazyIterator object by checking it's .next // element. function isLazyIteratorObject(iterator) { return (typeof iterator.next === 'function'); } function evaluateDataset( // Type `model` as `any` here to avoid circular dependency w/ // training.ts. // tslint:disable-next-line:no-any model, dataset, args) { return __awaiter(this, void 0, void 0, function () { var hasBatches, f, outs, dataIterator, _a, numExamples, batch, _loop_3, state_1, i, oldScalar; return __generator(this, function (_b) { switch (_b.label) { case 0: args = args || {}; hasBatches = args.batches != null; f = model.testFunction; outs = []; if (args.verbose > 0) { throw new NotImplementedError('Verbose mode is not implemented yet.'); } tfc__namespace.util.assert(!hasBatches || (args.batches > 0 && Number.isInteger(args.batches)), function () { return 'Test loop expects `batches` to be a positive integer, but ' + "received ".concat(JSON.stringify(args.batches)); }); if (!isLazyIteratorObject(dataset)) return [3 /*break*/, 1]; _a = dataset; return [3 /*break*/, 3]; case 1: return [4 /*yield*/, dataset.iterator()]; case 2: _a = _b.sent(); _b.label = 3; case 3: dataIterator = _a; numExamples = 0; batch = 0; _loop_3 = function () { var iteratorOut; return __generator(this, function (_c) { switch (_c.label) { case 0: return [4 /*yield*/, dataIterator.next()]; case 1: iteratorOut = _c.sent(); outs = tfc__namespace.tidy(function () { if (iteratorOut.value) { // TODO(cais): Once real dataset is available, use // `map(x => standardizeDataIteratorOutput(model, x).map(f)`. var _a = standardizeDataIteratorOutput(model, iteratorOut.value), xs = _a.xs, ys = _a.ys; var xsAndYs_1 = xs.concat(ys); var batchOuts = tfc__namespace.tidy(function () { return f(xsAndYs_1); }); tfc__namespace.dispose(xsAndYs_1); if (batch === 0) { for (var i = 0; i < batchOuts.length; ++i) { outs.push(tfc.scalar(0)); } } var batchSize_1 = xsAndYs_1[0].shape[0]; var _loop_4 = function (i) { var batchOut = batchOuts[i]; var oldScalar = outs[i]; outs[i] = tfc__namespace.tidy(function () { return tfc__namespace.add(outs[i], tfc__namespace.mul(batchSize_1, batchOut)); }); if (batch > 0) { tfc__namespace.dispose(oldScalar); } }; for (var i = 0; i < batchOuts.length; ++i) { _loop_4(i); } tfc__namespace.dispose(batchOuts); numExamples += batchSize_1; ++batch; } return outs; }); if (iteratorOut.done) { if (hasBatches) { console.warn('Your dataset iterator ran out of data during evaluateDataset(). ' + 'Interrupting evalution. Make sure that your ' + 'dataset can generate at least `batches` ' + "batches (in this case, ".concat(args.batches, " batches). ") + 'You may need to use the repeat() function when building ' + 'your dataset.'); } return [2 /*return*/, "break"]; } return [2 /*return*/]; } }); }; _b.label = 4; case 4: if (!(hasBatches ? batch < args.batches : true)) return [3 /*break*/, 6]; return [5 /*yield**/, _loop_3()]; case 5: state_1 = _b.sent(); if (state_1 === "break") return [3 /*break*/, 6]; return [3 /*break*/, 4]; case 6: for (i = 0; i < outs.length; ++i) { oldScalar = outs[i]; outs[i] = tfc__namespace.div(outs[i], numExamples); tfc__namespace.dispose(oldScalar); } return [2 /*return*/, singletonOrArray(outs)]; } }); }); } /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ function checkBatchSize(batchSize) { tfc__namespace.util.assert(batchSize > 0 && Number.isInteger(batchSize), function () { return "batchSize is required to be a positive integer, but got ".concat(batchSize); }); } /** * Slice a Tensor or an Array of Tensors, by start and stop indices. * * Porting Note: The `_slice_arrays` function in PyKeras is covered by this * function and `sliceArraysByIndices()` together. * * @param arrays: the input. * @param start: the starting index (inclusive). * @param stop: the stopping index (exclusive). * @returns The result of the slicing. If `arrays` is an `Array` of * `tf.Tensor`s, the slicing will be applied to all elements of the `Array` * in the same way. */ function sliceArrays(arrays, start, stop) { if (arrays == null) { return [null]; } else if (Array.isArray(arrays)) { return arrays.map(function (array) { return sliceAlongFirstAxis(array, start, stop - start); }); } else { // Tensor. return sliceAlongFirstAxis(arrays, start, stop - start); } } /** * Slice a Tensor or an Array of Tensors, by random-order indices. * * Porting Note: The `_slice_arrays` function in PyKeras is covered by this * function and `sliceArrays()` together. * * @param arrays The input `tf.Tensor` or `Array` of `tf.Tensor`s to slice. * If an `Array` of `tf.Tensor`s, all `tf.Tensor`s will be sliced in the * same fashion. * @param indices The indices to use for slicing along the first (batch) * dimension. * @returns Result(s) of the slicing. */ function sliceArraysByIndices(arrays, indices) { return tfc__namespace.tidy(function () { if (arrays == null) { return null; } else if (Array.isArray(arrays)) { return arrays.map(function (array) { return sliceArraysByIndices(array, indices); }); } else { // TODO(cais): indices should be a pre-constructed Tensor1D to avoid // tensor1d() calls. return gather$1(arrays, indices.dtype === 'int32' ? indices : tfc__namespace.cast(indices, 'int32')); } }); } /** * Returns a list of batch indices (tuples of indices). * @param size: Integer, total size of the data to slice into batches. * @param batchSize: Integer, batch size. * @returns An Array of [batchStart, batchEnd] tuples. batchStart is * inclusive; batchEnd is exclusive. I.e., each batch consists of indices x * that satisfy batchStart <= x < batchEnd. */ function makeBatches(size, batchSize) { var output = []; var batchStart = 0; var batchEnd = null; while (batchStart < size) { batchEnd = batchStart + batchSize; if (batchEnd >= size) { batchEnd = size; } output.push([batchStart, batchEnd]); batchStart = batchEnd; } return output; } /** * Ensure tensors all have a rank of at least 2. * * If a tensor has a rank of 1, it is dimension-expanded to rank 2. * If any tensor has a rank of 0 (i.e., is a scalar), an error will be thrown. */ function ensureTensorsRank2OrHigher(tensors) { var outs = []; if (tensors instanceof tfc.Tensor) { tensors = [tensors]; } // Make Tensors at least 2D. for (var i = 0; i < tensors.length; ++i) { var tensor = tensors[i]; if (tensor.rank === 1) { outs.push(expandDims$1(tensor, 1)); } else if (tensor.rank === 0) { throw new Error('Expected tensor to be at least 1D, but received a 0D tensor ' + '(scalar).'); } else { outs.push(tensor); } } return outs; } /** * Compare a set of tensors with a reference (old) set, discard the ones * in the new set that are not present in the reference set. * * This method is used for memory clenaup during calls such as * LayersModel.fit(). * * @param tensors New set which may contain Tensors not present in * `refTensors`. * @param refTensors Reference Tensor set. */ // TODO(cais, kangyizhang): Deduplicate with tfjs-data. function disposeNewTensors(tensors, refTensors) { if (tensors == null) { return; } var oldTensorIds = []; if (refTensors instanceof tfc.Tensor) { oldTensorIds.push(refTensors.id); } else if (Array.isArray(refTensors)) { refTensors.forEach(function (t) { return oldTensorIds.push(t.id); }); } else if (refTensors != null) { // `oldTensors` is a map from string name to Tensor. for (var name in refTensors) { var oldTensor = refTensors[name]; oldTensorIds.push(oldTensor.id); } } var tensorsToDispose = []; if (tensors instanceof tfc.Tensor) { if (oldTensorIds.indexOf(tensors.id) === -1) { tensorsToDispose.push(tensors); } } else if (Array.isArray(tensors)) { tensors.forEach(function (t) { if (oldTensorIds.indexOf(t.id) === -1) { tensorsToDispose.push(t); } }); } else if (tensors != null) { // `oldTensors` is a map from string name to Tensor. for (var name in tensors) { var tensor = tensors[name]; if (oldTensorIds.indexOf(tensor.id) === -1) { tensorsToDispose.push(tensor); } } } tensorsToDispose.forEach(function (t) { if (!t.isDisposed) { t.dispose(); } }); } /** * Helper function for polymorphic input data: 1. singleton Tensor. */ function isDataTensor(x) { return x instanceof tfc.Tensor; } /** * Helper function for polymorphic input data: 2. Array of Tensor. */ function isDataArray(x) { return Array.isArray(x); } /** * Helper function for polymorphic input data: 3. "dict" of Tensor. */ function isDataDict(x) { return !isDataTensor(x) && !isDataArray(x); } /** * Normalizes inputs and targets provided by users. * @param data User-provided input data (polymorphic). * @param names An Array of expected Tensor names. * @param shapes Optional Array of expected Tensor shapes. * @param checkBatchAxis Whether to check that the batch axis of the arrays * match the expected value found in `shapes`. * @param exceptionPrefix String prefix used for exception formatting. * @returns List of standardized input Tensors (one Tensor per model input). * @throws ValueError: in case of improperly formatted user data. */ function standardizeInputData(data, names, shapes, checkBatchAxis, exceptionPrefix) { var e_1, _a; if (checkBatchAxis === void 0) { checkBatchAxis = true; } if (exceptionPrefix === void 0) { exceptionPrefix = ''; } if (names == null || names.length === 0) { // Check for the case where the model expected no data, but some data got // sent. if (data != null) { var gotUnexpectedData = false; if (isDataArray(data) && data.length > 0) { gotUnexpectedData = true; } else if (isDataDict(data)) { for (var key in data) { if (data.hasOwnProperty(key)) { gotUnexpectedData = true; break; } } } else { // `data` is a singleton Tensor in this case. gotUnexpectedData = true; } if (gotUnexpectedData) { throw new ValueError("Error when checking model ".concat(exceptionPrefix, " expected no data, ") + "but got ".concat(data)); } } return []; } if (data == null) { return names.map(function (name) { return null; }); } var arrays; if (isDataDict(data)) { data = data; arrays = []; try { for (var names_1 = __values(names), names_1_1 = names_1.next(); !names_1_1.done; names_1_1 = names_1.next()) { var name = names_1_1.value; if (data[name] == null) { throw new ValueError("No data provided for \"".concat(name, "\". Need data for each key in: ") + "".concat(names)); } arrays.push(data[name]); } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (names_1_1 && !names_1_1.done && (_a = names_1.return)) _a.call(names_1); } finally { if (e_1) throw e_1.error; } } } else if (isDataArray(data)) { data = data; if (data.length !== names.length) { throw new ValueError("Error when checking model ".concat(exceptionPrefix, ": the Array of ") + "Tensors that you are passing to your model is not the size the " + "model expected. Expected to see ".concat(names.length, " Tensor(s), but ") + "instead got the following list of Tensor(s): ".concat(data)); } arrays = data; } else { data = data; if (names.length > 1) { throw new ValueError("The model ".concat(exceptionPrefix, " expects ").concat(names.length, " Tensor(s), ") + "but only received one Tensor. Found: Tensor with shape ".concat(data.shape)); } arrays = [data]; } arrays = ensureTensorsRank2OrHigher(arrays); // Check shape compatibility. if (shapes != null) { for (var i = 0; i < names.length; ++i) { if (shapes[i] == null) { continue; } var array = arrays[i]; if (array.shape.length !== shapes[i].length) { throw new ValueError("Error when checking ".concat(exceptionPrefix, ": expected ").concat(names[i], " ") + "to have ".concat(shapes[i].length, " dimension(s). but got array with ") + "shape ".concat(array.shape)); } for (var j = 0; j < shapes[i].length; ++j) { if (j === 0 && !checkBatchAxis) { // Skip the first (batch) axis. continue; } var dim = array.shape[j]; var refDim = shapes[i][j]; if (refDim != null && refDim >= 0 && dim !== refDim) { throw new ValueError("".concat(exceptionPrefix, " expected a batch of elements where each ") + "example has shape [".concat(shapes[i].slice(1, shapes[i].length), "] ") + "(i.e.,tensor shape [*,".concat(shapes[i].slice(1, shapes[i].length), "])") + " but the ".concat(exceptionPrefix, " received an input with ").concat(array.shape[0]) + " examples, each with shape [".concat(array.shape.slice(1, array.shape.length), "]") + " (tensor shape [".concat(array.shape, "])")); } } } } return arrays; } /** * User input validation for Tensors. * @param inputs `Array` of `tf.Tensor`s for inputs. * @param targets `Array` of `tf.Tensor`s for targets. * @param weights Optional `Array` of `tf.Tensor`s for sample weights. * @throws ValueError: in case of incorrectly formatted data. */ function checkArrayLengths(inputs, targets, weights) { var setX = unique(inputs.map(function (input) { return input.shape[0]; })); setX.sort(); var setY = unique(targets.map(function (target) { return target.shape[0]; })); setY.sort(); // TODO(cais): Check `weights` as well. if (setX.length > 1) { throw new ValueError("All input Tensors (x) should have the same number of samples. " + "Got array shapes: " + "".concat(JSON.stringify(inputs.map(function (input) { return input.shape; })))); } if (setY.length > 1) { throw new ValueError("All target Tensors (y) should have the same number of samples. " + "Got array shapes: " + "".concat(JSON.stringify(targets.map(function (target) { return target.shape; })))); } if (setX.length > 0 && setY.length > 0 && !tfc.util.arraysEqual(setX, setY)) { throw new ValueError("Input Tensors should have the same number of samples as target " + "Tensors. Found ".concat(setX[0], " input sample(s) and ").concat(setY[0], " target ") + "sample(s)."); } } /** * Validation on the compatibility of targes and loss functions. * * This helps prevent users from using loss functions incorrectly. * * @param targets `Array` of `tf.Tensor`s of targets. * @param lossFns `Array` of loss functions. * @param outputShapes `Array` of shapes of model outputs. */ function checkLossAndTargetCompatibility(targets, lossFns, outputShapes) { // TODO(cais): Dedicated test coverage? var keyLosses = [ meanSquaredError$1, binaryCrossentropy$2, categoricalCrossentropy$2 ]; for (var i = 0; i < targets.length; ++i) { var y = targets[i]; var loss = lossFns[i]; var shape = outputShapes[i]; if (loss == null) { continue; } if (loss === categoricalCrossentropy$2) { if (y.shape[y.shape.length - 1] === 1) { throw new ValueError("You are passing a target array of shape ".concat(y.shape, " while using ") + "a loss 'categorical_crossentropy'. 'categorical_crossentropy'" + "expects targets to be binary matrices (1s and 0s) of shape " + "[samples, classes]."); // TODO(cais): Example code in error message. } } if (keyLosses.indexOf(loss) !== -1) { var slicedYShape = y.shape.slice(1); var slicedShape = shape.slice(1); for (var j = 0; j < slicedYShape.length; ++j) { var targetDim = slicedYShape[j]; var outDim = slicedShape[j]; if (outDim != null && targetDim !== outDim) { throw new ValueError("A target Tensor with shape ".concat(y.shape, " was passed for an ") + "output of shape ".concat(shape, ", while using a loss function that ") + "expects targets to have the same shape as the output."); } } } } } /** * Check inputs provided by the user. * * Porting Note: This corresponds to _standardize_input_data() in Python * Keras. Because of the strong typing in TF.js, we do not need to convert * the data. Specifically: * 1) in PyKeras, `data` can be `DataFrame` instances from pandas, for * example. We don't need to worry about that here because there is no * widely popular javascript/typesdcript equivalent of pandas (so far). * If one becomes available in the future, we can add support. * 2) in PyKeras, inputs can be Python dict. But here we are stipulating * that the data is either a single `tf.Tensor` or an Array of `tf.Tensor`s. We * may add support for `Object` data inputs in the future when the need * arises. * * Instead, we perform basic checks for number of parameters and shapes. * * @param data: The input data. * @param names: Name for the inputs, from the model. * @param shapes: Expected shapes for the input data, from the model. * @param checkBatchAxis: Whether the size along the batch axis (i.e., the * first dimension) will be checked for matching. * @param exceptionPrefix: Execption prefix message, used in generating error * messages. * @throws ValueError: on incorrect number of inputs or mismatches in shapes. */ function checkInputData(data, names, shapes, checkBatchAxis, exceptionPrefix) { if (checkBatchAxis === void 0) { checkBatchAxis = true; } if (exceptionPrefix === void 0) { exceptionPrefix = ''; } var arrays; if (Array.isArray(data)) { if (data.length !== names.length) { throw new ValueError("Error when checking model ".concat(exceptionPrefix, ": the Array of ") + "Tensors that you are passing to your model is not the size the " + "the model expected. Expected to see ".concat(names.length, " Tensor(s),") + " but instead got ".concat(data.length, " Tensors(s).")); } arrays = data; } else { if (names.length > 1) { throw new ValueError("The model expects ".concat(names.length, " ").concat(exceptionPrefix, " Tensors, ") + "but only received one Tensor. Found: array with shape " + "".concat(JSON.stringify(data.shape), ".")); } arrays = [data]; } if (shapes != null) { for (var i = 0; i < names.length; ++i) { if (shapes[i] == null) { continue; } var array = arrays[i]; if (array.shape.length !== shapes[i].length) { throw new ValueError("Error when checking ".concat(exceptionPrefix, ": expected ").concat(names[i], " ") + "to have ".concat(shapes[i].length, " dimension(s), but got array with ") + "shape ".concat(JSON.stringify(array.shape))); } for (var j = 0; j < shapes[i].length; ++j) { if (j === 0 && !checkBatchAxis) { continue; } var dim = array.shape[j]; var refDim = shapes[i][j]; if (refDim != null) { if (refDim !== dim) { throw new ValueError("Error when checking ".concat(exceptionPrefix, ": expected ") + "".concat(names[i], " to have shape ").concat(JSON.stringify(shapes[i]), " but ") + "got array with shape ".concat(JSON.stringify(array.shape), ".")); } } } } } } /** * Maps metric functions to model outputs. * @param metrics An shortcut strings name, metric function, `Array` or dict * (`Object`) of metric functions. * @param outputNames An `Array` of the names of model outputs. * @returns An `Array` (one entry per model output) of `Array` of metric * functions. For instance, if the model has 2 outputs, and for the first * output we want to compute `binaryAccuracy` and `binaryCrossentropy`, * and just `binaryAccuracy` for the second output, the `Array` would look * like: * `[[binaryAccuracy, binaryCrossentropy], [binaryAccuracy]]` * @throws TypeError: incompatible metrics format. */ function collectMetrics(metrics, outputNames) { var e_2, _a; if (metrics == null || Array.isArray(metrics) && metrics.length === 0) { return outputNames.map(function (name) { return []; }); } var wrappedMetrics; if (typeof metrics === 'string' || typeof metrics === 'function') { wrappedMetrics = [metrics]; } else if (Array.isArray(metrics) || typeof metrics === 'object') { wrappedMetrics = metrics; } else { throw new TypeError('Type of metrics argument not understood. Expected an string,' + "function, Array, or Object, found: ".concat(metrics)); } if (Array.isArray(wrappedMetrics)) { // We then apply all metrics to all outputs. return outputNames.map(function (name) { return wrappedMetrics; }); } else { // In this case, metrics is a dict. var nestedMetrics = []; try { for (var outputNames_1 = __values(outputNames), outputNames_1_1 = outputNames_1.next(); !outputNames_1_1.done; outputNames_1_1 = outputNames_1.next()) { var name = outputNames_1_1.value; var outputMetrics = wrappedMetrics.hasOwnProperty(name) ? wrappedMetrics[name] : []; if (!Array.isArray(outputMetrics)) { outputMetrics = [outputMetrics]; } nestedMetrics.push(outputMetrics); } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (outputNames_1_1 && !outputNames_1_1.done && (_a = outputNames_1.return)) _a.call(outputNames_1); } finally { if (e_2) throw e_2.error; } } return nestedMetrics; } } var LAYERS_MODEL_FORMAT_NAME = 'layers-model'; /** * A `tf.LayersModel` is a directed, acyclic graph of `tf.Layer`s plus methods * for training, evaluation, prediction and saving. * * `tf.LayersModel` is the basic unit of training, inference and evaluation in * TensorFlow.js. To create a `tf.LayersModel`, use `tf.LayersModel`. * * See also: * `tf.Sequential`, `tf.loadLayersModel`. * * @doc {heading: 'Models', subheading: 'Classes'} */ var LayersModel = /** @class */ (function (_super) { __extends(LayersModel, _super); function LayersModel(args) { var _this = _super.call(this, args) || this; _this.isTraining = false; return _this; } /** * Print a text summary of the model's layers. * * The summary includes * - Name and type of all layers that comprise the model. * - Output shape(s) of the layers * - Number of weight parameters of each layer * - If the model has non-sequential-like topology, the inputs each layer * receives * - The total number of trainable and non-trainable parameters of the model. * * ```js * const input1 = tf.input({shape: [10]}); * const input2 = tf.input({shape: [20]}); * const dense1 = tf.layers.dense({units: 4}).apply(input1); * const dense2 = tf.layers.dense({units: 8}).apply(input2); * const concat = tf.layers.concatenate().apply([dense1, dense2]); * const output = * tf.layers.dense({units: 3, activation: 'softmax'}).apply(concat); * * const model = tf.model({inputs: [input1, input2], outputs: output}); * model.summary(); * ``` * * @param lineLength Custom line length, in number of characters. * @param positions Custom widths of each of the columns, as either * fractions of `lineLength` (e.g., `[0.5, 0.75, 1]`) or absolute number * of characters (e.g., `[30, 50, 65]`). Each number corresponds to * right-most (i.e., ending) position of a column. * @param printFn Custom print function. Can be used to replace the default * `console.log`. For example, you can use `x => {}` to mute the printed * messages in the console. * * @doc {heading: 'Models', subheading: 'Classes'} */ LayersModel.prototype.summary = function (lineLength, positions, printFn) { if (printFn === void 0) { printFn = console.log; } if (!this.built) { throw new ValueError("This model has never been called, thus its weights have not been " + "created yet. So no summary can be displayed. Build the model " + "first (e.g., by calling it on some test data)."); } printSummary(this, lineLength, positions, printFn); }; /** * Configures and prepares the model for training and evaluation. Compiling * outfits the model with an optimizer, loss, and/or metrics. Calling `fit` * or `evaluate` on an un-compiled model will throw an error. * * @param args a `ModelCompileArgs` specifying the loss, optimizer, and * metrics to be used for fitting and evaluating this model. * * @doc {heading: 'Models', subheading: 'Classes'} */ LayersModel.prototype.compile = function (args) { var e_3, _a; var _this = this; if (args.loss == null) { args.loss = []; } this.loss = args.loss; if (typeof args.optimizer === 'string') { this.optimizer_ = getOptimizer(args.optimizer); this.isOptimizerOwned = true; } else { if (!(args.optimizer instanceof tfc.Optimizer)) { throw new ValueError("User-defined optimizer must be an instance of tf.Optimizer."); } this.optimizer_ = args.optimizer; this.isOptimizerOwned = false; } // TODO(cais): Add lossWeights. // TODO(cais): Add sampleWeightMode. // Prepare loss functions. var lossFunctions = []; if (!Array.isArray(args.loss) && typeof args.loss !== 'string' && typeof args.loss !== 'function') { args.loss = args.loss; for (var name in args.loss) { if (this.outputNames.indexOf(name) === -1) { throw new ValueError("Unknown entry in loss dictionary: \"".concat(name, "\". ") + "Only expected the following keys: ".concat(this.outputNames)); } } try { for (var _b = __values(this.outputNames), _c = _b.next(); !_c.done; _c = _b.next()) { var name = _c.value; if (args.loss[name] == null) { console.warn("Output \"".concat(name, "\" is missing from loss dictionary. We assume ") + "this was done on purpose, and we will not be expecting data " + "to be passed to ".concat(name, " during training")); } lossFunctions.push(get$1(args.loss[name])); } } catch (e_3_1) { e_3 = { error: e_3_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_3) throw e_3.error; } } } else if (Array.isArray(args.loss)) { if (args.loss.length !== this.outputs.length) { throw new ValueError("When passing an Array as loss, it should have one entry per " + "model output. The model has ".concat(this.outputs.length, " output(s), ") + "but you passed loss=".concat(args.loss, ".")); } var theLosses = args.loss; lossFunctions = theLosses.map(function (l) { return get$1(l); }); } else { var lossFunction_1 = get$1(args.loss); this.outputs.forEach(function (_) { lossFunctions.push(lossFunction_1); }); } this.lossFunctions = lossFunctions; this.feedOutputNames = []; this.feedOutputShapes = []; this.feedLossFns = []; for (var i = 0; i < this.outputs.length; ++i) { // TODO(cais): Logic for skipping target(s). var shape = this.internalOutputShapes[i]; var name = this.outputNames[i]; this.feedOutputNames.push(name); this.feedOutputShapes.push(shape); this.feedLossFns.push(this.lossFunctions[i]); } // TODO(cais): Add logic for output masks. // TODO(cais): Add logic for sample weights. var skipTargetIndices = []; // Prepare metrics. this.metrics = args.metrics; // TODO(cais): Add weightedMetrics. this.metricsNames = ['loss']; this.metricsTensors = []; // Compute total loss. // Porting Note: In PyKeras, metrics_tensors are symbolic tensor objects. // Here, metricsTensors are TypeScript functions. This difference is due // to the difference in symbolic/imperative property of the backends. nameScope('loss', function () { for (var i = 0; i < _this.outputs.length; ++i) { if (skipTargetIndices.indexOf(i) !== -1) { continue; } // TODO(cais): Add weightedLoss, sampleWeight and mask. // The following line should be weightedLoss var weightedLoss = _this.lossFunctions[i]; if (_this.outputs.length > 1) { _this.metricsTensors.push([weightedLoss, i]); _this.metricsNames.push(_this.outputNames[i] + '_loss'); } } // Porting Note: Due to the imperative nature of the backend, we calculate // the regularizer penalties in the totalLossFunction, instead of here. }); var nestedMetrics = collectMetrics(args.metrics, this.outputNames); // TODO(cais): Add nestedWeightedMetrics. /** * Helper function used in loop below. */ var appendMetric = function (outputIndex, metricName, metricTensor) { if (_this.outputNames.length > 1) { metricName = _this.outputNames[outputIndex] + '_' + metricName; } _this.metricsNames.push(metricName); _this.metricsTensors.push([metricTensor, outputIndex]); }; nameScope('metric', function () { var _loop_1 = function (i) { if (skipTargetIndices.indexOf(i) !== -1) { return "continue"; } var outputMetrics = nestedMetrics[i]; // TODO(cais): Add weights and outputWeightedMetrics. // TODO(cais): Add optional arg `weights` to the following function. var handleMetrics = function (metrics) { var e_4, _a; var metricNamePrefix = ''; var metricName; var accFn; var weightedMetricFn; var _loop_2 = function (metric) { if (typeof metric === 'string' && ['accuracy', 'acc', 'crossentropy', 'ce'].indexOf(metric) !== -1) { var outputShape = _this.internalOutputShapes[i]; if (outputShape[outputShape.length - 1] === 1 || _this.lossFunctions[i] === binaryCrossentropy$2) { // case: binary accuracy/crossentropy. if (['accuracy', 'acc'].indexOf(metric) !== -1) { accFn = binaryAccuracy$1; } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) { accFn = binaryCrossentropy$1; } } else if (_this.lossFunctions[i] === sparseCategoricalCrossentropy$1) { // case: categorical accuracy / crossentropy with sparse // targets. if (['accuracy', 'acc'].indexOf(metric) !== -1) { accFn = sparseCategoricalAccuracy$1; } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) { accFn = sparseCategoricalCrossentropy; } } else { // case: categorical accuracy / crossentropy. if (['accuracy', 'acc'].indexOf(metric) !== -1) { accFn = categoricalAccuracy$1; } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) { accFn = categoricalCrossentropy$1; } } var suffix = void 0; if (['accuracy', 'acc'].indexOf(metric) !== -1) { suffix = 'acc'; } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) { suffix = 'ce'; } // TODO(cais): Add weighting actually. weightedMetricFn = accFn; metricName = metricNamePrefix + suffix; } else { var metricFn = get(metric); // TODO(cais): Add weighting actually. weightedMetricFn = metricFn; metricName = metricNamePrefix + getLossOrMetricName(metric); } // TODO(cais): Add weighting and masking to metricResult. var metricResult; nameScope(metricName, function () { metricResult = weightedMetricFn; }); appendMetric(i, metricName, metricResult); }; try { // TODO(cais): Use 'weights_' for weighted metrics. for (var metrics_1 = (e_4 = void 0, __values(metrics)), metrics_1_1 = metrics_1.next(); !metrics_1_1.done; metrics_1_1 = metrics_1.next()) { var metric = metrics_1_1.value; _loop_2(metric); } } catch (e_4_1) { e_4 = { error: e_4_1 }; } finally { try { if (metrics_1_1 && !metrics_1_1.done && (_a = metrics_1.return)) _a.call(metrics_1); } finally { if (e_4) throw e_4.error; } } }; handleMetrics(outputMetrics); }; for (var i = 0; i < _this.outputs.length; ++i) { _loop_1(i); } }); // Porting Notes: Given the imperative backend of tfjs-core, // there is no need for constructing the symbolic graph and placeholders. this.collectedTrainableWeights = this.trainableWeights; }; /** * Check trainable weights count consistency. * * This will raise a warning if `this.trainableWeights` and * `this.collectedTrainableWeights` are inconsistent (i.e., have different * numbers of parameters). * Inconsistency will typically arise when one modifies `model.trainable` * without calling `model.compile()` again. */ LayersModel.prototype.checkTrainableWeightsConsistency = function () { if (this.collectedTrainableWeights == null) { return; } if (this.trainableWeights.length !== this.collectedTrainableWeights.length) { console.warn('Discrepancy between trainableweights and collected trainable ' + 'weights. Did you set `model.trainable` without calling ' + '`model.compile()` afterwards?'); } }; /** * Returns the loss value & metrics values for the model in test mode. * * Loss and metrics are specified during `compile()`, which needs to happen * before calls to `evaluate()`. * * Computation is done in batches. * * ```js * const model = tf.sequential({ * layers: [tf.layers.dense({units: 1, inputShape: [10]})] * }); * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'}); * const result = model.evaluate( * tf.ones([8, 10]), tf.ones([8, 1]), {batchSize: 4}); * result.print(); * ``` * * @param x `tf.Tensor` of test data, or an `Array` of `tf.Tensor`s if the * model has multiple inputs. * @param y `tf.Tensor` of target data, or an `Array` of `tf.Tensor`s if the * model has multiple outputs. * @param args A `ModelEvaluateArgs`, containing optional fields. * * @return `Scalar` test loss (if the model has a single output and no * metrics) or `Array` of `Scalar`s (if the model has multiple outputs * and/or metrics). The attribute `model.metricsNames` * will give you the display labels for the scalar outputs. * * @doc {heading: 'Models', subheading: 'Classes'} */ LayersModel.prototype.evaluate = function (x, y, args) { if (args === void 0) { args = {}; } var batchSize = args.batchSize == null ? 32 : args.batchSize; checkBatchSize(batchSize); // TODO(cais): Standardize `config.sampleWeights` as well. // Validate user data. var checkBatchAxis = true; var standardizedOuts = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize); try { // TODO(cais): If uses `useLearningPhase`, set the corresponding element // of the input to 0. var ins = standardizedOuts[0].concat(standardizedOuts[1]); this.makeTestFunction(); var f = this.testFunction; var testOuts = this.testLoop(f, ins, batchSize, args.verbose, args.steps); return singletonOrArray(testOuts); } finally { disposeNewTensors(standardizedOuts[0], x); disposeNewTensors(standardizedOuts[1], y); } }; // TODO(cais): Add code snippet below once real dataset objects are // available. /** * Evaluate model using a dataset object. * * Note: Unlike `evaluate()`, this method is asynchronous (`async`). * * @param dataset A dataset object. Its `iterator()` method is expected * to generate a dataset iterator object, the `next()` method of which * is expected to produce data batches for evaluation. The return value * of the `next()` call ought to contain a boolean `done` field and a * `value` field. The `value` field is expected to be an array of two * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former * case is for models with exactly one input and one output (e.g. * a sequential model). The latter case is for models with multiple * inputs and/or multiple outputs. Of the two items in the array, the * first is the input feature(s) and the second is the output target(s). * @param args A configuration object for the dataset-based evaluation. * @returns Loss and metric values as an Array of `Scalar` objects. * * @doc {heading: 'Models', subheading: 'Classes'} */ LayersModel.prototype.evaluateDataset = function (dataset, args) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { this.makeTestFunction(); return [2 /*return*/, evaluateDataset(this, dataset, args)]; }); }); }; /** * Get number of samples provided for training, evaluation or prediction. * * @param ins Input `tf.Tensor`. * @param batchSize Integer batch size, optional. * @param steps Total number of steps (batches of samples) before * declaring loop finished. Optional. * @param stepsName The public API's parameter name for `steps`. * @returns Number of samples provided. */ LayersModel.prototype.checkNumSamples = function (ins, batchSize, steps, stepsName) { if (stepsName === void 0) { stepsName = 'steps'; } var numSamples; if (steps != null) { numSamples = null; if (batchSize != null) { throw new ValueError("If ".concat(stepsName, " is set, batchSize must be null or undefined.") + "Got batchSize = ".concat(batchSize)); } } else if (ins != null) { if (Array.isArray(ins)) { numSamples = ins[0].shape[0]; } else { numSamples = ins.shape[0]; } } else { throw new ValueError("Either the input data should have a defined shape, or " + "".concat(stepsName, " shoud be specified.")); } return numSamples; }; /** * Execute internal tensors of the model with input data feed. * @param inputs Input data feed. Must match the inputs of the model. * @param outputs Names of the output tensors to be fetched. Must match * names of the SymbolicTensors that belong to the graph. * @returns Fetched values for `outputs`. */ LayersModel.prototype.execute = function (inputs, outputs) { var e_5, _a; if (Array.isArray(outputs) && outputs.length === 0) { throw new ValueError('`outputs` is an empty Array, which is not allowed.'); } var outputsIsArray = Array.isArray(outputs); var outputNames = (outputsIsArray ? outputs : [outputs]); var outputSymbolicTensors = this.retrieveSymbolicTensors(outputNames); // Format the input into a FeedDict. var feedDict = new FeedDict(); if (inputs instanceof tfc.Tensor) { inputs = [inputs]; } if (Array.isArray(inputs)) { if (inputs.length !== this.inputs.length) { throw new ValueError("The number of inputs provided (".concat(inputs.length, ") ") + "does not match the number of inputs of this model " + "(".concat(this.inputs.length, ").")); } for (var i = 0; i < this.inputs.length; ++i) { feedDict.add(this.inputs[i], inputs[i]); } } else { try { for (var _b = __values(this.inputs), _c = _b.next(); !_c.done; _c = _b.next()) { var input = _c.value; var tensorValue = inputs[input.name]; if (tensorValue == null) { throw new ValueError("No value is provided for the model's input ".concat(input.name)); } feedDict.add(input, tensorValue); } } catch (e_5_1) { e_5 = { error: e_5_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_5) throw e_5.error; } } } // Run execution. var executeOutputs = execute(outputSymbolicTensors, feedDict); return outputsIsArray ? executeOutputs : executeOutputs[0]; }; /** * Retrieve the model's internal symbolic tensors from symbolic-tensor names. */ LayersModel.prototype.retrieveSymbolicTensors = function (symbolicTensorNames) { var e_6, _a; var outputSymbolicTensors = pyListRepeat(null, symbolicTensorNames.length); var outputsRemaining = symbolicTensorNames.length; try { for (var _b = __values(this.layers), _c = _b.next(); !_c.done; _c = _b.next()) { var layer = _c.value; var layerOutputs = Array.isArray(layer.output) ? layer.output : [layer.output]; var layerOutputNames = layerOutputs.map(function (output) { return output.name; }); for (var i = 0; i < symbolicTensorNames.length; ++i) { var index = layerOutputNames.indexOf(symbolicTensorNames[i]); if (index !== -1) { outputSymbolicTensors[i] = layerOutputs[index]; outputsRemaining--; } if (outputsRemaining === 0) { break; } } if (outputsRemaining === 0) { break; } } } catch (e_6_1) { e_6 = { error: e_6_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_6) throw e_6.error; } } if (outputsRemaining > 0) { var remainingNames_1 = []; outputSymbolicTensors.forEach(function (tensor, i) { if (tensor == null) { remainingNames_1.push(symbolicTensorNames[i]); } }); throw new ValueError("Cannot find SymbolicTensors for output name(s): " + "".concat(JSON.stringify(remainingNames_1))); } return outputSymbolicTensors; }; /** * Helper method to loop over some data in batches. * * Porting Note: Not using the functional approach in the Python equivalent * due to the imperative backend. * Porting Note: Does not support step mode currently. * * @param ins: input data * @param batchSize: integer batch size. * @param verbose: verbosity model * @returns: Predictions as `tf.Tensor` (if a single output) or an `Array` of * `tf.Tensor` (if multipe outputs). */ LayersModel.prototype.predictLoop = function (ins, batchSize, verbose) { var _this = this; if (batchSize === void 0) { batchSize = 32; } if (verbose === void 0) { verbose = false; } return tfc__namespace.tidy(function () { var numSamples = _this.checkNumSamples(ins); if (verbose) { throw new NotImplementedError('Verbose predictLoop() is not implemented yet.'); } // Sample-based predictions. // Porting Note: Tensor currently does not support sliced assignments as // in numpy, e.g., x[1:3] = y. Therefore we use concatenation while // iterating over the batches. var batches = makeBatches(numSamples, batchSize); var outsBatches = _this.outputs.map(function (output) { return []; }); var _loop_3 = function (batchIndex) { var batchOuts = tfc__namespace.tidy(function () { var batchStart = batches[batchIndex][0]; var batchEnd = batches[batchIndex][1]; // TODO(cais): Take care of the case of the last element is a flag for // training/test. var insBatch = sliceArrays(ins, batchStart, batchEnd); // Construct the feeds for execute(); var feeds = []; if (Array.isArray(insBatch)) { for (var i = 0; i < insBatch.length; ++i) { feeds.push({ key: _this.inputs[i], value: insBatch[i] }); } } else { feeds.push({ key: _this.inputs[0], value: insBatch }); } var feedDict = new FeedDict(feeds); return execute(_this.outputs, feedDict); }); batchOuts.forEach(function (batchOut, i) { return outsBatches[i].push(batchOut); }); }; // TODO(cais): Can the scope() be pushed down inside the for loop? for (var batchIndex = 0; batchIndex < batches.length; ++batchIndex) { _loop_3(batchIndex); } return singletonOrArray(outsBatches.map(function (batches) { return tfc__namespace.concat(batches, 0); })); }); }; /** * Generates output predictions for the input samples. * * Computation is done in batches. * * Note: the "step" mode of predict() is currently not supported. * This is because the TensorFlow.js core backend is imperative only. * * ```js * const model = tf.sequential({ * layers: [tf.layers.dense({units: 1, inputShape: [10]})] * }); * model.predict(tf.ones([8, 10]), {batchSize: 4}).print(); * ``` * * @param x The input data, as a Tensor, or an `Array` of `tf.Tensor`s if * the model has multiple inputs. * @param args A `ModelPredictArgs` object containing optional fields. * * @return Prediction results as a `tf.Tensor`(s). * * @exception ValueError In case of mismatch between the provided input data * and the model's expectations, or in case a stateful model receives a * number of samples that is not a multiple of the batch size. * * @doc {heading: 'Models', subheading: 'Classes'} */ LayersModel.prototype.predict = function (x, args) { if (args === void 0) { args = {}; } var xsRank2OrHigher = ensureTensorsRank2OrHigher(x); checkInputData(xsRank2OrHigher, this.inputNames, this.feedInputShapes, false); try { // TODO(cais): Take care of stateful models. // if (this.stateful) ... // TODO(cais): Take care of the learning_phase boolean flag. // if (this.useLearningPhase) ... var batchSize = args.batchSize == null ? 32 : args.batchSize; checkBatchSize(batchSize); return this.predictLoop(xsRank2OrHigher, batchSize); } finally { disposeNewTensors(xsRank2OrHigher, x); } }; /** * Returns predictions for a single batch of samples. * * ```js * const model = tf.sequential({ * layers: [tf.layers.dense({units: 1, inputShape: [10]})] * }); * model.predictOnBatch(tf.ones([8, 10])).print(); * ``` * @param x: Input samples, as a Tensor (for models with exactly one * input) or an array of Tensors (for models with more than one input). * @return Tensor(s) of predictions * * @doc {heading: 'Models', subheading: 'Classes'} */ LayersModel.prototype.predictOnBatch = function (x) { checkInputData(x, this.inputNames, this.feedInputShapes, true); // TODO(cais): Take care of the learning_phase boolean flag. // if (this.useLearningPhase) ... var batchSize = (Array.isArray(x) ? x[0] : x).shape[0]; return this.predictLoop(x, batchSize); }; LayersModel.prototype.standardizeUserDataXY = function (x, y, checkBatchAxis, batchSize) { // TODO(cais): Add sampleWeight, classWeight if (this.optimizer_ == null) { throw new RuntimeError('You must compile a model before training/testing. Use ' + 'LayersModel.compile(modelCompileArgs).'); } var outputShapes = []; for (var i = 0; i < this.feedOutputShapes.length; ++i) { var outputShape = this.feedOutputShapes[i]; var lossFn = this.feedLossFns[i]; if (lossFn === sparseCategoricalCrossentropy$1) { outputShapes.push(outputShape.slice(0, outputShape.length - 1).concat([1])); } else { // Porting Note: Because of strong typing `lossFn` must be a function. outputShapes.push(outputShape); } } x = standardizeInputData(x, this.feedInputNames, this.feedInputShapes, false, 'input'); y = standardizeInputData(y, this.feedOutputNames, outputShapes, false, 'target'); // TODO(cais): Standardize sampleWeights & classWeights. checkArrayLengths(x, y); // TODO(cais): Check sampleWeights as well. checkLossAndTargetCompatibility(y, this.feedLossFns, this.feedOutputShapes); if (this.stateful && batchSize != null && batchSize > 0) { if (x[0].shape[0] % batchSize !== 0) { throw new ValueError("In a stateful network, you should only pass inputs with a " + "number of samples that is divisible by the batch size " + "".concat(batchSize, ". Found: ").concat(x[0].shape[0], " sample(s).")); } } return [x, y]; }; LayersModel.prototype.standardizeUserData = function (x, y, sampleWeight, classWeight, checkBatchAxis, batchSize) { if (checkBatchAxis === void 0) { checkBatchAxis = true; } return __awaiter(this, void 0, void 0, function () { var _a, standardXs, standardYs, standardSampleWeights, classWeights, i, _b, _c; return __generator(this, function (_d) { switch (_d.label) { case 0: _a = __read(this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize), 2), standardXs = _a[0], standardYs = _a[1]; // TODO(cais): Handle sampleWeights. if (sampleWeight != null) { throw new Error('sample weight is not supported yet.'); } standardSampleWeights = null; if (!(classWeight != null)) return [3 /*break*/, 4]; classWeights = standardizeClassWeights(classWeight, this.outputNames); standardSampleWeights = []; i = 0; _d.label = 1; case 1: if (!(i < classWeights.length)) return [3 /*break*/, 4]; _c = (_b = standardSampleWeights).push; return [4 /*yield*/, standardizeWeights(standardYs[i], null, classWeights[i])]; case 2: _c.apply(_b, [_d.sent()]); _d.label = 3; case 3: ++i; return [3 /*break*/, 1]; case 4: // TODO(cais): Deal with the case of model.stateful == true. return [2 /*return*/, [standardXs, standardYs, standardSampleWeights]]; } }); }); }; /** * Loop over some test data in batches. * @param f A Function returning a list of tensors. * @param ins Array of tensors to be fed to `f`. * @param batchSize Integer batch size or `null` / `undefined`. * @param verbose verbosity mode. * @param steps Total number of steps (batches of samples) before * declaring test finished. Ignored with the default value of `null` / * `undefined`. * @returns Array of Scalars. */ LayersModel.prototype.testLoop = function (f, ins, batchSize, verbose, steps) { var _this = this; if (verbose === void 0) { verbose = 0; } return tfc__namespace.tidy(function () { var numSamples = _this.checkNumSamples(ins, batchSize, steps, 'steps'); var outs = []; if (verbose > 0) { throw new NotImplementedError('Verbose mode is not implemented yet.'); } // TODO(cais): Use `indicesForConversionToDense' to prevent slow down. if (steps != null) { throw new NotImplementedError('steps mode in testLoop() is not implemented yet'); } else { var batches = makeBatches(numSamples, batchSize); var indexArray = tfc.tensor1d(range(0, numSamples)); for (var batchIndex = 0; batchIndex < batches.length; ++batchIndex) { var batchStart = batches[batchIndex][0]; var batchEnd = batches[batchIndex][1]; var batchIds = sliceAlongFirstAxis(indexArray, batchStart, batchEnd - batchStart); // TODO(cais): In ins, train flag can be a number, instead of an // Tensor? Do we need to handle this in tfjs-layers? var insBatch = sliceArraysByIndices(ins, batchIds); var batchOuts = f(insBatch); if (batchIndex === 0) { for (var i = 0; i < batchOuts.length; ++i) { outs.push(tfc.scalar(0)); } } for (var i = 0; i < batchOuts.length; ++i) { var batchOut = batchOuts[i]; outs[i] = tfc__namespace.add(outs[i], tfc__namespace.mul(batchEnd - batchStart, batchOut)); } } for (var i = 0; i < outs.length; ++i) { outs[i] = tfc__namespace.div(outs[i], numSamples); } } return outs; }); }; LayersModel.prototype.getDedupedMetricsNames = function () { var outLabels = this.metricsNames; // Rename duplicated metrics names (can happen with an output layer // shared among multiple dataflows). var dedupedOutLabels = []; for (var i = 0; i < outLabels.length; ++i) { var label = outLabels[i]; var newLabel = label; if (count(outLabels, label) > 1) { var dupIndex = count(outLabels.slice(0, i), label); newLabel += "_".concat(dupIndex); } dedupedOutLabels.push(newLabel); } return dedupedOutLabels; }; /** * Creates a function that performs the following actions: * * 1. computes the losses * 2. sums them to get the total loss * 3. call the optimizer computes the gradients of the LayersModel's * trainable weights w.r.t. the total loss and update the variables * 4. calculates the metrics * 5. returns the values of the losses and metrics. */ LayersModel.prototype.makeTrainFunction = function () { var _this = this; return function (data) { var lossValues = []; var inputs = data.slice(0, _this.inputs.length); var targets = data.slice(_this.inputs.length, _this.inputs.length + _this.outputs.length); var sampleWeights = data.slice(_this.inputs.length + _this.outputs.length, _this.inputs.length + _this.outputs.length * 2); var metricsValues = []; // Create a function that computes the total loss based on the // inputs. This function is used for obtaining gradients through // backprop. var totalLossFunction = function () { var feeds = []; for (var i = 0; i < _this.inputs.length; ++i) { feeds.push({ key: _this.inputs[i], value: inputs[i] }); } var feedDict = new FeedDict(feeds); var outputs = execute(_this.outputs, feedDict, { 'training': true }); // TODO(cais): Take care of the case of multiple outputs from a // single layer? var totalLoss; for (var i = 0; i < _this.lossFunctions.length; ++i) { var lossFunction = _this.lossFunctions[i]; var loss = lossFunction(targets[i], outputs[i]); if (sampleWeights[i] != null) { loss = computeWeightedLoss(loss, sampleWeights[i]); } // TODO(cais): push Scalar instead. var meanLoss = tfc__namespace.mean(loss); // TODO(cais): Use a scope() instead, to avoid ownership. lossValues.push(meanLoss); if (i === 0) { totalLoss = loss; } else { totalLoss = tfc__namespace.add(totalLoss, loss); } } // Compute the metrics. // TODO(cais): These should probably be calculated outside // totalLossFunction to benefit speed? for (var i = 0; i < _this.metricsTensors.length; ++i) { var weightedMetric = void 0; if (_this.outputs.length > 1 && i < _this.outputs.length) { weightedMetric = lossValues[i]; } else { var metric = _this.metricsTensors[i][0]; var outputIndex = _this.metricsTensors[i][1]; weightedMetric = tfc__namespace.mean(metric(targets[outputIndex], outputs[outputIndex])); } tfc__namespace.keep(weightedMetric); // TODO(cais): Use a scope() instead, to avoid ownership. metricsValues.push(weightedMetric); } totalLoss = tfc__namespace.mean(totalLoss); // Add regularizer penalties. _this.calculateLosses().forEach(function (regularizerLoss) { totalLoss = tfc__namespace.add(totalLoss, regularizerLoss); }); return totalLoss; }; var variables = _this.collectedTrainableWeights.map(function (param) { return param.read(); }); var returnCost = true; var totalLossValue = _this.optimizer_.minimize(totalLossFunction, returnCost, variables); return [totalLossValue].concat(metricsValues); }; }; /** * Create a function which, when invoked with an array of `tf.Tensor`s as a * batch of inputs, returns the prespecified loss and metrics of the model * under the batch of input data. */ LayersModel.prototype.makeTestFunction = function () { var _this = this; this.testFunction = function (data) { return tfc__namespace.tidy(function () { var valOutputs = []; var totalLoss; var inputs = data.slice(0, _this.inputs.length); var targets = data.slice(_this.inputs.length, _this.inputs.length + _this.outputs.length); var feeds = []; for (var i = 0; i < _this.inputs.length; ++i) { feeds.push({ key: _this.inputs[i], value: inputs[i] }); } var feedDict = new FeedDict(feeds); var outputs = execute(_this.outputs, feedDict); // Compute total loss. for (var i = 0; i < _this.lossFunctions.length; ++i) { var lossFunction = _this.lossFunctions[i]; // TODO(cais): Add sample weighting and replace the simple // averaging. var loss = tfc__namespace.mean(lossFunction(targets[i], outputs[i])); if (i === 0) { totalLoss = loss; } else { totalLoss = tfc__namespace.add(totalLoss, loss); } valOutputs.push(totalLoss); } // Compute the metrics. for (var i = 0; i < _this.metricsTensors.length; ++i) { var metric = _this.metricsTensors[i][0]; var outputIndex = _this.metricsTensors[i][1]; // TODO(cais): Replace K.mean() with a proper weighting function. var meanMetric = tfc__namespace.mean(metric(targets[outputIndex], outputs[outputIndex])); valOutputs.push(meanMetric); } return valOutputs; }); }; }; /** * Trains the model for a fixed number of epochs (iterations on a * dataset). * * ```js * const model = tf.sequential({ * layers: [tf.layers.dense({units: 1, inputShape: [10]})] * }); * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'}); * for (let i = 1; i < 5 ; ++i) { * const h = await model.fit(tf.ones([8, 10]), tf.ones([8, 1]), { * batchSize: 4, * epochs: 3 * }); * console.log("Loss after Epoch " + i + " : " + h.history.loss[0]); * } * ``` * * @param x `tf.Tensor` of training data, or an array of `tf.Tensor`s if the * model has multiple inputs. If all inputs in the model are named, you * can also pass a dictionary mapping input names to `tf.Tensor`s. * @param y `tf.Tensor` of target (label) data, or an array of `tf.Tensor`s if * the model has multiple outputs. If all outputs in the model are named, * you can also pass a dictionary mapping output names to `tf.Tensor`s. * @param args A `ModelFitArgs`, containing optional fields. * * @return A `History` instance. Its `history` attribute contains all * information collected during training. * * @exception ValueError In case of mismatch between the provided input * data and what the model expects. * * @doc {heading: 'Models', subheading: 'Classes'} */ LayersModel.prototype.fit = function (x, y, args) { if (args === void 0) { args = {}; } return __awaiter(this, void 0, void 0, function () { var inputs, targets, originalInputs, originalTargets, inputValX, inputValY, valX, valY, sampleWeights, batchSize, checkBatchAxis, standardizedOuts, doValidation, valIns, checkBatchAxis_1, valStandardized, splitAt, originalBatchSize, ins, trainFunction, outLabels, valFunction, callbackMetrics, callbacks, out; return __generator(this, function (_a) { switch (_a.label) { case 0: if (this.isTraining) { throw new Error('Cannot start training because another fit() call is ongoing.'); } this.isTraining = true; _a.label = 1; case 1: _a.trys.push([1, , 7, 8]); batchSize = args.batchSize == null ? 32 : args.batchSize; checkBatchSize(batchSize); checkBatchAxis = false; return [4 /*yield*/, this.standardizeUserData(x, y, args.sampleWeight, args.classWeight, checkBatchAxis, batchSize)]; case 2: standardizedOuts = _a.sent(); inputs = standardizedOuts[0]; targets = standardizedOuts[1]; sampleWeights = standardizedOuts[2]; doValidation = false; valIns = void 0; if (!(args.validationData != null && args.validationData.length > 0)) return [3 /*break*/, 4]; doValidation = true; if (args.validationData.length === 2) { // config.validationData consists of valX and valY. inputValX = args.validationData[0]; inputValY = args.validationData[1]; } else if (args.validationData.length === 3) { throw new NotImplementedError('validationData including sample weights is not supported yet.'); } else { throw new ValueError("When passing validation data, it must contain 2 (valX, valY) " + "or 3 (valX, valY, valSampleWeight) items; " + "".concat(args.validationData, " is invalid.")); } checkBatchAxis_1 = true; return [4 /*yield*/, this.standardizeUserData(inputValX, inputValY, null, /** Unused sample weights. */ null, /** Unused class weights. */ checkBatchAxis_1, batchSize)]; case 3: valStandardized = _a.sent(); valX = valStandardized[0]; valY = valStandardized[1]; valIns = valX.concat(valY); return [3 /*break*/, 5]; case 4: if (args.validationSplit != null && args.validationSplit > 0 && args.validationSplit < 1) { doValidation = true; splitAt = Math.floor(inputs[0].shape[0] * (1 - args.validationSplit)); originalBatchSize = inputs[0].shape[0]; valX = sliceArrays(inputs, splitAt, originalBatchSize); originalInputs = inputs; inputs = sliceArrays(inputs, 0, splitAt); valY = sliceArrays(targets, splitAt, originalBatchSize); originalTargets = targets; targets = sliceArrays(targets, 0, splitAt); // TODO(cais): Once sampleWeights becomes available, slice it to get // valSampleWeights. valIns = valX.concat(valY); // TODO(cais): Add useLearningPhase data properly. } else if (args.validationSteps != null) { doValidation = true; // TODO(cais): Add useLearningPhase. } _a.label = 5; case 5: ins = inputs.concat(targets).concat(sampleWeights); this.checkTrainableWeightsConsistency(); trainFunction = this.makeTrainFunction(); outLabels = this.getDedupedMetricsNames(); valFunction = void 0; callbackMetrics = void 0; if (doValidation) { this.makeTestFunction(); valFunction = this.testFunction; callbackMetrics = outLabels.slice().concat(outLabels.map(function (n) { return 'val_' + n; })); } else { valFunction = null; valIns = []; callbackMetrics = outLabels.slice(); } callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery); return [4 /*yield*/, this.fitLoop(trainFunction, ins, outLabels, batchSize, args.epochs, args.verbose, callbacks, valFunction, valIns, args.shuffle, callbackMetrics, args.initialEpoch, null, null)]; case 6: out = _a.sent(); return [2 /*return*/, out]; case 7: this.isTraining = false; // Memory clean up. disposeNewTensors(inputs, x); disposeNewTensors(targets, y); disposeNewTensors(originalInputs, x); disposeNewTensors(originalTargets, y); disposeNewTensors(valX, inputValX); disposeNewTensors(valY, inputValY); if (sampleWeights != null) { tfc__namespace.dispose(sampleWeights); } return [7 /*endfinally*/]; case 8: return [2 /*return*/]; } }); }); }; /** * Abstract fit function for `f(ins)`. * @param f A Function returning a list of tensors. For training, this * function is expected to perform the updates to the variables. * @param ins List of tensors to be fed to `f`. * @param outLabels List of strings, display names of the outputs of `f`. * @param batchSize Integer batch size or `== null` if unknown. Default : 32. * @param epochs Number of times to iterate over the data. Default : 1. * @param verbose Verbosity mode: 0, 1, or 2. Default: 1. * @param callbacks List of callbacks to be called during training. * @param valF Function to call for validation. * @param valIns List of tensors to be fed to `valF`. * @param shuffle Whether to shuffle the data at the beginning of every * epoch. Default : true. * @param callbackMetrics List of strings, the display names of the metrics * passed to the callbacks. They should be the concatenation of the * display names of the outputs of `f` and the list of display names * of the outputs of `valF`. * @param initialEpoch Epoch at which to start training (useful for * resuming a previous training run). Default : 0. * @param stepsPerEpoch Total number of steps (batches on samples) before * declaring one epoch finished and starting the next epoch. Ignored with * the default value of `undefined` or `null`. * @param validationSteps Number of steps to run validation for (only if * doing validation from data tensors). Not applicable for tfjs-layers. * @returns A `History` object. */ LayersModel.prototype.fitLoop = function (f, ins, outLabels, batchSize, epochs, verbose, callbacks, valF, valIns, shuffle, callbackMetrics, initialEpoch, stepsPerEpoch, validationSteps) { return __awaiter(this, void 0, void 0, function () { var doValidation, numTrainSamples, indexArray, _a, callbackList, history, _loop_4, this_1, epoch, state_1; var _this = this; return __generator(this, function (_b) { switch (_b.label) { case 0: if (batchSize == null) { batchSize = 32; } if (epochs == null) { epochs = 1; } if (shuffle == null) { shuffle = true; } if (initialEpoch == null) { initialEpoch = 0; } doValidation = false; if (valF != null && valIns != null) { doValidation = true; // TODO(cais): verbose message. } if (validationSteps != null) { doValidation = true; if (stepsPerEpoch == null) { throw new ValueError('Can only use `validationSteps` when doing step-wise training, ' + 'i.e., `stepsPerEpoch` must be set.'); } } numTrainSamples = this.checkNumSamples(ins, batchSize, stepsPerEpoch, 'steps_per_epoch'); if (numTrainSamples != null) { indexArray = range(0, numTrainSamples); } if (verbose == null) { verbose = 1; } _a = configureCallbacks(callbacks, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics), callbackList = _a.callbackList, history = _a.history; callbackList.setModel(this); this.history = history; return [4 /*yield*/, callbackList.onTrainBegin()]; case 1: _b.sent(); this.stopTraining_ = false; _loop_4 = function (epoch) { var epochLogs, epochIndexArray1D_1, batches_1, _loop_5, batchIndex, state_2; return __generator(this, function (_c) { switch (_c.label) { case 0: return [4 /*yield*/, callbackList.onEpochBegin(epoch)]; case 1: _c.sent(); epochLogs = {}; if (!(stepsPerEpoch != null)) return [3 /*break*/, 2]; throw new NotImplementedError('stepsPerEpoch mode is not implemented yet.'); case 2: if (shuffle === 'batch') { throw new NotImplementedError('batch shuffling is not implemneted' + ' yet'); } else if (shuffle) { tfc.util.shuffle(indexArray); } epochIndexArray1D_1 = tfc.tensor1d(indexArray); batches_1 = makeBatches(numTrainSamples, batchSize); _loop_5 = function (batchIndex) { var batchLogs; return __generator(this, function (_d) { switch (_d.label) { case 0: batchLogs = {}; return [4 /*yield*/, callbackList.onBatchBegin(batchIndex, batchLogs)]; case 1: _d.sent(); tfc__namespace.tidy(function () { var batchStart = batches_1[batchIndex][0]; var batchEnd = batches_1[batchIndex][1]; var batchIds = sliceAlongFirstAxis(epochIndexArray1D_1, batchStart, batchEnd - batchStart); batchLogs['batch'] = batchIndex; batchLogs['size'] = batchEnd - batchStart; // TODO(cais): In ins, train flag can be a number, instead of an // Tensor? Do we need to handle this in tfjs-layers? var insBatch = sliceArraysByIndices(ins, batchIds); var outs = f(insBatch); for (var i = 0; i < outLabels.length; ++i) { var label = outLabels[i]; var out = outs[i]; batchLogs[label] = out; tfc__namespace.keep(out); // TODO(cais): Use scope() to avoid ownership. } if (batchIndex === batches_1.length - 1) { // Last batch. if (doValidation) { var valOuts = _this.testLoop(valF, valIns, batchSize); // Porting Notes: In tfjs-layers, valOuts is always an Array. for (var i = 0; i < outLabels.length; ++i) { var label = outLabels[i]; var out = valOuts[i]; tfc__namespace.keep(out); // TODO(cais): Use scope() to avoid ownership. epochLogs['val_' + label] = out; } } } }); return [4 /*yield*/, callbackList.onBatchEnd(batchIndex, batchLogs)]; case 2: _d.sent(); disposeTensorsInLogs(batchLogs); if (this_1.stopTraining_) { return [2 /*return*/, "break"]; } return [2 /*return*/]; } }); }; batchIndex = 0; _c.label = 3; case 3: if (!(batchIndex < batches_1.length)) return [3 /*break*/, 6]; return [5 /*yield**/, _loop_5(batchIndex)]; case 4: state_2 = _c.sent(); if (state_2 === "break") return [3 /*break*/, 6]; _c.label = 5; case 5: ++batchIndex; return [3 /*break*/, 3]; case 6: epochIndexArray1D_1.dispose(); _c.label = 7; case 7: // TODO(cais): Run validation at the end of the epoch. return [4 /*yield*/, callbackList.onEpochEnd(epoch, epochLogs)]; case 8: // TODO(cais): Run validation at the end of the epoch. _c.sent(); if (this_1.stopTraining_) { return [2 /*return*/, "break"]; } return [2 /*return*/]; } }); }; this_1 = this; epoch = initialEpoch; _b.label = 2; case 2: if (!(epoch < epochs)) return [3 /*break*/, 5]; return [5 /*yield**/, _loop_4(epoch)]; case 3: state_1 = _b.sent(); if (state_1 === "break") return [3 /*break*/, 5]; _b.label = 4; case 4: ++epoch; return [3 /*break*/, 2]; case 5: return [4 /*yield*/, callbackList.onTrainEnd()]; case 6: _b.sent(); return [4 /*yield*/, this.history.syncData()]; case 7: _b.sent(); return [2 /*return*/, this.history]; } }); }); }; // TODO(cais): Add code snippet below when it's possible to instantiate // actual dataset objects. /** * Trains the model using a dataset object. * * @param dataset A dataset object. Its `iterator()` method is expected * to generate a dataset iterator object, the `next()` method of which * is expected to produce data batches for training. The return value * of the `next()` call ought to contain a boolean `done` field and a * `value` field. The `value` field is expected to be an array of two * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former * case is for models with exactly one input and one output (e.g. * a sequential model). The latter case is for models with multiple * inputs and/or multiple outputs. * Of the two items in the array, the first is the input feature(s) and * the second is the output target(s). * @param args A `ModelFitDatasetArgs`, containing optional fields. * * @return A `History` instance. Its `history` attribute contains all * information collected during training. * * @doc {heading: 'Models', subheading: 'Classes'} */ LayersModel.prototype.fitDataset = function (dataset, args) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { return [2 /*return*/, fitDataset(this, dataset, args)]; }); }); }; /** * Runs a single gradient update on a single batch of data. * * This method differs from `fit()` and `fitDataset()` in the following * regards: * - It operates on exactly one batch of data. * - It returns only the loss and metric values, instead of * returning the batch-by-batch loss and metric values. * - It doesn't support fine-grained options such as verbosity and * callbacks. * * @param x Input data. It could be one of the following: * - A `tf.Tensor`, or an Array of `tf.Tensor`s (in case the model has * multiple inputs). * - An Object mapping input names to corresponding `tf.Tensor` (if the * model has named inputs). * @param y Target data. It could be either a `tf.Tensor` or multiple * `tf.Tensor`s. It should be consistent with `x`. * @returns Training loss or losses (in case the model has * multiple outputs), along with metrics (if any), as numbers. * * @doc {heading: 'Models', subheading: 'Classes'} */ LayersModel.prototype.trainOnBatch = function (x, y) { return __awaiter(this, void 0, void 0, function () { var standardizeOut, inputs, targets, trainFunction, losses, lossValues, losses_1, losses_1_1, loss, v, e_7_1; var e_7, _a; return __generator(this, function (_b) { switch (_b.label) { case 0: return [4 /*yield*/, this.standardizeUserData(x, y)]; case 1: standardizeOut = _b.sent(); inputs = standardizeOut[0]; targets = standardizeOut[1]; trainFunction = this.makeTrainFunction(); losses = trainFunction(inputs.concat(targets)); lossValues = []; _b.label = 2; case 2: _b.trys.push([2, 7, 8, 9]); losses_1 = __values(losses), losses_1_1 = losses_1.next(); _b.label = 3; case 3: if (!!losses_1_1.done) return [3 /*break*/, 6]; loss = losses_1_1.value; return [4 /*yield*/, loss.data()]; case 4: v = _b.sent(); lossValues.push(v[0]); _b.label = 5; case 5: losses_1_1 = losses_1.next(); return [3 /*break*/, 3]; case 6: return [3 /*break*/, 9]; case 7: e_7_1 = _b.sent(); e_7 = { error: e_7_1 }; return [3 /*break*/, 9]; case 8: try { if (losses_1_1 && !losses_1_1.done && (_a = losses_1.return)) _a.call(losses_1); } finally { if (e_7) throw e_7.error; } return [7 /*endfinally*/]; case 9: tfc__namespace.dispose(losses); disposeNewTensors(standardizeOut[0], x); disposeNewTensors(standardizeOut[1], y); return [2 /*return*/, singletonOrArray(lossValues)]; } }); }); }; /** * Extract weight values of the model. * * @param config: An instance of `io.SaveConfig`, which specifies * model-saving options such as whether only trainable weights are to be * saved. * @returns A `NamedTensorMap` mapping original weight names (i.e., * non-uniqueified weight names) to their values. */ LayersModel.prototype.getNamedWeights = function (config) { var namedWeights = []; var trainableOnly = config != null && config.trainableOnly; var weights = trainableOnly ? this.trainableWeights : this.weights; var weightValues = this.getWeights(trainableOnly); for (var i = 0; i < weights.length; ++i) { if (trainableOnly && !weights[i].trainable) { // Optionally skip non-trainable weights. continue; } namedWeights.push({ name: weights[i].originalName, tensor: weightValues[i] }); } return namedWeights; }; Object.defineProperty(LayersModel.prototype, "stopTraining", { get: function () { return this.stopTraining_; }, /** * Setter used for force stopping of LayersModel.fit() (i.e., training). * * Example: * * ```js * const input = tf.input({shape: [10]}); * const output = tf.layers.dense({units: 1}).apply(input); * const model = tf.model({inputs: [input], outputs: [output]}); * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); * const xs = tf.ones([8, 10]); * const ys = tf.zeros([8, 1]); * * const history = await model.fit(xs, ys, { * epochs: 10, * callbacks: { * onEpochEnd: async (epoch, logs) => { * if (epoch === 2) { * model.stopTraining = true; * } * } * } * }); * * // There should be only 3 values in the loss array, instead of 10 * values, * // due to the stopping after 3 epochs. * console.log(history.history.loss); * ``` */ set: function (stop) { this.stopTraining_ = stop; }, enumerable: false, configurable: true }); Object.defineProperty(LayersModel.prototype, "optimizer", { get: function () { return this.optimizer_; }, set: function (optimizer) { if (this.optimizer_ !== optimizer) { this.optimizer_ = optimizer; this.isOptimizerOwned = false; } }, enumerable: false, configurable: true }); LayersModel.prototype.dispose = function () { var result = _super.prototype.dispose.call(this); if (result.refCountAfterDispose === 0 && this.optimizer != null && this.isOptimizerOwned) { var numTensorsBeforeOptmizerDisposal = tfc__namespace.memory().numTensors; this.optimizer_.dispose(); result.numDisposedVariables += numTensorsBeforeOptmizerDisposal - tfc__namespace.memory().numTensors; } return result; }; LayersModel.prototype.getLossIdentifiers = function () { var e_8, _a, e_9, _b; var lossNames; if (typeof this.loss === 'string') { lossNames = toSnakeCase(this.loss); } else if (Array.isArray(this.loss)) { try { for (var _c = __values(this.loss), _d = _c.next(); !_d.done; _d = _c.next()) { var loss = _d.value; if (typeof loss !== 'string') { throw new Error('Serialization of non-string loss is not supported.'); } } } catch (e_8_1) { e_8 = { error: e_8_1 }; } finally { try { if (_d && !_d.done && (_a = _c.return)) _a.call(_c); } finally { if (e_8) throw e_8.error; } } lossNames = this.loss.map(function (name) { return toSnakeCase(name); }); } else { var outputNames = Object.keys(this.loss); lossNames = {}; var losses_2 = this.loss; try { for (var outputNames_2 = __values(outputNames), outputNames_2_1 = outputNames_2.next(); !outputNames_2_1.done; outputNames_2_1 = outputNames_2.next()) { var outputName = outputNames_2_1.value; if (typeof losses_2[outputName] === 'string') { lossNames[outputName] = toSnakeCase(losses_2[outputName]); } else { throw new Error('Serialization of non-string loss is not supported.'); } } } catch (e_9_1) { e_9 = { error: e_9_1 }; } finally { try { if (outputNames_2_1 && !outputNames_2_1.done && (_b = outputNames_2.return)) _b.call(outputNames_2); } finally { if (e_9) throw e_9.error; } } } return lossNames; }; LayersModel.prototype.getMetricIdentifiers = function () { if (typeof this.metrics === 'string' || typeof this.metrics === 'function') { return [toSnakeCase(getLossOrMetricName(this.metrics))]; } else if (Array.isArray(this.metrics)) { return this.metrics.map(function (metric) { return toSnakeCase(getLossOrMetricName(metric)); }); } else { var metricsIdentifiers = {}; for (var key in this.metrics) { metricsIdentifiers[key] = toSnakeCase(getLossOrMetricName(this.metrics[key])); } return metricsIdentifiers; } }; LayersModel.prototype.getTrainingConfig = function () { return { loss: this.getLossIdentifiers(), metrics: this.getMetricIdentifiers(), optimizer_config: { class_name: this.optimizer.getClassName(), config: this.optimizer.getConfig() } }; // TODO(cais): Add weight_metrics when they are supported. // TODO(cais): Add sample_weight_mode when it's supported. // TODO(cais): Add loss_weights when it's supported. }; LayersModel.prototype.loadTrainingConfig = function (trainingConfig) { if (trainingConfig.weighted_metrics != null) { throw new Error('Loading weight_metrics is not supported yet.'); } if (trainingConfig.loss_weights != null) { throw new Error('Loading loss_weights is not supported yet.'); } if (trainingConfig.sample_weight_mode != null) { throw new Error('Loading sample_weight_mode is not supported yet.'); } var tsConfig = convertPythonicToTs(trainingConfig.optimizer_config); var optimizer = deserialize(tsConfig); var loss; if (typeof trainingConfig.loss === 'string') { loss = toCamelCase(trainingConfig.loss); } else if (Array.isArray(trainingConfig.loss)) { loss = trainingConfig.loss.map(function (lossEntry) { return toCamelCase(lossEntry); }); } else if (trainingConfig.loss != null) { loss = {}; for (var key in trainingConfig.loss) { loss[key] = toCamelCase(trainingConfig.loss[key]); } } var metrics; if (Array.isArray(trainingConfig.metrics)) { metrics = trainingConfig.metrics.map(function (metric) { return toCamelCase(metric); }); } else if (trainingConfig.metrics != null) { metrics = {}; for (var key in trainingConfig.metrics) { metrics[key] = toCamelCase(trainingConfig.metrics[key]); } } this.compile({ loss: loss, metrics: metrics, optimizer: optimizer }); }; /** * Save the configuration and/or weights of the LayersModel. * * An `IOHandler` is an object that has a `save` method of the proper * signature defined. The `save` method manages the storing or * transmission of serialized data ("artifacts") that represent the * model's topology and weights onto or via a specific medium, such as * file downloads, local storage, IndexedDB in the web browser and HTTP * requests to a server. TensorFlow.js provides `IOHandler` * implementations for a number of frequently used saving mediums, such as * `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io` * for more details. * * This method also allows you to refer to certain types of `IOHandler`s * as URL-like string shortcuts, such as 'localstorage://' and * 'indexeddb://'. * * Example 1: Save `model`'s topology and weights to browser [local * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage); * then load it back. * * ```js * const model = tf.sequential( * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]}); * console.log('Prediction from original model:'); * model.predict(tf.ones([1, 3])).print(); * * const saveResults = await model.save('localstorage://my-model-1'); * * const loadedModel = await tf.loadLayersModel('localstorage://my-model-1'); * console.log('Prediction from loaded model:'); * loadedModel.predict(tf.ones([1, 3])).print(); * ``` * * Example 2. Saving `model`'s topology and weights to browser * [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API); * then load it back. * * ```js * const model = tf.sequential( * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]}); * console.log('Prediction from original model:'); * model.predict(tf.ones([1, 3])).print(); * * const saveResults = await model.save('indexeddb://my-model-1'); * * const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1'); * console.log('Prediction from loaded model:'); * loadedModel.predict(tf.ones([1, 3])).print(); * ``` * * Example 3. Saving `model`'s topology and weights as two files * (`my-model-1.json` and `my-model-1.weights.bin`) downloaded from * browser. * * ```js * const model = tf.sequential( * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]}); * const saveResults = await model.save('downloads://my-model-1'); * ``` * * Example 4. Send `model`'s topology and weights to an HTTP server. * See the documentation of `tf.io.http` for more details * including specifying request parameters and implementation of the * server. * * ```js * const model = tf.sequential( * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]}); * const saveResults = await model.save('http://my-server/model/upload'); * ``` * * @param handlerOrURL An instance of `IOHandler` or a URL-like, * scheme-based string shortcut for `IOHandler`. * @param config Options for saving the model. * @returns A `Promise` of `SaveResult`, which summarizes the result of * the saving, such as byte sizes of the saved artifacts for the model's * topology and weight values. * * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true} */ LayersModel.prototype.save = function (handlerOrURL, config) { return __awaiter(this, void 0, void 0, function () { var handlers, weightDataAndSpecs, returnString, unusedArg, modelConfig, modelArtifacts, includeOptimizer, weightType, _a, optimizerWeightData, optimizerWeightSpecs, _b, _c, checkSize; var _d; return __generator(this, function (_e) { switch (_e.label) { case 0: if (typeof handlerOrURL === 'string') { handlers = tfc.io.getSaveHandlers(handlerOrURL); if (handlers.length === 0) { throw new ValueError("Cannot find any save handlers for URL '".concat(handlerOrURL, "'")); } else if (handlers.length > 1) { throw new ValueError("Found more than one (".concat(handlers.length, ") save handlers for ") + "URL '".concat(handlerOrURL, "'")); } handlerOrURL = handlers[0]; } if (handlerOrURL.save == null) { throw new ValueError('LayersModel.save() cannot proceed because the IOHandler ' + 'provided does not have the `save` attribute defined.'); } return [4 /*yield*/, tfc.io.encodeWeights(this.getNamedWeights(config))]; case 1: weightDataAndSpecs = _e.sent(); returnString = false; unusedArg = null; modelConfig = this.toJSON(unusedArg, returnString); modelArtifacts = { modelTopology: modelConfig, format: LAYERS_MODEL_FORMAT_NAME, generatedBy: "TensorFlow.js tfjs-layers v".concat(version), convertedBy: null, }; includeOptimizer = config == null ? false : config.includeOptimizer; if (!(includeOptimizer && this.optimizer != null)) return [3 /*break*/, 4]; modelArtifacts.trainingConfig = this.getTrainingConfig(); weightType = 'optimizer'; _c = (_b = tfc.io).encodeWeights; return [4 /*yield*/, this.optimizer.getWeights()]; case 2: return [4 /*yield*/, _c.apply(_b, [_e.sent(), weightType])]; case 3: _a = _e.sent(), optimizerWeightData = _a.data, optimizerWeightSpecs = _a.specs; (_d = weightDataAndSpecs.specs).push.apply(_d, __spreadArray([], __read(optimizerWeightSpecs), false)); weightDataAndSpecs.data = tfc.io.concatenateArrayBuffers([weightDataAndSpecs.data, optimizerWeightData]); _e.label = 4; case 4: if (this.userDefinedMetadata != null) { checkSize = true; checkUserDefinedMetadata(this.userDefinedMetadata, this.name, checkSize); modelArtifacts.userDefinedMetadata = this.userDefinedMetadata; } modelArtifacts.weightData = weightDataAndSpecs.data; modelArtifacts.weightSpecs = weightDataAndSpecs.specs; return [2 /*return*/, handlerOrURL.save(modelArtifacts)]; } }); }); }; /** * Set user-defined metadata. * * The set metadata will be serialized together with the topology * and weights of the model during `save()` calls. * * @param setUserDefinedMetadata */ LayersModel.prototype.setUserDefinedMetadata = function (userDefinedMetadata) { checkUserDefinedMetadata(userDefinedMetadata, this.name); this.userDefinedMetadata = userDefinedMetadata; }; /** * Get user-defined metadata. * * The metadata is supplied via one of the two routes: * 1. By calling `setUserDefinedMetadata()`. * 2. Loaded during model loading (if the model is constructed * via `tf.loadLayersModel()`.) * * If no user-defined metadata is available from either of the * two routes, this function will return `undefined`. */ LayersModel.prototype.getUserDefinedMetadata = function () { return this.userDefinedMetadata; }; return LayersModel; }(Container)); // The class name is 'Model' rather than 'LayersModel' for backwards // compatibility since this class name shows up in the serialization format. /** @nocollapse */ LayersModel.className = 'Model'; tfc.serialization.registerClass(LayersModel); /** * A `tf.Functional` is an alias to `tf.LayersModel`. * * See also: * `tf.LayersModel`, `tf.Sequential`, `tf.loadLayersModel`. */ /** @doc {heading: 'Models', subheading: 'Classes'} */ var Functional = /** @class */ (function (_super) { __extends(Functional, _super); function Functional() { return _super !== null && _super.apply(this, arguments) || this; } return Functional; }(LayersModel)); Functional.className = 'Functional'; tfc.serialization.registerClass(Functional); /** * Parses a JSON model configuration file and returns a model instance. * * ```js * // This example shows how to serialize a model using `toJSON()` and * // deserialize it as another model using `tf.models.modelFromJSON()`. * // Note: this example serializes and deserializes only the topology * // of the model; the weights of the loaded model will be different * // from those of the the original model, due to random weight * // initialization. * // To load the topology and weights of a model, use `tf.loadLayersModel()`. * const model1 = tf.sequential(); * model1.add(tf.layers.repeatVector({inputShape: [2], n: 4})); * // Serialize `model1` as a JSON object. * const model1JSON = model1.toJSON(null, false); * model1.summary(); * * const model2 = await tf.models.modelFromJSON(model1JSON); * model2.summary(); * ``` * * @param modelAndWeightsConfig JSON object or string encoding a model and * weights configuration. It can also be only the topology JSON of the * model, in which case the weights will not be loaded. * @param custom_objects Optional dictionary mapping names * (strings) to custom classes or functions to be * considered during deserialization. * @returns A TensorFlow.js Layers `tf.LayersModel` instance (uncompiled). */ function modelFromJSON(modelAndWeightsConfig, customObjects) { return __awaiter(this, void 0, void 0, function () { var modelTopology, tsConfig, model, weightValues, uniqueWeightValues, _a, _b, weight; var e_1, _c; return __generator(this, function (_d) { switch (_d.label) { case 0: if (!('modelTopology' in modelAndWeightsConfig)) { modelAndWeightsConfig = { modelTopology: modelAndWeightsConfig }; } modelAndWeightsConfig = modelAndWeightsConfig; modelTopology = modelAndWeightsConfig.modelTopology; if (modelTopology['model_config'] != null) { // If the model-topology JSON contains a 'model_config' field, then it is // a full model JSON (e.g., from `keras.Model.save()`), which contains // not only the model's architecture in its 'model_config' field, but // additional information such as the model's optimizer. We use only the // 'model_config' field currently. modelTopology = modelTopology['model_config']; } tsConfig = convertPythonicToTs(modelTopology); model = deserialize(tsConfig, customObjects); if (!(modelAndWeightsConfig.weightsManifest != null)) return [3 /*break*/, 2]; return [4 /*yield*/, tfc.io.loadWeights(modelAndWeightsConfig.weightsManifest, modelAndWeightsConfig.pathPrefix, model.weights.map(function (weight) { return weight.originalName; }))]; case 1: weightValues = _d.sent(); uniqueWeightValues = {}; try { for (_a = __values(model.weights), _b = _a.next(); !_b.done; _b = _a.next()) { weight = _b.value; uniqueWeightValues[weight.originalName] = weightValues[weight.originalName]; } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (_b && !_b.done && (_c = _a.return)) _c.call(_a); } finally { if (e_1) throw e_1.error; } } model.loadWeights(uniqueWeightValues); // Dispose temporary weight values. tfc.dispose(weightValues); _d.label = 2; case 2: return [2 /*return*/, model]; } }); }); } /** * Load a model composed of Layer objects, including its topology and optionally * weights. See the Tutorial named "How to import a Keras Model" for usage * examples. * * This method is applicable to: * * 1. Models created with the `tf.layers.*`, `tf.sequential`, and * `tf.model` APIs of TensorFlow.js and later saved with the * `tf.LayersModel.save` method. * 2. Models converted from Keras or TensorFlow tf.keras using the * [tensorflowjs_converter](https://github.com/tensorflow/tfjs/tree/master/tfjs-converter). * * This mode is *not* applicable to TensorFlow `SavedModel`s or their converted * forms. For those models, use `tf.loadGraphModel`. * * Example 1. Load a model from an HTTP server. * * ```js * const model = await tf.loadLayersModel( * 'https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json'); * model.summary(); * ``` * * Example 2: Save `model`'s topology and weights to browser [local * storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage); * then load it back. * * ```js * const model = tf.sequential( * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]}); * console.log('Prediction from original model:'); * model.predict(tf.ones([1, 3])).print(); * * const saveResults = await model.save('localstorage://my-model-1'); * * const loadedModel = await tf.loadLayersModel('localstorage://my-model-1'); * console.log('Prediction from loaded model:'); * loadedModel.predict(tf.ones([1, 3])).print(); * ``` * * Example 3. Saving `model`'s topology and weights to browser * [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API); * then load it back. * * ```js * const model = tf.sequential( * {layers: [tf.layers.dense({units: 1, inputShape: [3]})]}); * console.log('Prediction from original model:'); * model.predict(tf.ones([1, 3])).print(); * * const saveResults = await model.save('indexeddb://my-model-1'); * * const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1'); * console.log('Prediction from loaded model:'); * loadedModel.predict(tf.ones([1, 3])).print(); * ``` * * Example 4. Load a model from user-selected files from HTML * [file input * elements](https://developer.mozilla.org/en-US/docs/Web/HTML/Element/input/file). * * ```js * // Note: this code snippet will not work without the HTML elements in the * // page * const jsonUpload = document.getElementById('json-upload'); * const weightsUpload = document.getElementById('weights-upload'); * * const model = await tf.loadLayersModel( * tf.io.browserFiles([jsonUpload.files[0], weightsUpload.files[0]])); * ``` * * @param pathOrIOHandler Can be either of the two formats * 1. A string path to the `ModelAndWeightsConfig` JSON describing * the model in the canonical TensorFlow.js format. For file:// * (tfjs-node-only), http:// and https:// schemas, the path can be * either absolute or relative. The content of the JSON file is assumed to * be a JSON object with the following fields and values: * - 'modelTopology': A JSON object that can be either of: * 1. a model architecture JSON consistent with the format of the return * value of `keras.Model.to_json()` * 2. a full model JSON in the format of `keras.models.save_model()`. * - 'weightsManifest': A TensorFlow.js weights manifest. * See the Python converter function `save_model()` for more details. * It is also assumed that model weights can be accessed from relative * paths described by the `paths` fields in weights manifest. * 2. A `tf.io.IOHandler` object that loads model artifacts with its `load` * method. * @param options Optional configuration arguments for the model loading, * including: * - `strict`: Require that the provided weights exactly match those required * by the layers. Default true. Passing false means that both extra * weights and missing weights will be silently ignored. * - `onProgress`: A progress callback of the form: * `(fraction: number) => void`. This callback can be used to monitor the * model-loading process. * @returns A `Promise` of `tf.LayersModel`, with the topology and weights * loaded. * * @doc {heading: 'Models', subheading: 'Loading'} */ function loadLayersModel(pathOrIOHandler, options) { return __awaiter(this, void 0, void 0, function () { var handlers; return __generator(this, function (_a) { if (options == null) { options = {}; } if (typeof pathOrIOHandler === 'string') { handlers = tfc.io.getLoadHandlers(pathOrIOHandler, options); if (handlers.length === 0) { // For backward compatibility: if no load handler can be found, // assume it is a relative http path. // TODO(cais): Reformat the args into a single `LoadOptions` once the core // is refactored. handlers.push(tfc.io.browserHTTPRequest(pathOrIOHandler, options)); } else if (handlers.length > 1) { throw new ValueError("Found more than one (".concat(handlers.length, ") load handlers for ") + "URL '".concat(pathOrIOHandler, "'")); } pathOrIOHandler = handlers[0]; } return [2 /*return*/, loadLayersModelFromIOHandler(pathOrIOHandler, undefined, options)]; }); }); } /** * Load a model and optionally its weights, using an IOHandler object. * * @param handler The instance of `IOHandler` to be used during the model * loading. * @param customObjects Any optional custom objects to be used during model * loading. * @param strict Whether the weight loading will be done in strict mode. * Default: `true`. */ function loadLayersModelFromIOHandler(handler, customObjects, options) { return __awaiter(this, void 0, void 0, function () { var artifacts, modelTopology, strict, fastWeightInit, model, trainingConfig, _a, modelWeights, optimizerWeights; return __generator(this, function (_b) { switch (_b.label) { case 0: if (options == null) { options = {}; } if (handler.load == null) { throw new ValueError('Cannot proceed with model loading because the IOHandler provided ' + 'does not have the `load` method implemented.'); } return [4 /*yield*/, handler.load()]; case 1: artifacts = _b.sent(); modelTopology = artifacts.modelTopology; if (modelTopology['model_config'] != null) { modelTopology = modelTopology['model_config']; } strict = options.strict == null ? true : options.strict; fastWeightInit = artifacts.weightData != null && artifacts.weightSpecs != null && strict; model = deserialize(convertPythonicToTs(modelTopology), customObjects, fastWeightInit); trainingConfig = artifacts.trainingConfig; if (trainingConfig != null) { model.loadTrainingConfig(trainingConfig); } if (artifacts.userDefinedMetadata != null) { model.setUserDefinedMetadata(artifacts.userDefinedMetadata); } if (!(artifacts.weightData != null)) return [3 /*break*/, 4]; // Loading weights requires weightSpecs. if (artifacts.weightSpecs == null) { throw new ValueError('LayersModel artifacts contains weight data, but not weight specs. ' + 'Therefore loading of weights cannot proceed.'); } _a = decodeModelAndOptimizerWeights(artifacts.weightData, artifacts.weightSpecs), modelWeights = _a.modelWeights, optimizerWeights = _a.optimizerWeights; model.loadWeights(modelWeights, strict); if (!(model.optimizer != null && optimizerWeights.length > 0)) return [3 /*break*/, 3]; return [4 /*yield*/, model.optimizer.setWeights(optimizerWeights)]; case 2: _b.sent(); _b.label = 3; case 3: // Dispose temporary weight values. tfc.dispose(modelWeights); tfc.dispose(optimizerWeights.map(function (w) { return w.tensor; })); _b.label = 4; case 4: return [2 /*return*/, model]; } }); }); } function decodeModelAndOptimizerWeights(weightData, specs) { var name2Tensor = tfc.io.decodeWeights(weightData, specs); var modelWeights = {}; var optimizerWeights = []; specs.forEach(function (spec) { if (spec.group === 'optimizer') { optimizerWeights.push({ name: spec.name, tensor: name2Tensor[spec.name] }); } else { modelWeights[spec.name] = name2Tensor[spec.name]; } }); return { modelWeights: modelWeights, optimizerWeights: optimizerWeights }; } /** * A model with a stack of layers, feeding linearly from one to the next. * * `tf.sequential` is a factory function that creates an instance of * `tf.Sequential`. * * ```js * // Define a model for linear regression. * const model = tf.sequential(); * model.add(tf.layers.dense({units: 1, inputShape: [1]})); * * // Prepare the model for training: Specify the loss and the optimizer. * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); * * // Generate some synthetic data for training. * const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]); * const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]); * * // Train the model using the data then do inference on a data point the * // model hasn't seen: * await model.fit(xs, ys); * model.predict(tf.tensor2d([5], [1, 1])).print(); * ``` * * @doc {heading: 'Models', subheading: 'Classes'} */ var Sequential = /** @class */ (function (_super) { __extends(Sequential, _super); function Sequential(args) { var e_2, _a; var _this = _super.call(this, { inputs: [], outputs: [] }) || this; args = args || {}; _this.trainable = true; _this.built = false; // Set model name. _this.name = (args.name != null) ? args.name : getUid('sequential_'); // Add to the model any layers passed to the constructor. if (args.layers != null) { try { for (var _b = __values(args.layers), _c = _b.next(); !_c.done; _c = _b.next()) { var layer = _c.value; _this.add(layer); } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_2) throw e_2.error; } } } return _this; } // Helper function to Sequential.add Throws if the new output shape will be // invalid. Sequential.prototype.checkShape = function (layer) { var shape = layer.inboundNodes[0].outputTensors[0].shape; if (shape.some(function (x) { return x < 0; })) { throw new ValueError('Negative dimension size caused by adding layer ' + "".concat(layer.name, " with input shape [") + "".concat(layer.inboundNodes[0].inputTensors[0].shape, "]")); } }; /** * Adds a layer instance on top of the layer stack. * * ```js * const model = tf.sequential(); * model.add(tf.layers.dense({units: 8, inputShape: [1]})); * model.add(tf.layers.dense({units: 4, activation: 'relu6'})); * model.add(tf.layers.dense({units: 1, activation: 'relu6'})); * // Note that the untrained model is random at this point. * model.predict(tf.randomNormal([10, 1])).print(); * ``` * @param layer Layer instance. * * @exception ValueError In case the `layer` argument does not know its * input shape. * @exception ValueError In case the `layer` argument has multiple output * tensors, or is already connected somewhere else (forbidden in * `Sequential` models). * * @doc {heading: 'Models', subheading: 'Classes'} */ Sequential.prototype.add = function (layer) { var isLayerModelInstance = layer instanceof Sequential || layer instanceof LayersModel; var modelLayer; if (isLayerModelInstance) { modelLayer = layer; if (modelLayer.outputs.length !== 1) { throw new ValueError('All layers in a Sequential model ' + 'should have a single output tensor. ' + 'For multi-output layers, ' + 'use the functional API.'); } if (modelLayer.inputs.length !== 1) { throw new ValueError('All layers in a Sequential model ' + 'should have a single input tensor. ' + 'For multi-input layers, ' + 'use the functional API.'); } } if (this.outputs.length === 0) { // first layer in model: check that it is an input layer if (layer.inboundNodes.length === 0) { // create an input layer if (layer.batchInputShape == null) { throw new ValueError('The first layer in a Sequential model must ' + 'get an `inputShape` or `batchInputShape` argument.'); } // Instantiate the input layer. var x = Input({ batchShape: layer.batchInputShape, dtype: layer.dtype, name: layer.name + '_input' }); // This will build the current layer and create the node connecting // the current layer to the input layer we just created. layer.apply(x); } if (isLayerModelInstance) { this.outputs = modelLayer.outputs; this.inputs = modelLayer.inputs; } else { if (layer.inboundNodes.length !== 1) { throw new ValueError('A layer added to a Sequential model must not already be ' + "connected somewhere else. LayersModel received layer ".concat(layer.name, " ") + "which has ".concat(layer.inboundNodes.length, " pre-existing inbound ") + 'connections.'); } if (layer.inboundNodes[0].outputTensors.length !== 1) { throw new ValueError('All layers in a Sequential model ' + 'should have a single output tensor. ' + 'For multi-output layers, ' + 'use the functional API.'); } this.checkShape(layer); this.outputs = [layer.inboundNodes[0].outputTensors[0]]; this.inputs = getSourceInputs(this.outputs[0]); } this.inboundNodes = []; // We create an input node, which we will keep updated // as we add more layers. // (This call has side effects.) // tslint:disable-next-line:no-unused-expression new Node({ outboundLayer: this, inboundLayers: [], nodeIndices: [], tensorIndices: [], inputTensors: this.inputs, outputTensors: this.outputs, // no model-level masking for now inputMasks: pyListRepeat(null, this.inputs.length), outputMasks: [null], inputShapes: this.inputs.map(function (x) { return x.shape; }), outputShapes: this.outputs[0].shape }); } else { var outputTensor = layer.apply(this.outputs[0]); if (Array.isArray(outputTensor)) { throw new TypeError('All layers in a Sequential model ' + 'should have a single output tensor. ' + 'For multi-output layers, ' + 'use the functional API.'); } this.checkShape(layer); this.outputs = [outputTensor]; // update self.inbound_nodes this.inboundNodes[0].outputTensors = this.outputs; this.inboundNodes[0].outputShapes = [this.outputs[0].shape]; } this.layers.push(layer); this.built = false; }; /** * Removes the last layer in the model. * * @exception TypeError if there are no layers in the model. */ Sequential.prototype.pop = function () { if (this.layers.length === 0) { throw new TypeError('There are no layers in the model.'); } this.layers.pop(); if (this.layers.length === 0) { this.outputs = []; this.inboundNodes = []; this.outboundNodes = []; } else { var lastLayerIndex = this.layers.length - 1; this.layers[lastLayerIndex].outboundNodes = []; this.outputs = [this.layers[lastLayerIndex].output]; // update self.inbound_nodes this.inboundNodes[0].outputTensors = this.outputs; this.inboundNodes[0].outputShapes = [this.outputs[0].shape]; } }; Sequential.prototype.call = function (inputs, kwargs) { if (this.model == null) { this.build(); } return this.model.call(inputs, kwargs); }; Sequential.prototype.build = function (inputShape) { // Call `getExactlyOneShape` without using its return value, // to verify that exactly one input shape is provided. getExactlyOneShape(inputShape); if (this.inputs.length === 0 || this.outputs.length === 0) { throw new TypeError('Sequential model cannot be built: model is empty.' + ' Add some layers first.'); } // actually create the model this.model = new LayersModel({ inputs: this.inputs, outputs: this.outputs[0], name: this.name + '_model' }); this.model.trainable = this.trainable; // mirror model attributes this.supportsMasking = this.model.supportsMasking; // TODO(michaelterry): Add caches this.inputLayers = this.model.inputLayers; this.inputLayersNodeIndices = this.model.inputLayersNodeIndices; this.inputLayersTensorIndices = this.model.inputLayersTensorIndices; this.outputLayers = this.model.outputLayers; this.outputLayersNodeIndices = this.model.outputLayersNodeIndices; this.outputLayersTensorIndices = this.model.outputLayersTensorIndices; this.nodesByDepth = this.model.nodesByDepth; this.containerNodes = this.model.containerNodes; this.outputNames = this.model.outputNames; this.inputNames = this.model.inputNames; // TODO(michaelterry): Add feedInputNames, feedInputs, if needed. // TODO(michaelterry): Add callbackModel if needed. this.built = true; }; Sequential.prototype.countParams = function () { if (!this.built) { this.build(); } return _super.prototype.countParams.call(this); }; /** * Print a text summary of the Sequential model's layers. * * The summary includes * - Name and type of all layers that comprise the model. * - Output shape(s) of the layers * - Number of weight parameters of each layer * - The total number of trainable and non-trainable parameters of the * model. * * ```js * const model = tf.sequential(); * model.add( * tf.layers.dense({units: 100, inputShape: [10], activation: 'relu'})); * model.add(tf.layers.dense({units: 1, activation: 'sigmoid'})); * * model.summary(); * ``` * * @param lineLength Custom line length, in number of characters. * @param positions Custom widths of each of the columns, as either * fractions of `lineLength` (e.g., `[0.5, 0.75, 1]`) or absolute number * of characters (e.g., `[30, 50, 65]`). Each number corresponds to * right-most (i.e., ending) position of a column. * @param printFn Custom print function. Can be used to replace the default * `console.log`. For example, you can use `x => {}` to mute the printed * messages in the console. * * @doc {heading: 'Models', subheading: 'Classes'} */ Sequential.prototype.summary = function (lineLength, positions, printFn) { if (printFn === void 0) { printFn = console.log; } if (!this.built) { this.build(); } _super.prototype.summary.call(this, lineLength, positions, printFn); }; /** * Sets the weights of the model. * * @param weights Should be a list of Tensors with shapes and types matching * the output of `model.getWeights()`. */ Sequential.prototype.setWeights = function (weights) { if (this.model == null) { this.build(); } this.model.setWeights(weights); }; /** * Returns the loss value & metrics values for the model in test mode. * * Loss and metrics are specified during `compile()`, which needs to happen * before calls to `evaluate()`. * * Computation is done in batches. * * ```js * const model = tf.sequential({ * layers: [tf.layers.dense({units: 1, inputShape: [10]})] * }); * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'}); * const result = model.evaluate(tf.ones([8, 10]), tf.ones([8, 1]), { * batchSize: 4, * }); * result.print(); * ``` * * @param x `tf.Tensor` of test data, or an `Array` of `tf.Tensor`s if the * model has multiple inputs. * @param y `tf.Tensor` of target data, or an `Array` of `tf.Tensor`s if the * model has multiple outputs. * @param args A `ModelEvaluateConfig`, containing optional fields. * * @return `Scalar` test loss (if the model has a single output and no * metrics) or `Array` of `Scalar`s (if the model has multiple outputs * and/or metrics). The attribute `model.metricsNames` * will give you the display labels for the scalar outputs. * * @doc {heading: 'Models', subheading: 'Classes'} */ Sequential.prototype.evaluate = function (x, y, args) { if (args === void 0) { args = {}; } if (!this.built) { throw new RuntimeError('The model needs to be compiled before being used.'); } return this.model.evaluate(x, y, args); }; // TODO(cais): Add code snippet below once real dataset objects are // available. /** * Evaluate model using a dataset object. * * Note: Unlike `evaluate()`, this method is asynchronous (`async`). * * @param dataset A dataset object. Its `iterator()` method is expected * to generate a dataset iterator object, the `next()` method of which * is expected to produce data batches for evaluation. The return value * of the `next()` call ought to contain a boolean `done` field and a * `value` field. The `value` field is expected to be an array of two * `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former * case is for models with exactly one input and one output (e.g. * a sequential model). The latter case is for models with multiple * inputs and/or multiple outputs. Of the two items in the array, the * first is the input feature(s) and the second is the output target(s). * @param args A configuration object for the dataset-based evaluation. * @returns Loss and metric values as an Array of `Scalar` objects. * * @doc {heading: 'Models', subheading: 'Classes'} */ Sequential.prototype.evaluateDataset = function (dataset, args) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { if (!this.built) { throw new RuntimeError('The model needs to be compiled before being used.'); } return [2 /*return*/, this.model.evaluateDataset(dataset, args)]; }); }); }; /** * Generates output predictions for the input samples. * * Computation is done in batches. * * Note: the "step" mode of predict() is currently not supported. * This is because the TensorFlow.js core backend is imperative only. * * ```js * const model = tf.sequential({ * layers: [tf.layers.dense({units: 1, inputShape: [10]})] * }); * model.predict(tf.ones([2, 10])).print(); * ``` * * @param x The input data, as a Tensor, or an `Array` of `tf.Tensor`s if * the model has multiple inputs. * @param conifg A `ModelPredictConfig` object containing optional fields. * * @return `tf.Tensor`(s) of predictions. * * @exception ValueError In case of mismatch between the provided input data * and the model's expectations, or in case a stateful model receives a * number of samples that is not a multiple of the batch size. * * @doc {heading: 'Models', subheading: 'Classes'} */ Sequential.prototype.predict = function (x, args) { if (args === void 0) { args = {}; } if (this.model == null) { this.build(); } return this.model.predict(x, args); }; /** * Returns predictions for a single batch of samples. * * @param x: Input samples, as a Tensor, or list of Tensors (if the model * has multiple inputs). * @return Tensor(s) of predictions */ Sequential.prototype.predictOnBatch = function (x) { if (this.model == null) { this.build(); } return this.model.predictOnBatch(x); }; /** * See `LayersModel.compile`. * * @param args */ Sequential.prototype.compile = function (args) { this.build(); this.model.compile(args); this.optimizer_ = this.model.optimizer; // tslint:disable-next-line:no-any this.isOptimizerOwned = this.model.isOptimizerOwned; this.loss = this.model.loss; this.metrics = this.model.metrics; // TODO(cais): Add this.lossWeights, this.sampleWeightMode, // this.weightedMetrics, this.targets. this.metricsTensors = this.model.metricsTensors; this.metricsNames = this.model.metricsNames; // TODO(cais): Add sampleWeights. }; Object.defineProperty(Sequential.prototype, "optimizer", { get: function () { return this.model == null ? undefined : this.model.optimizer; }, set: function (optimizer) { this.model.optimizer = optimizer; }, enumerable: false, configurable: true }); /** * Trains the model for a fixed number of epochs (iterations on a dataset). * * ```js * const model = tf.sequential({ * layers: [tf.layers.dense({units: 1, inputShape: [10]})] * }); * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'}); * const history = await model.fit(tf.ones([8, 10]), tf.ones([8, 1]), { * batchSize: 4, * epochs: 3 * }); * console.log(history.history.loss[0]); * ``` * * @param x `tf.Tensor` of training data, or an array of `tf.Tensor`s if the * model has multiple inputs. If all inputs in the model are named, you can * also pass a dictionary mapping input names to `tf.Tensor`s. * @param y `tf.Tensor` of target (label) data, or an array of `tf.Tensor`s if * the model has multiple outputs. If all outputs in the model are named, you * can also pass a dictionary mapping output names to `tf.Tensor`s. * @param args A `ModelFitConfig`, containing optional fields. * * @return A `History` instance. Its `history` attribute contains all * information collected during training. * * @exception ValueError In case of mismatch between the provided input data * and what the model expects. * * @doc {heading: 'Models', subheading: 'Classes'} */ Sequential.prototype.fit = function (x, y, args) { if (args === void 0) { args = {}; } return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { if (!this.built) { throw new RuntimeError('The model needs to be compiled before ' + 'being used.'); } return [2 /*return*/, this.model.fit(x, y, args)]; }); }); }; /** * Trains the model using a dataset object. * * ```js * const xArray = [ * [1, 1, 1, 1, 1, 1, 1, 1, 1], * [1, 1, 1, 1, 1, 1, 1, 1, 1], * [1, 1, 1, 1, 1, 1, 1, 1, 1], * [1, 1, 1, 1, 1, 1, 1, 1, 1], * ]; * const yArray = [1, 1, 1, 1]; * // Create a dataset from the JavaScript array. * const xDataset = tf.data.array(xArray); * const yDataset = tf.data.array(yArray); * // Zip combines the `x` and `y` Datasets into a single Dataset, the * // iterator of which will return an object containing of two tensors, * // corresponding to `x` and `y`. The call to `batch(4)` will bundle * // four such samples into a single object, with the same keys now pointing * // to tensors that hold 4 examples, organized along the batch dimension. * // The call to `shuffle(4)` causes each iteration through the dataset to * // happen in a different order. The size of the shuffle window is 4. * const xyDataset = tf.data.zip({xs: xDataset, ys: yDataset}) * .batch(4) * .shuffle(4); * const model = tf.sequential({ * layers: [tf.layers.dense({units: 1, inputShape: [9]})] * }); * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'}); * const history = await model.fitDataset(xyDataset, { * epochs: 4, * callbacks: {onEpochEnd: (epoch, logs) => console.log(logs.loss)} * }); * ``` * * @param dataset A dataset object. Its `iterator()` method is expected to * generate a dataset iterator object, the `next()` method of which is * expected to produce data batches for evaluation. The return value of the * `next()` call ought to contain a boolean `done` field and a `value` * field. * * The `value` field is expected to be an object of with fields * `xs` and `ys`, which point to the feature tensor and the target tensor, * respectively. This case is for models with exactly one input and one * output (e.g. a sequential model). For example: * ```js * {value: {xs: xsTensor, ys: ysTensor}, done: false} * ``` * * If the model has multiple inputs, the `xs` field of `value` should * be an object mapping input names to their respective feature tensors. * For example: * ```js * { * value: { * xs: { * input_1: xsTensor1, * input_2: xsTensor2 * }, * ys: ysTensor * }, * done: false * } * ``` * If the model has multiple outputs, the `ys` field of `value` should * be an object mapping output names to their respective target tensors. * For example: * ```js * { * value: { * xs: xsTensor, * ys: { * output_1: ysTensor1, * output_2: ysTensor2 * }, * }, * done: false * } * ``` * @param args A `ModelFitDatasetArgs`, containing optional fields. * * @return A `History` instance. Its `history` attribute contains all * information collected during training. * * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true} */ Sequential.prototype.fitDataset = function (dataset, args) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { if (!this.built) { throw new RuntimeError('The model needs to be compiled before ' + 'being used.'); } return [2 /*return*/, this.model.fitDataset(dataset, args)]; }); }); }; /** * Runs a single gradient update on a single batch of data. * * This method differs from `fit()` and `fitDataset()` in the following * regards: * - It operates on exactly one batch of data. * - It returns only the loss and metric values, instead of * returning the batch-by-batch loss and metric values. * - It doesn't support fine-grained options such as verbosity and * callbacks. * * @param x Input data. It could be one of the following: * - A `tf.Tensor`, or an Array of `tf.Tensor`s (in case the model has * multiple inputs). * - An Object mapping input names to corresponding `tf.Tensor` (if the * model has named inputs). * @param y Target data. It could be either a `tf.Tensor` or multiple * `tf.Tensor`s. It should be consistent with `x`. * @returns Training loss or losses (in case the model has * multiple outputs), along with metrics (if any), as numbers. * * @doc {heading: 'Models', subheading: 'Classes'} */ Sequential.prototype.trainOnBatch = function (x, y) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { return [2 /*return*/, this.model.trainOnBatch(x, y)]; }); }); }; /* See parent class for JsDoc */ /** @nocollapse */ Sequential.fromConfig = function (cls, config, customObjects, fastWeightInit) { var e_3, _a; if (fastWeightInit === void 0) { fastWeightInit = false; } var configArray; var extraModelConfig = {}; if (config instanceof Array) { if (!(config[0].className != null) || config[0]['className'] === 'Merge') { throw new ValueError('Legacy serialization format not supported yet.'); } configArray = config; } else { tfc.util.assert(config['layers'] != null, function () { return "When the config data for a Sequential model is not an Array, " + "it must be an Object that contains the 'layers' field."; }); configArray = config['layers']; delete config['layers']; extraModelConfig = config; } var model = new cls(extraModelConfig); if (!(model instanceof Sequential)) { throw new NotImplementedError("Sequential.fromConfig called on non-Sequential input: ".concat(model)); } try { for (var configArray_1 = __values(configArray), configArray_1_1 = configArray_1.next(); !configArray_1_1.done; configArray_1_1 = configArray_1.next()) { var conf = configArray_1_1.value; var customObjects_1 = undefined; var layer = deserialize(conf, customObjects_1, fastWeightInit); if (fastWeightInit) { layer.setFastWeightInitDuringBuild(true); } model.add(layer); } } catch (e_3_1) { e_3 = { error: e_3_1 }; } finally { try { if (configArray_1_1 && !configArray_1_1.done && (_a = configArray_1.return)) _a.call(configArray_1); } finally { if (e_3) throw e_3.error; } } return model; }; Object.defineProperty(Sequential.prototype, "stopTraining", { get: function () { if (this.model == null) { throw new ValueError('Cannot get the stopTraining property of a sequential model before ' + 'it is compiled.'); } return this.model.stopTraining; }, /** * Setter used for force stopping of LayersModel.fit() (i.e., training). * * Example: * * ```js * const model = tf.sequential(); * model.add(tf.layers.dense({units: 1, inputShape: [10]})); * model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); * const xs = tf.ones([8, 10]); * const ys = tf.zeros([8, 1]); * * const history = await model.fit(xs, ys, { * epochs: 10, * callbacks: { * onEpochEnd: async (epoch, logs) => { * if (epoch === 2) { * model.stopTraining = true; * } * } * } * }); * * // There should be only 3 values in the loss array, instead of 10 values, * // due to the stopping after 3 epochs. * console.log(history.history.loss); * ``` */ set: function (stop) { // TODO(cais): When refactoring to remove the composition pattern happens, // remove this method overriding. if (this.model == null) { throw new ValueError('Cannot set the stopTraining property of a sequential model before ' + 'it is compiled.'); } this.model.stopTraining = stop; }, enumerable: false, configurable: true }); // TODO(cais): Override get trainableWeights() here // tslint:disable-next-line:no-any Sequential.prototype.getConfig = function () { var e_4, _a; // NOTE(cais): We override the return type of getConfig() to `any` here, // because the `Sequential` class is a special case among `Container` // subtypes in that its getConfig() method returns an Array (not a // dict). var layers = []; try { for (var _b = __values(this.layers), _c = _b.next(); !_c.done; _c = _b.next()) { var layer = _c.value; var dict = {}; dict['className'] = layer.getClassName(); dict['config'] = layer.getConfig(); layers.push(dict); } } catch (e_4_1) { e_4 = { error: e_4_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_4) throw e_4.error; } } return { name: this.name, layers: layers }; }; return Sequential; }(LayersModel)); /** @nocollapse */ Sequential.className = 'Sequential'; tfc.serialization.registerClass(Sequential); /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ // TODO(cais): Add doc string to all the public static functions in this // class; include exectuable JavaScript code snippets where applicable // (b/74074458). // LayersModel and related factory methods. /** * A model is a data structure that consists of `Layers` and defines inputs * and outputs. * * The key difference between `tf.model` and `tf.sequential` is that * `tf.model` is more generic, supporting an arbitrary graph (without * cycles) of layers. `tf.sequential` is less generic and supports only a linear * stack of layers. * * When creating a `tf.LayersModel`, specify its input(s) and output(s). Layers * are used to wire input(s) to output(s). * * For example, the following code snippet defines a model consisting of * two `dense` layers, with 10 and 4 units, respectively. * * ```js * // Define input, which has a size of 5 (not including batch dimension). * const input = tf.input({shape: [5]}); * * // First dense layer uses relu activation. * const denseLayer1 = tf.layers.dense({units: 10, activation: 'relu'}); * // Second dense layer uses softmax activation. * const denseLayer2 = tf.layers.dense({units: 4, activation: 'softmax'}); * * // Obtain the output symbolic tensor by applying the layers on the input. * const output = denseLayer2.apply(denseLayer1.apply(input)); * * // Create the model based on the inputs. * const model = tf.model({inputs: input, outputs: output}); * * // The model can be used for training, evaluation and prediction. * // For example, the following line runs prediction with the model on * // some fake data. * model.predict(tf.ones([2, 5])).print(); * ``` * See also: * `tf.sequential`, `tf.loadLayersModel`. * * @doc {heading: 'Models', subheading: 'Creation'} */ function model(args) { return new LayersModel(args); } /** * Creates a `tf.Sequential` model. A sequential model is any model where the * outputs of one layer are the inputs to the next layer, i.e. the model * topology is a simple 'stack' of layers, with no branching or skipping. * * This means that the first layer passed to a `tf.Sequential` model should have * a defined input shape. What that means is that it should have received an * `inputShape` or `batchInputShape` argument, or for some type of layers * (recurrent, Dense...) an `inputDim` argument. * * The key difference between `tf.model` and `tf.sequential` is that * `tf.sequential` is less generic, supporting only a linear stack of layers. * `tf.model` is more generic and supports an arbitrary graph (without * cycles) of layers. * * Examples: * * ```js * const model = tf.sequential(); * * // First layer must have an input shape defined. * model.add(tf.layers.dense({units: 32, inputShape: [50]})); * // Afterwards, TF.js does automatic shape inference. * model.add(tf.layers.dense({units: 4})); * * // Inspect the inferred shape of the model's output, which equals * // `[null, 4]`. The 1st dimension is the undetermined batch dimension; the * // 2nd is the output size of the model's last layer. * console.log(JSON.stringify(model.outputs[0].shape)); * ``` * * It is also possible to specify a batch size (with potentially undetermined * batch dimension, denoted by "null") for the first layer using the * `batchInputShape` key. The following example is equivalent to the above: * * ```js * const model = tf.sequential(); * * // First layer must have a defined input shape * model.add(tf.layers.dense({units: 32, batchInputShape: [null, 50]})); * // Afterwards, TF.js does automatic shape inference. * model.add(tf.layers.dense({units: 4})); * * // Inspect the inferred shape of the model's output. * console.log(JSON.stringify(model.outputs[0].shape)); * ``` * * You can also use an `Array` of already-constructed `Layer`s to create * a `tf.Sequential` model: * * ```js * const model = tf.sequential({ * layers: [tf.layers.dense({units: 32, inputShape: [50]}), * tf.layers.dense({units: 4})] * }); * console.log(JSON.stringify(model.outputs[0].shape)); * ``` * * @doc {heading: 'Models', subheading: 'Creation'} */ function sequential(config) { return new Sequential(config); } /** * Used to instantiate an input to a model as a `tf.SymbolicTensor`. * * Users should call the `input` factory function for * consistency with other generator functions. * * Example: * * ```js * // Defines a simple logistic regression model with 32 dimensional input * // and 3 dimensional output. * const x = tf.input({shape: [32]}); * const y = tf.layers.dense({units: 3, activation: 'softmax'}).apply(x); * const model = tf.model({inputs: x, outputs: y}); * model.predict(tf.ones([2, 32])).print(); * ``` * * Note: `input` is only necessary when using `model`. When using * `sequential`, specify `inputShape` for the first layer or use `inputLayer` * as the first layer. * * @doc {heading: 'Models', subheading: 'Inputs'} */ function input(config) { return Input(config); } function registerCallbackConstructor(verbosityLevel, callbackConstructor) { CallbackConstructorRegistry.registerCallbackConstructor(verbosityLevel, callbackConstructor); } /** * Base class for Activations. * * Special note: due to cross-language compatibility reasons, the * static readonly className field in this family of classes must be set to * the initialLowerCamelCase name of the activation. */ var Activation$1 = /** @class */ (function (_super) { __extends(Activation, _super); function Activation() { return _super !== null && _super.apply(this, arguments) || this; } Activation.prototype.getConfig = function () { return {}; }; return Activation; }(tfc.serialization.Serializable)); /** * Exponential linear unit (ELU). * Reference: https://arxiv.org/abs/1511.07289 */ var Elu = /** @class */ (function (_super) { __extends(Elu, _super); function Elu() { return _super !== null && _super.apply(this, arguments) || this; } /** * Calculate the activation function. * * @param x: Input. * @param alpha: Scaling factor the negative section. * @return Output of the ELU activation. */ Elu.prototype.apply = function (x, alpha) { if (alpha === void 0) { alpha = 1; } return elu$1(x, alpha); }; return Elu; }(Activation$1)); /** @nocollapse */ Elu.className = 'elu'; tfc.serialization.registerClass(Elu); /** * Scaled Exponential Linear Unit. (Klambauer et al., 2017). * Reference: Self-Normalizing Neural Networks, https://arxiv.org/abs/1706.02515 * Notes: * - To be used together with the initialization "lecunNormal". * - To be used together with the dropout variant "AlphaDropout". */ var Selu = /** @class */ (function (_super) { __extends(Selu, _super); function Selu() { return _super !== null && _super.apply(this, arguments) || this; } Selu.prototype.apply = function (x) { return tfc__namespace.selu(x); }; return Selu; }(Activation$1)); /** @nocollapse */ Selu.className = 'selu'; tfc.serialization.registerClass(Selu); /** * Rectified linear unit */ var Relu = /** @class */ (function (_super) { __extends(Relu, _super); function Relu() { return _super !== null && _super.apply(this, arguments) || this; } Relu.prototype.apply = function (x) { return tfc__namespace.relu(x); }; return Relu; }(Activation$1)); /** @nocollapse */ Relu.className = 'relu'; tfc.serialization.registerClass(Relu); /** * Rectified linear unit activation maxing out at 6.0. */ var Relu6 = /** @class */ (function (_super) { __extends(Relu6, _super); function Relu6() { return _super !== null && _super.apply(this, arguments) || this; } Relu6.prototype.apply = function (x) { return tfc.tidy(function () { return tfc__namespace.minimum(6.0, tfc__namespace.relu(x)); }); }; return Relu6; }(Activation$1)); /** @nocollapse */ Relu6.className = 'relu6'; tfc.serialization.registerClass(Relu6); //* Linear activation (no-op) */ var Linear = /** @class */ (function (_super) { __extends(Linear, _super); function Linear() { return _super !== null && _super.apply(this, arguments) || this; } Linear.prototype.apply = function (x) { return x; }; return Linear; }(Activation$1)); /** @nocollapse */ Linear.className = 'linear'; tfc.serialization.registerClass(Linear); /** * Sigmoid activation function. */ var Sigmoid = /** @class */ (function (_super) { __extends(Sigmoid, _super); function Sigmoid() { return _super !== null && _super.apply(this, arguments) || this; } Sigmoid.prototype.apply = function (x) { return tfc__namespace.sigmoid(x); }; return Sigmoid; }(Activation$1)); /** @nocollapse */ Sigmoid.className = 'sigmoid'; tfc.serialization.registerClass(Sigmoid); /** * Segment-wise linear approximation of sigmoid. */ var HardSigmoid = /** @class */ (function (_super) { __extends(HardSigmoid, _super); function HardSigmoid() { return _super !== null && _super.apply(this, arguments) || this; } HardSigmoid.prototype.apply = function (x) { return hardSigmoid(x); }; return HardSigmoid; }(Activation$1)); /** @nocollapse */ HardSigmoid.className = 'hardSigmoid'; tfc.serialization.registerClass(HardSigmoid); /** * Softplus activation function. */ var Softplus = /** @class */ (function (_super) { __extends(Softplus, _super); function Softplus() { return _super !== null && _super.apply(this, arguments) || this; } Softplus.prototype.apply = function (x) { return tfc__namespace.softplus(x); }; return Softplus; }(Activation$1)); /** @nocollapse */ Softplus.className = 'softplus'; tfc.serialization.registerClass(Softplus); /** * Softsign activation function. */ var Softsign = /** @class */ (function (_super) { __extends(Softsign, _super); function Softsign() { return _super !== null && _super.apply(this, arguments) || this; } Softsign.prototype.apply = function (x) { return softsign(x); }; return Softsign; }(Activation$1)); /** @nocollapse */ Softsign.className = 'softsign'; tfc.serialization.registerClass(Softsign); /** * Hyperbolic tangent function. */ var Tanh = /** @class */ (function (_super) { __extends(Tanh, _super); function Tanh() { return _super !== null && _super.apply(this, arguments) || this; } Tanh.prototype.apply = function (x) { return tfc__namespace.tanh(x); }; return Tanh; }(Activation$1)); /** @nocollapse */ Tanh.className = 'tanh'; tfc.serialization.registerClass(Tanh); /** * Softmax activation function */ var Softmax$1 = /** @class */ (function (_super) { __extends(Softmax, _super); function Softmax() { return _super !== null && _super.apply(this, arguments) || this; } /** * Calculate the activation function. * * @param x Tensor. * @param axis Integer, axis along which the softmax normalization is applied. * Invalid if < 2, as softmax across 1 (the batch dimension) is assumed to be * an error. * * @returns a Tensor of the same shape as x * * @throws ValueError: In case `dim(x) < 2`. */ Softmax.prototype.apply = function (x, axis) { if (axis === void 0) { axis = (-1); } return tfc__namespace.softmax(x, axis); }; return Softmax; }(Activation$1)); /** @nocollapse */ Softmax$1.className = 'softmax'; tfc.serialization.registerClass(Softmax$1); /** * Log softmax activation function */ var LogSoftmax = /** @class */ (function (_super) { __extends(LogSoftmax, _super); function LogSoftmax() { return _super !== null && _super.apply(this, arguments) || this; } /** * Calculate the activation function of log softmax: * log( exp(x_i) / sum(exp(x)) ) * * @param x Tensor. * @param axis Integer, axis along which the softmax normalization is applied. * Invalid if < 2, as softmax across 1 (the batch dimension) is assumed to be * an error. * * @returns a Tensor of the same shape as x * * @throws ValueError: In case `dim(x) < 2`. */ LogSoftmax.prototype.apply = function (x, axis) { if (axis === void 0) { axis = (-1); } return tfc__namespace.logSoftmax(x, axis); }; return LogSoftmax; }(Activation$1)); /** @nocollapse */ LogSoftmax.className = 'logSoftmax'; tfc.serialization.registerClass(LogSoftmax); /** * Swish activation function */ var Swish = /** @class */ (function (_super) { __extends(Swish, _super); function Swish() { return _super !== null && _super.apply(this, arguments) || this; } /** * Calculate the activation function. * * @param x Tensor. * @param alpha Scaling factor for the sigmoid function. * @returns a Tensor of the same shape as x */ Swish.prototype.apply = function (x, alpha) { if (alpha === void 0) { alpha = 1; } return tfc.tidy(function () { return tfc__namespace.mul(tfc__namespace.sigmoid(tfc__namespace.mul(x, alpha)), x); }); }; return Swish; }(Activation$1)); /** @nocollapse */ Swish.className = 'swish'; tfc.serialization.registerClass(Swish); /** * Mish activation function */ var Mish = /** @class */ (function (_super) { __extends(Mish, _super); function Mish() { return _super !== null && _super.apply(this, arguments) || this; } /** * Calculate the activation function. * * @param x Tensor. * @returns a Tensor of the same shape as x */ Mish.prototype.apply = function (x) { return tfc.tidy(function () { return tfc__namespace.mul(x, tfc__namespace.tanh(tfc__namespace.softplus(x))); }); }; return Mish; }(Activation$1)); /** @nocollapse */ Mish.className = 'mish'; tfc.serialization.registerClass(Mish); function serializeActivation(activation) { return activation.getClassName(); } function deserializeActivation(config, customObjects) { if (customObjects === void 0) { customObjects = {}; } return deserializeKerasObject(config, tfc.serialization.SerializationMap.getMap().classNameMap, customObjects, 'activation'); } function getActivation(identifier) { if (identifier == null) { var config = {}; config['className'] = 'linear'; config['config'] = {}; return deserializeActivation(config); } if (typeof identifier === 'string') { var config = {}; config['className'] = identifier; config['config'] = {}; return deserializeActivation(config); } else if (identifier instanceof Activation$1) { return identifier; } else { return deserializeActivation(identifier); } } function assertObjectArgs(args) { if (args != null && typeof args !== 'object') { throw new Error("Argument to L1L2 regularizer's constructor is expected to be an " + "object, but received: ".concat(args)); } } /** * Regularizer base class. */ var Regularizer = /** @class */ (function (_super) { __extends(Regularizer, _super); function Regularizer() { return _super !== null && _super.apply(this, arguments) || this; } return Regularizer; }(tfc.serialization.Serializable)); var L1L2 = /** @class */ (function (_super) { __extends(L1L2, _super); function L1L2(args) { var _this = _super.call(this) || this; assertObjectArgs(args); _this.l1 = args == null || args.l1 == null ? 0.01 : args.l1; _this.l2 = args == null || args.l2 == null ? 0.01 : args.l2; _this.hasL1 = _this.l1 !== 0; _this.hasL2 = _this.l2 !== 0; return _this; } /** * Porting note: Renamed from __call__. * @param x Variable of which to calculate the regularization score. */ L1L2.prototype.apply = function (x) { var _this = this; return tfc.tidy(function () { var regularization = tfc.zeros([1]); if (_this.hasL1) { regularization = tfc.add(regularization, tfc.sum(tfc__namespace.mul(_this.l1, tfc.abs(x)))); } if (_this.hasL2) { regularization = tfc.add(regularization, tfc.sum(tfc__namespace.mul(_this.l2, square$1(x)))); } return tfc__namespace.reshape(regularization, []); }); }; L1L2.prototype.getConfig = function () { return { 'l1': this.l1, 'l2': this.l2 }; }; /** @nocollapse */ L1L2.fromConfig = function (cls, config) { return new cls({ l1: config['l1'], l2: config['l2'] }); }; return L1L2; }(Regularizer)); /** @nocollapse */ L1L2.className = 'L1L2'; tfc.serialization.registerClass(L1L2); function l1$1(args) { assertObjectArgs(args); return new L1L2({ l1: args != null ? args.l1 : null, l2: 0 }); } function l2$1(args) { assertObjectArgs(args); return new L1L2({ l2: args != null ? args.l2 : null, l1: 0 }); } // Maps the JavaScript-like identifier keys to the corresponding keras symbols. var REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = { 'l1l2': 'L1L2' }; function serializeRegularizer(constraint) { return serializeKerasObject(constraint); } function deserializeRegularizer(config, customObjects) { if (customObjects === void 0) { customObjects = {}; } return deserializeKerasObject(config, tfc.serialization.SerializationMap.getMap().classNameMap, customObjects, 'regularizer'); } function getRegularizer(identifier) { if (identifier == null) { return null; } if (typeof identifier === 'string') { var className = identifier in REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ? REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] : identifier; var config = { className: className, config: {} }; return deserializeRegularizer(config); } else if (identifier instanceof Regularizer) { return identifier; } else { return deserializeRegularizer(identifier); } } var ReLU = /** @class */ (function (_super) { __extends(ReLU, _super); function ReLU(args) { var _this = _super.call(this, args == null ? {} : args) || this; _this.supportsMasking = true; if (args != null) { _this.maxValue = args.maxValue; } return _this; } ReLU.prototype.call = function (inputs, kwargs) { inputs = getExactlyOneTensor(inputs); var output = tfc.relu(inputs); if (this.maxValue != null) { output = tfc.clipByValue(output, 0, this.maxValue); } return output; }; ReLU.prototype.computeOutputShape = function (inputShape) { return inputShape; }; ReLU.prototype.getConfig = function () { var config = { maxValue: this.maxValue }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return ReLU; }(Layer)); /** @nocollapse */ ReLU.className = 'ReLU'; tfc.serialization.registerClass(ReLU); var LeakyReLU = /** @class */ (function (_super) { __extends(LeakyReLU, _super); function LeakyReLU(args) { var _this = _super.call(this, args == null ? {} : args) || this; _this.DEFAULT_ALPHA = 0.3; if (args == null) { args = {}; } _this.alpha = args.alpha == null ? _this.DEFAULT_ALPHA : args.alpha; return _this; } LeakyReLU.prototype.call = function (inputs, kwargs) { var x = getExactlyOneTensor(inputs); return tfc.leakyRelu(x, this.alpha); }; LeakyReLU.prototype.computeOutputShape = function (inputShape) { return inputShape; }; LeakyReLU.prototype.getConfig = function () { var config = { alpha: this.alpha }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return LeakyReLU; }(Layer)); /** @nocollapse */ LeakyReLU.className = 'LeakyReLU'; tfc.serialization.registerClass(LeakyReLU); var PReLU = /** @class */ (function (_super) { __extends(PReLU, _super); function PReLU(args) { var _this = _super.call(this, args == null ? {} : args) || this; _this.DEFAULT_ALPHA_INITIALIZER = 'zeros'; if (args == null) { args = {}; } _this.supportsMasking = true; _this.alphaInitializer = getInitializer(args.alphaInitializer || _this.DEFAULT_ALPHA_INITIALIZER); _this.alphaRegularizer = getRegularizer(args.alphaRegularizer); _this.alphaConstraint = getConstraint(args.alphaConstraint); if (args.sharedAxes == null) { _this.sharedAxes = null; } else if (Array.isArray(args.sharedAxes)) { _this.sharedAxes = args.sharedAxes; } else if (typeof args.sharedAxes === 'number') { _this.sharedAxes = [args.sharedAxes]; } else { throw new ValueError("Expected sharedAxes to be a number or an array of numbers, " + "but got ".concat(args.sharedAxes)); } return _this; } PReLU.prototype.build = function (inputShape) { var e_1, _a; inputShape = getExactlyOneShape(inputShape); var paramShape = inputShape.slice(1); if (this.sharedAxes != null) { try { for (var _b = __values(this.sharedAxes), _c = _b.next(); !_c.done; _c = _b.next()) { var i = _c.value; paramShape[i - 1] = 1; } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_1) throw e_1.error; } } } this.alpha = this.addWeight('alpha', paramShape, 'float32', this.alphaInitializer, this.alphaRegularizer, true, this.alphaConstraint); // Set input spec. var axes = {}; if (this.sharedAxes != null) { for (var i = 1; i < inputShape.length; ++i) { axes[i] = inputShape[i]; } } this.inputSpec = [new InputSpec({ ndim: inputShape.length, axes: axes, })]; this.built = true; }; PReLU.prototype.call = function (inputs, kwargs) { inputs = getExactlyOneTensor(inputs); return tfc.prelu(inputs, this.alpha.read()); }; PReLU.prototype.getConfig = function () { var config = { alphaInitializer: serializeInitializer(this.alphaInitializer), alphaRegularizer: serializeRegularizer(this.alphaRegularizer), alphaConstraint: serializeConstraint(this.alphaConstraint), sharedAxes: this.sharedAxes }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return PReLU; }(Layer)); /** @nocollapse */ PReLU.className = 'PReLU'; tfc.serialization.registerClass(PReLU); var ELU = /** @class */ (function (_super) { __extends(ELU, _super); function ELU(args) { var _this = _super.call(this, args == null ? {} : args) || this; _this.DEFAULT_ALPHA = 1.0; if (args == null) { args = {}; } if (args.alpha != null && args.alpha !== _this.DEFAULT_ALPHA) { throw new NotImplementedError("Non-default alpha value (".concat(args.alpha, ") is not supported by the ") + "ELU layer yet."); } _this.alpha = args.alpha == null ? _this.DEFAULT_ALPHA : args.alpha; return _this; } ELU.prototype.call = function (inputs, kwargs) { var x = getExactlyOneTensor(inputs); return tfc.elu(x); }; ELU.prototype.computeOutputShape = function (inputShape) { return inputShape; }; ELU.prototype.getConfig = function () { var config = { alpha: this.alpha }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return ELU; }(Layer)); /** @nocollapse */ ELU.className = 'ELU'; tfc.serialization.registerClass(ELU); var ThresholdedReLU = /** @class */ (function (_super) { __extends(ThresholdedReLU, _super); function ThresholdedReLU(args) { var _this = _super.call(this, args == null ? {} : args) || this; _this.DEFAULT_THETA = 1.0; if (args == null) { args = {}; } _this.theta = args.theta == null ? _this.DEFAULT_THETA : args.theta; return _this; } ThresholdedReLU.prototype.call = function (inputs, kwargs) { var x = getExactlyOneTensor(inputs); return tfc.mul(x, tfc.cast(tfc.greater(x, this.theta), 'float32')); }; ThresholdedReLU.prototype.computeOutputShape = function (inputShape) { return inputShape; }; ThresholdedReLU.prototype.getConfig = function () { var config = { theta: this.theta }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return ThresholdedReLU; }(Layer)); /** @nocollapse */ ThresholdedReLU.className = 'ThresholdedReLU'; tfc.serialization.registerClass(ThresholdedReLU); var Softmax = /** @class */ (function (_super) { __extends(Softmax, _super); function Softmax(args) { var _this = _super.call(this, args == null ? {} : args) || this; _this.DEFAULT_AXIS = 1.0; if (args == null) { args = {}; } _this.softmax = new Softmax$1().apply; _this.axis = args.axis == null ? _this.DEFAULT_AXIS : args.axis; return _this; } Softmax.prototype.call = function (inputs, kwargs) { var _this = this; // TODO(pforderique): Add tests for when `this.axis` is a number[]. return tfc.tidy(function () { var x = getExactlyOneTensor(inputs); var mask = kwargs['mask']; if (mask != null) { // Since mask is 1.0 for positions we want to keep and 0.0 for masked // positions, this operation will create a tensor which is 0.0 for // positions we want to attend and -1e.9 for masked positions. var adder = tfc.mul(tfc.sub(tfc.ones(x.shape), tfc.cast(mask, x.dtype)), tfc.scalar(-1e9)); // Since we are adding it to the raw scores before the softmax, this // is effectively the same as removing these entirely. x = tfc.add(x, adder); } if (_this.axis instanceof Array) { if (_this.axis.length > 1) { return tfc.exp(tfc.sub(x, tfc.logSumExp(x, _this.axis, true))); } else { return _this.softmax(x, _this.axis[0]); } } return _this.softmax(x, _this.axis); }); }; Softmax.prototype.computeOutputShape = function (inputShape) { return inputShape; }; Softmax.prototype.getConfig = function () { var config = { axis: this.axis }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return Softmax; }(Layer)); /** @nocollapse */ Softmax.className = 'Softmax'; tfc.serialization.registerClass(Softmax); /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Transforms a single number of array of numbers into an array of numbers. * @param value * @param n: The size of the tuple to be returned. * @param name: Name of the parameter, used for generating error messages. * @returns An array of numbers. */ function normalizeArray(value, n, name) { if (typeof value === 'number') { return pyListRepeat(value, n); } else { if (value.length !== n) { throw new ValueError("The ".concat(name, " argument must be an integer or tuple of ").concat(n, " integers.") + " Received: ".concat(value.length, " elements.")); } for (var i = 0; i < n; ++i) { var singleValue = value[i]; if (!isInteger(singleValue)) { throw new ValueError("The ".concat(name, " argument must be an integer or tuple of ").concat(n) + " integers. Received: ".concat(JSON.stringify(value), " including a") + " non-integer number ".concat(singleValue)); } } return value; } } /** * Determines output length of a convolution given input length. * @param inputLength * @param filterSize * @param padding * @param stride * @param dilation: dilation rate. */ function convOutputLength(inputLength, filterSize, padding, stride, dilation) { if (dilation === void 0) { dilation = 1; } if (inputLength == null) { return inputLength; } var dilatedFilterSize = filterSize + (filterSize - 1) * (dilation - 1); var outputLength; if (padding === 'same') { outputLength = inputLength; } else { // VALID outputLength = inputLength - dilatedFilterSize + 1; } return Math.floor((outputLength + stride - 1) / stride); } function deconvLength(dimSize, strideSize, kernelSize, padding) { if (dimSize == null) { return null; } if (padding === 'valid') { dimSize = dimSize * strideSize + max([kernelSize - strideSize, 0]); } else if (padding === 'same') { dimSize = dimSize * strideSize; } else { throw new ValueError("Unsupport padding mode: ".concat(padding, ".")); } return dimSize; } /** * Transpose and cast the input before the conv2d. * @param x Input image tensor. * @param dataFormat */ function preprocessConv2DInput(x, dataFormat) { // TODO(cais): Cast type to float32 if not. return tfc.tidy(function () { checkDataFormat(dataFormat); if (dataFormat === 'channelsFirst') { return tfc__namespace.transpose(x, [0, 2, 3, 1]); // NCHW -> NHWC. } else { return x; } }); } /** * Transpose and cast the input before the conv3d. * @param x Input image tensor. * @param dataFormat */ function preprocessConv3DInput(x, dataFormat) { return tfc.tidy(function () { checkDataFormat(dataFormat); if (dataFormat === 'channelsFirst') { return tfc__namespace.transpose(x, [0, 2, 3, 4, 1]); // NCDHW -> NDHWC. } else { return x; } }); } /** * 1D-convolution with bias added. * * Porting Note: This function does not exist in the Python Keras backend. * It is exactly the same as `conv2d`, except the added `bias`. * * @param x Input tensor, rank-3, of shape `[batchSize, width, inChannels]`. * @param kernel Kernel, rank-3, of shape `[filterWidth, inDepth, outDepth]`. * @param bias Bias, rank-3, of shape `[outDepth]`. * @param strides * @param padding Padding mode. * @param dataFormat Data format. * @param dilationRate * @returns The result of the 1D convolution. * @throws ValueError, if `x`, `kernel` or `bias` is not of the correct rank. */ function conv1dWithBias(x, kernel, bias, strides, padding, dataFormat, dilationRate) { if (strides === void 0) { strides = 1; } if (padding === void 0) { padding = 'valid'; } if (dilationRate === void 0) { dilationRate = 1; } return tfc.tidy(function () { if (dataFormat == null) { dataFormat = imageDataFormat(); } checkDataFormat(dataFormat); // Check the ranks of x, kernel and bias. if (x.shape.length !== 3) { throw new ValueError("The input of a conv1dWithBias operation should be 3, but is " + "".concat(x.shape.length, " instead.")); } if (kernel.shape.length !== 3) { throw new ValueError("The kernel for a conv1dWithBias operation should be 3, but is " + "".concat(kernel.shape.length, " instead")); } if (bias != null && bias.shape.length !== 1) { throw new ValueError("The bias for a conv1dWithBias operation should be 1, but is " + "".concat(kernel.shape.length, " instead")); } // TODO(cais): Support CAUSAL padding mode. if (dataFormat === 'channelsFirst') { x = tfc__namespace.transpose(x, [0, 2, 1]); // NCW -> NWC. } if (padding === 'causal') { throw new NotImplementedError('The support for CAUSAL padding mode in conv1dWithBias is not ' + 'implemented yet.'); } var y = tfc__namespace.conv1d(x, kernel, strides, padding === 'same' ? 'same' : 'valid', 'NWC', dilationRate); if (bias != null) { y = biasAdd(y, bias); } return y; }); } /** * 2D Convolution with an added bias and optional activation. * Note: This function does not exist in the Python Keras Backend. This function * is exactly the same as `conv2d`, except the added `bias`. */ function conv2dWithBiasActivation(x, kernel, bias, strides, padding, dataFormat, dilationRate, activation) { if (strides === void 0) { strides = [1, 1]; } if (padding === void 0) { padding = 'valid'; } if (activation === void 0) { activation = null; } return tfc.tidy(function () { if (dataFormat == null) { dataFormat = imageDataFormat(); } checkDataFormat(dataFormat); if (x.rank !== 3 && x.rank !== 4) { throw new ValueError("conv2dWithBiasActivation expects input to be of rank 3 or 4, " + "but received ".concat(x.rank, ".")); } if (kernel.rank !== 3 && kernel.rank !== 4) { throw new ValueError("conv2dWithBiasActivation expects kernel to be of rank 3 or 4, " + "but received ".concat(x.rank, ".")); } var y = preprocessConv2DInput(x, dataFormat); if (padding === 'causal') { throw new NotImplementedError('The support for CAUSAL padding mode in conv1dWithBias is not ' + 'implemented yet.'); } y = tfc__namespace.fused.conv2d({ x: y, filter: kernel, strides: strides, pad: padding === 'same' ? 'same' : 'valid', dilations: dilationRate, dataFormat: 'NHWC', bias: bias, activation: activation }); if (dataFormat === 'channelsFirst') { y = tfc__namespace.transpose(y, [0, 3, 1, 2]); } return y; }); } /** * 3D Convolution with an added bias. * Note: This function does not exist in the Python Keras Backend. This function * is exactly the same as `conv3d`, except the added `bias`. */ function conv3dWithBias(x, kernel, bias, strides, padding, dataFormat, dilationRate) { if (strides === void 0) { strides = [1, 1, 1]; } if (padding === void 0) { padding = 'valid'; } return tfc.tidy(function () { if (dataFormat == null) { dataFormat = imageDataFormat(); } checkDataFormat(dataFormat); if (x.rank !== 4 && x.rank !== 5) { throw new ValueError("conv3dWithBias expects input to be of rank 4 or 5, but received " + "".concat(x.rank, ".")); } if (kernel.rank !== 4 && kernel.rank !== 5) { throw new ValueError("conv3dWithBias expects kernel to be of rank 4 or 5, but received " + "".concat(x.rank, ".")); } var y = preprocessConv3DInput(x, dataFormat); if (padding === 'causal') { throw new NotImplementedError('The support for CAUSAL padding mode in conv3dWithBias is not ' + 'implemented yet.'); } y = tfc__namespace.conv3d(y, kernel, strides, padding === 'same' ? 'same' : 'valid', 'NDHWC', dilationRate); if (bias != null) { y = biasAdd(y, bias); } if (dataFormat === 'channelsFirst') { y = tfc__namespace.transpose(y, [0, 4, 1, 2, 3]); } return y; }); } /** * Abstract convolution layer. */ var BaseConv = /** @class */ (function (_super) { __extends(BaseConv, _super); function BaseConv(rank, args) { var _this = _super.call(this, args) || this; _this.bias = null; _this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal'; _this.DEFAULT_BIAS_INITIALIZER = 'zeros'; BaseConv.verifyArgs(args); _this.rank = rank; assertPositiveInteger(_this.rank, 'rank'); if (_this.rank !== 1 && _this.rank !== 2 && _this.rank !== 3) { throw new NotImplementedError("Convolution layer for rank other than 1, 2, or 3 (".concat(_this.rank, ") is ") + "not implemented yet."); } _this.kernelSize = normalizeArray(args.kernelSize, rank, 'kernelSize'); _this.strides = normalizeArray(args.strides == null ? 1 : args.strides, rank, 'strides'); _this.padding = args.padding == null ? 'valid' : args.padding; checkPaddingMode(_this.padding); _this.dataFormat = args.dataFormat == null ? 'channelsLast' : args.dataFormat; checkDataFormat(_this.dataFormat); _this.activation = getActivation(args.activation); _this.useBias = args.useBias == null ? true : args.useBias; _this.biasInitializer = getInitializer(args.biasInitializer || _this.DEFAULT_BIAS_INITIALIZER); _this.biasConstraint = getConstraint(args.biasConstraint); _this.biasRegularizer = getRegularizer(args.biasRegularizer); _this.activityRegularizer = getRegularizer(args.activityRegularizer); _this.dilationRate = normalizeArray(args.dilationRate == null ? 1 : args.dilationRate, rank, 'dilationRate'); if (_this.rank === 1 && (Array.isArray(_this.dilationRate) && _this.dilationRate.length !== 1)) { throw new ValueError("dilationRate must be a number or an array of a single number " + "for 1D convolution, but received " + "".concat(JSON.stringify(_this.dilationRate))); } else if (_this.rank === 2) { if (typeof _this.dilationRate === 'number') { _this.dilationRate = [_this.dilationRate, _this.dilationRate]; } else if (_this.dilationRate.length !== 2) { throw new ValueError("dilationRate must be a number or array of two numbers for 2D " + "convolution, but received ".concat(JSON.stringify(_this.dilationRate))); } } else if (_this.rank === 3) { if (typeof _this.dilationRate === 'number') { _this.dilationRate = [_this.dilationRate, _this.dilationRate, _this.dilationRate]; } else if (_this.dilationRate.length !== 3) { throw new ValueError("dilationRate must be a number or array of three numbers for 3D " + "convolution, but received ".concat(JSON.stringify(_this.dilationRate))); } } return _this; } BaseConv.verifyArgs = function (args) { // Check config.kernelSize type and shape. assert$1('kernelSize' in args, "required key 'kernelSize' not in config"); if (typeof args.kernelSize !== 'number' && !checkArrayTypeAndLength(args.kernelSize, 'number', 1, 3)) { throw new ValueError("BaseConv expects config.kernelSize to be number or number[] with " + "length 1, 2, or 3, but received ".concat(JSON.stringify(args.kernelSize), ".")); } }; BaseConv.prototype.getConfig = function () { var config = { kernelSize: this.kernelSize, strides: this.strides, padding: this.padding, dataFormat: this.dataFormat, dilationRate: this.dilationRate, activation: serializeActivation(this.activation), useBias: this.useBias, biasInitializer: serializeInitializer(this.biasInitializer), biasRegularizer: serializeRegularizer(this.biasRegularizer), activityRegularizer: serializeRegularizer(this.activityRegularizer), biasConstraint: serializeConstraint(this.biasConstraint) }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return BaseConv; }(Layer)); /** * Abstract nD convolution layer. Ancestor of convolution layers which reduce * across channels, i.e., Conv1D and Conv2D, but not DepthwiseConv2D. */ var Conv = /** @class */ (function (_super) { __extends(Conv, _super); function Conv(rank, args) { var _this = _super.call(this, rank, args) || this; _this.kernel = null; Conv.verifyArgs(args); _this.filters = args.filters; assertPositiveInteger(_this.filters, 'filters'); _this.kernelInitializer = getInitializer(args.kernelInitializer || _this.DEFAULT_KERNEL_INITIALIZER); _this.kernelConstraint = getConstraint(args.kernelConstraint); _this.kernelRegularizer = getRegularizer(args.kernelRegularizer); return _this; } Conv.prototype.build = function (inputShape) { var _a; inputShape = getExactlyOneShape(inputShape); var channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1; if (inputShape[channelAxis] == null) { throw new ValueError("The channel dimension of the input should be defined. " + "Found ".concat(inputShape[channelAxis])); } var inputDim = inputShape[channelAxis]; var kernelShape = this.kernelSize.concat([inputDim, this.filters]); this.kernel = this.addWeight('kernel', kernelShape, null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint); if (this.useBias) { this.bias = this.addWeight('bias', [this.filters], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint); } this.inputSpec = [{ ndim: this.rank + 2, axes: (_a = {}, _a[channelAxis] = inputDim, _a) }]; this.built = true; }; Conv.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { inputs = getExactlyOneTensor(inputs); var outputs; var biasValue = _this.bias == null ? null : _this.bias.read(); var fusedActivationName = mapActivationToFusedKernel(_this.activation.getClassName()); if (fusedActivationName != null && _this.rank === 2) { outputs = conv2dWithBiasActivation(inputs, _this.kernel.read(), biasValue, _this.strides, _this.padding, _this.dataFormat, _this.dilationRate, fusedActivationName); } else { if (_this.rank === 1) { outputs = conv1dWithBias(inputs, _this.kernel.read(), biasValue, _this.strides[0], _this.padding, _this.dataFormat, _this.dilationRate[0]); } else if (_this.rank === 2) { // TODO(cais): Move up to constructor. outputs = conv2dWithBiasActivation(inputs, _this.kernel.read(), biasValue, _this.strides, _this.padding, _this.dataFormat, _this.dilationRate); } else if (_this.rank === 3) { outputs = conv3dWithBias(inputs, _this.kernel.read(), biasValue, _this.strides, _this.padding, _this.dataFormat, _this.dilationRate); } else { throw new NotImplementedError('convolutions greater than 3D are not implemented yet.'); } if (_this.activation != null) { outputs = _this.activation.apply(outputs); } } return outputs; }); }; Conv.prototype.computeOutputShape = function (inputShape) { inputShape = getExactlyOneShape(inputShape); var newSpace = []; var space = (this.dataFormat === 'channelsLast') ? inputShape.slice(1, inputShape.length - 1) : inputShape.slice(2); for (var i = 0; i < space.length; ++i) { var newDim = convOutputLength(space[i], this.kernelSize[i], this.padding, this.strides[i], typeof this.dilationRate === 'number' ? this.dilationRate : this.dilationRate[i]); newSpace.push(newDim); } var outputShape = [inputShape[0]]; if (this.dataFormat === 'channelsLast') { outputShape = outputShape.concat(newSpace); outputShape.push(this.filters); } else { outputShape.push(this.filters); outputShape = outputShape.concat(newSpace); } return outputShape; }; Conv.prototype.getConfig = function () { var config = { filters: this.filters, kernelInitializer: serializeInitializer(this.kernelInitializer), kernelRegularizer: serializeRegularizer(this.kernelRegularizer), kernelConstraint: serializeConstraint(this.kernelConstraint) }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; Conv.verifyArgs = function (args) { // Check config.filters type, shape, and value. if (!('filters' in args) || typeof args.filters !== 'number' || args.filters < 1) { throw new ValueError("Convolution layer expected config.filters to be a 'number' > 0 " + "but got ".concat(JSON.stringify(args.filters))); } }; return Conv; }(BaseConv)); var Conv2D = /** @class */ (function (_super) { __extends(Conv2D, _super); function Conv2D(args) { var _this = _super.call(this, 2, args) || this; Conv2D.verifyArgs(args); return _this; } Conv2D.prototype.getConfig = function () { var config = _super.prototype.getConfig.call(this); delete config['rank']; return config; }; Conv2D.verifyArgs = function (args) { // config.kernelSize must be a number or array of numbers. if ((typeof args.kernelSize !== 'number') && !checkArrayTypeAndLength(args.kernelSize, 'number', 1, 2)) { throw new ValueError("Conv2D expects config.kernelSize to be number or number[] with " + "length 1 or 2, but received ".concat(JSON.stringify(args.kernelSize), ".")); } }; return Conv2D; }(Conv)); /** @nocollapse */ Conv2D.className = 'Conv2D'; tfc.serialization.registerClass(Conv2D); var Conv3D = /** @class */ (function (_super) { __extends(Conv3D, _super); function Conv3D(args) { var _this = _super.call(this, 3, args) || this; Conv3D.verifyArgs(args); return _this; } Conv3D.prototype.getConfig = function () { var config = _super.prototype.getConfig.call(this); delete config['rank']; return config; }; Conv3D.verifyArgs = function (args) { // config.kernelSize must be a number or array of numbers. if (typeof args.kernelSize !== 'number') { if (!(Array.isArray(args.kernelSize) && (args.kernelSize.length === 1 || args.kernelSize.length === 3))) { throw new ValueError("Conv3D expects config.kernelSize to be number or" + " [number, number, number], but received ".concat(JSON.stringify(args.kernelSize), ".")); } } }; return Conv3D; }(Conv)); /** @nocollapse */ Conv3D.className = 'Conv3D'; tfc.serialization.registerClass(Conv3D); var Conv2DTranspose = /** @class */ (function (_super) { __extends(Conv2DTranspose, _super); function Conv2DTranspose(args) { var _this = _super.call(this, args) || this; _this.inputSpec = [new InputSpec({ ndim: 4 })]; if (_this.padding !== 'same' && _this.padding !== 'valid') { throw new ValueError("Conv2DTranspose currently supports only padding modes 'same' " + "and 'valid', but received padding mode ".concat(_this.padding)); } return _this; } Conv2DTranspose.prototype.build = function (inputShape) { var _a; inputShape = getExactlyOneShape(inputShape); if (inputShape.length !== 4) { throw new ValueError('Input should have rank 4; Received input shape: ' + JSON.stringify(inputShape)); } var channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1; if (inputShape[channelAxis] == null) { throw new ValueError('The channel dimension of the inputs should be defined. ' + 'Found `None`.'); } var inputDim = inputShape[channelAxis]; var kernelShape = this.kernelSize.concat([this.filters, inputDim]); this.kernel = this.addWeight('kernel', kernelShape, 'float32', this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint); if (this.useBias) { this.bias = this.addWeight('bias', [this.filters], 'float32', this.biasInitializer, this.biasRegularizer, true, this.biasConstraint); } // Set input spec. this.inputSpec = [new InputSpec({ ndim: 4, axes: (_a = {}, _a[channelAxis] = inputDim, _a) })]; this.built = true; }; Conv2DTranspose.prototype.call = function (inputs, kwargs) { var _this = this; return tfc__namespace.tidy(function () { var input = getExactlyOneTensor(inputs); if (input.shape.length !== 4) { throw new ValueError("Conv2DTranspose.call() expects input tensor to be rank-4, but " + "received a tensor of rank-".concat(input.shape.length)); } var inputShape = input.shape; var batchSize = inputShape[0]; var hAxis; var wAxis; if (_this.dataFormat === 'channelsFirst') { hAxis = 2; wAxis = 3; } else { hAxis = 1; wAxis = 2; } var height = inputShape[hAxis]; var width = inputShape[wAxis]; var kernelH = _this.kernelSize[0]; var kernelW = _this.kernelSize[1]; var strideH = _this.strides[0]; var strideW = _this.strides[1]; // Infer the dynamic output shape. var outHeight = deconvLength(height, strideH, kernelH, _this.padding); var outWidth = deconvLength(width, strideW, kernelW, _this.padding); // Porting Note: We don't branch based on `this.dataFormat` here, // because // the tjfs-core function `conv2dTranspose` called below always // assumes channelsLast. var outputShape = [batchSize, outHeight, outWidth, _this.filters]; if (_this.dataFormat !== 'channelsLast') { input = tfc__namespace.transpose(input, [0, 2, 3, 1]); } var outputs = tfc__namespace.conv2dTranspose(input, _this.kernel.read(), outputShape, _this.strides, _this.padding); if (_this.dataFormat !== 'channelsLast') { outputs = tfc__namespace.transpose(outputs, [0, 3, 1, 2]); } if (_this.bias != null) { outputs = biasAdd(outputs, _this.bias.read(), _this.dataFormat); } if (_this.activation != null) { outputs = _this.activation.apply(outputs); } return outputs; }); }; Conv2DTranspose.prototype.computeOutputShape = function (inputShape) { inputShape = getExactlyOneShape(inputShape); var outputShape = inputShape.slice(); var channelAxis; var heightAxis; var widthAxis; if (this.dataFormat === 'channelsFirst') { channelAxis = 1; heightAxis = 2; widthAxis = 3; } else { channelAxis = 3; heightAxis = 1; widthAxis = 2; } var kernelH = this.kernelSize[0]; var kernelW = this.kernelSize[1]; var strideH = this.strides[0]; var strideW = this.strides[1]; outputShape[channelAxis] = this.filters; outputShape[heightAxis] = deconvLength(outputShape[heightAxis], strideH, kernelH, this.padding); outputShape[widthAxis] = deconvLength(outputShape[widthAxis], strideW, kernelW, this.padding); return outputShape; }; Conv2DTranspose.prototype.getConfig = function () { var config = _super.prototype.getConfig.call(this); delete config['dilationRate']; return config; }; return Conv2DTranspose; }(Conv2D)); /** @nocollapse */ Conv2DTranspose.className = 'Conv2DTranspose'; tfc.serialization.registerClass(Conv2DTranspose); var Conv3DTranspose = /** @class */ (function (_super) { __extends(Conv3DTranspose, _super); function Conv3DTranspose(args) { var _this = _super.call(this, args) || this; _this.inputSpec = [new InputSpec({ ndim: 5 })]; if (_this.padding !== 'same' && _this.padding !== 'valid') { throw new ValueError("Conv3DTranspose currently supports only padding modes 'same' " + "and 'valid', but received padding mode ".concat(_this.padding)); } return _this; } Conv3DTranspose.prototype.build = function (inputShape) { var _a; inputShape = getExactlyOneShape(inputShape); if (inputShape.length !== 5) { throw new ValueError('Input should have rank 5; Received input shape: ' + JSON.stringify(inputShape)); } var channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1; if (inputShape[channelAxis] == null) { throw new ValueError('The channel dimension of the inputs should be defined. ' + 'Found `None`.'); } var inputDim = inputShape[channelAxis]; var kernelShape = this.kernelSize.concat([this.filters, inputDim]); this.kernel = this.addWeight('kernel', kernelShape, 'float32', this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint); if (this.useBias) { this.bias = this.addWeight('bias', [this.filters], 'float32', this.biasInitializer, this.biasRegularizer, true, this.biasConstraint); } // Set input spec. this.inputSpec = [new InputSpec({ ndim: 5, axes: (_a = {}, _a[channelAxis] = inputDim, _a) })]; this.built = true; }; Conv3DTranspose.prototype.call = function (inputs, kwargs) { var _this = this; return tfc__namespace.tidy(function () { var input = getExactlyOneTensor(inputs); if (input.shape.length !== 5) { throw new ValueError("Conv3DTranspose.call() expects input tensor to be rank-4, but " + "received a tensor of rank-".concat(input.shape.length)); } var inputShape = input.shape; var batchSize = inputShape[0]; var hAxis; var wAxis; var dAxis; if (_this.dataFormat === 'channelsFirst') { dAxis = 2; hAxis = 3; wAxis = 4; } else { dAxis = 1; hAxis = 2; wAxis = 3; } var depth = inputShape[dAxis]; var height = inputShape[hAxis]; var width = inputShape[wAxis]; var kernelD = _this.kernelSize[0]; var kernelH = _this.kernelSize[1]; var kernelW = _this.kernelSize[2]; var strideD = _this.strides[0]; var strideH = _this.strides[1]; var strideW = _this.strides[2]; // Infer the dynamic output shape. var outDepth = deconvLength(depth, strideD, kernelD, _this.padding); var outHeight = deconvLength(height, strideH, kernelH, _this.padding); var outWidth = deconvLength(width, strideW, kernelW, _this.padding); // Same as `conv2dTranspose`. We always assumes channelsLast. var outputShape = [batchSize, outDepth, outHeight, outWidth, _this.filters]; if (_this.dataFormat !== 'channelsLast') { input = tfc__namespace.transpose(input, [0, 2, 3, 4, 1]); } var outputs = tfc__namespace.conv3dTranspose(input, _this.kernel.read(), outputShape, _this.strides, _this.padding); if (_this.dataFormat !== 'channelsLast') { outputs = tfc__namespace.transpose(outputs, [0, 4, 1, 2, 3]); } if (_this.bias !== null) { outputs = biasAdd(outputs, _this.bias.read(), _this.dataFormat); } if (_this.activation !== null) { outputs = _this.activation.apply(outputs); } return outputs; }); }; Conv3DTranspose.prototype.computeOutputShape = function (inputShape) { inputShape = getExactlyOneShape(inputShape); var outputShape = inputShape.slice(); var channelAxis; var depthAxis; var heightAxis; var widthAxis; if (this.dataFormat === 'channelsFirst') { channelAxis = 1; depthAxis = 2; heightAxis = 3; widthAxis = 4; } else { channelAxis = 4; depthAxis = 1; heightAxis = 2; widthAxis = 3; } var kernelD = this.kernelSize[0]; var kernelH = this.kernelSize[1]; var kernelW = this.kernelSize[2]; var strideD = this.strides[0]; var strideH = this.strides[1]; var strideW = this.strides[2]; outputShape[channelAxis] = this.filters; outputShape[depthAxis] = deconvLength(outputShape[depthAxis], strideD, kernelD, this.padding); outputShape[heightAxis] = deconvLength(outputShape[heightAxis], strideH, kernelH, this.padding); outputShape[widthAxis] = deconvLength(outputShape[widthAxis], strideW, kernelW, this.padding); return outputShape; }; Conv3DTranspose.prototype.getConfig = function () { var config = _super.prototype.getConfig.call(this); delete config['dilationRate']; return config; }; return Conv3DTranspose; }(Conv3D)); /** @nocollapse */ Conv3DTranspose.className = 'Conv3DTranspose'; tfc.serialization.registerClass(Conv3DTranspose); var SeparableConv = /** @class */ (function (_super) { __extends(SeparableConv, _super); function SeparableConv(rank, config) { var _this = _super.call(this, rank, config) || this; _this.DEFAULT_DEPTHWISE_INITIALIZER = 'glorotUniform'; _this.DEFAULT_POINTWISE_INITIALIZER = 'glorotUniform'; _this.depthwiseKernel = null; _this.pointwiseKernel = null; if (config.filters == null) { throw new ValueError('The `filters` configuration field is required by SeparableConv, ' + 'but is unspecified.'); } if (config.kernelInitializer != null || config.kernelRegularizer != null || config.kernelConstraint != null) { throw new ValueError('Fields kernelInitializer, kernelRegularizer and kernelConstraint ' + 'are invalid for SeparableConv2D. Use depthwiseInitializer, ' + 'depthwiseRegularizer, depthwiseConstraint, pointwiseInitializer, ' + 'pointwiseRegularizer and pointwiseConstraint instead.'); } if (config.padding != null && config.padding !== 'same' && config.padding !== 'valid') { throw new ValueError("SeparableConv".concat(_this.rank, "D supports only padding modes: ") + "'same' and 'valid', but received ".concat(JSON.stringify(config.padding))); } _this.depthMultiplier = config.depthMultiplier == null ? 1 : config.depthMultiplier; _this.depthwiseInitializer = getInitializer(config.depthwiseInitializer || _this.DEFAULT_DEPTHWISE_INITIALIZER); _this.depthwiseRegularizer = getRegularizer(config.depthwiseRegularizer); _this.depthwiseConstraint = getConstraint(config.depthwiseConstraint); _this.pointwiseInitializer = getInitializer(config.depthwiseInitializer || _this.DEFAULT_POINTWISE_INITIALIZER); _this.pointwiseRegularizer = getRegularizer(config.pointwiseRegularizer); _this.pointwiseConstraint = getConstraint(config.pointwiseConstraint); return _this; } SeparableConv.prototype.build = function (inputShape) { var _a; inputShape = getExactlyOneShape(inputShape); if (inputShape.length < this.rank + 2) { throw new ValueError("Inputs to SeparableConv".concat(this.rank, "D should have rank ") + "".concat(this.rank + 2, ", but received input shape: ") + "".concat(JSON.stringify(inputShape))); } var channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1; if (inputShape[channelAxis] == null || inputShape[channelAxis] < 0) { throw new ValueError("The channel dimension of the inputs should be defined, " + "but found ".concat(JSON.stringify(inputShape[channelAxis]))); } var inputDim = inputShape[channelAxis]; var depthwiseKernelShape = this.kernelSize.concat([inputDim, this.depthMultiplier]); var pointwiseKernelShape = []; for (var i = 0; i < this.rank; ++i) { pointwiseKernelShape.push(1); } pointwiseKernelShape.push(inputDim * this.depthMultiplier, this.filters); var trainable = true; this.depthwiseKernel = this.addWeight('depthwise_kernel', depthwiseKernelShape, 'float32', this.depthwiseInitializer, this.depthwiseRegularizer, trainable, this.depthwiseConstraint); this.pointwiseKernel = this.addWeight('pointwise_kernel', pointwiseKernelShape, 'float32', this.pointwiseInitializer, this.pointwiseRegularizer, trainable, this.pointwiseConstraint); if (this.useBias) { this.bias = this.addWeight('bias', [this.filters], 'float32', this.biasInitializer, this.biasRegularizer, trainable, this.biasConstraint); } else { this.bias = null; } this.inputSpec = [new InputSpec({ ndim: this.rank + 2, axes: (_a = {}, _a[channelAxis] = inputDim, _a) })]; this.built = true; }; SeparableConv.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { inputs = getExactlyOneTensor(inputs); var output; if (_this.rank === 1) { throw new NotImplementedError('1D separable convolution is not implemented yet.'); } else if (_this.rank === 2) { if (_this.dataFormat === 'channelsFirst') { inputs = tfc__namespace.transpose(inputs, [0, 2, 3, 1]); // NCHW -> NHWC. } output = tfc__namespace.separableConv2d(inputs, _this.depthwiseKernel.read(), _this.pointwiseKernel.read(), _this.strides, _this.padding, _this.dilationRate, 'NHWC'); } if (_this.useBias) { output = biasAdd(output, _this.bias.read(), _this.dataFormat); } if (_this.activation != null) { output = _this.activation.apply(output); } if (_this.dataFormat === 'channelsFirst') { output = tfc__namespace.transpose(output, [0, 3, 1, 2]); // NHWC -> NCHW. } return output; }); }; SeparableConv.prototype.getConfig = function () { var config = _super.prototype.getConfig.call(this); delete config['rank']; delete config['kernelInitializer']; delete config['kernelRegularizer']; delete config['kernelConstraint']; config['depthwiseInitializer'] = serializeInitializer(this.depthwiseInitializer); config['pointwiseInitializer'] = serializeInitializer(this.pointwiseInitializer); config['depthwiseRegularizer'] = serializeRegularizer(this.depthwiseRegularizer); config['pointwiseRegularizer'] = serializeRegularizer(this.pointwiseRegularizer); config['depthwiseConstraint'] = serializeConstraint(this.depthwiseConstraint); config['pointwiseConstraint'] = serializeConstraint(this.pointwiseConstraint); return config; }; return SeparableConv; }(Conv)); /** @nocollapse */ SeparableConv.className = 'SeparableConv'; var SeparableConv2D = /** @class */ (function (_super) { __extends(SeparableConv2D, _super); function SeparableConv2D(args) { return _super.call(this, 2, args) || this; } return SeparableConv2D; }(SeparableConv)); /** @nocollapse */ SeparableConv2D.className = 'SeparableConv2D'; tfc.serialization.registerClass(SeparableConv2D); var Conv1D = /** @class */ (function (_super) { __extends(Conv1D, _super); function Conv1D(args) { var _this = _super.call(this, 1, args) || this; Conv1D.verifyArgs(args); _this.inputSpec = [{ ndim: 3 }]; return _this; } Conv1D.prototype.getConfig = function () { var config = _super.prototype.getConfig.call(this); delete config['rank']; delete config['dataFormat']; return config; }; Conv1D.verifyArgs = function (args) { // config.kernelSize must be a number or array of numbers. if (typeof args.kernelSize !== 'number' && !checkArrayTypeAndLength(args.kernelSize, 'number', 1, 1)) { throw new ValueError("Conv1D expects config.kernelSize to be number or number[] with " + "length 1, but received ".concat(JSON.stringify(args.kernelSize), ".")); } }; return Conv1D; }(Conv)); /** @nocollapse */ Conv1D.className = 'Conv1D'; tfc.serialization.registerClass(Conv1D); var Cropping2D = /** @class */ (function (_super) { __extends(Cropping2D, _super); function Cropping2D(args) { var _this = _super.call(this, args) || this; if (typeof args.cropping === 'number') { _this.cropping = [[args.cropping, args.cropping], [args.cropping, args.cropping]]; } else if (typeof args.cropping[0] === 'number') { _this.cropping = [ [args.cropping[0], args.cropping[0]], [args.cropping[1], args.cropping[1]] ]; } else { _this.cropping = args.cropping; } _this.dataFormat = args.dataFormat === undefined ? 'channelsLast' : args.dataFormat; _this.inputSpec = [{ ndim: 4 }]; return _this; } Cropping2D.prototype.computeOutputShape = function (inputShape) { if (this.dataFormat === 'channelsFirst') { return [ inputShape[0], inputShape[1], inputShape[2] - this.cropping[0][0] - this.cropping[0][1], inputShape[3] - this.cropping[1][0] - this.cropping[1][1] ]; } else { return [ inputShape[0], inputShape[1] - this.cropping[0][0] - this.cropping[0][1], inputShape[2] - this.cropping[1][0] - this.cropping[1][1], inputShape[3] ]; } }; Cropping2D.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { inputs = getExactlyOneTensor(inputs); if (_this.dataFormat === 'channelsLast') { var hSliced = sliceAlongAxis(inputs, _this.cropping[0][0], inputs.shape[1] - _this.cropping[0][0] - _this.cropping[0][1], 2); return sliceAlongAxis(hSliced, _this.cropping[1][0], inputs.shape[2] - _this.cropping[1][1] - _this.cropping[1][0], 3); } else { var hSliced = sliceAlongAxis(inputs, _this.cropping[0][0], inputs.shape[2] - _this.cropping[0][0] - _this.cropping[0][1], 3); return sliceAlongAxis(hSliced, _this.cropping[1][0], inputs.shape[3] - _this.cropping[1][1] - _this.cropping[1][0], 4); } }); }; Cropping2D.prototype.getConfig = function () { var config = { cropping: this.cropping, dataFormat: this.dataFormat }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return Cropping2D; }(Layer)); /** @nocollapse */ Cropping2D.className = 'Cropping2D'; tfc.serialization.registerClass(Cropping2D); var UpSampling2D = /** @class */ (function (_super) { __extends(UpSampling2D, _super); function UpSampling2D(args) { var _this = _super.call(this, args) || this; _this.DEFAULT_SIZE = [2, 2]; _this.inputSpec = [{ ndim: 4 }]; _this.size = args.size == null ? _this.DEFAULT_SIZE : args.size; _this.dataFormat = args.dataFormat == null ? 'channelsLast' : args.dataFormat; checkDataFormat(_this.dataFormat); _this.interpolation = args.interpolation == null ? 'nearest' : args.interpolation; checkInterpolationFormat(_this.interpolation); return _this; } UpSampling2D.prototype.computeOutputShape = function (inputShape) { if (this.dataFormat === 'channelsFirst') { var height = inputShape[2] == null ? null : this.size[0] * inputShape[2]; var width = inputShape[3] == null ? null : this.size[1] * inputShape[3]; return [inputShape[0], inputShape[1], height, width]; } else { var height = inputShape[1] == null ? null : this.size[0] * inputShape[1]; var width = inputShape[2] == null ? null : this.size[1] * inputShape[2]; return [inputShape[0], height, width, inputShape[3]]; } }; UpSampling2D.prototype.call = function (inputs, kwargs) { var _this = this; return tfc__namespace.tidy(function () { var input = getExactlyOneTensor(inputs); var inputShape = input.shape; if (_this.dataFormat === 'channelsFirst') { input = tfc__namespace.transpose(input, [0, 2, 3, 1]); var height = _this.size[0] * inputShape[2]; var width = _this.size[1] * inputShape[3]; var resized = _this.interpolation === 'nearest' ? tfc__namespace.image.resizeNearestNeighbor(input, [height, width]) : tfc__namespace.image.resizeBilinear(input, [height, width]); return tfc__namespace.transpose(resized, [0, 3, 1, 2]); } else { var height = _this.size[0] * inputShape[1]; var width = _this.size[1] * inputShape[2]; return _this.interpolation === 'nearest' ? tfc__namespace.image.resizeNearestNeighbor(input, [height, width]) : tfc__namespace.image.resizeBilinear(input, [height, width]); } }); }; UpSampling2D.prototype.getConfig = function () { var config = { size: this.size, dataFormat: this.dataFormat, interpolation: this.interpolation }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return UpSampling2D; }(Layer)); /** @nocollapse */ UpSampling2D.className = 'UpSampling2D'; tfc.serialization.registerClass(UpSampling2D); /** * 2D convolution with separable filters. * @param x Input tensor. * @param depthwiseKernel Convolution kernel for depthwise convolution. * @param strides Strides (Array of two integers). * @param padding Padding model. * @param dataFormat Data format. * @param dilationRate Array of two integers, dilation rates for the separable * convolution. * @returns Output tensor. * @throws ValueError If depthwiseKernel is not a 4D array. */ function depthwiseConv2d$1(x, depthwiseKernel, strides, padding, dataFormat, dilationRate) { if (strides === void 0) { strides = [1, 1]; } if (padding === void 0) { padding = 'valid'; } return tfc.tidy(function () { if (dataFormat == null) { dataFormat = imageDataFormat(); } checkDataFormat(dataFormat); var y = preprocessConv2DInput(x, dataFormat); if (x.rank !== 4) { throw new ValueError("Input for depthwiseConv2d is required to be 4-D, but is instead " + "".concat(x.rank, "-D")); } if (depthwiseKernel.rank !== 4) { throw new ValueError("depthwiseKernel is required to be 4-D, but is instead " + "".concat(depthwiseKernel.rank, "-D")); } y = tfc__namespace.depthwiseConv2d(y, depthwiseKernel, strides, padding === 'same' ? 'same' : 'valid', 'NHWC', dilationRate); if (dataFormat === 'channelsFirst') { y = tfc__namespace.transpose(y, [0, 3, 1, 2]); } return y; }); } var DepthwiseConv2D = /** @class */ (function (_super) { __extends(DepthwiseConv2D, _super); function DepthwiseConv2D(args) { var _this = _super.call(this, 2, args) || this; _this.depthwiseKernel = null; _this.depthMultiplier = args.depthMultiplier == null ? 1 : args.depthMultiplier; _this.depthwiseInitializer = getInitializer(args.depthwiseInitializer || _this.DEFAULT_KERNEL_INITIALIZER); _this.depthwiseConstraint = getConstraint(args.depthwiseConstraint); _this.depthwiseRegularizer = getRegularizer(args.depthwiseRegularizer); return _this; } DepthwiseConv2D.prototype.build = function (inputShape) { inputShape = getExactlyOneShape(inputShape); if (inputShape.length < 4) { throw new ValueError("Inputs to DepthwiseConv2D should have rank 4. " + "Received input shape: ".concat(JSON.stringify(inputShape), ".")); } var channelAxis = this.dataFormat === 'channelsFirst' ? 1 : 3; if (inputShape[channelAxis] == null || inputShape[channelAxis] < 0) { throw new ValueError('The channel dimension of the inputs to DepthwiseConv2D should ' + "be defined, but is not (".concat(inputShape[channelAxis], ").")); } var inputDim = inputShape[channelAxis]; var depthwiseKernelShape = [ this.kernelSize[0], this.kernelSize[1], inputDim, this.depthMultiplier ]; this.depthwiseKernel = this.addWeight('depthwise_kernel', depthwiseKernelShape, null, this.depthwiseInitializer, this.depthwiseRegularizer, true, this.depthwiseConstraint); if (this.useBias) { this.bias = this.addWeight('bias', [inputDim * this.depthMultiplier], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint); } else { this.bias = null; } this.built = true; }; DepthwiseConv2D.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { inputs = getExactlyOneTensor(inputs); var outputs = depthwiseConv2d$1(inputs, _this.depthwiseKernel.read(), _this.strides, _this.padding, _this.dataFormat, null); // TODO(cais): Add support for dilation. if (_this.useBias) { outputs = biasAdd(outputs, _this.bias.read(), _this.dataFormat); } if (_this.activation != null) { outputs = _this.activation.apply(outputs); } return outputs; }); }; DepthwiseConv2D.prototype.computeOutputShape = function (inputShape) { inputShape = getExactlyOneShape(inputShape); var rows = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1]; var cols = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2]; var outFilters = this.dataFormat === 'channelsFirst' ? inputShape[1] * this.depthMultiplier : inputShape[3] * this.depthMultiplier; var outRows = convOutputLength(rows, this.kernelSize[0], this.padding, this.strides[0]); var outCols = convOutputLength(cols, this.kernelSize[1], this.padding, this.strides[1]); if (this.dataFormat === 'channelsFirst') { return [inputShape[0], outFilters, outRows, outCols]; } else { // In this case, assume 'channelsLast'. return [inputShape[0], outRows, outCols, outFilters]; } }; DepthwiseConv2D.prototype.getConfig = function () { var config = _super.prototype.getConfig.call(this); config['depthMultiplier'] = this.depthMultiplier; config['depthwiseInitializer'] = serializeInitializer(this.depthwiseInitializer); config['depthwiseRegularizer'] = serializeRegularizer(this.depthwiseRegularizer); config['depthwiseConstraint'] = serializeConstraint(this.depthwiseRegularizer); return config; }; return DepthwiseConv2D; }(BaseConv)); /** @nocollapse */ DepthwiseConv2D.className = 'DepthwiseConv2D'; tfc.serialization.registerClass(DepthwiseConv2D); /** * Standardize `apply()` args to a single list of tensor inputs. * * When running a model loaded from file, the input tensors `initialState` and * `constants` are passed to `RNN.apply()` as part of `inputs` instead of the * dedicated kwargs fields. `inputs` consists of * `[inputs, initialState0, initialState1, ..., constant0, constant1]` in this * case. * This method makes sure that arguments are * separated and that `initialState` and `constants` are `Array`s of tensors * (or None). * * @param inputs Tensor or `Array` of tensors. * @param initialState Tensor or `Array` of tensors or `null`/`undefined`. * @param constants Tensor or `Array` of tensors or `null`/`undefined`. * @returns An object consisting of * inputs: A tensor. * initialState: `Array` of tensors or `null`. * constants: `Array` of tensors or `null`. * @throws ValueError, if `inputs` is an `Array` but either `initialState` or * `constants` is provided. */ function standardizeArgs(inputs, initialState, constants, numConstants) { if (Array.isArray(inputs)) { if (initialState != null || constants != null) { throw new ValueError('When inputs is an array, neither initialState or constants ' + 'should be provided'); } if (numConstants != null) { constants = inputs.slice(inputs.length - numConstants, inputs.length); inputs = inputs.slice(0, inputs.length - numConstants); } if (inputs.length > 1) { initialState = inputs.slice(1, inputs.length); } inputs = inputs[0]; } function toListOrNull(x) { if (x == null || Array.isArray(x)) { return x; } else { return [x]; } } initialState = toListOrNull(initialState); constants = toListOrNull(constants); return { inputs: inputs, initialState: initialState, constants: constants }; } /** * Iterates over the time dimension of a tensor. * * @param stepFunction RNN step function. * Parameters: * inputs: tensor with shape `[samples, ...]` (no time dimension), * representing input for the batch of samples at a certain time step. * states: an Array of tensors. * Returns: * outputs: tensor with shape `[samples, outputDim]` (no time dimension). * newStates: list of tensors, same length and shapes as `states`. The first * state in the list must be the output tensor at the previous timestep. * @param inputs Tensor of temporal data of shape `[samples, time, ...]` (at * least 3D). * @param initialStates Tensor with shape `[samples, outputDim]` (no time * dimension), containing the initial values of the states used in the step * function. * @param goBackwards If `true`, do the iteration over the time dimension in * reverse order and return the reversed sequence. * @param mask Binary tensor with shape `[sample, time, 1]`, with a zero for * every element that is masked. * @param constants An Array of constant values passed at each step. * @param unroll Whether to unroll the RNN or to use a symbolic loop. *Not* * applicable to this imperative deeplearn.js backend. Its value is ignored. * @param needPerStepOutputs Whether the per-step outputs are to be * concatenated into a single tensor and returned (as the second return * value). Default: `false`. This arg is included so that the relatively * expensive concatenation of the stepwise outputs can be omitted unless * the stepwise outputs need to be kept (e.g., for an LSTM layer of which * `returnSequence` is `true`.) * @returns An Array: `[lastOutput, outputs, newStates]`. * lastOutput: the lastest output of the RNN, of shape `[samples, ...]`. * outputs: tensor with shape `[samples, time, ...]` where each entry * `output[s, t]` is the output of the step function at time `t` for sample * `s`. This return value is provided if and only if the * `needPerStepOutputs` is set as `true`. If it is set as `false`, this * return value will be `undefined`. * newStates: Array of tensors, latest states returned by the step function, * of shape `(samples, ...)`. * @throws ValueError If input dimension is less than 3. * * TODO(nielsene): This needs to be tidy-ed. */ function rnn$1(stepFunction, inputs, initialStates, goBackwards, mask, constants, unroll, needPerStepOutputs) { if (goBackwards === void 0) { goBackwards = false; } if (unroll === void 0) { unroll = false; } if (needPerStepOutputs === void 0) { needPerStepOutputs = false; } return tfc__namespace.tidy(function () { var ndim = inputs.shape.length; if (ndim < 3) { throw new ValueError("Input should be at least 3D, but is ".concat(ndim, "D.")); } // Transpose to time-major, i.e., from [batch, time, ...] to [time, batch, // ...]. var axes = [1, 0].concat(range(2, ndim)); inputs = tfc__namespace.transpose(inputs, axes); if (constants != null) { throw new NotImplementedError('The rnn() functoin of the deeplearn.js backend does not support ' + 'constants yet.'); } // Porting Note: the unroll option is ignored by the imperative backend. if (unroll) { console.warn('Backend rnn(): the unroll = true option is not applicable to the ' + 'imperative deeplearn.js backend.'); } if (mask != null) { mask = tfc__namespace.cast(tfc__namespace.cast(mask, 'bool'), 'float32'); if (mask.rank === ndim - 1) { mask = tfc__namespace.expandDims(mask, -1); } mask = tfc__namespace.transpose(mask, axes); } if (goBackwards) { inputs = tfc__namespace.reverse(inputs, 0); if (mask != null) { mask = tfc__namespace.reverse(mask, 0); } } // Porting Note: PyKeras with TensorFlow backend uses a symbolic loop // (tf.while_loop). But for the imperative deeplearn.js backend, we just // use the usual TypeScript control flow to iterate over the time steps in // the inputs. // Porting Note: PyKeras patches a "_use_learning_phase" attribute to // outputs. // This is not idiomatic in TypeScript. The info regarding whether we are // in a learning (i.e., training) phase for RNN is passed in a different // way. var perStepOutputs = []; var lastOutput; var states = initialStates; var timeSteps = inputs.shape[0]; var perStepInputs = tfc__namespace.unstack(inputs); var perStepMasks; if (mask != null) { perStepMasks = tfc__namespace.unstack(mask); } var _loop_1 = function (t) { var currentInput = perStepInputs[t]; var stepOutputs = tfc__namespace.tidy(function () { return stepFunction(currentInput, states); }); if (mask == null) { lastOutput = stepOutputs[0]; states = stepOutputs[1]; } else { var maskedOutputs = tfc__namespace.tidy(function () { var stepMask = perStepMasks[t]; var negStepMask = tfc__namespace.sub(tfc__namespace.onesLike(stepMask), stepMask); // TODO(cais): Would tfc.where() be better for performance? var output = tfc__namespace.add(tfc__namespace.mul(stepOutputs[0], stepMask), tfc__namespace.mul(states[0], negStepMask)); var newStates = states.map(function (state, i) { return tfc__namespace.add(tfc__namespace.mul(stepOutputs[1][i], stepMask), tfc__namespace.mul(state, negStepMask)); }); return { output: output, newStates: newStates }; }); lastOutput = maskedOutputs.output; states = maskedOutputs.newStates; } if (needPerStepOutputs) { perStepOutputs.push(lastOutput); } }; for (var t = 0; t < timeSteps; ++t) { _loop_1(t); } var outputs; if (needPerStepOutputs) { var axis = 1; outputs = tfc__namespace.stack(perStepOutputs, axis); } return [lastOutput, outputs, states]; }); } var RNN = /** @class */ (function (_super) { __extends(RNN, _super); function RNN(args) { var _this = _super.call(this, args) || this; var cell; if (args.cell == null) { throw new ValueError('cell property is missing for the constructor of RNN.'); } else if (Array.isArray(args.cell)) { cell = new StackedRNNCells({ cells: args.cell }); } else { cell = args.cell; } if (cell.stateSize == null) { throw new ValueError('The RNN cell should have an attribute `stateSize` (tuple of ' + 'integers, one integer per RNN state).'); } _this.cell = cell; _this.returnSequences = args.returnSequences == null ? false : args.returnSequences; _this.returnState = args.returnState == null ? false : args.returnState; _this.goBackwards = args.goBackwards == null ? false : args.goBackwards; _this._stateful = args.stateful == null ? false : args.stateful; _this.unroll = args.unroll == null ? false : args.unroll; _this.supportsMasking = true; _this.inputSpec = [new InputSpec({ ndim: 3 })]; _this.stateSpec = null; _this.states_ = null; // TODO(cais): Add constantsSpec and numConstants. _this.numConstants = null; // TODO(cais): Look into the use of initial_state in the kwargs of the // constructor. _this.keptStates = []; return _this; } // Porting Note: This is the equivalent of `RNN.states` property getter in // PyKeras. RNN.prototype.getStates = function () { if (this.states_ == null) { var numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1; return range(0, numStates).map(function (x) { return null; }); } else { return this.states_; } }; // Porting Note: This is the equivalent of the `RNN.states` property setter in // PyKeras. RNN.prototype.setStates = function (states) { this.states_ = states; }; RNN.prototype.computeOutputShape = function (inputShape) { var e_1, _b; if (isArrayOfShapes(inputShape)) { inputShape = inputShape[0]; } inputShape = inputShape; // TODO(cais): Remove the casting once stacked RNN cells become supported. var stateSize = this.cell.stateSize; if (!Array.isArray(stateSize)) { stateSize = [stateSize]; } var outputDim = stateSize[0]; var outputShape; if (this.returnSequences) { outputShape = [inputShape[0], inputShape[1], outputDim]; } else { outputShape = [inputShape[0], outputDim]; } if (this.returnState) { var stateShape = []; try { for (var stateSize_1 = __values(stateSize), stateSize_1_1 = stateSize_1.next(); !stateSize_1_1.done; stateSize_1_1 = stateSize_1.next()) { var dim = stateSize_1_1.value; stateShape.push([inputShape[0], dim]); } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (stateSize_1_1 && !stateSize_1_1.done && (_b = stateSize_1.return)) _b.call(stateSize_1); } finally { if (e_1) throw e_1.error; } } return [outputShape].concat(stateShape); } else { return outputShape; } }; RNN.prototype.computeMask = function (inputs, mask) { var _this = this; return tfc__namespace.tidy(function () { if (Array.isArray(mask)) { mask = mask[0]; } var outputMask = _this.returnSequences ? mask : null; if (_this.returnState) { var stateMask = _this.states.map(function (s) { return null; }); return [outputMask].concat(stateMask); } else { return outputMask; } }); }; Object.defineProperty(RNN.prototype, "states", { /** * Get the current state tensors of the RNN. * * If the state hasn't been set, return an array of `null`s of the correct * length. */ get: function () { if (this.states_ == null) { var numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1; var output = []; for (var i = 0; i < numStates; ++i) { output.push(null); } return output; } else { return this.states_; } }, set: function (s) { this.states_ = s; }, enumerable: false, configurable: true }); RNN.prototype.build = function (inputShape) { if (this.numConstants != null) { throw new NotImplementedError('Constants support is not implemented in RNN yet.'); } if (isArrayOfShapes(inputShape)) { inputShape = inputShape[0]; } inputShape = inputShape; var batchSize = this.stateful ? inputShape[0] : null; var inputDim = inputShape.slice(2); this.inputSpec[0] = new InputSpec({ shape: __spreadArray([batchSize, null], __read(inputDim), false) }); // Allow cell (if RNNCell Layer) to build before we set or validate // stateSpec. var stepInputShape = [inputShape[0]].concat(inputShape.slice(2)); { this.cell.build(stepInputShape); } // Set or validate stateSpec. var stateSize; if (Array.isArray(this.cell.stateSize)) { stateSize = this.cell.stateSize; } else { stateSize = [this.cell.stateSize]; } if (this.stateSpec != null) { if (!tfc.util.arraysEqual(this.stateSpec.map(function (spec) { return spec.shape[spec.shape.length - 1]; }), stateSize)) { throw new ValueError("An initialState was passed that is not compatible with " + "cell.stateSize. Received stateSpec=".concat(this.stateSpec, "; ") + "However cell.stateSize is ".concat(this.cell.stateSize)); } } else { this.stateSpec = stateSize.map(function (dim) { return new InputSpec({ shape: [null, dim] }); }); } if (this.stateful) { this.resetStates(); } }; /** * Reset the state tensors of the RNN. * * If the `states` argument is `undefined` or `null`, will set the * state tensor(s) of the RNN to all-zero tensors of the appropriate * shape(s). * * If `states` is provided, will set the state tensors of the RNN to its * value. * * @param states Optional externally-provided initial states. * @param training Whether this call is done during training. For stateful * RNNs, this affects whether the old states are kept or discarded. In * particular, if `training` is `true`, the old states will be kept so * that subsequent backpropgataion through time (BPTT) may work properly. * Else, the old states will be discarded. */ RNN.prototype.resetStates = function (states, training) { var _this = this; if (training === void 0) { training = false; } tfc.tidy(function () { if (!_this.stateful) { throw new AttributeError('Cannot call resetStates() on an RNN Layer that is not stateful.'); } var batchSize = _this.inputSpec[0].shape[0]; if (batchSize == null) { throw new ValueError('If an RNN is stateful, it needs to know its batch size. Specify ' + 'the batch size of your input tensors: \n' + '- If using a Sequential model, specify the batch size by ' + 'passing a `batchInputShape` option to your first layer.\n' + '- If using the functional API, specify the batch size by ' + 'passing a `batchShape` option to your Input layer.'); } // Initialize state if null. if (_this.states_ == null) { if (Array.isArray(_this.cell.stateSize)) { _this.states_ = _this.cell.stateSize.map(function (dim) { return tfc__namespace.zeros([batchSize, dim]); }); } else { _this.states_ = [tfc__namespace.zeros([batchSize, _this.cell.stateSize])]; } } else if (states == null) { // Dispose old state tensors. tfc__namespace.dispose(_this.states_); // For stateful RNNs, fully dispose kept old states. if (_this.keptStates != null) { tfc__namespace.dispose(_this.keptStates); _this.keptStates = []; } if (Array.isArray(_this.cell.stateSize)) { _this.states_ = _this.cell.stateSize.map(function (dim) { return tfc__namespace.zeros([batchSize, dim]); }); } else { _this.states_[0] = tfc__namespace.zeros([batchSize, _this.cell.stateSize]); } } else { if (!Array.isArray(states)) { states = [states]; } if (states.length !== _this.states_.length) { throw new ValueError("Layer ".concat(_this.name, " expects ").concat(_this.states_.length, " state(s), ") + "but it received ".concat(states.length, " state value(s). Input ") + "received: ".concat(states)); } if (training === true) { // Store old state tensors for complete disposal later, i.e., during // the next no-arg call to this method. We do not dispose the old // states immediately because that BPTT (among other things) require // them. _this.keptStates.push(_this.states_.slice()); } else { tfc__namespace.dispose(_this.states_); } for (var index = 0; index < _this.states_.length; ++index) { var value = states[index]; var dim = Array.isArray(_this.cell.stateSize) ? _this.cell.stateSize[index] : _this.cell.stateSize; var expectedShape = [batchSize, dim]; if (!tfc.util.arraysEqual(value.shape, expectedShape)) { throw new ValueError("State ".concat(index, " is incompatible with layer ").concat(_this.name, ": ") + "expected shape=".concat(expectedShape, ", received shape=").concat(value.shape)); } _this.states_[index] = value; } } _this.states_ = _this.states_.map(function (state) { return tfc__namespace.keep(state.clone()); }); }); }; RNN.prototype.apply = function (inputs, kwargs) { var e_2, _b; // TODO(cais): Figure out whether initialState is in kwargs or inputs. var initialState = kwargs == null ? null : kwargs['initialState']; var constants = kwargs == null ? null : kwargs['constants']; if (kwargs == null) { kwargs = {}; } var standardized = standardizeArgs(inputs, initialState, constants, this.numConstants); inputs = standardized.inputs; initialState = standardized.initialState; constants = standardized.constants; // If any of `initial_state` or `constants` are specified and are // `tf.SymbolicTensor`s, then add them to the inputs and temporarily modify // the input_spec to include them. var additionalInputs = []; var additionalSpecs = []; if (initialState != null) { kwargs['initialState'] = initialState; additionalInputs = additionalInputs.concat(initialState); this.stateSpec = []; try { for (var initialState_1 = __values(initialState), initialState_1_1 = initialState_1.next(); !initialState_1_1.done; initialState_1_1 = initialState_1.next()) { var state = initialState_1_1.value; this.stateSpec.push(new InputSpec({ shape: state.shape })); } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (initialState_1_1 && !initialState_1_1.done && (_b = initialState_1.return)) _b.call(initialState_1); } finally { if (e_2) throw e_2.error; } } // TODO(cais): Use the following instead. // this.stateSpec = initialState.map(state => new InputSpec({shape: // state.shape})); additionalSpecs = additionalSpecs.concat(this.stateSpec); } if (constants != null) { kwargs['constants'] = constants; additionalInputs = additionalInputs.concat(constants); // TODO(cais): Add this.constantsSpec. this.numConstants = constants.length; } var isTensor = additionalInputs[0] instanceof SymbolicTensor; if (isTensor) { // Compute full input spec, including state and constants. var fullInput = [inputs].concat(additionalInputs); var fullInputSpec = this.inputSpec.concat(additionalSpecs); // Perform the call with temporarily replaced inputSpec. var originalInputSpec = this.inputSpec; this.inputSpec = fullInputSpec; var output = _super.prototype.apply.call(this, fullInput, kwargs); this.inputSpec = originalInputSpec; return output; } else { return _super.prototype.apply.call(this, inputs, kwargs); } }; // tslint:disable-next-line:no-any RNN.prototype.call = function (inputs, kwargs) { var _this = this; // Input shape: `[samples, time (padded with zeros), input_dim]`. // Note that the .build() method of subclasses **must** define // this.inputSpec and this.stateSpec owith complete input shapes. return tfc.tidy(function () { var mask = kwargs == null ? null : kwargs['mask']; var training = kwargs == null ? null : kwargs['training']; var initialState = kwargs == null ? null : kwargs['initialState']; inputs = getExactlyOneTensor(inputs); if (initialState == null) { if (_this.stateful) { initialState = _this.states_; } else { initialState = _this.getInitialState(inputs); } } var numStates = Array.isArray(_this.cell.stateSize) ? _this.cell.stateSize.length : 1; if (initialState.length !== numStates) { throw new ValueError("RNN Layer has ".concat(numStates, " state(s) but was passed ") + "".concat(initialState.length, " initial state(s).")); } if (_this.unroll) { console.warn('Ignoring unroll = true for RNN layer, due to imperative backend.'); } var cellCallKwargs = { training: training }; // TODO(cais): Add support for constants. var step = function (inputs, states) { // `inputs` and `states` are concatenated to form a single `Array` of // `tf.Tensor`s as the input to `cell.call()`. var outputs = _this.cell.call([inputs].concat(states), cellCallKwargs); // Marshall the return value into output and new states. return [outputs[0], outputs.slice(1)]; }; // TODO(cais): Add support for constants. var rnnOutputs = rnn$1(step, inputs, initialState, _this.goBackwards, mask, null, _this.unroll, _this.returnSequences); var lastOutput = rnnOutputs[0]; var outputs = rnnOutputs[1]; var states = rnnOutputs[2]; if (_this.stateful) { _this.resetStates(states, training); } var output = _this.returnSequences ? outputs : lastOutput; // TODO(cais): Porperty set learning phase flag. if (_this.returnState) { return [output].concat(states); } else { return output; } }); }; RNN.prototype.getInitialState = function (inputs) { var _this = this; return tfc.tidy(function () { // Build an all-zero tensor of shape [samples, outputDim]. // [Samples, timeSteps, inputDim]. var initialState = tfc__namespace.zeros(inputs.shape); // [Samples]. initialState = tfc__namespace.sum(initialState, [1, 2]); initialState = expandDims$1(initialState); // [Samples, 1]. if (Array.isArray(_this.cell.stateSize)) { return _this.cell.stateSize.map(function (dim) { return dim > 1 ? tile$1(initialState, [1, dim]) : initialState; }); } else { return _this.cell.stateSize > 1 ? [tile$1(initialState, [1, _this.cell.stateSize])] : [initialState]; } }); }; Object.defineProperty(RNN.prototype, "trainableWeights", { get: function () { if (!this.trainable) { return []; } // Porting Note: In TypeScript, `this` is always an instance of `Layer`. return this.cell.trainableWeights; }, enumerable: false, configurable: true }); Object.defineProperty(RNN.prototype, "nonTrainableWeights", { get: function () { // Porting Note: In TypeScript, `this` is always an instance of `Layer`. if (!this.trainable) { return this.cell.weights; } return this.cell.nonTrainableWeights; }, enumerable: false, configurable: true }); RNN.prototype.setFastWeightInitDuringBuild = function (value) { _super.prototype.setFastWeightInitDuringBuild.call(this, value); if (this.cell != null) { this.cell.setFastWeightInitDuringBuild(value); } }; RNN.prototype.getConfig = function () { var baseConfig = _super.prototype.getConfig.call(this); var config = { returnSequences: this.returnSequences, returnState: this.returnState, goBackwards: this.goBackwards, stateful: this.stateful, unroll: this.unroll, }; if (this.numConstants != null) { config['numConstants'] = this.numConstants; } var cellConfig = this.cell.getConfig(); if (this.getClassName() === RNN.className) { config['cell'] = { 'className': this.cell.getClassName(), 'config': cellConfig, }; } // this order is necessary, to prevent cell name from replacing layer name return Object.assign(Object.assign(Object.assign({}, cellConfig), baseConfig), config); }; /** @nocollapse */ RNN.fromConfig = function (cls, config, customObjects) { if (customObjects === void 0) { customObjects = {}; } var cellConfig = config['cell']; var cell = deserialize(cellConfig, customObjects); return new cls(Object.assign(config, { cell: cell })); }; return RNN; }(Layer)); /** @nocollapse */ RNN.className = 'RNN'; tfc.serialization.registerClass(RNN); // Porting Note: This is a common parent class for RNN cells. There is no // equivalent of this in PyKeras. Having a common parent class forgoes the // need for `has_attr(cell, ...)` checks or its TypeScript equivalent. /** * An RNNCell layer. * * @doc {heading: 'Layers', subheading: 'Classes'} */ var RNNCell = /** @class */ (function (_super) { __extends(RNNCell, _super); function RNNCell() { return _super !== null && _super.apply(this, arguments) || this; } return RNNCell; }(Layer)); var SimpleRNNCell = /** @class */ (function (_super) { __extends(SimpleRNNCell, _super); function SimpleRNNCell(args) { var _this = _super.call(this, args) || this; _this.DEFAULT_ACTIVATION = 'tanh'; _this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal'; _this.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal'; _this.DEFAULT_BIAS_INITIALIZER = 'zeros'; _this.units = args.units; assertPositiveInteger(_this.units, "units"); _this.activation = getActivation(args.activation == null ? _this.DEFAULT_ACTIVATION : args.activation); _this.useBias = args.useBias == null ? true : args.useBias; _this.kernelInitializer = getInitializer(args.kernelInitializer || _this.DEFAULT_KERNEL_INITIALIZER); _this.recurrentInitializer = getInitializer(args.recurrentInitializer || _this.DEFAULT_RECURRENT_INITIALIZER); _this.biasInitializer = getInitializer(args.biasInitializer || _this.DEFAULT_BIAS_INITIALIZER); _this.kernelRegularizer = getRegularizer(args.kernelRegularizer); _this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer); _this.biasRegularizer = getRegularizer(args.biasRegularizer); _this.kernelConstraint = getConstraint(args.kernelConstraint); _this.recurrentConstraint = getConstraint(args.recurrentConstraint); _this.biasConstraint = getConstraint(args.biasConstraint); _this.dropout = min([1, max([0, args.dropout == null ? 0 : args.dropout])]); _this.recurrentDropout = min([ 1, max([0, args.recurrentDropout == null ? 0 : args.recurrentDropout]) ]); _this.dropoutFunc = args.dropoutFunc; _this.stateSize = _this.units; _this.dropoutMask = null; _this.recurrentDropoutMask = null; return _this; } SimpleRNNCell.prototype.build = function (inputShape) { inputShape = getExactlyOneShape(inputShape); // TODO(cais): Use regularizer. this.kernel = this.addWeight('kernel', [inputShape[inputShape.length - 1], this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint); this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint); if (this.useBias) { this.bias = this.addWeight('bias', [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint); } else { this.bias = null; } this.built = true; }; // Porting Note: PyKeras' equivalent of this method takes two tensor inputs: // `inputs` and `states`. Here, the two tensors are combined into an // `Tensor[]` Array as the first input argument. // Similarly, PyKeras' equivalent of this method returns two values: // `output` and `[output]`. Here the two are combined into one length-2 // `Tensor[]`, consisting of `output` repeated. SimpleRNNCell.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { inputs = inputs; if (inputs.length !== 2) { throw new ValueError("SimpleRNNCell expects 2 input Tensors, got ".concat(inputs.length, ".")); } var prevOutput = inputs[1]; inputs = inputs[0]; var training = kwargs['training'] == null ? false : kwargs['training']; if (0 < _this.dropout && _this.dropout < 1 && _this.dropoutMask == null) { _this.dropoutMask = generateDropoutMask({ ones: function () { return tfc__namespace.onesLike(inputs); }, rate: _this.dropout, training: training, dropoutFunc: _this.dropoutFunc, }); } if (0 < _this.recurrentDropout && _this.recurrentDropout < 1 && _this.recurrentDropoutMask == null) { _this.recurrentDropoutMask = generateDropoutMask({ ones: function () { return tfc__namespace.onesLike(prevOutput); }, rate: _this.recurrentDropout, training: training, dropoutFunc: _this.dropoutFunc, }); } var h; var dpMask = _this.dropoutMask; var recDpMask = _this.recurrentDropoutMask; if (dpMask != null) { h = dot$1(tfc__namespace.mul(inputs, dpMask), _this.kernel.read()); } else { h = dot$1(inputs, _this.kernel.read()); } if (_this.bias != null) { h = biasAdd(h, _this.bias.read()); } if (recDpMask != null) { prevOutput = tfc__namespace.mul(prevOutput, recDpMask); } var output = tfc__namespace.add(h, dot$1(prevOutput, _this.recurrentKernel.read())); if (_this.activation != null) { output = _this.activation.apply(output); } // TODO(cais): Properly set learning phase on output tensor? return [output, output]; }); }; SimpleRNNCell.prototype.getConfig = function () { var baseConfig = _super.prototype.getConfig.call(this); var config = { units: this.units, activation: serializeActivation(this.activation), useBias: this.useBias, kernelInitializer: serializeInitializer(this.kernelInitializer), recurrentInitializer: serializeInitializer(this.recurrentInitializer), biasInitializer: serializeInitializer(this.biasInitializer), kernelRegularizer: serializeRegularizer(this.kernelRegularizer), recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer), biasRegularizer: serializeRegularizer(this.biasRegularizer), activityRegularizer: serializeRegularizer(this.activityRegularizer), kernelConstraint: serializeConstraint(this.kernelConstraint), recurrentConstraint: serializeConstraint(this.recurrentConstraint), biasConstraint: serializeConstraint(this.biasConstraint), dropout: this.dropout, recurrentDropout: this.recurrentDropout, }; return Object.assign(Object.assign({}, baseConfig), config); }; return SimpleRNNCell; }(RNNCell)); /** @nocollapse */ SimpleRNNCell.className = 'SimpleRNNCell'; tfc.serialization.registerClass(SimpleRNNCell); var SimpleRNN = /** @class */ (function (_super) { __extends(SimpleRNN, _super); function SimpleRNN(args) { args.cell = new SimpleRNNCell(args); return _super.call(this, args) || this; // TODO(cais): Add activityRegularizer. } SimpleRNN.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { if (_this.cell.dropoutMask != null) { tfc__namespace.dispose(_this.cell.dropoutMask); _this.cell.dropoutMask = null; } if (_this.cell.recurrentDropoutMask != null) { tfc__namespace.dispose(_this.cell.recurrentDropoutMask); _this.cell.recurrentDropoutMask = null; } var mask = kwargs == null ? null : kwargs['mask']; var training = kwargs == null ? null : kwargs['training']; var initialState = kwargs == null ? null : kwargs['initialState']; return _super.prototype.call.call(_this, inputs, { mask: mask, training: training, initialState: initialState }); }); }; /** @nocollapse */ SimpleRNN.fromConfig = function (cls, config) { return new cls(config); }; return SimpleRNN; }(RNN)); /** @nocollapse */ SimpleRNN.className = 'SimpleRNN'; tfc.serialization.registerClass(SimpleRNN); var GRUCell = /** @class */ (function (_super) { __extends(GRUCell, _super); function GRUCell(args) { var _this = _super.call(this, args) || this; _this.DEFAULT_ACTIVATION = 'tanh'; _this.DEFAULT_RECURRENT_ACTIVATION = 'hardSigmoid'; _this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal'; _this.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal'; _this.DEFAULT_BIAS_INITIALIZER = 'zeros'; if (args.resetAfter) { throw new ValueError("GRUCell does not support reset_after parameter set to true."); } _this.units = args.units; assertPositiveInteger(_this.units, 'units'); _this.activation = getActivation(args.activation === undefined ? _this.DEFAULT_ACTIVATION : args.activation); _this.recurrentActivation = getActivation(args.recurrentActivation === undefined ? _this.DEFAULT_RECURRENT_ACTIVATION : args.recurrentActivation); _this.useBias = args.useBias == null ? true : args.useBias; _this.kernelInitializer = getInitializer(args.kernelInitializer || _this.DEFAULT_KERNEL_INITIALIZER); _this.recurrentInitializer = getInitializer(args.recurrentInitializer || _this.DEFAULT_RECURRENT_INITIALIZER); _this.biasInitializer = getInitializer(args.biasInitializer || _this.DEFAULT_BIAS_INITIALIZER); _this.kernelRegularizer = getRegularizer(args.kernelRegularizer); _this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer); _this.biasRegularizer = getRegularizer(args.biasRegularizer); _this.kernelConstraint = getConstraint(args.kernelConstraint); _this.recurrentConstraint = getConstraint(args.recurrentConstraint); _this.biasConstraint = getConstraint(args.biasConstraint); _this.dropout = min([1, max([0, args.dropout == null ? 0 : args.dropout])]); _this.recurrentDropout = min([ 1, max([0, args.recurrentDropout == null ? 0 : args.recurrentDropout]) ]); _this.dropoutFunc = args.dropoutFunc; _this.implementation = args.implementation; _this.stateSize = _this.units; _this.dropoutMask = null; _this.recurrentDropoutMask = null; return _this; } GRUCell.prototype.build = function (inputShape) { inputShape = getExactlyOneShape(inputShape); var inputDim = inputShape[inputShape.length - 1]; this.kernel = this.addWeight('kernel', [inputDim, this.units * 3], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint); this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units * 3], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint); if (this.useBias) { this.bias = this.addWeight('bias', [this.units * 3], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint); } else { this.bias = null; } // Porting Notes: Unlike the PyKeras implementation, we perform slicing // of the weights and bias in the call() method, at execution time. this.built = true; }; GRUCell.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { inputs = inputs; if (inputs.length !== 2) { throw new ValueError("GRUCell expects 2 input Tensors (inputs, h, c), got " + "".concat(inputs.length, ".")); } var training = kwargs['training'] == null ? false : kwargs['training']; var hTMinus1 = inputs[1]; // Previous memory state. inputs = inputs[0]; // Note: For superior performance, TensorFlow.js always uses // implementation 2, regardless of the actual value of // config.implementation. if (0 < _this.dropout && _this.dropout < 1 && _this.dropoutMask == null) { _this.dropoutMask = generateDropoutMask({ ones: function () { return tfc__namespace.onesLike(inputs); }, rate: _this.dropout, training: training, count: 3, dropoutFunc: _this.dropoutFunc, }); } if (0 < _this.recurrentDropout && _this.recurrentDropout < 1 && _this.recurrentDropoutMask == null) { _this.recurrentDropoutMask = generateDropoutMask({ ones: function () { return tfc__namespace.onesLike(hTMinus1); }, rate: _this.recurrentDropout, training: training, count: 3, dropoutFunc: _this.dropoutFunc, }); } var dpMask = _this.dropoutMask; var recDpMask = _this.recurrentDropoutMask; var z; var r; var hh; if (0 < _this.dropout && _this.dropout < 1) { inputs = tfc__namespace.mul(inputs, dpMask[0]); } var matrixX = dot$1(inputs, _this.kernel.read()); if (_this.useBias) { matrixX = biasAdd(matrixX, _this.bias.read()); } if (0 < _this.recurrentDropout && _this.recurrentDropout < 1) { hTMinus1 = tfc__namespace.mul(hTMinus1, recDpMask[0]); } var recurrentKernelValue = _this.recurrentKernel.read(); var _b = __read(tfc__namespace.split(recurrentKernelValue, [2 * _this.units, _this.units], recurrentKernelValue.rank - 1), 2), rk1 = _b[0], rk2 = _b[1]; var matrixInner = dot$1(hTMinus1, rk1); var _c = __read(tfc__namespace.split(matrixX, 3, matrixX.rank - 1), 3), xZ = _c[0], xR = _c[1], xH = _c[2]; var _d = __read(tfc__namespace.split(matrixInner, 2, matrixInner.rank - 1), 2), recurrentZ = _d[0], recurrentR = _d[1]; z = _this.recurrentActivation.apply(tfc__namespace.add(xZ, recurrentZ)); r = _this.recurrentActivation.apply(tfc__namespace.add(xR, recurrentR)); var recurrentH = dot$1(tfc__namespace.mul(r, hTMinus1), rk2); hh = _this.activation.apply(tfc__namespace.add(xH, recurrentH)); var h = tfc__namespace.add(tfc__namespace.mul(z, hTMinus1), tfc__namespace.mul(tfc__namespace.add(1, tfc__namespace.neg(z)), hh)); // TODO(cais): Add use_learning_phase flag properly. return [h, h]; }); }; GRUCell.prototype.getConfig = function () { var baseConfig = _super.prototype.getConfig.call(this); var config = { units: this.units, activation: serializeActivation(this.activation), recurrentActivation: serializeActivation(this.recurrentActivation), useBias: this.useBias, kernelInitializer: serializeInitializer(this.kernelInitializer), recurrentInitializer: serializeInitializer(this.recurrentInitializer), biasInitializer: serializeInitializer(this.biasInitializer), kernelRegularizer: serializeRegularizer(this.kernelRegularizer), recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer), biasRegularizer: serializeRegularizer(this.biasRegularizer), activityRegularizer: serializeRegularizer(this.activityRegularizer), kernelConstraint: serializeConstraint(this.kernelConstraint), recurrentConstraint: serializeConstraint(this.recurrentConstraint), biasConstraint: serializeConstraint(this.biasConstraint), dropout: this.dropout, recurrentDropout: this.recurrentDropout, implementation: this.implementation, resetAfter: false }; return Object.assign(Object.assign({}, baseConfig), config); }; return GRUCell; }(RNNCell)); /** @nocollapse */ GRUCell.className = 'GRUCell'; tfc.serialization.registerClass(GRUCell); var GRU = /** @class */ (function (_super) { __extends(GRU, _super); function GRU(args) { if (args.implementation === 0) { console.warn('`implementation=0` has been deprecated, and now defaults to ' + '`implementation=1`. Please update your layer call.'); } args.cell = new GRUCell(args); return _super.call(this, args) || this; // TODO(cais): Add activityRegularizer. } GRU.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { if (_this.cell.dropoutMask != null) { tfc__namespace.dispose(_this.cell.dropoutMask); _this.cell.dropoutMask = null; } if (_this.cell.recurrentDropoutMask != null) { tfc__namespace.dispose(_this.cell.recurrentDropoutMask); _this.cell.recurrentDropoutMask = null; } var mask = kwargs == null ? null : kwargs['mask']; var training = kwargs == null ? null : kwargs['training']; var initialState = kwargs == null ? null : kwargs['initialState']; return _super.prototype.call.call(_this, inputs, { mask: mask, training: training, initialState: initialState }); }); }; /** @nocollapse */ GRU.fromConfig = function (cls, config) { if (config['implmentation'] === 0) { config['implementation'] = 1; } return new cls(config); }; return GRU; }(RNN)); /** @nocollapse */ GRU.className = 'GRU'; tfc.serialization.registerClass(GRU); var LSTMCell = /** @class */ (function (_super) { __extends(LSTMCell, _super); function LSTMCell(args) { var _this = _super.call(this, args) || this; _this.DEFAULT_ACTIVATION = 'tanh'; _this.DEFAULT_RECURRENT_ACTIVATION = 'hardSigmoid'; _this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal'; _this.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal'; _this.DEFAULT_BIAS_INITIALIZER = 'zeros'; _this.units = args.units; assertPositiveInteger(_this.units, 'units'); _this.activation = getActivation(args.activation === undefined ? _this.DEFAULT_ACTIVATION : args.activation); _this.recurrentActivation = getActivation(args.recurrentActivation === undefined ? _this.DEFAULT_RECURRENT_ACTIVATION : args.recurrentActivation); _this.useBias = args.useBias == null ? true : args.useBias; _this.kernelInitializer = getInitializer(args.kernelInitializer || _this.DEFAULT_KERNEL_INITIALIZER); _this.recurrentInitializer = getInitializer(args.recurrentInitializer || _this.DEFAULT_RECURRENT_INITIALIZER); _this.biasInitializer = getInitializer(args.biasInitializer || _this.DEFAULT_BIAS_INITIALIZER); _this.unitForgetBias = args.unitForgetBias; _this.kernelRegularizer = getRegularizer(args.kernelRegularizer); _this.recurrentRegularizer = getRegularizer(args.recurrentRegularizer); _this.biasRegularizer = getRegularizer(args.biasRegularizer); _this.kernelConstraint = getConstraint(args.kernelConstraint); _this.recurrentConstraint = getConstraint(args.recurrentConstraint); _this.biasConstraint = getConstraint(args.biasConstraint); _this.dropout = min([1, max([0, args.dropout == null ? 0 : args.dropout])]); _this.recurrentDropout = min([ 1, max([0, args.recurrentDropout == null ? 0 : args.recurrentDropout]) ]); _this.dropoutFunc = args.dropoutFunc; _this.implementation = args.implementation; _this.stateSize = [_this.units, _this.units]; _this.dropoutMask = null; _this.recurrentDropoutMask = null; return _this; } LSTMCell.prototype.build = function (inputShape) { var _a; inputShape = getExactlyOneShape(inputShape); var inputDim = inputShape[inputShape.length - 1]; this.kernel = this.addWeight('kernel', [inputDim, this.units * 4], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint); this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units * 4], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint); var biasInitializer; if (this.useBias) { if (this.unitForgetBias) { var capturedBiasInit_1 = this.biasInitializer; var capturedUnits_1 = this.units; biasInitializer = new (_a = /** @class */ (function (_super) { __extends(CustomInit, _super); function CustomInit() { return _super !== null && _super.apply(this, arguments) || this; } CustomInit.prototype.apply = function (shape, dtype) { // TODO(cais): More informative variable names? var bI = capturedBiasInit_1.apply([capturedUnits_1]); var bF = (new Ones()).apply([capturedUnits_1]); var bCAndH = capturedBiasInit_1.apply([capturedUnits_1 * 2]); return concatAlongFirstAxis(concatAlongFirstAxis(bI, bF), bCAndH); }; return CustomInit; }(Initializer)), /** @nocollapse */ _a.className = 'CustomInit', _a)(); } else { biasInitializer = this.biasInitializer; } this.bias = this.addWeight('bias', [this.units * 4], null, biasInitializer, this.biasRegularizer, true, this.biasConstraint); } else { this.bias = null; } // Porting Notes: Unlike the PyKeras implementation, we perform slicing // of the weights and bias in the call() method, at execution time. this.built = true; }; LSTMCell.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { var training = kwargs['training'] == null ? false : kwargs['training']; inputs = inputs; if (inputs.length !== 3) { throw new ValueError("LSTMCell expects 3 input Tensors (inputs, h, c), got " + "".concat(inputs.length, ".")); } var hTMinus1 = inputs[1]; // Previous memory state. var cTMinus1 = inputs[2]; // Previous carry state. inputs = inputs[0]; if (0 < _this.dropout && _this.dropout < 1 && _this.dropoutMask == null) { _this.dropoutMask = generateDropoutMask({ ones: function () { return tfc__namespace.onesLike(inputs); }, rate: _this.dropout, training: training, count: 4, dropoutFunc: _this.dropoutFunc }); } if (0 < _this.recurrentDropout && _this.recurrentDropout < 1 && _this.recurrentDropoutMask == null) { _this.recurrentDropoutMask = generateDropoutMask({ ones: function () { return tfc__namespace.onesLike(hTMinus1); }, rate: _this.recurrentDropout, training: training, count: 4, dropoutFunc: _this.dropoutFunc }); } var dpMask = _this.dropoutMask; var recDpMask = _this.recurrentDropoutMask; // Note: For superior performance, TensorFlow.js always uses // implementation 2 regardless of the actual value of // config.implementation. var i; var f; var c; var o; if (0 < _this.dropout && _this.dropout < 1) { inputs = tfc__namespace.mul(inputs, dpMask[0]); } var z = dot$1(inputs, _this.kernel.read()); if (0 < _this.recurrentDropout && _this.recurrentDropout < 1) { hTMinus1 = tfc__namespace.mul(hTMinus1, recDpMask[0]); } z = tfc__namespace.add(z, dot$1(hTMinus1, _this.recurrentKernel.read())); if (_this.useBias) { z = biasAdd(z, _this.bias.read()); } var _b = __read(tfc__namespace.split(z, 4, z.rank - 1), 4), z0 = _b[0], z1 = _b[1], z2 = _b[2], z3 = _b[3]; i = _this.recurrentActivation.apply(z0); f = _this.recurrentActivation.apply(z1); c = tfc__namespace.add(tfc__namespace.mul(f, cTMinus1), tfc__namespace.mul(i, _this.activation.apply(z2))); o = _this.recurrentActivation.apply(z3); var h = tfc__namespace.mul(o, _this.activation.apply(c)); // TODO(cais): Add use_learning_phase flag properly. return [h, h, c]; }); }; LSTMCell.prototype.getConfig = function () { var baseConfig = _super.prototype.getConfig.call(this); var config = { units: this.units, activation: serializeActivation(this.activation), recurrentActivation: serializeActivation(this.recurrentActivation), useBias: this.useBias, kernelInitializer: serializeInitializer(this.kernelInitializer), recurrentInitializer: serializeInitializer(this.recurrentInitializer), biasInitializer: serializeInitializer(this.biasInitializer), unitForgetBias: this.unitForgetBias, kernelRegularizer: serializeRegularizer(this.kernelRegularizer), recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer), biasRegularizer: serializeRegularizer(this.biasRegularizer), activityRegularizer: serializeRegularizer(this.activityRegularizer), kernelConstraint: serializeConstraint(this.kernelConstraint), recurrentConstraint: serializeConstraint(this.recurrentConstraint), biasConstraint: serializeConstraint(this.biasConstraint), dropout: this.dropout, recurrentDropout: this.recurrentDropout, implementation: this.implementation, }; return Object.assign(Object.assign({}, baseConfig), config); }; return LSTMCell; }(RNNCell)); /** @nocollapse */ LSTMCell.className = 'LSTMCell'; tfc.serialization.registerClass(LSTMCell); var LSTM = /** @class */ (function (_super) { __extends(LSTM, _super); function LSTM(args) { if (args.implementation === 0) { console.warn('`implementation=0` has been deprecated, and now defaults to ' + '`implementation=1`. Please update your layer call.'); } args.cell = new LSTMCell(args); return _super.call(this, args) || this; // TODO(cais): Add activityRegularizer. } LSTM.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { if (_this.cell.dropoutMask != null) { tfc__namespace.dispose(_this.cell.dropoutMask); _this.cell.dropoutMask = null; } if (_this.cell.recurrentDropoutMask != null) { tfc__namespace.dispose(_this.cell.recurrentDropoutMask); _this.cell.recurrentDropoutMask = null; } var mask = kwargs == null ? null : kwargs['mask']; var training = kwargs == null ? null : kwargs['training']; var initialState = kwargs == null ? null : kwargs['initialState']; return _super.prototype.call.call(_this, inputs, { mask: mask, training: training, initialState: initialState }); }); }; /** @nocollapse */ LSTM.fromConfig = function (cls, config) { if (config['implmentation'] === 0) { config['implementation'] = 1; } return new cls(config); }; return LSTM; }(RNN)); /** @nocollapse */ LSTM.className = 'LSTM'; tfc.serialization.registerClass(LSTM); var StackedRNNCells = /** @class */ (function (_super) { __extends(StackedRNNCells, _super); function StackedRNNCells(args) { var _this = _super.call(this, args) || this; _this.cells = args.cells; return _this; } Object.defineProperty(StackedRNNCells.prototype, "stateSize", { get: function () { var e_3, _b; // States are a flat list in reverse order of the cell stack. // This allows perserving the requirement `stack.statesize[0] === // outputDim`. E.g., states of a 2-layer LSTM would be `[h2, c2, h1, c1]`, // assuming one LSTM has states `[h, c]`. var stateSize = []; try { for (var _c = __values(this.cells.slice().reverse()), _d = _c.next(); !_d.done; _d = _c.next()) { var cell = _d.value; if (Array.isArray(cell.stateSize)) { stateSize.push.apply(stateSize, __spreadArray([], __read(cell.stateSize), false)); } else { stateSize.push(cell.stateSize); } } } catch (e_3_1) { e_3 = { error: e_3_1 }; } finally { try { if (_d && !_d.done && (_b = _c.return)) _b.call(_c); } finally { if (e_3) throw e_3.error; } } return stateSize; }, enumerable: false, configurable: true }); StackedRNNCells.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { var e_4, _b, e_5, _c; inputs = inputs; var states = inputs.slice(1); // Recover per-cell states. var nestedStates = []; try { for (var _d = __values(_this.cells.slice().reverse()), _e = _d.next(); !_e.done; _e = _d.next()) { var cell = _e.value; if (Array.isArray(cell.stateSize)) { nestedStates.push(states.splice(0, cell.stateSize.length)); } else { nestedStates.push(states.splice(0, 1)); } } } catch (e_4_1) { e_4 = { error: e_4_1 }; } finally { try { if (_e && !_e.done && (_b = _d.return)) _b.call(_d); } finally { if (e_4) throw e_4.error; } } nestedStates.reverse(); // Call the cells in order and store the returned states. var newNestedStates = []; var callInputs; for (var i = 0; i < _this.cells.length; ++i) { var cell = _this.cells[i]; states = nestedStates[i]; // TODO(cais): Take care of constants. if (i === 0) { callInputs = [inputs[0]].concat(states); } else { callInputs = [callInputs[0]].concat(states); } callInputs = cell.call(callInputs, kwargs); newNestedStates.push(callInputs.slice(1)); } // Format the new states as a flat list in reverse cell order. states = []; try { for (var _f = __values(newNestedStates.slice().reverse()), _g = _f.next(); !_g.done; _g = _f.next()) { var cellStates = _g.value; states.push.apply(states, __spreadArray([], __read(cellStates), false)); } } catch (e_5_1) { e_5 = { error: e_5_1 }; } finally { try { if (_g && !_g.done && (_c = _f.return)) _c.call(_f); } finally { if (e_5) throw e_5.error; } } return [callInputs[0]].concat(states); }); }; StackedRNNCells.prototype.build = function (inputShape) { if (isArrayOfShapes(inputShape)) { // TODO(cais): Take care of input constants. // const constantShape = inputShape.slice(1); inputShape = inputShape[0]; } inputShape = inputShape; var outputDim; this.cells.forEach(function (cell, i) { nameScope("RNNCell_".concat(i), function () { // TODO(cais): Take care of input constants. cell.build(inputShape); if (Array.isArray(cell.stateSize)) { outputDim = cell.stateSize[0]; } else { outputDim = cell.stateSize; } inputShape = [inputShape[0], outputDim]; }); }); this.built = true; }; StackedRNNCells.prototype.getConfig = function () { var baseConfig = _super.prototype.getConfig.call(this); var getCellConfig = function (cell) { return { 'className': cell.getClassName(), 'config': cell.getConfig(), }; }; var cellConfigs = this.cells.map(getCellConfig); var config = { 'cells': cellConfigs }; return Object.assign(Object.assign({}, baseConfig), config); }; /** @nocollapse */ StackedRNNCells.fromConfig = function (cls, config, customObjects) { var e_6, _b; if (customObjects === void 0) { customObjects = {}; } var cells = []; try { for (var _c = __values(config['cells']), _d = _c.next(); !_d.done; _d = _c.next()) { var cellConfig = _d.value; cells.push(deserialize(cellConfig, customObjects)); } } catch (e_6_1) { e_6 = { error: e_6_1 }; } finally { try { if (_d && !_d.done && (_b = _c.return)) _b.call(_c); } finally { if (e_6) throw e_6.error; } } return new cls({ cells: cells }); }; Object.defineProperty(StackedRNNCells.prototype, "trainableWeights", { get: function () { var e_7, _b; if (!this.trainable) { return []; } var weights = []; try { for (var _c = __values(this.cells), _d = _c.next(); !_d.done; _d = _c.next()) { var cell = _d.value; weights.push.apply(weights, __spreadArray([], __read(cell.trainableWeights), false)); } } catch (e_7_1) { e_7 = { error: e_7_1 }; } finally { try { if (_d && !_d.done && (_b = _c.return)) _b.call(_c); } finally { if (e_7) throw e_7.error; } } return weights; }, enumerable: false, configurable: true }); Object.defineProperty(StackedRNNCells.prototype, "nonTrainableWeights", { get: function () { var e_8, _b, e_9, _c; var weights = []; try { for (var _d = __values(this.cells), _e = _d.next(); !_e.done; _e = _d.next()) { var cell = _e.value; weights.push.apply(weights, __spreadArray([], __read(cell.nonTrainableWeights), false)); } } catch (e_8_1) { e_8 = { error: e_8_1 }; } finally { try { if (_e && !_e.done && (_b = _d.return)) _b.call(_d); } finally { if (e_8) throw e_8.error; } } if (!this.trainable) { var trainableWeights = []; try { for (var _f = __values(this.cells), _g = _f.next(); !_g.done; _g = _f.next()) { var cell = _g.value; trainableWeights.push.apply(trainableWeights, __spreadArray([], __read(cell.trainableWeights), false)); } } catch (e_9_1) { e_9 = { error: e_9_1 }; } finally { try { if (_g && !_g.done && (_c = _f.return)) _c.call(_f); } finally { if (e_9) throw e_9.error; } } return trainableWeights.concat(weights); } return weights; }, enumerable: false, configurable: true }); /** * Retrieve the weights of a the model. * * @returns A flat `Array` of `tf.Tensor`s. */ StackedRNNCells.prototype.getWeights = function () { var e_10, _b; var weights = []; try { for (var _c = __values(this.cells), _d = _c.next(); !_d.done; _d = _c.next()) { var cell = _d.value; weights.push.apply(weights, __spreadArray([], __read(cell.weights), false)); } } catch (e_10_1) { e_10 = { error: e_10_1 }; } finally { try { if (_d && !_d.done && (_b = _c.return)) _b.call(_c); } finally { if (e_10) throw e_10.error; } } return batchGetValue(weights); }; /** * Set the weights of the model. * * @param weights An `Array` of `tf.Tensor`s with shapes and types matching * the output of `getWeights()`. */ StackedRNNCells.prototype.setWeights = function (weights) { var e_11, _b; var tuples = []; try { for (var _c = __values(this.cells), _d = _c.next(); !_d.done; _d = _c.next()) { var cell = _d.value; var numParams = cell.weights.length; var inputWeights = weights.splice(numParams); for (var i = 0; i < cell.weights.length; ++i) { tuples.push([cell.weights[i], inputWeights[i]]); } } } catch (e_11_1) { e_11 = { error: e_11_1 }; } finally { try { if (_d && !_d.done && (_b = _c.return)) _b.call(_c); } finally { if (e_11) throw e_11.error; } } batchSetValue(tuples); }; return StackedRNNCells; }(RNNCell)); /** @nocollapse */ StackedRNNCells.className = 'StackedRNNCells'; tfc.serialization.registerClass(StackedRNNCells); function generateDropoutMask(args) { var ones = args.ones, rate = args.rate, _b = args.training, training = _b === void 0 ? false : _b, _c = args.count, count = _c === void 0 ? 1 : _c, dropoutFunc = args.dropoutFunc; var droppedInputs = function () { return dropoutFunc != null ? dropoutFunc(ones(), rate) : dropout$1(ones(), rate); }; var createMask = function () { return inTrainPhase(droppedInputs, ones, training); }; // just in case count is provided with null or undefined if (!count || count <= 1) { return tfc__namespace.keep(createMask().clone()); } var masks = Array(count).fill(undefined).map(createMask); return masks.map(function (m) { return tfc__namespace.keep(m.clone()); }); } /** * @license * Copyright 2020 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ var __rest = (undefined && undefined.__rest) || function (s, e) { var t = {}; for (var p in s) if (Object.prototype.hasOwnProperty.call(s, p) && e.indexOf(p) < 0) t[p] = s[p]; if (s != null && typeof Object.getOwnPropertySymbols === "function") for (var i = 0, p = Object.getOwnPropertySymbols(s); i < p.length; i++) { if (e.indexOf(p[i]) < 0 && Object.prototype.propertyIsEnumerable.call(s, p[i])) t[p[i]] = s[p[i]]; } return t; }; /** @class */ ((function (_super) { __extends(ConvRNN2DCell, _super); function ConvRNN2DCell() { return _super !== null && _super.apply(this, arguments) || this; } return ConvRNN2DCell; })(RNNCell)); /** * Base class for convolutional-recurrent layers. */ var ConvRNN2D = /** @class */ (function (_super) { __extends(ConvRNN2D, _super); function ConvRNN2D(args) { var _this = this; if (args.unroll) { throw new NotImplementedError('Unrolling is not possible with convolutional RNNs.'); } if (Array.isArray(args.cell)) { throw new NotImplementedError('It is not possible at the moment to stack convolutional cells.'); } _this = _super.call(this, args) || this; _this.inputSpec = [new InputSpec({ ndim: 5 })]; return _this; } ConvRNN2D.prototype.call = function (inputs, kwargs) { var _this = this; return tfc__namespace.tidy(function () { if (_this.cell.dropoutMask != null) { tfc__namespace.dispose(_this.cell.dropoutMask); _this.cell.dropoutMask = null; } if (_this.cell.recurrentDropoutMask != null) { tfc__namespace.dispose(_this.cell.recurrentDropoutMask); _this.cell.recurrentDropoutMask = null; } if (kwargs && kwargs['constants']) { throw new ValueError('ConvRNN2D cell does not support constants'); } var mask = kwargs == null ? null : kwargs['mask']; var training = kwargs == null ? null : kwargs['training']; var initialState = kwargs == null ? null : kwargs['initialState']; return _super.prototype.call.call(_this, inputs, { mask: mask, training: training, initialState: initialState }); }); }; ConvRNN2D.prototype.computeOutputShape = function (inputShape) { var outShape = this.computeSingleOutputShape(inputShape); if (!this.returnSequences) { outShape = __spreadArray([outShape[0]], __read(outShape.slice(2)), false); } if (this.returnState) { outShape = __spreadArray([outShape], __read(Array(2).fill(__spreadArray([inputShape[0]], __read(outShape.slice(-3)), false))), false); } return outShape; }; ConvRNN2D.prototype.getInitialState = function (inputs) { var _this = this; return tfc__namespace.tidy(function () { var stateSize = _this.cell.stateSize; var inputShape = inputs.shape; var outputShape = _this.computeSingleOutputShape(inputShape); var stateShape = __spreadArray([outputShape[0]], __read(outputShape.slice(2)), false); var initialState = tfc__namespace.zeros(stateShape); if (Array.isArray(stateSize)) { return Array(stateSize.length).fill(initialState); } return [initialState]; }); }; ConvRNN2D.prototype.resetStates = function (states, training) { var _this = this; if (training === void 0) { training = false; } tfc__namespace.tidy(function () { if (!_this.stateful) { throw new AttributeError('Cannot call resetStates() on an RNN Layer that is not stateful.'); } var inputShape = _this.inputSpec[0].shape; var outputShape = _this.computeSingleOutputShape(inputShape); var stateShape = __spreadArray([outputShape[0]], __read(outputShape.slice(2)), false); var batchSize = inputShape[0]; if (batchSize == null) { throw new ValueError('If an RNN is stateful, it needs to know its batch size. Specify ' + 'the batch size of your input tensors: \n' + '- If using a Sequential model, specify the batch size by ' + 'passing a `batchInputShape` option to your first layer.\n' + '- If using the functional API, specify the batch size by ' + 'passing a `batchShape` option to your Input layer.'); } // Initialize state if null. if (_this.getStates() == null) { if (Array.isArray(_this.cell.stateSize)) { _this.states_ = _this.cell.stateSize.map(function () { return tfc__namespace.zeros(stateShape); }); } else { _this.states_ = [tfc__namespace.zeros(stateShape)]; } } else if (states == null) { // Dispose old state tensors. tfc__namespace.dispose(_this.states_); // For stateful RNNs, fully dispose kept old states. if (_this.keptStates != null) { tfc__namespace.dispose(_this.keptStates); _this.keptStates = []; } if (Array.isArray(_this.cell.stateSize)) { _this.states_ = _this.cell.stateSize.map(function () { return tfc__namespace.zeros(stateShape); }); } else { _this.states_[0] = tfc__namespace.zeros(stateShape); } } else { if (!Array.isArray(states)) { states = [states]; } if (states.length !== _this.states_.length) { throw new ValueError("Layer ".concat(_this.name, " expects ").concat(_this.states_.length, " state(s), ") + "but it received ".concat(states.length, " state value(s). Input ") + "received: ".concat(states)); } if (training) { // Store old state tensors for complete disposal later, i.e., during // the next no-arg call to this method. We do not dispose the old // states immediately because that BPTT (among other things) require // them. _this.keptStates.push(_this.states_.slice()); } else { tfc__namespace.dispose(_this.states_); } for (var index = 0; index < _this.states_.length; ++index) { var value = states[index]; var expectedShape = stateShape; if (!tfc.util.arraysEqual(value.shape, expectedShape)) { throw new ValueError("State ".concat(index, " is incompatible with layer ").concat(_this.name, ": ") + "expected shape=".concat(expectedShape, ", received shape=").concat(value.shape)); } _this.states_[index] = value; } } _this.states_ = _this.states_.map(function (state) { return tfc__namespace.keep(state.clone()); }); }); }; ConvRNN2D.prototype.computeSingleOutputShape = function (inputShape) { var _b = this.cell, dataFormat = _b.dataFormat, filters = _b.filters, kernelSize = _b.kernelSize, padding = _b.padding, strides = _b.strides, dilationRate = _b.dilationRate; var isChannelsFirst = dataFormat === 'channelsFirst'; var h = inputShape[isChannelsFirst ? 3 : 2]; var w = inputShape[isChannelsFirst ? 4 : 3]; var hOut = convOutputLength(h, kernelSize[0], padding, strides[0], dilationRate[0]); var wOut = convOutputLength(w, kernelSize[1], padding, strides[1], dilationRate[1]); var outShape = __spreadArray(__spreadArray([], __read(inputShape.slice(0, 2)), false), __read((isChannelsFirst ? [filters, hOut, wOut] : [hOut, wOut, filters])), false); return outShape; }; return ConvRNN2D; }(RNN)); /** @nocollapse */ ConvRNN2D.className = 'ConvRNN2D'; var ConvLSTM2DCell = /** @class */ (function (_super) { __extends(ConvLSTM2DCell, _super); function ConvLSTM2DCell(args) { var _this = this; var filters = args.filters, kernelSize = args.kernelSize, strides = args.strides, padding = args.padding, dataFormat = args.dataFormat, dilationRate = args.dilationRate; _this = _super.call(this, Object.assign(Object.assign({}, args), { units: filters })) || this; _this.filters = filters; assertPositiveInteger(_this.filters, 'filters'); _this.kernelSize = normalizeArray(kernelSize, 2, 'kernelSize'); _this.kernelSize.forEach(function (size) { return assertPositiveInteger(size, 'kernelSize'); }); _this.strides = normalizeArray(strides || 1, 2, 'strides'); _this.strides.forEach(function (stride) { return assertPositiveInteger(stride, 'strides'); }); _this.padding = padding || 'valid'; checkPaddingMode(_this.padding); _this.dataFormat = dataFormat || 'channelsLast'; checkDataFormat(_this.dataFormat); _this.dilationRate = normalizeArray(dilationRate || 1, 2, 'dilationRate'); _this.dilationRate.forEach(function (rate) { return assertPositiveInteger(rate, 'dilationRate'); }); return _this; } ConvLSTM2DCell.prototype.build = function (inputShape) { var _a; inputShape = getExactlyOneShape(inputShape); var channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1; if (inputShape[channelAxis] == null) { throw new ValueError("The channel dimension of the input should be defined. " + "Found ".concat(inputShape[channelAxis])); } var inputDim = inputShape[channelAxis]; var numOfKernels = 4; var kernelShape = this.kernelSize.concat([inputDim, this.filters * numOfKernels]); this.kernel = this.addWeight('kernel', kernelShape, null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint); var recurrentKernelShape = this.kernelSize.concat([this.filters, this.filters * numOfKernels]); this.recurrentKernel = this.addWeight('recurrent_kernel', recurrentKernelShape, null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint); if (this.useBias) { var biasInitializer = void 0; if (this.unitForgetBias) { var init_1 = this.biasInitializer; var filters_1 = this.filters; biasInitializer = new (_a = /** @class */ (function (_super) { __extends(CustomInit, _super); function CustomInit() { return _super !== null && _super.apply(this, arguments) || this; } CustomInit.prototype.apply = function (shape, dtype) { var biasI = init_1.apply([filters_1]); var biasF = tfc__namespace.ones([filters_1]); var biasCAndO = init_1.apply([filters_1 * 2]); return concatenate$1([biasI, biasF, biasCAndO]); }; return CustomInit; }(Initializer)), /** @nocollapse */ _a.className = 'CustomInit', _a)(); } else { biasInitializer = this.biasInitializer; } this.bias = this.addWeight('bias', [this.filters * numOfKernels], null, biasInitializer, this.biasRegularizer, true, this.biasConstraint); } this.built = true; }; ConvLSTM2DCell.prototype.call = function (inputs, kwargs) { var _this = this; return tfc__namespace.tidy(function () { if (inputs.length !== 3) { throw new ValueError("ConvLSTM2DCell expects 3 input Tensors (inputs, h, c), got " + "".concat(inputs.length, ".")); } var training = kwargs['training'] || false; var x = inputs[0]; // Current input var hTMinus1 = inputs[1]; // Previous memory state. var cTMinus1 = inputs[2]; // Previous carry state. var numOfKernels = 4; if (0 < _this.dropout && _this.dropout < 1 && _this.dropoutMask == null) { _this.dropoutMask = generateDropoutMask({ ones: function () { return tfc__namespace.onesLike(x); }, rate: _this.dropout, training: training, count: numOfKernels, dropoutFunc: _this.dropoutFunc }); } var dropoutMask = _this.dropoutMask; var applyDropout = function (x, mask, index) { if (!mask || !mask[index]) { return x; } return tfc__namespace.mul(mask[index], x); }; var xI = applyDropout(x, dropoutMask, 0); var xF = applyDropout(x, dropoutMask, 1); var xC = applyDropout(x, dropoutMask, 2); var xO = applyDropout(x, dropoutMask, 3); if (0 < _this.recurrentDropout && _this.recurrentDropout < 1 && _this.recurrentDropoutMask == null) { _this.recurrentDropoutMask = generateDropoutMask({ ones: function () { return tfc__namespace.onesLike(hTMinus1); }, rate: _this.recurrentDropout, training: training, count: numOfKernels, dropoutFunc: _this.dropoutFunc }); } var recDropoutMask = _this.recurrentDropoutMask; var hI = applyDropout(hTMinus1, recDropoutMask, 0); var hF = applyDropout(hTMinus1, recDropoutMask, 1); var hC = applyDropout(hTMinus1, recDropoutMask, 2); var hO = applyDropout(hTMinus1, recDropoutMask, 3); var kernelChannelAxis = 3; var _b = __read(tfc__namespace.split(_this.kernel.read(), numOfKernels, kernelChannelAxis), 4), kernelI = _b[0], kernelF = _b[1], kernelC = _b[2], kernelO = _b[3]; var _c = __read(_this.useBias ? tfc__namespace.split(_this.bias.read(), numOfKernels) : [null, null, null, null], 4), biasI = _c[0], biasF = _c[1], biasC = _c[2], biasO = _c[3]; xI = _this.inputConv(xI, kernelI, biasI, _this.padding); xF = _this.inputConv(xF, kernelF, biasF, _this.padding); xC = _this.inputConv(xC, kernelC, biasC, _this.padding); xO = _this.inputConv(xO, kernelO, biasO, _this.padding); var _d = __read(tfc__namespace.split(_this.recurrentKernel.read(), numOfKernels, kernelChannelAxis), 4), recKernelI = _d[0], recKernelF = _d[1], recKernelC = _d[2], recKernelO = _d[3]; hI = _this.recurrentConv(hI, recKernelI); hF = _this.recurrentConv(hF, recKernelF); hC = _this.recurrentConv(hC, recKernelC); hO = _this.recurrentConv(hO, recKernelO); var i = _this.recurrentActivation.apply(tfc__namespace.add(xI, hI)); var f = _this.recurrentActivation.apply(tfc__namespace.add(xF, hF)); var c = tfc__namespace.add(tfc__namespace.mul(f, cTMinus1), tfc__namespace.mul(i, _this.activation.apply(tfc__namespace.add(xC, hC)))); var h = tfc__namespace.mul(_this.recurrentActivation.apply(tfc__namespace.add(xO, hO)), _this.activation.apply(c)); return [h, h, c]; }); }; ConvLSTM2DCell.prototype.getConfig = function () { var _a = _super.prototype.getConfig.call(this); _a["units"]; var baseConfig = __rest(_a, ['units']); var config = { filters: this.filters, kernelSize: this.kernelSize, padding: this.padding, dataFormat: this.dataFormat, dilationRate: this.dilationRate, strides: this.strides, }; return Object.assign(Object.assign({}, baseConfig), config); }; ConvLSTM2DCell.prototype.inputConv = function (x, w, b, padding) { var out = tfc__namespace.conv2d(x, w, this.strides, (padding || 'valid'), this.dataFormat === 'channelsFirst' ? 'NCHW' : 'NHWC', this.dilationRate); if (b) { return biasAdd(out, b, this.dataFormat); } return out; }; ConvLSTM2DCell.prototype.recurrentConv = function (x, w) { var strides = 1; return tfc__namespace.conv2d(x, w, strides, 'same', this.dataFormat === 'channelsFirst' ? 'NCHW' : 'NHWC'); }; return ConvLSTM2DCell; }(LSTMCell)); /** @nocollapse */ ConvLSTM2DCell.className = 'ConvLSTM2DCell'; tfc__namespace.serialization.registerClass(ConvLSTM2DCell); var ConvLSTM2D = /** @class */ (function (_super) { __extends(ConvLSTM2D, _super); function ConvLSTM2D(args) { var cell = new ConvLSTM2DCell(args); return _super.call(this, Object.assign(Object.assign({}, args), { cell: cell })) || this; } /** @nocollapse */ ConvLSTM2D.fromConfig = function (cls, config) { return new cls(config); }; return ConvLSTM2D; }(ConvRNN2D)); /** @nocollapse */ ConvLSTM2D.className = 'ConvLSTM2D'; tfc__namespace.serialization.registerClass(ConvLSTM2D); var Dropout = /** @class */ (function (_super) { __extends(Dropout, _super); function Dropout(args) { var _this = _super.call(this, args) || this; _this.rate = Math.max(Math.min(args.rate, 1), 0); // So that the scalar doesn't get tidied up between executions. _this.noiseShape = args.noiseShape; _this.seed = args.seed; _this.supportsMasking = true; return _this; } Dropout.prototype.getNoiseShape = function (input) { if (this.noiseShape == null) { return this.noiseShape; } var inputShape = input.shape; var noiseShape = []; for (var i = 0; i < this.noiseShape.length; ++i) { noiseShape.push(this.noiseShape[i] == null ? inputShape[i] : this.noiseShape[i]); } return noiseShape; }; Dropout.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { _this.invokeCallHook(inputs, kwargs); var input = getExactlyOneTensor(inputs); if (0 < _this.rate && _this.rate < 1) { var training = kwargs['training'] == null ? false : kwargs['training']; var noiseShape_1 = _this.getNoiseShape(input); var output = inTrainPhase(function () { return dropout$1(input, _this.rate, noiseShape_1, _this.seed); }, function () { return input; }, training); return output; } return inputs; }); }; Dropout.prototype.getConfig = function () { var config = { rate: this.rate, noiseShape: this.noiseShape, seed: this.seed, }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; Dropout.prototype.dispose = function () { return _super.prototype.dispose.call(this); }; return Dropout; }(Layer)); /** @nocollapse */ Dropout.className = 'Dropout'; tfc.serialization.registerClass(Dropout); var SpatialDropout1D = /** @class */ (function (_super) { __extends(SpatialDropout1D, _super); function SpatialDropout1D(args) { var _this = _super.call(this, args) || this; _this.inputSpec = [{ ndim: 3 }]; return _this; } SpatialDropout1D.prototype.getNoiseShape = function (input) { var inputShape = input.shape; return [inputShape[0], 1, inputShape[2]]; }; return SpatialDropout1D; }(Dropout)); /** @nocollapse */ SpatialDropout1D.className = 'SpatialDropout1D'; tfc.serialization.registerClass(SpatialDropout1D); var Dense = /** @class */ (function (_super) { __extends(Dense, _super); function Dense(args) { var _this = _super.call(this, args) || this; // Default activation: Linear (none). _this.activation = null; _this.useBias = true; _this.kernel = null; _this.bias = null; _this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal'; _this.DEFAULT_BIAS_INITIALIZER = 'zeros'; if (args.batchInputShape == null && args.inputShape == null && args.inputDim != null) { // This logic is copied from Layer's constructor, since we can't // do exactly what the Python constructor does for Dense(). var batchSize = null; if (args.batchSize != null) { batchSize = args.batchSize; } _this.batchInputShape = [batchSize, args.inputDim]; } _this.units = args.units; assertPositiveInteger(_this.units, 'units'); _this.activation = getActivation(args.activation); if (args.useBias != null) { _this.useBias = args.useBias; } _this.kernelInitializer = getInitializer(args.kernelInitializer || _this.DEFAULT_KERNEL_INITIALIZER); _this.biasInitializer = getInitializer(args.biasInitializer || _this.DEFAULT_BIAS_INITIALIZER); _this.kernelConstraint = getConstraint(args.kernelConstraint); _this.biasConstraint = getConstraint(args.biasConstraint); _this.kernelRegularizer = getRegularizer(args.kernelRegularizer); _this.biasRegularizer = getRegularizer(args.biasRegularizer); _this.activityRegularizer = getRegularizer(args.activityRegularizer); _this.supportsMasking = true; _this.inputSpec = [{ minNDim: 2 }]; return _this; } Dense.prototype.build = function (inputShape) { var _a; inputShape = getExactlyOneShape(inputShape); var inputLastDim = inputShape[inputShape.length - 1]; if (this.kernel == null) { this.kernel = this.addWeight('kernel', [inputLastDim, this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint); if (this.useBias) { this.bias = this.addWeight('bias', [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint); } } this.inputSpec = [{ minNDim: 2, axes: (_a = {}, _a[-1] = inputLastDim, _a) }]; this.built = true; }; Dense.prototype.computeOutputShape = function (inputShape) { inputShape = getExactlyOneShape(inputShape); var outputShape = inputShape.slice(); outputShape[outputShape.length - 1] = this.units; return outputShape; }; Dense.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { _this.invokeCallHook(inputs, kwargs); // Dense layer accepts only a single input. var input = getExactlyOneTensor(inputs); var fusedActivationName = mapActivationToFusedKernel(_this.activation.getClassName()); var output; if (fusedActivationName != null) { output = dot$1(input, _this.kernel.read(), fusedActivationName, _this.bias ? _this.bias.read() : null); } else { output = dot$1(input, _this.kernel.read()); if (_this.bias != null) { output = biasAdd(output, _this.bias.read()); } if (_this.activation != null) { output = _this.activation.apply(output); } } return output; }); }; Dense.prototype.getConfig = function () { var config = { units: this.units, activation: serializeActivation(this.activation), useBias: this.useBias, kernelInitializer: serializeInitializer(this.kernelInitializer), biasInitializer: serializeInitializer(this.biasInitializer), kernelRegularizer: serializeRegularizer(this.kernelRegularizer), biasRegularizer: serializeRegularizer(this.biasRegularizer), activityRegularizer: serializeRegularizer(this.activityRegularizer), kernelConstraint: serializeConstraint(this.kernelConstraint), biasConstraint: serializeConstraint(this.biasConstraint) }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return Dense; }(Layer)); /** @nocollapse */ Dense.className = 'Dense'; tfc.serialization.registerClass(Dense); var Flatten = /** @class */ (function (_super) { __extends(Flatten, _super); function Flatten(args) { var _this = this; args = args || {}; _this = _super.call(this, args) || this; _this.inputSpec = [{ minNDim: 3 }]; _this.dataFormat = args.dataFormat; return _this; } Flatten.prototype.computeOutputShape = function (inputShape) { var e_1, _a; inputShape = getExactlyOneShape(inputShape); try { for (var _b = __values(inputShape.slice(1)), _c = _b.next(); !_c.done; _c = _b.next()) { var dim = _c.value; if (dim == null) { throw new ValueError("The shape of the input to \"Flatten\" is not fully defined " + "(got ".concat(inputShape.slice(1), "). Make sure to pass a complete ") + "\"input_shape\" or \"batch_input_shape\" argument to the first " + "layer in your model."); } } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_1) throw e_1.error; } } return [inputShape[0], arrayProd(inputShape, 1)]; }; Flatten.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { _this.invokeCallHook(inputs, kwargs); var input = getExactlyOneTensor(inputs); if (_this.dataFormat === 'channelsFirst' && input.rank > 1) { var permutation = [0]; for (var i = 2; i < input.rank; ++i) { permutation.push(i); } permutation.push(1); input = tfc.transpose(input, permutation); } return batchFlatten(input); }); }; Flatten.prototype.getConfig = function () { var config = {}; if (this.dataFormat != null) { config['dataFormat'] = this.dataFormat; } var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return Flatten; }(Layer)); /** @nocollapse */ Flatten.className = 'Flatten'; tfc.serialization.registerClass(Flatten); var Activation = /** @class */ (function (_super) { __extends(Activation, _super); function Activation(args) { var _this = _super.call(this, args) || this; _this.supportsMasking = true; _this.activation = getActivation(args.activation); return _this; } Activation.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { _this.invokeCallHook(inputs, kwargs); var input = getExactlyOneTensor(inputs); return _this.activation.apply(input); }); }; Activation.prototype.getConfig = function () { var config = { activation: serializeActivation(this.activation) }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return Activation; }(Layer)); /** @nocollapse */ Activation.className = 'Activation'; tfc.serialization.registerClass(Activation); var RepeatVector = /** @class */ (function (_super) { __extends(RepeatVector, _super); function RepeatVector(args) { var _this = _super.call(this, args) || this; _this.n = args.n; _this.inputSpec = [{ ndim: 2 }]; return _this; } RepeatVector.prototype.computeOutputShape = function (inputShape) { return [inputShape[0], this.n, inputShape[1]]; }; RepeatVector.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { inputs = getExactlyOneTensor(inputs); return repeat(inputs, _this.n); }); }; RepeatVector.prototype.getConfig = function () { var config = { n: this.n, }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return RepeatVector; }(Layer)); /** @nocollapse */ RepeatVector.className = 'RepeatVector'; tfc.serialization.registerClass(RepeatVector); var Reshape = /** @class */ (function (_super) { __extends(Reshape, _super); function Reshape(args) { var _this = _super.call(this, args) || this; _this.targetShape = args.targetShape; // Make sure that all unknown dimensions are represented as `null`. for (var i = 0; i < _this.targetShape.length; ++i) { if (_this.isUnknown(_this.targetShape[i])) { _this.targetShape[i] = null; } } return _this; } Reshape.prototype.isUnknown = function (dim) { return dim < 0 || dim == null; }; /** * Finds and replaces a missing dimension in output shape. * * This is a near direct port of the internal Numpy function * `_fix_unknown_dimension` in `numpy/core/src/multiarray/shape.c`. * * @param inputShape: Original shape of array begin reshape. * @param outputShape: Target shape of the array, with at most a single * `null` or negative number, which indicates an underdetermined dimension * that should be derived from `inputShape` and the known dimensions of * `outputShape`. * @returns: The output shape with `null` replaced with its computed value. * @throws: ValueError: If `inputShape` and `outputShape` do not match. */ Reshape.prototype.fixUnknownDimension = function (inputShape, outputShape) { var errorMsg = 'Total size of new array must be unchanged.'; var finalShape = outputShape.slice(); var known = 1; var unknown = null; for (var i = 0; i < finalShape.length; ++i) { var dim = finalShape[i]; if (this.isUnknown(dim)) { if (unknown === null) { unknown = i; } else { throw new ValueError('Can only specifiy one unknown dimension.'); } } else { known *= dim; } } var originalSize = arrayProd(inputShape); if (unknown !== null) { if (known === 0 || originalSize % known !== 0) { throw new ValueError(errorMsg); } finalShape[unknown] = originalSize / known; } else if (originalSize !== known) { throw new ValueError(errorMsg); } return finalShape; }; Reshape.prototype.computeOutputShape = function (inputShape) { var anyUnknownDims = false; for (var i = 0; i < inputShape.length; ++i) { if (this.isUnknown(inputShape[i])) { anyUnknownDims = true; break; } } if (anyUnknownDims) { return inputShape.slice(0, 1).concat(this.targetShape); } else { return inputShape.slice(0, 1).concat(this.fixUnknownDimension(inputShape.slice(1), this.targetShape)); } }; Reshape.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { _this.invokeCallHook(inputs, kwargs); var input = getExactlyOneTensor(inputs); var inputShape = input.shape; var outputShape = inputShape.slice(0, 1).concat(_this.fixUnknownDimension(inputShape.slice(1), _this.targetShape)); return tfc.reshape(input, outputShape); }); }; Reshape.prototype.getConfig = function () { var config = { targetShape: this.targetShape, }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return Reshape; }(Layer)); /** @nocollapse */ Reshape.className = 'Reshape'; tfc.serialization.registerClass(Reshape); var Permute = /** @class */ (function (_super) { __extends(Permute, _super); function Permute(args) { var _this = _super.call(this, args) || this; if (args.dims == null) { throw new Error('Required configuration field `dims` is missing during Permute ' + 'constructor call.'); } if (!Array.isArray(args.dims)) { throw new Error('Permute constructor requires `dims` to be an Array, but received ' + "".concat(args.dims, " instead.")); } // Check the validity of the permutation indices. var expectedSortedIndices = range(1, args.dims.length + 1); if (!tfc.util.arraysEqual(args.dims.slice().sort(), expectedSortedIndices)) { throw new Error('Invalid permutation `dims`: ' + JSON.stringify(args.dims) + ' `dims` must contain consecutive integers starting from 1.'); } _this.dims = args.dims; _this.dimsIncludingBatch = [0].concat(_this.dims); _this.inputSpec = [new InputSpec({ ndim: _this.dims.length + 1 })]; return _this; } Permute.prototype.computeOutputShape = function (inputShape) { inputShape = getExactlyOneShape(inputShape); var outputShape = inputShape.slice(); this.dims.forEach(function (dim, i) { outputShape[i + 1] = inputShape[dim]; }); return outputShape; }; Permute.prototype.call = function (inputs, kwargs) { return tfc.transpose(getExactlyOneTensor(inputs), this.dimsIncludingBatch); }; Permute.prototype.getConfig = function () { var config = { dims: this.dims, }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return Permute; }(Layer)); /** @nocollapse */ Permute.className = 'Permute'; tfc.serialization.registerClass(Permute); var Masking = /** @class */ (function (_super) { __extends(Masking, _super); function Masking(args) { var _this = _super.call(this, args == null ? {} : args) || this; _this.supportsMasking = true; if (args != null) { _this.maskValue = args.maskValue == null ? 0 : args.maskValue; } else { _this.maskValue = 0; } return _this; } Masking.prototype.computeOutputShape = function (inputShape) { return inputShape; }; Masking.prototype.getConfig = function () { var baseConfig = _super.prototype.getConfig.call(this); var config = { maskValue: this.maskValue }; Object.assign(config, baseConfig); return config; }; Masking.prototype.computeMask = function (inputs, mask) { var input = getExactlyOneTensor(inputs); var axis = -1; return tfc.any(tfc.notEqual(input, this.maskValue), axis); }; Masking.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { _this.invokeCallHook(inputs, kwargs); var input = getExactlyOneTensor(inputs); var axis = -1; var keepDims = true; var booleanMask = tfc.any(tfc.notEqual(input, _this.maskValue), axis, keepDims); var output = tfc.mul(input, tfc.cast(booleanMask, input.dtype)); return output; }); }; return Masking; }(Layer)); /** @nocollapse */ Masking.className = 'Masking'; tfc.serialization.registerClass(Masking); var Embedding = /** @class */ (function (_super) { __extends(Embedding, _super); function Embedding(args) { var _this = _super.call(this, args) || this; _this.embeddings = null; _this.DEFAULT_EMBEDDINGS_INITIALIZER = 'randomUniform'; if (args.batchInputShape == null && args.inputShape == null) { // Porting Note: This logic is copied from Layer's constructor, since we // can't do exactly what the Python constructor does for Embedding(). // Specifically, the super constructor can not be called after the // mutation of the `config` argument. var batchSize = null; if (args.batchSize != null) { batchSize = args.batchSize; } if (args.inputLength == null) { // Fix super-constructor to what it would have done if // 'config.inputShape' were (None, ) _this.batchInputShape = [batchSize, null]; } else { // Fix super-constructor to what it would have done if // 'config.inputShape' were (config.inputLength, ) _this.batchInputShape = [batchSize].concat(toList(args.inputLength)); } } _this.inputDim = args.inputDim; assertPositiveInteger(_this.inputDim, 'inputDim'); _this.outputDim = args.outputDim; assertPositiveInteger(_this.outputDim, 'outputDim'); _this.embeddingsInitializer = getInitializer(args.embeddingsInitializer || _this.DEFAULT_EMBEDDINGS_INITIALIZER); _this.embeddingsRegularizer = getRegularizer(args.embeddingsRegularizer); _this.activityRegularizer = getRegularizer(args.activityRegularizer); _this.embeddingsConstraint = getConstraint(args.embeddingsConstraint); _this.maskZero = args.maskZero; _this.supportsMasking = args.maskZero; _this.inputLength = args.inputLength; return _this; } Embedding.prototype.build = function (inputShape) { this.embeddings = this.addWeight('embeddings', [this.inputDim, this.outputDim], this.dtype, this.embeddingsInitializer, this.embeddingsRegularizer, true, this.embeddingsConstraint); this.built = true; }; // Override warnOnIncompatibleInputShape because an embedding layer allows // the input to have varying ranks. Embedding.prototype.warnOnIncompatibleInputShape = function (inputShape) { }; Embedding.prototype.computeMask = function (inputs, mask) { var _this = this; return tfc.tidy(function () { if (!_this.maskZero) { return null; } else { inputs = getExactlyOneTensor(inputs); return tfc.notEqual(inputs, tfc.zerosLike(inputs)); } }); }; Embedding.prototype.computeOutputShape = function (inputShape) { inputShape = getExactlyOneShape(inputShape); if (this.inputLength == null) { return __spreadArray(__spreadArray([], __read(inputShape), false), [this.outputDim], false); } // inputLength can be an array if input is 3D or higher. var inLens = toList(this.inputLength); if (inLens.length !== inputShape.length - 1) { throw new ValueError("\"inputLength\" is ".concat(this.inputLength, ", but received ") + "input shape has shape ".concat(inputShape)); } else { var i = 0; for (var k = 0; k < inLens.length; ++k) { var s1 = inLens[k]; var s2 = inputShape[k + 1]; if ((s1 != null) && (s2 != null) && (s1 !== s2)) { throw new ValueError("\"inputLength\" is ".concat(this.inputLength, ", but received ") + "input shape has shape ".concat(inputShape)); } else if (s1 == null) { inLens[i] = s2; } i++; } } return __spreadArray(__spreadArray([inputShape[0]], __read(inLens), false), [this.outputDim], false); }; Embedding.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { _this.invokeCallHook(inputs, kwargs); // Embedding layer accepts only a single input. var input = getExactlyOneTensor(inputs); if (input.dtype !== 'int32') { input = cast$1(input, 'int32'); } var output = gather$1(_this.embeddings.read(), tfc.reshape(input, [input.size])); return tfc.reshape(output, getExactlyOneShape(_this.computeOutputShape(input.shape))); }); }; Embedding.prototype.getConfig = function () { var config = { inputDim: this.inputDim, outputDim: this.outputDim, embeddingsInitializer: serializeInitializer(this.embeddingsInitializer), embeddingsRegularizer: serializeRegularizer(this.embeddingsRegularizer), activityRegularizer: serializeRegularizer(this.activityRegularizer), embeddingsConstraint: serializeConstraint(this.embeddingsConstraint), maskZero: this.maskZero, inputLength: this.inputLength }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return Embedding; }(Layer)); /** @nocollapse */ Embedding.className = 'Embedding'; tfc.serialization.registerClass(Embedding); /** * Generic Merge layer for element-wise merge functions. * * Used to implement `Sum`, `Average`, `Concatenate`, etc. */ var Merge = /** @class */ (function (_super) { __extends(Merge, _super); function Merge(args) { var _this = _super.call(this, args || {}) || this; _this.supportsMasking = true; return _this; } /** * Logic for merging multiple tensors, to be overridden by subclasses. * @param inputs */ Merge.prototype.mergeFunction = function (inputs) { throw new NotImplementedError(); }; /** * Computes the shape of the result of an elementwise operation. * * @param shape1: Shape of the first tensor. * @param shape2: Shape of the second tensor. * @returns Expected output shape when an elementwise operation is carried * out on 2 tensors with shapes `shape1` and `shape2`. * @throws ValueError: If `shape1` and `shape2` are not compatible for * element-wise operations. */ Merge.prototype.computeElementwiseOpOutputShape = function (shape1, shape2) { if (shape1 == null || shape2 == null) { return null; } else if (shape1.length < shape2.length) { return this.computeElementwiseOpOutputShape(shape2, shape1); } else if (shape2.length === 0) { return shape1; } var outputShape = shape1.slice(0, shape1.length - shape2.length); for (var k = 0; k < shape2.length; ++k) { var i = shape1[shape1.length - shape2.length + k]; var j = shape2[k]; if (i == null || j == null || i < 0 || j < 0) { outputShape.push(null); } else if (i === 1) { outputShape.push(j); } else if (j === 1) { outputShape.push(i); } else { if (i !== j) { throw new ValueError('Operands could not be broadcast together with shapes ' + JSON.stringify(shape1) + ' ' + JSON.stringify(shape2)); } outputShape.push(i); } } return outputShape; }; Merge.prototype.build = function (inputShape) { var e_1, _a; // Used purely for shape validation. if (Array.isArray(inputShape) && !Array.isArray(inputShape[0])) { // Make sure that inputShape is an Array of shape. inputShape = [getExactlyOneShape(inputShape)]; } inputShape = inputShape; if (inputShape.length < 2) { throw new ValueError('A merge layer should be called on an Array of at least 2 inputs.' + " Got ".concat(inputShape.length, " input(s).")); } // Make sure that there is at most one unique batch size among the input // shapes. var batchSizes = []; try { for (var inputShape_1 = __values(inputShape), inputShape_1_1 = inputShape_1.next(); !inputShape_1_1.done; inputShape_1_1 = inputShape_1.next()) { var shape = inputShape_1_1.value; if (shape != null && shape[0] !== null) { batchSizes.push(shape[0]); } } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (inputShape_1_1 && !inputShape_1_1.done && (_a = inputShape_1.return)) _a.call(inputShape_1); } finally { if (e_1) throw e_1.error; } } batchSizes = unique(batchSizes); if (batchSizes.length > 1) { throw new ValueError("Can not merge tensors with different batch sizes. " + "Got tensors with shapes: ".concat(JSON.stringify(inputShape), ".")); } var outputShape = inputShape[0] == null ? null : inputShape[0].slice(1); for (var i = 1; i < inputShape.length; ++i) { var shape = inputShape[i] == null ? null : inputShape[i].slice(1); outputShape = this.computeElementwiseOpOutputShape(outputShape, shape); } // If the inputs have different ranks, we have to reshape them to make them // broadcastable. var allRanks = inputShape.map(function (shape) { return shape.length; }); if (inputShape.indexOf(null) === -1 && unique(allRanks).length === 1) { this.reshapeRequired = false; } else { this.reshapeRequired = true; } }; Merge.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { var e_2, _a, e_3, _b; inputs = inputs; if (_this.reshapeRequired) { var reshapedInputs = []; var inputDims = inputs.map(function (input) { return input.rank; }); if (inputDims.indexOf(null) === -1) { // If ranks of all inputs are available, we simply expand each of them // at axis=1 until all of them have the same rank. var maxNDim = max(inputDims); try { for (var inputs_1 = __values(inputs), inputs_1_1 = inputs_1.next(); !inputs_1_1.done; inputs_1_1 = inputs_1.next()) { var x = inputs_1_1.value; var xNDim = x.rank; for (var k = 0; k < maxNDim - xNDim; ++k) { x = expandDims$1(x, 1); } reshapedInputs.push(x); } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (inputs_1_1 && !inputs_1_1.done && (_a = inputs_1.return)) _a.call(inputs_1); } finally { if (e_2) throw e_2.error; } } return _this.mergeFunction(reshapedInputs); } else { // Transpose all inputs so that batch size is the last dimension. // [batchSize, dim1, dim2, ...] -> [dim1, dim2, ..., batchSize] var transposed = false; try { for (var inputs_2 = __values(inputs), inputs_2_1 = inputs_2.next(); !inputs_2_1.done; inputs_2_1 = inputs_2.next()) { var x = inputs_2_1.value; var xNDim = x.rank; if (xNDim == null) { var xShape = x.shape; var batchSize = xShape[0]; var newShape = xShape.slice(1).concat([batchSize]); var xTransposed = tfc__namespace.reshape(x, [batchSize].concat(arrayProd(xShape.slice(1)))); xTransposed = tfc__namespace.transpose(xTransposed, [1, 0]); xTransposed = tfc__namespace.reshape(xTransposed, newShape); reshapedInputs.push(xTransposed); transposed = true; } else if (xNDim > 1) { var dims = range(1, xNDim).concat([0]); reshapedInputs.push(tfc__namespace.transpose(x, dims)); transposed = true; } else { // We don't transpose inputs if they are 1D vectors or scalars. reshapedInputs.push(x); } } } catch (e_3_1) { e_3 = { error: e_3_1 }; } finally { try { if (inputs_2_1 && !inputs_2_1.done && (_b = inputs_2.return)) _b.call(inputs_2); } finally { if (e_3) throw e_3.error; } } var y = _this.mergeFunction(reshapedInputs); var yNDim = y.rank; if (transposed) { // If inputs have been transposed, we have to transpose the output // too. if (yNDim == null) { var yShape = y.shape; var yNDim_1 = yShape.length; var batchSize = yShape[yNDim_1 - 1]; var newShape = [batchSize].concat(yShape.slice(0, yShape.length - 1)); y = tfc__namespace.reshape(tfc__namespace.transpose(tfc__namespace.reshape(y, [-1, batchSize]), [1, 0]), newShape); } else if (yNDim > 1) { var dims = [yNDim - 1].concat(range(0, yNDim - 1)); y = tfc__namespace.transpose(y, dims); } } return y; } } else { return _this.mergeFunction(inputs); } }); }; Merge.prototype.computeOutputShape = function (inputShape) { var e_4, _a; inputShape = inputShape; var outputShape; if (inputShape[0] == null) { outputShape = null; } else { outputShape = inputShape[0].slice(1); } for (var i = 1; i < inputShape.length; ++i) { var shape = inputShape[i] == null ? null : inputShape[i].slice(1); outputShape = this.computeElementwiseOpOutputShape(outputShape, shape); } var batchSizes = []; try { for (var inputShape_2 = __values(inputShape), inputShape_2_1 = inputShape_2.next(); !inputShape_2_1.done; inputShape_2_1 = inputShape_2.next()) { var shape = inputShape_2_1.value; if (shape != null && shape[0] !== null) { batchSizes.push(shape[0]); } } } catch (e_4_1) { e_4 = { error: e_4_1 }; } finally { try { if (inputShape_2_1 && !inputShape_2_1.done && (_a = inputShape_2.return)) _a.call(inputShape_2); } finally { if (e_4) throw e_4.error; } } batchSizes = unique(batchSizes); if (batchSizes.length === 1) { outputShape = batchSizes.concat(outputShape); } else { outputShape = [null].concat(outputShape); } return outputShape; }; Merge.prototype.computeMask = function (inputs, mask) { return tfc__namespace.tidy(function () { if (mask == null) { return null; } if (!Array.isArray(mask)) { throw new ValueError('`mask` should be an Array'); } if (!Array.isArray(inputs)) { throw new ValueError('`inputs` should be an Array'); } if (mask.length !== inputs.length) { throw new ValueError("The Array 'inputs' and 'mask' are expected to have the same " + "length, but have different lengths " + "(".concat(inputs.length, " vs ").concat(mask.length, ")")); } if (mask.every(function (m) { return m == null; })) { return null; } mask = mask.map(function (m) { return m == null ? m : tfc__namespace.expandDims(m, 0); }); var output = mask[0]; for (var i = 1; i < mask.length - 1; ++i) { output = tfc__namespace.logicalAnd(output, mask[i]); } return output; }); }; return Merge; }(Layer)); var Add = /** @class */ (function (_super) { __extends(Add, _super); function Add(args) { return _super.call(this, args) || this; } Add.prototype.mergeFunction = function (inputs) { return tfc.tidy(function () { var output = inputs[0].clone(); for (var i = 1; i < inputs.length; ++i) { output = tfc__namespace.add(output, inputs[i]); } return output; }); }; return Add; }(Merge)); /** @nocollapse */ Add.className = 'Add'; tfc.serialization.registerClass(Add); var Multiply = /** @class */ (function (_super) { __extends(Multiply, _super); function Multiply(args) { return _super.call(this, args) || this; } Multiply.prototype.mergeFunction = function (inputs) { return tfc.tidy(function () { var output = inputs[0].clone(); for (var i = 1; i < inputs.length; ++i) { output = tfc__namespace.mul(output, inputs[i]); } return output; }); }; return Multiply; }(Merge)); /** @nocollapse */ Multiply.className = 'Multiply'; tfc.serialization.registerClass(Multiply); var Average = /** @class */ (function (_super) { __extends(Average, _super); function Average(args) { return _super.call(this, args) || this; } Average.prototype.mergeFunction = function (inputs) { return tfc.tidy(function () { var output = inputs[0].clone(); for (var i = 1; i < inputs.length; ++i) { output = tfc__namespace.add(output, inputs[i]); } return tfc__namespace.mul(1 / inputs.length, output); }); }; return Average; }(Merge)); /** @nocollapse */ Average.className = 'Average'; tfc.serialization.registerClass(Average); var Maximum = /** @class */ (function (_super) { __extends(Maximum, _super); function Maximum(args) { return _super.call(this, args) || this; } Maximum.prototype.mergeFunction = function (inputs) { return tfc.tidy(function () { var output = inputs[0]; for (var i = 1; i < inputs.length; ++i) { output = tfc__namespace.maximum(output, inputs[i]); } return output; }); }; return Maximum; }(Merge)); /** @nocollapse */ Maximum.className = 'Maximum'; tfc.serialization.registerClass(Maximum); var Minimum = /** @class */ (function (_super) { __extends(Minimum, _super); function Minimum(args) { return _super.call(this, args) || this; } Minimum.prototype.mergeFunction = function (inputs) { return tfc.tidy(function () { var output = inputs[0]; for (var i = 1; i < inputs.length; ++i) { output = tfc__namespace.minimum(output, inputs[i]); } return output; }); }; return Minimum; }(Merge)); /** @nocollapse */ Minimum.className = 'Minimum'; tfc.serialization.registerClass(Minimum); var Concatenate = /** @class */ (function (_super) { __extends(Concatenate, _super); function Concatenate(args) { var _this = _super.call(this, args) || this; _this.DEFAULT_AXIS = -1; if (args == null) { args = {}; } _this.axis = args.axis == null ? _this.DEFAULT_AXIS : args.axis; _this.supportsMasking = true; _this.reshapeRequired = false; return _this; } Concatenate.prototype.build = function (inputShape) { var e_5, _a, e_6, _b; // Used purely for shape validation.] if (!(Array.isArray(inputShape) && Array.isArray(inputShape[0])) || inputShape.length === 1) { throw new ValueError('A `Concatenate` layer should be called on a list of at least 2 ' + 'inputs'); } inputShape = inputShape; var allNoneShape = true; try { for (var inputShape_3 = __values(inputShape), inputShape_3_1 = inputShape_3.next(); !inputShape_3_1.done; inputShape_3_1 = inputShape_3.next()) { var shape = inputShape_3_1.value; if (shape != null) { allNoneShape = false; break; } } } catch (e_5_1) { e_5 = { error: e_5_1 }; } finally { try { if (inputShape_3_1 && !inputShape_3_1.done && (_a = inputShape_3.return)) _a.call(inputShape_3); } finally { if (e_5) throw e_5.error; } } if (allNoneShape) { return; } var shapeSet = []; for (var i = 0; i < inputShape.length; ++i) { var shapeWithoutConcatAxis = inputShape[i].slice(); shapeWithoutConcatAxis.splice(this.axis, 1); var exists = false; try { for (var shapeSet_1 = (e_6 = void 0, __values(shapeSet)), shapeSet_1_1 = shapeSet_1.next(); !shapeSet_1_1.done; shapeSet_1_1 = shapeSet_1.next()) { var shape = shapeSet_1_1.value; if (tfc.util.arraysEqual(shape, shapeWithoutConcatAxis)) { exists = true; break; } } } catch (e_6_1) { e_6 = { error: e_6_1 }; } finally { try { if (shapeSet_1_1 && !shapeSet_1_1.done && (_b = shapeSet_1.return)) _b.call(shapeSet_1); } finally { if (e_6) throw e_6.error; } } if (!exists) { shapeSet.push(shapeWithoutConcatAxis); } } if (shapeSet.length > 1) { throw new ValueError('A `Concatenate` layer requires inputs with matching shapes ' + 'except for the concat axis. Got input shapes: ' + JSON.stringify(inputShape)); } }; Concatenate.prototype.mergeFunction = function (inputs) { var _this = this; return tfc.tidy(function () { return concatenate$1(inputs, _this.axis); }); }; Concatenate.prototype.computeOutputShape = function (inputShape) { var e_7, _a; if (!(Array.isArray(inputShape) && Array.isArray(inputShape[0]))) { throw new ValueError('A `Concatenate` layer should be called on a list of inputs.'); } var inputShapes = inputShape; var outputShape = inputShapes[0].slice(); var axis = this.axis < 0 ? outputShape.length + this.axis : this.axis; try { // Porting Note: the line above is because TypeScript doesn't support // negative indices. for (var _b = __values(inputShapes.slice(1)), _c = _b.next(); !_c.done; _c = _b.next()) { var shape = _c.value; if (outputShape[axis] == null || shape[axis] == null) { outputShape[axis] = null; break; } outputShape[axis] += shape[axis]; } } catch (e_7_1) { e_7 = { error: e_7_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_7) throw e_7.error; } } return outputShape; }; Concatenate.prototype.computeMask = function (inputs, mask) { var _this = this; if (mask == null) { return null; } if (!Array.isArray(mask)) { throw new ValueError('`mask` should be an array for Concatenate'); } if (!Array.isArray(inputs)) { throw new ValueError('`inputs` should be an array for Concatenate'); } if (mask.length !== inputs.length) { throw new ValueError("Mismatch in the length of mask (".concat(mask.length, ") ") + "and the legnth of inputs (".concat(inputs.length, ")")); } return tfc__namespace.tidy(function () { var allNullMasks = true; mask.forEach(function (m) { if (m != null) { allNullMasks = false; return; } }); if (allNullMasks) { return null; } var outputMasks = []; for (var i = 0; i < inputs.length; ++i) { if (mask[i] == null) { // Input is unmasked. Append all 1's to masks. outputMasks.push(tfc__namespace.cast(tfc__namespace.onesLike(inputs[i]), 'bool')); } else if (mask[i].rank < inputs[i].rank) { // Mask is smaller than the input, expand it. outputMasks.push(tfc__namespace.expandDims(mask[i], -1)); } else { outputMasks.push(mask[i]); } } var concatenatedMasks = tfc__namespace.concat(outputMasks, _this.axis); return tfc__namespace.all(concatenatedMasks, -1, false); }); }; Concatenate.prototype.getConfig = function () { var config = { 'axis': this.axis, }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return Concatenate; }(Merge)); /** @nocollapse */ Concatenate.className = 'Concatenate'; tfc.serialization.registerClass(Concatenate); /** * Interpretable potentially negative axis index. * * For example, given axis = -1, and dim = 3, this function will return 2. * * @param axis The axis index, may be a positive, zero or negative integer. * @param dim Total number of dimensions, a positive integer. * @returns A non-negative axis index equivalent to the input `axis`. */ function interpretAxis(axis, dim) { while (axis < 0) { axis += dim; } return axis; } function batchDot(x, y, axes) { if (x.shape.length > 3 || y.shape.length > 3) { throw new NotImplementedError('batchDot is not implemented for tensors of 4D or higher rank yet'); } tfc__namespace.util.assert(x.shape.length >= 2, function () { return "batchDot requires the rank of x to be >= 2, " + "but got ".concat(x.shape.length); }); tfc__namespace.util.assert(x.shape.length >= 2, function () { return "batchDot requires the rank of y to be >= 2, " + "but got ".concat(y.shape.length); }); if (typeof axes === 'number') { axes = [axes, axes]; } if (x.dtype === 'complex64' || y.dtype === 'complex64') { throw new NotImplementedError('batchDot is not implemented for complex64-type Tensors yet.'); } var xNDim = x.shape.length; var yNDim = y.shape.length; if (axes == null) { // Behave like batchMatmul by default. axes = [xNDim - 1, yNDim - 2]; } var axesArray = axes; return tfc__namespace.tidy(function () { var diff; if (xNDim > yNDim) { diff = xNDim - yNDim; var diffShape = []; for (var i = 0; i < diff; ++i) { diffShape.push(1); } y = tfc__namespace.reshape(y, y.shape.concat(diffShape)); } else if (yNDim > xNDim) { diff = yNDim - xNDim; var diffShape = []; for (var i = 0; i < diff; ++i) { diffShape.push(1); } x = tfc__namespace.reshape(x, x.shape.concat(diffShape)); } else { diff = 0; } var out; if (x.shape.length === 2 && y.shape.length === 2) { if (axesArray[0] === axesArray[1]) { out = tfc__namespace.sum(tfc__namespace.mul(x, y), axesArray[0]); } else { out = tfc__namespace.sum(tfc__namespace.mul(tfc__namespace.transpose(x, [1, 0]), y), axesArray[1]); } } else { var adjX = axesArray[0] !== x.shape.length - 1; var adjY = axesArray[1] === y.shape.length - 1; out = tfc__namespace.matMul(x, y, adjX, adjY); } if (diff > 0) { var idx = void 0; if (xNDim > yNDim) { idx = xNDim + yNDim - 3; } else { idx = xNDim - 1; } var squeezeAxes = []; for (var i = idx; i < idx + diff; ++i) { squeezeAxes.push(i); } out = tfc__namespace.squeeze(out, squeezeAxes); } if (out.shape.length === 1) { out = tfc__namespace.expandDims(out, 1); } return out; }); } var Dot = /** @class */ (function (_super) { __extends(Dot, _super); function Dot(args) { var _this = _super.call(this, args) || this; _this.axes = args.axes; _this.normalize = args.normalize == null ? false : args.normalize; _this.supportsMasking = true; _this.reshapeRequired = false; return _this; } Dot.prototype.build = function (inputShape) { tfc__namespace.util.assert(Array.isArray(inputShape) && inputShape.length === 2 && Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]), function () { return 'A `Dot` layer should be called on a list of exactly 2 inputs.'; }); var shape1 = inputShape[0]; var shape2 = inputShape[1]; if (shape1.length > 3 || shape2.length > 3) { throw new NotImplementedError('Dot layer does not support tensors of 4D or higher rank yet.'); } var axes = this.interpretAxes(shape1, shape2); if (shape1[axes[0]] !== shape2[axes[1]]) { throw new ValueError("Dimension incompatibility: " + "".concat(shape1[axes[0]], " !== ").concat(shape2[axes[1]])); } }; Dot.prototype.mergeFunction = function (inputs) { if (inputs.length !== 2) { throw new ValueError('A `Dot` layer must be called on exactly 2 inputs, ' + "but received ".concat(inputs.length, " input(s).")); } var x1 = inputs[0]; var x2 = inputs[1]; var axes; if (!Array.isArray(this.axes)) { axes = [ interpretAxis(this.axes, x1.shape.length), interpretAxis(this.axes, x2.shape.length) ]; } else { axes = this.axes.map(function (axis, i) { return interpretAxis(axis, inputs[i].shape.length); }); } if (this.normalize) { x1 = l2Normalize(x1, axes[0]); x2 = l2Normalize(x2, axes[1]); } return batchDot(x1, x2, axes); }; Dot.prototype.interpretAxes = function (shape1, shape2) { var axes; if (!Array.isArray(this.axes)) { // `this.axes` is a single integer. axes = [ interpretAxis(this.axes, shape1.length), interpretAxis(this.axes, shape2.length) ]; } else { // `this.axes` is an Array of integers. axes = this.axes; } return axes; }; Dot.prototype.computeOutputShape = function (inputShape) { tfc__namespace.util.assert(Array.isArray(inputShape) && inputShape.length === 2 && Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]), function () { return 'A `Dot` layer should be called on a list of exactly 2 inputs.'; }); var shape1 = inputShape[0].slice(); var shape2 = inputShape[1].slice(); if (shape1.length > 3 || shape2.length > 3) { throw new NotImplementedError('Dot layer does not support tensors of 4D or higher rank yet.'); } var axes = this.interpretAxes(shape1, shape2); shape1.splice(axes[0], 1); shape2.splice(axes[1], 1); shape2.splice(0, 1); var outputShape = shape1.concat(shape2); if (outputShape.length === 1) { outputShape.push(1); } return outputShape; }; Dot.prototype.computeMask = function (inputs, mask) { return null; }; Dot.prototype.getConfig = function () { var config = { 'axes': this.axes, 'normalize': this.normalize }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return Dot; }(Merge)); /** @nocollapse */ Dot.className = 'Dot'; tfc.serialization.registerClass(Dot); // TODO(cais): Add functional interfaces for the merge layers. var GaussianNoise = /** @class */ (function (_super) { __extends(GaussianNoise, _super); function GaussianNoise(args) { var _this = _super.call(this, args) || this; _this.supportsMasking = true; _this.stddev = args.stddev; return _this; } GaussianNoise.prototype.computeOutputShape = function (inputShape) { return inputShape; }; GaussianNoise.prototype.getConfig = function () { var baseConfig = _super.prototype.getConfig.call(this); var config = { stddev: this.stddev }; Object.assign(config, baseConfig); return config; }; GaussianNoise.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { _this.invokeCallHook(inputs, kwargs); var input = getExactlyOneTensor(inputs); var noised = function () { return tfc.add(randomNormal$1(input.shape, 0, _this.stddev), input); }; var output = inTrainPhase(noised, function () { return input; }, kwargs['training'] || false); return output; }); }; return GaussianNoise; }(Layer)); /** @nocollapse */ GaussianNoise.className = 'GaussianNoise'; tfc.serialization.registerClass(GaussianNoise); var GaussianDropout = /** @class */ (function (_super) { __extends(GaussianDropout, _super); function GaussianDropout(args) { var _this = _super.call(this, args) || this; _this.supportsMasking = true; _this.rate = args.rate; return _this; } GaussianDropout.prototype.computeOutputShape = function (inputShape) { return inputShape; }; GaussianDropout.prototype.getConfig = function () { var baseConfig = _super.prototype.getConfig.call(this); var config = { rate: this.rate }; Object.assign(config, baseConfig); return config; }; GaussianDropout.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { _this.invokeCallHook(inputs, kwargs); var input = getExactlyOneTensor(inputs); if (_this.rate > 0 && _this.rate < 1) { var noised = function () { var stddev = Math.sqrt(_this.rate / (1 - _this.rate)); return tfc.mul(input, randomNormal$1(input.shape, 1, stddev)); }; return inTrainPhase(noised, function () { return input; }, kwargs['training'] || false); } return input; }); }; return GaussianDropout; }(Layer)); /** @nocollapse */ GaussianDropout.className = 'GaussianDropout'; tfc.serialization.registerClass(GaussianDropout); /** * Applies Alpha Dropout to the input. * * As it is a regularization layer, it is only active at training time. * * Alpha Dropout is a `Dropout` that keeps mean and variance of inputs * to their original values, in order to ensure the self-normalizing property * even after this dropout. * Alpha Dropout fits well to Scaled Exponential Linear Units * by randomly setting activations to the negative saturation value. * * Arguments: * - `rate`: float, drop probability (as with `Dropout`). * The multiplicative noise will have * standard deviation `sqrt(rate / (1 - rate))`. * - `noise_shape`: A 1-D `Tensor` of type `int32`, representing the * shape for randomly generated keep/drop flags. * * Input shape: * Arbitrary. Use the keyword argument `inputShape` * (tuple of integers, does not include the samples axis) * when using this layer as the first layer in a model. * * Output shape: * Same shape as input. * * References: * - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) */ var AlphaDropout = /** @class */ (function (_super) { __extends(AlphaDropout, _super); function AlphaDropout(args) { var _this = _super.call(this, args) || this; _this.supportsMasking = true; _this.rate = args.rate; _this.noiseShape = args.noiseShape; return _this; } AlphaDropout.prototype._getNoiseShape = function (inputs) { return this.noiseShape || getExactlyOneTensor(inputs).shape; }; AlphaDropout.prototype.computeOutputShape = function (inputShape) { return inputShape; }; AlphaDropout.prototype.getConfig = function () { var baseConfig = _super.prototype.getConfig.call(this); var config = { rate: this.rate }; Object.assign(config, baseConfig); return config; }; AlphaDropout.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { if (_this.rate < 1 && _this.rate > 0) { var noiseShape_1 = _this._getNoiseShape(inputs); var droppedInputs = function () { var input = getExactlyOneTensor(inputs); var alpha = 1.6732632423543772848170429916717; var scale = 1.0507009873554804934193349852946; var alphaP = -alpha * scale; var keptIdx = tfc.greaterEqual(tfc.randomUniform(noiseShape_1), _this.rate); keptIdx = cast$1(keptIdx, 'float32'); // get default dtype. // Get affine transformation params. var a = Math.pow(((1 - _this.rate) * (1 + _this.rate * Math.pow(alphaP, 2))), -0.5); var b = -a * alphaP * _this.rate; // Apply mask. var x = tfc.add(tfc.mul(input, keptIdx), tfc.mul(tfc.add(keptIdx, -1), alphaP)); return tfc.add(tfc.mul(x, a), b); }; return inTrainPhase(droppedInputs, function () { return getExactlyOneTensor(inputs); }, kwargs['training'] || false); } return inputs; }); }; return AlphaDropout; }(Layer)); /** @nocollapse */ AlphaDropout.className = 'AlphaDropout'; tfc.serialization.registerClass(AlphaDropout); /** * Applies batch normalization on x given mean, var, beta and gamma. * * I.e. returns: * `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta` * * @param x Input tensor. * @param mean Mean of batch. * @param variance Variance of batch. * @param beta Tensor with which to center the input. * @param gamma Tensor by which to scale the input. * @param epsilon Fuzz factor. * @returns The result of the batch normalization. */ function batchNormalization$1(x, mean, variance, beta, gamma, epsilon) { if (epsilon === void 0) { epsilon = 1e-3; } var out; if (x.rank === 2) { out = tfc__namespace.batchNorm2d(x, mean, variance, beta, gamma, epsilon); } else if (x.rank === 3) { // TODO(cais): Check rank; give proper error message. out = tfc__namespace.batchNorm3d(x, mean, variance, beta, gamma, epsilon); } else if (x.rank === 4) { out = tfc__namespace.batchNorm4d(x, mean, variance, beta, gamma, epsilon); } else { throw new NotImplementedError("batchNormalization is not implemented for array of rank ".concat(x.rank, " ") + "yet"); } return out; } /** * Non-broadcasting batch normalization for use in training (not inference). * * The input is normalized to zero mean and unit variance along the * `reductionAxes`, followed by scaling with `gamma` and shifted by `beta`. * The result of that is returned as the first element * of the returned `Array`. The other two elements are the mean and variance, * respectively. * * @param x Input tensor to be normalized. * @param gamma Tensor by which to scale the input. * @param beta Tensor by which to center the input. * @param reductionAxes Axes over which to normalize. * @param epsilon Fuzz factor. * @returns An `Array` of three `Tensors`: * [normalized tensor, mean of input, variance of input]. */ function regularNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon) { if (epsilon === void 0) { epsilon = 1e-3; } return tfc.tidy(function () { var meanAndVariance = tfc__namespace.moments(x, reductionAxes); var mean = meanAndVariance.mean; var variance = meanAndVariance.variance; var normed = batchNormalization$1(x, mean, variance, beta, gamma, epsilon); return [normed, mean, variance]; }); } /** * Broadcasting batch normalization for use in training (not inference). * * The input is normalized to zero mean and unit variance along the * `reductionAxes`, followed by scaling with `gamma` and shifted by `beta`. * The result of that is returned as the first element * of the returned `Array`. The other two elements are the mean and variance, * respectively. * * @param x Input tensor to be normalized. * @param gamma Tensor by which to scale the input. * @param beta Tensor by which to center the input. * @param reductionAxes Axes over which to normalize. * @param epsilon Fuzz factor. * @returns An `Array` of three `Tensors`: * [normalized tensor, mean of input, variance of input]. */ function broadcastNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon) { if (epsilon === void 0) { epsilon = 1e-3; } return tfc.tidy(function () { var e_1, _a; var meanAndVariance = tfc__namespace.moments(x, reductionAxes); var mean = meanAndVariance.mean; var variance = meanAndVariance.variance; var targetShape = []; try { for (var _b = __values(range(0, x.rank)), _c = _b.next(); !_c.done; _c = _b.next()) { var axis = _c.value; if (reductionAxes.indexOf(axis) !== -1) { targetShape.push(1); } else { targetShape.push(x.shape[axis]); } } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_1) throw e_1.error; } } var broadcastMean = tfc.reshape(mean, targetShape); var broadcastVariance = tfc.reshape(variance, targetShape); var broadcastGamma = gamma == null ? null : tfc.reshape(gamma, targetShape); var broadcastBeta = beta == null ? null : tfc.reshape(beta, targetShape); var normed = batchNormalization$1(x, broadcastMean, broadcastVariance, broadcastBeta, broadcastGamma, epsilon); return [normed, mean, variance]; }); } /** * Batch normalization for use in training (not inference). * * @param x Input tensor to be normalized. * @param gamma Tensor by which to scale the input. * @param beta Tensor by which to center the input. * @param reductionAxes Axes over which to normalize. * @param epsilon Fuzz factor. * @returns An `Array` of three `Tensors`: * [normalized tensor, mean of input, variance of input]. */ function normalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon) { if (epsilon === void 0) { epsilon = 1e-3; } if (tfc.util.arraysEqual(reductionAxes.slice().sort(), range(0, x.rank - 1))) { return regularNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon); } else { return broadcastNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon); } } var BatchNormalization = /** @class */ (function (_super) { __extends(BatchNormalization, _super); function BatchNormalization(args) { var _this = this; if (args == null) { args = {}; } _this = _super.call(this, args) || this; _this.supportsMasking = true; _this.axis = args.axis == null ? -1 : args.axis; _this.momentum = args.momentum == null ? 0.99 : args.momentum; _this.epsilon = args.epsilon == null ? 1e-3 : args.epsilon; _this.center = args.center == null ? true : args.center; _this.scale = args.scale == null ? true : args.scale; _this.betaInitializer = getInitializer(args.betaInitializer || 'zeros'); _this.gammaInitializer = getInitializer(args.gammaInitializer || 'ones'); _this.movingMeanInitializer = getInitializer(args.movingMeanInitializer || 'zeros'); _this.movingVarianceInitializer = getInitializer(args.movingVarianceInitializer || 'ones'); _this.betaConstraint = getConstraint(args.betaConstraint); _this.gammaConstraint = getConstraint(args.gammaConstraint); _this.betaRegularizer = getRegularizer(args.betaRegularizer); _this.gammaRegularizer = getRegularizer(args.gammaRegularizer); return _this; } BatchNormalization.prototype.build = function (inputShape) { var _a; inputShape = getExactlyOneShape(inputShape); var axis = this.axis >= 0 ? this.axis : (this.axis + inputShape.length); var dim = inputShape[axis]; if (dim == null) { throw new ValueError("Axis ".concat(axis, " of input tensor should have a defined dimension but ") + "the layer received an input with shape " + "".concat(JSON.stringify(inputShape), ".")); } this.inputSpec = [new InputSpec({ ndim: inputShape.length, axes: (_a = {}, _a[axis] = dim, _a) })]; var shape = [dim]; if (this.scale) { this.gamma = this.addWeight('gamma', shape, null, this.gammaInitializer, this.gammaRegularizer, true, this.gammaConstraint); } if (this.center) { this.beta = this.addWeight('beta', shape, null, this.betaInitializer, this.betaRegularizer, true, this.betaConstraint); } this.movingMean = this.addWeight('moving_mean', shape, null, this.movingMeanInitializer, null, false); this.movingVariance = this.addWeight('moving_variance', shape, null, this.movingVarianceInitializer, null, false); this.built = true; }; BatchNormalization.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { var training = kwargs['training'] == null ? false : kwargs['training']; var input = getExactlyOneTensor(inputs); var inputShape = input.shape; var ndim = inputShape.length; var reductionAxes = range(0, ndim); var axis = _this.axis >= 0 ? _this.axis : (_this.axis + ndim); reductionAxes.splice(axis, 1); var broadcastShape = pyListRepeat(1, ndim); broadcastShape[axis] = inputShape[axis]; var sortedReductionAxes = reductionAxes.slice(); sortedReductionAxes.sort(); var needsBroadcasting = !tfc.util.arraysEqual(sortedReductionAxes, range(0, ndim).slice(0, ndim - 1)); var normalizeInference = function () { if (needsBroadcasting) { var broadcastMovingMean = tfc.reshape(_this.movingMean.read(), broadcastShape); var broadcastMovingVariance = tfc.reshape(_this.movingVariance.read(), broadcastShape); var broadcastBeta = _this.center ? tfc.reshape(_this.beta.read(), broadcastShape) : null; var broadcastGamma = _this.scale ? tfc.reshape(_this.gamma.read(), broadcastShape) : null; return batchNormalization$1(input, broadcastMovingMean, broadcastMovingVariance, broadcastBeta, broadcastGamma, _this.epsilon); } else { return batchNormalization$1(input, _this.movingMean.read(), _this.movingVariance.read(), _this.beta == null ? null : _this.beta.read(), _this.gamma == null ? null : _this.gamma.read(), _this.epsilon); } }; if (!training) { return normalizeInference(); } var _a = __read(normalizeBatchInTraining(input, _this.gamma.read(), _this.beta.read(), reductionAxes, _this.epsilon), 3), normedTraining = _a[0], mean = _a[1], variance = _a[2]; var doMovingAverage = function (variable, value, momentum) { tfc__namespace.tidy(function () { var decay = 1 - momentum; var origValue = variable.read(); var updateDelta = tfc__namespace.mul(tfc__namespace.sub(origValue, value), decay); variable.write(tfc__namespace.sub(origValue, updateDelta)); }); }; // Perform updates to moving mean and moving variance for training. // Porting Note: In PyKeras, these updates to `movingMean` and // `movingAverage` are done as a deferred Graph, added to the `Layer`'s // `update`s using the `add_update()` method. Here we do it imperatively // and encapsulate the updates in a function that is invoked // immediately. var updateMovingMeanAndVariance = function () { doMovingAverage(_this.movingMean, mean, _this.momentum); doMovingAverage(_this.movingVariance, variance, _this.momentum); }; updateMovingMeanAndVariance(); return normedTraining; }); }; BatchNormalization.prototype.getConfig = function () { var config = { axis: this.axis, momentum: this.momentum, epsilon: this.epsilon, center: this.center, scale: this.scale, betaInitializer: serializeInitializer(this.betaInitializer), gammaInitializer: serializeInitializer(this.gammaInitializer), movingMeanInitializer: serializeInitializer(this.movingMeanInitializer), movingVarianceInitializer: serializeInitializer(this.movingVarianceInitializer), betaRegularizer: serializeRegularizer(this.betaRegularizer), gammaRegularizer: serializeRegularizer(this.gammaRegularizer), betaConstraint: serializeConstraint(this.betaConstraint), gammaConstraint: serializeConstraint(this.gammaConstraint) }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return BatchNormalization; }(Layer)); /** @nocollapse */ BatchNormalization.className = 'BatchNormalization'; tfc.serialization.registerClass(BatchNormalization); var LayerNormalization = /** @class */ (function (_super) { __extends(LayerNormalization, _super); function LayerNormalization(args) { var e_2, _a; var _this = this; if (args == null) { args = {}; } _this = _super.call(this, args) || this; _this.axis = args.axis == null ? -1 : args.axis; if (typeof _this.axis === 'number') { if (!Number.isInteger(_this.axis)) { throw new Error("Expected axis to be an integer, but received ".concat(_this.axis)); } } else if (Array.isArray(_this.axis)) { try { for (var _b = __values(_this.axis), _c = _b.next(); !_c.done; _c = _b.next()) { var axis = _c.value; if (!Number.isInteger(axis)) { throw new Error("Expected axis to be an array of integers, " + "but received ".concat(JSON.stringify(_this.axis))); } } } catch (e_2_1) { e_2 = { error: e_2_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_2) throw e_2.error; } } } else { throw new Error("Expected axis to be an integer or an array of integers, " + "but received ".concat(JSON.stringify(_this.axis))); } _this.epsilon = args.epsilon == null ? 1e-3 : args.epsilon; _this.center = args.center == null ? true : args.center; _this.scale = args.scale == null ? true : args.scale; _this.betaInitializer = getInitializer(args.betaInitializer || 'zeros'); _this.gammaInitializer = getInitializer(args.gammaInitializer || 'ones'); _this.betaRegularizer = getRegularizer(args.betaRegularizer); _this.gammaRegularizer = getRegularizer(args.gammaRegularizer); _this.supportsMasking = true; return _this; } LayerNormalization.prototype.build = function (inputShape) { var e_3, _a; inputShape = getExactlyOneShape(inputShape); var nDims = inputShape.length; // Convert axis to array and resolve negatives. if (typeof this.axis === 'number') { this.axis = [this.axis]; } for (var i = 0; i < this.axis.length; ++i) { if (this.axis[i] < 0) { this.axis[i] += nDims; } } try { // Further validate axes. for (var _b = __values(this.axis), _c = _b.next(); !_c.done; _c = _b.next()) { var axis = _c.value; if (axis < 0 || axis >= nDims) { throw new Error("Invalid axis: ".concat(axis)); } } } catch (e_3_1) { e_3 = { error: e_3_1 }; } finally { try { if (_c && !_c.done && (_a = _b.return)) _a.call(_b); } finally { if (e_3) throw e_3.error; } } if (this.axis.length !== unique(this.axis).length) { throw new Error("Found duplicate axes in: ".concat(this.axis)); } var paramShape = this.axis.map(function (axis) { return inputShape[axis]; }); var trainable = true; if (this.scale) { this.gamma = this.addWeight('gamma', paramShape, 'float32', this.gammaInitializer, this.gammaRegularizer, trainable); } else { this.gamma = null; } if (this.center) { this.beta = this.addWeight('beta', paramShape, 'float32', this.betaInitializer, this.betaRegularizer, trainable); } else { this.beta = null; } this.built = true; }; LayerNormalization.prototype.call = function (inputs, kwargs) { var _this = this; var input = getExactlyOneTensor(inputs); var inputShape = input.shape; var nDims = inputShape.length; return tfc.tidy(function () { var e_4, _a; var keepDims = true; var _b = tfc.moments(input, _this.axis, keepDims), mean = _b.mean, variance = _b.variance; var broadcastShape = pyListRepeat(1, nDims); try { for (var _c = __values(_this.axis), _d = _c.next(); !_d.done; _d = _c.next()) { var dim = _d.value; broadcastShape[dim] = inputShape[dim]; } } catch (e_4_1) { e_4 = { error: e_4_1 }; } finally { try { if (_d && !_d.done && (_a = _c.return)) _a.call(_c); } finally { if (e_4) throw e_4.error; } } var broadcast = function (v) { if (v != null && v.shape.length !== nDims) { return tfc__namespace.reshape(v, broadcastShape); } else { return v; } }; var scale = _this.scale ? broadcast(_this.gamma.read()) : null; var offset = _this.center ? broadcast(_this.beta.read()) : null; // TODO(https://github.com/tensorflow/tfjs/issues/2120): The tiling below // is a workaround for the limitation of core's batchNormalization?d don't // support broadcasting in their gradients. In addition, the tiling is // necessary to ensure correctness on the browser CPU backend regardless // of forward or backward computation. Remove this workaround once the // limitation is addressed. See . var momentsTiling = []; var scaleOffsetTiling = []; for (var i = 0; i < nDims; ++i) { if (_this.axis.indexOf(i) !== -1) { momentsTiling.push(inputShape[i]); scaleOffsetTiling.push(1); } else { momentsTiling.push(1); scaleOffsetTiling.push(inputShape[i]); } } mean = tfc__namespace.tile(mean, momentsTiling); variance = tfc__namespace.tile(variance, momentsTiling); if (scale != null) { scale = tfc__namespace.tile(scale, scaleOffsetTiling); } if (offset != null) { offset = tfc__namespace.tile(offset, scaleOffsetTiling); } return batchNormalization$1(input, mean, variance, offset, scale, _this.epsilon); }); }; LayerNormalization.prototype.getConfig = function () { var config = { axis: this.axis, epsilon: this.epsilon, center: this.center, scale: this.scale, betaInitializer: serializeInitializer(this.betaInitializer), gammaInitializer: serializeInitializer(this.gammaInitializer), betaRegularizer: serializeRegularizer(this.betaRegularizer), gammaRegularizer: serializeRegularizer(this.gammaRegularizer) }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return LayerNormalization; }(Layer)); /** @nocollapse */ LayerNormalization.className = 'LayerNormalization'; tfc.serialization.registerClass(LayerNormalization); /** * Pads the 2nd and 3rd dimensions of a 4D tensor. * * @param x Input `tf.Tensor` to be padded. * @param padding `Array` of two `Array`s, each of which is an `Array` of two * integers. The amount of padding at the beginning and end of the 2nd and 3rd * dimensions, respectively. * @param dataFormat 'channelsLast' (default) or 'channelsFirst'. * @return Padded 4D `tf.Tensor`. */ function spatial2dPadding(x, padding, dataFormat) { return tfc.tidy(function () { if (x.rank !== 4) { throw new ValueError("temporalPadding expects input tensor to be 4-D, but received a " + "".concat(x.rank, "-D tensor.")); } if (padding == null) { padding = [[1, 1], [1, 1]]; } if (padding.length !== 2 || padding[0].length !== 2 || padding[1].length !== 2) { throw new ValueError('spatial2dPadding expects `padding` to be an Array of two Arrays, ' + 'each of which is an Array of two integers.'); } if (dataFormat == null) { dataFormat = imageDataFormat(); } if (dataFormat !== 'channelsLast' && dataFormat !== 'channelsFirst') { throw new ValueError("Unknown data format: ".concat(dataFormat, ". ") + "Supported data formats are 'channelsLast' and 'channelsFirst."); } var pattern; if (dataFormat === 'channelsFirst') { pattern = [[0, 0], [0, 0], padding[0], padding[1]]; } else { pattern = [[0, 0], padding[0], padding[1], [0, 0]]; } return tfc__namespace.pad(x, pattern); }); } var ZeroPadding2D = /** @class */ (function (_super) { __extends(ZeroPadding2D, _super); function ZeroPadding2D(args) { var _this = this; if (args == null) { args = {}; } _this = _super.call(this, args) || this; _this.dataFormat = args.dataFormat == null ? imageDataFormat() : args.dataFormat; // TODO(cais): Maybe refactor the following logic surrounding `padding` // into a helper method. if (args.padding == null) { _this.padding = [[1, 1], [1, 1]]; } else if (typeof args.padding === 'number') { _this.padding = [[args.padding, args.padding], [args.padding, args.padding]]; } else { args.padding = args.padding; if (args.padding.length !== 2) { throw new ValueError("ZeroPadding2D expects padding to be a length-2 array, but " + "received a length-".concat(args.padding.length, " array.")); } var heightPadding = void 0; var widthPadding = void 0; if (typeof args.padding[0] === 'number') { heightPadding = [args.padding[0], args.padding[0]]; widthPadding = [args.padding[1], args.padding[1]]; } else { args.padding = args.padding; if (args.padding[0].length !== 2) { throw new ValueError("ZeroPadding2D expects height padding to be a length-2 array, " + "but received a length-".concat(args.padding[0].length, " array.")); } heightPadding = args.padding[0]; if (args.padding[1].length !== 2) { throw new ValueError("ZeroPadding2D expects width padding to be a length-2 array, " + "but received a length-".concat(args.padding[1].length, " array.")); } widthPadding = args.padding[1]; } _this.padding = [heightPadding, widthPadding]; } _this.inputSpec = [new InputSpec({ ndim: 4 })]; return _this; } ZeroPadding2D.prototype.computeOutputShape = function (inputShape) { inputShape = getExactlyOneShape(inputShape); var rows; var cols; if (this.dataFormat === 'channelsFirst') { if (inputShape[2] != null && inputShape[2] >= 0) { rows = inputShape[2] + this.padding[0][0] + this.padding[0][1]; } else { rows = null; } if (inputShape[3] != null && inputShape[3] >= 0) { cols = inputShape[3] + this.padding[1][0] + this.padding[1][1]; } else { cols = null; } return [inputShape[0], inputShape[1], rows, cols]; } else { if (inputShape[1] != null && inputShape[1] >= 0) { rows = inputShape[1] + this.padding[0][0] + this.padding[0][1]; } else { rows = null; } if (inputShape[2] != null && inputShape[2] >= 0) { cols = inputShape[2] + this.padding[1][0] + this.padding[1][1]; } else { cols = null; } return [inputShape[0], rows, cols, inputShape[3]]; } }; ZeroPadding2D.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { return spatial2dPadding(getExactlyOneTensor(inputs), _this.padding, _this.dataFormat); }); }; ZeroPadding2D.prototype.getConfig = function () { var config = { padding: this.padding, dataFormat: this.dataFormat, }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return ZeroPadding2D; }(Layer)); /** @nocollapse */ ZeroPadding2D.className = 'ZeroPadding2D'; tfc.serialization.registerClass(ZeroPadding2D); /** * 2D pooling. * @param x * @param poolSize * @param strides strides. Defaults to [1, 1]. * @param padding padding. Defaults to 'valid'. * @param dataFormat data format. Defaults to 'channelsLast'. * @param poolMode Mode of pooling. Defaults to 'max'. * @returns Result of the 2D pooling. */ function pool2d(x, poolSize, strides, padding, dataFormat, poolMode) { return tfc.tidy(function () { checkDataFormat(dataFormat); checkPoolMode(poolMode); checkPaddingMode(padding); if (strides == null) { strides = [1, 1]; } if (padding == null) { padding = 'valid'; } if (dataFormat == null) { dataFormat = imageDataFormat(); } if (poolMode == null) { poolMode = 'max'; } // TODO(cais): Remove the preprocessing step once deeplearn.js supports // dataFormat as an input argument. x = preprocessConv2DInput(x, dataFormat); // x is NHWC after preprocessing. var y; var paddingString = (padding === 'same') ? 'same' : 'valid'; if (poolMode === 'max') { // TODO(cais): Rank check? y = tfc__namespace.maxPool(x, poolSize, strides, paddingString); } else { // 'avg' // TODO(cais): Check the dtype and rank of x and give clear error message // if those are incorrect. y = tfc__namespace.avgPool( // TODO(cais): Rank check? x, poolSize, strides, paddingString); } if (dataFormat === 'channelsFirst') { y = tfc__namespace.transpose(y, [0, 3, 1, 2]); // NHWC -> NCHW. } return y; }); } /** * 3D pooling. * @param x * @param poolSize. Default to [1, 1, 1]. * @param strides strides. Defaults to [1, 1, 1]. * @param padding padding. Defaults to 'valid'. * @param dataFormat data format. Defaults to 'channelsLast'. * @param poolMode Mode of pooling. Defaults to 'max'. * @returns Result of the 3D pooling. */ function pool3d(x, poolSize, strides, padding, dataFormat, poolMode) { return tfc.tidy(function () { checkDataFormat(dataFormat); checkPoolMode(poolMode); checkPaddingMode(padding); if (strides == null) { strides = [1, 1, 1]; } if (padding == null) { padding = 'valid'; } if (dataFormat == null) { dataFormat = imageDataFormat(); } if (poolMode == null) { poolMode = 'max'; } // x is NDHWC after preprocessing. x = preprocessConv3DInput(x, dataFormat); var y; var paddingString = (padding === 'same') ? 'same' : 'valid'; if (poolMode === 'max') { y = tfc__namespace.maxPool3d(x, poolSize, strides, paddingString); } else { // 'avg' y = tfc__namespace.avgPool3d(x, poolSize, strides, paddingString); } if (dataFormat === 'channelsFirst') { y = tfc__namespace.transpose(y, [0, 4, 1, 2, 3]); // NDHWC -> NCDHW. } return y; }); } /** * Abstract class for different pooling 1D layers. */ var Pooling1D = /** @class */ (function (_super) { __extends(Pooling1D, _super); /** * * @param args Parameters for the Pooling layer. * * config.poolSize defaults to 2. */ function Pooling1D(args) { var _this = this; if (args.poolSize == null) { args.poolSize = 2; } _this = _super.call(this, args) || this; if (typeof args.poolSize === 'number') { _this.poolSize = [args.poolSize]; } else if (Array.isArray(args.poolSize) && args.poolSize.length === 1 && typeof args.poolSize[0] === 'number') { _this.poolSize = args.poolSize; } else { throw new ValueError("poolSize for 1D convolutional layer must be a number or an " + "Array of a single number, but received " + "".concat(JSON.stringify(args.poolSize))); } assertPositiveInteger(_this.poolSize, 'poolSize'); if (args.strides == null) { _this.strides = _this.poolSize; } else { if (typeof args.strides === 'number') { _this.strides = [args.strides]; } else if (Array.isArray(args.strides) && args.strides.length === 1 && typeof args.strides[0] === 'number') { _this.strides = args.strides; } else { throw new ValueError("strides for 1D convolutional layer must be a number or an " + "Array of a single number, but received " + "".concat(JSON.stringify(args.strides))); } } assertPositiveInteger(_this.strides, 'strides'); _this.padding = args.padding == null ? 'valid' : args.padding; checkPaddingMode(_this.padding); _this.inputSpec = [new InputSpec({ ndim: 3 })]; return _this; } Pooling1D.prototype.computeOutputShape = function (inputShape) { inputShape = getExactlyOneShape(inputShape); var length = convOutputLength(inputShape[1], this.poolSize[0], this.padding, this.strides[0]); return [inputShape[0], length, inputShape[2]]; }; Pooling1D.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { _this.invokeCallHook(inputs, kwargs); // Add dummy last dimension. inputs = expandDims$1(getExactlyOneTensor(inputs), 2); var output = _this.poolingFunction(getExactlyOneTensor(inputs), [_this.poolSize[0], 1], [_this.strides[0], 1], _this.padding, 'channelsLast'); // Remove dummy last dimension. return tfc__namespace.squeeze(output, [2]); }); }; Pooling1D.prototype.getConfig = function () { var config = { poolSize: this.poolSize, padding: this.padding, strides: this.strides, }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return Pooling1D; }(Layer)); var MaxPooling1D = /** @class */ (function (_super) { __extends(MaxPooling1D, _super); function MaxPooling1D(args) { return _super.call(this, args) || this; } MaxPooling1D.prototype.poolingFunction = function (inputs, poolSize, strides, padding, dataFormat) { checkDataFormat(dataFormat); checkPaddingMode(padding); return pool2d(inputs, poolSize, strides, padding, dataFormat, 'max'); }; return MaxPooling1D; }(Pooling1D)); /** @nocollapse */ MaxPooling1D.className = 'MaxPooling1D'; tfc.serialization.registerClass(MaxPooling1D); var AveragePooling1D = /** @class */ (function (_super) { __extends(AveragePooling1D, _super); function AveragePooling1D(args) { return _super.call(this, args) || this; } AveragePooling1D.prototype.poolingFunction = function (inputs, poolSize, strides, padding, dataFormat) { checkDataFormat(dataFormat); checkPaddingMode(padding); return pool2d(inputs, poolSize, strides, padding, dataFormat, 'avg'); }; return AveragePooling1D; }(Pooling1D)); /** @nocollapse */ AveragePooling1D.className = 'AveragePooling1D'; tfc.serialization.registerClass(AveragePooling1D); /** * Abstract class for different pooling 2D layers. */ var Pooling2D = /** @class */ (function (_super) { __extends(Pooling2D, _super); function Pooling2D(args) { var _this = this; if (args.poolSize == null) { args.poolSize = [2, 2]; } _this = _super.call(this, args) || this; _this.poolSize = Array.isArray(args.poolSize) ? args.poolSize : [args.poolSize, args.poolSize]; if (args.strides == null) { _this.strides = _this.poolSize; } else if (Array.isArray(args.strides)) { if (args.strides.length !== 2) { throw new ValueError("If the strides property of a 2D pooling layer is an Array, " + "it is expected to have a length of 2, but received length " + "".concat(args.strides.length, ".")); } _this.strides = args.strides; } else { // `config.strides` is a number. _this.strides = [args.strides, args.strides]; } assertPositiveInteger(_this.poolSize, 'poolSize'); assertPositiveInteger(_this.strides, 'strides'); _this.padding = args.padding == null ? 'valid' : args.padding; _this.dataFormat = args.dataFormat == null ? 'channelsLast' : args.dataFormat; checkDataFormat(_this.dataFormat); checkPaddingMode(_this.padding); _this.inputSpec = [new InputSpec({ ndim: 4 })]; return _this; } Pooling2D.prototype.computeOutputShape = function (inputShape) { inputShape = getExactlyOneShape(inputShape); var rows = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1]; var cols = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2]; rows = convOutputLength(rows, this.poolSize[0], this.padding, this.strides[0]); cols = convOutputLength(cols, this.poolSize[1], this.padding, this.strides[1]); if (this.dataFormat === 'channelsFirst') { return [inputShape[0], inputShape[1], rows, cols]; } else { return [inputShape[0], rows, cols, inputShape[3]]; } }; Pooling2D.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { _this.invokeCallHook(inputs, kwargs); return _this.poolingFunction(getExactlyOneTensor(inputs), _this.poolSize, _this.strides, _this.padding, _this.dataFormat); }); }; Pooling2D.prototype.getConfig = function () { var config = { poolSize: this.poolSize, padding: this.padding, strides: this.strides, dataFormat: this.dataFormat }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return Pooling2D; }(Layer)); var MaxPooling2D = /** @class */ (function (_super) { __extends(MaxPooling2D, _super); function MaxPooling2D(args) { return _super.call(this, args) || this; } MaxPooling2D.prototype.poolingFunction = function (inputs, poolSize, strides, padding, dataFormat) { checkDataFormat(dataFormat); checkPaddingMode(padding); return pool2d(inputs, poolSize, strides, padding, dataFormat, 'max'); }; return MaxPooling2D; }(Pooling2D)); /** @nocollapse */ MaxPooling2D.className = 'MaxPooling2D'; tfc.serialization.registerClass(MaxPooling2D); var AveragePooling2D = /** @class */ (function (_super) { __extends(AveragePooling2D, _super); function AveragePooling2D(args) { return _super.call(this, args) || this; } AveragePooling2D.prototype.poolingFunction = function (inputs, poolSize, strides, padding, dataFormat) { checkDataFormat(dataFormat); checkPaddingMode(padding); return pool2d(inputs, poolSize, strides, padding, dataFormat, 'avg'); }; return AveragePooling2D; }(Pooling2D)); /** @nocollapse */ AveragePooling2D.className = 'AveragePooling2D'; tfc.serialization.registerClass(AveragePooling2D); /** * Abstract class for different pooling 3D layers. */ var Pooling3D = /** @class */ (function (_super) { __extends(Pooling3D, _super); function Pooling3D(args) { var _this = this; if (args.poolSize == null) { args.poolSize = [2, 2, 2]; } _this = _super.call(this, args) || this; _this.poolSize = Array.isArray(args.poolSize) ? args.poolSize : [args.poolSize, args.poolSize, args.poolSize]; if (args.strides == null) { _this.strides = _this.poolSize; } else if (Array.isArray(args.strides)) { if (args.strides.length !== 3) { throw new ValueError("If the strides property of a 3D pooling layer is an Array, " + "it is expected to have a length of 3, but received length " + "".concat(args.strides.length, ".")); } _this.strides = args.strides; } else { // `config.strides` is a number. _this.strides = [args.strides, args.strides, args.strides]; } assertPositiveInteger(_this.poolSize, 'poolSize'); assertPositiveInteger(_this.strides, 'strides'); _this.padding = args.padding == null ? 'valid' : args.padding; _this.dataFormat = args.dataFormat == null ? 'channelsLast' : args.dataFormat; checkDataFormat(_this.dataFormat); checkPaddingMode(_this.padding); _this.inputSpec = [new InputSpec({ ndim: 5 })]; return _this; } Pooling3D.prototype.computeOutputShape = function (inputShape) { inputShape = getExactlyOneShape(inputShape); var depths = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1]; var rows = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2]; var cols = this.dataFormat === 'channelsFirst' ? inputShape[4] : inputShape[3]; depths = convOutputLength(depths, this.poolSize[0], this.padding, this.strides[0]); rows = convOutputLength(rows, this.poolSize[1], this.padding, this.strides[1]); cols = convOutputLength(cols, this.poolSize[2], this.padding, this.strides[2]); if (this.dataFormat === 'channelsFirst') { return [inputShape[0], inputShape[1], depths, rows, cols]; } else { return [inputShape[0], depths, rows, cols, inputShape[4]]; } }; Pooling3D.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { _this.invokeCallHook(inputs, kwargs); return _this.poolingFunction(getExactlyOneTensor(inputs), _this.poolSize, _this.strides, _this.padding, _this.dataFormat); }); }; Pooling3D.prototype.getConfig = function () { var config = { poolSize: this.poolSize, padding: this.padding, strides: this.strides, dataFormat: this.dataFormat }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return Pooling3D; }(Layer)); var MaxPooling3D = /** @class */ (function (_super) { __extends(MaxPooling3D, _super); function MaxPooling3D(args) { return _super.call(this, args) || this; } MaxPooling3D.prototype.poolingFunction = function (inputs, poolSize, strides, padding, dataFormat) { checkDataFormat(dataFormat); checkPaddingMode(padding); return pool3d(inputs, poolSize, strides, padding, dataFormat, 'max'); }; return MaxPooling3D; }(Pooling3D)); /** @nocollapse */ MaxPooling3D.className = 'MaxPooling3D'; tfc.serialization.registerClass(MaxPooling3D); var AveragePooling3D = /** @class */ (function (_super) { __extends(AveragePooling3D, _super); function AveragePooling3D(args) { return _super.call(this, args) || this; } AveragePooling3D.prototype.poolingFunction = function (inputs, poolSize, strides, padding, dataFormat) { checkDataFormat(dataFormat); checkPaddingMode(padding); return pool3d(inputs, poolSize, strides, padding, dataFormat, 'avg'); }; return AveragePooling3D; }(Pooling3D)); /** @nocollapse */ AveragePooling3D.className = 'AveragePooling3D'; tfc.serialization.registerClass(AveragePooling3D); /** * Abstract class for different global pooling 1D layers. */ var GlobalPooling1D = /** @class */ (function (_super) { __extends(GlobalPooling1D, _super); function GlobalPooling1D(args) { var _this = _super.call(this, args) || this; _this.inputSpec = [new InputSpec({ ndim: 3 })]; return _this; } GlobalPooling1D.prototype.computeOutputShape = function (inputShape) { return [inputShape[0], inputShape[2]]; }; GlobalPooling1D.prototype.call = function (inputs, kwargs) { throw new NotImplementedError(); }; return GlobalPooling1D; }(Layer)); var GlobalAveragePooling1D = /** @class */ (function (_super) { __extends(GlobalAveragePooling1D, _super); function GlobalAveragePooling1D(args) { return _super.call(this, args || {}) || this; } GlobalAveragePooling1D.prototype.call = function (inputs, kwargs) { return tfc.tidy(function () { var input = getExactlyOneTensor(inputs); return tfc__namespace.mean(input, 1); }); }; return GlobalAveragePooling1D; }(GlobalPooling1D)); /** @nocollapse */ GlobalAveragePooling1D.className = 'GlobalAveragePooling1D'; tfc.serialization.registerClass(GlobalAveragePooling1D); var GlobalMaxPooling1D = /** @class */ (function (_super) { __extends(GlobalMaxPooling1D, _super); function GlobalMaxPooling1D(args) { return _super.call(this, args || {}) || this; } GlobalMaxPooling1D.prototype.call = function (inputs, kwargs) { return tfc.tidy(function () { var input = getExactlyOneTensor(inputs); return tfc__namespace.max(input, 1); }); }; return GlobalMaxPooling1D; }(GlobalPooling1D)); /** @nocollapse */ GlobalMaxPooling1D.className = 'GlobalMaxPooling1D'; tfc.serialization.registerClass(GlobalMaxPooling1D); /** * Abstract class for different global pooling 2D layers. */ var GlobalPooling2D = /** @class */ (function (_super) { __extends(GlobalPooling2D, _super); function GlobalPooling2D(args) { var _this = _super.call(this, args) || this; _this.dataFormat = args.dataFormat == null ? 'channelsLast' : args.dataFormat; checkDataFormat(_this.dataFormat); _this.inputSpec = [new InputSpec({ ndim: 4 })]; return _this; } GlobalPooling2D.prototype.computeOutputShape = function (inputShape) { inputShape = inputShape; if (this.dataFormat === 'channelsLast') { return [inputShape[0], inputShape[3]]; } else { return [inputShape[0], inputShape[1]]; } }; GlobalPooling2D.prototype.call = function (inputs, kwargs) { throw new NotImplementedError(); }; GlobalPooling2D.prototype.getConfig = function () { var config = { dataFormat: this.dataFormat }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return GlobalPooling2D; }(Layer)); var GlobalAveragePooling2D = /** @class */ (function (_super) { __extends(GlobalAveragePooling2D, _super); function GlobalAveragePooling2D() { return _super !== null && _super.apply(this, arguments) || this; } GlobalAveragePooling2D.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { var input = getExactlyOneTensor(inputs); if (_this.dataFormat === 'channelsLast') { return tfc__namespace.mean(input, [1, 2]); } else { return tfc__namespace.mean(input, [2, 3]); } }); }; return GlobalAveragePooling2D; }(GlobalPooling2D)); /** @nocollapse */ GlobalAveragePooling2D.className = 'GlobalAveragePooling2D'; tfc.serialization.registerClass(GlobalAveragePooling2D); var GlobalMaxPooling2D = /** @class */ (function (_super) { __extends(GlobalMaxPooling2D, _super); function GlobalMaxPooling2D() { return _super !== null && _super.apply(this, arguments) || this; } GlobalMaxPooling2D.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { var input = getExactlyOneTensor(inputs); if (_this.dataFormat === 'channelsLast') { return tfc__namespace.max(input, [1, 2]); } else { return tfc__namespace.max(input, [2, 3]); } }); }; return GlobalMaxPooling2D; }(GlobalPooling2D)); /** @nocollapse */ GlobalMaxPooling2D.className = 'GlobalMaxPooling2D'; tfc.serialization.registerClass(GlobalMaxPooling2D); /** * Abstract wrapper base class. * * Wrappers take another layer and augment it in various ways. * Do not use this class as a layer, it is only an abstract base class. * Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers. */ var Wrapper = /** @class */ (function (_super) { __extends(Wrapper, _super); function Wrapper(args) { var _this = // Porting Note: In PyKeras, `self.layer` is set prior to the calling // `super()`. But we can't do that here due to TypeScript's restriction. // See: https://github.com/Microsoft/TypeScript/issues/8277 // As a result, we have to add checks in `get trainable()` and // `set trainable()` below in order to prevent using `this.layer` when // its value is `undefined`. The super constructor does use the getter // and the setter of `this.layer`. _super.call(this, args) || this; _this.layer = args.layer; return _this; } Wrapper.prototype.build = function (inputShape) { this.built = true; }; Object.defineProperty(Wrapper.prototype, "trainable", { // TODO(cais): Implement activityRegularizer getter. get: function () { // Porting Note: the check of `this.layer` here is necessary due to the // way the `constructor` of this class is written (see Porting Note // above). if (this.layer != null) { return this.layer.trainable; } else { return false; } }, set: function (value) { // Porting Note: the check of `this.layer` here is necessary due to the // way the `constructor` of this class is written (see Porting Note // above). if (this.layer != null) { this.layer.trainable = value; } }, enumerable: false, configurable: true }); Object.defineProperty(Wrapper.prototype, "trainableWeights", { get: function () { return this.layer.trainableWeights; }, enumerable: false, configurable: true }); Object.defineProperty(Wrapper.prototype, "nonTrainableWeights", { // TODO(cais): Implement setter for trainableWeights. get: function () { return this.layer.nonTrainableWeights; }, enumerable: false, configurable: true }); Object.defineProperty(Wrapper.prototype, "updates", { // TODO(cais): Implement setter for nonTrainableWeights. get: function () { // tslint:disable-next-line:no-any return this.layer._updates; }, enumerable: false, configurable: true }); Object.defineProperty(Wrapper.prototype, "losses", { // TODO(cais): Implement getUpdatesFor(). get: function () { return this.layer.losses; }, enumerable: false, configurable: true }); // TODO(cais): Implement getLossesFor(). Wrapper.prototype.getWeights = function () { return this.layer.getWeights(); }; Wrapper.prototype.setWeights = function (weights) { this.layer.setWeights(weights); }; Wrapper.prototype.getConfig = function () { var config = { 'layer': { 'className': this.layer.getClassName(), 'config': this.layer.getConfig(), } }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; Wrapper.prototype.setFastWeightInitDuringBuild = function (value) { _super.prototype.setFastWeightInitDuringBuild.call(this, value); if (this.layer != null) { this.layer.setFastWeightInitDuringBuild(value); } }; /** @nocollapse */ Wrapper.fromConfig = function (cls, config, customObjects) { if (customObjects === void 0) { customObjects = {}; } var layerConfig = config['layer']; var layer = deserialize(layerConfig, customObjects); delete config['layer']; var newConfig = { layer: layer }; Object.assign(newConfig, config); return new cls(newConfig); }; return Wrapper; }(Layer)); var TimeDistributed = /** @class */ (function (_super) { __extends(TimeDistributed, _super); function TimeDistributed(args) { var _this = _super.call(this, args) || this; _this.supportsMasking = true; return _this; } TimeDistributed.prototype.build = function (inputShape) { inputShape = getExactlyOneShape(inputShape); if (inputShape.length < 3) { throw new ValueError("TimeDistributed layer expects an input shape >= 3D, but received " + "input shape ".concat(JSON.stringify(inputShape))); } this.inputSpec = [{ shape: inputShape }]; var childInputShape = [inputShape[0]].concat(inputShape.slice(2)); if (!this.layer.built) { this.layer.build(childInputShape); this.layer.built = true; } _super.prototype.build.call(this, inputShape); }; TimeDistributed.prototype.computeOutputShape = function (inputShape) { inputShape = getExactlyOneShape(inputShape); var childInputShape = [inputShape[0]].concat(inputShape.slice(2)); var childOutputShape = this.layer.computeOutputShape(childInputShape); var timesteps = inputShape[1]; return [childOutputShape[0], timesteps].concat(childOutputShape.slice(1)); }; TimeDistributed.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { // TODO(cais): Add 'training' and 'useLearningPhase' to kwargs. inputs = getExactlyOneTensor(inputs); // Porting Note: In tfjs-layers, `inputs` are always concrete tensor // values. Hence the inputs can't have an undetermined first (batch) // dimension, which is why we always use the K.rnn approach here. var step = function (inputs, states) { // TODO(cais): Add useLearningPhase. // NOTE(cais): `layer.call` may return a length-1 array of Tensor in // some cases (e.g., `layer` is a `Sequential` instance), which is // why `getExactlyOneTensor` is used below. var output = getExactlyOneTensor(_this.layer.call(inputs, kwargs)); return [output, []]; }; var rnnOutputs = rnn$1(step, inputs, [], false /* goBackwards */, null /* mask */, null /* constants */, false /* unroll */, true /* needPerStepOutputs */); var y = rnnOutputs[1]; // TODO(cais): Add activity regularization. // TODO(cais): Add useLearningPhase. return y; }); }; return TimeDistributed; }(Wrapper)); /** @nocollapse */ TimeDistributed.className = 'TimeDistributed'; tfc.serialization.registerClass(TimeDistributed); function checkBidirectionalMergeMode(value) { checkStringTypeUnionValue(VALID_BIDIRECTIONAL_MERGE_MODES, 'BidirectionalMergeMode', value); } var DEFAULT_BIDIRECTIONAL_MERGE_MODE = 'concat'; var Bidirectional = /** @class */ (function (_super) { __extends(Bidirectional, _super); function Bidirectional(args) { var _this = _super.call(this, args) || this; // Note: When creating `this.forwardLayer`, the original Layer object // (`config.layer`) ought to be cloned. This is why we call // `getConfig()` followed by `deserialize()`. Without this cloning, // the layer names saved during serialization will incorrectly contain // the 'forward_' prefix. In Python Keras, this is done using // `copy.copy` (shallow copy), which does not have a simple equivalent // in JavaScript. JavaScript's `Object.assign()` does not copy // methods. var layerConfig = args.layer.getConfig(); var forwDict = {}; forwDict['className'] = args.layer.getClassName(); forwDict['config'] = layerConfig; _this.forwardLayer = deserialize(forwDict); layerConfig['goBackwards'] = layerConfig['goBackwards'] === true ? false : true; var backDict = {}; backDict['className'] = args.layer.getClassName(); backDict['config'] = layerConfig; _this.backwardLayer = deserialize(backDict); _this.forwardLayer.name = 'forward_' + _this.forwardLayer.name; _this.backwardLayer.name = 'backward_' + _this.backwardLayer.name; _this.mergeMode = args.mergeMode === undefined ? DEFAULT_BIDIRECTIONAL_MERGE_MODE : args.mergeMode; checkBidirectionalMergeMode(_this.mergeMode); if (args.weights) { throw new NotImplementedError('weights support is not implemented for Bidirectional layer yet.'); } _this._stateful = args.layer.stateful; _this.returnSequences = args.layer.returnSequences; _this.returnState = args.layer.returnState; _this.supportsMasking = true; _this._trainable = true; _this.inputSpec = args.layer.inputSpec; _this.numConstants = null; return _this; } Object.defineProperty(Bidirectional.prototype, "trainable", { get: function () { return this._trainable; }, set: function (value) { // Porting Note: the check of `this.layer` here is necessary due to the // way the `constructor` of this class is written (see Porting Note // above). this._trainable = value; if (this.forwardLayer != null) { this.forwardLayer.trainable = value; } if (this.backwardLayer != null) { this.backwardLayer.trainable = value; } }, enumerable: false, configurable: true }); Bidirectional.prototype.getWeights = function () { return this.forwardLayer.getWeights().concat(this.backwardLayer.getWeights()); }; Bidirectional.prototype.setWeights = function (weights) { var numWeights = weights.length; var numeightsOver2 = Math.floor(numWeights / 2); this.forwardLayer.setWeights(weights.slice(0, numeightsOver2)); this.backwardLayer.setWeights(weights.slice(numeightsOver2)); }; Bidirectional.prototype.computeOutputShape = function (inputShape) { var layerShapes = this.forwardLayer.computeOutputShape(inputShape); if (!(Array.isArray(layerShapes) && Array.isArray(layerShapes[0]))) { layerShapes = [layerShapes]; } layerShapes = layerShapes; var outputShape; var outputShapes; var stateShape; if (this.returnState) { stateShape = layerShapes.slice(1); outputShape = layerShapes[0]; } else { outputShape = layerShapes[0]; } outputShape = outputShape; if (this.mergeMode === 'concat') { outputShape[outputShape.length - 1] *= 2; outputShapes = [outputShape]; } else if (this.mergeMode == null) { outputShapes = [outputShape, outputShape.slice()]; } else { outputShapes = [outputShape]; } if (this.returnState) { if (this.mergeMode == null) { return outputShapes.concat(stateShape).concat(stateShape.slice()); } return [outputShape].concat(stateShape).concat(stateShape.slice()); } return singletonOrArray(outputShapes); }; Bidirectional.prototype.apply = function (inputs, kwargs) { var e_1, _a; var initialState = kwargs == null ? null : kwargs['initialState']; var constants = kwargs == null ? null : kwargs['constants']; if (kwargs == null) { kwargs = {}; } var standardized = standardizeArgs(inputs, initialState, constants, this.numConstants); inputs = standardized.inputs; initialState = standardized.initialState; constants = standardized.constants; if (Array.isArray(inputs)) { initialState = inputs.slice(1); inputs = inputs[0]; } if ((initialState == null || initialState.length === 0) && constants == null) { return _super.prototype.apply.call(this, inputs, kwargs); } var additionalInputs = []; var additionalSpecs = []; if (initialState != null) { var numStates = initialState.length; if (numStates % 2 > 0) { throw new ValueError('When passing `initialState` to a Bidrectional RNN, ' + 'the state should be an Array containing the states of ' + 'the underlying RNNs.'); } kwargs['initialState'] = initialState; additionalInputs.push.apply(additionalInputs, __spreadArray([], __read(initialState), false)); var stateSpecs = initialState .map(function (state) { return new InputSpec({ shape: state.shape }); }); this.forwardLayer.stateSpec = stateSpecs.slice(0, numStates / 2); this.backwardLayer.stateSpec = stateSpecs.slice(numStates / 2); additionalSpecs.push.apply(additionalSpecs, __spreadArray([], __read(stateSpecs), false)); } if (constants != null) { throw new NotImplementedError('Support for constants in Bidirectional layers is not ' + 'implemented yet.'); } var isSymbolicTensor = additionalInputs[0] instanceof SymbolicTensor; try { for (var additionalInputs_1 = __values(additionalInputs), additionalInputs_1_1 = additionalInputs_1.next(); !additionalInputs_1_1.done; additionalInputs_1_1 = additionalInputs_1.next()) { var tensor = additionalInputs_1_1.value; if (tensor instanceof SymbolicTensor !== isSymbolicTensor) { throw new ValueError('The initial state of a Bidirectional layer cannot be ' + 'specified as a mix of symbolic and non-symbolic tensors'); } } } catch (e_1_1) { e_1 = { error: e_1_1 }; } finally { try { if (additionalInputs_1_1 && !additionalInputs_1_1.done && (_a = additionalInputs_1.return)) _a.call(additionalInputs_1); } finally { if (e_1) throw e_1.error; } } if (isSymbolicTensor) { // Compute the full input and specs, including the states. var fullInput = [inputs].concat(additionalInputs); var fullInputSpec = this.inputSpec.concat(additionalSpecs); // Perform the call temporarily and replace inputSpec. // Note: with initial states symbolic calls and non-symbolic calls to // this method differ in how the initial states are passed. For // symbolic calls, the initial states are passed in the first arg, as // an Array of SymbolicTensors; for non-symbolic calls, they are // passed in the second arg as a part of the kwargs. Hence the need to // temporarily modify inputSpec here. // TODO(cais): Make refactoring so that this hacky code below is no // longer needed. var originalInputSpec = this.inputSpec; this.inputSpec = fullInputSpec; var output = _super.prototype.apply.call(this, fullInput, kwargs); this.inputSpec = originalInputSpec; return output; } else { return _super.prototype.apply.call(this, inputs, kwargs); } }; Bidirectional.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { var initialState = kwargs['initialState']; var y; var yRev; if (initialState == null) { y = _this.forwardLayer.call(inputs, kwargs); yRev = _this.backwardLayer.call(inputs, kwargs); } else { var forwardState = initialState.slice(0, initialState.length / 2); var backwardState = initialState.slice(initialState.length / 2); y = _this.forwardLayer.call(inputs, Object.assign(kwargs, { initialState: forwardState })); yRev = _this.backwardLayer.call(inputs, Object.assign(kwargs, { initialState: backwardState })); } var states; if (_this.returnState) { if (Array.isArray(y)) { states = y.slice(1).concat(yRev.slice(1)); } y = y[0]; yRev = yRev[0]; } if (_this.returnSequences) { yRev = tfc__namespace.reverse(yRev, 1); } var output; if (_this.mergeMode === 'concat') { output = concatenate$1([y, yRev]); } else if (_this.mergeMode === 'sum') { output = tfc__namespace.add(y, yRev); } else if (_this.mergeMode === 'ave') { output = tfc__namespace.mul(.5, tfc__namespace.add(y, yRev)); } else if (_this.mergeMode === 'mul') { output = tfc__namespace.mul(y, yRev); } else if (_this.mergeMode == null) { output = [y, yRev]; } // TODO(cais): Properly set learning phase. if (_this.returnState) { if (_this.mergeMode == null) { return output.concat(states); } return [output].concat(states); } return output; }); }; Bidirectional.prototype.resetStates = function (states) { this.forwardLayer.resetStates(); this.backwardLayer.resetStates(); }; Bidirectional.prototype.build = function (inputShape) { var _this = this; nameScope(this.forwardLayer.name, function () { _this.forwardLayer.build(inputShape); }); nameScope(this.backwardLayer.name, function () { _this.backwardLayer.build(inputShape); }); this.built = true; }; Bidirectional.prototype.computeMask = function (inputs, mask) { if (Array.isArray(mask)) { mask = mask[0]; } var outputMask; if (this.returnSequences) { if (this.mergeMode == null) { outputMask = [mask, mask]; } else { outputMask = mask; } } else { if (this.mergeMode == null) { outputMask = [null, null]; } else { outputMask = null; } } if (this.returnState) { var states = this.forwardLayer.states; var stateMask = states.map(function (state) { return null; }); if (Array.isArray(outputMask)) { return outputMask.concat(stateMask).concat(stateMask); } else { return [outputMask].concat(stateMask).concat(stateMask); } } else { return outputMask; } }; Object.defineProperty(Bidirectional.prototype, "trainableWeights", { get: function () { return this.forwardLayer.trainableWeights.concat(this.backwardLayer.trainableWeights); }, enumerable: false, configurable: true }); Object.defineProperty(Bidirectional.prototype, "nonTrainableWeights", { get: function () { return this.forwardLayer.nonTrainableWeights.concat(this.backwardLayer.nonTrainableWeights); }, enumerable: false, configurable: true }); // TODO(cais): Implement constraints(). Bidirectional.prototype.setFastWeightInitDuringBuild = function (value) { _super.prototype.setFastWeightInitDuringBuild.call(this, value); if (this.forwardLayer != null) { this.forwardLayer.setFastWeightInitDuringBuild(value); } if (this.backwardLayer != null) { this.backwardLayer.setFastWeightInitDuringBuild(value); } }; Bidirectional.prototype.getConfig = function () { var config = { 'mergeMode': this.mergeMode, }; // TODO(cais): Add logic for `numConstants` once the property is added. var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; /** @nocollapse */ Bidirectional.fromConfig = function (cls, config) { var rnnLayer = deserialize(config['layer']); delete config['layer']; // TODO(cais): Add logic for `numConstants` once the property is added. if (config['numConstants'] != null) { throw new NotImplementedError("Deserialization of a Bidirectional layer with numConstants " + "present is not supported yet."); } // tslint:disable-next-line:no-any var newConfig = config; newConfig['layer'] = rnnLayer; return new cls(newConfig); }; return Bidirectional; }(Wrapper)); /** @nocollapse */ Bidirectional.className = 'Bidirectional'; tfc.serialization.registerClass(Bidirectional); /** * Preprocessing Rescaling Layer * * This rescales images by a scaling and offset factor */ var Rescaling = /** @class */ (function (_super) { __extends(Rescaling, _super); function Rescaling(args) { var _this = _super.call(this, args) || this; _this.scale = args.scale; if (args.offset) { _this.offset = args.offset; } else { _this.offset = 0; } return _this; } Rescaling.prototype.getConfig = function () { var config = { 'scale': this.scale, 'offset': this.offset }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; Rescaling.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { inputs = getExactlyOneTensor(inputs); if (inputs.dtype !== 'float32') { inputs = cast$1(inputs, 'float32'); } return tfc.add(tfc.mul(inputs, _this.scale), _this.offset); }); }; return Rescaling; }(Layer)); /** @nocollapse */ Rescaling.className = 'Rescaling'; tfc.serialization.registerClass(Rescaling); var resizeBilinear = tfc.image.resizeBilinear, cropAndResize = tfc.image.cropAndResize; var CenterCrop = /** @class */ (function (_super) { __extends(CenterCrop, _super); function CenterCrop(args) { var _this = _super.call(this, args) || this; _this.height = args.height; _this.width = args.width; return _this; } CenterCrop.prototype.centerCrop = function (inputs, hBuffer, wBuffer, height, width, inputHeight, inputWidth, dtype) { return tfc.tidy(function () { var input; var isRank3 = false; var top = hBuffer / inputHeight; var left = wBuffer / inputWidth; var bottom = ((height) + hBuffer) / inputHeight; var right = ((width) + wBuffer) / inputWidth; var bound = [top, left, bottom, right]; var boxesArr = []; if (inputs.rank === 3) { isRank3 = true; input = tfc.stack([inputs]); } else { input = inputs; } for (var i = 0; i < input.shape[0]; i++) { boxesArr.push(bound); } var boxes = tfc.tensor(boxesArr, [boxesArr.length, 4]); var boxInd = tfc.range(0, boxesArr.length, 1, 'int32'); var cropSize = [height, width]; var cropped = cropAndResize(input, boxes, boxInd, cropSize, 'nearest'); if (isRank3) { return cast$1(getExactlyOneTensor(tfc.unstack(cropped)), dtype); } return cast$1(cropped, dtype); }); }; CenterCrop.prototype.upsize = function (inputs, height, width, dtype) { return tfc.tidy(function () { var outputs = resizeBilinear(inputs, [height, width]); return cast$1(outputs, dtype); }); }; CenterCrop.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { var rankedInputs = getExactlyOneTensor(inputs); var dtype = rankedInputs.dtype; var inputShape = rankedInputs.shape; var inputHeight = inputShape[inputShape.length - 3]; var inputWidth = inputShape[inputShape.length - 2]; var hBuffer = 0; if (inputHeight !== _this.height) { hBuffer = Math.floor((inputHeight - _this.height) / 2); } var wBuffer = 0; if (inputWidth !== _this.width) { wBuffer = Math.floor((inputWidth - _this.width) / 2); if (wBuffer === 0) { wBuffer = 1; } } if (hBuffer >= 0 && wBuffer >= 0) { return _this.centerCrop(rankedInputs, hBuffer, wBuffer, _this.height, _this.width, inputHeight, inputWidth, dtype); } else { return _this.upsize(inputs, _this.height, _this.width, dtype); } }); }; CenterCrop.prototype.getConfig = function () { var config = { 'height': this.height, 'width': this.width }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; CenterCrop.prototype.computeOutputShape = function (inputShape) { inputShape = getExactlyOneShape(inputShape); var hAxis = inputShape.length - 3; var wAxis = inputShape.length - 2; inputShape[hAxis] = this.height; inputShape[wAxis] = this.width; return inputShape; }; return CenterCrop; }(Layer)); /** @nocollapse */ CenterCrop.className = 'CenterCrop'; tfc.serialization.registerClass(CenterCrop); /** * @license * Copyright 2022 CodeSmith LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ function encodeCategoricalInputs(inputs, outputMode, depth, weights) { var input = getExactlyOneTensor(inputs); if (input.dtype !== 'int32') { input = cast$1(input, 'int32'); } if (outputMode === 'int') { return input; } var originalShape = input.shape; if (input.rank === 0) { input = tfc.expandDims(input, -1); } if (outputMode === 'oneHot') { if (input.shape[input.shape.length - 1] !== 1) { input = tfc.expandDims(input, -1); } } if (input.rank > 2) { throw new ValueError("When outputMode is not int, maximum output rank is 2" + " Received outputMode ".concat(outputMode, " and input shape ").concat(originalShape) + " which would result in output rank ".concat(input.rank, ".")); } var binaryOutput = ['multiHot', 'oneHot'].includes(outputMode); var denseBincountInput = input; var binCounts; if ((typeof weights) !== 'undefined' && outputMode === 'count') { binCounts = tfc.denseBincount(denseBincountInput, weights, depth, binaryOutput); } else { binCounts = tfc.denseBincount(denseBincountInput, [], depth, binaryOutput); } if (outputMode !== 'tfIdf') { return binCounts; } if (weights) { return tfc.mul(binCounts, weights); } else { throw new ValueError("When outputMode is 'tfIdf', weights must be provided."); } } var CategoryEncoding = /** @class */ (function (_super) { __extends(CategoryEncoding, _super); function CategoryEncoding(args) { var _this = _super.call(this, args) || this; _this.numTokens = args.numTokens; if (args.outputMode) { _this.outputMode = args.outputMode; } else { _this.outputMode = 'multiHot'; } return _this; } CategoryEncoding.prototype.getConfig = function () { var config = { 'numTokens': this.numTokens, 'outputMode': this.outputMode, }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; CategoryEncoding.prototype.computeOutputShape = function (inputShape) { inputShape = getExactlyOneShape(inputShape); if (inputShape == null) { return [this.numTokens]; } if (this.outputMode === 'oneHot' && inputShape[inputShape.length - 1] !== 1) { inputShape.push(this.numTokens); return inputShape; } inputShape[inputShape.length - 1] = this.numTokens; return inputShape; }; CategoryEncoding.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { inputs = getExactlyOneTensor(inputs); if (inputs.dtype !== 'int32') { inputs = cast$1(inputs, 'int32'); } var countWeights; if ((typeof kwargs['countWeights']) !== 'undefined') { if (_this.outputMode !== 'count') { throw new ValueError("countWeights is not used when outputMode !== count.\n Received countWeights=".concat(kwargs['countWeights'])); } countWeights = getExactlyOneTensor(kwargs['countWeights']); } var maxValue = tfc.max(inputs); var minValue = tfc.min(inputs); var greaterEqualMax = tfc.greater(_this.numTokens, maxValue) .bufferSync().get(0); var greaterMin = tfc.greaterEqual(minValue, 0).bufferSync().get(0); if (!(greaterEqualMax && greaterMin)) { throw new ValueError('Input values must be between 0 < values <=' + " numTokens with numTokens=".concat(_this.numTokens)); } return encodeCategoricalInputs(inputs, _this.outputMode, _this.numTokens, countWeights); }); }; return CategoryEncoding; }(Layer)); /** @nocollapse */ CategoryEncoding.className = 'CategoryEncoding'; tfc.serialization.registerClass(CategoryEncoding); // tf methods unimplemented in tfjs: 'bicubic', 'area', 'lanczos3', 'lanczos5', // 'gaussian', 'mitchellcubic' var INTERPOLATION_KEYS$1 = ['bilinear', 'nearest']; var INTERPOLATION_METHODS$1 = new Set(INTERPOLATION_KEYS$1); /** * Preprocessing Resizing Layer * * This resizes images by a scaling and offset factor */ var Resizing = /** @class */ (function (_super) { __extends(Resizing, _super); function Resizing(args) { var _this = _super.call(this, args) || this; _this.height = args.height; _this.width = args.width; if (args.interpolation) { if (INTERPOLATION_METHODS$1.has(args.interpolation)) { _this.interpolation = args.interpolation; } else { throw new ValueError("Invalid interpolation parameter: ".concat(args.interpolation, " is not implemented")); } } else { _this.interpolation = 'bilinear'; } _this.cropToAspectRatio = Boolean(args.cropToAspectRatio); return _this; } Resizing.prototype.computeOutputShape = function (inputShape) { inputShape = getExactlyOneShape(inputShape); var numChannels = inputShape[2]; return [this.height, this.width, numChannels]; }; Resizing.prototype.getConfig = function () { var config = { 'height': this.height, 'width': this.width, 'interpolation': this.interpolation, 'cropToAspectRatio': this.cropToAspectRatio }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; Resizing.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { var size = [_this.height, _this.width]; if (_this.interpolation === 'bilinear') { return tfc.image.resizeBilinear(inputs, size, !_this.cropToAspectRatio); } else if (_this.interpolation === 'nearest') { return tfc.image.resizeNearestNeighbor(inputs, size, !_this.cropToAspectRatio); } else { throw new Error("Interpolation is ".concat(_this.interpolation, " but only ").concat(__spreadArray([], __read(INTERPOLATION_METHODS$1), false), " are supported")); } }); }; return Resizing; }(Layer)); /** @nocollapse */ Resizing.className = 'Resizing'; tfc.serialization.registerClass(Resizing); /** * @license * Copyright 2023 CodeSmith LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Keeps track of seed and handles pseudorandomness * Instance created in BaseRandomLayer class * Utilized for random preprocessing layers */ var RandomSeed = /** @class */ (function () { function RandomSeed(seed) { this.seed = seed; } RandomSeed.prototype.next = function () { if (this.seed === undefined) { return undefined; } return this.seed++; }; return RandomSeed; }()); RandomSeed.className = 'RandomSeed'; var BaseRandomLayer = /** @class */ (function (_super) { __extends(BaseRandomLayer, _super); function BaseRandomLayer(args) { var _this = _super.call(this, args) || this; _this.randomGenerator = new RandomSeed(args.seed); return _this; } BaseRandomLayer.prototype.getConfig = function () { var config = { 'seed': this.randomGenerator.seed }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; return BaseRandomLayer; }(Layer)); // A layer handle the random number creation and savemodel behavior. /** @nocollapse */ BaseRandomLayer.className = 'BaseRandomLayer'; var INTERPOLATION_KEYS = ['bilinear', 'nearest']; var INTERPOLATION_METHODS = new Set(INTERPOLATION_KEYS); /** * Preprocessing Layer with randomly varies image during training * * This layer randomly adjusts the width of a batch of images of a * batch of images by a random factor. * * The input should be a 3D (unbatched) or * 4D (batched) tensor in the `"channels_last"` image data format. Input pixel * values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and of interger * or floating point dtype. By default, the layer will output floats. * * tf methods implemented in tfjs: 'bilinear', 'nearest', * tf methods unimplemented in tfjs: 'bicubic', 'area', 'lanczos3', 'lanczos5', * 'gaussian', 'mitchellcubic' * */ var RandomWidth = /** @class */ (function (_super) { __extends(RandomWidth, _super); function RandomWidth(args) { var _this = _super.call(this, args) || this; var factor = args.factor, _a = args.interpolation, interpolation = _a === void 0 ? 'bilinear' : _a; _this.factor = factor; if (Array.isArray(_this.factor) && _this.factor.length === 2) { _this.widthLower = _this.factor[0]; _this.widthUpper = _this.factor[1]; } else if (!Array.isArray(_this.factor) && _this.factor > 0) { _this.widthLower = -_this.factor; _this.widthUpper = _this.factor; } else { throw new ValueError("Invalid factor: ".concat(_this.factor, ". Must be positive number or tuple of 2 numbers")); } if (_this.widthLower < -1.0 || _this.widthUpper < -1.0) { throw new ValueError("factor must have values larger than -1. Got: ".concat(_this.factor)); } if (_this.widthUpper < _this.widthLower) { throw new ValueError("factor cannot have upper bound less than lower bound.\n Got upper bound: ".concat(_this.widthUpper, ".\n Got lower bound: ").concat(_this.widthLower, "\n ")); } if (interpolation) { if (INTERPOLATION_METHODS.has(interpolation)) { _this.interpolation = interpolation; } else { throw new ValueError("Invalid interpolation parameter: ".concat(interpolation, " is not implemented")); } } return _this; } RandomWidth.prototype.getConfig = function () { var config = { 'factor': this.factor, 'interpolation': this.interpolation, }; var baseConfig = _super.prototype.getConfig.call(this); Object.assign(config, baseConfig); return config; }; RandomWidth.prototype.computeOutputShape = function (inputShape) { inputShape = getExactlyOneShape(inputShape); var numChannels = inputShape[2]; return [this.imgHeight, -1, numChannels]; }; RandomWidth.prototype.call = function (inputs, kwargs) { var _this = this; return tfc.tidy(function () { var input = getExactlyOneTensor(inputs); _this.imgHeight = input.shape[input.shape.length - 3]; var imgWidth = input.shape[input.shape.length - 2]; _this.widthFactor = tfc.randomUniform([1], (1.0 + _this.widthLower), (1.0 + _this.widthUpper), 'float32', _this.randomGenerator.next()); var adjustedWidth = _this.widthFactor.dataSync()[0] * imgWidth; adjustedWidth = Math.round(adjustedWidth); var size = [_this.imgHeight, adjustedWidth]; switch (_this.interpolation) { case 'bilinear': return tfc.image.resizeBilinear(inputs, size); case 'nearest': return tfc.image.resizeNearestNeighbor(inputs, size); default: throw new Error("Interpolation is ".concat(_this.interpolation, "\n but only ").concat(__spreadArray([], __read(INTERPOLATION_METHODS), false), " are supported")); } }); }; return RandomWidth; }(BaseRandomLayer)); /** @nocollapse */ RandomWidth.className = 'RandomWidth'; tfc.serialization.registerClass(RandomWidth); /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ // TODO(cais): Add doc string to all the public static functions in this // class; include exectuable JavaScript code snippets where applicable // (b/74074458). // Input Layer. /** * An input layer is an entry point into a `tf.LayersModel`. * * `InputLayer` is generated automatically for `tf.Sequential` models by * specifying the `inputshape` or `batchInputShape` for the first layer. It * should not be specified explicitly. However, it can be useful sometimes, * e.g., when constructing a sequential model from a subset of another * sequential model's layers. Like the code snippet below shows. * * ```js * // Define a model which simply adds two inputs. * const model1 = tf.sequential(); * model1.add(tf.layers.dense({inputShape: [4], units: 3, activation: 'relu'})); * model1.add(tf.layers.dense({units: 1, activation: 'sigmoid'})); * model1.summary(); * model1.predict(tf.zeros([1, 4])).print(); * * // Construct another model, reusing the second layer of `model1` while * // not using the first layer of `model1`. Note that you cannot add the second * // layer of `model` directly as the first layer of the new sequential model, * // because doing so will lead to an error related to the fact that the layer * // is not an input layer. Instead, you need to create an `inputLayer` and add * // it to the new sequential model before adding the reused layer. * const model2 = tf.sequential(); * // Use an inputShape that matches the input shape of `model1`'s second * // layer. * model2.add(tf.layers.inputLayer({inputShape: [3]})); * model2.add(model1.layers[1]); * model2.summary(); * model2.predict(tf.zeros([1, 3])).print(); * ``` * * @doc {heading: 'Layers', subheading: 'Inputs', namespace: 'layers'} */ function inputLayer(args) { return new InputLayer(args); } // Advanced Activation Layers. /** * Exponential Linear Unit (ELU). * * It follows: * `f(x) = alpha * (exp(x) - 1.) for x < 0`, * `f(x) = x for x >= 0`. * * Input shape: * Arbitrary. Use the configuration `inputShape` when using this layer as the * first layer in a model. * * Output shape: * Same shape as the input. * * References: * - [Fast and Accurate Deep Network Learning by Exponential Linear Units * (ELUs)](https://arxiv.org/abs/1511.07289v1) * * @doc { * heading: 'Layers', * subheading: 'Advanced Activation', * namespace: 'layers' * } */ function elu(args) { return new ELU(args); } /** * Rectified Linear Unit activation function. * * Input shape: * Arbitrary. Use the config field `inputShape` (Array of integers, does * not include the sample axis) when using this layer as the first layer * in a model. * * Output shape: * Same shape as the input. * * @doc { * heading: 'Layers', * subheading: 'Advanced Activation', * namespace: 'layers' * } */ function reLU(args) { return new ReLU(args); } /** * Leaky version of a rectified linear unit. * * It allows a small gradient when the unit is not active: * `f(x) = alpha * x for x < 0.` * `f(x) = x for x >= 0.` * * Input shape: * Arbitrary. Use the configuration `inputShape` when using this layer as the * first layer in a model. * * Output shape: * Same shape as the input. * * @doc { * heading: 'Layers', * subheading: 'Advanced Activation', * namespace: 'layers' * } */ function leakyReLU(args) { return new LeakyReLU(args); } /** * Parameterized version of a leaky rectified linear unit. * * It follows * `f(x) = alpha * x for x < 0.` * `f(x) = x for x >= 0.` * wherein `alpha` is a trainable weight. * * Input shape: * Arbitrary. Use the configuration `inputShape` when using this layer as the * first layer in a model. * * Output shape: * Same shape as the input. * * @doc { * heading: 'Layers', * subheading: 'Advanced Activation', * namespace: 'layers' * } */ function prelu(args) { return new PReLU(args); } /** * Softmax activation layer. * * Input shape: * Arbitrary. Use the configuration `inputShape` when using this layer as the * first layer in a model. * * Output shape: * Same shape as the input. * * @doc { * heading: 'Layers', * subheading: 'Advanced Activation', * namespace: 'layers' * } */ function softmax(args) { return new Softmax(args); } /** * Thresholded Rectified Linear Unit. * * It follows: * `f(x) = x for x > theta`, * `f(x) = 0 otherwise`. * * Input shape: * Arbitrary. Use the configuration `inputShape` when using this layer as the * first layer in a model. * * Output shape: * Same shape as the input. * * References: * - [Zero-Bias Autoencoders and the Benefits of Co-Adapting * Features](http://arxiv.org/abs/1402.3337) * * @doc { * heading: 'Layers', * subheading: 'Advanced Activation', * namespace: 'layers' * } */ function thresholdedReLU(args) { return new ThresholdedReLU(args); } // Convolutional Layers. /** * 1D convolution layer (e.g., temporal convolution). * * This layer creates a convolution kernel that is convolved * with the layer input over a single spatial (or temporal) dimension * to produce a tensor of outputs. * * If `use_bias` is True, a bias vector is created and added to the outputs. * * If `activation` is not `null`, it is applied to the outputs as well. * * When using this layer as the first layer in a model, provide an * `inputShape` argument `Array` or `null`. * * For example, `inputShape` would be: * - `[10, 128]` for sequences of 10 vectors of 128-dimensional vectors * - `[null, 128]` for variable-length sequences of 128-dimensional vectors. * * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'} */ function conv1d(args) { return new Conv1D(args); } /** * 2D convolution layer (e.g. spatial convolution over images). * * This layer creates a convolution kernel that is convolved * with the layer input to produce a tensor of outputs. * * If `useBias` is True, a bias vector is created and added to the outputs. * * If `activation` is not `null`, it is applied to the outputs as well. * * When using this layer as the first layer in a model, * provide the keyword argument `inputShape` * (Array of integers, does not include the sample axis), * e.g. `inputShape=[128, 128, 3]` for 128x128 RGB pictures * in `dataFormat='channelsLast'`. * * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'} */ function conv2d(args) { return new Conv2D(args); } /** * Transposed convolutional layer (sometimes called Deconvolution). * * The need for transposed convolutions generally arises * from the desire to use a transformation going in the opposite direction of * a normal convolution, i.e., from something that has the shape of the output * of some convolution to something that has the shape of its input while * maintaining a connectivity pattern that is compatible with said * convolution. * * When using this layer as the first layer in a model, provide the * configuration `inputShape` (`Array` of integers, does not include the * sample axis), e.g., `inputShape: [128, 128, 3]` for 128x128 RGB pictures in * `dataFormat: 'channelsLast'`. * * Input shape: * 4D tensor with shape: * `[batch, channels, rows, cols]` if `dataFormat` is `'channelsFirst'`. * or 4D tensor with shape * `[batch, rows, cols, channels]` if `dataFormat` is `'channelsLast'`. * * Output shape: * 4D tensor with shape: * `[batch, filters, newRows, newCols]` if `dataFormat` is * `'channelsFirst'`. or 4D tensor with shape: * `[batch, newRows, newCols, filters]` if `dataFormat` is `'channelsLast'`. * * References: * - [A guide to convolution arithmetic for deep * learning](https://arxiv.org/abs/1603.07285v1) * - [Deconvolutional * Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf) * * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'} */ function conv2dTranspose(args) { return new Conv2DTranspose(args); } /** * 3D convolution layer (e.g. spatial convolution over volumes). * * This layer creates a convolution kernel that is convolved * with the layer input to produce a tensor of outputs. * * If `useBias` is True, a bias vector is created and added to the outputs. * * If `activation` is not `null`, it is applied to the outputs as well. * * When using this layer as the first layer in a model, * provide the keyword argument `inputShape` * (Array of integers, does not include the sample axis), * e.g. `inputShape=[128, 128, 128, 1]` for 128x128x128 grayscale volumes * in `dataFormat='channelsLast'`. * * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'} */ function conv3d(args) { return new Conv3D(args); } function conv3dTranspose(args) { return new Conv3DTranspose(args); } /** * Depthwise separable 2D convolution. * * Separable convolution consists of first performing * a depthwise spatial convolution * (which acts on each input channel separately) * followed by a pointwise convolution which mixes together the resulting * output channels. The `depthMultiplier` argument controls how many * output channels are generated per input channel in the depthwise step. * * Intuitively, separable convolutions can be understood as * a way to factorize a convolution kernel into two smaller kernels, * or as an extreme version of an Inception block. * * Input shape: * 4D tensor with shape: * `[batch, channels, rows, cols]` if data_format='channelsFirst' * or 4D tensor with shape: * `[batch, rows, cols, channels]` if data_format='channelsLast'. * * Output shape: * 4D tensor with shape: * `[batch, filters, newRows, newCols]` if data_format='channelsFirst' * or 4D tensor with shape: * `[batch, newRows, newCols, filters]` if data_format='channelsLast'. * `rows` and `cols` values might have changed due to padding. * * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'} */ function separableConv2d(args) { return new SeparableConv2D(args); } /** * Cropping layer for 2D input (e.g., image). * * This layer can crop an input * at the top, bottom, left and right side of an image tensor. * * Input shape: * 4D tensor with shape: * - If `dataFormat` is `"channelsLast"`: * `[batch, rows, cols, channels]` * - If `data_format` is `"channels_first"`: * `[batch, channels, rows, cols]`. * * Output shape: * 4D with shape: * - If `dataFormat` is `"channelsLast"`: * `[batch, croppedRows, croppedCols, channels]` * - If `dataFormat` is `"channelsFirst"`: * `[batch, channels, croppedRows, croppedCols]`. * * Examples * ```js * * const model = tf.sequential(); * model.add(tf.layers.cropping2D({cropping:[[2, 2], [2, 2]], * inputShape: [128, 128, 3]})); * //now output shape is [batch, 124, 124, 3] * ``` * * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'} */ function cropping2D(args) { return new Cropping2D(args); } /** * Upsampling layer for 2D inputs. * * Repeats the rows and columns of the data * by size[0] and size[1] respectively. * * * Input shape: * 4D tensor with shape: * - If `dataFormat` is `"channelsLast"`: * `[batch, rows, cols, channels]` * - If `dataFormat` is `"channelsFirst"`: * `[batch, channels, rows, cols]` * * Output shape: * 4D tensor with shape: * - If `dataFormat` is `"channelsLast"`: * `[batch, upsampledRows, upsampledCols, channels]` * - If `dataFormat` is `"channelsFirst"`: * `[batch, channels, upsampledRows, upsampledCols]` * * * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'} */ function upSampling2d(args) { return new UpSampling2D(args); } // Convolutional(depthwise) Layers. /** * Depthwise separable 2D convolution. * * Depthwise Separable convolutions consists in performing just the first step * in a depthwise spatial convolution (which acts on each input channel * separately). The `depthMultiplier` argument controls how many output channels * are generated per input channel in the depthwise step. * * @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'} */ function depthwiseConv2d(args) { return new DepthwiseConv2D(args); } // Basic Layers. /** * Applies an activation function to an output. * * This layer applies element-wise activation function. Other layers, notably * `dense` can also apply activation functions. Use this isolated activation * function to extract the values before and after the * activation. For instance: * * ```js * const input = tf.input({shape: [5]}); * const denseLayer = tf.layers.dense({units: 1}); * const activationLayer = tf.layers.activation({activation: 'relu6'}); * * // Obtain the output symbolic tensors by applying the layers in order. * const denseOutput = denseLayer.apply(input); * const activationOutput = activationLayer.apply(denseOutput); * * // Create the model based on the inputs. * const model = tf.model({ * inputs: input, * outputs: [denseOutput, activationOutput] * }); * * // Collect both outputs and print separately. * const [denseOut, activationOut] = model.predict(tf.randomNormal([6, 5])); * denseOut.print(); * activationOut.print(); * ``` * * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'} */ function activation(args) { return new Activation(args); } /** * Creates a dense (fully connected) layer. * * This layer implements the operation: * `output = activation(dot(input, kernel) + bias)` * * `activation` is the element-wise activation function * passed as the `activation` argument. * * `kernel` is a weights matrix created by the layer. * * `bias` is a bias vector created by the layer (only applicable if `useBias` * is `true`). * * **Input shape:** * * nD `tf.Tensor` with shape: `(batchSize, ..., inputDim)`. * * The most common situation would be * a 2D input with shape `(batchSize, inputDim)`. * * **Output shape:** * * nD tensor with shape: `(batchSize, ..., units)`. * * For instance, for a 2D input with shape `(batchSize, inputDim)`, * the output would have shape `(batchSize, units)`. * * Note: if the input to the layer has a rank greater than 2, then it is * flattened prior to the initial dot product with the kernel. * * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'} */ function dense(args) { return new Dense(args); } /** * Applies * [dropout](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf) to * the input. * * Dropout consists in randomly setting a fraction `rate` of input units to 0 at * each update during training time, which helps prevent overfitting. * * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'} */ function dropout(args) { return new Dropout(args); } /** * Spatial 1D version of Dropout. * * This Layer type performs the same function as the Dropout layer, but it drops * entire 1D feature maps instead of individual elements. For example, if an * input example consists of 3 timesteps and the feature map for each timestep * has a size of 4, a `spatialDropout1d` layer may zero out the feature maps * of the 1st timesteps and 2nd timesteps completely while sparing all feature * elements of the 3rd timestep. * * If adjacent frames (timesteps) are strongly correlated (as is normally the * case in early convolution layers), regular dropout will not regularize the * activation and will otherwise just result in merely an effective learning * rate decrease. In this case, `spatialDropout1d` will help promote * independence among feature maps and should be used instead. * * **Arguments:** * rate: A floating-point number >=0 and <=1. Fraction of the input elements * to drop. * * **Input shape:** * 3D tensor with shape `(samples, timesteps, channels)`. * * **Output shape:** * Same as the input shape. * * References: * - [Efficient Object Localization Using Convolutional * Networks](https://arxiv.org/abs/1411.4280) * * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'} */ function spatialDropout1d(args) { return new SpatialDropout1D(args); } /** * Flattens the input. Does not affect the batch size. * * A `Flatten` layer flattens each batch in its inputs to 1D (making the output * 2D). * * For example: * * ```js * const input = tf.input({shape: [4, 3]}); * const flattenLayer = tf.layers.flatten(); * // Inspect the inferred output shape of the flatten layer, which * // equals `[null, 12]`. The 2nd dimension is 4 * 3, i.e., the result of the * // flattening. (The 1st dimension is the undermined batch size.) * console.log(JSON.stringify(flattenLayer.apply(input).shape)); * ``` * * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'} */ function flatten(args) { return new Flatten(args); } /** * Repeats the input n times in a new dimension. * * ```js * const model = tf.sequential(); * model.add(tf.layers.repeatVector({n: 4, inputShape: [2]})); * const x = tf.tensor2d([[10, 20]]); * // Use the model to do inference on a data point the model hasn't seen * model.predict(x).print(); * // output shape is now [batch, 2, 4] * ``` * * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'} */ function repeatVector(args) { return new RepeatVector(args); } /** * Reshapes an input to a certain shape. * * ```js * const input = tf.input({shape: [4, 3]}); * const reshapeLayer = tf.layers.reshape({targetShape: [2, 6]}); * // Inspect the inferred output shape of the Reshape layer, which * // equals `[null, 2, 6]`. (The 1st dimension is the undermined batch size.) * console.log(JSON.stringify(reshapeLayer.apply(input).shape)); * ``` * * Input shape: * Arbitrary, although all dimensions in the input shape must be fixed. * Use the configuration `inputShape` when using this layer as the * first layer in a model. * * * Output shape: * [batchSize, targetShape[0], targetShape[1], ..., * targetShape[targetShape.length - 1]]. * * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'} */ function reshape(args) { return new Reshape(args); } /** * Permutes the dimensions of the input according to a given pattern. * * Useful for, e.g., connecting RNNs and convnets together. * * Example: * * ```js * const model = tf.sequential(); * model.add(tf.layers.permute({ * dims: [2, 1], * inputShape: [10, 64] * })); * console.log(model.outputShape); * // Now model's output shape is [null, 64, 10], where null is the * // unpermuted sample (batch) dimension. * ``` * * Input shape: * Arbitrary. Use the configuration field `inputShape` when using this * layer as the first layer in a model. * * Output shape: * Same rank as the input shape, but with the dimensions re-ordered (i.e., * permuted) according to the `dims` configuration of this layer. * * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'} */ function permute(args) { return new Permute(args); } /** * Maps positive integers (indices) into dense vectors of fixed size. * E.g. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]] * * **Input shape:** 2D tensor with shape: `[batchSize, sequenceLength]`. * * **Output shape:** 3D tensor with shape: `[batchSize, sequenceLength, * outputDim]`. * * @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'} */ function embedding(args) { return new Embedding(args); } // Merge Layers. /** * Layer that performs element-wise addition on an `Array` of inputs. * * It takes as input a list of tensors, all of the same shape, and returns a * single tensor (also of the same shape). The inputs are specified as an * `Array` when the `apply` method of the `Add` layer instance is called. For * example: * * ```js * const input1 = tf.input({shape: [2, 2]}); * const input2 = tf.input({shape: [2, 2]}); * const addLayer = tf.layers.add(); * const sum = addLayer.apply([input1, input2]); * console.log(JSON.stringify(sum.shape)); * // You get [null, 2, 2], with the first dimension as the undetermined batch * // dimension. * ``` * * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'} */ function add(args) { return new Add(args); } /** * Layer that performs element-wise averaging on an `Array` of inputs. * * It takes as input a list of tensors, all of the same shape, and returns a * single tensor (also of the same shape). For example: * * ```js * const input1 = tf.input({shape: [2, 2]}); * const input2 = tf.input({shape: [2, 2]}); * const averageLayer = tf.layers.average(); * const average = averageLayer.apply([input1, input2]); * console.log(JSON.stringify(average.shape)); * // You get [null, 2, 2], with the first dimension as the undetermined batch * // dimension. * ``` * * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'} */ function average(args) { return new Average(args); } /** * Layer that concatenates an `Array` of inputs. * * It takes a list of tensors, all of the same shape except for the * concatenation axis, and returns a single tensor, the concatenation * of all inputs. For example: * * ```js * const input1 = tf.input({shape: [2, 2]}); * const input2 = tf.input({shape: [2, 3]}); * const concatLayer = tf.layers.concatenate(); * const output = concatLayer.apply([input1, input2]); * console.log(JSON.stringify(output.shape)); * // You get [null, 2, 5], with the first dimension as the undetermined batch * // dimension. The last dimension (5) is the result of concatenating the * // last dimensions of the inputs (2 and 3). * ``` * * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'} */ function concatenate(args) { return new Concatenate(args); } /** * Layer that computes the element-wise maximum of an `Array` of inputs. * * It takes as input a list of tensors, all of the same shape, and returns a * single tensor (also of the same shape). For example: * * ```js * const input1 = tf.input({shape: [2, 2]}); * const input2 = tf.input({shape: [2, 2]}); * const maxLayer = tf.layers.maximum(); * const max = maxLayer.apply([input1, input2]); * console.log(JSON.stringify(max.shape)); * // You get [null, 2, 2], with the first dimension as the undetermined batch * // dimension. * ``` * * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'} */ function maximum(args) { return new Maximum(args); } /** * Layer that computes the element-wise minimum of an `Array` of inputs. * * It takes as input a list of tensors, all of the same shape, and returns a * single tensor (also of the same shape). For example: * * ```js * const input1 = tf.input({shape: [2, 2]}); * const input2 = tf.input({shape: [2, 2]}); * const minLayer = tf.layers.minimum(); * const min = minLayer.apply([input1, input2]); * console.log(JSON.stringify(min.shape)); * // You get [null, 2, 2], with the first dimension as the undetermined batch * // dimension. * ``` * * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'} */ function minimum(args) { return new Minimum(args); } /** * Layer that multiplies (element-wise) an `Array` of inputs. * * It takes as input an Array of tensors, all of the same * shape, and returns a single tensor (also of the same shape). * For example: * * ```js * const input1 = tf.input({shape: [2, 2]}); * const input2 = tf.input({shape: [2, 2]}); * const input3 = tf.input({shape: [2, 2]}); * const multiplyLayer = tf.layers.multiply(); * const product = multiplyLayer.apply([input1, input2, input3]); * console.log(product.shape); * // You get [null, 2, 2], with the first dimension as the undetermined batch * // dimension. * * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'} */ function multiply(args) { return new Multiply(args); } /** * Layer that computes a dot product between samples in two tensors. * * E.g., if applied to a list of two tensors `a` and `b` both of shape * `[batchSize, n]`, the output will be a tensor of shape `[batchSize, 1]`, * where each entry at index `[i, 0]` will be the dot product between * `a[i, :]` and `b[i, :]`. * * Example: * * ```js * const dotLayer = tf.layers.dot({axes: -1}); * const x1 = tf.tensor2d([[10, 20], [30, 40]]); * const x2 = tf.tensor2d([[-1, -2], [-3, -4]]); * * // Invoke the layer's apply() method in eager (imperative) mode. * const y = dotLayer.apply([x1, x2]); * y.print(); * ``` * * @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'} */ function dot(args) { return new Dot(args); } // Normalization Layers. /** * Batch normalization layer (Ioffe and Szegedy, 2014). * * Normalize the activations of the previous layer at each batch, * i.e. applies a transformation that maintains the mean activation * close to 0 and the activation standard deviation close to 1. * * Input shape: * Arbitrary. Use the keyword argument `inputShape` (Array of integers, does * not include the sample axis) when calling the constructor of this class, * if this layer is used as a first layer in a model. * * Output shape: * Same shape as input. * * References: * - [Batch Normalization: Accelerating Deep Network Training by Reducing * Internal Covariate Shift](https://arxiv.org/abs/1502.03167) * * @doc {heading: 'Layers', subheading: 'Normalization', namespace: 'layers'} */ function batchNormalization(args) { return new BatchNormalization(args); } /** * Layer-normalization layer (Ba et al., 2016). * * Normalizes the activations of the previous layer for each given example in a * batch independently, instead of across a batch like in `batchNormalization`. * In other words, this layer applies a transformation that maintains the mean * activation within each example close to 0 and activation variance close to 1. * * Input shape: * Arbitrary. Use the argument `inputShape` when using this layer as the first * layer in a model. * * Output shape: * Same as input. * * References: * - [Layer Normalization](https://arxiv.org/abs/1607.06450) * * @doc {heading: 'Layers', subheading: 'Normalization', namespace: 'layers'} */ function layerNormalization(args) { return new LayerNormalization(args); } // Padding Layers. /** * Zero-padding layer for 2D input (e.g., image). * * This layer can add rows and columns of zeros * at the top, bottom, left and right side of an image tensor. * * Input shape: * 4D tensor with shape: * - If `dataFormat` is `"channelsLast"`: * `[batch, rows, cols, channels]` * - If `data_format` is `"channels_first"`: * `[batch, channels, rows, cols]`. * * Output shape: * 4D with shape: * - If `dataFormat` is `"channelsLast"`: * `[batch, paddedRows, paddedCols, channels]` * - If `dataFormat` is `"channelsFirst"`: * `[batch, channels, paddedRows, paddedCols]`. * * @doc {heading: 'Layers', subheading: 'Padding', namespace: 'layers'} */ function zeroPadding2d(args) { return new ZeroPadding2D(args); } // Pooling Layers. /** * Average pooling operation for spatial data. * * Input shape: `[batchSize, inLength, channels]` * * Output shape: `[batchSize, pooledLength, channels]` * * `tf.avgPool1d` is an alias. * * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'} */ function averagePooling1d(args) { return new AveragePooling1D(args); } function avgPool1d(args) { return averagePooling1d(args); } // For backwards compatibility. // See https://github.com/tensorflow/tfjs/issues/152 function avgPooling1d(args) { return averagePooling1d(args); } /** * Average pooling operation for spatial data. * * Input shape: * - If `dataFormat === CHANNEL_LAST`: * 4D tensor with shape: * `[batchSize, rows, cols, channels]` * - If `dataFormat === CHANNEL_FIRST`: * 4D tensor with shape: * `[batchSize, channels, rows, cols]` * * Output shape * - If `dataFormat === CHANNEL_LAST`: * 4D tensor with shape: * `[batchSize, pooledRows, pooledCols, channels]` * - If `dataFormat === CHANNEL_FIRST`: * 4D tensor with shape: * `[batchSize, channels, pooledRows, pooledCols]` * * `tf.avgPool2d` is an alias. * * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'} */ function averagePooling2d(args) { return new AveragePooling2D(args); } function avgPool2d(args) { return averagePooling2d(args); } // For backwards compatibility. // See https://github.com/tensorflow/tfjs/issues/152 function avgPooling2d(args) { return averagePooling2d(args); } /** * Average pooling operation for 3D data. * * Input shape * - If `dataFormat === channelsLast`: * 5D tensor with shape: * `[batchSize, depths, rows, cols, channels]` * - If `dataFormat === channelsFirst`: * 4D tensor with shape: * `[batchSize, channels, depths, rows, cols]` * * Output shape * - If `dataFormat=channelsLast`: * 5D tensor with shape: * `[batchSize, pooledDepths, pooledRows, pooledCols, channels]` * - If `dataFormat=channelsFirst`: * 5D tensor with shape: * `[batchSize, channels, pooledDepths, pooledRows, pooledCols]` * * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'} */ function averagePooling3d(args) { return new AveragePooling3D(args); } function avgPool3d(args) { return averagePooling3d(args); } // For backwards compatibility. // See https://github.com/tensorflow/tfjs/issues/152 function avgPooling3d(args) { return averagePooling3d(args); } /** * Global average pooling operation for temporal data. * * Input Shape: 3D tensor with shape: `[batchSize, steps, features]`. * * Output Shape: 2D tensor with shape: `[batchSize, features]`. * * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'} */ function globalAveragePooling1d(args) { return new GlobalAveragePooling1D(args); } /** * Global average pooling operation for spatial data. * * Input shape: * - If `dataFormat` is `CHANNEL_LAST`: * 4D tensor with shape: `[batchSize, rows, cols, channels]`. * - If `dataFormat` is `CHANNEL_FIRST`: * 4D tensor with shape: `[batchSize, channels, rows, cols]`. * * Output shape: * 2D tensor with shape: `[batchSize, channels]`. * * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'} */ function globalAveragePooling2d(args) { return new GlobalAveragePooling2D(args); } /** * Global max pooling operation for temporal data. * * Input Shape: 3D tensor with shape: `[batchSize, steps, features]`. * * Output Shape: 2D tensor with shape: `[batchSize, features]`. * * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'} */ function globalMaxPooling1d(args) { return new GlobalMaxPooling1D(args); } /** * Global max pooling operation for spatial data. * * Input shape: * - If `dataFormat` is `CHANNEL_LAST`: * 4D tensor with shape: `[batchSize, rows, cols, channels]`. * - If `dataFormat` is `CHANNEL_FIRST`: * 4D tensor with shape: `[batchSize, channels, rows, cols]`. * * Output shape: * 2D tensor with shape: `[batchSize, channels]`. * * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'} */ function globalMaxPooling2d(args) { return new GlobalMaxPooling2D(args); } /** * Max pooling operation for temporal data. * * Input shape: `[batchSize, inLength, channels]` * * Output shape: `[batchSize, pooledLength, channels]` * * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'} */ function maxPooling1d(args) { return new MaxPooling1D(args); } /** * Max pooling operation for spatial data. * * Input shape * - If `dataFormat === CHANNEL_LAST`: * 4D tensor with shape: * `[batchSize, rows, cols, channels]` * - If `dataFormat === CHANNEL_FIRST`: * 4D tensor with shape: * `[batchSize, channels, rows, cols]` * * Output shape * - If `dataFormat=CHANNEL_LAST`: * 4D tensor with shape: * `[batchSize, pooledRows, pooledCols, channels]` * - If `dataFormat=CHANNEL_FIRST`: * 4D tensor with shape: * `[batchSize, channels, pooledRows, pooledCols]` * * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'} */ function maxPooling2d(args) { return new MaxPooling2D(args); } /** * Max pooling operation for 3D data. * * Input shape * - If `dataFormat === channelsLast`: * 5D tensor with shape: * `[batchSize, depths, rows, cols, channels]` * - If `dataFormat === channelsFirst`: * 5D tensor with shape: * `[batchSize, channels, depths, rows, cols]` * * Output shape * - If `dataFormat=channelsLast`: * 5D tensor with shape: * `[batchSize, pooledDepths, pooledRows, pooledCols, channels]` * - If `dataFormat=channelsFirst`: * 5D tensor with shape: * `[batchSize, channels, pooledDepths, pooledRows, pooledCols]` * * @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'} */ function maxPooling3d(args) { return new MaxPooling3D(args); } // Recurrent Layers. /** * Gated Recurrent Unit - Cho et al. 2014. * * This is an `RNN` layer consisting of one `GRUCell`. However, unlike * the underlying `GRUCell`, the `apply` method of `SimpleRNN` operates * on a sequence of inputs. The shape of the input (not including the first, * batch dimension) needs to be at least 2-D, with the first dimension being * time steps. For example: * * ```js * const rnn = tf.layers.gru({units: 8, returnSequences: true}); * * // Create an input with 10 time steps. * const input = tf.input({shape: [10, 20]}); * const output = rnn.apply(input); * * console.log(JSON.stringify(output.shape)); * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the * // same as the sequence length of `input`, due to `returnSequences`: `true`; * // 3rd dimension is the `GRUCell`'s number of units. * * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */ function gru(args) { return new GRU(args); } /** * Cell class for `GRU`. * * `GRUCell` is distinct from the `RNN` subclass `GRU` in that its * `apply` method takes the input data of only a single time step and returns * the cell's output at the time step, while `GRU` takes the input data * over a number of time steps. For example: * * ```js * const cell = tf.layers.gruCell({units: 2}); * const input = tf.input({shape: [10]}); * const output = cell.apply(input); * * console.log(JSON.stringify(output.shape)); * // [null, 10]: This is the cell's output at a single time step. The 1st * // dimension is the unknown batch size. * ``` * * Instance(s) of `GRUCell` can be used to construct `RNN` layers. The * most typical use of this workflow is to combine a number of cells into a * stacked RNN cell (i.e., `StackedRNNCell` internally) and use it to create an * RNN. For example: * * ```js * const cells = [ * tf.layers.gruCell({units: 4}), * tf.layers.gruCell({units: 8}), * ]; * const rnn = tf.layers.rnn({cell: cells, returnSequences: true}); * * // Create an input with 10 time steps and a length-20 vector at each step. * const input = tf.input({shape: [10, 20]}); * const output = rnn.apply(input); * * console.log(JSON.stringify(output.shape)); * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the * // same as the sequence length of `input`, due to `returnSequences`: `true`; * // 3rd dimension is the last `gruCell`'s number of units. * ``` * * To create an `RNN` consisting of only *one* `GRUCell`, use the * `tf.layers.gru`. * * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */ function gruCell(args) { return new GRUCell(args); } /** * Long-Short Term Memory layer - Hochreiter 1997. * * This is an `RNN` layer consisting of one `LSTMCell`. However, unlike * the underlying `LSTMCell`, the `apply` method of `LSTM` operates * on a sequence of inputs. The shape of the input (not including the first, * batch dimension) needs to be at least 2-D, with the first dimension being * time steps. For example: * * ```js * const lstm = tf.layers.lstm({units: 8, returnSequences: true}); * * // Create an input with 10 time steps. * const input = tf.input({shape: [10, 20]}); * const output = lstm.apply(input); * * console.log(JSON.stringify(output.shape)); * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the * // same as the sequence length of `input`, due to `returnSequences`: `true`; * // 3rd dimension is the `LSTMCell`'s number of units. * * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */ function lstm(args) { return new LSTM(args); } /** * Cell class for `LSTM`. * * `LSTMCell` is distinct from the `RNN` subclass `LSTM` in that its * `apply` method takes the input data of only a single time step and returns * the cell's output at the time step, while `LSTM` takes the input data * over a number of time steps. For example: * * ```js * const cell = tf.layers.lstmCell({units: 2}); * const input = tf.input({shape: [10]}); * const output = cell.apply(input); * * console.log(JSON.stringify(output.shape)); * // [null, 10]: This is the cell's output at a single time step. The 1st * // dimension is the unknown batch size. * ``` * * Instance(s) of `LSTMCell` can be used to construct `RNN` layers. The * most typical use of this workflow is to combine a number of cells into a * stacked RNN cell (i.e., `StackedRNNCell` internally) and use it to create an * RNN. For example: * * ```js * const cells = [ * tf.layers.lstmCell({units: 4}), * tf.layers.lstmCell({units: 8}), * ]; * const rnn = tf.layers.rnn({cell: cells, returnSequences: true}); * * // Create an input with 10 time steps and a length-20 vector at each step. * const input = tf.input({shape: [10, 20]}); * const output = rnn.apply(input); * * console.log(JSON.stringify(output.shape)); * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the * // same as the sequence length of `input`, due to `returnSequences`: `true`; * // 3rd dimension is the last `lstmCell`'s number of units. * ``` * * To create an `RNN` consisting of only *one* `LSTMCell`, use the * `tf.layers.lstm`. * * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */ function lstmCell(args) { return new LSTMCell(args); } /** * Fully-connected RNN where the output is to be fed back to input. * * This is an `RNN` layer consisting of one `SimpleRNNCell`. However, unlike * the underlying `SimpleRNNCell`, the `apply` method of `SimpleRNN` operates * on a sequence of inputs. The shape of the input (not including the first, * batch dimension) needs to be at least 2-D, with the first dimension being * time steps. For example: * * ```js * const rnn = tf.layers.simpleRNN({units: 8, returnSequences: true}); * * // Create an input with 10 time steps. * const input = tf.input({shape: [10, 20]}); * const output = rnn.apply(input); * * console.log(JSON.stringify(output.shape)); * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the * // same as the sequence length of `input`, due to `returnSequences`: `true`; * // 3rd dimension is the `SimpleRNNCell`'s number of units. * ``` * * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */ function simpleRNN(args) { return new SimpleRNN(args); } /** * Cell class for `SimpleRNN`. * * `SimpleRNNCell` is distinct from the `RNN` subclass `SimpleRNN` in that its * `apply` method takes the input data of only a single time step and returns * the cell's output at the time step, while `SimpleRNN` takes the input data * over a number of time steps. For example: * * ```js * const cell = tf.layers.simpleRNNCell({units: 2}); * const input = tf.input({shape: [10]}); * const output = cell.apply(input); * * console.log(JSON.stringify(output.shape)); * // [null, 10]: This is the cell's output at a single time step. The 1st * // dimension is the unknown batch size. * ``` * * Instance(s) of `SimpleRNNCell` can be used to construct `RNN` layers. The * most typical use of this workflow is to combine a number of cells into a * stacked RNN cell (i.e., `StackedRNNCell` internally) and use it to create an * RNN. For example: * * ```js * const cells = [ * tf.layers.simpleRNNCell({units: 4}), * tf.layers.simpleRNNCell({units: 8}), * ]; * const rnn = tf.layers.rnn({cell: cells, returnSequences: true}); * * // Create an input with 10 time steps and a length-20 vector at each step. * const input = tf.input({shape: [10, 20]}); * const output = rnn.apply(input); * * console.log(JSON.stringify(output.shape)); * // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the * // same as the sequence length of `input`, due to `returnSequences`: `true`; * // 3rd dimension is the last `SimpleRNNCell`'s number of units. * ``` * * To create an `RNN` consisting of only *one* `SimpleRNNCell`, use the * `tf.layers.simpleRNN`. * * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */ function simpleRNNCell(args) { return new SimpleRNNCell(args); } /** * Convolutional LSTM layer - Xingjian Shi 2015. * * This is a `ConvRNN2D` layer consisting of one `ConvLSTM2DCell`. However, * unlike the underlying `ConvLSTM2DCell`, the `apply` method of `ConvLSTM2D` * operates on a sequence of inputs. The shape of the input (not including the * first, batch dimension) needs to be 4-D, with the first dimension being time * steps. For example: * * ```js * const filters = 3; * const kernelSize = 3; * * const batchSize = 4; * const sequenceLength = 2; * const size = 5; * const channels = 3; * * const inputShape = [batchSize, sequenceLength, size, size, channels]; * const input = tf.ones(inputShape); * * const layer = tf.layers.convLstm2d({filters, kernelSize}); * * const output = layer.apply(input); * ``` */ /** @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */ function convLstm2d(args) { return new ConvLSTM2D(args); } /** * Cell class for `ConvLSTM2D`. * * `ConvLSTM2DCell` is distinct from the `ConvRNN2D` subclass `ConvLSTM2D` in * that its `call` method takes the input data of only a single time step and * returns the cell's output at the time step, while `ConvLSTM2D` takes the * input data over a number of time steps. For example: * * ```js * const filters = 3; * const kernelSize = 3; * * const sequenceLength = 1; * const size = 5; * const channels = 3; * * const inputShape = [sequenceLength, size, size, channels]; * const input = tf.ones(inputShape); * * const cell = tf.layers.convLstm2dCell({filters, kernelSize}); * * cell.build(input.shape); * * const outputSize = size - kernelSize + 1; * const outShape = [sequenceLength, outputSize, outputSize, filters]; * * const initialH = tf.zeros(outShape); * const initialC = tf.zeros(outShape); * * const [o, h, c] = cell.call([input, initialH, initialC], {}); * ``` */ /** @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */ function convLstm2dCell(args) { return new ConvLSTM2DCell(args); } /** * Base class for recurrent layers. * * Input shape: * 3D tensor with shape `[batchSize, timeSteps, inputDim]`. * * Output shape: * - if `returnState`, an Array of tensors (i.e., `tf.Tensor`s). The first * tensor is the output. The remaining tensors are the states at the * last time step, each with shape `[batchSize, units]`. * - if `returnSequences`, the output will have shape * `[batchSize, timeSteps, units]`. * - else, the output will have shape `[batchSize, units]`. * * Masking: * This layer supports masking for input data with a variable number * of timesteps. To introduce masks to your data, * use an embedding layer with the `mask_zero` parameter * set to `True`. * * Notes on using statefulness in RNNs: * You can set RNN layers to be 'stateful', which means that the states * computed for the samples in one batch will be reused as initial states * for the samples in the next batch. This assumes a one-to-one mapping * between samples in different successive batches. * * To enable statefulness: * - specify `stateful: true` in the layer constructor. * - specify a fixed batch size for your model, by passing * if sequential model: * `batchInputShape=[...]` to the first layer in your model. * else for functional model with 1 or more Input layers: * `batchShape=[...]` to all the first layers in your model. * This is the expected shape of your inputs *including the batch size*. * It should be a tuple of integers, e.g. `(32, 10, 100)`. * - specify `shuffle=False` when calling fit(). * * To reset the states of your model, call `.resetStates()` on either * a specific layer, or on your entire model. * * Note on specifying the initial state of RNNs * You can specify the initial state of RNN layers symbolically by * calling them with the option `initialState`. The value of * `initialState` should be a tensor or list of tensors representing * the initial state of the RNN layer. * * You can specify the initial state of RNN layers numerically by * calling `resetStates` with the keyword argument `states`. The value of * `states` should be a numpy array or list of numpy arrays representing * the initial state of the RNN layer. * * Note on passing external constants to RNNs * You can pass "external" constants to the cell using the `constants` * keyword argument of `RNN.call` method. This requires that the `cell.call` * method accepts the same keyword argument `constants`. Such constants * can be used to condition the cell transformation on additional static * inputs (not changing over time), a.k.a. an attention mechanism. * * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */ function rnn(args) { return new RNN(args); } /** * Wrapper allowing a stack of RNN cells to behave as a single cell. * * Used to implement efficient stacked RNNs. * * @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */ function stackedRNNCells(args) { return new StackedRNNCells(args); } // Wrapper Layers. /** @doc {heading: 'Layers', subheading: 'Wrapper', namespace: 'layers'} */ function bidirectional(args) { return new Bidirectional(args); } /** * This wrapper applies a layer to every temporal slice of an input. * * The input should be at least 3D, and the dimension of the index `1` will be * considered to be the temporal dimension. * * Consider a batch of 32 samples, where each sample is a sequence of 10 vectors * of 16 dimensions. The batch input shape of the layer is then `[32, 10, * 16]`, and the `inputShape`, not including the sample dimension, is * `[10, 16]`. * * You can then use `TimeDistributed` to apply a `Dense` layer to each of the 10 * timesteps, independently: * * ```js * const model = tf.sequential(); * model.add(tf.layers.timeDistributed({ * layer: tf.layers.dense({units: 8}), * inputShape: [10, 16], * })); * * // Now model.outputShape = [null, 10, 8]. * // The output will then have shape `[32, 10, 8]`. * * // In subsequent layers, there is no need for `inputShape`: * model.add(tf.layers.timeDistributed({layer: tf.layers.dense({units: 32})})); * console.log(JSON.stringify(model.outputs[0].shape)); * // Now model.outputShape = [null, 10, 32]. * ``` * * The output will then have shape `[32, 10, 32]`. * * `TimeDistributed` can be used with arbitrary layers, not just `Dense`, for * instance a `Conv2D` layer. * * ```js * const model = tf.sequential(); * model.add(tf.layers.timeDistributed({ * layer: tf.layers.conv2d({filters: 64, kernelSize: [3, 3]}), * inputShape: [10, 299, 299, 3], * })); * console.log(JSON.stringify(model.outputs[0].shape)); * ``` * * @doc {heading: 'Layers', subheading: 'Wrapper', namespace: 'layers'} */ function timeDistributed(args) { return new TimeDistributed(args); } // Aliases for pooling. var globalMaxPool1d = globalMaxPooling1d; var globalMaxPool2d = globalMaxPooling2d; var maxPool1d = maxPooling1d; var maxPool2d = maxPooling2d; /** * Apply additive zero-centered Gaussian noise. * * As it is a regularization layer, it is only active at training time. * * This is useful to mitigate overfitting * (you could see it as a form of random data augmentation). * Gaussian Noise (GS) is a natural choice as corruption process * for real valued inputs. * * # Arguments * stddev: float, standard deviation of the noise distribution. * * # Input shape * Arbitrary. Use the keyword argument `input_shape` * (tuple of integers, does not include the samples axis) * when using this layer as the first layer in a model. * * # Output shape * Same shape as input. * * @doc {heading: 'Layers', subheading: 'Noise', namespace: 'layers'} */ function gaussianNoise(args) { return new GaussianNoise(args); } /** * Apply multiplicative 1-centered Gaussian noise. * * As it is a regularization layer, it is only active at training time. * * Arguments: * - `rate`: float, drop probability (as with `Dropout`). * The multiplicative noise will have * standard deviation `sqrt(rate / (1 - rate))`. * * Input shape: * Arbitrary. Use the keyword argument `inputShape` * (tuple of integers, does not include the samples axis) * when using this layer as the first layer in a model. * * Output shape: * Same shape as input. * * References: * - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting]( * http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf) * * @doc {heading: 'Layers', subheading: 'Noise', namespace: 'layers'} */ function gaussianDropout(args) { return new GaussianDropout(args); } /** * Applies Alpha Dropout to the input. * * As it is a regularization layer, it is only active at training time. * * Alpha Dropout is a `Dropout` that keeps mean and variance of inputs * to their original values, in order to ensure the self-normalizing property * even after this dropout. * Alpha Dropout fits well to Scaled Exponential Linear Units * by randomly setting activations to the negative saturation value. * * Arguments: * - `rate`: float, drop probability (as with `Dropout`). * The multiplicative noise will have * standard deviation `sqrt(rate / (1 - rate))`. * - `noise_shape`: A 1-D `Tensor` of type `int32`, representing the * shape for randomly generated keep/drop flags. * * Input shape: * Arbitrary. Use the keyword argument `inputShape` * (tuple of integers, does not include the samples axis) * when using this layer as the first layer in a model. * * Output shape: * Same shape as input. * * References: * - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) * * @doc {heading: 'Layers', subheading: 'Noise', namespace: 'layers'} */ function alphaDropout(args) { return new AlphaDropout(args); } /** * Masks a sequence by using a mask value to skip timesteps. * * If all features for a given sample timestep are equal to `mask_value`, * then the sample timestep will be masked (skipped) in all downstream layers * (as long as they support masking). * * If any downstream layer does not support masking yet receives such * an input mask, an exception will be raised. * * Arguments: * - `maskValue`: Either None or mask value to skip. * * Input shape: * Arbitrary. Use the keyword argument `inputShape` * (tuple of integers, does not include the samples axis) * when using this layer as the first layer in a model. * * Output shape: * Same shape as input. * * @doc {heading: 'Layers', subheading: 'Mask', namespace: 'layers'} */ function masking(args) { return new Masking(args); } /** * A preprocessing layer which rescales input values to a new range. * * This layer rescales every value of an input (often an image) by multiplying * by `scale` and adding `offset`. * * For instance: * 1. To rescale an input in the ``[0, 255]`` range * to be in the `[0, 1]` range, you would pass `scale=1/255`. * 2. To rescale an input in the ``[0, 255]`` range to be in the `[-1, 1]` * range, you would pass `scale=1./127.5, offset=-1`. * The rescaling is applied both during training and inference. Inputs can be * of integer or floating point dtype, and by default the layer will output * floats. * * Arguments: * - `scale`: Float, the scale to apply to the inputs. * - `offset`: Float, the offset to apply to the inputs. * * Input shape: * Arbitrary. * * Output shape: * Same as input. * * @doc {heading: 'Layers', subheading: 'Rescaling', namespace: 'layers'} */ function rescaling(args) { return new Rescaling(args); } /** * A preprocessing layer which center crops images. * * This layers crops the central portion of the images to a target size. If an * image is smaller than the target size, it will be resized and cropped so as * to return the largest possible window in the image that matches the target * aspect ratio. * * Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and * of integer or floating point dtype. * * If the input height/width is even and the target height/width is odd (or * inversely), the input image is left-padded by 1 pixel. * * Arguments: * `height`: Integer, the height of the output shape. * `width`: Integer, the width of the output shape. * * Input shape: * 3D (unbatched) or 4D (batched) tensor with shape: * `(..., height, width, channels)`, in `channelsLast` format. * * Output shape: * 3D (unbatched) or 4D (batched) tensor with shape: * `(..., targetHeight, targetWidth, channels)`. * * * @doc {heading: 'Layers', subheading: 'CenterCrop', namespace: 'layers'} */ function centerCrop(args) { return new CenterCrop(args); } /** * A preprocessing layer which resizes images. * This layer resizes an image input to a target height and width. The input * should be a 4D (batched) or 3D (unbatched) tensor in `"channels_last"` * format. Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, * 255]`) and of interger or floating point dtype. By default, the layer will * output floats. * * Arguments: * - `height`: number, the height for the output tensor. * - `width`: number, the width for the output tensor. * - `interpolation`: string, the method for image resizing interpolation. * - `cropToAspectRatio`: boolean, whether to keep image aspect ratio. * * Input shape: * Arbitrary. * * Output shape: * height, width, num channels. * * @doc {heading: 'Layers', subheading: 'Resizing', namespace: 'layers'} */ function resizing(args) { return new Resizing(args); } /** * A preprocessing layer which encodes integer features. * * This layer provides options for condensing data into a categorical encoding * when the total number of tokens are known in advance. It accepts integer * values as inputs, and it outputs a dense representation of those * inputs. * * Arguments: * * numTokens: The total number of tokens the layer should support. All * inputs to the layer must integers in the range `0 <= value < * numTokens`, or an error will be thrown. * * outputMode: Specification for the output of the layer. * Defaults to `multiHot`. Values can be `oneHot`, `multiHot` or * `count`, configuring the layer as follows: * * oneHot: Encodes each individual element in the input into an * array of `numTokens` size, containing a 1 at the element index. If * the last dimension is size 1, will encode on that dimension. If the * last dimension is not size 1, will append a new dimension for the * encoded output. * * multiHot: Encodes each sample in the input into a single array * of `numTokens` size, containing a 1 for each vocabulary term * present in the sample. Treats the last dimension as the sample * dimension, if input shape is `(..., sampleLength)`, output shape * will be `(..., numTokens)`. * * count: Like `multiHot`, but the int array contains a count of * the number of times the token at that index appeared in the sample. * * For all output modes, currently only output up to rank 2 is supported. * Call arguments: * inputs: A 1D or 2D tensor of integer inputs. * countWeights: A tensor in the same shape as `inputs` indicating the * weight for each sample value when summing up in `count` mode. Not used * in `multiHot` or `oneHot` modes. * * * @doc {heading: 'Layers', subheading: 'CategoryEncoding', namespace: 'layers'} */ function categoryEncoding(args) { return new CategoryEncoding(args); } /** * A preprocessing layer which randomly varies image width during training. * * This layer will randomly adjusts the width of a batch of images of a batch * of images by a random factor. * * The input should be a 3D (unbatched) or 4D (batched) tensor in * the `"channels_last"` image data format. Input pixel values can be of any * range (e.g. `[0., 1.)` or `[0, 255]`) and of integer or floating point * dtype. By default, the layer will output floats. By default, this layer is * inactive during inference. For an overview and full list of preprocessing * layers, see the preprocessing [guide] * (https://www.tensorflow.org/guide/keras/preprocessing_layers). * * Arguments: * * factor: * A positive float (fraction of original width), or a tuple of size 2 * representing lower and upper bound for resizing vertically. * When represented as a single float, this value is used for both the upper * and lower bound. For instance, `factor=(0.2, 0.3)` results in an output * with width changed by a random amount in the range `[20%, 30%]`. * `factor=(-0.2, 0.3)` results in an output with width changed by a random * amount in the range `[-20%, +30%]`. `factor=0.2` results in an output * with width changed by a random amount in the range `[-20%, +20%]`. * interpolation: * String, the interpolation method. * Defaults to `bilinear`. * Supports `"bilinear"`, `"nearest"`. * The tf methods `"bicubic"`, `"area"`, `"lanczos3"`, `"lanczos5"`, * `"gaussian"`, `"mitchellcubic"` are unimplemented in tfjs. * seed: * Integer. Used to create a random seed. * * Input shape: * 3D (unbatched) or 4D (batched) tensor with shape: * `(..., height, width, channels)`, in `"channels_last"` format. * Output shape: * 3D (unbatched) or 4D (batched) tensor with shape: * `(..., height, random_width, channels)`. * * * @doc {heading: 'Layers', subheading: 'RandomWidth', namespace: 'layers'} */ function randomWidth(args) { return new RandomWidth(args); } var exports_layers = { __proto__: null, Layer: Layer, RNN: RNN, RNNCell: RNNCell, activation: activation, add: add, alphaDropout: alphaDropout, average: average, averagePooling1d: averagePooling1d, averagePooling2d: averagePooling2d, averagePooling3d: averagePooling3d, avgPool1d: avgPool1d, avgPool2d: avgPool2d, avgPool3d: avgPool3d, avgPooling1d: avgPooling1d, avgPooling2d: avgPooling2d, avgPooling3d: avgPooling3d, batchNormalization: batchNormalization, bidirectional: bidirectional, categoryEncoding: categoryEncoding, centerCrop: centerCrop, concatenate: concatenate, conv1d: conv1d, conv2d: conv2d, conv2dTranspose: conv2dTranspose, conv3d: conv3d, conv3dTranspose: conv3dTranspose, convLstm2d: convLstm2d, convLstm2dCell: convLstm2dCell, cropping2D: cropping2D, dense: dense, depthwiseConv2d: depthwiseConv2d, dot: dot, dropout: dropout, elu: elu, embedding: embedding, flatten: flatten, gaussianDropout: gaussianDropout, gaussianNoise: gaussianNoise, globalAveragePooling1d: globalAveragePooling1d, globalAveragePooling2d: globalAveragePooling2d, globalMaxPool1d: globalMaxPool1d, globalMaxPool2d: globalMaxPool2d, globalMaxPooling1d: globalMaxPooling1d, globalMaxPooling2d: globalMaxPooling2d, gru: gru, gruCell: gruCell, input: input, inputLayer: inputLayer, layerNormalization: layerNormalization, leakyReLU: leakyReLU, lstm: lstm, lstmCell: lstmCell, masking: masking, maxPool1d: maxPool1d, maxPool2d: maxPool2d, maxPooling1d: maxPooling1d, maxPooling2d: maxPooling2d, maxPooling3d: maxPooling3d, maximum: maximum, minimum: minimum, multiply: multiply, permute: permute, prelu: prelu, randomWidth: randomWidth, reLU: reLU, repeatVector: repeatVector, rescaling: rescaling, reshape: reshape, resizing: resizing, rnn: rnn, separableConv2d: separableConv2d, simpleRNN: simpleRNN, simpleRNNCell: simpleRNNCell, softmax: softmax, spatialDropout1d: spatialDropout1d, stackedRNNCells: stackedRNNCells, thresholdedReLU: thresholdedReLU, timeDistributed: timeDistributed, upSampling2d: upSampling2d, zeroPadding2d: zeroPadding2d }; /** * Binary accuracy metric function. * * `yTrue` and `yPred` can have 0-1 values. Example: * ```js * const x = tf.tensor2d([[1, 1, 1, 1], [0, 0, 0, 0]], [2, 4]); * const y = tf.tensor2d([[1, 0, 1, 0], [0, 0, 0, 1]], [2, 4]); * const accuracy = tf.metrics.binaryAccuracy(x, y); * accuracy.print(); * ``` * * `yTrue` and `yPred` can also have floating-number values between 0 and 1, in * which case the values will be thresholded at 0.5 to yield 0-1 values (i.e., * a value >= 0.5 and <= 1.0 is interpreted as 1). * * Example: * ```js * const x = tf.tensor1d([1, 1, 1, 1, 0, 0, 0, 0]); * const y = tf.tensor1d([0.2, 0.4, 0.6, 0.8, 0.2, 0.3, 0.4, 0.7]); * const accuracy = tf.metrics.binaryAccuracy(x, y); * accuracy.print(); * ``` * * @param yTrue Binary Tensor of truth. * @param yPred Binary Tensor of prediction. * @return Accuracy Tensor. * * @doc {heading: 'Metrics', namespace: 'metrics'} */ function binaryAccuracy(yTrue, yPred) { return binaryAccuracy$1(yTrue, yPred); } /** * Binary crossentropy metric function. * * Example: * ```js * const x = tf.tensor2d([[0], [1], [1], [1]]); * const y = tf.tensor2d([[0], [0], [0.5], [1]]); * const crossentropy = tf.metrics.binaryCrossentropy(x, y); * crossentropy.print(); * ``` * * @param yTrue Binary Tensor of truth. * @param yPred Binary Tensor of prediction, probabilities for the `1` case. * @return Accuracy Tensor. * * @doc {heading: 'Metrics', namespace: 'metrics'} */ function binaryCrossentropy(yTrue, yPred) { return binaryCrossentropy$1(yTrue, yPred); } /** * Sparse categorical accuracy metric function. * * Example: * ```js * * const yTrue = tf.tensor1d([1, 1, 2, 2, 0]); * const yPred = tf.tensor2d( * [[0, 1, 0], [1, 0, 0], [0, 0.4, 0.6], [0, 0.6, 0.4], [0.7, 0.3, 0]]); * const crossentropy = tf.metrics.sparseCategoricalAccuracy(yTrue, yPred); * crossentropy.print(); * ``` * * @param yTrue True labels: indices. * @param yPred Predicted probabilities or logits. * @returns Accuracy tensor. * * @doc {heading: 'Metrics', namespace: 'metrics'} */ function sparseCategoricalAccuracy(yTrue, yPred) { return sparseCategoricalAccuracy$1(yTrue, yPred); } /** * Categorical accuracy metric function. * * Example: * ```js * const x = tf.tensor2d([[0, 0, 0, 1], [0, 0, 0, 1]]); * const y = tf.tensor2d([[0.1, 0.8, 0.05, 0.05], [0.1, 0.05, 0.05, 0.8]]); * const accuracy = tf.metrics.categoricalAccuracy(x, y); * accuracy.print(); * ``` * * @param yTrue Binary Tensor of truth: one-hot encoding of categories. * @param yPred Binary Tensor of prediction: probabilities or logits for the * same categories as in `yTrue`. * @return Accuracy Tensor. * * @doc {heading: 'Metrics', namespace: 'metrics'} */ function categoricalAccuracy(yTrue, yPred) { return categoricalAccuracy$1(yTrue, yPred); } /** * Categorical crossentropy between an output tensor and a target tensor. * * @param target A tensor of the same shape as `output`. * @param output A tensor resulting from a softmax (unless `fromLogits` is * `true`, in which case `output` is expected to be the logits). * @param fromLogits Boolean, whether `output` is the result of a softmax, or is * a tensor of logits. * * @doc {heading: 'Metrics', namespace: 'metrics'} */ function categoricalCrossentropy(yTrue, yPred) { return categoricalCrossentropy$1(yTrue, yPred); } /** * Computes the precision of the predictions with respect to the labels. * * Example: * ```js * const x = tf.tensor2d( * [ * [0, 0, 0, 1], * [0, 1, 0, 0], * [0, 0, 0, 1], * [1, 0, 0, 0], * [0, 0, 1, 0] * ] * ); * * const y = tf.tensor2d( * [ * [0, 0, 1, 0], * [0, 1, 0, 0], * [0, 0, 0, 1], * [0, 1, 0, 0], * [0, 1, 0, 0] * ] * ); * * const precision = tf.metrics.precision(x, y); * precision.print(); * ``` * * @param yTrue The ground truth values. Expected to contain only 0-1 values. * @param yPred The predicted values. Expected to contain only 0-1 values. * @return Precision Tensor. * * @doc {heading: 'Metrics', namespace: 'metrics'} */ function precision(yTrue, yPred) { return precision$1(yTrue, yPred); } /** * Computes the recall of the predictions with respect to the labels. * * Example: * ```js * const x = tf.tensor2d( * [ * [0, 0, 0, 1], * [0, 1, 0, 0], * [0, 0, 0, 1], * [1, 0, 0, 0], * [0, 0, 1, 0] * ] * ); * * const y = tf.tensor2d( * [ * [0, 0, 1, 0], * [0, 1, 0, 0], * [0, 0, 0, 1], * [0, 1, 0, 0], * [0, 1, 0, 0] * ] * ); * * const recall = tf.metrics.recall(x, y); * recall.print(); * ``` * * @param yTrue The ground truth values. Expected to contain only 0-1 values. * @param yPred The predicted values. Expected to contain only 0-1 values. * @return Recall Tensor. * * @doc {heading: 'Metrics', namespace: 'metrics'} */ function recall(yTrue, yPred) { return recall$1(yTrue, yPred); } /** * Loss or metric function: Cosine proximity. * * Mathematically, cosine proximity is defined as: * `-sum(l2Normalize(yTrue) * l2Normalize(yPred))`, * wherein `l2Normalize()` normalizes the L2 norm of the input to 1 and `*` * represents element-wise multiplication. * * ```js * const yTrue = tf.tensor2d([[1, 0], [1, 0]]); * const yPred = tf.tensor2d([[1 / Math.sqrt(2), 1 / Math.sqrt(2)], [0, 1]]); * const proximity = tf.metrics.cosineProximity(yTrue, yPred); * proximity.print(); * ``` * * @param yTrue Truth Tensor. * @param yPred Prediction Tensor. * @return Cosine proximity Tensor. * * @doc {heading: 'Metrics', namespace: 'metrics'} */ function cosineProximity(yTrue, yPred) { return cosineProximity$1(yTrue, yPred); } /** * Loss or metric function: Mean absolute error. * * Mathematically, mean absolute error is defined as: * `mean(abs(yPred - yTrue))`, * wherein the `mean` is applied over feature dimensions. * * ```js * const yTrue = tf.tensor2d([[0, 1], [0, 0], [2, 3]]); * const yPred = tf.tensor2d([[0, 1], [0, 1], [-2, -3]]); * const mse = tf.metrics.meanAbsoluteError(yTrue, yPred); * mse.print(); * ``` * * @param yTrue Truth Tensor. * @param yPred Prediction Tensor. * @return Mean absolute error Tensor. * * @doc {heading: 'Metrics', namespace: 'metrics'} */ function meanAbsoluteError(yTrue, yPred) { return meanAbsoluteError$1(yTrue, yPred); } /** * Loss or metric function: Mean absolute percentage error. * * ```js * const yTrue = tf.tensor2d([[0, 1], [10, 20]]); * const yPred = tf.tensor2d([[0, 1], [11, 24]]); * const mse = tf.metrics.meanAbsolutePercentageError(yTrue, yPred); * mse.print(); * ``` * * Aliases: `tf.metrics.MAPE`, `tf.metrics.mape`. * * @param yTrue Truth Tensor. * @param yPred Prediction Tensor. * @return Mean absolute percentage error Tensor. * * @doc {heading: 'Metrics', namespace: 'metrics'} */ function meanAbsolutePercentageError(yTrue, yPred) { return meanAbsolutePercentageError$1(yTrue, yPred); } function MAPE(yTrue, yPred) { return meanAbsolutePercentageError$1(yTrue, yPred); } function mape(yTrue, yPred) { return meanAbsolutePercentageError$1(yTrue, yPred); } /** * Loss or metric function: Mean squared error. * * ```js * const yTrue = tf.tensor2d([[0, 1], [3, 4]]); * const yPred = tf.tensor2d([[0, 1], [-3, -4]]); * const mse = tf.metrics.meanSquaredError(yTrue, yPred); * mse.print(); * ``` * * Aliases: `tf.metrics.MSE`, `tf.metrics.mse`. * * @param yTrue Truth Tensor. * @param yPred Prediction Tensor. * @return Mean squared error Tensor. * * @doc {heading: 'Metrics', namespace: 'metrics'} */ function meanSquaredError(yTrue, yPred) { return meanSquaredError$1(yTrue, yPred); } function MSE(yTrue, yPred) { return meanSquaredError$1(yTrue, yPred); } function mse(yTrue, yPred) { return meanSquaredError$1(yTrue, yPred); } var exports_metrics = { __proto__: null, MAPE: MAPE, MSE: MSE, binaryAccuracy: binaryAccuracy, binaryCrossentropy: binaryCrossentropy, categoricalAccuracy: categoricalAccuracy, categoricalCrossentropy: categoricalCrossentropy, cosineProximity: cosineProximity, mape: mape, meanAbsoluteError: meanAbsoluteError, meanAbsolutePercentageError: meanAbsolutePercentageError, meanSquaredError: meanSquaredError, mse: mse, precision: precision, recall: recall, sparseCategoricalAccuracy: sparseCategoricalAccuracy }; /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ var exports_models = { __proto__: null, modelFromJSON: modelFromJSON }; /** * @license * Copyright 2018 Google LLC * * Use of this source code is governed by an MIT-style * license that can be found in the LICENSE file or at * https://opensource.org/licenses/MIT. * ============================================================================= */ /** * Regularizer for L1 and L2 regularization. * * Adds a term to the loss to penalize large weights: * loss += sum(l1 * abs(x)) + sum(l2 * x^2) * * @doc {heading: 'Regularizers', namespace: 'regularizers'} */ function l1l2(config) { return new L1L2(config); } /** * Regularizer for L1 regularization. * * Adds a term to the loss to penalize large weights: * loss += sum(l1 * abs(x)) * @param args l1 config. * * @doc {heading: 'Regularizers', namespace: 'regularizers'} */ function l1(config) { return l1$1(config); } /** * Regularizer for L2 regularization. * * Adds a term to the loss to penalize large weights: * loss += sum(l2 * x^2) * @param args l2 config. * * @doc {heading: 'Regularizers', namespace: 'regularizers'} */ function l2(config) { return l2$1(config); } var exports_regularizers = { __proto__: null, l1: l1, l1l2: l1l2, l2: l2 }; var Callback = /** @class */ (function (_super) { __extends(Callback, _super); function Callback() { var _this = _super.apply(this, __spreadArray([], __read(arguments), false)) || this; /** Instance of `keras.models.Model`. Reference of the model being trained. */ _this.model = null; return _this; } Callback.prototype.setModel = function (model) { if (!(model instanceof LayersModel)) { throw new Error('model must be a LayersModel, not some other Container'); } this.model = model; }; return Callback; }(BaseCallback)); function less(currVal, prevVal) { return currVal < prevVal; } function greater(currVal, prevVal) { return currVal > prevVal; } /** * A Callback that stops training when a monitored quantity has stopped * improving. */ var EarlyStopping = /** @class */ (function (_super) { __extends(EarlyStopping, _super); function EarlyStopping(args) { var _this = _super.call(this) || this; if (args == null) { args = {}; } if (args.restoreBestWeights) { throw new NotImplementedError('restoreBestWeights = True is not implemented in EarlyStopping yet.'); } _this.monitor = args.monitor || 'val_loss'; _this.minDelta = Math.abs(args.minDelta || 0); _this.patience = args.patience || 0; _this.verbose = args.verbose || 0; _this.mode = args.mode || 'auto'; _this.baseline = args.baseline; if (['auto', 'min', 'max'].indexOf(_this.mode) === -1) { console.warn("EarlyStopping mode '".concat(_this.mode, "' is invalid. ") + "Falling back to mode 'auto'."); _this.mode = 'auto'; } if (_this.mode === 'min') { _this.monitorFunc = less; } else if (_this.mode === 'max') { _this.monitorFunc = greater; } else { // For mode === 'auto'. if (_this.monitor.indexOf('acc') !== -1) { _this.monitorFunc = greater; } else { _this.monitorFunc = less; } } if (_this.monitorFunc === less) { _this.minDelta *= -1; } return _this; } EarlyStopping.prototype.onTrainBegin = function (logs) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { this.wait = 0; this.stoppedEpoch = 0; if (this.baseline != null) { this.best = this.baseline; } else { this.best = this.monitorFunc === less ? Infinity : -Infinity; } return [2 /*return*/]; }); }); }; EarlyStopping.prototype.onEpochEnd = function (epoch, logs) { return __awaiter(this, void 0, void 0, function () { var current; return __generator(this, function (_a) { switch (_a.label) { case 0: return [4 /*yield*/, resolveScalarsInLogs(logs)]; case 1: _a.sent(); current = this.getMonitorValue(logs); if (current == null) { return [2 /*return*/]; } if (this.monitorFunc(current - this.minDelta, this.best)) { this.best = current; this.wait = 0; // TODO(cais): Logic for restoreBestWeights. } else { this.wait++; if (this.wait >= this.patience) { this.stoppedEpoch = epoch; this.model.stopTraining = true; } // TODO(cais): Logic for restoreBestWeights. } return [2 /*return*/]; } }); }); }; EarlyStopping.prototype.onTrainEnd = function (logs) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { if (this.stoppedEpoch > 0 && this.verbose) { console.log("Epoch ".concat(this.stoppedEpoch, ": early stopping.")); } return [2 /*return*/]; }); }); }; EarlyStopping.prototype.getMonitorValue = function (logs) { if (logs == null) { logs = {}; } var monitorValue = logs[this.monitor]; if (monitorValue == null) { console.warn("Metric for EarlyStopping ".concat(this.monitor, " is not available. ") + "Available metrics are: ".concat(Object.keys(logs))); } return monitorValue; }; return EarlyStopping; }(Callback)); /** * Factory function for a Callback that stops training when a monitored * quantity has stopped improving. * * Early stopping is a type of regularization, and protects model against * overfitting. * * The following example based on fake data illustrates how this callback * can be used during `tf.LayersModel.fit()`: * * ```js * const model = tf.sequential(); * model.add(tf.layers.dense({ * units: 3, * activation: 'softmax', * kernelInitializer: 'ones', * inputShape: [2] * })); * const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]); * const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]); * const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]); * const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]); * model.compile( * {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']}); * * // Without the EarlyStopping callback, the val_acc value would be: * // 0.5, 0.5, 0.5, 0.5, ... * // With val_acc being monitored, training should stop after the 2nd epoch. * const history = await model.fit(xs, ys, { * epochs: 10, * validationData: [xsVal, ysVal], * callbacks: tf.callbacks.earlyStopping({monitor: 'val_acc'}) * }); * * // Expect to see a length-2 array. * console.log(history.history.val_acc); * ``` * * @doc { * heading: 'Callbacks', * namespace: 'callbacks' * } */ function earlyStopping(args) { return new EarlyStopping(args); } var callbacks = { earlyStopping: earlyStopping }; exports.Callback = Callback; exports.CallbackList = CallbackList; exports.CustomCallback = CustomCallback; exports.EarlyStopping = EarlyStopping; exports.History = History; exports.InputSpec = InputSpec; exports.LayerVariable = LayerVariable; exports.LayersModel = LayersModel; exports.RNN = RNN; exports.Sequential = Sequential; exports.SymbolicTensor = SymbolicTensor; exports.callbacks = callbacks; exports.constraints = exports_constraints; exports.initializers = exports_initializers; exports.input = input; exports.layers = exports_layers; exports.loadLayersModel = loadLayersModel; exports.metrics = exports_metrics; exports.model = model; exports.models = exports_models; exports.registerCallbackConstructor = registerCallbackConstructor; exports.regularizers = exports_regularizers; exports.sequential = sequential; exports.version_layers = version; //# sourceMappingURL=tf-layers.node.js.map