gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
/**
 * @license
 * Copyright 2023 Google LLC.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
/**
 *  Position embedding implementation based on `tf.layers.Layer`.
 */
/* Original source: keras_nlp/layers/modeling/position_embedding.py */
import { serialization, tidy } from '@tensorflow/tfjs-core';
import { Layer } from '../../../engine/topology';
import { ValueError } from '../../../errors';
import { getInitializer, serializeInitializer } from '../../../initializers';
import { getExactlyOneTensor } from '../../../utils/types_utils';
/**
 * A layer which learns a position embedding for input sequences.
 *
 * This class assumes that in the input tensor, the last dimension corresponds
 * to the features, and the dimension before the last corresponds to the
 * sequence.
 *
 * Examples:
 *
 * Called directly on input.
 * ```js
 * const layer = new PositionEmbedding({sequenceLength=10});
 * layer.call(tf.zeros([8, 10, 16]));
 * ```
 *
 * Combine with a token embedding.
 * ```js
 * const seqLength = 50;
 * const vocabSize = 5000;
 * const embedDim = 128;
 * const inputs = tf.input({shape: [seqLength]});
 * const tokenEmbeddings = tf.layers.embedding({
 *     inputDim=vocabSize, outputDim=embedDim
 * }).apply(inputs);
 * const positionEmbeddings = new PositionEmbedding({
 *     sequenceLength: seqLength
 * }).apply(tokenEmbeddings);
 * const outputs = tf.add(tokenEmbeddings, positionEmbeddings);
 * ```
 *
 * Reference:
 *  - [Devlin et al., 2019](https://arxiv.org/abs/1810.04805)
 */
