/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { assert } from './util'; /** * Maps to mapping between the custom object and its name. * * After registering a custom class, these two maps will add key-value pairs * for the class object and the registered name. * * Therefore we can get the relative registered name by calling * getRegisteredName() function. * * For example: * GLOBAL_CUSTOM_OBJECT: {key=registeredName: value=corresponding * CustomObjectClass} * * GLOBAL_CUSTOM_NAMES: {key=CustomObjectClass: value=corresponding * registeredName} * */ const GLOBAL_CUSTOM_OBJECT = new Map(); const GLOBAL_CUSTOM_NAMES = new Map(); /** * Serializable defines the serialization contract. * * TFJS requires serializable classes to return their className when asked * to avoid issues with minification. */ export class Serializable { /** * Return the class name for this class to use in serialization contexts. * * Generally speaking this will be the same thing that constructor.name * would have returned. However, the class name needs to be robust * against minification for serialization/deserialization to work properly. * * There's also places such as initializers.VarianceScaling, where * implementation details between different languages led to different * class hierarchies and a non-leaf node is used for serialization purposes. */ getClassName() { return this.constructor .className; } /** * Creates an instance of T from a ConfigDict. * * This works for most descendants of serializable. A few need to * provide special handling. * @param cls A Constructor for the class to instantiate. * @param config The Configuration for the object. */ /** @nocollapse */ static fromConfig(cls, config) { return new cls(config); } } /** * Maps string keys to class constructors. * * Used during (de)serialization from the cross-language JSON format, which * requires the class name in the serialization format matches the class * names as used in Python, should it exist. */ export class SerializationMap { constructor() { this.classNameMap = {}; } /** * Returns the singleton instance of the map. */ static getMap() { if (SerializationMap.instance == null) { SerializationMap.instance = new SerializationMap(); } return SerializationMap.instance; } /** * Registers the class as serializable. */ static register(cls) { SerializationMap.getMap().classNameMap[cls.className] = [cls, cls.fromConfig]; } } /** * Register a class with the serialization map of TensorFlow.js. * * This is often used for registering custom Layers, so they can be * serialized and deserialized. * * Example 1. Register the class without package name and specified name. * * ```js * class MyCustomLayer extends tf.layers.Layer { * static className = 'MyCustomLayer'; * * constructor(config) { * super(config); * } * } * tf.serialization.registerClass(MyCustomLayer); * console.log(tf.serialization.GLOBALCUSTOMOBJECT.get("Custom>MyCustomLayer")); * console.log(tf.serialization.GLOBALCUSTOMNAMES.get(MyCustomLayer)); * ``` * * Example 2. Register the class with package name: "Package" and specified * name: "MyLayer". * ```js * class MyCustomLayer extends tf.layers.Layer { * static className = 'MyCustomLayer'; * * constructor(config) { * super(config); * } * } * tf.serialization.registerClass(MyCustomLayer, "Package", "MyLayer"); * console.log(tf.serialization.GLOBALCUSTOMOBJECT.get("Package>MyLayer")); * console.log(tf.serialization.GLOBALCUSTOMNAMES.get(MyCustomLayer)); * ``` * * Example 3. Register the class with specified name: "MyLayer". * ```js * class MyCustomLayer extends tf.layers.Layer { * static className = 'MyCustomLayer'; * * constructor(config) { * super(config); * } * } * tf.serialization.registerClass(MyCustomLayer, undefined, "MyLayer"); * console.log(tf.serialization.GLOBALCUSTOMOBJECT.get("Custom>MyLayer")); * console.log(tf.serialization.GLOBALCUSTOMNAMES.get(MyCustomLayer)); * ``` * * Example 4. Register the class with specified package name: "Package". * ```js * class MyCustomLayer extends tf.layers.Layer { * static className = 'MyCustomLayer'; * * constructor(config) { * super(config); * } * } * tf.serialization.registerClass(MyCustomLayer, "Package"); * console.log(tf.serialization.GLOBALCUSTOMOBJECT * .get("Package>MyCustomLayer")); * console.log(tf.serialization.GLOBALCUSTOMNAMES * .get(MyCustomLayer)); * ``` * * @param cls The class to be registered. It must have a public static member * called `className` defined and the value must be a non-empty string. * @param pkg The pakcage name that this class belongs to. This used to define * the key in GlobalCustomObject. If not defined, it defaults to `Custom`. * @param name The name that user specified. It defaults to the actual name of * the class as specified by its static `className` property. * @doc {heading: 'Models', subheading: 'Serialization', ignoreCI: true} */ export function registerClass(cls, pkg, name) { assert(cls.className != null, () => `Class being registered does not have the static className ` + `property defined.`); assert(typeof cls.className === 'string', () => `className is required to be a string, but got type ` + typeof cls.className); assert(cls.className.length > 0, () => `Class being registered has an empty-string as its className, ` + `which is disallowed.`); if (typeof pkg === 'undefined') { pkg = 'Custom'; } if (typeof name === 'undefined') { name = cls.className; } const className = name; const registerName = pkg + '>' + className; SerializationMap.register(cls); GLOBAL_CUSTOM_OBJECT.set(registerName, cls); GLOBAL_CUSTOM_NAMES.set(cls, registerName); return cls; } /** * Get the registered name of a class. If the class has not been registered, * return the class name. * * @param cls The class we want to get register name for. It must have a public * static member called `className` defined. * @returns registered name or class name. */ export function getRegisteredName(cls) { if (GLOBAL_CUSTOM_NAMES.has(cls)) { return GLOBAL_CUSTOM_NAMES.get(cls); } else { return cls.className; } } //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"serialization.js","sourceRoot":"","sources":["../../../../../tfjs-core/src/serialization.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,MAAM,EAAC,MAAM,QAAQ,CAAC;AAuC9B;;;;;;;;;;;;;;;;GAgBG;AACH,MAAM,oBAAoB,GACtB,IAAI,GAAG,EAAiD,CAAC;AAE7D,MAAM,mBAAmB,GACrB,IAAI,GAAG,EAAiD,CAAC;AAE7D;;;;;GAKG;AACH,MAAM,OAAgB,YAAY;IAChC;;;;;;;;;;OAUG;IACH,YAAY;QACV,OAAQ,IAAI,CAAC,WAAqD;aAC7D,SAAS,CAAC;IACjB,CAAC;IAOD;;;;;;;OAOG;IACH,kBAAkB;IAClB,MAAM,CAAC,UAAU,CACb,GAA+B,EAAE,MAAkB;QACrD,OAAO,IAAI,GAAG,CAAC,MAAM,CAAC,CAAC;IACzB,CAAC;CACF;AAED;;;;;;GAMG;AACH,MAAM,OAAO,gBAAgB;IAO3B;QACE,IAAI,CAAC,YAAY,GAAG,EAAE,CAAC;IACzB,CAAC;IAED;;OAEG;IACH,MAAM,CAAC,MAAM;QACX,IAAI,gBAAgB,CAAC,QAAQ,IAAI,IAAI,EAAE;YACrC,gBAAgB,CAAC,QAAQ,GAAG,IAAI,gBAAgB,EAAE,CAAC;SACpD;QACD,OAAO,gBAAgB,CAAC,QAAQ,CAAC;IACnC,CAAC;IAED;;OAEG;IACH,MAAM,CAAC,QAAQ,CAAyB,GAA+B;QACrE,gBAAgB,CAAC,MAAM,EAAE,CAAC,YAAY,CAAC,GAAG,CAAC,SAAS,CAAC;YACjD,CAAC,GAAG,EAAE,GAAG,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;CACF;AAED;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAyEG;AACH,MAAM,UAAU,aAAa,CACzB,GAA+B,EAAE,GAAY,EAAE,IAAa;IAC9D,MAAM,CACF,GAAG,CAAC,SAAS,IAAI,IAAI,EACrB,GAAG,EAAE,CAAC,4DAA4D;QAC9D,mBAAmB,CAAC,CAAC;IAC7B,MAAM,CACF,OAAO,GAAG,CAAC,SAAS,KAAK,QAAQ,EACjC,GAAG,EAAE,CAAC,qDAAqD;QACvD,OAAO,GAAG,CAAC,SAAS,CAAC,CAAC;IAC9B,MAAM,CACF,GAAG,CAAC,SAAS,CAAC,MAAM,GAAG,CAAC,EACxB,GAAG,EAAE,CAAC,+DAA+D;QACjE,sBAAsB,CAAC,CAAC;IAEhC,IAAI,OAAO,GAAG,KAAK,WAAW,EAAE;QAC9B,GAAG,GAAG,QAAQ,CAAC;KAChB;IAED,IAAI,OAAO,IAAI,KAAK,WAAW,EAAE;QAC/B,IAAI,GAAG,GAAG,CAAC,SAAS,CAAC;KACtB;IAED,MAAM,SAAS,GAAG,IAAI,CAAC;IACvB,MAAM,YAAY,GAAG,GAAG,GAAG,GAAG,GAAG,SAAS,CAAC;IAE3C,gBAAgB,CAAC,QAAQ,CAAC,GAAG,CAAC,CAAC;IAC/B,oBAAoB,CAAC,GAAG,CAAC,YAAY,EAAE,GAAG,CAAC,CAAC;IAC5C,mBAAmB,CAAC,GAAG,CAAC,GAAG,EAAE,YAAY,CAAC,CAAC;IAE3C,OAAO,GAAG,CAAC;AACb,CAAC;AAED;;;;;;;GAOG;AACH,MAAM,UAAU,iBAAiB,CAC7B,GAA+B;IACjC,IAAI,mBAAmB,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE;QAChC,OAAO,mBAAmB,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC;KACrC;SAAM;QACL,OAAO,GAAG,CAAC,SAAS,CAAC;KACtB;AACH,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {assert} from './util';\n\n/**\n * Types to support JSON-esque data structures internally.\n *\n * Internally ConfigDict's use camelCase keys and values where the\n * values are class names to be instantiated.  On the python side, these\n * will be snake_case.  Internally we allow Enums into the values for better\n * type safety, but these need to be converted to raw primitives (usually\n * strings) for round-tripping with python.\n *\n * toConfig returns the TS-friendly representation. model.toJSON() returns\n * the pythonic version as that's the portable format.  If you need to\n * python-ify a non-model level toConfig output, you'll need to use a\n * convertTsToPythonic from serialization_utils in -Layers.\n *\n */\nexport declare type ConfigDictValue =\n    boolean | number | string | null | ConfigDictArray | ConfigDict;\nexport declare interface ConfigDict {\n  [key: string]: ConfigDictValue;\n}\nexport declare interface ConfigDictArray extends Array<ConfigDictValue> {}\n\n/**\n * Type to represent the class-type of Serializable objects.\n *\n * Ie the class prototype with access to the constructor and any\n * static members/methods. Instance methods are not listed here.\n *\n * Source for this idea: https://stackoverflow.com/a/43607255\n */\nexport declare type SerializableConstructor<T extends Serializable> = {\n  // tslint:disable-next-line:no-any\n  new (...args: any[]): T; className: string; fromConfig: FromConfigMethod<T>;\n};\nexport declare type FromConfigMethod<T extends Serializable> =\n    (cls: SerializableConstructor<T>, config: ConfigDict) => T;\n\n/**\n * Maps to mapping between the custom object and its name.\n *\n * After registering a custom class, these two maps will add key-value pairs\n * for the class object and the registered name.\n *\n * Therefore we can get the relative registered name by calling\n * getRegisteredName() function.\n *\n * For example:\n * GLOBAL_CUSTOM_OBJECT: {key=registeredName: value=corresponding\n * CustomObjectClass}\n *\n * GLOBAL_CUSTOM_NAMES: {key=CustomObjectClass: value=corresponding\n * registeredName}\n *\n */\nconst GLOBAL_CUSTOM_OBJECT =\n    new Map<string, SerializableConstructor<Serializable>>();\n\nconst GLOBAL_CUSTOM_NAMES =\n    new Map<SerializableConstructor<Serializable>, string>();\n\n/**\n * Serializable defines the serialization contract.\n *\n * TFJS requires serializable classes to return their className when asked\n * to avoid issues with minification.\n */\nexport abstract class Serializable {\n  /**\n   * Return the class name for this class to use in serialization contexts.\n   *\n   * Generally speaking this will be the same thing that constructor.name\n   * would have returned.  However, the class name needs to be robust\n   * against minification for serialization/deserialization to work properly.\n   *\n   * There's also places such as initializers.VarianceScaling, where\n   * implementation details between different languages led to different\n   * class hierarchies and a non-leaf node is used for serialization purposes.\n   */\n  getClassName(): string {\n    return (this.constructor as SerializableConstructor<Serializable>)\n        .className;\n  }\n\n  /**\n   * Return all the non-weight state needed to serialize this object.\n   */\n  abstract getConfig(): ConfigDict;\n\n  /**\n   * Creates an instance of T from a ConfigDict.\n   *\n   * This works for most descendants of serializable.  A few need to\n   * provide special handling.\n   * @param cls A Constructor for the class to instantiate.\n   * @param config The Configuration for the object.\n   */\n  /** @nocollapse */\n  static fromConfig<T extends Serializable>(\n      cls: SerializableConstructor<T>, config: ConfigDict): T {\n    return new cls(config);\n  }\n}\n\n/**\n * Maps string keys to class constructors.\n *\n * Used during (de)serialization from the cross-language JSON format, which\n * requires the class name in the serialization format matches the class\n * names as used in Python, should it exist.\n */\nexport class SerializationMap {\n  private static instance: SerializationMap;\n  classNameMap: {\n    [className: string]:\n        [SerializableConstructor<Serializable>, FromConfigMethod<Serializable>]\n  };\n\n  private constructor() {\n    this.classNameMap = {};\n  }\n\n  /**\n   * Returns the singleton instance of the map.\n   */\n  static getMap(): SerializationMap {\n    if (SerializationMap.instance == null) {\n      SerializationMap.instance = new SerializationMap();\n    }\n    return SerializationMap.instance;\n  }\n\n  /**\n   * Registers the class as serializable.\n   */\n  static register<T extends Serializable>(cls: SerializableConstructor<T>) {\n    SerializationMap.getMap().classNameMap[cls.className] =\n        [cls, cls.fromConfig];\n  }\n}\n\n/**\n * Register a class with the serialization map of TensorFlow.js.\n *\n * This is often used for registering custom Layers, so they can be\n * serialized and deserialized.\n *\n * Example 1. Register the class without package name and specified name.\n *\n * ```js\n * class MyCustomLayer extends tf.layers.Layer {\n *   static className = 'MyCustomLayer';\n *\n *   constructor(config) {\n *     super(config);\n *   }\n * }\n * tf.serialization.registerClass(MyCustomLayer);\n * console.log(tf.serialization.GLOBALCUSTOMOBJECT.get(\"Custom>MyCustomLayer\"));\n * console.log(tf.serialization.GLOBALCUSTOMNAMES.get(MyCustomLayer));\n * ```\n *\n * Example 2. Register the class with package name: \"Package\" and specified\n * name: \"MyLayer\".\n * ```js\n * class MyCustomLayer extends tf.layers.Layer {\n *   static className = 'MyCustomLayer';\n *\n *   constructor(config) {\n *     super(config);\n *   }\n * }\n * tf.serialization.registerClass(MyCustomLayer, \"Package\", \"MyLayer\");\n * console.log(tf.serialization.GLOBALCUSTOMOBJECT.get(\"Package>MyLayer\"));\n * console.log(tf.serialization.GLOBALCUSTOMNAMES.get(MyCustomLayer));\n * ```\n *\n * Example 3. Register the class with specified name: \"MyLayer\".\n * ```js\n * class MyCustomLayer extends tf.layers.Layer {\n *   static className = 'MyCustomLayer';\n *\n *   constructor(config) {\n *     super(config);\n *   }\n * }\n * tf.serialization.registerClass(MyCustomLayer, undefined, \"MyLayer\");\n * console.log(tf.serialization.GLOBALCUSTOMOBJECT.get(\"Custom>MyLayer\"));\n * console.log(tf.serialization.GLOBALCUSTOMNAMES.get(MyCustomLayer));\n * ```\n *\n * Example 4. Register the class with specified package name: \"Package\".\n * ```js\n * class MyCustomLayer extends tf.layers.Layer {\n *   static className = 'MyCustomLayer';\n *\n *   constructor(config) {\n *     super(config);\n *   }\n * }\n * tf.serialization.registerClass(MyCustomLayer, \"Package\");\n * console.log(tf.serialization.GLOBALCUSTOMOBJECT\n * .get(\"Package>MyCustomLayer\"));\n * console.log(tf.serialization.GLOBALCUSTOMNAMES\n * .get(MyCustomLayer));\n * ```\n *\n * @param cls The class to be registered. It must have a public static member\n *   called `className` defined and the value must be a non-empty string.\n * @param pkg The pakcage name that this class belongs to. This used to define\n *     the key in GlobalCustomObject. If not defined, it defaults to `Custom`.\n * @param name The name that user specified. It defaults to the actual name of\n *     the class as specified by its static `className` property.\n * @doc {heading: 'Models', subheading: 'Serialization', ignoreCI: true}\n */\nexport function registerClass<T extends Serializable>(\n    cls: SerializableConstructor<T>, pkg?: string, name?: string) {\n  assert(\n      cls.className != null,\n      () => `Class being registered does not have the static className ` +\n          `property defined.`);\n  assert(\n      typeof cls.className === 'string',\n      () => `className is required to be a string, but got type ` +\n          typeof cls.className);\n  assert(\n      cls.className.length > 0,\n      () => `Class being registered has an empty-string as its className, ` +\n          `which is disallowed.`);\n\n  if (typeof pkg === 'undefined') {\n    pkg = 'Custom';\n  }\n\n  if (typeof name === 'undefined') {\n    name = cls.className;\n  }\n\n  const className = name;\n  const registerName = pkg + '>' + className;\n\n  SerializationMap.register(cls);\n  GLOBAL_CUSTOM_OBJECT.set(registerName, cls);\n  GLOBAL_CUSTOM_NAMES.set(cls, registerName);\n\n  return cls;\n}\n\n/**\n * Get the registered name of a class. If the class has not been registered,\n * return the class name.\n *\n * @param cls The class we want to get register name for. It must have a public\n *     static member called `className` defined.\n * @returns registered name or class name.\n */\nexport function getRegisteredName<T extends Serializable>(\n    cls: SerializableConstructor<T>) {\n  if (GLOBAL_CUSTOM_NAMES.has(cls)) {\n    return GLOBAL_CUSTOM_NAMES.get(cls);\n  } else {\n    return cls.className;\n  }\n}\n"]}