/**
|
* @license
|
* Copyright 2018 Google LLC
|
*
|
* Use of this source code is governed by an MIT-style
|
* license that can be found in the LICENSE file or at
|
* https://opensource.org/licenses/MIT.
|
* =============================================================================
|
*/
|
/// <amd-module name="@tensorflow/tfjs-layers/dist/engine/training_tensors" />
|
import { Tensor, Tensor1D } from '@tensorflow/tfjs-core';
|
import { BaseCallback, CustomCallbackArgs, ModelLoggingVerbosity, YieldEveryOptions } from '../base_callbacks';
|
import { ClassWeight, ClassWeightMap } from './training_utils';
|
/**
|
* Interface configuration model training based on data as `tf.Tensor`s.
|
*/
|
export interface ModelFitArgs {
|
/**
|
* Number of samples per gradient update. If unspecified, it
|
* will default to 32.
|
*/
|
batchSize?: number;
|
/**
|
* Integer number of times to iterate over the training data arrays.
|
*/
|
epochs?: number;
|
/**
|
* Verbosity level.
|
*
|
* Expected to be 0, 1, or 2. Default: 1.
|
*
|
* 0 - No printed message during fit() call.
|
* 1 - In Node.js (tfjs-node), prints the progress bar, together with
|
* real-time updates of loss and metric values and training speed.
|
* In the browser: no action. This is the default.
|
* 2 - Not implemented yet.
|
*/
|
verbose?: ModelLoggingVerbosity | 2;
|
/**
|
* List of callbacks to be called during training.
|
* Can have one or more of the following callbacks:
|
* - `onTrainBegin(logs)`: called when training starts.
|
* - `onTrainEnd(logs)`: called when training ends.
|
* - `onEpochBegin(epoch, logs)`: called at the start of every epoch.
|
* - `onEpochEnd(epoch, logs)`: called at the end of every epoch.
|
* - `onBatchBegin(batch, logs)`: called at the start of every batch.
|
* - `onBatchEnd(batch, logs)`: called at the end of every batch.
|
* - `onYield(epoch, batch, logs)`: called every `yieldEvery` milliseconds
|
* with the current epoch, batch and logs. The logs are the same
|
* as in `onBatchEnd()`. Note that `onYield` can skip batches or
|
* epochs. See also docs for `yieldEvery` below.
|
*/
|
callbacks?: BaseCallback[] | CustomCallbackArgs | CustomCallbackArgs[];
|
/**
|
* Float between 0 and 1: fraction of the training data
|
* to be used as validation data. The model will set apart this fraction of
|
* the training data, will not train on it, and will evaluate the loss and
|
* any model metrics on this data at the end of each epoch.
|
* The validation data is selected from the last samples in the `x` and `y`
|
* data provided, before shuffling.
|
*/
|
validationSplit?: number;
|
/**
|
* Data on which to evaluate the loss and any model
|
* metrics at the end of each epoch. The model will not be trained on this
|
* data. This could be a tuple [xVal, yVal] or a tuple [xVal, yVal,
|
* valSampleWeights]. The model will not be trained on this data.
|
* `validationData` will override `validationSplit`.
|
*/
|
validationData?: [
|
Tensor | Tensor[],
|
Tensor | Tensor[]
|
] | [Tensor | Tensor[], Tensor | Tensor[], Tensor | Tensor[]];
|
/**
|
* Whether to shuffle the training data before each epoch. Has
|
* no effect when `stepsPerEpoch` is not `null`.
|
*/
|
shuffle?: boolean;
|
/**
|
* Optional object mapping class indices (integers) to
|
* a weight (float) to apply to the model's loss for the samples from this
|
* class during training. This can be useful to tell the model to "pay more
|
* attention" to samples from an under-represented class.
|
*
|
* If the model has multiple outputs, a class weight can be specified for
|
* each of the outputs by setting this field an array of weight object
|
* or an object that maps model output names (e.g., `model.outputNames[0]`)
|
* to weight objects.
|
*/
|
classWeight?: ClassWeight | ClassWeight[] | ClassWeightMap;
|
/**
|
* Optional array of the same length as x, containing
|
* weights to apply to the model's loss for each sample. In the case of
|
* temporal data, you can pass a 2D array with shape (samples,
|
* sequenceLength), to apply a different weight to every timestep of every
|
* sample. In this case you should make sure to specify
|
* sampleWeightMode="temporal" in compile().
|
*/
|
sampleWeight?: Tensor;
|
/**
|
* Epoch at which to start training (useful for resuming a previous training
|
* run). When this is used, `epochs` is the index of the "final epoch".
|
* The model is not trained for a number of iterations given by `epochs`,
|
* but merely until the epoch of index `epochs` is reached.
|
*/
|
initialEpoch?: number;
|
/**
|
* Total number of steps (batches of samples) before
|
* declaring one epoch finished and starting the next epoch. When training
|
* with Input Tensors such as TensorFlow data tensors, the default `null` is
|
* equal to the number of unique samples in your dataset divided by the
|
* batch size, or 1 if that cannot be determined.
|
*/
|
stepsPerEpoch?: number;
|
/**
|
* Only relevant if `stepsPerEpoch` is specified. Total number of steps
|
* (batches of samples) to validate before stopping.
|
*/
|
validationSteps?: number;
|
/**
|
* Configures the frequency of yielding the main thread to other tasks.
|
*
|
* In the browser environment, yielding the main thread can improve the
|
* responsiveness of the page during training. In the Node.js environment,
|
* it can ensure tasks queued in the event loop can be handled in a timely
|
* manner.
|
*
|
* The value can be one of the following:
|
* - `'auto'`: The yielding happens at a certain frame rate (currently set
|
* at 125ms). This is the default.
|
* - `'batch'`: yield every batch.
|
* - `'epoch'`: yield every epoch.
|
* - any `number`: yield every `number` milliseconds.
|
* - `'never'`: never yield. (yielding can still happen through `await
|
* nextFrame()` calls in custom callbacks.)
|
*/
|
yieldEvery?: YieldEveryOptions;
|
}
|
export declare function checkBatchSize(batchSize: number): void;
|
/**
|
* Slice a Tensor or an Array of Tensors, by start and stop indices.
|
*
|
* Porting Note: The `_slice_arrays` function in PyKeras is covered by this
|
* function and `sliceArraysByIndices()` together.
|
*
|
* @param arrays: the input.
|
* @param start: the starting index (inclusive).
|
* @param stop: the stopping index (exclusive).
|
* @returns The result of the slicing. If `arrays` is an `Array` of
|
* `tf.Tensor`s, the slicing will be applied to all elements of the `Array`
|
* in the same way.
|
*/
|
export declare function sliceArrays(arrays: Tensor | Tensor[], start: number, stop: number): Tensor | Tensor[];
|
/**
|
* Slice a Tensor or an Array of Tensors, by random-order indices.
|
*
|
* Porting Note: The `_slice_arrays` function in PyKeras is covered by this
|
* function and `sliceArrays()` together.
|
*
|
* @param arrays The input `tf.Tensor` or `Array` of `tf.Tensor`s to slice.
|
* If an `Array` of `tf.Tensor`s, all `tf.Tensor`s will be sliced in the
|
* same fashion.
|
* @param indices The indices to use for slicing along the first (batch)
|
* dimension.
|
* @returns Result(s) of the slicing.
|
*/
|
export declare function sliceArraysByIndices(arrays: Tensor | Tensor[], indices: Tensor1D): Tensor | Tensor[];
|
/**
|
* Returns a list of batch indices (tuples of indices).
|
* @param size: Integer, total size of the data to slice into batches.
|
* @param batchSize: Integer, batch size.
|
* @returns An Array of [batchStart, batchEnd] tuples. batchStart is
|
* inclusive; batchEnd is exclusive. I.e., each batch consists of indices x
|
* that satisfy batchStart <= x < batchEnd.
|
*/
|
export declare function makeBatches(size: number, batchSize: number): Array<[number, number]>;
|
/**
|
* Ensure tensors all have a rank of at least 2.
|
*
|
* If a tensor has a rank of 1, it is dimension-expanded to rank 2.
|
* If any tensor has a rank of 0 (i.e., is a scalar), an error will be thrown.
|
*/
|
export declare function ensureTensorsRank2OrHigher(tensors: Tensor | Tensor[]): Tensor[];
|
/**
|
* Compare a set of tensors with a reference (old) set, discard the ones
|
* in the new set that are not present in the reference set.
|
*
|
* This method is used for memory clenaup during calls such as
|
* LayersModel.fit().
|
*
|
* @param tensors New set which may contain Tensors not present in
|
* `refTensors`.
|
* @param refTensors Reference Tensor set.
|
*/
|
export declare function disposeNewTensors(tensors: Tensor | Tensor[] | {
|
[inputName: string]: Tensor;
|
}, refTensors: Tensor | Tensor[] | {
|
[inputName: string]: Tensor;
|
}): void;
|