import { convertToTensor, convertToTensorArray } from '../tensor_util_env';
|
import { op } from './operation';
|
/**
|
* Computes the next states and outputs of a stack of LSTMCells.
|
*
|
* Each cell output is used as input to the next cell.
|
*
|
* Returns `[cellState, cellOutput]`.
|
*
|
* Derived from tf.contrib.rn.MultiRNNCell.
|
*
|
* @param lstmCells Array of LSTMCell functions.
|
* @param data The input to the cell.
|
* @param c Array of previous cell states.
|
* @param h Array of previous cell outputs.
|
*
|
* @doc {heading: 'Operations', subheading: 'RNN'}
|
*/
|
function multiRNNCell_(lstmCells, data, c, h) {
|
const $data = convertToTensor(data, 'data', 'multiRNNCell');
|
const $c = convertToTensorArray(c, 'c', 'multiRNNCell');
|
const $h = convertToTensorArray(h, 'h', 'multiRNNCell');
|
let input = $data;
|
const newStates = [];
|
for (let i = 0; i < lstmCells.length; i++) {
|
const output = lstmCells[i](input, $c[i], $h[i]);
|
newStates.push(output[0]);
|
newStates.push(output[1]);
|
input = output[1];
|
}
|
const newC = [];
|
const newH = [];
|
for (let i = 0; i < newStates.length; i += 2) {
|
newC.push(newStates[i]);
|
newH.push(newStates[i + 1]);
|
}
|
return [newC, newH];
|
}
|
export const multiRNNCell = /* @__PURE__ */ op({ multiRNNCell_ });
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoibXVsdGlfcm5uX2NlbGwuanMiLCJzb3VyY2VSb290IjoiIiwic291cmNlcyI6WyIuLi8uLi8uLi8uLi8uLi8uLi90ZmpzLWNvcmUvc3JjL29wcy9tdWx0aV9ybm5fY2VsbC50cyJdLCJuYW1lcyI6W10sIm1hcHBpbmdzIjoiQUFpQkEsT0FBTyxFQUFDLGVBQWUsRUFBRSxvQkFBb0IsRUFBQyxNQUFNLG9CQUFvQixDQUFDO0FBRXpFLE9BQU8sRUFBQyxFQUFFLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFTL0I7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBQ0gsU0FBUyxhQUFhLENBQ2xCLFNBQXlCLEVBQUUsSUFBeUIsRUFDcEQsQ0FBNkIsRUFDN0IsQ0FBNkI7SUFDL0IsTUFBTSxLQUFLLEdBQUcsZUFBZSxDQUFDLElBQUksRUFBRSxNQUFNLEVBQUUsY0FBYyxDQUFDLENBQUM7SUFDNUQsTUFBTSxFQUFFLEdBQUcsb0JBQW9CLENBQUMsQ0FBQyxFQUFFLEdBQUcsRUFBRSxjQUFjLENBQUMsQ0FBQztJQUN4RCxNQUFNLEVBQUUsR0FBRyxvQkFBb0IsQ0FBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLGNBQWMsQ0FBQyxDQUFDO0lBRXhELElBQUksS0FBSyxHQUFHLEtBQUssQ0FBQztJQUNsQixNQUFNLFNBQVMsR0FBRyxFQUFFLENBQUM7SUFDckIsS0FBSyxJQUFJLENBQUMsR0FBRyxDQUFDLEVBQUUsQ0FBQyxHQUFHLFNBQVMsQ0FBQyxNQUFNLEVBQUUsQ0FBQyxFQUFFLEVBQUU7UUFDekMsTUFBTSxNQUFNLEdBQUcsU0FBUyxDQUFDLENBQUMsQ0FBQyxDQUFDLEtBQUssRUFBRSxFQUFFLENBQUMsQ0FBQyxDQUFDLEVBQUUsRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDakQsU0FBUyxDQUFDLElBQUksQ0FBQyxNQUFNLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUMxQixTQUFTLENBQUMsSUFBSSxDQUFDLE1BQU0sQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQzFCLEtBQUssR0FBRyxNQUFNLENBQUMsQ0FBQyxDQUFDLENBQUM7S0FDbkI7SUFDRCxNQUFNLElBQUksR0FBZSxFQUFFLENBQUM7SUFDNUIsTUFBTSxJQUFJLEdBQWUsRUFBRSxDQUFDO0lBQzVCLEtBQUssSUFBSSxDQUFDLEdBQUcsQ0FBQyxFQUFFLENBQUMsR0FBRyxTQUFTLENBQUMsTUFBTSxFQUFFLENBQUMsSUFBSSxDQUFDLEVBQUU7UUFDNUMsSUFBSSxDQUFDLElBQUksQ0FBQyxTQUFTLENBQUMsQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUN4QixJQUFJLENBQUMsSUFBSSxDQUFDLFNBQVMsQ0FBQyxDQUFDLEdBQUcsQ0FBQyxDQUFDLENBQUMsQ0FBQztLQUM3QjtJQUNELE9BQU8sQ0FBQyxJQUFJLEVBQUUsSUFBSSxDQUFDLENBQUM7QUFDdEIsQ0FBQztBQUNELE1BQU0sQ0FBQyxNQUFNLFlBQVksR0FBRyxlQUFlLENBQUMsRUFBRSxDQUFDLEVBQUMsYUFBYSxFQUFDLENBQUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cbmltcG9ydCB7VGVuc29yMkR9IGZyb20gJy4uL3RlbnNvcic7XG5pbXBvcnQge2NvbnZlcnRUb1RlbnNvciwgY29udmVydFRvVGVuc29yQXJyYXl9IGZyb20gJy4uL3RlbnNvcl91dGlsX2Vudic7XG5pbXBvcnQge1RlbnNvckxpa2V9IGZyb20gJy4uL3R5cGVzJztcbmltcG9ydCB7b3B9IGZyb20gJy4vb3BlcmF0aW9uJztcblxuLyoqXG4gKiBAZG9jYWxpYXMgKGRhdGE6IFRlbnNvcjJELCBjOiBUZW5zb3IyRCwgaDogVGVuc29yMkQpOiBbVGVuc29yMkQsIFRlbnNvcjJEXVxuICovXG5leHBvcnQgdHlwZSBMU1RNQ2VsbEZ1bmMgPSB7XG4gIChkYXRhOiBUZW5zb3IyRCwgYzogVGVuc29yMkQsIGg6IFRlbnNvcjJEKTogW1RlbnNvcjJELCBUZW5zb3IyRF07XG59O1xuXG4vKipcbiAqIENvbXB1dGVzIHRoZSBuZXh0IHN0YXRlcyBhbmQgb3V0cHV0cyBvZiBhIHN0YWNrIG9mIExTVE1DZWxscy5cbiAqXG4gKiBFYWNoIGNlbGwgb3V0cHV0IGlzIHVzZWQgYXMgaW5wdXQgdG8gdGhlIG5leHQgY2VsbC5cbiAqXG4gKiBSZXR1cm5zIGBbY2VsbFN0YXRlLCBjZWxsT3V0cHV0XWAuXG4gKlxuICogRGVyaXZlZCBmcm9tIHRmLmNvbnRyaWIucm4uTXVsdGlSTk5DZWxsLlxuICpcbiAqIEBwYXJhbSBsc3RtQ2VsbHMgQXJyYXkgb2YgTFNUTUNlbGwgZnVuY3Rpb25zLlxuICogQHBhcmFtIGRhdGEgVGhlIGlucHV0IHRvIHRoZSBjZWxsLlxuICogQHBhcmFtIGMgQXJyYXkgb2YgcHJldmlvdXMgY2VsbCBzdGF0ZXMuXG4gKiBAcGFyYW0gaCBBcnJheSBvZiBwcmV2aW91cyBjZWxsIG91dHB1dHMuXG4gKlxuICogQGRvYyB7aGVhZGluZzogJ09wZXJhdGlvbnMnLCBzdWJoZWFkaW5nOiAnUk5OJ31cbiAqL1xuZnVuY3Rpb24gbXVsdGlSTk5DZWxsXyhcbiAgICBsc3RtQ2VsbHM6IExTVE1DZWxsRnVuY1tdLCBkYXRhOiBUZW5zb3IyRHxUZW5zb3JMaWtlLFxuICAgIGM6IEFycmF5PFRlbnNvcjJEfFRlbnNvckxpa2U+LFxuICAgIGg6IEFycmF5PFRlbnNvcjJEfFRlbnNvckxpa2U+KTogW1RlbnNvcjJEW10sIFRlbnNvcjJEW11dIHtcbiAgY29uc3QgJGRhdGEgPSBjb252ZXJ0VG9UZW5zb3IoZGF0YSwgJ2RhdGEnLCAnbXVsdGlSTk5DZWxsJyk7XG4gIGNvbnN0ICRjID0gY29udmVydFRvVGVuc29yQXJyYXkoYywgJ2MnLCAnbXVsdGlSTk5DZWxsJyk7XG4gIGNvbnN0ICRoID0gY29udmVydFRvVGVuc29yQXJyYXkoaCwgJ2gnLCAnbXVsdGlSTk5DZWxsJyk7XG5cbiAgbGV0IGlucHV0ID0gJGRhdGE7XG4gIGNvbnN0IG5ld1N0YXRlcyA9IFtdO1xuICBmb3IgKGxldCBpID0gMDsgaSA8IGxzdG1DZWxscy5sZW5ndGg7IGkrKykge1xuICAgIGNvbnN0IG91dHB1dCA9IGxzdG1DZWxsc1tpXShpbnB1dCwgJGNbaV0sICRoW2ldKTtcbiAgICBuZXdTdGF0ZXMucHVzaChvdXRwdXRbMF0pO1xuICAgIG5ld1N0YXRlcy5wdXNoKG91dHB1dFsxXSk7XG4gICAgaW5wdXQgPSBvdXRwdXRbMV07XG4gIH1cbiAgY29uc3QgbmV3QzogVGVuc29yMkRbXSA9IFtdO1xuICBjb25zdCBuZXdIOiBUZW5zb3IyRFtdID0gW107XG4gIGZvciAobGV0IGkgPSAwOyBpIDwgbmV3U3RhdGVzLmxlbmd0aDsgaSArPSAyKSB7XG4gICAgbmV3Qy5wdXNoKG5ld1N0YXRlc1tpXSk7XG4gICAgbmV3SC5wdXNoKG5ld1N0YXRlc1tpICsgMV0pO1xuICB9XG4gIHJldHVybiBbbmV3QywgbmV3SF07XG59XG5leHBvcnQgY29uc3QgbXVsdGlSTk5DZWxsID0gLyogQF9fUFVSRV9fICovIG9wKHttdWx0aVJOTkNlbGxffSk7XG4iXX0=
|