gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
/**
 * @license
 * Copyright 2023 Google LLC.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
/**
 *  Utility functions for `TransformerDecoder`.
 */
/* Original source: keras_nlp/layers/modeling/transformer_layer_utils.py */
import { add, expandDims, tensor, tidy } from '@tensorflow/tfjs-core';
import { ValueError } from '../../../errors';
function checkMasksShapes(inputs, paddingMask, attentionMask) {
    if (paddingMask != null) {
        if (paddingMask.shape.length !== 2) {
            throw new ValueError('`paddingMask` should have shape ' +
                `[batchSize, targetLength]. Received shape ${paddingMask.shape}.`);
        }
    }
    if (attentionMask != null) {
        if (attentionMask.shape.length !== 3) {
            throw new ValueError('`attentionMask` should have shape ' +
                `[batchSize, targetLength, sourceLength]. ` +
                `Received shape ${attentionMask.shape}.`);
        }
    }
}
/**
 * Compute a causal attention mask for a transformer decoder.
 *
 * @param batchSize batch size for the mask.
 * @param inputLength the length of key/value tensors in the attention layer.
 * @param outputLength the length of query tensor in the attention layer.
 * @param cacheIndex the current index for cached generation. If passed, the
 *  query sequence will be considered to start at `cacheIndex` rather than zero.
 *  For example, a casual mask with `outputLength=1` and `cacheIndex=5` would
 *  allow the query tensor to attend to the first five positions of the
 *  key/value tensors.
 *
 * @returns a causal attention mask with shape
 *  `[batchSize, outputLength, inputLength]` that can be passed to a attention
 *  layer.
 */
export function computeCausalMask(batchSize, inputLength, outputLength, cacheIndex = 0) {
    return tidy(() => {
        const i = add(expandDims(Array.from({ length: outputLength }, (_, i) => i), 1), cacheIndex);
        const j = tensor(Array.from({ length: inputLength }, (_, i) => i));
        const mask = i.greaterEqual(j).cast('int32').expandDims(0);
        return mask.broadcastTo([batchSize, outputLength, inputLength]);
    });
}
/**
 * Merge the padding mask with a customized attention mask.
 *
 * @param inputs the input sequence.
 * @param paddingMask the 1D padding mask, of shape
 *          [batchSize, sequenceLength].
 * @param attentionMask the 2D customized mask, of shape
 *          [batchSize, sequenceLength, sequence2_length].
 * @returns
 *  A merged 2D mask or null. If only `paddingMask` is provided, the
 *  returned mask is paddingMask with one additional axis.
 */
