/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { dispose } from '../globals'; import { variableGrads } from '../gradients'; import { scalar } from '../ops/ops'; import { Serializable } from '../serialization'; /** @doc {heading: 'Training', subheading: 'Classes', namespace: 'train'} */ export class Optimizer extends Serializable { /** * Executes `f()` and minimizes the scalar output of `f()` by computing * gradients of y with respect to the list of trainable variables provided by * `varList`. If no list is provided, it defaults to all trainable variables. * * @param f The function to execute and whose output to minimize. * @param returnCost Whether to return the scalar cost value produced by * executing `f()`. * @param varList An optional list of variables to update. If specified, only * the trainable variables in varList will be updated by minimize. Defaults to * all trainable variables. * * @doc {heading: 'Training', subheading: 'Optimizers'} */ minimize(f, returnCost = false, varList) { const { value, grads } = this.computeGradients(f, varList); if (varList != null) { const gradArray = varList.map(v => ({ name: v.name, tensor: grads[v.name] })); this.applyGradients(gradArray); } else { this.applyGradients(grads); } // Dispose gradients. dispose(grads); if (returnCost) { return value; } else { value.dispose(); return null; } } /** * The number of iterations that this optimizer instance has been invoked for. */ get iterations() { if (this.iterations_ == null) { this.iterations_ = 0; } return this.iterations_; } incrementIterations() { this.iterations_ = this.iterations + 1; } /** * Executes f() and computes the gradient of the scalar output of f() with * respect to the list of trainable variables provided by `varList`. If no * list is provided, it defaults to all trainable variables. * * @param f The function to execute and whose output to use for computing * gradients with respect to variables. * @param varList An optional list of variables to compute gradients with * respect to. If specified, only the trainable variables in varList will have * gradients computed with respect to. Defaults to all trainable variables. * * @doc {heading: 'Training', subheading: 'Optimizers'} */ computeGradients(f, varList) { return variableGrads(f, varList); } /** * Dispose the variables (if any) owned by this optimizer instance. */ dispose() { if (this.iterations_ != null) { dispose(this.iterations_); } } async saveIterations() { if (this.iterations_ == null) { this.iterations_ = 0; } return { name: 'iter', // TODO(cais): Use 'int64' type when available. tensor: scalar(this.iterations_, 'int32') }; } async getWeights() { throw new Error('getWeights() is not implemented for this optimizer yet.'); } async setWeights(weightValues) { throw new Error(`setWeights() is not implemented for this optimizer class ` + `${this.getClassName()}`); } /** * Extract the first element of the weight values and set it * as the iterations counter variable of this instance of optimizer. * * @param weightValues * @returns Weight values with the first element consumed and excluded. */ async extractIterations(weightValues) { this.iterations_ = (await weightValues[0].tensor.data())[0]; return weightValues.slice(1); } } Object.defineProperty(Optimizer, Symbol.hasInstance, { value: (instance) => { return instance.minimize != null && instance.computeGradients != null && instance.applyGradients != null; } }); //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"optimizer.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/optimizers/optimizer.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,EAAC,OAAO,EAAC,MAAM,YAAY,CAAC;AACnC,OAAO,EAAC,aAAa,EAAC,MAAM,cAAc,CAAC;AAC3C,OAAO,EAAC,MAAM,EAAC,MAAM,YAAY,CAAC;AAClC,OAAO,EAAC,YAAY,EAAC,MAAM,kBAAkB,CAAC;AAoB9C,4EAA4E;AAC5E,MAAM,OAAgB,SAAU,SAAQ,YAAY;IAGlD;;;;;;;;;;;;;OAaG;IACH,QAAQ,CAAC,CAAe,EAAE,UAAU,GAAG,KAAK,EAAE,OAAoB;QAEhE,MAAM,EAAC,KAAK,EAAE,KAAK,EAAC,GAAG,IAAI,CAAC,gBAAgB,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QAEzD,IAAI,OAAO,IAAI,IAAI,EAAE;YACnB,MAAM,SAAS,GACX,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAC,IAAI,EAAE,CAAC,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,CAAC,CAAC,CAAC,IAAI,CAAC,EAAC,CAAC,CAAC,CAAC;YAC9D,IAAI,CAAC,cAAc,CAAC,SAAS,CAAC,CAAC;SAChC;aAAM;YACL,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,CAAC;SAC5B;QAED,qBAAqB;QACrB,OAAO,CAAC,KAAK,CAAC,CAAC;QAEf,IAAI,UAAU,EAAE;YACd,OAAO,KAAK,CAAC;SACd;aAAM;YACL,KAAK,CAAC,OAAO,EAAE,CAAC;YAChB,OAAO,IAAI,CAAC;SACb;IACH,CAAC;IAED;;OAEG;IACH,IAAI,UAAU;QACZ,IAAI,IAAI,CAAC,WAAW,IAAI,IAAI,EAAE;YAC5B,IAAI,CAAC,WAAW,GAAG,CAAC,CAAC;SACtB;QACD,OAAO,IAAI,CAAC,WAAW,CAAC;IAC1B,CAAC;IAES,mBAAmB;QAC3B,IAAI,CAAC,WAAW,GAAG,IAAI,CAAC,UAAU,GAAG,CAAC,CAAC;IACzC,CAAC;IAED;;;;;;;;;;;;OAYG;IACH,gBAAgB,CAAC,CAAe,EAAE,OAAoB;QAEpD,OAAO,aAAa,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;IACnC,CAAC;IAYD;;OAEG;IACH,OAAO;QACL,IAAI,IAAI,CAAC,WAAW,IAAI,IAAI,EAAE;YAC5B,OAAO,CAAC,IAAI,CAAC,WAAW,CAAC,CAAC;SAC3B;IACH,CAAC;IAED,KAAK,CAAC,cAAc;QAClB,IAAI,IAAI,CAAC,WAAW,IAAI,IAAI,EAAE;YAC5B,IAAI,CAAC,WAAW,GAAG,CAAC,CAAC;SACtB;QACD,OAAO;YACL,IAAI,EAAE,MAAM;YACZ,+CAA+C;YAC/C,MAAM,EAAE,MAAM,CAAC,IAAI,CAAC,WAAW,EAAE,OAAO,CAAC;SAC1C,CAAC;IACJ,CAAC;IAED,KAAK,CAAC,UAAU;QACd,MAAM,IAAI,KAAK,CAAC,yDAAyD,CAAC,CAAC;IAC7E,CAAC;IAED,KAAK,CAAC,UAAU,CAAC,YAA2B;QAC1C,MAAM,IAAI,KAAK,CACX,2DAA2D;YAC3D,GAAG,IAAI,CAAC,YAAY,EAAE,EAAE,CAAC,CAAC;IAChC,CAAC;IAED;;;;;;OAMG;IACO,KAAK,CAAC,iBAAiB,CAAC,YAA2B;QAE3D,IAAI,CAAC,WAAW,GAAG,CAAC,MAAM,YAAY,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,IAAI,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QAC5D,OAAO,YAAY,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;IAC/B,CAAC;CACF;AAED,MAAM,CAAC,cAAc,CAAC,SAAS,EAAE,MAAM,CAAC,WAAW,EAAE;IACnD,KAAK,EAAE,CAAC,QAAmB,EAAE,EAAE;QAC7B,OAAO,QAAQ,CAAC,QAAQ,IAAI,IAAI,IAAI,QAAQ,CAAC,gBAAgB,IAAI,IAAI;YACjE,QAAQ,CAAC,cAAc,IAAI,IAAI,CAAC;IACtC,CAAC;CACF,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport {dispose} from '../globals';\nimport {variableGrads} from '../gradients';\nimport {scalar} from '../ops/ops';\nimport {Serializable} from '../serialization';\nimport {Scalar, Variable} from '../tensor';\nimport {NamedTensor, NamedTensorMap} from '../tensor_types';\n\n/**\n * A variable that belongs to an optimizer.\n *\n * The `originalName` field is required for keeping track of the canonical\n * name of the variable, which is usually the name of the model weight that\n * the variable is related to plus a suffix, e.g., 'dense1/kernel/momentum'.\n * The name of the `Variable` object itself cannot be used directly due to\n * possible deduplication: Every `Variable` must have a unique name but more\n * than one optimizer objects of the same type may be created for the same model\n * or the same `Variable`.\n */\nexport interface OptimizerVariable {\n  originalName: string;\n  variable: Variable;\n}\n\n/** @doc {heading: 'Training', subheading: 'Classes', namespace: 'train'} */\nexport abstract class Optimizer extends Serializable {\n  protected iterations_: number;\n\n  /**\n   * Executes `f()` and minimizes the scalar output of `f()` by computing\n   * gradients of y with respect to the list of trainable variables provided by\n   * `varList`. If no list is provided, it defaults to all trainable variables.\n   *\n   * @param f The function to execute and whose output to minimize.\n   * @param returnCost Whether to return the scalar cost value produced by\n   * executing `f()`.\n   * @param varList An optional list of variables to update. If specified, only\n   * the trainable variables in varList will be updated by minimize. Defaults to\n   * all trainable variables.\n   *\n   * @doc {heading: 'Training', subheading: 'Optimizers'}\n   */\n  minimize(f: () => Scalar, returnCost = false, varList?: Variable[]): Scalar\n      |null {\n    const {value, grads} = this.computeGradients(f, varList);\n\n    if (varList != null) {\n      const gradArray: NamedTensor[] =\n          varList.map(v => ({name: v.name, tensor: grads[v.name]}));\n      this.applyGradients(gradArray);\n    } else {\n      this.applyGradients(grads);\n    }\n\n    // Dispose gradients.\n    dispose(grads);\n\n    if (returnCost) {\n      return value;\n    } else {\n      value.dispose();\n      return null;\n    }\n  }\n\n  /**\n   * The number of iterations that this optimizer instance has been invoked for.\n   */\n  get iterations(): number {\n    if (this.iterations_ == null) {\n      this.iterations_ = 0;\n    }\n    return this.iterations_;\n  }\n\n  protected incrementIterations() {\n    this.iterations_ = this.iterations + 1;\n  }\n\n  /**\n   * Executes f() and computes the gradient of the scalar output of f() with\n   * respect to the list of trainable variables provided by `varList`. If no\n   * list is provided, it defaults to all trainable variables.\n   *\n   * @param f The function to execute and whose output to use for computing\n   * gradients with respect to variables.\n   * @param varList An optional list of variables to compute gradients with\n   * respect to. If specified, only the trainable variables in varList will have\n   * gradients computed with respect to. Defaults to all trainable variables.\n   *\n   * @doc {heading: 'Training', subheading: 'Optimizers'}\n   */\n  computeGradients(f: () => Scalar, varList?: Variable[]):\n      {value: Scalar, grads: NamedTensorMap} {\n    return variableGrads(f, varList);\n  }\n\n  /**\n   * Updates variables by using the computed gradients.\n   *\n   * @param variableGradients A mapping of variable name to its gradient value.\n   *\n   * @doc {heading: 'Training', subheading: 'Optimizers'}\n   */\n  abstract applyGradients(variableGradients: NamedTensorMap|\n                          NamedTensor[]): void;\n\n  /**\n   * Dispose the variables (if any) owned by this optimizer instance.\n   */\n  dispose(): void {\n    if (this.iterations_ != null) {\n      dispose(this.iterations_);\n    }\n  }\n\n  async saveIterations(): Promise<NamedTensor> {\n    if (this.iterations_ == null) {\n      this.iterations_ = 0;\n    }\n    return {\n      name: 'iter',  // Named for Python compatibility.\n      // TODO(cais): Use 'int64' type when available.\n      tensor: scalar(this.iterations_, 'int32')\n    };\n  }\n\n  async getWeights(): Promise<NamedTensor[]> {\n    throw new Error('getWeights() is not implemented for this optimizer yet.');\n  }\n\n  async setWeights(weightValues: NamedTensor[]): Promise<void> {\n    throw new Error(\n        `setWeights() is not implemented for this optimizer class ` +\n        `${this.getClassName()}`);\n  }\n\n  /**\n   * Extract the first element of the weight values and set it\n   * as the iterations counter variable of this instance of optimizer.\n   *\n   * @param weightValues\n   * @returns Weight values with the first element consumed and excluded.\n   */\n  protected async extractIterations(weightValues: NamedTensor[]):\n      Promise<NamedTensor[]> {\n    this.iterations_ = (await weightValues[0].tensor.data())[0];\n    return weightValues.slice(1);\n  }\n}\n\nObject.defineProperty(Optimizer, Symbol.hasInstance, {\n  value: (instance: Optimizer) => {\n    return instance.minimize != null && instance.computeGradients != null &&\n        instance.applyGradients != null;\n  }\n});\n"]}