gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
/**
 * @license
 * Copyright 2018 Google LLC
 *
 * Use of this source code is governed by an MIT-style
 * license that can be found in the LICENSE file or at
 * https://opensource.org/licenses/MIT.
 * =============================================================================
 */
/// <amd-module name="@tensorflow/tfjs-layers/dist/layers/wrappers" />
import { serialization, Tensor } from '@tensorflow/tfjs-core';
import { Layer, LayerArgs, SymbolicTensor } from '../engine/topology';
import { BidirectionalMergeMode, Shape } from '../keras_format/common';
import { Kwargs } from '../types';
import { RegularizerFn } from '../types';
import { LayerVariable } from '../variables';
import { RNN } from './recurrent';
export declare interface WrapperLayerArgs extends LayerArgs {
    /**
     * The layer to be wrapped.
     */
    layer: Layer;
}
/**
 * Abstract wrapper base class.
 *
 * Wrappers take another layer and augment it in various ways.
 * Do not use this class as a layer, it is only an abstract base class.
 * Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers.
 */
export declare abstract class Wrapper extends Layer {
    readonly layer: Layer;
    constructor(args: WrapperLayerArgs);
    build(inputShape: Shape | Shape[]): void;
    get trainable(): boolean;
    set trainable(value: boolean);
    get trainableWeights(): LayerVariable[];
    get nonTrainableWeights(): LayerVariable[];
    get updates(): Tensor[];
    get losses(): RegularizerFn[];
    getWeights(): Tensor[];
    setWeights(weights: Tensor[]): void;
    getConfig(): serialization.ConfigDict;
    setFastWeightInitDuringBuild(value: boolean): void;
    /** @nocollapse */
    static fromConfig<T extends serialization.Serializable>(cls: serialization.SerializableConstructor<T>, config: serialization.ConfigDict, customObjects?: serialization.ConfigDict): T;
}
export declare class TimeDistributed extends Wrapper {
    /** @nocollapse */
    static className: string;
    constructor(args: WrapperLayerArgs);
    build(inputShape: Shape | Shape[]): void;
    computeOutputShape(inputShape: Shape | Shape[]): Shape | Shape[];
    call(inputs: Tensor | Tensor[], kwargs: Kwargs): Tensor | Tensor[];
}
export declare function checkBidirectionalMergeMode(value?: string): void;
export declare interface BidirectionalLayerArgs extends WrapperLayerArgs {
    /**
     * The instance of an `RNN` layer to be wrapped.
     */
    layer: RNN;
    /**
     * Mode by which outputs of the forward and backward RNNs are
     * combined. If `null` or `undefined`, the output will not be
     * combined, they will be returned as an `Array`.
     *
     * If `undefined` (i.e., not provided), defaults to `'concat'`.
     */
    mergeMode?: BidirectionalMergeMode;
}
export declare class Bidirectional extends Wrapper {
    /** @nocollapse */
    static className: string;
    mergeMode: BidirectionalMergeMode;
    private forwardLayer;
    private backwardLayer;
    private returnSequences;
    private returnState;
    private numConstants?;
    private _trainable;
    constructor(args: BidirectionalLayerArgs);
    get trainable(): boolean;
    set trainable(value: boolean);
    getWeights(): Tensor[];
    setWeights(weights: Tensor[]): void;
    computeOutputShape(inputShape: Shape | Shape[]): Shape | Shape[];
    apply(inputs: Tensor | Tensor[] | SymbolicTensor | SymbolicTensor[], kwargs?: Kwargs): Tensor | Tensor[] | SymbolicTensor | SymbolicTensor[];
    call(inputs: Tensor | Tensor[], kwargs: Kwargs): Tensor | Tensor[];
    resetStates(states?: Tensor | Tensor[]): void;
    build(inputShape: Shape | Shape[]): void;
    computeMask(inputs: Tensor | Tensor[], mask?: Tensor | Tensor[]): Tensor | Tensor[];
    get trainableWeights(): LayerVariable[];
    get nonTrainableWeights(): LayerVariable[];
    setFastWeightInitDuringBuild(value: boolean): void;
    getConfig(): serialization.ConfigDict;
    /** @nocollapse */
    static fromConfig<T extends serialization.Serializable>(cls: serialization.SerializableConstructor<T>, config: serialization.ConfigDict): T;
}