/**
|
* @license
|
* Copyright 2023 Google LLC.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
/**
|
* TFJS-based multi-head attention layer.
|
*/
|
/* Original source: keras/layers/attention/multi_head_attention.py */
|
import { einsum, linalg, logicalAnd, mul, ones, serialization, tidy, util } from '@tensorflow/tfjs-core';
|
import { cast, expandDims } from '../../backend/tfjs_backend';
|
import { getConstraint, serializeConstraint } from '../../constraints';
|
import { Layer } from '../../engine/topology';
|
import { ValueError } from '../../errors';
|
import { getInitializer, serializeInitializer } from '../../initializers';
|
import { getRegularizer, serializeRegularizer } from '../../regularizers';
|
import { Softmax } from '../advanced_activations';
|
import { Dropout } from '../core';
|
import { EinsumDense } from './einsum_dense';
|
const _CHR_IDX = 'abcdefghijklmnopqrstuvwxyz'.split('');
|
/**
|
* Builds einsum equations for the attention computation.
|
*
|
* Query, key, value inputs after projection are expected to have the shape as:
|
* `(bs, <non-attention dims>, <attention dims>, numHeads, channels)`.
|
* `bs` and `<non-attention dims>` are treated as `<batch dims>`.
|
*
|
* The attention operations can be generalized:
|
* (1) Query-key dot product:
|
* `(<batch dims>, <query attention dims>, numHeads, channels), (<batch dims>,
|
* <key attention dims>, numHeads, channels) -> (<batch dims>,
|
* numHeads, <query attention dims>, <key attention dims>)`
|
* (2) Combination:
|
* `(<batch dims>, numHeads, <query attention dims>, <key attention dims>),
|
* (<batch dims>, <value attention dims>, numHeads, channels) -> (<batch
|
* dims>, <query attention dims>, numHeads, channels)`
|
*
|
* @param rank Rank of query, key, value tensors.
|
* @param attnAxes Array of axes, `[-1, rank)`,
|
* that attention will be applied to.
|
* @returns Einsum equations.
|
*/
|
function buildAttentionEquation(rank, attnAxes) {
|
const targetNotationArr = _CHR_IDX.slice(0, rank);
|
// `batchDims` includes the head dim.
|
const excludeIndices = [...attnAxes, rank - 1];
|
const batchDims = [];
|
for (const e of Array(rank).keys()) {
|
if (!excludeIndices.includes(e)) {
|
batchDims.push(e);
|
}
|
}
|
let letterOffset = rank;
|
let sourceNotation = '';
|
for (let i = 0; i < rank; i++) {
|
if (batchDims.includes(i) || i === rank - 1) {
|
sourceNotation += targetNotationArr[i];
|
}
|
else {
|
sourceNotation += _CHR_IDX[letterOffset];
|
letterOffset++;
|
}
|
}
|
const productNotation = batchDims.map(i => targetNotationArr[i]).concat(attnAxes.map(i => targetNotationArr[i]), attnAxes.map(i => sourceNotation[i])).join('');
|
const targetNotation = targetNotationArr.join('');
|
const dotProductEquation = `${sourceNotation},${targetNotation}->${productNotation}`;
|
const attnScoresRank = productNotation.length;
|
const combineEquation = `${productNotation},${sourceNotation}->${targetNotation}`;
|
return [dotProductEquation, combineEquation, attnScoresRank];
|
}
|
/**
|
* Builds an einsum equation for projections inside multi-head attention.
|
*/
|
function buildProjectionEquation(freeDims, boundDims, outputDims) {
|
let inputStr = '';
|
let kernelStr = '';
|
let outputStr = '';
|
let biasAxes = '';
|
let letterOffset = 0;
|
for (let i = 0; i < freeDims; i++) {
|
const char = _CHR_IDX[i + letterOffset];
|
inputStr += char;
|
outputStr += char;
|
}
|
letterOffset += freeDims;
|
for (let i = 0; i < boundDims; i++) {
|
const char = _CHR_IDX[i + letterOffset];
|
inputStr += char;
|
kernelStr += char;
|
}
|
letterOffset += boundDims;
|
for (let i = 0; i < outputDims; i++) {
|
const char = _CHR_IDX[i + letterOffset];
|
kernelStr += char;
|
outputStr += char;
|
biasAxes += char;
|
}
|
const equation = `${inputStr},${kernelStr}->${outputStr}`;
|
return [equation, biasAxes, outputStr.length];
|
}
|
function getOutputShape(outputRank, knownLastDims) {
|
const outputShape = Array(outputRank - knownLastDims.length).fill(null).concat(knownLastDims);
|
return outputShape;
|
}
|
/**
|
* MultiHeadAttention layer.
|
*
|
* This is an implementation of multi-headed attention as described in the
|
* paper "Attention is all you Need" (Vaswani et al., 2017).
|
* If `query`, `key,` `value` are the same, then
|
* this is self-attention. Each timestep in `query` attends to the
|
* corresponding sequence in `key`, and returns a fixed-width vector.
|
*
|
* This layer first projects `query`, `key` and `value`. These are
|
* (effectively) a list of tensors of length `numAttentionHeads`, where the
|
* corresponding shapes are `(batchSize, <query dimensions>, keyDim)`,
|
* `(batchSize, <key/value dimensions>, keyDim)`,
|
* `(batchSize, <key/value dimensions>, valueDim)`.
|
*
|
* Then, the query and key tensors are dot-producted and scaled. These are
|
* softmaxed to obtain attention probabilities. The value tensors are then
|
* interpolated by these probabilities, then concatenated back to a single
|
* tensor.
|
*
|
* Finally, the result tensor with the last dimension as valueDim can take an
|
* linear projection and return.
|
*
|
* When using `MultiHeadAttention` inside a custom layer, the custom layer must
|
* implement its own `build()` method and call `MultiHeadAttention`'s
|
* `buildFromSignature()` there.
|
* This enables weights to be restored correctly when the model is loaded.
|
*
|
* Examples:
|
*
|
* Performs 1D cross-attention over two sequence inputs with an attention mask.
|
* Returns the additional attention weights over heads.
|
*
|
* ```js
|
* const layer = new MultiHeadAttention({numHeads: 2, keyDim: 2});
|
* const target = tf.input({shape: [8, 16]});
|
* const source = tf.input({shape: [4, 16]});
|
* const outputTensor, weights = layer.callAndReturnAttentionScores(
|
* target, {value: source});
|
* console.log(outputTensor.shape); // [null, 8, 16]
|
* console.log(weights.shape); // [null, 2, 8, 4]
|
* ```
|
*
|
* Performs 2D self-attention over a 5D input tensor on axes 2 and 3.
|
*
|
* ```js
|
* const layer = new MultiHeadAttention({
|
* numHeads: 2, keyDim: 2, attentionAxes: [2, 3]});
|
* const inputTensor = tf.input({shape: [5, 3, 4, 16]});
|
* const outputTensor = layer.call(inputTensor, {value: inputTensor});
|
* console.log(outputTensor.shape); // [null, 5, 3, 4, 16]
|
* ```
|
*
|
* Returns:
|
* attentionOutput: The result of the computation, of shape `(B, T, E)`,
|
* where `T` is for target sequence shapes and `E` is the query input
|
* last dimension if `outputShape` is `None`. Otherwise, the
|
* multi-head outputs are projected to the shape specified by
|
* `outputShape`.
|
* attentionScores: multi-head attention coefficients over attention axes.
|
*/
|
class MultiHeadAttention extends Layer {
|
constructor(args) {
|
var _a, _b, _c, _d, _e;
|
super(args);
|
this.supportsMasking = true;
|
this.numHeads = args.numHeads;
|
this.keyDim = args.keyDim;
|
this.valueDim = (_a = args.valueDim) !== null && _a !== void 0 ? _a : args.keyDim;
|
this.dropout = (_b = args.dropout) !== null && _b !== void 0 ? _b : 0;
|
this.useBias = (_c = args.useBias) !== null && _c !== void 0 ? _c : true;
|
this._outputShape = args.outputShape;
|
this.kernelInitializer = getInitializer((_d = args.kernelInitializer) !== null && _d !== void 0 ? _d : 'glorotUniform');
|
this.biasInitializer = getInitializer((_e = args.biasInitializer) !== null && _e !== void 0 ? _e : 'zeros');
|
this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
|
this.biasRegularizer = getRegularizer(args.biasRegularizer);
|
this.activityRegularizer = getRegularizer(args.activityRegularizer);
|
this.kernelConstraint = getConstraint(args.kernelConstraint);
|
this.biasConstraint = getConstraint(args.biasConstraint);
|
if (args.attentionAxes != null && !Array.isArray(args.attentionAxes)) {
|
this.attentionAxes = [args.attentionAxes];
|
}
|
else {
|
this.attentionAxes = args.attentionAxes;
|
}
|
this.builtFromSignature = false;
|
this.queryShape = null;
|
this.keyShape = null;
|
this.valueShape = null;
|
}
|
/**
|
* Should be used for testing purposes only.
|
*/
|
get _queryDense() {
|
return this.queryDense;
|
}
|
/**
|
* Should be used for testing purposes only.
|
*/
|
get _keyDense() {
|
return this.keyDense;
|
}
|
/**
|
* Should be used for testing purposes only.
|
*/
|
get _valueDense() {
|
return this.valueDense;
|
}
|
/**
|
* Should be used for testing purposes only.
|
*/
|
get _outputDense() {
|
return this.outputDense;
|
}
|
getConfig() {
|
const config = {
|
numHeads: this.numHeads,
|
keyDim: this.keyDim,
|
valueDim: this.valueDim,
|
dropout: this.dropout,
|
useBias: this.useBias,
|
outputShape: this._outputShape,
|
attentionAxes: this.attentionAxes,
|
kernelInitializer: serializeInitializer(this.kernelInitializer),
|
biasInitializer: serializeInitializer(this.biasInitializer),
|
kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
|
biasRegularizer: serializeRegularizer(this.biasRegularizer),
|
activityRegularizer: serializeRegularizer(this.activityRegularizer),
|
kernelConstraint: serializeConstraint(this.kernelConstraint),
|
biasConstraint: serializeConstraint(this.biasConstraint),
|
queryShape: this.queryShape,
|
keyShape: this.keyShape,
|
valueShape: this.valueShape,
|
};
|
const baseConfig = super.getConfig();
|
Object.assign(config, baseConfig);
|
return config;
|
}
|
static fromConfig(cls, config) {
|
// If the layer has a different build() function from the default,
|
// we need to trigger the customized build to create weights.
|
const queryShape = config['queryShape'];
|
const keyShape = config['keyShape'];
|
const valueShape = config['valueShape'];
|
delete config['queryShape'];
|
delete config['keyShape'];
|
delete config['valueShape'];
|
const layer = new cls(config);
|
if ([queryShape, keyShape, valueShape].includes(null)) {
|
console.warn('One of dimensions of the input shape is missing. It ' +
|
'should have been memorized when the layer was serialized. ' +
|
`${cls.toString()} is created without weights.`);
|
}
|
else {
|
layer.buildFromSignature(queryShape, valueShape, keyShape);
|
}
|
return layer;
|
}
|
/**
|
* Builds layers and variables.
|
*
|
* Once the method is called, this.builtFromSignature will be set to true.
|
*/
|
buildFromSignature(queryShape, valueShape, keyShape) {
|
this.builtFromSignature = true;
|
if (keyShape == null) {
|
keyShape = valueShape;
|
}
|
this.queryShape = queryShape;
|
this.valueShape = valueShape;
|
this.keyShape = keyShape;
|
// Not using SymbolicTensors since tf.input() adds a batch dimension to the
|
// given shape, therefore giving the tensor the wrong rank.
|
const queryRank = queryShape.length;
|
const valueRank = valueShape.length;
|
const keyRank = keyShape.length;
|
const freeDims = queryRank - 1;
|
let [einsumEquation, biasAxes, outputRank] = buildProjectionEquation(freeDims, 1, 2);
|
this.queryDense = new EinsumDense(Object.assign({ equation: einsumEquation, outputShape: getOutputShape(outputRank - 1, [this.numHeads, this.keyDim]), biasAxes: this.useBias ? biasAxes : null, name: 'query' }, this.getCommonKwargsForSublayer()));
|
[einsumEquation, biasAxes, outputRank] =
|
buildProjectionEquation(keyRank - 1, 1, 2);
|
this.keyDense = new EinsumDense(Object.assign({ equation: einsumEquation, outputShape: getOutputShape(outputRank - 1, [this.numHeads, this.keyDim]), biasAxes: this.useBias ? biasAxes : null, name: 'key' }, this.getCommonKwargsForSublayer()));
|
[einsumEquation, biasAxes, outputRank] =
|
buildProjectionEquation(valueRank - 1, 1, 2);
|
this.valueDense = new EinsumDense(Object.assign({ equation: einsumEquation, outputShape: getOutputShape(outputRank - 1, [this.numHeads, this.valueDim]), biasAxes: this.useBias ? biasAxes : null, name: 'value' }, this.getCommonKwargsForSublayer()));
|
// Builds the attention computations for multi-head dot product attention.
|
this.buildAttention(outputRank);
|
this.outputDense = this.makeOutputDense(freeDims, this.getCommonKwargsForSublayer(), 'attentionOutput');
|
}
|
getCommonKwargsForSublayer() {
|
// Create new clone of kernel/bias initializer, so that we don't reuse
|
// the initializer instance, which could lead to same init value since
|
// initializer is stateless.
|
const kernelInitializer = getInitializer({
|
className: this.kernelInitializer.getClassName(),
|
config: this.kernelInitializer.getConfig(),
|
});
|
const biasInitializer = getInitializer({
|
className: this.biasInitializer.getClassName(),
|
config: this.biasInitializer.getConfig(),
|
});
|
const commonKwargs = {
|
kernelInitializer,
|
biasInitializer,
|
kernelRegularizer: this.kernelRegularizer,
|
biasRegularizer: this.biasRegularizer,
|
activityRegularizer: this.activityRegularizer,
|
kernelConstraint: this.kernelConstraint,
|
biasConstraint: this.biasConstraint,
|
};
|
return commonKwargs;
|
}
|
/**
|
* Builds the output projection matrix.
|
*
|
* @param freeDims Number of free dimensions for einsum equation building.
|
* @param commonKwargs Common keyword arguments for einsum layer.
|
* @param name Name for the projection layer.
|
* @returns Projection layer.
|
*/
|
makeOutputDense(freeDims, commonKwargs, name) {
|
let outputShape;
|
if (this._outputShape) {
|
if (!Array.isArray(this._outputShape)) {
|
outputShape = [this._outputShape];
|
}
|
else {
|
outputShape = this._outputShape;
|
}
|
}
|
else {
|
outputShape = [this.queryShape[this.queryShape.length - 1]];
|
}
|
const [einsumEquation, biasAxes, outputRank] = buildProjectionEquation(freeDims, 2, outputShape.length);
|
return new EinsumDense(Object.assign({ equation: einsumEquation, outputShape: getOutputShape(outputRank - 1, outputShape), biasAxes: this.useBias ? biasAxes : null, name }, commonKwargs));
|
}
|
/**
|
* Builds multi-head dot-product attention computations.
|
*
|
* This function builds attributes necessary for `computeAttention` to
|
* customize attention computation to replace the default dot-product
|
* attention.
|
*
|
* @param rank The rank of query, key, value tensors.
|
*/
|
buildAttention(rank) {
|
if (this.attentionAxes == null) {
|
this.attentionAxes = [];
|
for (let i = 1; i < rank - 2; i++) {
|
this.attentionAxes.push(i);
|
}
|
}
|
else {
|
this.attentionAxes = [...this.attentionAxes];
|
}
|
const [dotProductEquation, combineEquation, attnScoresRank] = buildAttentionEquation(rank, this.attentionAxes);
|
this.dotProductEquation = dotProductEquation;
|
this.combineEquation = combineEquation;
|
const normAxes = [];
|
const startIdx = attnScoresRank - this.attentionAxes.length;
|
for (let i = startIdx; i < attnScoresRank; i++) {
|
normAxes.push(i);
|
}
|
this.softmax = new Softmax({ axis: normAxes });
|
this.dropoutLayer = new Dropout({ rate: this.dropout });
|
}
|
maskedSoftmax(attentionScores, attentionMask) {
|
return tidy(() => {
|
// Normalize the attention scores to probabilities.
|
// `attentionScores` = [B, N, T, S]
|
if (attentionMask != null) {
|
// The expand dim happens starting from the `numHeads` dimension,
|
// (<batchDims>, numHeads, <queryAttentionDims, keyAttentionDims>)
|
const maskExpansionAxis = -this.attentionAxes.length * 2 - 1;
|
const endIdx = attentionScores.shape.length - attentionMask.shape.length;
|
for (let _ = 0; _ < endIdx; _++) {
|
attentionMask = expandDims(attentionMask, maskExpansionAxis);
|
}
|
}
|
return this.softmax.apply(attentionScores, { mask: attentionMask });
|
});
|
}
|
/**
|
* Applies Dot-product attention with query, key, value tensors.
|
*
|
* This function defines the computation inside `call` with projected
|
* multi-head Q, K, V inputs. Users can override this function for
|
* customized attention implementation.
|
*
|
* @param query Projected query `Tensor` of shape `(B, T, N, keyDim)`.
|
* @param key Projected key `Tensor` of shape `(B, S, N, keyDim)`.
|
* @param value Projected value `Tensor` of shape `(B, S, N, valueDim)`.
|
* @param attentionMask A boolean mask of shape `(B, T, S)`, that prevents
|
* attention to certain positions. It is generally not needed if
|
* the `query` and `value` (and/or `key`) are masked.
|
* @param training Boolean indicating whether the layer should behave
|
* in training mode (adding dropout) or in inference mode (doing
|
* nothing).
|
* @returns attentionOutput: Multi-headed outputs of attention computation.
|
* @returns attentionScores: Multi-headed attention weights.
|
*/
|
computeAttention(query, key, value, attentionMask, training) {
|
return tidy(() => {
|
// Note: Applying scalar multiply at the smaller end of einsum improves
|
// XLA performance, but may introduce slight numeric differences in
|
// the Transformer attention head.
|
query = mul(query, 1.0 / Math.sqrt(this.keyDim));
|
// Take the dot product between "query" and "key" to get the raw
|
// attention scores.
|
let attentionScores = einsum(this.dotProductEquation, key, query);
|
attentionScores = this.maskedSoftmax(attentionScores, attentionMask);
|
// This is actually dropping out entire tokens to attend to, which might
|
// seem a bit unusual, but is taken from the original Transformer paper.
|
const attentionScoresDropout = this.dropoutLayer.apply(attentionScores, { training });
|
// `contextLayer` = [B, T, N, H]
|
const attentionOutput = einsum(this.combineEquation, attentionScoresDropout, value);
|
return [attentionOutput, attentionScores];
|
});
|
}
|
apply(inputs, kwargs) {
|
var _a;
|
if (!kwargs || !kwargs['value']) {
|
throw new ValueError('Must pass in `value` argument in `kwargs.`');
|
}
|
let newInputs;
|
newInputs = [inputs, kwargs['value']].concat((_a = kwargs['key']) !== null && _a !== void 0 ? _a : []);
|
// TODO(pforderique): Support mask propogation.
|
return super.apply(newInputs, kwargs);
|
}
|
call(query, kwargs) {
|
return tidy(() => {
|
return this.callAndReturnAttentionScores(query, kwargs)[0];
|
});
|
}
|
/**
|
* Exactly like `call` except also returns the attention scores.
|
*/
|
callAndReturnAttentionScores(query, { value, key, useCausalMask, attentionMask, training }) {
|
return tidy(() => {
|
if (!this.builtFromSignature) {
|
this.buildFromSignature(query.shape, value.shape, key ? key.shape : null);
|
}
|
if (key == null) {
|
key = value;
|
}
|
// TODO(pforderique): Support RaggedTensor inputs.
|
attentionMask = this.computeAttentionMask(query, value, attentionMask, useCausalMask);
|
// N = `numAttentionHeads`
|
// H = `sizePerHead`
|
// `query` = [B, T, N ,H]
|
query = this.queryDense.apply(query);
|
// `key` = [B, S, N, H]
|
key = this.keyDense.apply(key);
|
// `value` = [B, S, N, H]
|
value = this.valueDense.apply(value);
|
const [attentionOutputPreDense, attentionScores] = this.computeAttention(query, key, value, attentionMask, training);
|
const attentionOutput = this.outputDense.apply(attentionOutputPreDense);
|
return [attentionOutput, attentionScores];
|
});
|
}
|
/**
|
* Computes the attention mask.
|
*
|
* * The `query`'s mask is reshaped from [B, T] to [B, T, 1].
|
* * The `value`'s mask is reshaped from [B, S] to [B, 1, S].
|
* * The `key`'s mask is reshaped from [B, S] to [B, 1, S]. The `key`'s
|
* mask is ignored if `key` is `None` or if `key is value`.
|
* * If `useCausalMask=true`, then the causal mask is computed. Its shape
|
* is [1, T, S].
|
*
|
* All defined masks are merged using a logical AND operation (`&`).
|
*
|
* In general, if the `query` and `value` are masked, then there is no need
|
* to define the `attentionMask`.
|
*
|
* @param query Projected query `Tensor` of shape `(B, T, N, keyDim)`.
|
* @param key Projected key `Tensor` of shape `(B, S, N, keyDim)`.
|
* @param value Projected value `Tensor` of shape `(B, S, N, valueDim)`.
|
* @param attentionMask A boolean mask of shape `(B, T, S)`, that prevents
|
* attention to certain positions.
|
* @param useCausalMask A boolean to indicate whether to apply a causal
|
* mask to prevent tokens from attending to future tokens (e.g.,
|
* used in a decoder Transformer).
|
* @returns attentionMask: A boolean mask of shape `(B, T, S)`, that prevents
|
* attention to certain positions, based on the Keras masks of the
|
* `query`, `key`, `value`, and `attentionMask` tensors, and the
|
* causal mask if `useCausalMask=true`.
|
*/
|
computeAttentionMask(query, value, attentionMask, useCausalMask = false) {
|
return tidy(() => {
|
let autoMask;
|
const queryMask = query.kerasMask;
|
const valueMask = value.kerasMask;
|
if (queryMask != null) {
|
autoMask = queryMask.expandDims(2); // Shape is [B, T, 1]
|
}
|
if (valueMask != null) {
|
const mask = valueMask.expandDims(1); // Shape is [B, 1, S]
|
autoMask = autoMask ? logicalAnd(autoMask, mask) : mask;
|
}
|
if (useCausalMask) {
|
// the shape of the causal mask is [1, T, S]
|
const mask = this.computeCausalMask(query, value);
|
autoMask = autoMask ? logicalAnd(autoMask, mask) : mask;
|
}
|
if (autoMask != null) {
|
// Merge attentionMask & automatic mask, to shape [B, T, S]
|
attentionMask = attentionMask ?
|
cast(attentionMask, 'bool').logicalAnd(autoMask) : autoMask;
|
}
|
return attentionMask;
|
});
|
}
|
/**
|
* Computes a causal mask (e.g., for masked self-attention layers).
|
*
|
* For example, if query and value both contain sequences of length 4,
|
* this function returns a boolean `Tensor` equal to:
|
*
|
* ```
|
* [[[true, false, false, false],
|
* [true, true, false, false],
|
* [true, true, true, false],
|
* [true, true, true, true]]]
|
* ```
|
*
|
* @param query query `Tensor` of shape `(B, T, ...)`.
|
* @param value value `Tensor` of shape `(B, S, ...)` (defaults to query).
|
* @returns mask: A boolean `Tensor` of shape [1, T, S] containing a lower
|
* triangular matrix of shape [T, S].
|
*/
|
computeCausalMask(query, value) {
|
return tidy(() => {
|
const qSeqLength = query.shape[1];
|
const vSeqLength = value ? value.shape[1] : qSeqLength;
|
// Create a lower triangular matrix.
|
return linalg.bandPart(ones([1, qSeqLength, vSeqLength], 'bool'), -1, 0);
|
});
|
}
|
/**
|
*
|
* @param inputShapes A list of [queryShape, valueShape] or
|
* [queryShape, valueShape, keyShape]. If no keyShape provided, valueShape
|
* is assumed as the keyShape.
|
*/
|
computeOutputShape(inputShapes) {
|
const [queryShape, valueShape, maybeKeyShape] = inputShapes;
|
const keyShape = maybeKeyShape !== null && maybeKeyShape !== void 0 ? maybeKeyShape : valueShape;
|
if (queryShape.slice(-1)[0] !== valueShape.slice(-1)[0]) {
|
throw new ValueError(`The last dimension of 'queryShape' and 'valueShape' must be equal, ` +
|
`but are ${queryShape.slice(-1)[0]}, ${valueShape.slice(-1)[0]}. ` +
|
`Received: queryShape=${queryShape}, valueShape=${valueShape}`);
|
}
|
if (!util.arraysEqual(valueShape.slice(1, -1), keyShape.slice(1, -1))) {
|
throw new Error(`All dimensions of 'value' and 'key', except the last one, must be ` +
|
`equal. Received ${valueShape} and ${keyShape}`);
|
}
|
if (this._outputShape) {
|
return queryShape.slice(0, -1).concat(this._outputShape);
|
}
|
return queryShape;
|
}
|
}
|
/** @nocollapse */
|
MultiHeadAttention.className = 'MultiHeadAttention';
|
export { MultiHeadAttention };
|
serialization.registerClass(MultiHeadAttention);
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"multihead_attention.js","sourceRoot":"","sources":["../../../../../../../tfjs-layers/src/layers/nlp/multihead_attention.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,qEAAqE;AACrE,OAAO,EAAU,MAAM,EAAE,MAAM,EAAE,UAAU,EAAE,GAAG,EAAE,IAAI,EAAE,aAAa,EAAE,IAAI,EAAE,IAAI,EAAE,MAAM,uBAAuB,CAAC;AAEjH,OAAO,EAAE,IAAI,EAAE,UAAU,EAAE,MAAM,4BAA4B,CAAC;AAC9D,OAAO,EAAoC,aAAa,EAAE,mBAAmB,EAAE,MAAM,mBAAmB,CAAC;AACzG,OAAO,EAAE,KAAK,EAA6B,MAAM,uBAAuB,CAAC;AACzE,OAAO,EAAE,UAAU,EAAE,MAAM,cAAc,CAAC;AAC1C,OAAO,EAAsC,cAAc,EAAE,oBAAoB,EAAE,MAAM,oBAAoB,CAAC;AAE9G,OAAO,EAAsC,cAAc,EAAE,oBAAoB,EAAE,MAAM,oBAAoB,CAAC;AAE9G,OAAO,EAAE,OAAO,EAAE,MAAM,yBAAyB,CAAC;AAClD,OAAO,EAAE,OAAO,EAAE,MAAM,SAAS,CAAC;AAClC,OAAO,EAAE,WAAW,EAAE,MAAM,gBAAgB,CAAC;AAE7C,MAAM,QAAQ,GAAG,4BAA4B,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;AACxD;;;;;;;;;;;;;;;;;;;;;GAqBG;AACH,SAAS,sBAAsB,CAC7B,IAAY,EAAE,QAAkB;IAEhC,MAAM,iBAAiB,GAAG,QAAQ,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC;IAClD,qCAAqC;IACrC,MAAM,cAAc,GAAG,CAAC,GAAG,QAAQ,EAAE,IAAI,GAAG,CAAC,CAAC,CAAC;IAC/C,MAAM,SAAS,GAAG,EAAE,CAAC;IACrB,KAAK,MAAM,CAAC,IAAI,KAAK,CAAC,IAAI,CAAC,CAAC,IAAI,EAAE,EAAE;QAClC,IAAI,CAAC,cAAc,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE;YAC/B,SAAS,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;SACnB;KACF;IACD,IAAI,YAAY,GAAG,IAAI,CAAC;IACxB,IAAI,cAAc,GAAG,EAAE,CAAC;IACxB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,EAAE;QAC7B,IAAI,SAAS,CAAC,QAAQ,CAAC,CAAC,CAAC,IAAI,CAAC,KAAK,IAAI,GAAG,CAAC,EAAE;YAC3C,cAAc,IAAI,iBAAiB,CAAC,CAAC,CAAC,CAAC;SACxC;aAAM;YACL,cAAc,IAAI,QAAQ,CAAC,YAAY,CAAC,CAAC;YACzC,YAAY,EAAE,CAAC;SAChB;KACF;IAED,MAAM,eAAe,GACnB,SAAS,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,iBAAiB,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,CAC/C,QAAQ,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,iBAAiB,CAAC,CAAC,CAAC,CAAC,EACvC,QAAQ,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,cAAc,CAAC,CAAC,CAAC,CAAC,CACrC,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;IACX,MAAM,cAAc,GAAG,iBAAiB,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;IAElD,MAAM,kBAAkB,GACtB,GAAG,cAAc,IAAI,cAAc,KAAK,eAAe,EAAE,CAAC;IAC5D,MAAM,cAAc,GAAG,eAAe,CAAC,MAAM,CAAC;IAC9C,MAAM,eAAe,GACnB,GAAG,eAAe,IAAI,cAAc,KAAK,cAAc,EAAE,CAAC;IAE5D,OAAO,CAAC,kBAAkB,EAAE,eAAe,EAAE,cAAc,CAAC,CAAC;AAC/D,CAAC;AAED;;GAEG;AACH,SAAS,uBAAuB,CAC9B,QAAgB,EAAE,SAAiB,EAAE,UAAkB;IAEvD,IAAI,QAAQ,GAAG,EAAE,CAAC;IAClB,IAAI,SAAS,GAAG,EAAE,CAAC;IACnB,IAAI,SAAS,GAAG,EAAE,CAAC;IACnB,IAAI,QAAQ,GAAG,EAAE,CAAC;IAClB,IAAI,YAAY,GAAG,CAAC,CAAC;IAErB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,EAAE,CAAC,EAAE,EAAE;QACjC,MAAM,IAAI,GAAG,QAAQ,CAAC,CAAC,GAAG,YAAY,CAAC,CAAC;QACxC,QAAQ,IAAI,IAAI,CAAC;QACjB,SAAS,IAAI,IAAI,CAAC;KACnB;IAED,YAAY,IAAI,QAAQ,CAAC;IACzB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,EAAE,CAAC,EAAE,EAAE;QAClC,MAAM,IAAI,GAAG,QAAQ,CAAC,CAAC,GAAG,YAAY,CAAC,CAAC;QACxC,QAAQ,IAAI,IAAI,CAAC;QACjB,SAAS,IAAI,IAAI,CAAC;KACnB;IAED,YAAY,IAAI,SAAS,CAAC;IAC1B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,EAAE,CAAC,EAAE,EAAE;QACnC,MAAM,IAAI,GAAG,QAAQ,CAAC,CAAC,GAAG,YAAY,CAAC,CAAC;QACxC,SAAS,IAAI,IAAI,CAAC;QAClB,SAAS,IAAI,IAAI,CAAC;QAClB,QAAQ,IAAI,IAAI,CAAC;KAClB;IAED,MAAM,QAAQ,GAAG,GAAG,QAAQ,IAAI,SAAS,KAAK,SAAS,EAAE,CAAC;IAC1D,OAAO,CAAC,QAAQ,EAAE,QAAQ,EAAE,SAAS,CAAC,MAAM,CAAC,CAAC;AAChD,CAAC;AAED,SAAS,cAAc,CACrB,UAAkB,EAAE,aAAuB;IAE3C,MAAM,WAAW,GACf,KAAK,CAAC,UAAU,GAAG,aAAa,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC,MAAM,CAAC,aAAa,CAAC,CAAC;IAC5E,OAAO,WAAW,CAAC;AACrB,CAAC;AA2HD;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GA4DG;AACH,MAAa,kBAAmB,SAAQ,KAAK;IA8B3C,YAAY,IAA4B;;QACtC,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,IAAI,CAAC,eAAe,GAAG,IAAI,CAAC;QAC5B,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;QAC9B,IAAI,CAAC,MAAM,GAAG,IAAI,CAAC,MAAM,CAAC;QAC1B,IAAI,CAAC,QAAQ,GAAG,MAAA,IAAI,CAAC,QAAQ,mCAAI,IAAI,CAAC,MAAM,CAAC;QAC7C,IAAI,CAAC,OAAO,GAAG,MAAA,IAAI,CAAC,OAAO,mCAAI,CAAC,CAAC;QACjC,IAAI,CAAC,OAAO,GAAG,MAAA,IAAI,CAAC,OAAO,mCAAI,IAAI,CAAC;QACpC,IAAI,CAAC,YAAY,GAAG,IAAI,CAAC,WAAW,CAAC;QACrC,IAAI,CAAC,iBAAiB,GAAG,cAAc,CACrC,MAAA,IAAI,CAAC,iBAAiB,mCAAI,eAAe,CAAC,CAAC;QAC7C,IAAI,CAAC,eAAe,GAAG,cAAc,CAAC,MAAA,IAAI,CAAC,eAAe,mCAAI,OAAO,CAAC,CAAC;QACvE,IAAI,CAAC,iBAAiB,GAAG,cAAc,CAAC,IAAI,CAAC,iBAAiB,CAAC,CAAC;QAChE,IAAI,CAAC,eAAe,GAAG,cAAc,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;QAC5D,IAAI,CAAC,mBAAmB,GAAG,cAAc,CAAC,IAAI,CAAC,mBAAmB,CAAC,CAAC;QACpE,IAAI,CAAC,gBAAgB,GAAG,aAAa,CAAC,IAAI,CAAC,gBAAgB,CAAC,CAAC;QAC7D,IAAI,CAAC,cAAc,GAAG,aAAa,CAAC,IAAI,CAAC,cAAc,CAAC,CAAC;QACzD,IAAI,IAAI,CAAC,aAAa,IAAI,IAAI,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,aAAa,CAAC,EAAE;YACpE,IAAI,CAAC,aAAa,GAAG,CAAC,IAAI,CAAC,aAAa,CAAC,CAAC;SAC3C;aAAM;YACL,IAAI,CAAC,aAAa,GAAG,IAAI,CAAC,aAAyB,CAAC;SACrD;QACD,IAAI,CAAC,kBAAkB,GAAG,KAAK,CAAC;QAChC,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC;QACvB,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC;QACrB,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC;IACzB,CAAC;IAED;;OAEG;IACH,IAAI,WAAW;QACb,OAAO,IAAI,CAAC,UAAU,CAAC;IACzB,CAAC;IAED;;OAEG;IACH,IAAI,SAAS;QACX,OAAO,IAAI,CAAC,QAAQ,CAAC;IACvB,CAAC;IAED;;OAEG;IACH,IAAI,WAAW;QACb,OAAO,IAAI,CAAC,UAAU,CAAC;IACzB,CAAC;IAED;;OAEG;IACH,IAAI,YAAY;QACd,OAAO,IAAI,CAAC,WAAW,CAAC;IAC1B,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAAG;YACb,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,MAAM,EAAE,IAAI,CAAC,MAAM;YACnB,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,WAAW,EAAE,IAAI,CAAC,YAAY;YAC9B,aAAa,EAAE,IAAI,CAAC,aAAa;YACjC,iBAAiB,EAAE,oBAAoB,CAAC,IAAI,CAAC,iBAAiB,CAAC;YAC/D,eAAe,EAAE,oBAAoB,CAAC,IAAI,CAAC,eAAe,CAAC;YAC3D,iBAAiB,EAAE,oBAAoB,CAAC,IAAI,CAAC,iBAAiB,CAAC;YAC/D,eAAe,EAAE,oBAAoB,CAAC,IAAI,CAAC,eAAe,CAAC;YAC3D,mBAAmB,EAAE,oBAAoB,CAAC,IAAI,CAAC,mBAAmB,CAAC;YACnE,gBAAgB,EAAE,mBAAmB,CAAC,IAAI,CAAC,gBAAgB,CAAC;YAC5D,cAAc,EAAE,mBAAmB,CAAC,IAAI,CAAC,cAAc,CAAC;YACxD,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,UAAU,EAAE,IAAI,CAAC,UAAU;SAC5B,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;IAED,MAAM,CAAU,UAAU,CACxB,GAA6C,EAC7C,MAAgC;QAEhC,kEAAkE;QAClE,6DAA6D;QAC7D,MAAM,UAAU,GAAG,MAAM,CAAC,YAAY,CAAU,CAAC;QACjD,MAAM,QAAQ,GAAG,MAAM,CAAC,UAAU,CAAU,CAAC;QAC7C,MAAM,UAAU,GAAG,MAAM,CAAC,YAAY,CAAU,CAAC;QACjD,OAAO,MAAM,CAAC,YAAY,CAAC,CAAC;QAC5B,OAAO,MAAM,CAAC,UAAU,CAAC,CAAC;QAC1B,OAAO,MAAM,CAAC,YAAY,CAAC,CAAC;QAE5B,MAAM,KAAK,GAAG,IAAI,GAAG,CAAC,MAAM,CAAC,CAAC;QAC9B,IAAI,CAAC,UAAU,EAAE,QAAQ,EAAE,UAAU,CAAC,CAAC,QAAQ,CAAC,IAAI,CAAC,EAAE;YACnD,OAAO,CAAC,IAAI,CACR,sDAAsD;gBACtD,4DAA4D;gBAC5D,GAAG,GAAG,CAAC,QAAQ,EAAE,8BAA8B,CAClD,CAAC;SACL;aAAM;YACJ,KAAuC,CAAC,kBAAkB,CACzD,UAAU,EAAE,UAAU,EAAE,QAAQ,CAAC,CAAC;SACrC;QACD,OAAO,KAAK,CAAC;IACf,CAAC;IAED;;;;OAIG;IACH,kBAAkB,CAChB,UAAiB,EACjB,UAAiB,EACjB,QAAgB;QAEhB,IAAI,CAAC,kBAAkB,GAAG,IAAI,CAAC;QAE/B,IAAI,QAAQ,IAAI,IAAI,EAAE;YACpB,QAAQ,GAAG,UAAU,CAAC;SACvB;QAED,IAAI,CAAC,UAAU,GAAG,UAAU,CAAC;QAC7B,IAAI,CAAC,UAAU,GAAG,UAAU,CAAC;QAC7B,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QAEzB,2EAA2E;QAC3E,2DAA2D;QAC3D,MAAM,SAAS,GAAG,UAAU,CAAC,MAAM,CAAC;QACpC,MAAM,SAAS,GAAG,UAAU,CAAC,MAAM,CAAC;QACpC,MAAM,OAAO,GAAG,QAAQ,CAAC,MAAM,CAAC;QAEhC,MAAM,QAAQ,GAAG,SAAS,GAAG,CAAC,CAAC;QAC/B,IAAI,CAAC,cAAc,EAAE,QAAQ,EAAE,UAAU,CAAC,GACxC,uBAAuB,CAAC,QAAQ,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;QAC1C,IAAI,CAAC,UAAU,GAAG,IAAI,WAAW,iBAC/B,QAAQ,EAAE,cAAc,EACxB,WAAW,EAAE,cAAc,CAAC,UAAU,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,MAAM,CAAC,CAAC,EACzE,QAAQ,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,QAAQ,CAAC,CAAC,CAAC,IAAI,EACxC,IAAI,EAAE,OAAO,IACV,IAAI,CAAC,0BAA0B,EAAE,EACpC,CAAC;QAEH,CAAC,cAAc,EAAE,QAAQ,EAAE,UAAU,CAAC;YACpC,uBAAuB,CAAC,OAAO,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;QAC7C,IAAI,CAAC,QAAQ,GAAG,IAAI,WAAW,iBAC7B,QAAQ,EAAE,cAAc,EACxB,WAAW,EAAE,cAAc,CAAC,UAAU,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,MAAM,CAAC,CAAC,EACzE,QAAQ,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,QAAQ,CAAC,CAAC,CAAC,IAAI,EACxC,IAAI,EAAE,KAAK,IACR,IAAI,CAAC,0BAA0B,EAAE,EACpC,CAAC;QAEH,CAAC,cAAc,EAAE,QAAQ,EAAE,UAAU,CAAC;YACpC,uBAAuB,CAAC,SAAS,GAAG,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;QAC/C,IAAI,CAAC,UAAU,GAAG,IAAI,WAAW,iBAC/B,QAAQ,EAAE,cAAc,EACxB,WAAW,EAAE,cAAc,CACzB,UAAU,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,EACjD,QAAQ,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,QAAQ,CAAC,CAAC,CAAC,IAAI,EACxC,IAAI,EAAE,OAAO,IACV,IAAI,CAAC,0BAA0B,EAAE,EACpC,CAAC;QAEH,0EAA0E;QAC1E,IAAI,CAAC,cAAc,CAAC,UAAU,CAAC,CAAC;QAChC,IAAI,CAAC,WAAW,GAAG,IAAI,CAAC,eAAe,CACrC,QAAQ,EACR,IAAI,CAAC,0BAA0B,EAAE,EACjC,iBAAiB,CAClB,CAAC;IACJ,CAAC;IAEO,0BAA0B;QAChC,sEAAsE;QACtE,sEAAsE;QACtE,4BAA4B;QAC5B,MAAM,iBAAiB,GAAG,cAAc,CAAC;YACvC,SAAS,EAAE,IAAI,CAAC,iBAAiB,CAAC,YAAY,EAAE;YAChD,MAAM,EAAE,IAAI,CAAC,iBAAiB,CAAC,SAAS,EAAE;SAC3C,CAAC,CAAC;QACH,MAAM,eAAe,GAAG,cAAc,CAAC;YACrC,SAAS,EAAE,IAAI,CAAC,eAAe,CAAC,YAAY,EAAE;YAC9C,MAAM,EAAE,IAAI,CAAC,eAAe,CAAC,SAAS,EAAE;SACzC,CAAC,CAAC;QAEH,MAAM,YAAY,GAAG;YACnB,iBAAiB;YACjB,eAAe;YACf,iBAAiB,EAAE,IAAI,CAAC,iBAAiB;YACzC,eAAe,EAAE,IAAI,CAAC,eAAe;YACrC,mBAAmB,EAAE,IAAI,CAAC,mBAAmB;YAC7C,gBAAgB,EAAE,IAAI,CAAC,gBAAgB;YACvC,cAAc,EAAE,IAAI,CAAC,cAAc;SACpC,CAAC;QACF,OAAO,YAAY,CAAC;IACtB,CAAC;IAED;;;;;;;OAOG;IACK,eAAe,CACrB,QAAgB,EAAE,YAAoB,EAAE,IAAa;QAErD,IAAI,WAAkB,CAAC;QACvB,IAAI,IAAI,CAAC,YAAY,EAAE;YACrB,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,YAAY,CAAC,EAAE;gBACrC,WAAW,GAAG,CAAC,IAAI,CAAC,YAAY,CAAC,CAAC;aACnC;iBAAM;gBACL,WAAW,GAAG,IAAI,CAAC,YAAY,CAAC;aACjC;SACF;aAAM;YACL,WAAW,GAAG,CAAC,IAAI,CAAC,UAAU,CAAC,IAAI,CAAC,UAAU,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC,CAAC;SAC7D;QAED,MAAM,CAAC,cAAc,EAAE,QAAQ,EAAE,UAAU,CAAC,GAC1C,uBAAuB,CAAC,QAAQ,EAAE,CAAC,EAAE,WAAW,CAAC,MAAM,CAAC,CAAC;QAE3D,OAAO,IAAI,WAAW,iBACpB,QAAQ,EAAE,cAAc,EACxB,WAAW,EAAE,cAAc,CAAC,UAAU,GAAG,CAAC,EAAE,WAAW,CAAC,EACxD,QAAQ,EAAE,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,QAAQ,CAAC,CAAC,CAAC,IAAI,EACxC,IAAI,IACD,YAAY,EACf,CAAC;IACL,CAAC;IAED;;;;;;;;OAQG;IACO,cAAc,CAAC,IAAY;QACnC,IAAI,IAAI,CAAC,aAAa,IAAI,IAAI,EAAE;YAC9B,IAAI,CAAC,aAAa,GAAG,EAAE,CAAC;YACxB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE;gBACjC,IAAI,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;aAC5B;SACF;aAAM;YACL,IAAI,CAAC,aAAa,GAAG,CAAC,GAAG,IAAI,CAAC,aAAa,CAAC,CAAC;SAC9C;QAED,MAAM,CAAC,kBAAkB,EAAE,eAAe,EAAE,cAAc,CAAC,GACzD,sBAAsB,CAAC,IAAI,EAAE,IAAI,CAAC,aAAa,CAAC,CAAC;QACnD,IAAI,CAAC,kBAAkB,GAAG,kBAAkB,CAAC;QAC7C,IAAI,CAAC,eAAe,GAAG,eAAe,CAAC;QAEvC,MAAM,QAAQ,GAAa,EAAE,CAAC;QAC9B,MAAM,QAAQ,GAAG,cAAc,GAAG,IAAI,CAAC,aAAa,CAAC,MAAM,CAAC;QAC5D,KAAK,IAAI,CAAC,GAAG,QAAQ,EAAE,CAAC,GAAG,cAAc,EAAE,CAAC,EAAE,EAAE;YAC9C,QAAQ,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;SAClB;QACD,IAAI,CAAC,OAAO,GAAG,IAAI,OAAO,CAAC,EAAC,IAAI,EAAE,QAAQ,EAAC,CAAC,CAAC;QAC7C,IAAI,CAAC,YAAY,GAAG,IAAI,OAAO,CAAC,EAAC,IAAI,EAAE,IAAI,CAAC,OAAO,EAAC,CAAC,CAAC;IACxD,CAAC;IAES,aAAa,CACrB,eAAuB,EAAE,aAAsB;QAE/C,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,mDAAmD;YACnD,mCAAmC;YACnC,IAAI,aAAa,IAAI,IAAI,EAAE;gBACzB,iEAAiE;gBACjE,kEAAkE;gBAClE,MAAM,iBAAiB,GAAG,CAAC,IAAI,CAAC,aAAa,CAAC,MAAM,GAAG,CAAC,GAAG,CAAC,CAAC;gBAC7D,MAAM,MAAM,GACV,eAAe,CAAC,KAAK,CAAC,MAAM,GAAG,aAAa,CAAC,KAAK,CAAC,MAAM,CAAC;gBAC5D,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,MAAM,EAAE,CAAC,EAAE,EAAE;oBAC/B,aAAa,GAAG,UAAU,CAAC,aAAa,EAAE,iBAAiB,CAAC,CAAC;iBAC9D;aACF;YACD,OAAO,IAAI,CAAC,OAAO,CAAC,KAAK,CACvB,eAAe,EAAE,EAAC,IAAI,EAAE,aAAa,EAAC,CAAW,CAAC;QACtD,CAAC,CAAC,CAAC;IACL,CAAC;IAED;;;;;;;;;;;;;;;;;;OAkBG;IACO,gBAAgB,CACxB,KAAa,EACb,GAAW,EACX,KAAa,EACb,aAAsB,EACtB,QAAkB;QAElB,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,uEAAuE;YACvE,mEAAmE;YACnE,kCAAkC;YAClC,KAAK,GAAG,GAAG,CAAC,KAAK,EAAE,GAAG,GAAG,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC,CAAC;YAEjD,gEAAgE;YAChE,oBAAoB;YACpB,IAAI,eAAe,GAAG,MAAM,CAAC,IAAI,CAAC,kBAAkB,EAAE,GAAG,EAAE,KAAK,CAAC,CAAC;YAElE,eAAe,GAAG,IAAI,CAAC,aAAa,CAAC,eAAe,EAAE,aAAa,CAAC,CAAC;YAErE,wEAAwE;YACxE,wEAAwE;YACxE,MAAM,sBAAsB,GAC1B,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,eAAe,EAAE,EAAC,QAAQ,EAAC,CAAW,CAAC;YAEjE,gCAAgC;YAChC,MAAM,eAAe,GACnB,MAAM,CAAC,IAAI,CAAC,eAAe,EAAE,sBAAsB,EAAE,KAAK,CAAC,CAAC;YAE9D,OAAO,CAAC,eAAe,EAAE,eAAe,CAAC,CAAC;QAC5C,CAAC,CAAC,CAAC;IACL,CAAC;IAEQ,KAAK,CACZ,MAA+B,EAC/B,MAAe;;QAEf,IAAI,CAAC,MAAM,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE;YAC/B,MAAM,IAAI,UAAU,CAAC,4CAA4C,CAAC,CAAC;SACpE;QACD,IAAI,SAAoC,CAAC;QAEzC,SAAS,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,OAAO,CAAC,CAAC,CAAC,MAAM,CAAC,MAAA,MAAM,CAAC,KAAK,CAAC,mCAAI,EAAE,CAAC,CAAC;QAElE,+CAA+C;QAC/C,OAAO,KAAK,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAC,CAAC;IACxC,CAAC;IAEQ,IAAI,CACX,KAAa,EAAE,MAAiC;QAEhD,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,OAAO,IAAI,CAAC,4BAA4B,CAAC,KAAK,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;QAC7D,CAAC,CAAC,CAAC;IACL,CAAC;IAED;;OAEG;IACH,4BAA4B,CAC1B,KAAa,EACb,EACE,KAAK,EACL,GAAG,EACH,aAAa,EACb,aAAa,EACb,QAAQ,EACkB;QAE5B,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,IAAI,CAAC,IAAI,CAAC,kBAAkB,EAAE;gBAC5B,IAAI,CAAC,kBAAkB,CACrB,KAAK,CAAC,KAAK,EACX,KAAK,CAAC,KAAK,EACX,GAAG,CAAC,CAAC,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI,CACvB,CAAC;aACH;YACD,IAAI,GAAG,IAAI,IAAI,EAAE;gBACf,GAAG,GAAG,KAAK,CAAC;aACb;YAED,kDAAkD;YAElD,aAAa,GAAG,IAAI,CAAC,oBAAoB,CACvC,KAAK,EACL,KAAK,EACL,aAAa,EACb,aAAa,CACd,CAAC;YAEF,4BAA4B;YAC5B,sBAAsB;YACtB,yBAAyB;YACzB,KAAK,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,KAAK,CAAW,CAAC;YAE/C,uBAAuB;YACvB,GAAG,GAAG,IAAI,CAAC,QAAQ,CAAC,KAAK,CAAC,GAAG,CAAW,CAAC;YAEzC,yBAAyB;YACzB,KAAK,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,KAAK,CAAW,CAAC;YAE/C,MAAM,CAAC,uBAAuB,EAAE,eAAe,CAAC,GAAG,IAAI,CAAC,gBAAgB,CACtE,KAAK,EACL,GAAG,EACH,KAAK,EACL,aAAa,EACb,QAAQ,CACT,CAAC;YACF,MAAM,eAAe,GACnB,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,uBAAuB,CAAW,CAAC;YAE5D,OAAO,CAAC,eAAe,EAAE,eAAe,CAAC,CAAC;QAC5C,CAAC,CAAC,CAAC;IACL,CAAC;IAED;;;;;;;;;;;;;;;;;;;;;;;;;;;OA2BG;IACK,oBAAoB,CAC1B,KAAa,EACb,KAAa,EACb,aAAsB,EACtB,aAAa,GAAG,KAAK;QAErB,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,IAAI,QAAgB,CAAC;YAErB,MAAM,SAAS,GAAG,KAAK,CAAC,SAAS,CAAC;YAClC,MAAM,SAAS,GAAG,KAAK,CAAC,SAAS,CAAC;YAClC,IAAI,SAAS,IAAI,IAAI,EAAE;gBACrB,QAAQ,GAAG,SAAS,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,qBAAqB;aAC1D;YACD,IAAI,SAAS,IAAI,IAAI,EAAE;gBACrB,MAAM,IAAI,GAAG,SAAS,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,qBAAqB;gBAC3D,QAAQ,GAAG,QAAQ,CAAC,CAAC,CAAC,UAAU,CAAC,QAAQ,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC;aACzD;YACD,IAAI,aAAa,EAAE;gBACjB,4CAA4C;gBAC5C,MAAM,IAAI,GAAG,IAAI,CAAC,iBAAiB,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC;gBAClD,QAAQ,GAAG,QAAQ,CAAC,CAAC,CAAC,UAAU,CAAC,QAAQ,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC;aACzD;YACD,IAAI,QAAQ,IAAI,IAAI,EAAE;gBACpB,2DAA2D;gBAC3D,aAAa,GAAG,aAAa,CAAC,CAAC;oBAC7B,IAAI,CAAC,aAAa,EAAE,MAAM,CAAC,CAAC,UAAU,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,QAAQ,CAAC;aAC/D;YAED,OAAO,aAAa,CAAC;QACvB,CAAC,CAAC,CAAC;IACL,CAAC;IAED;;;;;;;;;;;;;;;;;OAiBG;IACK,iBAAiB,CAAC,KAAa,EAAE,KAAc;QACrD,OAAO,IAAI,CAAC,GAAG,EAAE;YACf,MAAM,UAAU,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAClC,MAAM,UAAU,GAAG,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC;YACvD,oCAAoC;YACpC,OAAO,MAAM,CAAC,QAAQ,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,UAAU,EAAE,UAAU,CAAC,EAAE,MAAM,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;QAC3E,CAAC,CAAC,CAAC;IACL,CAAC;IAED;;;;;OAKG;IACM,kBAAkB,CAAC,WAAuC;QACjE,MAAM,CAAC,UAAU,EAAE,UAAU,EAAE,aAAa,CAAC,GAAG,WAAW,CAAC;QAC5D,MAAM,QAAQ,GAAG,aAAa,aAAb,aAAa,cAAb,aAAa,GAAI,UAAU,CAAC;QAE7C,IAAI,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE;YACvD,MAAM,IAAI,UAAU,CAClB,qEAAqE;gBACrE,WAAW,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI;gBAClE,wBAAwB,UAAU,gBAAgB,UAAU,EAAE,CAC/D,CAAC;SACH;QAED,IAAI,CAAC,IAAI,CAAC,WAAW,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,QAAQ,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,EAAE;YACrE,MAAM,IAAI,KAAK,CACb,oEAAoE;gBACpE,mBAAmB,UAAU,QAAQ,QAAQ,EAAE,CAChD,CAAC;SACH;QAED,IAAI,IAAI,CAAC,YAAY,EAAE;YACrB,OAAO,UAAU,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,IAAI,CAAC,YAAY,CAAC,CAAC;SAC1D;QAED,OAAO,UAAU,CAAC;IACpB,CAAC;;AAxjBD,kBAAkB;AACF,4BAAS,GAAG,oBAAoB,CAAC;SAFtC,kBAAkB;AA2jB/B,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n *  TFJS-based multi-head attention layer.\n */\n\n/* Original source: keras/layers/attention/multi_head_attention.py */\nimport { Tensor, einsum, linalg, logicalAnd, mul, ones, serialization, tidy, util } from '@tensorflow/tfjs-core';\n\nimport { cast, expandDims } from '../../backend/tfjs_backend';\nimport { Constraint, ConstraintIdentifier, getConstraint, serializeConstraint } from '../../constraints';\nimport { Layer, LayerArgs, SymbolicTensor } from '../../engine/topology';\nimport { ValueError } from '../../errors';\nimport { Initializer, InitializerIdentifier, getInitializer, serializeInitializer } from '../../initializers';\nimport { Shape } from '../../keras_format/common';\nimport { Regularizer, RegularizerIdentifier, getRegularizer, serializeRegularizer } from '../../regularizers';\nimport { Kwargs } from '../../types';\nimport { Softmax } from '../advanced_activations';\nimport { Dropout } from '../core';\nimport { EinsumDense } from './einsum_dense';\n\nconst _CHR_IDX = 'abcdefghijklmnopqrstuvwxyz'.split('');\n/**\n * Builds einsum equations for the attention computation.\n *\n * Query, key, value inputs after projection are expected to have the shape as:\n * `(bs, <non-attention dims>, <attention dims>, numHeads, channels)`.\n * `bs` and `<non-attention dims>` are treated as `<batch dims>`.\n *\n * The attention operations can be generalized:\n * (1) Query-key dot product:\n * `(<batch dims>, <query attention dims>, numHeads, channels), (<batch dims>,\n * <key attention dims>, numHeads, channels) -> (<batch dims>,\n * numHeads, <query attention dims>, <key attention dims>)`\n * (2) Combination:\n * `(<batch dims>, numHeads, <query attention dims>, <key attention dims>),\n * (<batch dims>, <value attention dims>, numHeads, channels) -> (<batch\n * dims>, <query attention dims>, numHeads, channels)`\n *\n * @param rank Rank of query, key, value tensors.\n * @param attnAxes Array of axes, `[-1, rank)`,\n *    that attention will be applied to.\n * @returns Einsum equations.\n */\nfunction buildAttentionEquation(\n  rank: number, attnAxes: number[]\n): [string, string, number] {\n  const targetNotationArr = _CHR_IDX.slice(0, rank);\n  // `batchDims` includes the head dim.\n  const excludeIndices = [...attnAxes, rank - 1];\n  const batchDims = [];\n  for (const e of Array(rank).keys()) {\n    if (!excludeIndices.includes(e)) {\n      batchDims.push(e);\n    }\n  }\n  let letterOffset = rank;\n  let sourceNotation = '';\n  for (let i = 0; i < rank; i++) {\n    if (batchDims.includes(i) || i === rank - 1) {\n      sourceNotation += targetNotationArr[i];\n    } else {\n      sourceNotation += _CHR_IDX[letterOffset];\n      letterOffset++;\n    }\n  }\n\n  const productNotation =\n    batchDims.map(i => targetNotationArr[i]).concat(\n    attnAxes.map(i => targetNotationArr[i]),\n    attnAxes.map(i => sourceNotation[i]),\n  ).join('');\n  const targetNotation = targetNotationArr.join('');\n\n  const dotProductEquation =\n    `${sourceNotation},${targetNotation}->${productNotation}`;\n  const attnScoresRank = productNotation.length;\n  const combineEquation =\n    `${productNotation},${sourceNotation}->${targetNotation}`;\n\n  return [dotProductEquation, combineEquation, attnScoresRank];\n}\n\n/**\n * Builds an einsum equation for projections inside multi-head attention.\n */\nfunction buildProjectionEquation(\n  freeDims: number, boundDims: number, outputDims: number\n): [string, string, number] {\n  let inputStr = '';\n  let kernelStr = '';\n  let outputStr = '';\n  let biasAxes = '';\n  let letterOffset = 0;\n\n  for (let i = 0; i < freeDims; i++) {\n    const char = _CHR_IDX[i + letterOffset];\n    inputStr += char;\n    outputStr += char;\n  }\n\n  letterOffset += freeDims;\n  for (let i = 0; i < boundDims; i++) {\n    const char = _CHR_IDX[i + letterOffset];\n    inputStr += char;\n    kernelStr += char;\n  }\n\n  letterOffset += boundDims;\n  for (let i = 0; i < outputDims; i++) {\n    const char = _CHR_IDX[i + letterOffset];\n    kernelStr += char;\n    outputStr += char;\n    biasAxes += char;\n  }\n\n  const equation = `${inputStr},${kernelStr}->${outputStr}`;\n  return [equation, biasAxes, outputStr.length];\n}\n\nfunction getOutputShape(\n  outputRank: number, knownLastDims: number[]\n): Shape {\n  const outputShape =\n    Array(outputRank - knownLastDims.length).fill(null).concat(knownLastDims);\n  return outputShape;\n}\n\nexport declare interface MultiHeadAttentionArgs extends LayerArgs {\n  /**\n   * Integer. Number of attention heads.\n   */\n  numHeads: number;\n\n  /**\n   * Integer. Size of each attention head for query and key.\n   */\n  keyDim: number;\n\n  /**\n   * Integer. Size of each attention head for value.\n   * Defaults to `keyDim`.\n   */\n  valueDim?: number;\n\n  /**\n   * Dropout probability.\n   * Defaults to 0.0.\n   */\n  dropout?: number;\n\n  /**\n   * Whether the dense layers use bias vectors/matrices.\n   * Defaults to true.\n   */\n  useBias?: boolean;\n\n  /**\n   * The expected shape of an output tensor, besides the batch\n   * and sequence dims. If not specified, projects back to the query\n   * feature dim (the query input's last dimension).\n   */\n  outputShape?: Shape;\n\n  /**\n   * Axes over which the attention is applied. `null` means attention over\n   * all axes, but batch, heads, and features.\n   */\n  attentionAxes?: number[]|number;\n\n  /**\n   * Initializer for dense layer kernels.\n   * Defaults to `\"glorotUniform\"`.\n   */\n  kernelInitializer?: Initializer|InitializerIdentifier;\n\n  /**\n   * Initializer for dense layer biases.\n   * Defaults to `\"zeros\"`.\n   */\n  biasInitializer?: Initializer|InitializerIdentifier;\n\n  /**\n   * Regularizer for dense layer kernels.\n   */\n  kernelRegularizer?: Regularizer|RegularizerIdentifier;\n\n  /**\n   * Regularizer for dense layer biases.\n   */\n  biasRegularizer?: Regularizer|RegularizerIdentifier;\n\n  /**\n   * Regularizer for dense layer activity.\n   */\n  activityRegularizer?: Regularizer|RegularizerIdentifier;\n\n  /**\n   * Constraint for dense layer kernels.\n   */\n  kernelConstraint?: Constraint|ConstraintIdentifier;\n\n  /**\n   * Constraint for dense layer kernels.\n   */\n  biasConstraint?: Constraint|ConstraintIdentifier;\n}\n\nexport declare interface MultiHeadAttentionOptions {\n  /**\n   * Query `Tensor` of shape `(B, T, dim)`.\n   */\n\n  /**\n   * Value `Tensor` of shape `(B, S, dim)`.\n   */\n  value: Tensor;\n\n  /**\n   * Key `Tensor` of shape `(B, S, dim)`. If not given, will use `value` for\n   * both `key` and `value`, which is the most common case.\n   */\n  key?: Tensor;\n\n  /**\n   * A boolean mask of shape `(B, T, S)`, that prevents\n   * attention to certain positions. The boolean mask specifies which\n   * query elements can attend to which key elements, 1 indicates\n   * attention and 0 indicates no attention. Broadcasting can happen for\n   * the missing batch dimensions and the head dimension.\n   */\n  attentionMask?: Tensor;\n\n  /**\n   * Indicates whether the layer should behave in training mode\n   * (adding dropout) or in inference mode (no dropout).\n   * Will go with either using the training mode of the parent\n   * layer/model, or false (inference) if there is no parent layer.\n   */\n  training?: boolean;\n\n  /**\n   * Indicates whether to apply a causal mask to prevent tokens from attending\n   * to future tokens (e.g., used in a decoder Transformer).\n   * Defaults to false.\n   */\n  useCausalMask?: boolean;\n}\n\n/**\n * MultiHeadAttention layer.\n *\n * This is an implementation of multi-headed attention as described in the\n * paper \"Attention is all you Need\" (Vaswani et al., 2017).\n * If `query`, `key,` `value` are the same, then\n * this is self-attention. Each timestep in `query` attends to the\n * corresponding sequence in `key`, and returns a fixed-width vector.\n *\n * This layer first projects `query`, `key` and `value`. These are\n * (effectively) a list of tensors of length `numAttentionHeads`, where the\n * corresponding shapes are `(batchSize, <query dimensions>, keyDim)`,\n * `(batchSize, <key/value dimensions>, keyDim)`,\n * `(batchSize, <key/value dimensions>, valueDim)`.\n *\n * Then, the query and key tensors are dot-producted and scaled. These are\n * softmaxed to obtain attention probabilities. The value tensors are then\n * interpolated by these probabilities, then concatenated back to a single\n * tensor.\n *\n * Finally, the result tensor with the last dimension as valueDim can take an\n * linear projection and return.\n *\n * When using `MultiHeadAttention` inside a custom layer, the custom layer must\n * implement its own `build()` method and call `MultiHeadAttention`'s\n * `buildFromSignature()` there.\n * This enables weights to be restored correctly when the model is loaded.\n *\n * Examples:\n *\n * Performs 1D cross-attention over two sequence inputs with an attention mask.\n * Returns the additional attention weights over heads.\n *\n * ```js\n * const layer = new MultiHeadAttention({numHeads: 2, keyDim: 2});\n * const target = tf.input({shape: [8, 16]});\n * const source = tf.input({shape: [4, 16]});\n * const outputTensor, weights = layer.callAndReturnAttentionScores(\n *     target, {value: source});\n * console.log(outputTensor.shape);  // [null, 8, 16]\n * console.log(weights.shape);  // [null, 2, 8, 4]\n * ```\n *\n * Performs 2D self-attention over a 5D input tensor on axes 2 and 3.\n *\n * ```js\n * const layer = new MultiHeadAttention({\n *    numHeads: 2, keyDim: 2, attentionAxes: [2, 3]});\n * const inputTensor = tf.input({shape: [5, 3, 4, 16]});\n * const outputTensor = layer.call(inputTensor, {value: inputTensor});\n * console.log(outputTensor.shape);  // [null, 5, 3, 4, 16]\n * ```\n *\n * Returns:\n *    attentionOutput: The result of the computation, of shape `(B, T, E)`,\n *        where `T` is for target sequence shapes and `E` is the query input\n *        last dimension if `outputShape` is `None`. Otherwise, the\n *        multi-head outputs are projected to the shape specified by\n *        `outputShape`.\n *    attentionScores: multi-head attention coefficients over attention axes.\n */\nexport class MultiHeadAttention extends Layer {\n  /** @nocollapse */\n  static readonly className = 'MultiHeadAttention';\n\n  protected readonly numHeads: number;\n  protected readonly keyDim: number;\n  protected readonly valueDim: number;\n  protected readonly dropout: number;\n  protected readonly useBias: boolean;\n  protected readonly _outputShape: Shape;\n  protected readonly kernelInitializer: Initializer;\n  protected readonly biasInitializer: Initializer;\n  protected readonly kernelRegularizer: Regularizer;\n  protected readonly biasRegularizer: Regularizer;\n  protected readonly kernelConstraint: Constraint;\n  protected readonly biasConstraint: Constraint;\n  protected dotProductEquation: string;\n  protected combineEquation: string;\n  protected attentionAxes: number[];\n  protected builtFromSignature: boolean;\n  protected softmax: Softmax;\n  protected dropoutLayer: Dropout;\n  protected queryShape: Shape;\n  protected keyShape: Shape;\n  protected valueShape: Shape;\n  protected queryDense: EinsumDense;\n  protected keyDense: EinsumDense;\n  protected valueDense: EinsumDense;\n  protected outputDense: EinsumDense;\n\n  constructor(args: MultiHeadAttentionArgs) {\n    super(args);\n    this.supportsMasking = true;\n    this.numHeads = args.numHeads;\n    this.keyDim = args.keyDim;\n    this.valueDim = args.valueDim ?? args.keyDim;\n    this.dropout = args.dropout ?? 0;\n    this.useBias = args.useBias ?? true;\n    this._outputShape = args.outputShape;\n    this.kernelInitializer = getInitializer(\n      args.kernelInitializer ?? 'glorotUniform');\n    this.biasInitializer = getInitializer(args.biasInitializer ?? 'zeros');\n    this.kernelRegularizer = getRegularizer(args.kernelRegularizer);\n    this.biasRegularizer = getRegularizer(args.biasRegularizer);\n    this.activityRegularizer = getRegularizer(args.activityRegularizer);\n    this.kernelConstraint = getConstraint(args.kernelConstraint);\n    this.biasConstraint = getConstraint(args.biasConstraint);\n    if (args.attentionAxes != null && !Array.isArray(args.attentionAxes)) {\n      this.attentionAxes = [args.attentionAxes];\n    } else {\n      this.attentionAxes = args.attentionAxes as number[];\n    }\n    this.builtFromSignature = false;\n    this.queryShape = null;\n    this.keyShape = null;\n    this.valueShape = null;\n  }\n\n  /**\n   * Should be used for testing purposes only.\n   */\n  get _queryDense() {\n    return this.queryDense;\n  }\n\n  /**\n   * Should be used for testing purposes only.\n   */\n  get _keyDense() {\n    return this.keyDense;\n  }\n\n  /**\n   * Should be used for testing purposes only.\n   */\n  get _valueDense() {\n    return this.valueDense;\n  }\n\n  /**\n   * Should be used for testing purposes only.\n   */\n  get _outputDense() {\n    return this.outputDense;\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config = {\n      numHeads: this.numHeads,\n      keyDim: this.keyDim,\n      valueDim: this.valueDim,\n      dropout: this.dropout,\n      useBias: this.useBias,\n      outputShape: this._outputShape,\n      attentionAxes: this.attentionAxes,\n      kernelInitializer: serializeInitializer(this.kernelInitializer),\n      biasInitializer: serializeInitializer(this.biasInitializer),\n      kernelRegularizer: serializeRegularizer(this.kernelRegularizer),\n      biasRegularizer: serializeRegularizer(this.biasRegularizer),\n      activityRegularizer: serializeRegularizer(this.activityRegularizer),\n      kernelConstraint: serializeConstraint(this.kernelConstraint),\n      biasConstraint: serializeConstraint(this.biasConstraint),\n      queryShape: this.queryShape,\n      keyShape: this.keyShape,\n      valueShape: this.valueShape,\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n\n  static override fromConfig<T extends serialization.Serializable>(\n    cls: serialization.SerializableConstructor<T>,\n    config: serialization.ConfigDict\n  ): T {\n    // If the layer has a different build() function from the default,\n    // we need to trigger the customized build to create weights.\n    const queryShape = config['queryShape'] as Shape;\n    const keyShape = config['keyShape'] as Shape;\n    const valueShape = config['valueShape'] as Shape;\n    delete config['queryShape'];\n    delete config['keyShape'];\n    delete config['valueShape'];\n\n    const layer = new cls(config);\n    if ([queryShape, keyShape, valueShape].includes(null)) {\n        console.warn(\n            'One of dimensions of the input shape is missing. It ' +\n            'should have been memorized when the layer was serialized. ' +\n            `${cls.toString()} is created without weights.`\n        );\n    } else {\n      (layer as unknown as MultiHeadAttention).buildFromSignature(\n        queryShape, valueShape, keyShape);\n    }\n    return layer;\n  }\n\n  /**\n   * Builds layers and variables.\n   *\n   * Once the method is called, this.builtFromSignature will be set to true.\n   */\n  buildFromSignature(\n    queryShape: Shape,\n    valueShape: Shape,\n    keyShape?: Shape\n  ) {\n    this.builtFromSignature = true;\n\n    if (keyShape == null) {\n      keyShape = valueShape;\n    }\n\n    this.queryShape = queryShape;\n    this.valueShape = valueShape;\n    this.keyShape = keyShape;\n\n    // Not using SymbolicTensors since tf.input() adds a batch dimension to the\n    // given shape, therefore giving the tensor the wrong rank.\n    const queryRank = queryShape.length;\n    const valueRank = valueShape.length;\n    const keyRank = keyShape.length;\n\n    const freeDims = queryRank - 1;\n    let [einsumEquation, biasAxes, outputRank] =\n      buildProjectionEquation(freeDims, 1, 2);\n    this.queryDense = new EinsumDense({\n      equation: einsumEquation,\n      outputShape: getOutputShape(outputRank - 1, [this.numHeads, this.keyDim]),\n      biasAxes: this.useBias ? biasAxes : null,\n      name: 'query',\n      ...this.getCommonKwargsForSublayer(),\n    });\n\n    [einsumEquation, biasAxes, outputRank] =\n      buildProjectionEquation(keyRank - 1, 1, 2);\n    this.keyDense = new EinsumDense({\n      equation: einsumEquation,\n      outputShape: getOutputShape(outputRank - 1, [this.numHeads, this.keyDim]),\n      biasAxes: this.useBias ? biasAxes : null,\n      name: 'key',\n      ...this.getCommonKwargsForSublayer(),\n    });\n\n    [einsumEquation, biasAxes, outputRank] =\n      buildProjectionEquation(valueRank - 1, 1, 2);\n    this.valueDense = new EinsumDense({\n      equation: einsumEquation,\n      outputShape: getOutputShape(\n        outputRank - 1, [this.numHeads, this.valueDim]),\n      biasAxes: this.useBias ? biasAxes : null,\n      name: 'value',\n      ...this.getCommonKwargsForSublayer(),\n    });\n\n    // Builds the attention computations for multi-head dot product attention.\n    this.buildAttention(outputRank);\n    this.outputDense = this.makeOutputDense(\n      freeDims,\n      this.getCommonKwargsForSublayer(),\n      'attentionOutput'\n    );\n  }\n\n  private getCommonKwargsForSublayer(): Kwargs {\n    // Create new clone of kernel/bias initializer, so that we don't reuse\n    // the initializer instance, which could lead to same init value since\n    // initializer is stateless.\n    const kernelInitializer = getInitializer({\n      className: this.kernelInitializer.getClassName(),\n      config: this.kernelInitializer.getConfig(),\n    });\n    const biasInitializer = getInitializer({\n      className: this.biasInitializer.getClassName(),\n      config: this.biasInitializer.getConfig(),\n    });\n\n    const commonKwargs = {\n      kernelInitializer,\n      biasInitializer,\n      kernelRegularizer: this.kernelRegularizer,\n      biasRegularizer: this.biasRegularizer,\n      activityRegularizer: this.activityRegularizer,\n      kernelConstraint: this.kernelConstraint,\n      biasConstraint: this.biasConstraint,\n    };\n    return commonKwargs;\n  }\n\n  /**\n   * Builds the output projection matrix.\n   *\n   * @param freeDims Number of free dimensions for einsum equation building.\n   * @param commonKwargs Common keyword arguments for einsum layer.\n   * @param name Name for the projection layer.\n   * @returns Projection layer.\n   */\n  private makeOutputDense(\n    freeDims: number, commonKwargs: Kwargs, name?: string\n  ): EinsumDense {\n    let outputShape: Shape;\n    if (this._outputShape) {\n      if (!Array.isArray(this._outputShape)) {\n        outputShape = [this._outputShape];\n      } else {\n        outputShape = this._outputShape;\n      }\n    } else {\n      outputShape = [this.queryShape[this.queryShape.length - 1]];\n    }\n\n    const [einsumEquation, biasAxes, outputRank] =\n      buildProjectionEquation(freeDims, 2, outputShape.length);\n\n    return new EinsumDense({\n      equation: einsumEquation,\n      outputShape: getOutputShape(outputRank - 1, outputShape),\n      biasAxes: this.useBias ? biasAxes : null,\n      name,\n      ...commonKwargs,\n    });\n  }\n\n  /**\n   * Builds multi-head dot-product attention computations.\n   *\n   * This function builds attributes necessary for `computeAttention` to\n   * customize attention computation to replace the default dot-product\n   * attention.\n   *\n   * @param rank The rank of query, key, value tensors.\n   */\n  protected buildAttention(rank: number) {\n    if (this.attentionAxes == null) {\n      this.attentionAxes = [];\n      for (let i = 1; i < rank - 2; i++) {\n        this.attentionAxes.push(i);\n      }\n    } else {\n      this.attentionAxes = [...this.attentionAxes];\n    }\n\n    const [dotProductEquation, combineEquation, attnScoresRank] =\n      buildAttentionEquation(rank, this.attentionAxes);\n    this.dotProductEquation = dotProductEquation;\n    this.combineEquation = combineEquation;\n\n    const normAxes: number[] = [];\n    const startIdx = attnScoresRank - this.attentionAxes.length;\n    for (let i = startIdx; i < attnScoresRank; i++) {\n      normAxes.push(i);\n    }\n    this.softmax = new Softmax({axis: normAxes});\n    this.dropoutLayer = new Dropout({rate: this.dropout});\n  }\n\n  protected maskedSoftmax(\n    attentionScores: Tensor, attentionMask?: Tensor\n  ): Tensor {\n    return tidy(() => {\n      // Normalize the attention scores to probabilities.\n      // `attentionScores` = [B, N, T, S]\n      if (attentionMask != null) {\n        // The expand dim happens starting from the `numHeads` dimension,\n        // (<batchDims>, numHeads, <queryAttentionDims, keyAttentionDims>)\n        const maskExpansionAxis = -this.attentionAxes.length * 2 - 1;\n        const endIdx =\n          attentionScores.shape.length - attentionMask.shape.length;\n        for (let _ = 0; _ < endIdx; _++) {\n          attentionMask = expandDims(attentionMask, maskExpansionAxis);\n        }\n      }\n      return this.softmax.apply(\n        attentionScores, {mask: attentionMask}) as Tensor;\n    });\n  }\n\n  /**\n   * Applies Dot-product attention with query, key, value tensors.\n   *\n   * This function defines the computation inside `call` with projected\n   * multi-head Q, K, V inputs. Users can override this function for\n   * customized attention implementation.\n   *\n   * @param query Projected query `Tensor` of shape `(B, T, N, keyDim)`.\n   * @param key  Projected key `Tensor` of shape `(B, S, N, keyDim)`.\n   * @param value Projected value `Tensor` of shape `(B, S, N, valueDim)`.\n   * @param attentionMask A boolean mask of shape `(B, T, S)`, that prevents\n   *    attention to certain positions. It is generally not needed if\n   *    the `query` and `value` (and/or `key`) are masked.\n   * @param training Boolean indicating whether the layer should behave\n   *    in training mode (adding dropout) or in inference mode (doing\n   *    nothing).\n   * @returns attentionOutput: Multi-headed outputs of attention computation.\n   * @returns attentionScores: Multi-headed attention weights.\n   */\n  protected computeAttention(\n    query: Tensor,\n    key: Tensor,\n    value: Tensor,\n    attentionMask?: Tensor,\n    training?: boolean\n  ): [Tensor, Tensor] {\n    return tidy(() => {\n      // Note: Applying scalar multiply at the smaller end of einsum improves\n      // XLA performance, but may introduce slight numeric differences in\n      // the Transformer attention head.\n      query = mul(query, 1.0 / Math.sqrt(this.keyDim));\n\n      // Take the dot product between \"query\" and \"key\" to get the raw\n      // attention scores.\n      let attentionScores = einsum(this.dotProductEquation, key, query);\n\n      attentionScores = this.maskedSoftmax(attentionScores, attentionMask);\n\n      // This is actually dropping out entire tokens to attend to, which might\n      // seem a bit unusual, but is taken from the original Transformer paper.\n      const attentionScoresDropout =\n        this.dropoutLayer.apply(attentionScores, {training}) as Tensor;\n\n      // `contextLayer` = [B, T, N, H]\n      const attentionOutput =\n        einsum(this.combineEquation, attentionScoresDropout, value);\n\n      return [attentionOutput, attentionScores];\n    });\n  }\n\n  override apply(\n    inputs: Tensor | SymbolicTensor,\n    kwargs?: Kwargs\n  ): Tensor | Tensor[] | SymbolicTensor | SymbolicTensor[] {\n    if (!kwargs || !kwargs['value']) {\n      throw new ValueError('Must pass in `value` argument in `kwargs.`');\n    }\n    let newInputs: Tensor[]|SymbolicTensor[];\n\n    newInputs = [inputs, kwargs['value']].concat(kwargs['key'] ?? []);\n\n    // TODO(pforderique): Support mask propogation.\n    return super.apply(newInputs, kwargs);\n  }\n\n  override call(\n    query: Tensor, kwargs: MultiHeadAttentionOptions\n  ): Tensor {\n    return tidy(() => {\n      return this.callAndReturnAttentionScores(query, kwargs)[0];\n    });\n  }\n\n  /**\n   * Exactly like `call` except also returns the attention scores.\n   */\n  callAndReturnAttentionScores(\n    query: Tensor,\n    {\n      value,\n      key,\n      useCausalMask,\n      attentionMask,\n      training\n    }: MultiHeadAttentionOptions\n  ): [Tensor, Tensor] {\n    return tidy(() => {\n      if (!this.builtFromSignature) {\n        this.buildFromSignature(\n          query.shape,\n          value.shape,\n          key ? key.shape : null\n        );\n      }\n      if (key == null) {\n        key = value;\n      }\n\n      // TODO(pforderique): Support RaggedTensor inputs.\n\n      attentionMask = this.computeAttentionMask(\n        query,\n        value,\n        attentionMask,\n        useCausalMask,\n      );\n\n      //   N = `numAttentionHeads`\n      //   H = `sizePerHead`\n      // `query` = [B, T, N ,H]\n      query = this.queryDense.apply(query) as Tensor;\n\n      // `key` = [B, S, N, H]\n      key = this.keyDense.apply(key) as Tensor;\n\n      // `value` = [B, S, N, H]\n      value = this.valueDense.apply(value) as Tensor;\n\n      const [attentionOutputPreDense, attentionScores] = this.computeAttention(\n        query,\n        key,\n        value,\n        attentionMask,\n        training\n      );\n      const attentionOutput =\n        this.outputDense.apply(attentionOutputPreDense) as Tensor;\n\n      return [attentionOutput, attentionScores];\n    });\n  }\n\n  /**\n   * Computes the attention mask.\n   *\n   * * The `query`'s mask is reshaped from [B, T] to [B, T, 1].\n   * * The `value`'s mask is reshaped from [B, S] to [B, 1, S].\n   * * The `key`'s mask is reshaped from [B, S] to [B, 1, S]. The `key`'s\n   *   mask is ignored if `key` is `None` or if `key is value`.\n   * * If `useCausalMask=true`, then the causal mask is computed. Its shape\n   *   is [1, T, S].\n   *\n   * All defined masks are merged using a logical AND operation (`&`).\n   *\n   * In general, if the `query` and `value` are masked, then there is no need\n   * to define the `attentionMask`.\n   *\n   * @param query Projected query `Tensor` of shape `(B, T, N, keyDim)`.\n   * @param key  Projected key `Tensor` of shape `(B, S, N, keyDim)`.\n   * @param value Projected value `Tensor` of shape `(B, S, N, valueDim)`.\n   * @param attentionMask A boolean mask of shape `(B, T, S)`, that prevents\n   *    attention to certain positions.\n   * @param useCausalMask  A boolean to indicate whether to apply a causal\n   *    mask to prevent tokens from attending to future tokens (e.g.,\n   *    used in a decoder Transformer).\n   * @returns attentionMask: A boolean mask of shape `(B, T, S)`, that prevents\n   *    attention to certain positions, based on the Keras masks of the\n   *    `query`, `key`, `value`, and `attentionMask` tensors, and the\n   *    causal mask if `useCausalMask=true`.\n   */\n  private computeAttentionMask(\n    query: Tensor,\n    value: Tensor,\n    attentionMask?: Tensor,\n    useCausalMask = false\n  ): Tensor {\n    return tidy(() => {\n      let autoMask: Tensor;\n\n      const queryMask = query.kerasMask;\n      const valueMask = value.kerasMask;\n      if (queryMask != null) {\n        autoMask = queryMask.expandDims(2); // Shape is [B, T, 1]\n      }\n      if (valueMask != null) {\n        const mask = valueMask.expandDims(1); // Shape is [B, 1, S]\n        autoMask = autoMask ? logicalAnd(autoMask, mask) : mask;\n      }\n      if (useCausalMask) {\n        // the shape of the causal mask is [1, T, S]\n        const mask = this.computeCausalMask(query, value);\n        autoMask = autoMask ? logicalAnd(autoMask, mask) : mask;\n      }\n      if (autoMask != null) {\n        // Merge attentionMask & automatic mask, to shape [B, T, S]\n        attentionMask = attentionMask ?\n          cast(attentionMask, 'bool').logicalAnd(autoMask) : autoMask;\n      }\n\n      return attentionMask;\n    });\n  }\n\n  /**\n   * Computes a causal mask (e.g., for masked self-attention layers).\n   *\n   * For example, if query and value both contain sequences of length 4,\n   * this function returns a boolean `Tensor` equal to:\n   *\n   * ```\n   * [[[true,  false, false, false],\n   *   [true,  true,  false, false],\n   *   [true,  true,  true,  false],\n   *   [true,  true,  true,  true]]]\n   * ```\n   *\n   * @param query query `Tensor` of shape `(B, T, ...)`.\n   * @param value value `Tensor` of shape `(B, S, ...)` (defaults to query).\n   * @returns mask: A boolean `Tensor` of shape [1, T, S] containing a lower\n   *    triangular matrix of shape [T, S].\n   */\n  private computeCausalMask(query: Tensor, value?: Tensor): Tensor {\n    return tidy(() => {\n      const qSeqLength = query.shape[1];\n      const vSeqLength = value ? value.shape[1] : qSeqLength;\n      // Create a lower triangular matrix.\n      return linalg.bandPart(ones([1, qSeqLength, vSeqLength], 'bool'), -1, 0);\n    });\n  }\n\n  /**\n   *\n   * @param inputShapes A list of [queryShape, valueShape] or\n   *    [queryShape, valueShape, keyShape]. If no keyShape provided, valueShape\n   *    is assumed as the keyShape.\n   */\n  override computeOutputShape(inputShapes: [Shape, Shape, Shape|null]): Shape {\n    const [queryShape, valueShape, maybeKeyShape] = inputShapes;\n    const keyShape = maybeKeyShape ?? valueShape;\n\n    if (queryShape.slice(-1)[0] !== valueShape.slice(-1)[0]) {\n      throw new ValueError(\n        `The last dimension of 'queryShape' and 'valueShape' must be equal, ` +\n        `but are ${queryShape.slice(-1)[0]}, ${valueShape.slice(-1)[0]}. ` +\n        `Received: queryShape=${queryShape}, valueShape=${valueShape}`\n      );\n    }\n\n    if (!util.arraysEqual(valueShape.slice(1, -1), keyShape.slice(1, -1))) {\n      throw new Error(\n        `All dimensions of 'value' and 'key', except the last one, must be ` +\n        `equal. Received ${valueShape} and ${keyShape}`\n      );\n    }\n\n    if (this._outputShape) {\n      return queryShape.slice(0, -1).concat(this._outputShape);\n    }\n\n    return queryShape;\n  }\n}\nserialization.registerClass(MultiHeadAttention);\n"]}
|