gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
/**
 * @license
 * Copyright 2018 Google LLC
 *
 * Use of this source code is governed by an MIT-style
 * license that can be found in the LICENSE file or at
 * https://opensource.org/licenses/MIT.
 * =============================================================================
 */
/// <amd-module name="@tensorflow/tfjs-layers/dist/callbacks" />
import { BaseCallback } from './base_callbacks';
import { Container } from './engine/container';
import { LayersModel } from './engine/training';
import { Logs } from './logs';
export declare abstract class Callback extends BaseCallback {
    /** Instance of `keras.models.Model`. Reference of the model being trained. */
    model: LayersModel;
    setModel(model: Container): void;
}
export interface EarlyStoppingCallbackArgs {
    /**
     * Quantity to be monitored.
     *
     * Defaults to 'val_loss'.
     */
    monitor?: string;
    /**
     * Minimum change in the monitored quantity to qualify as improvement,
     * i.e., an absolute change of less than `minDelta` will count as no
     * improvement.
     *
     * Defaults to 0.
     */
    minDelta?: number;
    /**
     * Number of epochs with no improvement after which training will be stopped.
     *
     * Defaults to 0.
     */
    patience?: number;
    /** Verbosity mode. */
    verbose?: number;
    /**
     * Mode: one of 'min', 'max', and 'auto'.
     * - In 'min' mode, training will be stopped when the quantity monitored has
     *   stopped decreasing.
     * - In 'max' mode, training will be stopped when the quantity monitored has
     *   stopped increasing.
     * - In 'auto' mode, the direction is inferred automatically from the name of
     *   the monitored quantity.
     *
     * Defaults to 'auto'.
     */
    mode?: 'auto' | 'min' | 'max';
    /**
     * Baseline value of the monitored quantity.
     *
     * If specified, training will be stopped if the model doesn't show
     * improvement over the baseline.
     */
    baseline?: number;
    /**
     * Whether to restore model weights from the epoch with the best value
     * of the monitored quantity. If `False`, the model weights obtained at the
     * last step of training are used.
     *
     * **`True` is not supported yet.**
     */
    restoreBestWeights?: boolean;
}
/**
 * A Callback that stops training when a monitored quantity has stopped
 * improving.
 */
export declare class EarlyStopping extends Callback {
    protected readonly monitor: string;
    protected readonly minDelta: number;
    protected readonly patience: number;
    protected readonly baseline: number;
    protected readonly verbose: number;
    protected readonly mode: 'auto' | 'min' | 'max';
    protected monitorFunc: (currVal: number, prevVal: number) => boolean;
    private wait;
    private stoppedEpoch;
    private best;
    constructor(args?: EarlyStoppingCallbackArgs);
    onTrainBegin(logs?: Logs): Promise<void>;
    onEpochEnd(epoch: number, logs?: Logs): Promise<void>;
    onTrainEnd(logs?: Logs): Promise<void>;
    private getMonitorValue;
}
/**
 * Factory function for a Callback that stops training when a monitored
 * quantity has stopped improving.
 *
 * Early stopping is a type of regularization, and protects model against
 * overfitting.
 *
 * The following example based on fake data illustrates how this callback
 * can be used during `tf.LayersModel.fit()`:
 *
 * ```js
 * const model = tf.sequential();
 * model.add(tf.layers.dense({
 *   units: 3,
 *   activation: 'softmax',
 *   kernelInitializer: 'ones',
 *   inputShape: [2]
 * }));
 * const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
 * const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
 * const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
 * const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
 * model.compile(
 *     {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});
 *
 * // Without the EarlyStopping callback, the val_acc value would be:
 * //   0.5, 0.5, 0.5, 0.5, ...
 * // With val_acc being monitored, training should stop after the 2nd epoch.
 * const history = await model.fit(xs, ys, {
 *   epochs: 10,
 *   validationData: [xsVal, ysVal],
 *   callbacks: tf.callbacks.earlyStopping({monitor: 'val_acc'})
 * });
 *
 * // Expect to see a length-2 array.
 * console.log(history.history.val_acc);
 * ```
 *
 * @doc {
 *   heading: 'Callbacks',
 *   namespace: 'callbacks'
 * }
 */
export declare function earlyStopping(args?: EarlyStoppingCallbackArgs): EarlyStopping;
export declare const callbacks: {
    earlyStopping: typeof earlyStopping;
};