/**
|
* @license
|
* Copyright 2023 Google LLC.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
/// <amd-module name="@tensorflow/tfjs-layers/dist/layers/nlp/modeling/cached_multihead_attention" />
|
/**
|
* Cached MHA layer based on `MultiHeadAttention`.
|
*/
|
import { Tensor } from '@tensorflow/tfjs-core';
|
import { MultiHeadAttention } from '../multihead_attention';
|
export declare interface CachedMultiHeadAttentionOptions {
|
/**
|
* Query `Tensor` of shape `(B, T, dim)`.
|
*/
|
/**
|
* Value `Tensor` of shape `(B, S*, dim)`. If `cache` is `null`, `S*`
|
* must equal `S` and match the shape of `attentionMask`. If `cache` is
|
* not `null`, `S*` can be any length less than `S`, and the computed
|
* value will be spliced into `cache` at `cacheUpdateIndex`.
|
*/
|
value: Tensor;
|
/**
|
* Key `Tensor` of shape `(B, S*, dim)`. If `cache` is `null`, `S*` must
|
* equal `S` and match the shape of `attentionMask`. If `cache` is not `null`,
|
* `S*` can be any length less than `S`, and the computed value will be
|
* spliced into `cache` at `cacheUpdateIndex`.
|
*/
|
key?: Tensor;
|
/**
|
* A boolean mask of shape `(B, T, S)`. `attentionMask` prevents
|
* attention to certain positions. The boolean mask specifies which
|
* query elements can attend to which key elements, 1 indicates
|
* attention and 0 indicates no attention. Broadcasting can happen for
|
* the missing batch dimensions and the head dimension.
|
*/
|
attentionMask?: Tensor;
|
/**
|
* A dense float Tensor. The key/value cache, of shape
|
* `[B, 2, S, numHeads, keyDims]`, where `S` must agree with the
|
* `attentionMask` shape. This argument is intended for use during
|
* generation to avoid recomputing intermediate state.
|
*/
|
cache?: Tensor;
|
/**
|
* Integer or Integer `Tensor`. The index at which to update `cache`
|
* (usually the index of the current token being processed when running
|
* generation). If `cacheUpdateIndex=null` while `cache` is set, the cache
|
* will not be updated.
|
*/
|
cacheUpdateIndex?: number;
|
}
|
/**
|
* MultiHeadAttention layer with cache support.
|
*
|
* This layer is suitable for use in autoregressive decoding. It can be use
|
* to cache decoder self-attention and cross-attention. The forward pass
|
* can happen in one of three modes:
|
* - No cache, same as regular multi-head attention.
|
* - Static cache (`cacheUpdateIndex` is None). In this case, the
|
* cached key/value projections will be used and the input values will
|
* be ignored.
|
* - Updated cache (`cacheUpdateIndex` is not None). In this case, new
|
* key/value projections are computed using the input, and spliced into
|
* the cache at the specified index.
|
*
|
* Note that caching is useful only during inference and should not be used
|
* during training.
|
*
|
* We use the notation `B`, `T`, `S` below, where `B` is the batch dimension,
|
* `T` is the target sequence length, and `S` in the source sequence length.
|
* Note that during generative decoding, `T` is usually 1 (you are
|
* generating a target sequence of length one to predict the next token).
|
*
|
* Returns:
|
* An `(attentionOutput, cache)` tuple. `attentionOutput` is the result
|
* of the computation, of shape `(B, T, dim)`, where `T` is for target
|
* sequence shapes and `dim` is the query input last dimension if
|
* `outputShape` is `null`. Otherwise, the multi-head outputs are
|
* projected to the shape specified by `outputShape`. `cache` is the
|
* updated cache.
|
*/
|
export declare class CachedMultiHeadAttention extends MultiHeadAttention {
|
call(query: Tensor, kwargs: CachedMultiHeadAttentionOptions): Tensor;
|
/**
|
* Exactly like `call` except also returns the updated cache.
|
*/
|
callAndReturnCache(query: Tensor, { value, key, attentionMask, cache, cacheUpdateIndex }: CachedMultiHeadAttentionOptions): [Tensor, Tensor];
|
}
|