export function mergePaddingAndAttentionMask(inputs, paddingMask, attentionMask) {
    return tidy(() => {
        checkMasksShapes(inputs, paddingMask, attentionMask);
        let mask;
        if (paddingMask != null) {
            // Add an axis for broadcasting, the attention mask should be 2D
            // (not including the batch axis).
            mask = paddingMask.expandDims(1).cast('int32');
        }
        if (attentionMask != null) {
            attentionMask = attentionMask.cast('int32');
            if (mask == null) {
                return attentionMask;
            }
            else {
                return mask.minimum(attentionMask);
            }
        }
        return mask;
    });
}
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"transformer_layer_utils.js","sourceRoot":"","sources":["../../../../../../../../tfjs-layers/src/layers/nlp/modeling/transformer_layer_utils.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,2EAA2E;AAC3E,OAAO,EAAU,GAAG,EAAE,UAAU,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,uBAAuB,CAAC;AAE9E,OAAO,EAAE,UAAU,EAAE,MAAM,iBAAiB,CAAC;AAE7C,SAAS,gBAAgB,CACrB,MAAc,EAAE,WAAmB,EAAE,aAAqB;IAC5D,IAAI,WAAW,IAAI,IAAI,EAAE;QACvB,IAAI,WAAW,CAAC,KAAK,CAAC,MAAM,KAAI,CAAC,EAAE;YACjC,MAAM,IAAI,UAAU,CAClB,kCAAkC;gBAClC,6CAA6C,WAAW,CAAC,KAAK,GAAG,CAClE,CAAC;SACH;KACF;IACD,IAAI,aAAa,IAAI,IAAI,EAAE;QACzB,IAAI,aAAa,CAAC,KAAK,CAAC,MAAM,KAAK,CAAC,EAAE;YACpC,MAAM,IAAI,UAAU,CAClB,oCAAoC;gBACpC,2CAA2C;gBAC3C,kBAAkB,aAAa,CAAC,KAAK,GAAG,CACzC,CAAC;SACH;KACF;AACH,CAAC;AAED;;;;;;;;;;;;;;;GAeG;AACH,MAAM,UAAU,iBAAiB,CAC7B,SAAiB,EACjB,WAAmB,EACnB,YAAoB,EACpB,UAAU,GAAG,CAAC;IAEhB,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,MAAM,CAAC,GAAG,GAAG,CACX,UAAU,CAAC,KAAK,CAAC,IAAI,CAAC,EAAC,MAAM,EAAE,YAAY,EAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAC9D,UAAU,CACX,CAAC;QACF,MAAM,CAAC,GAAG,MAAM,CAAC,KAAK,CAAC,IAAI,CAAC,EAAC,MAAM,EAAE,WAAW,EAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QACjE,MAAM,IAAI,GAAG,CAAC,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;QAC3D,OAAO,IAAI,CAAC,WAAW,CAAC,CAAC,SAAS,EAAE,YAAY,EAAE,WAAW,CAAC,CAAC,CAAC;IAClE,CAAC,CAAC,CAAC;AACL,CAAC;AAED;;;;;;;;;;;GAWG;AACH,MAAM,UAAU,4BAA4B,CACxC,MAAc,EAAE,WAAmB,EAAE,aAAqB;IAC5D,OAAO,IAAI,CAAC,GAAG,EAAE;QACf,gBAAgB,CAAC,MAAM,EAAE,WAAW,EAAE,aAAa,CAAC,CAAC;QACrD,IAAI,IAAY,CAAC;QACjB,IAAI,WAAW,IAAI,IAAI,EAAE;YACvB,gEAAgE;YAChE,kCAAkC;YAClC,IAAI,GAAG,WAAW,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;SAChD;QACD,IAAI,aAAa,IAAI,IAAI,EAAE;YACzB,aAAa,GAAG,aAAa,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;YAC5C,IAAI,IAAI,IAAI,IAAI,EAAE;gBAChB,OAAO,aAAa,CAAC;aACtB;iBAAM;gBACL,OAAO,IAAI,CAAC,OAAO,CAAC,aAAa,CAAC,CAAC;aACpC;SACF;QACD,OAAO,IAAI,CAAC;IACd,CAAC,CAAC,CAAC;AACL,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n *  Utility functions for `TransformerDecoder`.\n */\n\n/* Original source: keras_nlp/layers/modeling/transformer_layer_utils.py */\nimport { Tensor, add, expandDims, tensor, tidy } from '@tensorflow/tfjs-core';\n\nimport { ValueError } from '../../../errors';\n\nfunction checkMasksShapes(\n    inputs: Tensor, paddingMask: Tensor, attentionMask: Tensor): void {\n  if (paddingMask != null) {\n    if (paddingMask.shape.length !==2) {\n      throw new ValueError(\n        '`paddingMask` should have shape ' +\n        `[batchSize, targetLength]. Received shape ${paddingMask.shape}.`\n      );\n    }\n  }\n  if (attentionMask != null) {\n    if (attentionMask.shape.length !== 3) {\n      throw new ValueError(\n        '`attentionMask` should have shape ' +\n        `[batchSize, targetLength, sourceLength]. ` +\n        `Received shape ${attentionMask.shape}.`\n      );\n    }\n  }\n}\n\n/**\n * Compute a causal attention mask for a transformer decoder.\n *\n * @param batchSize batch size for the mask.\n * @param inputLength the length of key/value tensors in the attention layer.\n * @param outputLength the length of query tensor in the attention layer.\n * @param cacheIndex the current index for cached generation. If passed, the\n *  query sequence will be considered to start at `cacheIndex` rather than zero.\n *  For example, a casual mask with `outputLength=1` and `cacheIndex=5` would\n *  allow the query tensor to attend to the first five positions of the\n *  key/value tensors.\n *\n * @returns a causal attention mask with shape\n *  `[batchSize, outputLength, inputLength]` that can be passed to a attention\n *  layer.\n */\nexport function computeCausalMask(\n    batchSize: number,\n    inputLength: number,\n    outputLength: number,\n    cacheIndex = 0\n  ): Tensor {\n  return tidy(() => {\n    const i = add(\n      expandDims(Array.from({length: outputLength}, (_, i) => i), 1),\n      cacheIndex,\n    );\n    const j = tensor(Array.from({length: inputLength}, (_, i) => i));\n    const mask = i.greaterEqual(j).cast('int32').expandDims(0);\n    return mask.broadcastTo([batchSize, outputLength, inputLength]);\n  });\n}\n\n/**\n * Merge the padding mask with a customized attention mask.\n *\n * @param inputs the input sequence.\n * @param paddingMask the 1D padding mask, of shape\n *          [batchSize, sequenceLength].\n * @param attentionMask the 2D customized mask, of shape\n *          [batchSize, sequenceLength, sequence2_length].\n * @returns\n *  A merged 2D mask or null. If only `paddingMask` is provided, the\n *  returned mask is paddingMask with one additional axis.\n */\nexport function mergePaddingAndAttentionMask(\n    inputs: Tensor, paddingMask: Tensor, attentionMask: Tensor): Tensor {\n  return tidy(() => {\n    checkMasksShapes(inputs, paddingMask, attentionMask);\n    let mask: Tensor;\n    if (paddingMask != null) {\n      // Add an axis for broadcasting, the attention mask should be 2D\n      // (not including the batch axis).\n      mask = paddingMask.expandDims(1).cast('int32');\n    }\n    if (attentionMask != null) {\n      attentionMask = attentionMask.cast('int32');\n      if (mask == null) {\n        return attentionMask;\n      } else {\n        return mask.minimum(attentionMask);\n      }\n    }\n    return mask;\n  });\n}\n"]}