gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
/**
 * @license
 * Copyright 2023 Google LLC.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
/// <amd-module name="@tensorflow/tfjs-layers/dist/layers/nlp/multihead_attention" />
/**
 *  TFJS-based multi-head attention layer.
 */
import { Tensor, serialization } from '@tensorflow/tfjs-core';
import { Constraint, ConstraintIdentifier } from '../../constraints';
import { Layer, LayerArgs, SymbolicTensor } from '../../engine/topology';
import { Initializer, InitializerIdentifier } from '../../initializers';
import { Shape } from '../../keras_format/common';
import { Regularizer, RegularizerIdentifier } from '../../regularizers';
import { Kwargs } from '../../types';
import { Softmax } from '../advanced_activations';
import { Dropout } from '../core';
import { EinsumDense } from './einsum_dense';
export declare interface MultiHeadAttentionArgs extends LayerArgs {
    /**
     * Integer. Number of attention heads.
     */
    numHeads: number;
    /**
     * Integer. Size of each attention head for query and key.
     */
    keyDim: number;
    /**
     * Integer. Size of each attention head for value.
     * Defaults to `keyDim`.
     */
    valueDim?: number;
    /**
     * Dropout probability.
     * Defaults to 0.0.
     */
    dropout?: number;
    /**
     * Whether the dense layers use bias vectors/matrices.
     * Defaults to true.
     */
    useBias?: boolean;
    /**
     * The expected shape of an output tensor, besides the batch
     * and sequence dims. If not specified, projects back to the query
     * feature dim (the query input's last dimension).
     */
    outputShape?: Shape;
    /**
     * Axes over which the attention is applied. `null` means attention over
     * all axes, but batch, heads, and features.
     */
    attentionAxes?: number[] | number;
    /**
     * Initializer for dense layer kernels.
     * Defaults to `"glorotUniform"`.
     */
    kernelInitializer?: Initializer | InitializerIdentifier;
    /**
     * Initializer for dense layer biases.
     * Defaults to `"zeros"`.
     */
    biasInitializer?: Initializer | InitializerIdentifier;
    /**
     * Regularizer for dense layer kernels.
     */
    kernelRegularizer?: Regularizer | RegularizerIdentifier;
    /**
     * Regularizer for dense layer biases.
     */
    biasRegularizer?: Regularizer | RegularizerIdentifier;
    /**
     * Regularizer for dense layer activity.
     */
    activityRegularizer?: Regularizer | RegularizerIdentifier;
    /**
     * Constraint for dense layer kernels.
     */
    kernelConstraint?: Constraint | ConstraintIdentifier;
    /**
     * Constraint for dense layer kernels.
     */
    biasConstraint?: Constraint | ConstraintIdentifier;
}
export declare interface MultiHeadAttentionOptions {
    /**
     * Query `Tensor` of shape `(B, T, dim)`.
     */
    /**
     * Value `Tensor` of shape `(B, S, dim)`.
     */
    value: Tensor;
    /**
     * Key `Tensor` of shape `(B, S, dim)`. If not given, will use `value` for
     * both `key` and `value`, which is the most common case.
     */
    key?: Tensor;
    /**
     * A boolean mask of shape `(B, T, S)`, that prevents
     * attention to certain positions. The boolean mask specifies which
     * query elements can attend to which key elements, 1 indicates
     * attention and 0 indicates no attention. Broadcasting can happen for
     * the missing batch dimensions and the head dimension.
     */
    attentionMask?: Tensor;
    /**
     * Indicates whether the layer should behave in training mode
     * (adding dropout) or in inference mode (no dropout).
     * Will go with either using the training mode of the parent
     * layer/model, or false (inference) if there is no parent layer.
     */
    training?: boolean;
    /**
     * Indicates whether to apply a causal mask to prevent tokens from attending
     * to future tokens (e.g., used in a decoder Transformer).
     * Defaults to false.
     */
    useCausalMask?: boolean;
}
/**
 * MultiHeadAttention layer.
 *
 * This is an implementation of multi-headed attention as described in the
 * paper "Attention is all you Need" (Vaswani et al., 2017).
 * If `query`, `key,` `value` are the same, then
 * this is self-attention. Each timestep in `query` attends to the
 * corresponding sequence in `key`, and returns a fixed-width vector.
 *
 * This layer first projects `query`, `key` and `value`. These are
 * (effectively) a list of tensors of length `numAttentionHeads`, where the
 * corresponding shapes are `(batchSize, <query dimensions>, keyDim)`,
 * `(batchSize, <key/value dimensions>, keyDim)`,
 * `(batchSize, <key/value dimensions>, valueDim)`.
 *
 * Then, the query and key tensors are dot-producted and scaled. These are
 * softmaxed to obtain attention probabilities. The value tensors are then
 * interpolated by these probabilities, then concatenated back to a single
 * tensor.
 *
 * Finally, the result tensor with the last dimension as valueDim can take an
 * linear projection and return.
 *
 * When using `MultiHeadAttention` inside a custom layer, the custom layer must
 * implement its own `build()` method and call `MultiHeadAttention`'s
 * `buildFromSignature()` there.
 * This enables weights to be restored correctly when the model is loaded.
 *
 * Examples:
 *
 * Performs 1D cross-attention over two sequence inputs with an attention mask.
 * Returns the additional attention weights over heads.
 *
 * ```js
 * const layer = new MultiHeadAttention({numHeads: 2, keyDim: 2});
 * const target = tf.input({shape: [8, 16]});
 * const source = tf.input({shape: [4, 16]});
 * const outputTensor, weights = layer.callAndReturnAttentionScores(
 *     target, {value: source});
 * console.log(outputTensor.shape);  // [null, 8, 16]
 * console.log(weights.shape);  // [null, 2, 8, 4]
 * ```
 *
 * Performs 2D self-attention over a 5D input tensor on axes 2 and 3.
 *
 * ```js
 * const layer = new MultiHeadAttention({
 *    numHeads: 2, keyDim: 2, attentionAxes: [2, 3]});
 * const inputTensor = tf.input({shape: [5, 3, 4, 16]});
 * const outputTensor = layer.call(inputTensor, {value: inputTensor});
 * console.log(outputTensor.shape);  // [null, 5, 3, 4, 16]
 * ```
 *
 * Returns:
 *    attentionOutput: The result of the computation, of shape `(B, T, E)`,
 *        where `T` is for target sequence shapes and `E` is the query input
 *        last dimension if `outputShape` is `None`. Otherwise, the
 *        multi-head outputs are projected to the shape specified by
 *        `outputShape`.
 *    attentionScores: multi-head attention coefficients over attention axes.
 */
