/**
|
* @license
|
* Copyright 2018 Google LLC
|
*
|
* Use of this source code is governed by an MIT-style
|
* license that can be found in the LICENSE file or at
|
* https://opensource.org/licenses/MIT.
|
* =============================================================================
|
*/
|
/// <amd-module name="@tensorflow/tfjs-layers/dist/constraints" />
|
import { serialization, Tensor } from '@tensorflow/tfjs-core';
|
/**
|
* Base class for functions that impose constraints on weight values
|
*
|
* @doc {
|
* heading: 'Constraints',
|
* subheading: 'Classes',
|
* namespace: 'constraints'
|
* }
|
*/
|
export declare abstract class Constraint extends serialization.Serializable {
|
abstract apply(w: Tensor): Tensor;
|
getConfig(): serialization.ConfigDict;
|
}
|
export interface MaxNormArgs {
|
/**
|
* Maximum norm for incoming weights
|
*/
|
maxValue?: number;
|
/**
|
* Axis along which to calculate norms.
|
*
|
* For instance, in a `Dense` layer the weight matrix
|
* has shape `[inputDim, outputDim]`,
|
* set `axis` to `0` to constrain each weight vector
|
* of length `[inputDim,]`.
|
* In a `Conv2D` layer with `dataFormat="channels_last"`,
|
* the weight tensor has shape
|
* `[rows, cols, inputDepth, outputDepth]`,
|
* set `axis` to `[0, 1, 2]`
|
* to constrain the weights of each filter tensor of size
|
* `[rows, cols, inputDepth]`.
|
*/
|
axis?: number;
|
}
|
export declare class MaxNorm extends Constraint {
|
/** @nocollapse */
|
static readonly className = "MaxNorm";
|
private maxValue;
|
private axis;
|
private readonly defaultMaxValue;
|
private readonly defaultAxis;
|
constructor(args: MaxNormArgs);
|
apply(w: Tensor): Tensor;
|
getConfig(): serialization.ConfigDict;
|
}
|
export interface UnitNormArgs {
|
/**
|
* Axis along which to calculate norms.
|
*
|
* For instance, in a `Dense` layer the weight matrix
|
* has shape `[inputDim, outputDim]`,
|
* set `axis` to `0` to constrain each weight vector
|
* of length `[inputDim,]`.
|
* In a `Conv2D` layer with `dataFormat="channels_last"`,
|
* the weight tensor has shape
|
* `[rows, cols, inputDepth, outputDepth]`,
|
* set `axis` to `[0, 1, 2]`
|
* to constrain the weights of each filter tensor of size
|
* `[rows, cols, inputDepth]`.
|
*/
|
axis?: number;
|
}
|
export declare class UnitNorm extends Constraint {
|
/** @nocollapse */
|
static readonly className = "UnitNorm";
|
private axis;
|
private readonly defaultAxis;
|
constructor(args: UnitNormArgs);
|
apply(w: Tensor): Tensor;
|
getConfig(): serialization.ConfigDict;
|
}
|
export declare class NonNeg extends Constraint {
|
/** @nocollapse */
|
static readonly className = "NonNeg";
|
apply(w: Tensor): Tensor;
|
}
|
export interface MinMaxNormArgs {
|
/**
|
* Minimum norm for incoming weights
|
*/
|
minValue?: number;
|
/**
|
* Maximum norm for incoming weights
|
*/
|
maxValue?: number;
|
/**
|
* Axis along which to calculate norms.
|
* For instance, in a `Dense` layer the weight matrix
|
* has shape `[inputDim, outputDim]`,
|
* set `axis` to `0` to constrain each weight vector
|
* of length `[inputDim,]`.
|
* In a `Conv2D` layer with `dataFormat="channels_last"`,
|
* the weight tensor has shape
|
* `[rows, cols, inputDepth, outputDepth]`,
|
* set `axis` to `[0, 1, 2]`
|
* to constrain the weights of each filter tensor of size
|
* `[rows, cols, inputDepth]`.
|
*/
|
axis?: number;
|
/**
|
* Rate for enforcing the constraint: weights will be rescaled to yield:
|
* `(1 - rate) * norm + rate * norm.clip(minValue, maxValue)`.
|
* Effectively, this means that rate=1.0 stands for strict
|
* enforcement of the constraint, while rate<1.0 means that
|
* weights will be rescaled at each step to slowly move
|
* towards a value inside the desired interval.
|
*/
|
rate?: number;
|
}
|
export declare class MinMaxNorm extends Constraint {
|
/** @nocollapse */
|
static readonly className = "MinMaxNorm";
|
private minValue;
|
private maxValue;
|
private rate;
|
private axis;
|
private readonly defaultMinValue;
|
private readonly defaultMaxValue;
|
private readonly defaultRate;
|
private readonly defaultAxis;
|
constructor(args: MinMaxNormArgs);
|
apply(w: Tensor): Tensor;
|
getConfig(): serialization.ConfigDict;
|
}
|
/** @docinline */
|
export type ConstraintIdentifier = 'maxNorm' | 'minMaxNorm' | 'nonNeg' | 'unitNorm' | string;
|
export declare const CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP: {
|
[identifier in ConstraintIdentifier]: string;
|
};
|
export declare function serializeConstraint(constraint: Constraint): serialization.ConfigDictValue;
|
export declare function deserializeConstraint(config: serialization.ConfigDict, customObjects?: serialization.ConfigDict): Constraint;
|
export declare function getConstraint(identifier: ConstraintIdentifier | serialization.ConfigDict | Constraint): Constraint;
|