/**
|
* @license
|
* Copyright 2023 Google LLC.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
/// <amd-module name="@tensorflow/tfjs-layers/dist/layers/nlp/models/generative_task" />
|
/**
|
* Base class for Generative Task models.
|
*/
|
import { NamedTensorMap, Tensor } from '@tensorflow/tfjs-core';
|
import { ModelCompileArgs } from '../../../engine/training';
|
import { Task } from './task';
|
export type GenerateFn = (inputs: NamedTensorMap, endTokenId?: number) => NamedTensorMap;
|
/**
|
* Base class for Generative Task models.
|
*/
|
export declare class GenerativeTask extends Task {
|
/** @nocollapse */
|
static className: string;
|
protected generateFunction: GenerateFn;
|
compile(args: ModelCompileArgs): void;
|
/**
|
* Run the generation on a single batch of input.
|
*/
|
generateStep(inputs: NamedTensorMap, endTokenId: number): NamedTensorMap;
|
/**
|
* Create or return the compiled generation function.
|
*/
|
makeGenerateFunction(): GenerateFn;
|
/**
|
* Normalize user input to the generate function.
|
*
|
* This function converts all inputs to tensors, adds a batch dimension if
|
* necessary, and returns a iterable "dataset like" object.
|
*/
|
protected normalizeGenerateInputs(inputs: Tensor): [Tensor, boolean];
|
/**
|
* Normalize user output from the generate function.
|
*
|
* This function converts all output to numpy (for integer output), or
|
* python strings (for string output). If a batch dimension was added to
|
* the input, it is removed from the output (so generate can be string in,
|
* string out).
|
*/
|
protected normalizeGenerateOutputs(outputs: Tensor, inputIsScalar: boolean): Tensor;
|
/**
|
* Generate text given prompt `inputs`.
|
*
|
* This method generates text based on given `inputs`. The sampling method
|
* used for generation can be set via the `compile()` method.
|
*
|
* `inputs` will be handled as a single batch.
|
*
|
* If a `preprocessor` is attached to the model, `inputs` will be
|
* preprocessed inside the `generate()` function and should match the
|
* structure expected by the `preprocessor` layer (usually raw strings).
|
* If a `preprocessor` is not attached, inputs should match the structure
|
* expected by the `backbone`. See the example usage above for a
|
* demonstration of each.
|
*
|
* @param inputs tensor data. If a `preprocessor` is attached to the model,
|
* `inputs` should match the structure expected by the `preprocessor` layer.
|
* If a `preprocessor` is not attached, `inputs` should match the structure
|
* expected the the `backbone` model.
|
* @param maxLength Integer. The max length of the generated sequence.
|
* Will default to the max configured `sequenceLength` of the
|
* `preprocessor`. If `preprocessor` is `null`, `inputs` should be
|
* should be padded to the desired maximum length and this argument
|
* will be ignored.
|
*/
|
generate(inputs: Tensor, maxLength?: number): void;
|
}
|