/**
|
* @license
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { ENGINE } from '../engine';
|
import { dispose, tidy } from '../globals';
|
import { add } from '../ops/add';
|
import { div } from '../ops/div';
|
import { mul } from '../ops/mul';
|
import { pow } from '../ops/pow';
|
import { scalar } from '../ops/scalar';
|
import { sqrt } from '../ops/sqrt';
|
import { square } from '../ops/square';
|
import { sub } from '../ops/sub';
|
import { zerosLike } from '../ops/zeros_like';
|
import { Optimizer } from './optimizer';
|
export class AdamOptimizer extends Optimizer {
|
/** @nocollapse */
|
static get className() {
|
// Name matters for Python compatibility.
|
// This is a getter instead of a property because when it's a property, it
|
// prevents the entire class from being tree-shaken.
|
return 'Adam';
|
}
|
constructor(learningRate, beta1, beta2, epsilon = null) {
|
super();
|
this.learningRate = learningRate;
|
this.beta1 = beta1;
|
this.beta2 = beta2;
|
this.epsilon = epsilon;
|
this.accumulatedFirstMoment = [];
|
this.accumulatedSecondMoment = [];
|
tidy(() => {
|
// accB* will be updated by batch.
|
this.accBeta1 = scalar(beta1).variable();
|
this.accBeta2 = scalar(beta2).variable();
|
});
|
if (epsilon == null) {
|
this.epsilon = ENGINE.backend.epsilon();
|
}
|
}
|
applyGradients(variableGradients) {
|
const varNames = Array.isArray(variableGradients) ?
|
variableGradients.map(v => v.name) :
|
Object.keys(variableGradients);
|
tidy(() => {
|
const oneMinusAccBeta1 = sub(1, this.accBeta1);
|
const oneMinusAccBeta2 = sub(1, this.accBeta2);
|
varNames.forEach((name, i) => {
|
const value = ENGINE.registeredVariables[name];
|
const trainable = false;
|
if (this.accumulatedFirstMoment[i] == null) {
|
this.accumulatedFirstMoment[i] = {
|
originalName: `${name}/m`,
|
variable: tidy(() => zerosLike(value).variable(trainable))
|
};
|
}
|
if (this.accumulatedSecondMoment[i] == null) {
|
this.accumulatedSecondMoment[i] = {
|
originalName: `${name}/v`,
|
variable: tidy(() => zerosLike(value).variable(trainable))
|
};
|
}
|
const gradient = Array.isArray(variableGradients) ?
|
variableGradients[i].tensor :
|
variableGradients[name];
|
if (gradient == null) {
|
return;
|
}
|
const firstMoment = this.accumulatedFirstMoment[i].variable;
|
const secondMoment = this.accumulatedSecondMoment[i].variable;
|
const newFirstMoment = add(mul(firstMoment, this.beta1), mul(gradient, 1 - this.beta1));
|
const newSecondMoment = add(mul(secondMoment, this.beta2), mul(square(gradient), 1 - this.beta2));
|
const biasCorrectedFirstMoment = div(newFirstMoment, oneMinusAccBeta1);
|
const biasCorrectedSecondMoment = div(newSecondMoment, oneMinusAccBeta2);
|
firstMoment.assign(newFirstMoment);
|
secondMoment.assign(newSecondMoment);
|
const newValue = add(mul(div(biasCorrectedFirstMoment, add(sqrt(biasCorrectedSecondMoment), this.epsilon)), -this.learningRate), value);
|
value.assign(newValue);
|
});
|
this.accBeta1.assign(mul(this.accBeta1, this.beta1));
|
this.accBeta2.assign(mul(this.accBeta2, this.beta2));
|
});
|
this.incrementIterations();
|
}
|
dispose() {
|
this.accBeta1.dispose();
|
this.accBeta2.dispose();
|
if (this.accumulatedFirstMoment != null) {
|
dispose(this.accumulatedFirstMoment.map(v => v.variable));
|
}
|
if (this.accumulatedSecondMoment != null) {
|
dispose(this.accumulatedSecondMoment.map(v => v.variable));
|
}
|
}
|
async getWeights() {
|
// Order matters for Python compatibility.
|
const variables = [...this.accumulatedFirstMoment, ...this.accumulatedSecondMoment];
|
return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable })));
|
}
|
async setWeights(weightValues) {
|
weightValues = await this.extractIterations(weightValues);
|
tidy(() => {
|
this.accBeta1.assign(pow(this.beta1, this.iterations_ + 1));
|
this.accBeta2.assign(pow(this.beta2, this.iterations_ + 1));
|
});
|
const variableCount = weightValues.length / 2;
|
const trainable = false;
|
this.accumulatedFirstMoment =
|
weightValues.slice(0, variableCount).map(v => ({
|
originalName: v.name,
|
variable: v.tensor.variable(trainable)
|
}));
|
this.accumulatedSecondMoment =
|
weightValues.slice(variableCount, variableCount * 2)
|
.map(v => ({
|
originalName: v.name,
|
variable: v.tensor.variable(trainable)
|
}));
|
}
|
getConfig() {
|
return {
|
'learningRate': this.learningRate,
|
'beta1': this.beta1,
|
'beta2': this.beta2,
|
'epsilon': this.epsilon,
|
};
|
}
|
/** @nocollapse */
|
static fromConfig(cls, config) {
|
return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon']);
|
}
|
}
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYWRhbV9vcHRpbWl6ZXIuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvcmUvc3JjL29wdGltaXplcnMvYWRhbV9vcHRpbWl6ZXIudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxFQUFDLE1BQU0sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNqQyxPQUFPLEVBQUMsT0FBTyxFQUFFLElBQUksRUFBQyxNQUFNLFlBQVksQ0FBQztBQUN6QyxPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sWUFBWSxDQUFDO0FBQy9CLE9BQU8sRUFBQyxHQUFHLEVBQUMsTUFBTSxZQUFZLENBQUM7QUFDL0IsT0FBTyxFQUFDLEdBQUcsRUFBQyxNQUFNLFlBQVksQ0FBQztBQUMvQixPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sWUFBWSxDQUFDO0FBQy9CLE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxlQUFlLENBQUM7QUFDckMsT0FBTyxFQUFDLElBQUksRUFBQyxNQUFNLGFBQWEsQ0FBQztBQUNqQyxPQUFPLEVBQUMsTUFBTSxFQUFDLE1BQU0sZUFBZSxDQUFDO0FBQ3JDLE9BQU8sRUFBQyxHQUFHLEVBQUMsTUFBTSxZQUFZLENBQUM7QUFDL0IsT0FBTyxFQUFDLFNBQVMsRUFBQyxNQUFNLG1CQUFtQixDQUFDO0FBSzVDLE9BQU8sRUFBQyxTQUFTLEVBQW9CLE1BQU0sYUFBYSxDQUFDO0FBRXpELE1BQU0sT0FBTyxhQUFjLFNBQVEsU0FBUztJQUMxQyxrQkFBa0I7SUFDbEIsTUFBTSxLQUFLLFNBQVM7UUFDbEIseUNBQXlDO1FBQ3pDLDBFQUEwRTtRQUMxRSxvREFBb0Q7UUFDcEQsT0FBTyxNQUFNLENBQUM7SUFDaEIsQ0FBQztJQU9ELFlBQ2MsWUFBb0IsRUFBWSxLQUFhLEVBQzdDLEtBQWEsRUFBWSxVQUFrQixJQUFJO1FBQzNELEtBQUssRUFBRSxDQUFDO1FBRkksaUJBQVksR0FBWixZQUFZLENBQVE7UUFBWSxVQUFLLEdBQUwsS0FBSyxDQUFRO1FBQzdDLFVBQUssR0FBTCxLQUFLLENBQVE7UUFBWSxZQUFPLEdBQVAsT0FBTyxDQUFlO1FBTHJELDJCQUFzQixHQUF3QixFQUFFLENBQUM7UUFDakQsNEJBQXVCLEdBQXdCLEVBQUUsQ0FBQztRQU14RCxJQUFJLENBQUMsR0FBRyxFQUFFO1lBQ1Isa0NBQWtDO1lBQ2xDLElBQUksQ0FBQyxRQUFRLEdBQUcsTUFBTSxDQUFDLEtBQUssQ0FBQyxDQUFDLFFBQVEsRUFBRSxDQUFDO1lBQ3pDLElBQUksQ0FBQyxRQUFRLEdBQUcsTUFBTSxDQUFDLEtBQUssQ0FBQyxDQUFDLFFBQVEsRUFBRSxDQUFDO1FBQzNDLENBQUMsQ0FBQyxDQUFDO1FBRUgsSUFBSSxPQUFPLElBQUksSUFBSSxFQUFFO1lBQ25CLElBQUksQ0FBQyxPQUFPLEdBQUcsTUFBTSxDQUFDLE9BQU8sQ0FBQyxPQUFPLEVBQUUsQ0FBQztTQUN6QztJQUNILENBQUM7SUFFRCxjQUFjLENBQUMsaUJBQWlEO1FBQzlELE1BQU0sUUFBUSxHQUFHLEtBQUssQ0FBQyxPQUFPLENBQUMsaUJBQWlCLENBQUMsQ0FBQyxDQUFDO1lBQy9DLGlCQUFpQixDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDO1lBQ3BDLE1BQU0sQ0FBQyxJQUFJLENBQUMsaUJBQWlCLENBQUMsQ0FBQztRQUNuQyxJQUFJLENBQUMsR0FBRyxFQUFFO1lBQ1IsTUFBTSxnQkFBZ0IsR0FBRyxHQUFHLENBQUMsQ0FBQyxFQUFFLElBQUksQ0FBQyxRQUFRLENBQUMsQ0FBQztZQUMvQyxNQUFNLGdCQUFnQixHQUFHLEdBQUcsQ0FBQyxDQUFDLEVBQUUsSUFBSSxDQUFDLFFBQVEsQ0FBQyxDQUFDO1lBRS9DLFFBQVEsQ0FBQyxPQUFPLENBQUMsQ0FBQyxJQUFJLEVBQUUsQ0FBQyxFQUFFLEVBQUU7Z0JBQzNCLE1BQU0sS0FBSyxHQUFHLE1BQU0sQ0FBQyxtQkFBbUIsQ0FBQyxJQUFJLENBQUMsQ0FBQztnQkFDL0MsTUFBTSxTQUFTLEdBQUcsS0FBSyxDQUFDO2dCQUN4QixJQUFJLElBQUksQ0FBQyxzQkFBc0IsQ0FBQyxDQUFDLENBQUMsSUFBSSxJQUFJLEVBQUU7b0JBQzFDLElBQUksQ0FBQyxzQkFBc0IsQ0FBQyxDQUFDLENBQUMsR0FBRzt3QkFDL0IsWUFBWSxFQUFFLEdBQUcsSUFBSSxJQUFJO3dCQUN6QixRQUFRLEVBQUUsSUFBSSxDQUFDLEdBQUcsRUFBRSxDQUFDLFNBQVMsQ0FBQyxLQUFLLENBQUMsQ0FBQyxRQUFRLENBQUMsU0FBUyxDQUFDLENBQUM7cUJBQzNELENBQUM7aUJBQ0g7Z0JBQ0QsSUFBSSxJQUFJLENBQUMsdUJBQXVCLENBQUMsQ0FBQyxDQUFDLElBQUksSUFBSSxFQUFFO29CQUMzQyxJQUFJLENBQUMsdUJBQXVCLENBQUMsQ0FBQyxDQUFDLEdBQUc7d0JBQ2hDLFlBQVksRUFBRSxHQUFHLElBQUksSUFBSTt3QkFDekIsUUFBUSxFQUFFLElBQUksQ0FBQyxHQUFHLEVBQUUsQ0FBQyxTQUFTLENBQUMsS0FBSyxDQUFDLENBQUMsUUFBUSxDQUFDLFNBQVMsQ0FBQyxDQUFDO3FCQUMzRCxDQUFDO2lCQUNIO2dCQUVELE1BQU0sUUFBUSxHQUFHLEtBQUssQ0FBQyxPQUFPLENBQUMsaUJBQWlCLENBQUMsQ0FBQyxDQUFDO29CQUMvQyxpQkFBaUIsQ0FBQyxDQUFDLENBQUMsQ0FBQyxNQUFNLENBQUMsQ0FBQztvQkFDN0IsaUJBQWlCLENBQUMsSUFBSSxDQUFDLENBQUM7Z0JBQzVCLElBQUksUUFBUSxJQUFJLElBQUksRUFBRTtvQkFDcEIsT0FBTztpQkFDUjtnQkFFRCxNQUFNLFdBQVcsR0FBRyxJQUFJLENBQUMsc0JBQXNCLENBQUMsQ0FBQyxDQUFDLENBQUMsUUFBUSxDQUFDO2dCQUM1RCxNQUFNLFlBQVksR0FBRyxJQUFJLENBQUMsdUJBQXVCLENBQUMsQ0FBQyxDQUFDLENBQUMsUUFBUSxDQUFDO2dCQUU5RCxNQUFNLGNBQWMsR0FDaEIsR0FBRyxDQUFDLEdBQUcsQ0FBQyxXQUFXLEVBQUUsSUFBSSxDQUFDLEtBQUssQ0FBQyxFQUFFLEdBQUcsQ0FBQyxRQUFRLEVBQUUsQ0FBQyxHQUFHLElBQUksQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDO2dCQUNyRSxNQUFNLGVBQWUsR0FDakIsR0FBRyxDQUFDLEdBQUcsQ0FBQyxZQUFZLEVBQUUsSUFBSSxDQUFDLEtBQUssQ0FBQyxFQUM3QixHQUFHLENBQUMsTUFBTSxDQUFDLFFBQVEsQ0FBQyxFQUFFLENBQUMsR0FBRyxJQUFJLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQztnQkFFL0MsTUFBTSx3QkFBd0IsR0FBRyxHQUFHLENBQUMsY0FBYyxFQUFFLGdCQUFnQixDQUFDLENBQUM7Z0JBQ3ZFLE1BQU0seUJBQXlCLEdBQzNCLEdBQUcsQ0FBQyxlQUFlLEVBQUUsZ0JBQWdCLENBQUMsQ0FBQztnQkFFM0MsV0FBVyxDQUFDLE1BQU0sQ0FBQyxjQUFjLENBQUMsQ0FBQztnQkFDbkMsWUFBWSxDQUFDLE1BQU0sQ0FBQyxlQUFlLENBQUMsQ0FBQztnQkFFckMsTUFBTSxRQUFRLEdBQ1YsR0FBRyxDQUFDLEdBQUcsQ0FBQyxHQUFHLENBQUMsd0JBQXdCLEVBQ3hCLEdBQUcsQ0FBQyxJQUFJLENBQUMseUJBQXlCLENBQUMsRUFBRSxJQUFJLENBQUMsT0FBTyxDQUFDLENBQUMsRUFDdkQsQ0FBQyxJQUFJLENBQUMsWUFBWSxDQUFDLEVBQ3ZCLEtBQUssQ0FBQyxDQUFDO2dCQUNmLEtBQUssQ0FBQyxNQUFNLENBQUMsUUFBUSxDQUFDLENBQUM7WUFDekIsQ0FBQyxDQUFDLENBQUM7WUFFSCxJQUFJLENBQUMsUUFBUSxDQUFDLE1BQU0sQ0FBQyxHQUFHLENBQUMsSUFBSSxDQUFDLFFBQVEsRUFBRSxJQUFJLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQztZQUNyRCxJQUFJLENBQUMsUUFBUSxDQUFDLE1BQU0sQ0FBQyxHQUFHLENBQUMsSUFBSSxDQUFDLFFBQVEsRUFBRSxJQUFJLENBQUMsS0FBSyxDQUFDLENBQUMsQ0FBQztRQUN2RCxDQUFDLENBQUMsQ0FBQztRQUNILElBQUksQ0FBQyxtQkFBbUIsRUFBRSxDQUFDO0lBQzdCLENBQUM7SUFFUSxPQUFPO1FBQ2QsSUFBSSxDQUFDLFFBQVEsQ0FBQyxPQUFPLEVBQUUsQ0FBQztRQUN4QixJQUFJLENBQUMsUUFBUSxDQUFDLE9BQU8sRUFBRSxDQUFDO1FBRXhCLElBQUksSUFBSSxDQUFDLHNCQUFzQixJQUFJLElBQUksRUFBRTtZQUN2QyxPQUFPLENBQUMsSUFBSSxDQUFDLHNCQUFzQixDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxRQUFRLENBQUMsQ0FBQyxDQUFDO1NBQzNEO1FBQ0QsSUFBSSxJQUFJLENBQUMsdUJBQXVCLElBQUksSUFBSSxFQUFFO1lBQ3hDLE9BQU8sQ0FBQyxJQUFJLENBQUMsdUJBQXVCLENBQUMsR0FBRyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLFFBQVEsQ0FBQyxDQUFDLENBQUM7U0FDNUQ7SUFDSCxDQUFDO0lBRVEsS0FBSyxDQUFDLFVBQVU7UUFDdkIsMENBQTBDO1FBQzFDLE1BQU0sU0FBUyxHQUNYLENBQUMsR0FBRyxJQUFJLENBQUMsc0JBQXNCLEVBQUUsR0FBRyxJQUFJLENBQUMsdUJBQXVCLENBQUMsQ0FBQztRQUN0RSxPQUFPLENBQUMsTUFBTSxJQUFJLENBQUMsY0FBYyxFQUFFLENBQUMsQ0FBQyxNQUFNLENBQ3ZDLFNBQVMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUMsSUFBSSxFQUFFLENBQUMsQ0FBQyxZQUFZLEVBQUUsTUFBTSxFQUFFLENBQUMsQ0FBQyxRQUFRLEVBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztJQUN4RSxDQUFDO0lBRVEsS0FBSyxDQUFDLFVBQVUsQ0FBQyxZQUEyQjtRQUNuRCxZQUFZLEdBQUcsTUFBTSxJQUFJLENBQUMsaUJBQWlCLENBQUMsWUFBWSxDQUFDLENBQUM7UUFDMUQsSUFBSSxDQUFDLEdBQUcsRUFBRTtZQUNSLElBQUksQ0FBQyxRQUFRLENBQUMsTUFBTSxDQUFDLEdBQUcsQ0FBQyxJQUFJLENBQUMsS0FBSyxFQUFFLElBQUksQ0FBQyxXQUFXLEdBQUcsQ0FBQyxDQUFDLENBQUMsQ0FBQztZQUM1RCxJQUFJLENBQUMsUUFBUSxDQUFDLE1BQU0sQ0FBQyxHQUFHLENBQUMsSUFBSSxDQUFDLEtBQUssRUFBRSxJQUFJLENBQUMsV0FBVyxHQUFHLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDOUQsQ0FBQyxDQUFDLENBQUM7UUFFSCxNQUFNLGFBQWEsR0FBRyxZQUFZLENBQUMsTUFBTSxHQUFHLENBQUMsQ0FBQztRQUM5QyxNQUFNLFNBQVMsR0FBRyxLQUFLLENBQUM7UUFDeEIsSUFBSSxDQUFDLHNCQUFzQjtZQUN2QixZQUFZLENBQUMsS0FBSyxDQUFDLENBQUMsRUFBRSxhQUFhLENBQUMsQ0FBQyxHQUFHLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDO2dCQUNKLFlBQVksRUFBRSxDQUFDLENBQUMsSUFBSTtnQkFDcEIsUUFBUSxFQUFFLENBQUMsQ0FBQyxNQUFNLENBQUMsUUFBUSxDQUN2QixTQUFTLENBQUM7YUFDZixDQUFDLENBQUMsQ0FBQztRQUNqRCxJQUFJLENBQUMsdUJBQXVCO1lBQ3hCLFlBQVksQ0FBQyxLQUFLLENBQUMsYUFBYSxFQUFFLGFBQWEsR0FBRyxDQUFDLENBQUM7aUJBQy9DLEdBQUcsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUM7Z0JBQ0osWUFBWSxFQUFFLENBQUMsQ0FBQyxJQUFJO2dCQUNwQixRQUFRLEVBQUUsQ0FBQyxDQUFDLE1BQU0sQ0FBQyxRQUFRLENBQUMsU0FBUyxDQUFDO2FBQ3ZDLENBQUMsQ0FBQyxDQUFDO0lBQ25CLENBQUM7SUFFRCxTQUFTO1FBQ1AsT0FBTztZQUNMLGNBQWMsRUFBRSxJQUFJLENBQUMsWUFBWTtZQUNqQyxPQUFPLEVBQUUsSUFBSSxDQUFDLEtBQUs7WUFDbkIsT0FBTyxFQUFFLElBQUksQ0FBQyxLQUFLO1lBQ25CLFNBQVMsRUFBRSxJQUFJLENBQUMsT0FBTztTQUN4QixDQUFDO0lBQ0osQ0FBQztJQUVELGtCQUFrQjtJQUNsQixNQUFNLENBQVUsVUFBVSxDQUN0QixHQUErQixFQUFFLE1BQWtCO1FBQ3JELE9BQU8sSUFBSSxHQUFHLENBQ1YsTUFBTSxDQUFDLGNBQWMsQ0FBQyxFQUFFLE1BQU0sQ0FBQyxPQUFPLENBQUMsRUFBRSxNQUFNLENBQUMsT0FBTyxDQUFDLEVBQ3hELE1BQU0sQ0FBQyxTQUFTLENBQUMsQ0FBQyxDQUFDO0lBQ3pCLENBQUM7Q0FDRiIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDE4IEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtFTkdJTkV9IGZyb20gJy4uL2VuZ2luZSc7XG5pbXBvcnQge2Rpc3Bvc2UsIHRpZHl9IGZyb20gJy4uL2dsb2JhbHMnO1xuaW1wb3J0IHthZGR9IGZyb20gJy4uL29wcy9hZGQnO1xuaW1wb3J0IHtkaXZ9IGZyb20gJy4uL29wcy9kaXYnO1xuaW1wb3J0IHttdWx9IGZyb20gJy4uL29wcy9tdWwnO1xuaW1wb3J0IHtwb3d9IGZyb20gJy4uL29wcy9wb3cnO1xuaW1wb3J0IHtzY2FsYXJ9IGZyb20gJy4uL29wcy9zY2FsYXInO1xuaW1wb3J0IHtzcXJ0fSBmcm9tICcuLi9vcHMvc3FydCc7XG5pbXBvcnQge3NxdWFyZX0gZnJvbSAnLi4vb3BzL3NxdWFyZSc7XG5pbXBvcnQge3N1Yn0gZnJvbSAnLi4vb3BzL3N1Yic7XG5pbXBvcnQge3plcm9zTGlrZX0gZnJvbSAnLi4vb3BzL3plcm9zX2xpa2UnO1xuaW1wb3J0IHtDb25maWdEaWN0LCBTZXJpYWxpemFibGUsIFNlcmlhbGl6YWJsZUNvbnN0cnVjdG9yfSBmcm9tICcuLi9zZXJpYWxpemF0aW9uJztcbmltcG9ydCB7VmFyaWFibGV9IGZyb20gJy4uL3RlbnNvcic7XG5pbXBvcnQge05hbWVkVGVuc29yLCBOYW1lZFZhcmlhYmxlTWFwfSBmcm9tICcuLi90ZW5zb3JfdHlwZXMnO1xuXG5pbXBvcnQge09wdGltaXplciwgT3B0aW1pemVyVmFyaWFibGV9IGZyb20gJy4vb3B0aW1pemVyJztcblxuZXhwb3J0IGNsYXNzIEFkYW1PcHRpbWl6ZXIgZXh0ZW5kcyBPcHRpbWl6ZXIge1xuICAvKiogQG5vY29sbGFwc2UgKi9cbiAgc3RhdGljIGdldCBjbGFzc05hbWUoKSB7XG4gICAgLy8gTmFtZSBtYXR0ZXJzIGZvciBQeXRob24gY29tcGF0aWJpbGl0eS5cbiAgICAvLyBUaGlzIGlzIGEgZ2V0dGVyIGluc3RlYWQgb2YgYSBwcm9wZXJ0eSBiZWNhdXNlIHdoZW4gaXQncyBhIHByb3BlcnR5LCBpdFxuICAgIC8vIHByZXZlbnRzIHRoZSBlbnRpcmUgY2xhc3MgZnJvbSBiZWluZyB0cmVlLXNoYWtlbi5cbiAgICByZXR1cm4gJ0FkYW0nO1xuICB9XG4gIHByaXZhdGUgYWNjQmV0YTE6IFZhcmlhYmxlO1xuICBwcml2YXRlIGFjY0JldGEyOiBWYXJpYWJsZTtcblxuICBwcml2YXRlIGFjY3VtdWxhdGVkRmlyc3RNb21lbnQ6IE9wdGltaXplclZhcmlhYmxlW10gPSBbXTtcbiAgcHJpdmF0ZSBhY2N1bXVsYXRlZFNlY29uZE1vbWVudDogT3B0aW1pemVyVmFyaWFibGVbXSA9IFtdO1xuXG4gIGNvbnN0cnVjdG9yKFxuICAgICAgcHJvdGVjdGVkIGxlYXJuaW5nUmF0ZTogbnVtYmVyLCBwcm90ZWN0ZWQgYmV0YTE6IG51bWJlcixcbiAgICAgIHByb3RlY3RlZCBiZXRhMjogbnVtYmVyLCBwcm90ZWN0ZWQgZXBzaWxvbjogbnVtYmVyID0gbnVsbCkge1xuICAgIHN1cGVyKCk7XG4gICAgdGlkeSgoKSA9PiB7XG4gICAgICAvLyBhY2NCKiB3aWxsIGJlIHVwZGF0ZWQgYnkgYmF0Y2guXG4gICAgICB0aGlzLmFjY0JldGExID0gc2NhbGFyKGJldGExKS52YXJpYWJsZSgpO1xuICAgICAgdGhpcy5hY2NCZXRhMiA9IHNjYWxhcihiZXRhMikudmFyaWFibGUoKTtcbiAgICB9KTtcblxuICAgIGlmIChlcHNpbG9uID09IG51bGwpIHtcbiAgICAgIHRoaXMuZXBzaWxvbiA9IEVOR0lORS5iYWNrZW5kLmVwc2lsb24oKTtcbiAgICB9XG4gIH1cblxuICBhcHBseUdyYWRpZW50cyh2YXJpYWJsZUdyYWRpZW50czogTmFtZWRWYXJpYWJsZU1hcHxOYW1lZFRlbnNvcltdKSB7XG4gICAgY29uc3QgdmFyTmFtZXMgPSBBcnJheS5pc0FycmF5KHZhcmlhYmxlR3JhZGllbnRzKSA/XG4gICAgICAgIHZhcmlhYmxlR3JhZGllbnRzLm1hcCh2ID0+IHYubmFtZSkgOlxuICAgICAgICBPYmplY3Qua2V5cyh2YXJpYWJsZUdyYWRpZW50cyk7XG4gICAgdGlkeSgoKSA9PiB7XG4gICAgICBjb25zdCBvbmVNaW51c0FjY0JldGExID0gc3ViKDEsIHRoaXMuYWNjQmV0YTEpO1xuICAgICAgY29uc3Qgb25lTWludXNBY2NCZXRhMiA9IHN1YigxLCB0aGlzLmFjY0JldGEyKTtcblxuICAgICAgdmFyTmFtZXMuZm9yRWFjaCgobmFtZSwgaSkgPT4ge1xuICAgICAgICBjb25zdCB2YWx1ZSA9IEVOR0lORS5yZWdpc3RlcmVkVmFyaWFibGVzW25hbWVdO1xuICAgICAgICBjb25zdCB0cmFpbmFibGUgPSBmYWxzZTtcbiAgICAgICAgaWYgKHRoaXMuYWNjdW11bGF0ZWRGaXJzdE1vbWVudFtpXSA9PSBudWxsKSB7XG4gICAgICAgICAgdGhpcy5hY2N1bXVsYXRlZEZpcnN0TW9tZW50W2ldID0ge1xuICAgICAgICAgICAgb3JpZ2luYWxOYW1lOiBgJHtuYW1lfS9tYCxcbiAgICAgICAgICAgIHZhcmlhYmxlOiB0aWR5KCgpID0+IHplcm9zTGlrZSh2YWx1ZSkudmFyaWFibGUodHJhaW5hYmxlKSlcbiAgICAgICAgICB9O1xuICAgICAgICB9XG4gICAgICAgIGlmICh0aGlzLmFjY3VtdWxhdGVkU2Vjb25kTW9tZW50W2ldID09IG51bGwpIHtcbiAgICAgICAgICB0aGlzLmFjY3VtdWxhdGVkU2Vjb25kTW9tZW50W2ldID0ge1xuICAgICAgICAgICAgb3JpZ2luYWxOYW1lOiBgJHtuYW1lfS92YCxcbiAgICAgICAgICAgIHZhcmlhYmxlOiB0aWR5KCgpID0+IHplcm9zTGlrZSh2YWx1ZSkudmFyaWFibGUodHJhaW5hYmxlKSlcbiAgICAgICAgICB9O1xuICAgICAgICB9XG5cbiAgICAgICAgY29uc3QgZ3JhZGllbnQgPSBBcnJheS5pc0FycmF5KHZhcmlhYmxlR3JhZGllbnRzKSA/XG4gICAgICAgICAgICB2YXJpYWJsZUdyYWRpZW50c1tpXS50ZW5zb3IgOlxuICAgICAgICAgICAgdmFyaWFibGVHcmFkaWVudHNbbmFtZV07XG4gICAgICAgIGlmIChncmFkaWVudCA9PSBudWxsKSB7XG4gICAgICAgICAgcmV0dXJuO1xuICAgICAgICB9XG5cbiAgICAgICAgY29uc3QgZmlyc3RNb21lbnQgPSB0aGlzLmFjY3VtdWxhdGVkRmlyc3RNb21lbnRbaV0udmFyaWFibGU7XG4gICAgICAgIGNvbnN0IHNlY29uZE1vbWVudCA9IHRoaXMuYWNjdW11bGF0ZWRTZWNvbmRNb21lbnRbaV0udmFyaWFibGU7XG5cbiAgICAgICAgY29uc3QgbmV3Rmlyc3RNb21lbnQgPVxuICAgICAgICAgICAgYWRkKG11bChmaXJzdE1vbWVudCwgdGhpcy5iZXRhMSksIG11bChncmFkaWVudCwgMSAtIHRoaXMuYmV0YTEpKTtcbiAgICAgICAgY29uc3QgbmV3U2Vjb25kTW9tZW50ID1cbiAgICAgICAgICAgIGFkZChtdWwoc2Vjb25kTW9tZW50LCB0aGlzLmJldGEyKSxcbiAgICAgICAgICAgICAgICBtdWwoc3F1YXJlKGdyYWRpZW50KSwgMSAtIHRoaXMuYmV0YTIpKTtcblxuICAgICAgICBjb25zdCBiaWFzQ29ycmVjdGVkRmlyc3RNb21lbnQgPSBkaXYobmV3Rmlyc3RNb21lbnQsIG9uZU1pbnVzQWNjQmV0YTEpO1xuICAgICAgICBjb25zdCBiaWFzQ29ycmVjdGVkU2Vjb25kTW9tZW50ID1cbiAgICAgICAgICAgIGRpdihuZXdTZWNvbmRNb21lbnQsIG9uZU1pbnVzQWNjQmV0YTIpO1xuXG4gICAgICAgIGZpcnN0TW9tZW50LmFzc2lnbihuZXdGaXJzdE1vbWVudCk7XG4gICAgICAgIHNlY29uZE1vbWVudC5hc3NpZ24obmV3U2Vjb25kTW9tZW50KTtcblxuICAgICAgICBjb25zdCBuZXdWYWx1ZSA9XG4gICAgICAgICAgICBhZGQobXVsKGRpdihiaWFzQ29ycmVjdGVkRmlyc3RNb21lbnQsXG4gICAgICAgICAgICAgICAgICAgICAgICBhZGQoc3FydChiaWFzQ29ycmVjdGVkU2Vjb25kTW9tZW50KSwgdGhpcy5lcHNpbG9uKSksXG4gICAgICAgICAgICAgICAgICAgIC10aGlzLmxlYXJuaW5nUmF0ZSksXG4gICAgICAgICAgICAgICAgdmFsdWUpO1xuICAgICAgICB2YWx1ZS5hc3NpZ24obmV3VmFsdWUpO1xuICAgICAgfSk7XG5cbiAgICAgIHRoaXMuYWNjQmV0YTEuYXNzaWduKG11bCh0aGlzLmFjY0JldGExLCB0aGlzLmJldGExKSk7XG4gICAgICB0aGlzLmFjY0JldGEyLmFzc2lnbihtdWwodGhpcy5hY2NCZXRhMiwgdGhpcy5iZXRhMikpO1xuICAgIH0pO1xuICAgIHRoaXMuaW5jcmVtZW50SXRlcmF0aW9ucygpO1xuICB9XG5cbiAgb3ZlcnJpZGUgZGlzcG9zZSgpOiB2b2lkIHtcbiAgICB0aGlzLmFjY0JldGExLmRpc3Bvc2UoKTtcbiAgICB0aGlzLmFjY0JldGEyLmRpc3Bvc2UoKTtcblxuICAgIGlmICh0aGlzLmFjY3VtdWxhdGVkRmlyc3RNb21lbnQgIT0gbnVsbCkge1xuICAgICAgZGlzcG9zZSh0aGlzLmFjY3VtdWxhdGVkRmlyc3RNb21lbnQubWFwKHYgPT4gdi52YXJpYWJsZSkpO1xuICAgIH1cbiAgICBpZiAodGhpcy5hY2N1bXVsYXRlZFNlY29uZE1vbWVudCAhPSBudWxsKSB7XG4gICAgICBkaXNwb3NlKHRoaXMuYWNjdW11bGF0ZWRTZWNvbmRNb21lbnQubWFwKHYgPT4gdi52YXJpYWJsZSkpO1xuICAgIH1cbiAgfVxuXG4gIG92ZXJyaWRlIGFzeW5jIGdldFdlaWdodHMoKTogUHJvbWlzZTxOYW1lZFRlbnNvcltdPiB7XG4gICAgLy8gT3JkZXIgbWF0dGVycyBmb3IgUHl0aG9uIGNvbXBhdGliaWxpdHkuXG4gICAgY29uc3QgdmFyaWFibGVzOiBPcHRpbWl6ZXJWYXJpYWJsZVtdID1cbiAgICAgICAgWy4uLnRoaXMuYWNjdW11bGF0ZWRGaXJzdE1vbWVudCwgLi4udGhpcy5hY2N1bXVsYXRlZFNlY29uZE1vbWVudF07XG4gICAgcmV0dXJuIFthd2FpdCB0aGlzLnNhdmVJdGVyYXRpb25zKCldLmNvbmNhdChcbiAgICAgICAgdmFyaWFibGVzLm1hcCh2ID0+ICh7bmFtZTogdi5vcmlnaW5hbE5hbWUsIHRlbnNvcjogdi52YXJpYWJsZX0pKSk7XG4gIH1cblxuICBvdmVycmlkZSBhc3luYyBzZXRXZWlnaHRzKHdlaWdodFZhbHVlczogTmFtZWRUZW5zb3JbXSk6IFByb21pc2U8dm9pZD4ge1xuICAgIHdlaWdodFZhbHVlcyA9IGF3YWl0IHRoaXMuZXh0cmFjdEl0ZXJhdGlvbnMod2VpZ2h0VmFsdWVzKTtcbiAgICB0aWR5KCgpID0+IHtcbiAgICAgIHRoaXMuYWNjQmV0YTEuYXNzaWduKHBvdyh0aGlzLmJldGExLCB0aGlzLml0ZXJhdGlvbnNfICsgMSkpO1xuICAgICAgdGhpcy5hY2NCZXRhMi5hc3NpZ24ocG93KHRoaXMuYmV0YTIsIHRoaXMuaXRlcmF0aW9uc18gKyAxKSk7XG4gICAgfSk7XG5cbiAgICBjb25zdCB2YXJpYWJsZUNvdW50ID0gd2VpZ2h0VmFsdWVzLmxlbmd0aCAvIDI7XG4gICAgY29uc3QgdHJhaW5hYmxlID0gZmFsc2U7XG4gICAgdGhpcy5hY2N1bXVsYXRlZEZpcnN0TW9tZW50ID1cbiAgICAgICAgd2VpZ2h0VmFsdWVzLnNsaWNlKDAsIHZhcmlhYmxlQ291bnQpLm1hcCh2ID0+ICh7XG4gICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBvcmlnaW5hbE5hbWU6IHYubmFtZSxcbiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgIHZhcmlhYmxlOiB2LnRlbnNvci52YXJpYWJsZShcbiAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICB0cmFpbmFibGUpXG4gICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgfSkpO1xuICAgIHRoaXMuYWNjdW11bGF0ZWRTZWNvbmRNb21lbnQgPVxuICAgICAgICB3ZWlnaHRWYWx1ZXMuc2xpY2UodmFyaWFibGVDb3VudCwgdmFyaWFibGVDb3VudCAqIDIpXG4gICAgICAgICAgICAubWFwKHYgPT4gKHtcbiAgICAgICAgICAgICAgICAgICBvcmlnaW5hbE5hbWU6IHYubmFtZSxcbiAgICAgICAgICAgICAgICAgICB2YXJpYWJsZTogdi50ZW5zb3IudmFyaWFibGUodHJhaW5hYmxlKVxuICAgICAgICAgICAgICAgICB9KSk7XG4gIH1cblxuICBnZXRDb25maWcoKTogQ29uZmlnRGljdCB7XG4gICAgcmV0dXJuIHtcbiAgICAgICdsZWFybmluZ1JhdGUnOiB0aGlzLmxlYXJuaW5nUmF0ZSxcbiAgICAgICdiZXRhMSc6IHRoaXMuYmV0YTEsXG4gICAgICAnYmV0YTInOiB0aGlzLmJldGEyLFxuICAgICAgJ2Vwc2lsb24nOiB0aGlzLmVwc2lsb24sXG4gICAgfTtcbiAgfVxuXG4gIC8qKiBAbm9jb2xsYXBzZSAqL1xuICBzdGF0aWMgb3ZlcnJpZGUgZnJvbUNvbmZpZzxUIGV4dGVuZHMgU2VyaWFsaXphYmxlPihcbiAgICAgIGNsczogU2VyaWFsaXphYmxlQ29uc3RydWN0b3I8VD4sIGNvbmZpZzogQ29uZmlnRGljdCk6IFQge1xuICAgIHJldHVybiBuZXcgY2xzKFxuICAgICAgICBjb25maWdbJ2xlYXJuaW5nUmF0ZSddLCBjb25maWdbJ2JldGExJ10sIGNvbmZpZ1snYmV0YTInXSxcbiAgICAgICAgY29uZmlnWydlcHNpbG9uJ10pO1xuICB9XG59XG4iXX0=
|