class PositionEmbedding extends Layer {
    constructor(args) {
        super(args);
        if (args.sequenceLength == null) {
            throw new ValueError('`sequenceLength` must be an Integer, received `null`.');
        }
        this.sequenceLength = args.sequenceLength;
        this.initializer = getInitializer(args.initializer || 'glorotUniform');
    }
    getConfig() {
        const config = {
            'sequenceLength': this.sequenceLength,
            'initializer': serializeInitializer(this.initializer),
        };
        const baseConfig = super.getConfig();
        Object.assign(config, baseConfig);
        return config;
    }
    build(inputShape) {
        const featureSize = inputShape[inputShape.length - 1];
        this.positionEmbeddings = this.addWeight('embeddings', [this.sequenceLength, featureSize], null, this.initializer, null, true);
        super.build(inputShape);
    }
    call(inputs, kwargs) {
        return tidy(() => {
            var _a;
            kwargs.startIndex = (_a = kwargs.startIndex) !== null && _a !== void 0 ? _a : 0;
            const shape = getExactlyOneTensor(inputs).shape;
            const featureLength = shape[shape.length - 1];
            const sequenceLength = shape[shape.length - 2];
            // trim to match the length of the input sequence, which might be less
            // than the sequence_length of the layer.
            const positionEmbeddings = this.positionEmbeddings.read().slice([kwargs.startIndex, 0], [sequenceLength, featureLength]);
            return positionEmbeddings.broadcastTo(shape);
        });
    }
    computeOutputShape(inputShape) {
        return inputShape;
    }
}
/** @nocollapse */
PositionEmbedding.className = 'PositionEmbedding';
export { PositionEmbedding };
serialization.registerClass(PositionEmbedding);
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"position_embedding.js","sourceRoot":"","sources":["../../../../../../../../tfjs-layers/src/layers/nlp/modeling/position_embedding.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,sEAAsE;AACtE,OAAO,EAAU,aAAa,EAAE,IAAI,EAAE,MAAM,uBAAuB,CAAC;AAGpE,OAAO,EAAE,KAAK,EAAa,MAAM,0BAA0B,CAAC;AAC5D,OAAO,EAAE,UAAU,EAAE,MAAM,iBAAiB,CAAC;AAC7C,OAAO,EAAsC,cAAc,EAAE,oBAAoB,EAAE,MAAM,uBAAuB,CAAC;AACjH,OAAO,EAAE,mBAAmB,EAAE,MAAM,4BAA4B,CAAC;AAwBjE;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAgCG;AACH,MAAa,iBAAkB,SAAQ,KAAK;IAO1C,YAAY,IAA2B;QACrC,KAAK,CAAC,IAAI,CAAC,CAAC;QACZ,IAAI,IAAI,CAAC,cAAc,IAAI,IAAI,EAAE;YAC/B,MAAM,IAAI,UAAU,CAClB,uDAAuD,CAAC,CAAC;SAC5D;QACD,IAAI,CAAC,cAAc,GAAG,IAAI,CAAC,cAAc,CAAC;QAC1C,IAAI,CAAC,WAAW,GAAG,cAAc,CAAC,IAAI,CAAC,WAAW,IAAI,eAAe,CAAC,CAAC;IACzE,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAAG;YACb,gBAAgB,EAAE,IAAI,CAAC,cAAc;YACrC,aAAa,EAAE,oBAAoB,CAAC,IAAI,CAAC,WAAW,CAAC;SACtD,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;IAEQ,KAAK,CAAC,UAAiB;QAC9B,MAAM,WAAW,GAAG,UAAU,CAAC,UAAU,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;QACtD,IAAI,CAAC,kBAAkB,GAAG,IAAI,CAAC,SAAS,CACtC,YAAY,EACZ,CAAC,IAAI,CAAC,cAAc,EAAE,WAAW,CAAC,EAClC,IAAI,EACJ,IAAI,CAAC,WAAW,EAChB,IAAI,EACJ,IAAI,CACL,CAAC;QACF,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC1B,CAAC;IAEQ,IAAI,CACX,MAAuB,EACvB,MAAiC;QAEjC,OAAO,IAAI,CAAC,GAAG,EAAE;;YACf,MAAM,CAAC,UAAU,GAAG,MAAA,MAAM,CAAC,UAAU,mCAAI,CAAC,CAAC;YAC3C,MAAM,KAAK,GAAG,mBAAmB,CAAC,MAAM,CAAC,CAAC,KAAK,CAAC;YAChD,MAAM,aAAa,GAAG,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;YAC9C,MAAM,cAAc,GAAG,KAAK,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;YAC/C,sEAAsE;YACtE,yCAAyC;YACzC,MAAM,kBAAkB,GAAG,IAAI,CAAC,kBAAkB,CAAC,IAAI,EAAE,CAAC,KAAK,CAC7D,CAAC,MAAM,CAAC,UAAU,EAAE,CAAC,CAAC,EAAE,CAAC,cAAc,EAAE,aAAa,CAAC,CAAC,CAAC;YAC3D,OAAO,kBAAkB,CAAC,WAAW,CAAC,KAAK,CAAC,CAAC;QAC/C,CAAC,CAAC,CAAC;IACL,CAAC;IAEQ,kBAAkB,CAAC,UAAiB;QAC3C,OAAO,UAAU,CAAC;IACpB,CAAC;;AA1DD,kBAAkB;AACF,2BAAS,GAAG,mBAAmB,CAAC;SAFrC,iBAAiB;AA6D9B,aAAa,CAAC,aAAa,CAAC,iBAAiB,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n *  Position embedding implementation based on `tf.layers.Layer`.\n */\n\n/* Original source: keras_nlp/layers/modeling/position_embedding.py */\nimport { Tensor, serialization, tidy } from '@tensorflow/tfjs-core';\n\nimport { Shape } from '../../../keras_format/common';\nimport { Layer, LayerArgs } from '../../../engine/topology';\nimport { ValueError } from '../../../errors';\nimport { Initializer, InitializerIdentifier, getInitializer, serializeInitializer } from '../../../initializers';\nimport { getExactlyOneTensor } from '../../../utils/types_utils';\nimport { LayerVariable } from '../../../variables';\n\nexport declare interface PositionEmbeddingArgs extends LayerArgs {\n  /**\n   * Integer. The maximum length of the dynamic sequence.\n   */\n  sequenceLength: number;\n\n  /**\n   * The initializer to use for the embedding weights.\n   * Defaults to `\"glorotUniform\"`.\n   */\n  initializer?: Initializer|InitializerIdentifier;\n}\n\nexport declare interface PositionEmbeddingOptions {\n  /**\n   * Integer. Index to start the position embeddings at.\n   * Defaults to 0.\n   */\n  startIndex?: number;\n}\n\n/**\n * A layer which learns a position embedding for input sequences.\n *\n * This class assumes that in the input tensor, the last dimension corresponds\n * to the features, and the dimension before the last corresponds to the\n * sequence.\n *\n * Examples:\n *\n * Called directly on input.\n * ```js\n * const layer = new PositionEmbedding({sequenceLength=10});\n * layer.call(tf.zeros([8, 10, 16]));\n * ```\n *\n * Combine with a token embedding.\n * ```js\n * const seqLength = 50;\n * const vocabSize = 5000;\n * const embedDim = 128;\n * const inputs = tf.input({shape: [seqLength]});\n * const tokenEmbeddings = tf.layers.embedding({\n *     inputDim=vocabSize, outputDim=embedDim\n * }).apply(inputs);\n * const positionEmbeddings = new PositionEmbedding({\n *     sequenceLength: seqLength\n * }).apply(tokenEmbeddings);\n * const outputs = tf.add(tokenEmbeddings, positionEmbeddings);\n * ```\n *\n * Reference:\n *  - [Devlin et al., 2019](https://arxiv.org/abs/1810.04805)\n */\nexport class PositionEmbedding extends Layer {\n  /** @nocollapse */\n  static readonly className = 'PositionEmbedding';\n  private sequenceLength: number;\n  private initializer: Initializer;\n  protected positionEmbeddings: LayerVariable;\n\n  constructor(args: PositionEmbeddingArgs) {\n    super(args);\n    if (args.sequenceLength == null) {\n      throw new ValueError(\n        '`sequenceLength` must be an Integer, received `null`.');\n    }\n    this.sequenceLength = args.sequenceLength;\n    this.initializer = getInitializer(args.initializer || 'glorotUniform');\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config = {\n      'sequenceLength': this.sequenceLength,\n      'initializer': serializeInitializer(this.initializer),\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n\n  override build(inputShape: Shape): void {\n    const featureSize = inputShape[inputShape.length - 1];\n    this.positionEmbeddings = this.addWeight(\n      'embeddings',\n      [this.sequenceLength, featureSize],\n      null,\n      this.initializer,\n      null,\n      true\n    );\n    super.build(inputShape);\n  }\n\n  override call(\n    inputs: Tensor|Tensor[],\n    kwargs?: PositionEmbeddingOptions\n  ): Tensor {\n    return tidy(() => {\n      kwargs.startIndex = kwargs.startIndex ?? 0;\n      const shape = getExactlyOneTensor(inputs).shape;\n      const featureLength = shape[shape.length - 1];\n      const sequenceLength = shape[shape.length - 2];\n      // trim to match the length of the input sequence, which might be less\n      // than the sequence_length of the layer.\n      const positionEmbeddings = this.positionEmbeddings.read().slice(\n        [kwargs.startIndex, 0], [sequenceLength, featureLength]);\n      return positionEmbeddings.broadcastTo(shape);\n    });\n  }\n\n  override computeOutputShape(inputShape: Shape): Shape {\n    return inputShape;\n  }\n}\nserialization.registerClass(PositionEmbedding);\n"]}