gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
/**
 * @license
 * Copyright 2023 Google LLC.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
/**
 *  Start End Packer implementation based on `tf.layers.Layer`.
 */
/* Original source: keras-nlp/start_end_packer.py */
import { Tensor, concat, serialization, stack, tensor, tidy } from '@tensorflow/tfjs-core';
import { Layer } from '../../../engine/topology';
import { ValueError } from '../../../errors';
/**
 * Adds start and end tokens to a sequence and pads to a fixed length.
 *
 *  This layer is useful when tokenizing inputs for tasks like translation,
 *  where each sequence should include a start and end marker. It should
 *  be called after tokenization. The layer will first trim inputs to fit, then
 *  add start/end tokens, and finally pad, if necessary, to `sequence_length`.
 *
 *  Input should be either a `tf.Tensor[]` or a dense `tf.Tensor`, and
 *  either rank-1 or rank-2.
 */
class StartEndPacker extends Layer {
    constructor(args) {
        super(args);
        this.sequenceLength = args.sequenceLength;
        this.startValue = args.startValue;
        this.endValue = args.endValue;
        this.padValue = args.padValue;
    }
    call(inputs, kwargs = { addStartValue: true, addEndValue: true }) {
        return this.callAndReturnPaddingMask(inputs, kwargs)[0];
    }
    /**
     * Exactly like `call` except also returns a boolean padding mask of all
     * locations that are filled in with the `padValue`.
     */
    callAndReturnPaddingMask(inputs, kwargs = { addStartValue: true, addEndValue: true }) {
        return tidy(() => {
            var _a;
            // Add a new axis at the beginning if needed.
            let x = inputs instanceof Tensor ? [inputs] : inputs;
            const inputIs1d = inputs instanceof Tensor && inputs.rank === 1;
            if (x.some(t => t.rank !== 1)) {
                throw new ValueError('Input must either be a rank 1 Tensor or an array of rank 1 Tensors.');
            }
            const sequenceLength = (_a = kwargs.sequenceLength) !== null && _a !== void 0 ? _a : this.sequenceLength;
            // Concatenate start and end tokens.
            if (kwargs.addStartValue && this.startValue != null) {
                const startTokenIdTensor = tensor([this.startValue]);
                x = x.map(t => concat([startTokenIdTensor, t]));
            }
            if (kwargs.addEndValue && this.endValue != null) {
                const endTokenIdTensor = tensor([this.endValue]);
                // Trim to leave room for end token.
                x = x.map(t => {
                    const sliced = t.slice(0, Math.min(t.shape[0], sequenceLength - 1));
                    const padded = concat([sliced, endTokenIdTensor]);
                    return padded;
                });
            }
            // tf.pad does not allow padding on Tensors with dtype='string'
            function ensureLength(input, length, padValue) {
                if (padValue === undefined) {
                    padValue = input.dtype === 'string' ? '' : 0;
                }
                if (typeof padValue === 'number') {
                    return input.pad([[0, length - input.size]], padValue);
                }
                const strInput = input.arraySync();
                if (strInput.length <= length) {
                    const pads = Array(length - strInput.length).fill(padValue);
                    return tensor(strInput.concat(pads));
                }
                return tensor(strInput.slice(0, strInput.length - length));
            }
            const paddedMask = x.map(t => {
                // `onesLike` not used since it does not support string tensors.
                const ones = tensor(Array(t.shape[0]).fill(1));
                return ensureLength(ones, sequenceLength, 0).cast('bool');
            });
            const mask = inputIs1d ?
                paddedMask[0]
                : stack(paddedMask);
            const paddedTensors = x.map(t => ensureLength(t, sequenceLength, this.padValue));
            const outputs = inputIs1d ?
                paddedTensors[0]
                : stack(paddedTensors);
            return [outputs, mask];
        });
    }
    getConfig() {
        const config = {
            sequenceLength: this.sequenceLength,
            startValue: this.startValue,
            endValue: this.endValue,
            padValue: this.padValue,
        };
        const baseConfig = super.getConfig();
        Object.assign(config, baseConfig);
        return config;
    }
}
/** @nocollapse */
StartEndPacker.className = 'StartEndPacker';
export { StartEndPacker };
serialization.registerClass(StartEndPacker);
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"start_end_packer.js","sourceRoot":"","sources":["../../../../../../../../tfjs-layers/src/layers/nlp/preprocessing/start_end_packer.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH;;GAEG;AAEH,oDAAoD;AACpD,OAAO,EAAE,MAAM,EAAsB,MAAM,EAAE,aAAa,EAAE,KAAK,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,uBAAuB,CAAC;AAE/G,OAAO,EAAE,KAAK,EAAa,MAAM,0BAA0B,CAAC;AAC5D,OAAO,EAAE,UAAU,EAAE,MAAM,iBAAiB,CAAC;AAiD7C;;;;;;;;;;GAUG;AACH,MAAa,cAAe,SAAQ,KAAK;IASvC,YAAY,IAAwB;QAClC,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,cAAc,GAAG,IAAI,CAAC,cAAc,CAAC;QAC1C,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,UAAU,CAAC;QAClC,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;QAC9B,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;IAChC,CAAC;IAEQ,IAAI,CACX,MAAuB,EACvB,SAA8B,EAAC,aAAa,EAAE,IAAI,EAAE,WAAW,EAAE,IAAI,EAAC;QAEtE,OAAO,IAAI,CAAC,wBAAwB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;IAC1D,CAAC;IAED;;;OAGG;IACH,wBAAwB,CACtB,MAAuB,EACvB,SAA8B,EAAC,aAAa,EAAE,IAAI,EAAE,WAAW,EAAE,IAAI,EAAC;QAEtE,OAAO,IAAI,CAAC,GAAG,EAAE;;YACf,6CAA6C;YAC7C,IAAI,CAAC,GAAG,MAAM,YAAY,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;YAErD,MAAM,SAAS,GAAG,MAAM,YAAY,MAAM,IAAI,MAAM,CAAC,IAAI,KAAK,CAAC,CAAC;YAEhE,IAAI,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,KAAK,CAAC,CAAC,EAAE;gBAC7B,MAAM,IAAI,UAAU,CAClB,qEAAqE,CACtE,CAAC;aACH;YACD,MAAM,cAAc,GAAG,MAAA,MAAM,CAAC,cAAc,mCAAI,IAAI,CAAC,cAAc,CAAC;YAEpE,oCAAoC;YACpC,IAAI,MAAM,CAAC,aAAa,IAAI,IAAI,CAAC,UAAU,IAAI,IAAI,EAAE;gBACnD,MAAM,kBAAkB,GAAG,MAAM,CAAC,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC,CAAC;gBACrD,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,kBAAkB,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;aACjD;YACD,IAAI,MAAM,CAAC,WAAW,IAAI,IAAI,CAAC,QAAQ,IAAI,IAAI,EAAE;gBAC/C,MAAM,gBAAgB,GAAG,MAAM,CAAC,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;gBACjD,oCAAoC;gBACpC,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE;oBACZ,MAAM,MAAM,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,cAAc,GAAG,CAAC,CAAC,CAAC,CAAC;oBACpE,MAAM,MAAM,GAAG,MAAM,CAAC,CAAC,MAAM,EAAE,gBAAgB,CAAC,CAAC,CAAC;oBAClD,OAAO,MAAM,CAAC;gBAChB,CAAC,CAAC,CAAC;aACJ;YAED,+DAA+D;YAC/D,SAAS,YAAY,CACnB,KAAa,EAAE,MAAc,EAAE,QAAwB;gBACvD,IAAI,QAAQ,KAAK,SAAS,EAAE;oBAC1B,QAAQ,GAAG,KAAK,CAAC,KAAK,KAAK,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;iBAC9C;gBACD,IAAI,OAAO,QAAQ,KAAK,QAAQ,EAAE;oBAChC,OAAO,KAAK,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,EAAE,MAAM,GAAG,KAAK,CAAC,IAAI,CAAC,CAAC,EAAE,QAAQ,CAAC,CAAC;iBACxD;gBAED,MAAM,QAAQ,GAAG,KAAK,CAAC,SAAS,EAAyB,CAAC;gBAE1D,IAAI,QAAQ,CAAC,MAAM,IAAI,MAAM,EAAE;oBAC7B,MAAM,IAAI,GAAG,KAAK,CAAC,MAAM,GAAG,QAAQ,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC;oBAC5D,OAAO,MAAM,CAAC,QAAQ,CAAC,MAAM,CAAC,IAAI,CAAC,CAAC,CAAC;iBACtC;gBAED,OAAO,MAAM,CAAC,QAAQ,CAAC,KAAK,CAAC,CAAC,EAAE,QAAQ,CAAC,MAAM,GAAG,MAAM,CAAC,CAAC,CAAC;YAC7D,CAAC;YAED,MAAM,UAAU,GAAa,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE;gBACrC,gEAAgE;gBAChE,MAAM,IAAI,GAAG,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC;gBAC/C,OAAO,YAAY,CAAC,IAAI,EAAE,cAAc,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;YAC5D,CAAC,CAAC,CAAC;YACH,MAAM,IAAI,GAAG,SAAS,CAAC,CAAC;gBACtB,UAAU,CAAC,CAAC,CAAa;gBACzB,CAAC,CAAC,KAAK,CAAC,UAAU,CAAa,CAAC;YAElC,MAAM,aAAa,GACjB,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,YAAY,CAAC,CAAC,EAAE,cAAc,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAC7D,MAAM,OAAO,GAAG,SAAS,CAAC,CAAC;gBACzB,aAAa,CAAC,CAAC,CAAa;gBAC5B,CAAC,CAAC,KAAK,CAAC,aAAa,CAAa,CAAC;YAErC,OAAO,CAAC,OAAO,EAAE,IAAI,CAAC,CAAC;QACzB,CAAC,CAAC,CAAC;IACL,CAAC;IAEQ,SAAS;QAChB,MAAM,MAAM,GAAG;YACb,cAAc,EAAE,IAAI,CAAC,cAAc;YACnC,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,QAAQ,EAAE,IAAI,CAAC,QAAQ;SACxB,CAAC;QACF,MAAM,UAAU,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QACrC,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;QAClC,OAAO,MAAM,CAAC;IAChB,CAAC;;AA7GD,kBAAkB;AACF,wBAAS,GAAG,gBAAgB,CAAC;SAFlC,cAAc;AAgH3B,aAAa,CAAC,aAAa,CAAC,cAAc,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2023 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/**\n *  Start End Packer implementation based on `tf.layers.Layer`.\n */\n\n/* Original source: keras-nlp/start_end_packer.py */\nimport { Tensor, Tensor1D, Tensor2D, concat, serialization, stack, tensor, tidy } from '@tensorflow/tfjs-core';\n\nimport { Layer, LayerArgs } from '../../../engine/topology';\nimport { ValueError } from '../../../errors';\n\nexport declare interface StartEndPackerArgs extends LayerArgs {\n  /**\n   * Integer. The desired output length.\n   */\n  sequenceLength: number;\n\n  /**\n   * Integer or string. The ID or token that is to be placed at the start of\n   * each sequence. The dtype must match the dtype of the input tensors to the\n   * layer. If undefined, no start value will be added.\n   */\n  startValue?: number|string;\n\n  /**\n   * Integer or string. The ID or token that is to be placed at the end of each\n   * input segment. The dtype must match the dtype of the input tensors to the\n   * layer. If undefined, no end value will be added.\n   */\n  endValue?: number|string;\n\n  /**\n   * Integer or string. The ID or token that is to be placed into the unused\n   * positions after the last segment in the sequence. If undefined, 0 or ''\n   * will be added depending on the dtype of the input tensor.\n   */\n  padValue?: number|string;\n}\n\nexport declare interface StartEndPackerOptions {\n  /**\n   * Pass to override the configured `sequenceLength` of the layer.\n   */\n  sequenceLength?: number;\n\n  /**\n   * Pass `false` to not append a start value for this input.\n   * Defaults to true.\n   */\n  addStartValue?: boolean;\n\n  /**\n   * Pass `false` to not append an end value for this input.\n   * Defaults to true.\n   */\n  addEndValue?: boolean;\n}\n\n/**\n * Adds start and end tokens to a sequence and pads to a fixed length.\n *\n *  This layer is useful when tokenizing inputs for tasks like translation,\n *  where each sequence should include a start and end marker. It should\n *  be called after tokenization. The layer will first trim inputs to fit, then\n *  add start/end tokens, and finally pad, if necessary, to `sequence_length`.\n *\n *  Input should be either a `tf.Tensor[]` or a dense `tf.Tensor`, and\n *  either rank-1 or rank-2.\n */\nexport class StartEndPacker extends Layer {\n  /** @nocollapse */\n  static readonly className = 'StartEndPacker';\n\n  private sequenceLength: number;\n  private startValue?: number|string;\n  private endValue?: number|string;\n  private padValue?: number|string;\n\n  constructor(args: StartEndPackerArgs) {\n    super(args);\n\n    this.sequenceLength = args.sequenceLength;\n    this.startValue = args.startValue;\n    this.endValue = args.endValue;\n    this.padValue = args.padValue;\n  }\n\n  override call(\n    inputs: Tensor|Tensor[],\n    kwargs: StartEndPackerOptions={addStartValue: true, addEndValue: true}\n  ): Tensor|Tensor2D {\n    return this.callAndReturnPaddingMask(inputs, kwargs)[0];\n  }\n\n  /**\n   * Exactly like `call` except also returns a boolean padding mask of all\n   * locations that are filled in with the `padValue`.\n   */\n  callAndReturnPaddingMask(\n    inputs: Tensor|Tensor[],\n    kwargs: StartEndPackerOptions={addStartValue: true, addEndValue: true}\n  ): [Tensor1D|Tensor2D, Tensor1D|Tensor2D] {\n    return tidy(() => {\n      // Add a new axis at the beginning if needed.\n      let x = inputs instanceof Tensor ? [inputs] : inputs;\n\n      const inputIs1d = inputs instanceof Tensor && inputs.rank === 1;\n\n      if (x.some(t => t.rank !== 1)) {\n        throw new ValueError(\n          'Input must either be a rank 1 Tensor or an array of rank 1 Tensors.'\n        );\n      }\n      const sequenceLength = kwargs.sequenceLength ?? this.sequenceLength;\n\n      // Concatenate start and end tokens.\n      if (kwargs.addStartValue && this.startValue != null) {\n        const startTokenIdTensor = tensor([this.startValue]);\n        x = x.map(t => concat([startTokenIdTensor, t]));\n      }\n      if (kwargs.addEndValue && this.endValue != null) {\n        const endTokenIdTensor = tensor([this.endValue]);\n        // Trim to leave room for end token.\n        x = x.map(t => {\n          const sliced = t.slice(0, Math.min(t.shape[0], sequenceLength - 1));\n          const padded = concat([sliced, endTokenIdTensor]);\n          return padded;\n        });\n      }\n\n      // tf.pad does not allow padding on Tensors with dtype='string'\n      function ensureLength(\n        input: Tensor, length: number, padValue?: string|number) {\n        if (padValue === undefined) {\n          padValue = input.dtype === 'string' ? '' : 0;\n        }\n        if (typeof padValue === 'number') {\n          return input.pad([[0, length - input.size]], padValue);\n        }\n\n        const strInput = input.arraySync() as unknown as string[];\n\n        if (strInput.length <= length) {\n          const pads = Array(length - strInput.length).fill(padValue);\n          return tensor(strInput.concat(pads));\n        }\n\n        return tensor(strInput.slice(0, strInput.length - length));\n      }\n\n      const paddedMask: Tensor[] = x.map(t => {\n        // `onesLike` not used since it does not support string tensors.\n        const ones = tensor(Array(t.shape[0]).fill(1));\n        return ensureLength(ones, sequenceLength, 0).cast('bool');\n      });\n      const mask = inputIs1d ?\n        paddedMask[0] as Tensor1D\n        : stack(paddedMask) as Tensor2D;\n\n      const paddedTensors: Tensor[] =\n        x.map(t => ensureLength(t, sequenceLength, this.padValue));\n      const outputs = inputIs1d ?\n        paddedTensors[0] as Tensor1D\n        : stack(paddedTensors) as Tensor2D;\n\n      return [outputs, mask];\n    });\n  }\n\n  override getConfig(): serialization.ConfigDict {\n    const config = {\n      sequenceLength: this.sequenceLength,\n      startValue: this.startValue,\n      endValue: this.endValue,\n      padValue: this.padValue,\n    };\n    const baseConfig = super.getConfig();\n    Object.assign(config, baseConfig);\n    return config;\n  }\n}\nserialization.registerClass(StartEndPacker);\n"]}