/**
|
* @license
|
* Copyright 2018 Google LLC
|
*
|
* Use of this source code is governed by an MIT-style
|
* license that can be found in the LICENSE file or at
|
* https://opensource.org/licenses/MIT.
|
* =============================================================================
|
*/
|
import { eye, linalg, mul, ones, randomUniform, scalar, serialization, tidy, truncatedNormal, util, zeros } from '@tensorflow/tfjs-core';
|
import * as K from './backend/tfjs_backend';
|
import { checkDataFormat } from './common';
|
import { NotImplementedError, ValueError } from './errors';
|
import { VALID_DISTRIBUTION_VALUES, VALID_FAN_MODE_VALUES } from './keras_format/initializer_config';
|
import { checkStringTypeUnionValue, deserializeKerasObject, serializeKerasObject } from './utils/generic_utils';
|
import { arrayProd } from './utils/math_utils';
|
export function checkFanMode(value) {
|
checkStringTypeUnionValue(VALID_FAN_MODE_VALUES, 'FanMode', value);
|
}
|
export function checkDistribution(value) {
|
checkStringTypeUnionValue(VALID_DISTRIBUTION_VALUES, 'Distribution', value);
|
}
|
/**
|
* Initializer base class.
|
*
|
* @doc {
|
* heading: 'Initializers', subheading: 'Classes', namespace: 'initializers'}
|
*/
|
export class Initializer extends serialization.Serializable {
|
fromConfigUsesCustomObjects() {
|
return false;
|
}
|
getConfig() {
|
return {};
|
}
|
}
|
class Zeros extends Initializer {
|
apply(shape, dtype) {
|
return zeros(shape, dtype);
|
}
|
}
|
/** @nocollapse */
|
Zeros.className = 'Zeros';
|
export { Zeros };
|
serialization.registerClass(Zeros);
|
class Ones extends Initializer {
|
apply(shape, dtype) {
|
return ones(shape, dtype);
|
}
|
}
|
/** @nocollapse */
|
Ones.className = 'Ones';
|
export { Ones };
|
serialization.registerClass(Ones);
|
class Constant extends Initializer {
|
constructor(args) {
|
super();
|
if (typeof args !== 'object') {
|
throw new ValueError(`Expected argument of type ConstantConfig but got ${args}`);
|
}
|
if (args.value === undefined) {
|
throw new ValueError(`config must have value set but got ${args}`);
|
}
|
this.value = args.value;
|
}
|
apply(shape, dtype) {
|
return tidy(() => mul(scalar(this.value), ones(shape, dtype)));
|
}
|
getConfig() {
|
return {
|
value: this.value,
|
};
|
}
|
}
|
/** @nocollapse */
|
Constant.className = 'Constant';
|
export { Constant };
|
serialization.registerClass(Constant);
|
class RandomUniform extends Initializer {
|
constructor(args) {
|
super();
|
this.DEFAULT_MINVAL = -0.05;
|
this.DEFAULT_MAXVAL = 0.05;
|
this.minval = args.minval || this.DEFAULT_MINVAL;
|
this.maxval = args.maxval || this.DEFAULT_MAXVAL;
|
this.seed = args.seed;
|
}
|
apply(shape, dtype) {
|
return randomUniform(shape, this.minval, this.maxval, dtype, this.seed);
|
}
|
getConfig() {
|
return { minval: this.minval, maxval: this.maxval, seed: this.seed };
|
}
|
}
|
/** @nocollapse */
|
RandomUniform.className = 'RandomUniform';
|
export { RandomUniform };
|
serialization.registerClass(RandomUniform);
|
class RandomNormal extends Initializer {
|
constructor(args) {
|
super();
|
this.DEFAULT_MEAN = 0.;
|
this.DEFAULT_STDDEV = 0.05;
|
this.mean = args.mean || this.DEFAULT_MEAN;
|
this.stddev = args.stddev || this.DEFAULT_STDDEV;
|
this.seed = args.seed;
|
}
|
apply(shape, dtype) {
|
dtype = dtype || 'float32';
|
if (dtype !== 'float32' && dtype !== 'int32') {
|
throw new NotImplementedError(`randomNormal does not support dType ${dtype}.`);
|
}
|
return K.randomNormal(shape, this.mean, this.stddev, dtype, this.seed);
|
}
|
getConfig() {
|
return { mean: this.mean, stddev: this.stddev, seed: this.seed };
|
}
|
}
|
/** @nocollapse */
|
RandomNormal.className = 'RandomNormal';
|
export { RandomNormal };
|
serialization.registerClass(RandomNormal);
|
class TruncatedNormal extends Initializer {
|
constructor(args) {
|
super();
|
this.DEFAULT_MEAN = 0.;
|
this.DEFAULT_STDDEV = 0.05;
|
this.mean = args.mean || this.DEFAULT_MEAN;
|
this.stddev = args.stddev || this.DEFAULT_STDDEV;
|
this.seed = args.seed;
|
}
|
apply(shape, dtype) {
|
dtype = dtype || 'float32';
|
if (dtype !== 'float32' && dtype !== 'int32') {
|
throw new NotImplementedError(`truncatedNormal does not support dType ${dtype}.`);
|
}
|
return truncatedNormal(shape, this.mean, this.stddev, dtype, this.seed);
|
}
|
getConfig() {
|
return { mean: this.mean, stddev: this.stddev, seed: this.seed };
|
}
|
}
|
/** @nocollapse */
|
TruncatedNormal.className = 'TruncatedNormal';
|
export { TruncatedNormal };
|
serialization.registerClass(TruncatedNormal);
|
class Identity extends Initializer {
|
constructor(args) {
|
super();
|
this.gain = args.gain != null ? args.gain : 1.0;
|
}
|
apply(shape, dtype) {
|
return tidy(() => {
|
if (shape.length !== 2 || shape[0] !== shape[1]) {
|
throw new ValueError('Identity matrix initializer can only be used for' +
|
' 2D square matrices.');
|
}
|
else {
|
return mul(this.gain, eye(shape[0]));
|
}
|
});
|
}
|
getConfig() {
|
return { gain: this.gain };
|
}
|
}
|
/** @nocollapse */
|
Identity.className = 'Identity';
|
export { Identity };
|
serialization.registerClass(Identity);
|
/**
|
* Computes the number of input and output units for a weight shape.
|
* @param shape Shape of weight.
|
* @param dataFormat data format to use for convolution kernels.
|
* Note that all kernels in Keras are standardized on the
|
* CHANNEL_LAST ordering (even when inputs are set to CHANNEL_FIRST).
|
* @return An length-2 array: fanIn, fanOut.
|
*/
|
function computeFans(shape, dataFormat = 'channelsLast') {
|
let fanIn;
|
let fanOut;
|
checkDataFormat(dataFormat);
|
if (shape.length === 2) {
|
fanIn = shape[0];
|
fanOut = shape[1];
|
}
|
else if ([3, 4, 5].indexOf(shape.length) !== -1) {
|
if (dataFormat === 'channelsFirst') {
|
const receptiveFieldSize = arrayProd(shape, 2);
|
fanIn = shape[1] * receptiveFieldSize;
|
fanOut = shape[0] * receptiveFieldSize;
|
}
|
else if (dataFormat === 'channelsLast') {
|
const receptiveFieldSize = arrayProd(shape, 0, shape.length - 2);
|
fanIn = shape[shape.length - 2] * receptiveFieldSize;
|
fanOut = shape[shape.length - 1] * receptiveFieldSize;
|
}
|
}
|
else {
|
const shapeProd = arrayProd(shape);
|
fanIn = Math.sqrt(shapeProd);
|
fanOut = Math.sqrt(shapeProd);
|
}
|
return [fanIn, fanOut];
|
}
|
class VarianceScaling extends Initializer {
|
/**
|
* Constructor of VarianceScaling.
|
* @throws ValueError for invalid value in scale.
|
*/
|
constructor(args) {
|
super();
|
if (args.scale < 0.0) {
|
throw new ValueError(`scale must be a positive float. Got: ${args.scale}`);
|
}
|
this.scale = args.scale == null ? 1.0 : args.scale;
|
this.mode = args.mode == null ? 'fanIn' : args.mode;
|
checkFanMode(this.mode);
|
this.distribution =
|
args.distribution == null ? 'normal' : args.distribution;
|
checkDistribution(this.distribution);
|
this.seed = args.seed;
|
}
|
apply(shape, dtype) {
|
const fans = computeFans(shape);
|
const fanIn = fans[0];
|
const fanOut = fans[1];
|
let scale = this.scale;
|
if (this.mode === 'fanIn') {
|
scale /= Math.max(1, fanIn);
|
}
|
else if (this.mode === 'fanOut') {
|
scale /= Math.max(1, fanOut);
|
}
|
else {
|
scale /= Math.max(1, (fanIn + fanOut) / 2);
|
}
|
if (this.distribution === 'normal') {
|
const stddev = Math.sqrt(scale);
|
dtype = dtype || 'float32';
|
if (dtype !== 'float32' && dtype !== 'int32') {
|
throw new NotImplementedError(`${this.getClassName()} does not support dType ${dtype}.`);
|
}
|
return truncatedNormal(shape, 0, stddev, dtype, this.seed);
|
}
|
else {
|
const limit = Math.sqrt(3 * scale);
|
return randomUniform(shape, -limit, limit, dtype, this.seed);
|
}
|
}
|
getConfig() {
|
return {
|
scale: this.scale,
|
mode: this.mode,
|
distribution: this.distribution,
|
seed: this.seed
|
};
|
}
|
}
|
/** @nocollapse */
|
VarianceScaling.className = 'VarianceScaling';
|
export { VarianceScaling };
|
serialization.registerClass(VarianceScaling);
|
class GlorotUniform extends VarianceScaling {
|
/**
|
* Constructor of GlorotUniform
|
* @param scale
|
* @param mode
|
* @param distribution
|
* @param seed
|
*/
|
constructor(args) {
|
super({
|
scale: 1.0,
|
mode: 'fanAvg',
|
distribution: 'uniform',
|
seed: args == null ? null : args.seed
|
});
|
}
|
getClassName() {
|
// In Python Keras, GlorotUniform is not a class, but a helper method
|
// that creates a VarianceScaling object. Use 'VarianceScaling' as
|
// class name to be compatible with that.
|
return VarianceScaling.className;
|
}
|
}
|
/** @nocollapse */
|
GlorotUniform.className = 'GlorotUniform';
|
export { GlorotUniform };
|
serialization.registerClass(GlorotUniform);
|
class GlorotNormal extends VarianceScaling {
|
/**
|
* Constructor of GlorotNormal.
|
* @param scale
|
* @param mode
|
* @param distribution
|
* @param seed
|
*/
|
constructor(args) {
|
super({
|
scale: 1.0,
|
mode: 'fanAvg',
|
distribution: 'normal',
|
seed: args == null ? null : args.seed
|
});
|
}
|
getClassName() {
|
// In Python Keras, GlorotNormal is not a class, but a helper method
|
// that creates a VarianceScaling object. Use 'VarianceScaling' as
|
// class name to be compatible with that.
|
return VarianceScaling.className;
|
}
|
}
|
/** @nocollapse */
|
GlorotNormal.className = 'GlorotNormal';
|
export { GlorotNormal };
|
serialization.registerClass(GlorotNormal);
|
class HeNormal extends VarianceScaling {
|
constructor(args) {
|
super({
|
scale: 2.0,
|
mode: 'fanIn',
|
distribution: 'normal',
|
seed: args == null ? null : args.seed
|
});
|
}
|
getClassName() {
|
// In Python Keras, HeNormal is not a class, but a helper method
|
// that creates a VarianceScaling object. Use 'VarianceScaling' as
|
// class name to be compatible with that.
|
return VarianceScaling.className;
|
}
|
}
|
/** @nocollapse */
|
HeNormal.className = 'HeNormal';
|
export { HeNormal };
|
serialization.registerClass(HeNormal);
|
class HeUniform extends VarianceScaling {
|
constructor(args) {
|
super({
|
scale: 2.0,
|
mode: 'fanIn',
|
distribution: 'uniform',
|
seed: args == null ? null : args.seed
|
});
|
}
|
getClassName() {
|
// In Python Keras, HeUniform is not a class, but a helper method
|
// that creates a VarianceScaling object. Use 'VarianceScaling' as
|
// class name to be compatible with that.
|
return VarianceScaling.className;
|
}
|
}
|
/** @nocollapse */
|
HeUniform.className = 'HeUniform';
|
export { HeUniform };
|
serialization.registerClass(HeUniform);
|
class LeCunNormal extends VarianceScaling {
|
constructor(args) {
|
super({
|
scale: 1.0,
|
mode: 'fanIn',
|
distribution: 'normal',
|
seed: args == null ? null : args.seed
|
});
|
}
|
getClassName() {
|
// In Python Keras, LeCunNormal is not a class, but a helper method
|
// that creates a VarianceScaling object. Use 'VarianceScaling' as
|
// class name to be compatible with that.
|
return VarianceScaling.className;
|
}
|
}
|
/** @nocollapse */
|
LeCunNormal.className = 'LeCunNormal';
|
export { LeCunNormal };
|
serialization.registerClass(LeCunNormal);
|
class LeCunUniform extends VarianceScaling {
|
constructor(args) {
|
super({
|
scale: 1.0,
|
mode: 'fanIn',
|
distribution: 'uniform',
|
seed: args == null ? null : args.seed
|
});
|
}
|
getClassName() {
|
// In Python Keras, LeCunUniform is not a class, but a helper method
|
// that creates a VarianceScaling object. Use 'VarianceScaling' as
|
// class name to be compatible with that.
|
return VarianceScaling.className;
|
}
|
}
|
/** @nocollapse */
|
LeCunUniform.className = 'LeCunUniform';
|
export { LeCunUniform };
|
serialization.registerClass(LeCunUniform);
|
class Orthogonal extends Initializer {
|
constructor(args) {
|
super();
|
this.DEFAULT_GAIN = 1;
|
this.ELEMENTS_WARN_SLOW = 2000;
|
this.gain = args.gain == null ? this.DEFAULT_GAIN : args.gain;
|
this.seed = args.seed;
|
}
|
apply(shape, dtype) {
|
return tidy(() => {
|
if (shape.length < 2) {
|
throw new NotImplementedError('Shape must be at least 2D.');
|
}
|
if (dtype !== 'int32' && dtype !== 'float32' && dtype !== undefined) {
|
throw new TypeError(`Unsupported data type ${dtype}.`);
|
}
|
dtype = dtype;
|
// flatten the input shape with the last dimension remaining its
|
// original shape so it works for conv2d
|
const numRows = util.sizeFromShape(shape.slice(0, -1));
|
const numCols = shape[shape.length - 1];
|
const numElements = numRows * numCols;
|
if (numElements > this.ELEMENTS_WARN_SLOW) {
|
console.warn(`Orthogonal initializer is being called on a matrix with more ` +
|
`than ${this.ELEMENTS_WARN_SLOW} (${numElements}) elements: ` +
|
`Slowness may result.`);
|
}
|
const flatShape = [Math.max(numCols, numRows), Math.min(numCols, numRows)];
|
// Generate a random matrix
|
const randNormalMat = K.randomNormal(flatShape, 0, 1, dtype, this.seed);
|
// Compute QR factorization
|
const qr = linalg.qr(randNormalMat, false);
|
let qMat = qr[0];
|
const rMat = qr[1];
|
// Make Q uniform
|
const diag = rMat.flatten().stridedSlice([0], [Math.min(numCols, numRows) * Math.min(numCols, numRows)], [Math.min(numCols, numRows) + 1]);
|
qMat = mul(qMat, diag.sign());
|
if (numRows < numCols) {
|
qMat = qMat.transpose();
|
}
|
return mul(scalar(this.gain), qMat.reshape(shape));
|
});
|
}
|
getConfig() {
|
return {
|
gain: this.gain,
|
seed: this.seed,
|
};
|
}
|
}
|
/** @nocollapse */
|
Orthogonal.className = 'Orthogonal';
|
export { Orthogonal };
|
serialization.registerClass(Orthogonal);
|
// Maps the JavaScript-like identifier keys to the corresponding registry
|
// symbols.
|
export const INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
|
'constant': 'Constant',
|
'glorotNormal': 'GlorotNormal',
|
'glorotUniform': 'GlorotUniform',
|
'heNormal': 'HeNormal',
|
'heUniform': 'HeUniform',
|
'identity': 'Identity',
|
'leCunNormal': 'LeCunNormal',
|
'leCunUniform': 'LeCunUniform',
|
'ones': 'Ones',
|
'orthogonal': 'Orthogonal',
|
'randomNormal': 'RandomNormal',
|
'randomUniform': 'RandomUniform',
|
'truncatedNormal': 'TruncatedNormal',
|
'varianceScaling': 'VarianceScaling',
|
'zeros': 'Zeros'
|
};
|
function deserializeInitializer(config, customObjects = {}) {
|
return deserializeKerasObject(config, serialization.SerializationMap.getMap().classNameMap, customObjects, 'initializer');
|
}
|
export function serializeInitializer(initializer) {
|
return serializeKerasObject(initializer);
|
}
|
export function getInitializer(identifier) {
|
if (typeof identifier === 'string') {
|
const className = identifier in INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
|
INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
|
identifier;
|
/* We have four 'helper' classes for common initializers that
|
all get serialized as 'VarianceScaling' and shouldn't go through
|
the deserializeInitializer pathway. */
|
if (className === 'GlorotNormal') {
|
return new GlorotNormal();
|
}
|
else if (className === 'GlorotUniform') {
|
return new GlorotUniform();
|
}
|
else if (className === 'HeNormal') {
|
return new HeNormal();
|
}
|
else if (className === 'HeUniform') {
|
return new HeUniform();
|
}
|
else if (className === 'LeCunNormal') {
|
return new LeCunNormal();
|
}
|
else if (className === 'LeCunUniform') {
|
return new LeCunUniform();
|
}
|
else {
|
const config = {};
|
config['className'] = className;
|
config['config'] = {};
|
return deserializeInitializer(config);
|
}
|
}
|
else if (identifier instanceof Initializer) {
|
return identifier;
|
}
|
else {
|
return deserializeInitializer(identifier);
|
}
|
}
|
//# sourceMappingURL=data:application/json;base64,
|