export declare class MultiHeadAttention extends Layer {
    /** @nocollapse */
    static readonly className = "MultiHeadAttention";
    protected readonly numHeads: number;
    protected readonly keyDim: number;
    protected readonly valueDim: number;
    protected readonly dropout: number;
    protected readonly useBias: boolean;
    protected readonly _outputShape: Shape;
    protected readonly kernelInitializer: Initializer;
    protected readonly biasInitializer: Initializer;
    protected readonly kernelRegularizer: Regularizer;
    protected readonly biasRegularizer: Regularizer;
    protected readonly kernelConstraint: Constraint;
    protected readonly biasConstraint: Constraint;
    protected dotProductEquation: string;
    protected combineEquation: string;
    protected attentionAxes: number[];
    protected builtFromSignature: boolean;
    protected softmax: Softmax;
    protected dropoutLayer: Dropout;
    protected queryShape: Shape;
    protected keyShape: Shape;
    protected valueShape: Shape;
    protected queryDense: EinsumDense;
    protected keyDense: EinsumDense;
    protected valueDense: EinsumDense;
    protected outputDense: EinsumDense;
    constructor(args: MultiHeadAttentionArgs);
    /**
     * Should be used for testing purposes only.
     */
    get _queryDense(): EinsumDense;
    /**
     * Should be used for testing purposes only.
     */
    get _keyDense(): EinsumDense;
    /**
     * Should be used for testing purposes only.
     */
    get _valueDense(): EinsumDense;
    /**
     * Should be used for testing purposes only.
     */
    get _outputDense(): EinsumDense;
    getConfig(): serialization.ConfigDict;
    static fromConfig<T extends serialization.Serializable>(cls: serialization.SerializableConstructor<T>, config: serialization.ConfigDict): T;
    /**
     * Builds layers and variables.
     *
     * Once the method is called, this.builtFromSignature will be set to true.
     */
    buildFromSignature(queryShape: Shape, valueShape: Shape, keyShape?: Shape): void;
    private getCommonKwargsForSublayer;
    /**
     * Builds the output projection matrix.
     *
     * @param freeDims Number of free dimensions for einsum equation building.
     * @param commonKwargs Common keyword arguments for einsum layer.
     * @param name Name for the projection layer.
     * @returns Projection layer.
     */
    private makeOutputDense;
    /**
     * Builds multi-head dot-product attention computations.
     *
     * This function builds attributes necessary for `computeAttention` to
     * customize attention computation to replace the default dot-product
     * attention.
     *
     * @param rank The rank of query, key, value tensors.
     */
    protected buildAttention(rank: number): void;
    protected maskedSoftmax(attentionScores: Tensor, attentionMask?: Tensor): Tensor;
    /**
     * Applies Dot-product attention with query, key, value tensors.
     *
     * This function defines the computation inside `call` with projected
     * multi-head Q, K, V inputs. Users can override this function for
     * customized attention implementation.
     *
     * @param query Projected query `Tensor` of shape `(B, T, N, keyDim)`.
     * @param key  Projected key `Tensor` of shape `(B, S, N, keyDim)`.
     * @param value Projected value `Tensor` of shape `(B, S, N, valueDim)`.
     * @param attentionMask A boolean mask of shape `(B, T, S)`, that prevents
     *    attention to certain positions. It is generally not needed if
     *    the `query` and `value` (and/or `key`) are masked.
     * @param training Boolean indicating whether the layer should behave
     *    in training mode (adding dropout) or in inference mode (doing
     *    nothing).
     * @returns attentionOutput: Multi-headed outputs of attention computation.
     * @returns attentionScores: Multi-headed attention weights.
     */
    protected computeAttention(query: Tensor, key: Tensor, value: Tensor, attentionMask?: Tensor, training?: boolean): [Tensor, Tensor];
    apply(inputs: Tensor | SymbolicTensor, kwargs?: Kwargs): Tensor | Tensor[] | SymbolicTensor | SymbolicTensor[];
    call(query: Tensor, kwargs: MultiHeadAttentionOptions): Tensor;
    /**
     * Exactly like `call` except also returns the attention scores.
     */
    callAndReturnAttentionScores(query: Tensor, { value, key, useCausalMask, attentionMask, training }: MultiHeadAttentionOptions): [Tensor, Tensor];
    /**
     * Computes the attention mask.
     *
     * * The `query`'s mask is reshaped from [B, T] to [B, T, 1].
     * * The `value`'s mask is reshaped from [B, S] to [B, 1, S].
     * * The `key`'s mask is reshaped from [B, S] to [B, 1, S]. The `key`'s
     *   mask is ignored if `key` is `None` or if `key is value`.
     * * If `useCausalMask=true`, then the causal mask is computed. Its shape
     *   is [1, T, S].
     *
     * All defined masks are merged using a logical AND operation (`&`).
     *
     * In general, if the `query` and `value` are masked, then there is no need
     * to define the `attentionMask`.
     *
     * @param query Projected query `Tensor` of shape `(B, T, N, keyDim)`.
     * @param key  Projected key `Tensor` of shape `(B, S, N, keyDim)`.
     * @param value Projected value `Tensor` of shape `(B, S, N, valueDim)`.
     * @param attentionMask A boolean mask of shape `(B, T, S)`, that prevents
     *    attention to certain positions.
     * @param useCausalMask  A boolean to indicate whether to apply a causal
     *    mask to prevent tokens from attending to future tokens (e.g.,
     *    used in a decoder Transformer).
     * @returns attentionMask: A boolean mask of shape `(B, T, S)`, that prevents
     *    attention to certain positions, based on the Keras masks of the
     *    `query`, `key`, `value`, and `attentionMask` tensors, and the
     *    causal mask if `useCausalMask=true`.
     */
    private computeAttentionMask;
    /**
     * Computes a causal mask (e.g., for masked self-attention layers).
     *
     * For example, if query and value both contain sequences of length 4,
     * this function returns a boolean `Tensor` equal to:
     *
     * ```
     * [[[true,  false, false, false],
     *   [true,  true,  false, false],
     *   [true,  true,  true,  false],
     *   [true,  true,  true,  true]]]
     * ```
     *
     * @param query query `Tensor` of shape `(B, T, ...)`.
     * @param value value `Tensor` of shape `(B, S, ...)` (defaults to query).
     * @returns mask: A boolean `Tensor` of shape [1, T, S] containing a lower
     *    triangular matrix of shape [T, S].
     */
    private computeCausalMask;
    /**
     *
     * @param inputShapes A list of [queryShape, valueShape] or
     *    [queryShape, valueShape, keyShape]. If no keyShape provided, valueShape
     *    is assumed as the keyShape.
     */
    computeOutputShape(inputShapes: [Shape, Shape, Shape | null]): Shape;
}