gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
/**
 * @license
 * Copyright 2018 Google LLC
 *
 * Use of this source code is governed by an MIT-style
 * license that can be found in the LICENSE file or at
 * https://opensource.org/licenses/MIT.
 * =============================================================================
 */
/// <amd-module name="@tensorflow/tfjs-layers/dist/initializers" />
import { DataType, serialization, Tensor } from '@tensorflow/tfjs-core';
import { Shape } from './keras_format/common';
import { Distribution, FanMode } from './keras_format/initializer_config';
export declare function checkFanMode(value?: string): void;
export declare function checkDistribution(value?: string): void;
/**
 * Initializer base class.
 *
 * @doc {
 *   heading: 'Initializers', subheading: 'Classes', namespace: 'initializers'}
 */
export declare abstract class Initializer extends serialization.Serializable {
    fromConfigUsesCustomObjects(): boolean;
    /**
     * Generate an initial value.
     * @param shape
     * @param dtype
     * @return The init value.
     */
    abstract apply(shape: Shape, dtype?: DataType): Tensor;
    getConfig(): serialization.ConfigDict;
}
export declare class Zeros extends Initializer {
    /** @nocollapse */
    static className: string;
    apply(shape: Shape, dtype?: DataType): Tensor;
}
export declare class Ones extends Initializer {
    /** @nocollapse */
    static className: string;
    apply(shape: Shape, dtype?: DataType): Tensor;
}
export interface ConstantArgs {
    /** The value for each element in the variable. */
    value: number;
}
export declare class Constant extends Initializer {
    /** @nocollapse */
    static className: string;
    private value;
    constructor(args: ConstantArgs);
    apply(shape: Shape, dtype?: DataType): Tensor;
    getConfig(): serialization.ConfigDict;
}
export interface RandomUniformArgs {
    /** Lower bound of the range of random values to generate. */
    minval?: number;
    /** Upper bound of the range of random values to generate. */
    maxval?: number;
    /** Used to seed the random generator. */
    seed?: number;
}
export declare class RandomUniform extends Initializer {
    /** @nocollapse */
    static className: string;
    readonly DEFAULT_MINVAL = -0.05;
    readonly DEFAULT_MAXVAL = 0.05;
    private minval;
    private maxval;
    private seed;
    constructor(args: RandomUniformArgs);
    apply(shape: Shape, dtype?: DataType): Tensor;
    getConfig(): serialization.ConfigDict;
}
export interface RandomNormalArgs {
    /** Mean of the random values to generate. */
    mean?: number;
    /** Standard deviation of the random values to generate. */
    stddev?: number;
    /** Used to seed the random generator. */
    seed?: number;
}
export declare class RandomNormal extends Initializer {
    /** @nocollapse */
    static className: string;
    readonly DEFAULT_MEAN = 0;
    readonly DEFAULT_STDDEV = 0.05;
    private mean;
    private stddev;
    private seed;
    constructor(args: RandomNormalArgs);
    apply(shape: Shape, dtype?: DataType): Tensor;
    getConfig(): serialization.ConfigDict;
}
export interface TruncatedNormalArgs {
    /** Mean of the random values to generate. */
    mean?: number;
    /** Standard deviation of the random values to generate. */
    stddev?: number;
    /** Used to seed the random generator. */
    seed?: number;
}
export declare class TruncatedNormal extends Initializer {
    /** @nocollapse */
    static className: string;
    readonly DEFAULT_MEAN = 0;
    readonly DEFAULT_STDDEV = 0.05;
    private mean;
    private stddev;
    private seed;
    constructor(args: TruncatedNormalArgs);
    apply(shape: Shape, dtype?: DataType): Tensor;
    getConfig(): serialization.ConfigDict;
}
export interface IdentityArgs {
    /**
     * Multiplicative factor to apply to the identity matrix.
     */
    gain?: number;
}
export declare class Identity extends Initializer {
    /** @nocollapse */
    static className: string;
    private gain;
    constructor(args: IdentityArgs);
    apply(shape: Shape, dtype?: DataType): Tensor;
    getConfig(): serialization.ConfigDict;
}
export interface VarianceScalingArgs {
    /** Scaling factor (positive float). */
    scale?: number;
    /** Fanning mode for inputs and outputs. */
    mode?: FanMode;
    /** Probabilistic distribution of the values. */
    distribution?: Distribution;
    /** Random number generator seed. */
    seed?: number;
}
export declare class VarianceScaling extends Initializer {
    /** @nocollapse */
    static className: string;
    private scale;
    private mode;
    private distribution;
    private seed;
    /**
     * Constructor of VarianceScaling.
     * @throws ValueError for invalid value in scale.
     */
    constructor(args: VarianceScalingArgs);
    apply(shape: Shape, dtype?: DataType): Tensor;
    getConfig(): serialization.ConfigDict;
}
export interface SeedOnlyInitializerArgs {
    /** Random number generator seed. */
    seed?: number;
}
export declare class GlorotUniform extends VarianceScaling {
    /** @nocollapse */
    static className: string;
    /**
     * Constructor of GlorotUniform
     * @param scale
     * @param mode
     * @param distribution
     * @param seed
     */
    constructor(args?: SeedOnlyInitializerArgs);
    getClassName(): string;
}
export declare class GlorotNormal extends VarianceScaling {
    /** @nocollapse */
    static className: string;
    /**
     * Constructor of GlorotNormal.
     * @param scale
     * @param mode
     * @param distribution
     * @param seed
     */
    constructor(args?: SeedOnlyInitializerArgs);
    getClassName(): string;
}
export declare class HeNormal extends VarianceScaling {
    /** @nocollapse */
    static className: string;
    constructor(args?: SeedOnlyInitializerArgs);
    getClassName(): string;
}
export declare class HeUniform extends VarianceScaling {
    /** @nocollapse */
    static className: string;
    constructor(args?: SeedOnlyInitializerArgs);
    getClassName(): string;
}
export declare class LeCunNormal extends VarianceScaling {
    /** @nocollapse */
    static className: string;
    constructor(args?: SeedOnlyInitializerArgs);
    getClassName(): string;
}
export declare class LeCunUniform extends VarianceScaling {
    /** @nocollapse */
    static className: string;
    constructor(args?: SeedOnlyInitializerArgs);
    getClassName(): string;
}
export interface OrthogonalArgs extends SeedOnlyInitializerArgs {
    /**
     * Multiplicative factor to apply to the orthogonal matrix. Defaults to 1.
     */
    gain?: number;
}
export declare class Orthogonal extends Initializer {
    /** @nocollapse */
    static className: string;
    readonly DEFAULT_GAIN = 1;
    readonly ELEMENTS_WARN_SLOW = 2000;
    protected readonly gain: number;
    protected readonly seed: number;
    constructor(args?: OrthogonalArgs);
    apply(shape: Shape, dtype?: DataType): Tensor;
    getConfig(): serialization.ConfigDict;
}
/** @docinline */
export type InitializerIdentifier = 'constant' | 'glorotNormal' | 'glorotUniform' | 'heNormal' | 'heUniform' | 'identity' | 'leCunNormal' | 'leCunUniform' | 'ones' | 'orthogonal' | 'randomNormal' | 'randomUniform' | 'truncatedNormal' | 'varianceScaling' | 'zeros' | string;
export declare const INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP: {
    [identifier in InitializerIdentifier]: string;
};
export declare function serializeInitializer(initializer: Initializer): serialization.ConfigDictValue;
export declare function getInitializer(identifier: InitializerIdentifier | Initializer | serialization.ConfigDict): Initializer;