gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
/**
 * @license
 * Copyright 2023 Google LLC.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
/// <amd-module name="@tensorflow/tfjs-layers/dist/layers/nlp/modeling/cached_multihead_attention" />
/**
 *  Cached MHA layer based on `MultiHeadAttention`.
 */
import { Tensor } from '@tensorflow/tfjs-core';
import { MultiHeadAttention } from '../multihead_attention';
export declare interface CachedMultiHeadAttentionOptions {
    /**
     * Query `Tensor` of shape `(B, T, dim)`.
     */
    /**
     * Value `Tensor` of shape `(B, S*, dim)`. If `cache` is `null`, `S*`
     * must equal `S` and match the shape of `attentionMask`. If `cache` is
     * not `null`, `S*` can be any length less than `S`, and the computed
     * value will be spliced into `cache` at `cacheUpdateIndex`.
     */
    value: Tensor;
    /**
     * Key `Tensor` of shape `(B, S*, dim)`.  If `cache` is `null`, `S*` must
     * equal `S` and match the shape of `attentionMask`. If `cache` is not `null`,
     * `S*` can be any length less than `S`, and the computed value will be
     * spliced into `cache` at `cacheUpdateIndex`.
     */
    key?: Tensor;
    /**
     * A boolean mask of shape `(B, T, S)`. `attentionMask` prevents
     * attention to certain positions. The boolean mask specifies which
     * query elements can attend to which key elements, 1 indicates
     * attention and 0 indicates no attention. Broadcasting can happen for
     * the missing batch dimensions and the head dimension.
     */
    attentionMask?: Tensor;
    /**
     * A dense float Tensor. The key/value cache, of shape
     * `[B, 2, S, numHeads, keyDims]`, where `S` must agree with the
     * `attentionMask` shape. This argument is intended for use during
     * generation to avoid recomputing intermediate state.
     */
    cache?: Tensor;
    /**
     * Integer or Integer `Tensor`. The index at which to update `cache`
     * (usually the index of the current token being processed when running
     * generation). If `cacheUpdateIndex=null` while `cache` is set, the cache
     * will not be updated.
     */
    cacheUpdateIndex?: number;
}
/**
 * MultiHeadAttention layer with cache support.
 *
 * This layer is suitable for use in autoregressive decoding. It can be use
 * to cache decoder self-attention and cross-attention. The forward pass
 * can happen in one of three modes:
 * - No cache, same as regular multi-head attention.
 * - Static cache (`cacheUpdateIndex` is None). In this case, the
 *     cached key/value projections will be used and the input values will
 *     be ignored.
 * - Updated cache (`cacheUpdateIndex` is not None). In this case, new
 *     key/value projections are computed using the input, and spliced into
 *     the cache at the specified index.
 *
 * Note that caching is useful only during inference and should not be used
 * during training.
 *
 * We use the notation `B`, `T`, `S` below, where `B` is the batch dimension,
 * `T` is the target sequence length, and `S` in the source sequence length.
 * Note that during generative decoding, `T` is usually 1 (you are
 * generating a target sequence of length one to predict the next token).
 *
 * Returns:
 *     An `(attentionOutput, cache)` tuple. `attentionOutput` is the result
 *     of the computation, of shape `(B, T, dim)`, where `T` is for target
 *     sequence shapes and `dim` is the query input last dimension if
 *     `outputShape` is `null`. Otherwise, the multi-head outputs are
 *     projected to the shape specified by `outputShape`. `cache` is the
 *     updated cache.
 */
export declare class CachedMultiHeadAttention extends MultiHeadAttention {
    call(query: Tensor, kwargs: CachedMultiHeadAttentionOptions): Tensor;
    /**
     * Exactly like `call` except also returns the updated cache.
     */
    callAndReturnCache(query: Tensor, { value, key, attentionMask, cache, cacheUpdateIndex }: CachedMultiHeadAttentionOptions): [Tensor, Tensor];
}