/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import { convertToTensor } from '../tensor_util_env';
|
import { add } from './add';
|
import { concat } from './concat';
|
import { matMul } from './mat_mul';
|
import { mul } from './mul';
|
import { op } from './operation';
|
import { sigmoid } from './sigmoid';
|
import { slice } from './slice';
|
import { tanh } from './tanh';
|
/**
|
* Computes the next state and output of a BasicLSTMCell.
|
*
|
* Returns `[newC, newH]`.
|
*
|
* Derived from tf.contrib.rnn.BasicLSTMCell.
|
*
|
* @param forgetBias Forget bias for the cell.
|
* @param lstmKernel The weights for the cell.
|
* @param lstmBias The bias for the cell.
|
* @param data The input to the cell.
|
* @param c Previous cell state.
|
* @param h Previous cell output.
|
*
|
* @doc {heading: 'Operations', subheading: 'RNN'}
|
*/
|
function basicLSTMCell_(forgetBias, lstmKernel, lstmBias, data, c, h) {
|
const $forgetBias = convertToTensor(forgetBias, 'forgetBias', 'basicLSTMCell');
|
const $lstmKernel = convertToTensor(lstmKernel, 'lstmKernel', 'basicLSTMCell');
|
const $lstmBias = convertToTensor(lstmBias, 'lstmBias', 'basicLSTMCell');
|
const $data = convertToTensor(data, 'data', 'basicLSTMCell');
|
const $c = convertToTensor(c, 'c', 'basicLSTMCell');
|
const $h = convertToTensor(h, 'h', 'basicLSTMCell');
|
const combined = concat([$data, $h], 1);
|
const weighted = matMul(combined, $lstmKernel);
|
const res = add(weighted, $lstmBias);
|
// i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
const batchSize = res.shape[0];
|
const sliceCols = res.shape[1] / 4;
|
const sliceSize = [batchSize, sliceCols];
|
const i = slice(res, [0, 0], sliceSize);
|
const j = slice(res, [0, sliceCols], sliceSize);
|
const f = slice(res, [0, sliceCols * 2], sliceSize);
|
const o = slice(res, [0, sliceCols * 3], sliceSize);
|
const newC = add(mul(sigmoid(i), tanh(j)), mul($c, sigmoid(add($forgetBias, f))));
|
const newH = mul(tanh(newC), sigmoid(o));
|
return [newC, newH];
|
}
|
export const basicLSTMCell = /* @__PURE__ */ op({ basicLSTMCell_ });
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYmFzaWNfbHN0bV9jZWxsLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvYmFzaWNfbHN0bV9jZWxsLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUdILE9BQU8sRUFBQyxlQUFlLEVBQUMsTUFBTSxvQkFBb0IsQ0FBQztBQUduRCxPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBQzFCLE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxVQUFVLENBQUM7QUFDaEMsT0FBTyxFQUFDLE1BQU0sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNqQyxPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBQzFCLE9BQU8sRUFBQyxFQUFFLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFDL0IsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNsQyxPQUFPLEVBQUMsS0FBSyxFQUFDLE1BQU0sU0FBUyxDQUFDO0FBQzlCLE9BQU8sRUFBQyxJQUFJLEVBQUMsTUFBTSxRQUFRLENBQUM7QUFFNUI7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBQ0gsU0FBUyxjQUFjLENBQ25CLFVBQTZCLEVBQUUsVUFBK0IsRUFDOUQsUUFBNkIsRUFBRSxJQUF5QixFQUN4RCxDQUFzQixFQUFFLENBQXNCO0lBQ2hELE1BQU0sV0FBVyxHQUNiLGVBQWUsQ0FBQyxVQUFVLEVBQUUsWUFBWSxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBQy9ELE1BQU0sV0FBVyxHQUNiLGVBQWUsQ0FBQyxVQUFVLEVBQUUsWUFBWSxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBQy9ELE1BQU0sU0FBUyxHQUFHLGVBQWUsQ0FBQyxRQUFRLEVBQUUsVUFBVSxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBQ3pFLE1BQU0sS0FBSyxHQUFHLGVBQWUsQ0FBQyxJQUFJLEVBQUUsTUFBTSxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBQzdELE1BQU0sRUFBRSxHQUFHLGVBQWUsQ0FBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBQ3BELE1BQU0sRUFBRSxHQUFHLGVBQWUsQ0FBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBRXBELE1BQU0sUUFBUSxHQUFHLE1BQU0sQ0FBQyxDQUFDLEtBQUssRUFBRSxFQUFFLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQztJQUN4QyxNQUFNLFFBQVEsR0FBRyxNQUFNLENBQUMsUUFBUSxFQUFFLFdBQVcsQ0FBQyxDQUFDO0lBQy9DLE1BQU0sR0FBRyxHQUFhLEdBQUcsQ0FBQyxRQUFRLEVBQUUsU0FBUyxDQUFDLENBQUM7SUFFL0Msa0VBQWtFO0lBQ2xFLE1BQU0sU0FBUyxHQUFHLEdBQUcsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUM7SUFDL0IsTUFBTSxTQUFTLEdBQUcsR0FBRyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUM7SUFDbkMsTUFBTSxTQUFTLEdBQXFCLENBQUMsU0FBUyxFQUFFLFNBQVMsQ0FBQyxDQUFDO0lBQzNELE1BQU0sQ0FBQyxHQUFHLEtBQUssQ0FBQyxHQUFHLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUUsU0FBUyxDQUFDLENBQUM7SUFDeEMsTUFBTSxDQUFDLEdBQUcsS0FBSyxDQUFDLEdBQUcsRUFBRSxDQUFDLENBQUMsRUFBRSxTQUFTLENBQUMsRUFBRSxTQUFTLENBQUMsQ0FBQztJQUNoRCxNQUFNLENBQUMsR0FBRyxLQUFLLENBQUMsR0FBRyxFQUFFLENBQUMsQ0FBQyxFQUFFLFNBQVMsR0FBRyxDQUFDLENBQUMsRUFBRSxTQUFTLENBQUMsQ0FBQztJQUNwRCxNQUFNLENBQUMsR0FBRyxLQUFLLENBQUMsR0FBRyxFQUFFLENBQUMsQ0FBQyxFQUFFLFNBQVMsR0FBRyxDQUFDLENBQUMsRUFBRSxTQUFTLENBQUMsQ0FBQztJQUVwRCxNQUFNLElBQUksR0FDTixHQUFHLENBQUMsR0FBRyxDQUFDLE9BQU8sQ0FBQyxDQUFDLENBQUMsRUFBRSxJQUFJLENBQUMsQ0FBQyxDQUFDLENBQUMsRUFDeEIsR0FBRyxDQUFDLEVBQUUsRUFBRSxPQUFPLENBQUMsR0FBRyxDQUFDLFdBQVcsRUFBRSxDQUFDLENBQUMsQ0FBYSxDQUFDLENBQUMsQ0FBQztJQUMzRCxNQUFNLElBQUksR0FBYSxHQUFHLENBQUMsSUFBSSxDQUFDLElBQUksQ0FBQyxFQUFFLE9BQU8sQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO0lBQ25ELE9BQU8sQ0FBQyxJQUFJLEVBQUUsSUFBSSxDQUFDLENBQUM7QUFDdEIsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLGFBQWEsR0FBRyxlQUFlLENBQUMsRUFBRSxDQUFDLEVBQUMsY0FBYyxFQUFDLENBQUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtTY2FsYXIsIFRlbnNvcjFELCBUZW5zb3IyRH0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCB7Y29udmVydFRvVGVuc29yfSBmcm9tICcuLi90ZW5zb3JfdXRpbF9lbnYnO1xuaW1wb3J0IHtUZW5zb3JMaWtlfSBmcm9tICcuLi90eXBlcyc7XG5cbmltcG9ydCB7YWRkfSBmcm9tICcuL2FkZCc7XG5pbXBvcnQge2NvbmNhdH0gZnJvbSAnLi9jb25jYXQnO1xuaW1wb3J0IHttYXRNdWx9IGZyb20gJy4vbWF0X211bCc7XG5pbXBvcnQge211bH0gZnJvbSAnLi9tdWwnO1xuaW1wb3J0IHtvcH0gZnJvbSAnLi9vcGVyYXRpb24nO1xuaW1wb3J0IHtzaWdtb2lkfSBmcm9tICcuL3NpZ21vaWQnO1xuaW1wb3J0IHtzbGljZX0gZnJvbSAnLi9zbGljZSc7XG5pbXBvcnQge3Rhbmh9IGZyb20gJy4vdGFuaCc7XG5cbi8qKlxuICogQ29tcHV0ZXMgdGhlIG5leHQgc3RhdGUgYW5kIG91dHB1dCBvZiBhIEJhc2ljTFNUTUNlbGwuXG4gKlxuICogUmV0dXJucyBgW25ld0MsIG5ld0hdYC5cbiAqXG4gKiBEZXJpdmVkIGZyb20gdGYuY29udHJpYi5ybm4uQmFzaWNMU1RNQ2VsbC5cbiAqXG4gKiBAcGFyYW0gZm9yZ2V0QmlhcyBGb3JnZXQgYmlhcyBmb3IgdGhlIGNlbGwuXG4gKiBAcGFyYW0gbHN0bUtlcm5lbCBUaGUgd2VpZ2h0cyBmb3IgdGhlIGNlbGwuXG4gKiBAcGFyYW0gbHN0bUJpYXMgVGhlIGJpYXMgZm9yIHRoZSBjZWxsLlxuICogQHBhcmFtIGRhdGEgVGhlIGlucHV0IHRvIHRoZSBjZWxsLlxuICogQHBhcmFtIGMgUHJldmlvdXMgY2VsbCBzdGF0ZS5cbiAqIEBwYXJhbSBoIFByZXZpb3VzIGNlbGwgb3V0cHV0LlxuICpcbiAqIEBkb2Mge2hlYWRpbmc6ICdPcGVyYXRpb25zJywgc3ViaGVhZGluZzogJ1JOTid9XG4gKi9cbmZ1bmN0aW9uIGJhc2ljTFNUTUNlbGxfKFxuICAgIGZvcmdldEJpYXM6IFNjYWxhcnxUZW5zb3JMaWtlLCBsc3RtS2VybmVsOiBUZW5zb3IyRHxUZW5zb3JMaWtlLFxuICAgIGxzdG1CaWFzOiBUZW5zb3IxRHxUZW5zb3JMaWtlLCBkYXRhOiBUZW5zb3IyRHxUZW5zb3JMaWtlLFxuICAgIGM6IFRlbnNvcjJEfFRlbnNvckxpa2UsIGg6IFRlbnNvcjJEfFRlbnNvckxpa2UpOiBbVGVuc29yMkQsIFRlbnNvcjJEXSB7XG4gIGNvbnN0ICRmb3JnZXRCaWFzID1cbiAgICAgIGNvbnZlcnRUb1RlbnNvcihmb3JnZXRCaWFzLCAnZm9yZ2V0QmlhcycsICdiYXNpY0xTVE1DZWxsJyk7XG4gIGNvbnN0ICRsc3RtS2VybmVsID1cbiAgICAgIGNvbnZlcnRUb1RlbnNvcihsc3RtS2VybmVsLCAnbHN0bUtlcm5lbCcsICdiYXNpY0xTVE1DZWxsJyk7XG4gIGNvbnN0ICRsc3RtQmlhcyA9IGNvbnZlcnRUb1RlbnNvcihsc3RtQmlhcywgJ2xzdG1CaWFzJywgJ2Jhc2ljTFNUTUNlbGwnKTtcbiAgY29uc3QgJGRhdGEgPSBjb252ZXJ0VG9UZW5zb3IoZGF0YSwgJ2RhdGEnLCAnYmFzaWNMU1RNQ2VsbCcpO1xuICBjb25zdCAkYyA9IGNvbnZlcnRUb1RlbnNvcihjLCAnYycsICdiYXNpY0xTVE1DZWxsJyk7XG4gIGNvbnN0ICRoID0gY29udmVydFRvVGVuc29yKGgsICdoJywgJ2Jhc2ljTFNUTUNlbGwnKTtcblxuICBjb25zdCBjb21iaW5lZCA9IGNvbmNhdChbJGRhdGEsICRoXSwgMSk7XG4gIGNvbnN0IHdlaWdodGVkID0gbWF0TXVsKGNvbWJpbmVkLCAkbHN0bUtlcm5lbCk7XG4gIGNvbnN0IHJlczogVGVuc29yMkQgPSBhZGQod2VpZ2h0ZWQsICRsc3RtQmlhcyk7XG5cbiAgLy8gaSA9IGlucHV0X2dhdGUsIGogPSBuZXdfaW5wdXQsIGYgPSBmb3JnZXRfZ2F0ZSwgbyA9IG91dHB1dF9nYXRlXG4gIGNvbnN0IGJhdGNoU2l6ZSA9IHJlcy5zaGFwZVswXTtcbiAgY29uc3Qgc2xpY2VDb2xzID0gcmVzLnNoYXBlWzFdIC8gNDtcbiAgY29uc3Qgc2xpY2VTaXplOiBbbnVtYmVyLCBudW1iZXJdID0gW2JhdGNoU2l6ZSwgc2xpY2VDb2xzXTtcbiAgY29uc3QgaSA9IHNsaWNlKHJlcywgWzAsIDBdLCBzbGljZVNpemUpO1xuICBjb25zdCBqID0gc2xpY2UocmVzLCBbMCwgc2xpY2VDb2xzXSwgc2xpY2VTaXplKTtcbiAgY29uc3QgZiA9IHNsaWNlKHJlcywgWzAsIHNsaWNlQ29scyAqIDJdLCBzbGljZVNpemUpO1xuICBjb25zdCBvID0gc2xpY2UocmVzLCBbMCwgc2xpY2VDb2xzICogM10sIHNsaWNlU2l6ZSk7XG5cbiAgY29uc3QgbmV3QzogVGVuc29yMkQgPVxuICAgICAgYWRkKG11bChzaWdtb2lkKGkpLCB0YW5oKGopKSxcbiAgICAgICAgICBtdWwoJGMsIHNpZ21vaWQoYWRkKCRmb3JnZXRCaWFzLCBmKSkgYXMgVGVuc29yMkQpKTtcbiAgY29uc3QgbmV3SDogVGVuc29yMkQgPSBtdWwodGFuaChuZXdDKSwgc2lnbW9pZChvKSk7XG4gIHJldHVybiBbbmV3QywgbmV3SF07XG59XG5cbmV4cG9ydCBjb25zdCBiYXNpY0xTVE1DZWxsID0gLyogQF9fUFVSRV9fICovIG9wKHtiYXNpY0xTVE1DZWxsX30pO1xuIl19
|