/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { ENGINE } from '../engine'; import { dispose, tidy } from '../globals'; import { add } from '../ops/add'; import { div } from '../ops/div'; import { mul } from '../ops/mul'; import { pow } from '../ops/pow'; import { scalar } from '../ops/scalar'; import { sqrt } from '../ops/sqrt'; import { square } from '../ops/square'; import { sub } from '../ops/sub'; import { zerosLike } from '../ops/zeros_like'; import { Optimizer } from './optimizer'; export class AdamOptimizer extends Optimizer { /** @nocollapse */ static get className() { // Name matters for Python compatibility. // This is a getter instead of a property because when it's a property, it // prevents the entire class from being tree-shaken. return 'Adam'; } constructor(learningRate, beta1, beta2, epsilon = null) { super(); this.learningRate = learningRate; this.beta1 = beta1; this.beta2 = beta2; this.epsilon = epsilon; this.accumulatedFirstMoment = []; this.accumulatedSecondMoment = []; tidy(() => { // accB* will be updated by batch. this.accBeta1 = scalar(beta1).variable(); this.accBeta2 = scalar(beta2).variable(); }); if (epsilon == null) { this.epsilon = ENGINE.backend.epsilon(); } } applyGradients(variableGradients) { const varNames = Array.isArray(variableGradients) ? variableGradients.map(v => v.name) : Object.keys(variableGradients); tidy(() => { const oneMinusAccBeta1 = sub(1, this.accBeta1); const oneMinusAccBeta2 = sub(1, this.accBeta2); varNames.forEach((name, i) => { const value = ENGINE.registeredVariables[name]; const trainable = false; if (this.accumulatedFirstMoment[i] == null) { this.accumulatedFirstMoment[i] = { originalName: `${name}/m`, variable: tidy(() => zerosLike(value).variable(trainable)) }; } if (this.accumulatedSecondMoment[i] == null) { this.accumulatedSecondMoment[i] = { originalName: `${name}/v`, variable: tidy(() => zerosLike(value).variable(trainable)) }; } const gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name]; if (gradient == null) { return; } const firstMoment = this.accumulatedFirstMoment[i].variable; const secondMoment = this.accumulatedSecondMoment[i].variable; const newFirstMoment = add(mul(firstMoment, this.beta1), mul(gradient, 1 - this.beta1)); const newSecondMoment = add(mul(secondMoment, this.beta2), mul(square(gradient), 1 - this.beta2)); const biasCorrectedFirstMoment = div(newFirstMoment, oneMinusAccBeta1); const biasCorrectedSecondMoment = div(newSecondMoment, oneMinusAccBeta2); firstMoment.assign(newFirstMoment); secondMoment.assign(newSecondMoment); const newValue = add(mul(div(biasCorrectedFirstMoment, add(sqrt(biasCorrectedSecondMoment), this.epsilon)), -this.learningRate), value); value.assign(newValue); }); this.accBeta1.assign(mul(this.accBeta1, this.beta1)); this.accBeta2.assign(mul(this.accBeta2, this.beta2)); }); this.incrementIterations(); } dispose() { this.accBeta1.dispose(); this.accBeta2.dispose(); if (this.accumulatedFirstMoment != null) { dispose(this.accumulatedFirstMoment.map(v => v.variable)); } if (this.accumulatedSecondMoment != null) { dispose(this.accumulatedSecondMoment.map(v => v.variable)); } } async getWeights() { // Order matters for Python compatibility. const variables = [...this.accumulatedFirstMoment, ...this.accumulatedSecondMoment]; return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable }))); } async setWeights(weightValues) { weightValues = await this.extractIterations(weightValues); tidy(() => { this.accBeta1.assign(pow(this.beta1, this.iterations_ + 1)); this.accBeta2.assign(pow(this.beta2, this.iterations_ + 1)); }); const variableCount = weightValues.length / 2; const trainable = false; this.accumulatedFirstMoment = weightValues.slice(0, variableCount).map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) })); this.accumulatedSecondMoment = weightValues.slice(variableCount, variableCount * 2) .map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) })); } getConfig() { return { 'learningRate': this.learningRate, 'beta1': this.beta1, 'beta2': this.beta2, 'epsilon': this.epsilon, }; } /** @nocollapse */ static fromConfig(cls, config) { return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon']); } } //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYWRhbV9vcHRpbWl6ZXIuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvcmUvc3JjL29wdGltaXplcnMvYWRhbV9vcHRpbWl6ZXIudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLE1BQU0sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNqQyxPQUFPLEVBQUMsT0FBTyxFQUFFLElBQUksRUFBQyxNQUFNLFlBQVksQ0FBQztBQUN6QyxPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sWUFBWSxDQUFDO0FBQy9CLE9BQU8sRUFBQyxHQUFHLEVBQUMsTUFBTSxZQUFZLENBQUM7QUFDL0IsT0FBTyxFQUFDLEdBQUcsRUFBQyxNQUFNLFlBQVksQ0FBQztBQUMvQixPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sWUFBWSxDQUFDO0FBQy9CLE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxlQUFlLENBQUM7QUFDckMsT0FBTyxFQUFDLElBQUksRUFBQyxNQUFNLGFBQWEsQ0FBQztBQUNqQyxPQUFPLEVBQUMsTUFBTSxFQUFDLE1BQU0sZUFBZSxDQUFDO0FBQ3JDLE9BQU8sRUFBQyxHQUFHLEVBQUMsTUFBTSxZQUFZLENBQUM7QUFDL0IsT0FBTyxFQUFDLFNBQVMsRUFBQyxNQUFNLG1CQUFtQixDQUFDO0FBSzVDLE9BQU8sRUFBQyxTQUFTLEVBQW9CLE1BQU0sYUFBYSxDQUFDO0FBRXpELE1BQU0sT0FBTyxhQUFjLFNBQVEsU0FBUztJQUMxQyxrQkFBa0I7SUFDbEIsTUFBTSxLQUFLLFNBQVM7UUFDbEIseUNBQXlDO1FBQ3pDLDBFQUEwRTtRQUMxRSxvREFBb0Q7UUFDcEQsT0FBTyxNQUFNLENBQUM7SUFDaEIsQ0FBQztJQU9ELFlBQ2MsWUFBb0IsRUFBWSxLQUFhLEVBQzdDLEtBQWEsRUFBWSxVQUFrQixJQUFJO1FBQzNELEtBQUssRUFBRSxDQUFDO1FBRkksaUJBQVksR0FBWixZQUFZLENBQVE7UUFBWSxVQUFLLEdBQUwsS0FBSyxDQUFRO1FBQzdDLFVBQUssR0FBTCxLQUFLLENBQVE7UUFBWSxZQUFPLEdBQVAsT0FBTyxDQUFlO1FBTHJELDJCQUFzQixHQUF3QixFQUFFLENBQUM7UUFDakQsNEJBQXVCLEdBQXdCLEVBQUUsQ0FBQztRQU14RCxJQUFJLENBQUMsR0FBRyxFQUFFO1lBQ1Isa0NBQWtDO1lBQ2xDLElBQUksQ0FBQyxRQUFRLEdBQUcsTUFBTSxDQUFDLEtBQUssQ0FBQyxDQUFDLFFBQVEsRUFBRSxDQUFDO1lBQ3pDLElBQUksQ0FBQyxRQUFRLEdBQUcsTUFBTSxDQUFDLEtBQUssQ0FBQyxDQUFDLFFBQVEsRUFBRSxDQUFDO1FBQzNDLENBQUMsQ0FBQyxDQUFDO1FBRUgsSUFBSSxPQUFPLElBQUksSUFBSSxFQUFFO1lBQ25CLElBQUksQ0FBQyxPQUFPLEdBQUcsTUFBTSxDQUFDLE9BQU8sQ0FBQyxPQUFPLEVBQUUsQ0FBQztTQUN6QztJQUNILENBQUM7SUFFRCxjQUFjLENBQUMsaUJBQWlEO1FBQzlELE1BQU0sUUFBUSxHQUFHLEtBQUssQ0FBQyxPQUFPLENBQUMsaUJBQWlCLENBQUMsQ0FBQyxDQUFDO1lBQy9DLGlCQUFpQixDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDO1lBQ3BDLE1BQU0sQ0FBQyxJQUFJLENBQUMsaUJBQWlCLENBQUMsQ0FBQztRQUNuQyxJQUFJLENBQUMsR0FBRyxFQUFFO1lBQ1IsTUFBTSxnQkFBZ0IsR0FBRyxHQUFHLENBQUMsQ0FBQyxFQUFFLElBQUksQ0FBQyxRQUFRLENBQUMsQ0FBQztZQUMvQyxNQUFNLGdCQUFnQixHQUFHLEdBQUcsQ0FBQyxDQUFDLEVBQUUsSUFBSSxDQUFDLFFBQVEsQ0FBQyxDQUFDO1lBRS9DLFFBQVEsQ0FBQyxPQUFPLENBQUMsQ0FBQyxJQUFJLEVBQUUsQ0FBQyxFQUFFLEVBQUU7Z0JBQzNCLE1BQU0sS0FBSyxHQUFHLE1BQU0sQ0FBQyxtQkFBbUIsQ0FBQyxJQUFJLENBQUMsQ0FBQztnQkFDL0MsTUFBTSxTQUFTLEdBQUcsS0FBSyxDQUFDO2dCQUN4QixJQUFJLElBQUksQ0FBQyxzQkFBc0IsQ0FBQyxDQUFDLENBQUMsSUFBSSxJQUFJLEVBQUU7b0JBQzFDLElBQUksQ0FBQyxzQkFBc0IsQ0FBQyxDQUFDLENBQUMsR0FBRzt3QkFDL0IsWUFBWSxFQUFFLEdBQUcsSUFBSSxJQUFJO3dCQUN6QixRQUFRLEVBQUUsSUFBSSxDQUFDLEdBQUcsRUFBRSxDQUFDLFNBQVMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxRQUFRLENBQUMsU0FBUyxDQUFDLENBQUM7cUJBQzNELENBQUM7aUJBQ0g7Z0JBQ0QsSUFBSSxJQUFJLENBQUMsdUJBQXVCLENBQUMsQ0FBQyxDQUFDLElBQUksSUFBSSxFQUFFO29CQUMzQyxJQUFJLENBQUMsdUJBQXVCLENBQUMsQ0FBQyxDQUFDLEdBQUc7d0JBQ2hDLFlBQVksRUFBRSxHQUFHLElBQUksSUFBSTt3QkFDekIsUUFBUSxFQUFFLElBQUksQ0FBQyxHQUFHLEVBQUUsQ0FBQyxTQUFTLENBQUMsS0FBSyxDQUFDLENBQUMsUUFBUSxDQUFDLFNBQVMsQ0FBQyxDQUFDO3FCQUMzRCxDQUFDO2lCQUNIO2dCQUVELE1BQU0sUUFBUSxHQUFHLEtBQUssQ0FBQyxPQUFPLENBQUMsaUJBQWlCLENBQUMsQ0FBQyxDQUFDO29CQUMvQyxpQkFBaUIsQ0FBQyxDQUFDLENBQUMsQ0FBQyxNQUFNLENBQUMsQ0FBQztvQkFDN0IsaUJBQWlCLENBQUMsSUFBSSxDQUFDLENBQUM7Z0JBQzVCLElBQUksUUFBUSxJQUFJLElBQUksRUFBRTtvQkFDcEIsT0FBTztpQkFDUjtnQkFFRCxNQUFNLFdBQVcsR0FBRyxJQUFJLENBQUMsc0JBQXNCLENBQUMsQ0FBQyxDQUFDLENBQUMsUUFBUSxDQUFDO2dCQUM1RCxNQUFNLFlBQVksR0FBRyxJQUFJLENBQUMsdUJBQXVCLENBQUMsQ0FBQyxDQUFDLENBQUMsUUFBUSxDQUFDO2dCQUU5RCxNQUFNLGNBQWMsR0FDaEIsR0FBRyxDQUFDLEdBQUcsQ0FBQyxXQUFXLEVBQUUsSUFBSSxDQUFDLEtBQUssQ0FBQyxFQUFFLEdBQUcsQ0FBQyxRQUFRLEVBQUUsQ0FBQyxHQUFHLElBQUksQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDO2dCQUNyRSxNQUFNLGVBQWUsR0FDakIsR0FBRyxDQUFDLEdBQUcsQ0FBQyxZQUFZLEVBQUUsSUFBSSxDQUFDLEtBQUssQ0FBQyxFQUM3QixHQUFHLENBQUMsTUFBTSxDQUFDLFFBQVEsQ0FBQyxFQUFFLENBQUMsR0FBRyxJQUFJLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQztnQkFFL0MsTUFBTSx3QkFBd0IsR0FBRyxHQUFHLENBQUMsY0FBYyxFQUFFLGdCQUFnQixDQUFDLENBQUM7Z0JBQ3ZFLE1BQU0seUJBQXlCLEdBQzNCLEdBQUcsQ0FBQyxlQUFlLEVBQUUsZ0JBQWdCLENBQUMsQ0FBQztnQkFFM0MsV0FBVyxDQUFDLE1BQU0sQ0FBQyxjQUFjLENBQUMsQ0FBQztnQkFDbkMsWUFBWSxDQUFDLE1BQU0sQ0FBQyxlQUFlLENBQUMsQ0FBQztnQkFFckMsTUFBTSxRQUFRLEdBQ1YsR0FBRyxDQUFDLEdBQUcsQ0FBQyxHQUFHLENBQUMsd0JBQXdCLEVBQ3hCLEdBQUcsQ0FBQyxJQUFJLENBQUMseUJBQXlCLENBQUMsRUFBRSxJQUFJLENBQUMsT0FBTyxDQUFDLENBQUMsRUFDdkQsQ0FBQyxJQUFJLENBQUMsWUFBWSxDQUFDLEVBQ3ZCLEtBQUssQ0FBQyxDQUFDO2dCQUNmLEtBQUssQ0FBQyxNQUFNLENBQUMsUUFBUSxDQUFDLENBQUM7WUFDekIsQ0FBQyxDQUFDLENBQUM7WUFFSCxJQUFJLENBQUMsUUFBUSxDQUFDLE1BQU0sQ0FBQyxHQUFHLENBQUMsSUFBSSxDQUFDLFFBQVEsRUFBRSxJQUFJLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQztZQUNyRCxJQUFJLENBQUMsUUFBUSxDQUFDLE1BQU0sQ0FBQyxHQUFHLENBQUMsSUFBSSxDQUFDLFFBQVEsRUFBRSxJQUFJLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQztRQUN2RCxDQUFDLENBQUMsQ0FBQztRQUNILElBQUksQ0FBQyxtQkFBbUIsRUFBRSxDQUFDO0lBQzdCLENBQUM7SUFFUSxPQUFPO1FBQ2QsSUFBSSxDQUFDLFFBQVEsQ0FBQyxPQUFPLEVBQUUsQ0FBQztRQUN4QixJQUFJLENBQUMsUUFBUSxDQUFDLE9BQU8sRUFBRSxDQUFDO1FBRXhCLElBQUksSUFBSSxDQUFDLHNCQUFzQixJQUFJLElBQUksRUFBRTtZQUN2QyxPQUFPLENBQUMsSUFBSSxDQUFDLHNCQUFzQixDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxRQUFRLENBQUMsQ0FBQyxDQUFDO1NBQzNEO1FBQ0QsSUFBSSxJQUFJLENBQUMsdUJBQXVCLElBQUksSUFBSSxFQUFFO1lBQ3hDLE9BQU8sQ0FBQyxJQUFJLENBQUMsdUJBQXVCLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLFFBQVEsQ0FBQyxDQUFDLENBQUM7U0FDNUQ7SUFDSCxDQUFDO0lBRVEsS0FBSyxDQUFDLFVBQVU7UUFDdkIsMENBQTBDO1FBQzFDLE1BQU0sU0FBUyxHQUNYLENBQUMsR0FBRyxJQUFJLENBQUMsc0JBQXNCLEVBQUUsR0FBRyxJQUFJLENBQUMsdUJBQXVCLENBQUMsQ0FBQztRQUN0RSxPQUFPLENBQUMsTUFBTSxJQUFJLENBQUMsY0FBYyxFQUFFLENBQUMsQ0FBQyxNQUFNLENBQ3ZDLFNBQVMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUMsSUFBSSxFQUFFLENBQUMsQ0FBQyxZQUFZLEVBQUUsTUFBTSxFQUFFLENBQUMsQ0FBQyxRQUFRLEVBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztJQUN4RSxDQUFDO0lBRVEsS0FBSyxDQUFDLFVBQVUsQ0FBQyxZQUEyQjtRQUNuRCxZQUFZLEdBQUcsTUFBTSxJQUFJLENBQUMsaUJBQWlCLENBQUMsWUFBWSxDQUFDLENBQUM7UUFDMUQsSUFBSSxDQUFDLEdBQUcsRUFBRTtZQUNSLElBQUksQ0FBQyxRQUFRLENBQUMsTUFBTSxDQUFDLEdBQUcsQ0FBQyxJQUFJLENBQUMsS0FBSyxFQUFFLElBQUksQ0FBQyxXQUFXLEdBQUcsQ0FBQyxDQUFDLENBQUMsQ0FBQztZQUM1RCxJQUFJLENBQUMsUUFBUSxDQUFDLE1BQU0sQ0FBQyxHQUFHLENBQUMsSUFBSSxDQUFDLEtBQUssRUFBRSxJQUFJLENBQUMsV0FBVyxHQUFHLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDOUQsQ0FBQyxDQUFDLENBQUM7UUFFSCxNQUFNLGFBQWEsR0FBRyxZQUFZLENBQUMsTUFBTSxHQUFHLENBQUMsQ0FBQztRQUM5QyxNQUFNLFNBQVMsR0FBRyxLQUFLLENBQUM7UUFDeEIsSUFBSSxDQUFDLHNCQUFzQjtZQUN2QixZQUFZLENBQUMsS0FBSyxDQUFDLENBQUMsRUFBRSxhQUFhLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDO2dCQUNKLFlBQVksRUFBRSxDQUFDLENBQUMsSUFBSTtnQkFDcEIsUUFBUSxFQUFFLENBQUMsQ0FBQyxNQUFNLENBQUMsUUFBUSxDQUN2QixTQUFTLENBQUM7YUFDZixDQUFDLENBQUMsQ0FBQztRQUNqRCxJQUFJLENBQUMsdUJBQXVCO1lBQ3hCLFlBQVksQ0FBQyxLQUFLLENBQUMsYUFBYSxFQUFFLGFBQWEsR0FBRyxDQUFDLENBQUM7aUJBQy9DLEdBQUcsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUM7Z0JBQ0osWUFBWSxFQUFFLENBQUMsQ0FBQyxJQUFJO2dCQUNwQixRQUFRLEVBQUUsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxRQUFRLENBQUMsU0FBUyxDQUFDO2FBQ3ZDLENBQUMsQ0FBQyxDQUFDO0lBQ25CLENBQUM7SUFFRCxTQUFTO1FBQ1AsT0FBTztZQUNMLGNBQWMsRUFBRSxJQUFJLENBQUMsWUFBWTtZQUNqQyxPQUFPLEVBQUUsSUFBSSxDQUFDLEtBQUs7WUFDbkIsT0FBTyxFQUFFLElBQUksQ0FBQyxLQUFLO1lBQ25CLFNBQVMsRUFBRSxJQUFJLENBQUMsT0FBTztTQUN4QixDQUFDO0lBQ0osQ0FBQztJQUVELGtCQUFrQjtJQUNsQixNQUFNLENBQVUsVUFBVSxDQUN0QixHQUErQixFQUFFLE1BQWtCO1FBQ3JELE9BQU8sSUFBSSxHQUFHLENBQ1YsTUFBTSxDQUFDLGNBQWMsQ0FBQyxFQUFFLE1BQU0sQ0FBQyxPQUFPLENBQUMsRUFBRSxNQUFNLENBQUMsT0FBTyxDQUFDLEVBQ3hELE1BQU0sQ0FBQyxTQUFTLENBQUMsQ0FBQyxDQUFDO0lBQ3pCLENBQUM7Q0FDRiIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDE4IEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtFTkdJTkV9IGZyb20gJy4uL2VuZ2luZSc7XG5pbXBvcnQge2Rpc3Bvc2UsIHRpZHl9IGZyb20gJy4uL2dsb2JhbHMnO1xuaW1wb3J0IHthZGR9IGZyb20gJy4uL29wcy9hZGQnO1xuaW1wb3J0IHtkaXZ9IGZyb20gJy4uL29wcy9kaXYnO1xuaW1wb3J0IHttdWx9IGZyb20gJy4uL29wcy9tdWwnO1xuaW1wb3J0IHtwb3d9IGZyb20gJy4uL29wcy9wb3cnO1xuaW1wb3J0IHtzY2FsYXJ9IGZyb20gJy4uL29wcy9zY2FsYXInO1xuaW1wb3J0IHtzcXJ0fSBmcm9tICcuLi9vcHMvc3FydCc7XG5pbXBvcnQge3NxdWFyZX0gZnJvbSAnLi4vb3BzL3NxdWFyZSc7XG5pbXBvcnQge3N1Yn0gZnJvbSAnLi4vb3BzL3N1Yic7XG5pbXBvcnQge3plcm9zTGlrZX0gZnJvbSAnLi4vb3BzL3plcm9zX2xpa2UnO1xuaW1wb3J0IHtDb25maWdEaWN0LCBTZXJpYWxpemFibGUsIFNlcmlhbGl6YWJsZUNvbnN0cnVjdG9yfSBmcm9tICcuLi9zZXJpYWxpemF0aW9uJztcbmltcG9ydCB7VmFyaWFibGV9IGZyb20gJy4uL3RlbnNvcic7XG5pbXBvcnQge05hbWVkVGVuc29yLCBOYW1lZFZhcmlhYmxlTWFwfSBmcm9tICcuLi90ZW5zb3JfdHlwZXMnO1xuXG5pbXBvcnQge09wdGltaXplciwgT3B0aW1pemVyVmFyaWFibGV9IGZyb20gJy4vb3B0aW1pemVyJztcblxuZXhwb3J0IGNsYXNzIEFkYW1PcHRpbWl6ZXIgZXh0ZW5kcyBPcHRpbWl6ZXIge1xuICAvKiogQG5vY29sbGFwc2UgKi9cbiAgc3RhdGljIGdldCBjbGFzc05hbWUoKSB7XG4gICAgLy8gTmFtZSBtYXR0ZXJzIGZvciBQeXRob24gY29tcGF0aWJpbGl0eS5cbiAgICAvLyBUaGlzIGlzIGEgZ2V0dGVyIGluc3RlYWQgb2YgYSBwcm9wZXJ0eSBiZWNhdXNlIHdoZW4gaXQncyBhIHByb3BlcnR5LCBpdFxuICAgIC8vIHByZXZlbnRzIHRoZSBlbnRpcmUgY2xhc3MgZnJvbSBiZWluZyB0cmVlLXNoYWtlbi5cbiAgICByZXR1cm4gJ0FkYW0nO1xuICB9XG4gIHByaXZhdGUgYWNjQmV0YTE6IFZhcmlhYmxlO1xuICBwcml2YXRlIGFjY0JldGEyOiBWYXJpYWJsZTtcblxuICBwcml2YXRlIGFjY3VtdWxhdGVkRmlyc3RNb21lbnQ6IE9wdGltaXplclZhcmlhYmxlW10gPSBbXTtcbiAgcHJpdmF0ZSBhY2N1bXVsYXRlZFNlY29uZE1vbWVudDogT3B0aW1pemVyVmFyaWFibGVbXSA9IFtdO1xuXG4gIGNvbnN0cnVjdG9yKFxuICAgICAgcHJvdGVjdGVkIGxlYXJuaW5nUmF0ZTogbnVtYmVyLCBwcm90ZWN0ZWQgYmV0YTE6IG51bWJlcixcbiAgICAgIHByb3RlY3RlZCBiZXRhMjogbnVtYmVyLCBwcm90ZWN0ZWQgZXBzaWxvbjogbnVtYmVyID0gbnVsbCkge1xuICAgIHN1cGVyKCk7XG4gICAgdGlkeSgoKSA9PiB7XG4gICAgICAvLyBhY2NCKiB3aWxsIGJlIHVwZGF0ZWQgYnkgYmF0Y2guXG4gICAgICB0aGlzLmFjY0JldGExID0gc2NhbGFyKGJldGExKS52YXJpYWJsZSgpO1xuICAgICAgdGhpcy5hY2NCZXRhMiA9IHNjYWxhcihiZXRhMikudmFyaWFibGUoKTtcbiAgICB9KTtcblxuICAgIGlmIChlcHNpbG9uID09IG51bGwpIHtcbiAgICAgIHRoaXMuZXBzaWxvbiA9IEVOR0lORS5iYWNrZW5kLmVwc2lsb24oKTtcbiAgICB9XG4gIH1cblxuICBhcHBseUdyYWRpZW50cyh2YXJpYWJsZUdyYWRpZW50czogTmFtZWRWYXJpYWJsZU1hcHxOYW1lZFRlbnNvcltdKSB7XG4gICAgY29uc3QgdmFyTmFtZXMgPSBBcnJheS5pc0FycmF5KHZhcmlhYmxlR3JhZGllbnRzKSA/XG4gICAgICAgIHZhcmlhYmxlR3JhZGllbnRzLm1hcCh2ID0+IHYubmFtZSkgOlxuICAgICAgICBPYmplY3Qua2V5cyh2YXJpYWJsZUdyYWRpZW50cyk7XG4gICAgdGlkeSgoKSA9PiB7XG4gICAgICBjb25zdCBvbmVNaW51c0FjY0JldGExID0gc3ViKDEsIHRoaXMuYWNjQmV0YTEpO1xuICAgICAgY29uc3Qgb25lTWludXNBY2NCZXRhMiA9IHN1YigxLCB0aGlzLmFjY0JldGEyKTtcblxuICAgICAgdmFyTmFtZXMuZm9yRWFjaCgobmFtZSwgaSkgPT4ge1xuICAgICAgICBjb25zdCB2YWx1ZSA9IEVOR0lORS5yZWdpc3RlcmVkVmFyaWFibGVzW25hbWVdO1xuICAgICAgICBjb25zdCB0cmFpbmFibGUgPSBmYWxzZTtcbiAgICAgICAgaWYgKHRoaXMuYWNjdW11bGF0ZWRGaXJzdE1vbWVudFtpXSA9PSBudWxsKSB7XG4gICAgICAgICAgdGhpcy5hY2N1bXVsYXRlZEZpcnN0TW9tZW50W2ldID0ge1xuICAgICAgICAgICAgb3JpZ2luYWxOYW1lOiBgJHtuYW1lfS9tYCxcbiAgICAgICAgICAgIHZhcmlhYmxlOiB0aWR5KCgpID0+IHplcm9zTGlrZSh2YWx1ZSkudmFyaWFibGUodHJhaW5hYmxlKSlcbiAgICAgICAgICB9O1xuICAgICAgICB9XG4gICAgICAgIGlmICh0aGlzLmFjY3VtdWxhdGVkU2Vjb25kTW9tZW50W2ldID09IG51bGwpIHtcbiAgICAgICAgICB0aGlzLmFjY3VtdWxhdGVkU2Vjb25kTW9tZW50W2ldID0ge1xuICAgICAgICAgICAgb3JpZ2luYWxOYW1lOiBgJHtuYW1lfS92YCxcbiAgICAgICAgICAgIHZhcmlhYmxlOiB0aWR5KCgpID0+IHplcm9zTGlrZSh2YWx1ZSkudmFyaWFibGUodHJhaW5hYmxlKSlcbiAgICAgICAgICB9O1xuICAgICAgICB9XG5cbiAgICAgICAgY29uc3QgZ3JhZGllbnQgPSBBcnJheS5pc0FycmF5KHZhcmlhYmxlR3JhZGllbnRzKSA/XG4gICAgICAgICAgICB2YXJpYWJsZUdyYWRpZW50c1tpXS50ZW5zb3IgOlxuICAgICAgICAgICAgdmFyaWFibGVHcmFkaWVudHNbbmFtZV07XG4gICAgICAgIGlmIChncmFkaWVudCA9PSBudWxsKSB7XG4gICAgICAgICAgcmV0dXJuO1xuICAgICAgICB9XG5cbiAgICAgICAgY29uc3QgZmlyc3RNb21lbnQgPSB0aGlzLmFjY3VtdWxhdGVkRmlyc3RNb21lbnRbaV0udmFyaWFibGU7XG4gICAgICAgIGNvbnN0IHNlY29uZE1vbWVudCA9IHRoaXMuYWNjdW11bGF0ZWRTZWNvbmRNb21lbnRbaV0udmFyaWFibGU7XG5cbiAgICAgICAgY29uc3QgbmV3Rmlyc3RNb21lbnQgPVxuICAgICAgICAgICAgYWRkKG11bChmaXJzdE1vbWVudCwgdGhpcy5iZXRhMSksIG11bChncmFkaWVudCwgMSAtIHRoaXMuYmV0YTEpKTtcbiAgICAgICAgY29uc3QgbmV3U2Vjb25kTW9tZW50ID1cbiAgICAgICAgICAgIGFkZChtdWwoc2Vjb25kTW9tZW50LCB0aGlzLmJldGEyKSxcbiAgICAgICAgICAgICAgICBtdWwoc3F1YXJlKGdyYWRpZW50KSwgMSAtIHRoaXMuYmV0YTIpKTtcblxuICAgICAgICBjb25zdCBiaWFzQ29ycmVjdGVkRmlyc3RNb21lbnQgPSBkaXYobmV3Rmlyc3RNb21lbnQsIG9uZU1pbnVzQWNjQmV0YTEpO1xuICAgICAgICBjb25zdCBiaWFzQ29ycmVjdGVkU2Vjb25kTW9tZW50ID1cbiAgICAgICAgICAgIGRpdihuZXdTZWNvbmRNb21lbnQsIG9uZU1pbnVzQWNjQmV0YTIpO1xuXG4gICAgICAgIGZpcnN0TW9tZW50LmFzc2lnbihuZXdGaXJzdE1vbWVudCk7XG4gICAgICAgIHNlY29uZE1vbWVudC5hc3NpZ24obmV3U2Vjb25kTW9tZW50KTtcblxuICAgICAgICBjb25zdCBuZXdWYWx1ZSA9XG4gICAgICAgICAgICBhZGQobXVsKGRpdihiaWFzQ29ycmVjdGVkRmlyc3RNb21lbnQsXG4gICAgICAgICAgICAgICAgICAgICAgICBhZGQoc3FydChiaWFzQ29ycmVjdGVkU2Vjb25kTW9tZW50KSwgdGhpcy5lcHNpbG9uKSksXG4gICAgICAgICAgICAgICAgICAgIC10aGlzLmxlYXJuaW5nUmF0ZSksXG4gICAgICAgICAgICAgICAgdmFsdWUpO1xuICAgICAgICB2YWx1ZS5hc3NpZ24obmV3VmFsdWUpO1xuICAgICAgfSk7XG5cbiAgICAgIHRoaXMuYWNjQmV0YTEuYXNzaWduKG11bCh0aGlzLmFjY0JldGExLCB0aGlzLmJldGExKSk7XG4gICAgICB0aGlzLmFjY0JldGEyLmFzc2lnbihtdWwodGhpcy5hY2NCZXRhMiwgdGhpcy5iZXRhMikpO1xuICAgIH0pO1xuICAgIHRoaXMuaW5jcmVtZW50SXRlcmF0aW9ucygpO1xuICB9XG5cbiAgb3ZlcnJpZGUgZGlzcG9zZSgpOiB2b2lkIHtcbiAgICB0aGlzLmFjY0JldGExLmRpc3Bvc2UoKTtcbiAgICB0aGlzLmFjY0JldGEyLmRpc3Bvc2UoKTtcblxuICAgIGlmICh0aGlzLmFjY3VtdWxhdGVkRmlyc3RNb21lbnQgIT0gbnVsbCkge1xuICAgICAgZGlzcG9zZSh0aGlzLmFjY3VtdWxhdGVkRmlyc3RNb21lbnQubWFwKHYgPT4gdi52YXJpYWJsZSkpO1xuICAgIH1cbiAgICBpZiAodGhpcy5hY2N1bXVsYXRlZFNlY29uZE1vbWVudCAhPSBudWxsKSB7XG4gICAgICBkaXNwb3NlKHRoaXMuYWNjdW11bGF0ZWRTZWNvbmRNb21lbnQubWFwKHYgPT4gdi52YXJpYWJsZSkpO1xuICAgIH1cbiAgfVxuXG4gIG92ZXJyaWRlIGFzeW5jIGdldFdlaWdodHMoKTogUHJvbWlzZTxOYW1lZFRlbnNvcltdPiB7XG4gICAgLy8gT3JkZXIgbWF0dGVycyBmb3IgUHl0aG9uIGNvbXBhdGliaWxpdHkuXG4gICAgY29uc3QgdmFyaWFibGVzOiBPcHRpbWl6ZXJWYXJpYWJsZVtdID1cbiAgICAgICAgWy4uLnRoaXMuYWNjdW11bGF0ZWRGaXJzdE1vbWVudCwgLi4udGhpcy5hY2N1bXVsYXRlZFNlY29uZE1vbWVudF07XG4gICAgcmV0dXJuIFthd2FpdCB0aGlzLnNhdmVJdGVyYXRpb25zKCldLmNvbmNhdChcbiAgICAgICAgdmFyaWFibGVzLm1hcCh2ID0+ICh7bmFtZTogdi5vcmlnaW5hbE5hbWUsIHRlbnNvcjogdi52YXJpYWJsZX0pKSk7XG4gIH1cblxuICBvdmVycmlkZSBhc3luYyBzZXRXZWlnaHRzKHdlaWdodFZhbHVlczogTmFtZWRUZW5zb3JbXSk6IFByb21pc2U8dm9pZD4ge1xuICAgIHdlaWdodFZhbHVlcyA9IGF3YWl0IHRoaXMuZXh0cmFjdEl0ZXJhdGlvbnMod2VpZ2h0VmFsdWVzKTtcbiAgICB0aWR5KCgpID0+IHtcbiAgICAgIHRoaXMuYWNjQmV0YTEuYXNzaWduKHBvdyh0aGlzLmJldGExLCB0aGlzLml0ZXJhdGlvbnNfICsgMSkpO1xuICAgICAgdGhpcy5hY2NCZXRhMi5hc3NpZ24ocG93KHRoaXMuYmV0YTIsIHRoaXMuaXRlcmF0aW9uc18gKyAxKSk7XG4gICAgfSk7XG5cbiAgICBjb25zdCB2YXJpYWJsZUNvdW50ID0gd2VpZ2h0VmFsdWVzLmxlbmd0aCAvIDI7XG4gICAgY29uc3QgdHJhaW5hYmxlID0gZmFsc2U7XG4gICAgdGhpcy5hY2N1bXVsYXRlZEZpcnN0TW9tZW50ID1cbiAgICAgICAgd2VpZ2h0VmFsdWVzLnNsaWNlKDAsIHZhcmlhYmxlQ291bnQpLm1hcCh2ID0+ICh7XG4gICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBvcmlnaW5hbE5hbWU6IHYubmFtZSxcbiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHZhcmlhYmxlOiB2LnRlbnNvci52YXJpYWJsZShcbiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0cmFpbmFibGUpXG4gICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgfSkpO1xuICAgIHRoaXMuYWNjdW11bGF0ZWRTZWNvbmRNb21lbnQgPVxuICAgICAgICB3ZWlnaHRWYWx1ZXMuc2xpY2UodmFyaWFibGVDb3VudCwgdmFyaWFibGVDb3VudCAqIDIpXG4gICAgICAgICAgICAubWFwKHYgPT4gKHtcbiAgICAgICAgICAgICAgICAgICBvcmlnaW5hbE5hbWU6IHYubmFtZSxcbiAgICAgICAgICAgICAgICAgICB2YXJpYWJsZTogdi50ZW5zb3IudmFyaWFibGUodHJhaW5hYmxlKVxuICAgICAgICAgICAgICAgICB9KSk7XG4gIH1cblxuICBnZXRDb25maWcoKTogQ29uZmlnRGljdCB7XG4gICAgcmV0dXJuIHtcbiAgICAgICdsZWFybmluZ1JhdGUnOiB0aGlzLmxlYXJuaW5nUmF0ZSxcbiAgICAgICdiZXRhMSc6IHRoaXMuYmV0YTEsXG4gICAgICAnYmV0YTInOiB0aGlzLmJldGEyLFxuICAgICAgJ2Vwc2lsb24nOiB0aGlzLmVwc2lsb24sXG4gICAgfTtcbiAgfVxuXG4gIC8qKiBAbm9jb2xsYXBzZSAqL1xuICBzdGF0aWMgb3ZlcnJpZGUgZnJvbUNvbmZpZzxUIGV4dGVuZHMgU2VyaWFsaXphYmxlPihcbiAgICAgIGNsczogU2VyaWFsaXphYmxlQ29uc3RydWN0b3I8VD4sIGNvbmZpZzogQ29uZmlnRGljdCk6IFQge1xuICAgIHJldHVybiBuZXcgY2xzKFxuICAgICAgICBjb25maWdbJ2xlYXJuaW5nUmF0ZSddLCBjb25maWdbJ2JldGExJ10sIGNvbmZpZ1snYmV0YTInXSxcbiAgICAgICAgY29uZmlnWydlcHNpbG9uJ10pO1xuICB9XG59XG4iXX0=