gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
/**
 * @license
 * Copyright 2023 Google LLC.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
/// <amd-module name="@tensorflow/tfjs-layers/dist/layers/nlp/models/generative_task" />
/**
 *  Base class for Generative Task models.
 */
import { NamedTensorMap, Tensor } from '@tensorflow/tfjs-core';
import { ModelCompileArgs } from '../../../engine/training';
import { Task } from './task';
export type GenerateFn = (inputs: NamedTensorMap, endTokenId?: number) => NamedTensorMap;
/**
 *  Base class for Generative Task models.
 */
export declare class GenerativeTask extends Task {
    /** @nocollapse */
    static className: string;
    protected generateFunction: GenerateFn;
    compile(args: ModelCompileArgs): void;
    /**
     * Run the generation on a single batch of input.
     */
    generateStep(inputs: NamedTensorMap, endTokenId: number): NamedTensorMap;
    /**
     * Create or return the compiled generation function.
     */
    makeGenerateFunction(): GenerateFn;
    /**
     * Normalize user input to the generate function.
     *
     * This function converts all inputs to tensors, adds a batch dimension if
     * necessary, and returns a iterable "dataset like" object.
     */
    protected normalizeGenerateInputs(inputs: Tensor): [Tensor, boolean];
    /**
     * Normalize user output from the generate function.
     *
     * This function converts all output to numpy (for integer output), or
     * python strings (for string output). If a batch dimension was added to
     * the input, it is removed from the output (so generate can be string in,
     * string out).
     */
    protected normalizeGenerateOutputs(outputs: Tensor, inputIsScalar: boolean): Tensor;
    /**
     * Generate text given prompt `inputs`.
     *
     * This method generates text based on given `inputs`. The sampling method
     * used for generation can be set via the `compile()` method.
     *
     * `inputs` will be handled as a single batch.
     *
     * If a `preprocessor` is attached to the model, `inputs` will be
     * preprocessed inside the `generate()` function and should match the
     * structure expected by the `preprocessor` layer (usually raw strings).
     * If a `preprocessor` is not attached, inputs should match the structure
     * expected by the `backbone`. See the example usage above for a
     * demonstration of each.
     *
     * @param inputs tensor data. If a `preprocessor` is attached to the model,
     *  `inputs` should match the structure expected by the `preprocessor` layer.
     *  If a `preprocessor` is not attached, `inputs` should match the structure
     *  expected the the `backbone` model.
     * @param maxLength Integer. The max length of the generated sequence.
     *  Will default to the max configured `sequenceLength` of the
     *  `preprocessor`. If `preprocessor` is `null`, `inputs` should be
     *  should be padded to the desired maximum length and this argument
     *  will be ignored.
     */
    generate(inputs: Tensor, maxLength?: number): void;
}