/**
|
* @license
|
* Copyright 2018 Google LLC
|
*
|
* Use of this source code is governed by an MIT-style
|
* license that can be found in the LICENSE file or at
|
* https://opensource.org/licenses/MIT.
|
* =============================================================================
|
*/
|
/// <amd-module name="@tensorflow/tfjs-layers/dist/base_callbacks" />
|
import { Tensor } from '@tensorflow/tfjs-core';
|
import { Container } from './engine/container';
|
import { Logs, UnresolvedLogs } from './logs';
|
/** Verbosity logging level when fitting a model. */
|
export declare enum ModelLoggingVerbosity {
|
SILENT = 0,
|
VERBOSE = 1
|
}
|
/** How often to yield to the main thread when training (in ms). */
|
export declare const DEFAULT_YIELD_EVERY_MS = 125;
|
export type Params = {
|
[key: string]: number | string | boolean | number[] | string[] | boolean[];
|
};
|
export type YieldEveryOptions = 'auto' | 'batch' | 'epoch' | 'never' | number;
|
/**
|
* Abstract base class used to build new callbacks.
|
*
|
* The `logs` dictionary that callback methods take as argument will contain
|
* keys for quantities relevant to the current batch or epoch.
|
*
|
* Currently, the `.fit()` method of the `Sequential` model class
|
* will include the following quantities in the `logs` that
|
* it passes to its callbacks:
|
*
|
* onEpochEnd: Logs include `acc` and `loss`, and optionally include `valLoss`
|
* (if validation is enabled in `fit`), and `valAcc` (if validation and
|
* accuracy monitoring are enabled).
|
* onBatchBegin: Logs include `size`, the number of samples in the current
|
* batch.
|
* onBatchEnd: Logs include `loss`, and optionally `acc` (if accuracy monitoring
|
* is enabled).
|
*/
|
export declare abstract class BaseCallback {
|
validationData: Tensor | Tensor[];
|
/**
|
* Training parameters (eg. verbosity, batch size, number of epochs...).
|
*/
|
params: Params;
|
setParams(params: Params): void;
|
onEpochBegin(epoch: number, logs?: UnresolvedLogs): Promise<void>;
|
onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
|
onBatchBegin(batch: number, logs?: UnresolvedLogs): Promise<void>;
|
onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>;
|
onTrainBegin(logs?: UnresolvedLogs): Promise<void>;
|
onTrainEnd(logs?: UnresolvedLogs): Promise<void>;
|
setModel(model: Container): void;
|
}
|
/**
|
* Container abstracting a list of callbacks.
|
*/
|
export declare class CallbackList {
|
callbacks: BaseCallback[];
|
queueLength: number;
|
/**
|
* Constructor of CallbackList.
|
* @param callbacks Array of `Callback` instances.
|
* @param queueLength Queue length for keeping running statistics over
|
* callback execution time.
|
*/
|
constructor(callbacks?: BaseCallback[], queueLength?: number);
|
append(callback: BaseCallback): void;
|
setParams(params: Params): void;
|
setModel(model: Container): void;
|
/**
|
* Called at the start of an epoch.
|
* @param epoch Index of epoch.
|
* @param logs Dictionary of logs.
|
*/
|
onEpochBegin(epoch: number, logs?: UnresolvedLogs): Promise<void>;
|
/**
|
* Called at the end of an epoch.
|
* @param epoch Index of epoch.
|
* @param logs Dictionary of logs.
|
*/
|
onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
|
/**
|
* Called right before processing a batch.
|
* @param batch Index of batch within the current epoch.
|
* @param logs Dictionary of logs.
|
*/
|
onBatchBegin(batch: number, logs?: UnresolvedLogs): Promise<void>;
|
/**
|
* Called at the end of a batch.
|
* @param batch Index of batch within the current epoch.
|
* @param logs Dictionary of logs.
|
*/
|
onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>;
|
/**
|
* Called at the beginning of training.
|
* @param logs Dictionary of logs.
|
*/
|
onTrainBegin(logs?: UnresolvedLogs): Promise<void>;
|
/**
|
* Called at the end of training.
|
* @param logs Dictionary of logs.
|
*/
|
onTrainEnd(logs?: UnresolvedLogs): Promise<void>;
|
}
|
/**
|
* Callback that accumulates epoch averages of metrics.
|
*
|
* This callback is automatically applied to every LayersModel.
|
*/
|
export declare class BaseLogger extends BaseCallback {
|
private seen;
|
private totals;
|
constructor();
|
onEpochBegin(epoch: number): Promise<void>;
|
onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>;
|
onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
|
}
|
/**
|
* Callback that records events into a `History` object. This callback is
|
* automatically applied to every TF.js Layers model. The `History` object
|
* gets returned by the `fit` method of models.
|
*/
|
export declare class History extends BaseCallback {
|
epoch: number[];
|
history: {
|
[key: string]: Array<number | Tensor>;
|
};
|
onTrainBegin(logs?: UnresolvedLogs): Promise<void>;
|
onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
|
/**
|
* Await the values of all losses and metrics.
|
*/
|
syncData(): Promise<void>;
|
}
|
export interface CustomCallbackArgs {
|
onTrainBegin?: (logs?: Logs) => void | Promise<void>;
|
onTrainEnd?: (logs?: Logs) => void | Promise<void>;
|
onEpochBegin?: (epoch: number, logs?: Logs) => void | Promise<void>;
|
onEpochEnd?: (epoch: number, logs?: Logs) => void | Promise<void>;
|
onBatchBegin?: (batch: number, logs?: Logs) => void | Promise<void>;
|
onBatchEnd?: (batch: number, logs?: Logs) => void | Promise<void>;
|
onYield?: (epoch: number, batch: number, logs: Logs) => void | Promise<void>;
|
nowFunc?: Function;
|
nextFrameFunc?: Function;
|
}
|
/**
|
* Custom callback for training.
|
*/
|
export declare class CustomCallback extends BaseCallback {
|
protected readonly trainBegin: (logs?: Logs) => void | Promise<void>;
|
protected readonly trainEnd: (logs?: Logs) => void | Promise<void>;
|
protected readonly epochBegin: (epoch: number, logs?: Logs) => void | Promise<void>;
|
protected readonly epochEnd: (epoch: number, logs?: Logs) => void | Promise<void>;
|
protected readonly batchBegin: (batch: number, logs?: Logs) => void | Promise<void>;
|
protected readonly batchEnd: (batch: number, logs?: Logs) => void | Promise<void>;
|
protected readonly yield: (epoch: number, batch: number, logs: Logs) => void | Promise<void>;
|
private yieldEvery;
|
private currentEpoch;
|
nowFunc: Function;
|
nextFrameFunc: Function;
|
constructor(args: CustomCallbackArgs, yieldEvery?: YieldEveryOptions);
|
maybeWait(epoch: number, batch: number, logs: UnresolvedLogs): Promise<void>;
|
onEpochBegin(epoch: number, logs?: UnresolvedLogs): Promise<void>;
|
onEpochEnd(epoch: number, logs?: UnresolvedLogs): Promise<void>;
|
onBatchBegin(batch: number, logs?: UnresolvedLogs): Promise<void>;
|
onBatchEnd(batch: number, logs?: UnresolvedLogs): Promise<void>;
|
onTrainBegin(logs?: UnresolvedLogs): Promise<void>;
|
onTrainEnd(logs?: UnresolvedLogs): Promise<void>;
|
}
|
/**
|
* Standardize callbacks or configurations of them to an Array of callbacks.
|
*/
|
export declare function standardizeCallbacks(callbacks: BaseCallback | BaseCallback[] | CustomCallbackArgs | CustomCallbackArgs[], yieldEvery: YieldEveryOptions): BaseCallback[];
|
export declare type BaseCallbackConstructor = {
|
new (): BaseCallback;
|
};
|
/**
|
* A global registry for callback constructors to be used during
|
* LayersModel.fit().
|
*/
|
export declare class CallbackConstructorRegistry {
|
private static constructors;
|
/**
|
* Blocks public access to constructor.
|
*/
|
private constructor();
|
/**
|
* Register a tf.LayersModel.fit() callback constructor.
|
*
|
* The registered callback constructor will be used to instantiate
|
* callbacks for every tf.LayersModel.fit() call afterwards.
|
*
|
* @param verbosityLevel Level of verbosity at which the `callbackConstructor`
|
* is to be reigstered.
|
* @param callbackConstructor A no-arg constructor for `tf.Callback`.
|
* @throws Error, if the same callbackConstructor has been registered before,
|
* either at the same or a different `verbosityLevel`.
|
*/
|
static registerCallbackConstructor(verbosityLevel: number, callbackConstructor: BaseCallbackConstructor): void;
|
private static checkForDuplicate;
|
/**
|
* Clear all registered callback constructors.
|
*/
|
protected static clear(): void;
|
/**
|
* Create callbacks using the registered callback constructors.
|
*
|
* Given `verbosityLevel`, all constructors registered at that level or above
|
* will be called and the instantiated callbacks will be used.
|
*
|
* @param verbosityLevel: Level of verbosity.
|
*/
|
static createCallbacks(verbosityLevel: number): BaseCallback[];
|
}
|
export declare function configureCallbacks(callbacks: BaseCallback[], verbose: ModelLoggingVerbosity, epochs: number, initialEpoch: number, numTrainSamples: number, stepsPerEpoch: number, batchSize: number, doValidation: boolean, callbackMetrics: string[]): {
|
callbackList: CallbackList;
|
history: History;
|
};
|