/** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import { convertToTensor } from '../tensor_util_env'; import { add } from './add'; import { concat } from './concat'; import { matMul } from './mat_mul'; import { mul } from './mul'; import { op } from './operation'; import { sigmoid } from './sigmoid'; import { slice } from './slice'; import { tanh } from './tanh'; /** * Computes the next state and output of a BasicLSTMCell. * * Returns `[newC, newH]`. * * Derived from tf.contrib.rnn.BasicLSTMCell. * * @param forgetBias Forget bias for the cell. * @param lstmKernel The weights for the cell. * @param lstmBias The bias for the cell. * @param data The input to the cell. * @param c Previous cell state. * @param h Previous cell output. * * @doc {heading: 'Operations', subheading: 'RNN'} */ function basicLSTMCell_(forgetBias, lstmKernel, lstmBias, data, c, h) { const $forgetBias = convertToTensor(forgetBias, 'forgetBias', 'basicLSTMCell'); const $lstmKernel = convertToTensor(lstmKernel, 'lstmKernel', 'basicLSTMCell'); const $lstmBias = convertToTensor(lstmBias, 'lstmBias', 'basicLSTMCell'); const $data = convertToTensor(data, 'data', 'basicLSTMCell'); const $c = convertToTensor(c, 'c', 'basicLSTMCell'); const $h = convertToTensor(h, 'h', 'basicLSTMCell'); const combined = concat([$data, $h], 1); const weighted = matMul(combined, $lstmKernel); const res = add(weighted, $lstmBias); // i = input_gate, j = new_input, f = forget_gate, o = output_gate const batchSize = res.shape[0]; const sliceCols = res.shape[1] / 4; const sliceSize = [batchSize, sliceCols]; const i = slice(res, [0, 0], sliceSize); const j = slice(res, [0, sliceCols], sliceSize); const f = slice(res, [0, sliceCols * 2], sliceSize); const o = slice(res, [0, sliceCols * 3], sliceSize); const newC = add(mul(sigmoid(i), tanh(j)), mul($c, sigmoid(add($forgetBias, f)))); const newH = mul(tanh(newC), sigmoid(o)); return [newC, newH]; } export const basicLSTMCell = /* @__PURE__ */ op({ basicLSTMCell_ }); //# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYmFzaWNfbHN0bV9jZWxsLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvYmFzaWNfbHN0bV9jZWxsLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUdILE9BQU8sRUFBQyxlQUFlLEVBQUMsTUFBTSxvQkFBb0IsQ0FBQztBQUduRCxPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBQzFCLE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxVQUFVLENBQUM7QUFDaEMsT0FBTyxFQUFDLE1BQU0sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNqQyxPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBQzFCLE9BQU8sRUFBQyxFQUFFLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFDL0IsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNsQyxPQUFPLEVBQUMsS0FBSyxFQUFDLE1BQU0sU0FBUyxDQUFDO0FBQzlCLE9BQU8sRUFBQyxJQUFJLEVBQUMsTUFBTSxRQUFRLENBQUM7QUFFNUI7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBQ0gsU0FBUyxjQUFjLENBQ25CLFVBQTZCLEVBQUUsVUFBK0IsRUFDOUQsUUFBNkIsRUFBRSxJQUF5QixFQUN4RCxDQUFzQixFQUFFLENBQXNCO0lBQ2hELE1BQU0sV0FBVyxHQUNiLGVBQWUsQ0FBQyxVQUFVLEVBQUUsWUFBWSxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBQy9ELE1BQU0sV0FBVyxHQUNiLGVBQWUsQ0FBQyxVQUFVLEVBQUUsWUFBWSxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBQy9ELE1BQU0sU0FBUyxHQUFHLGVBQWUsQ0FBQyxRQUFRLEVBQUUsVUFBVSxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBQ3pFLE1BQU0sS0FBSyxHQUFHLGVBQWUsQ0FBQyxJQUFJLEVBQUUsTUFBTSxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBQzdELE1BQU0sRUFBRSxHQUFHLGVBQWUsQ0FBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBQ3BELE1BQU0sRUFBRSxHQUFHLGVBQWUsQ0FBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBRXBELE1BQU0sUUFBUSxHQUFHLE1BQU0sQ0FBQyxDQUFDLEtBQUssRUFBRSxFQUFFLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQztJQUN4QyxNQUFNLFFBQVEsR0FBRyxNQUFNLENBQUMsUUFBUSxFQUFFLFdBQVcsQ0FBQyxDQUFDO0lBQy9DLE1BQU0sR0FBRyxHQUFhLEdBQUcsQ0FBQyxRQUFRLEVBQUUsU0FBUyxDQUFDLENBQUM7SUFFL0Msa0VBQWtFO0lBQ2xFLE1BQU0sU0FBUyxHQUFHLEdBQUcsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUM7SUFDL0IsTUFBTSxTQUFTLEdBQUcsR0FBRyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUM7SUFDbkMsTUFBTSxTQUFTLEdBQXFCLENBQUMsU0FBUyxFQUFFLFNBQVMsQ0FBQyxDQUFDO0lBQzNELE1BQU0sQ0FBQyxHQUFHLEtBQUssQ0FBQyxHQUFHLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUUsU0FBUyxDQUFDLENBQUM7SUFDeEMsTUFBTSxDQUFDLEdBQUcsS0FBSyxDQUFDLEdBQUcsRUFBRSxDQUFDLENBQUMsRUFBRSxTQUFTLENBQUMsRUFBRSxTQUFTLENBQUMsQ0FBQztJQUNoRCxNQUFNLENBQUMsR0FBRyxLQUFLLENBQUMsR0FBRyxFQUFFLENBQUMsQ0FBQyxFQUFFLFNBQVMsR0FBRyxDQUFDLENBQUMsRUFBRSxTQUFTLENBQUMsQ0FBQztJQUNwRCxNQUFNLENBQUMsR0FBRyxLQUFLLENBQUMsR0FBRyxFQUFFLENBQUMsQ0FBQyxFQUFFLFNBQVMsR0FBRyxDQUFDLENBQUMsRUFBRSxTQUFTLENBQUMsQ0FBQztJQUVwRCxNQUFNLElBQUksR0FDTixHQUFHLENBQUMsR0FBRyxDQUFDLE9BQU8sQ0FBQyxDQUFDLENBQUMsRUFBRSxJQUFJLENBQUMsQ0FBQyxDQUFDLENBQUMsRUFDeEIsR0FBRyxDQUFDLEVBQUUsRUFBRSxPQUFPLENBQUMsR0FBRyxDQUFDLFdBQVcsRUFBRSxDQUFDLENBQUMsQ0FBYSxDQUFDLENBQUMsQ0FBQztJQUMzRCxNQUFNLElBQUksR0FBYSxHQUFHLENBQUMsSUFBSSxDQUFDLElBQUksQ0FBQyxFQUFFLE9BQU8sQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO0lBQ25ELE9BQU8sQ0FBQyxJQUFJLEVBQUUsSUFBSSxDQUFDLENBQUM7QUFDdEIsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLGFBQWEsR0FBRyxlQUFlLENBQUMsRUFBRSxDQUFDLEVBQUMsY0FBYyxFQUFDLENBQUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtTY2FsYXIsIFRlbnNvcjFELCBUZW5zb3IyRH0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCB7Y29udmVydFRvVGVuc29yfSBmcm9tICcuLi90ZW5zb3JfdXRpbF9lbnYnO1xuaW1wb3J0IHtUZW5zb3JMaWtlfSBmcm9tICcuLi90eXBlcyc7XG5cbmltcG9ydCB7YWRkfSBmcm9tICcuL2FkZCc7XG5pbXBvcnQge2NvbmNhdH0gZnJvbSAnLi9jb25jYXQnO1xuaW1wb3J0IHttYXRNdWx9IGZyb20gJy4vbWF0X211bCc7XG5pbXBvcnQge211bH0gZnJvbSAnLi9tdWwnO1xuaW1wb3J0IHtvcH0gZnJvbSAnLi9vcGVyYXRpb24nO1xuaW1wb3J0IHtzaWdtb2lkfSBmcm9tICcuL3NpZ21vaWQnO1xuaW1wb3J0IHtzbGljZX0gZnJvbSAnLi9zbGljZSc7XG5pbXBvcnQge3Rhbmh9IGZyb20gJy4vdGFuaCc7XG5cbi8qKlxuICogQ29tcHV0ZXMgdGhlIG5leHQgc3RhdGUgYW5kIG91dHB1dCBvZiBhIEJhc2ljTFNUTUNlbGwuXG4gKlxuICogUmV0dXJucyBgW25ld0MsIG5ld0hdYC5cbiAqXG4gKiBEZXJpdmVkIGZyb20gdGYuY29udHJpYi5ybm4uQmFzaWNMU1RNQ2VsbC5cbiAqXG4gKiBAcGFyYW0gZm9yZ2V0QmlhcyBGb3JnZXQgYmlhcyBmb3IgdGhlIGNlbGwuXG4gKiBAcGFyYW0gbHN0bUtlcm5lbCBUaGUgd2VpZ2h0cyBmb3IgdGhlIGNlbGwuXG4gKiBAcGFyYW0gbHN0bUJpYXMgVGhlIGJpYXMgZm9yIHRoZSBjZWxsLlxuICogQHBhcmFtIGRhdGEgVGhlIGlucHV0IHRvIHRoZSBjZWxsLlxuICogQHBhcmFtIGMgUHJldmlvdXMgY2VsbCBzdGF0ZS5cbiAqIEBwYXJhbSBoIFByZXZpb3VzIGNlbGwgb3V0cHV0LlxuICpcbiAqIEBkb2Mge2hlYWRpbmc6ICdPcGVyYXRpb25zJywgc3ViaGVhZGluZzogJ1JOTid9XG4gKi9cbmZ1bmN0aW9uIGJhc2ljTFNUTUNlbGxfKFxuICAgIGZvcmdldEJpYXM6IFNjYWxhcnxUZW5zb3JMaWtlLCBsc3RtS2VybmVsOiBUZW5zb3IyRHxUZW5zb3JMaWtlLFxuICAgIGxzdG1CaWFzOiBUZW5zb3IxRHxUZW5zb3JMaWtlLCBkYXRhOiBUZW5zb3IyRHxUZW5zb3JMaWtlLFxuICAgIGM6IFRlbnNvcjJEfFRlbnNvckxpa2UsIGg6IFRlbnNvcjJEfFRlbnNvckxpa2UpOiBbVGVuc29yMkQsIFRlbnNvcjJEXSB7XG4gIGNvbnN0ICRmb3JnZXRCaWFzID1cbiAgICAgIGNvbnZlcnRUb1RlbnNvcihmb3JnZXRCaWFzLCAnZm9yZ2V0QmlhcycsICdiYXNpY0xTVE1DZWxsJyk7XG4gIGNvbnN0ICRsc3RtS2VybmVsID1cbiAgICAgIGNvbnZlcnRUb1RlbnNvcihsc3RtS2VybmVsLCAnbHN0bUtlcm5lbCcsICdiYXNpY0xTVE1DZWxsJyk7XG4gIGNvbnN0ICRsc3RtQmlhcyA9IGNvbnZlcnRUb1RlbnNvcihsc3RtQmlhcywgJ2xzdG1CaWFzJywgJ2Jhc2ljTFNUTUNlbGwnKTtcbiAgY29uc3QgJGRhdGEgPSBjb252ZXJ0VG9UZW5zb3IoZGF0YSwgJ2RhdGEnLCAnYmFzaWNMU1RNQ2VsbCcpO1xuICBjb25zdCAkYyA9IGNvbnZlcnRUb1RlbnNvcihjLCAnYycsICdiYXNpY0xTVE1DZWxsJyk7XG4gIGNvbnN0ICRoID0gY29udmVydFRvVGVuc29yKGgsICdoJywgJ2Jhc2ljTFNUTUNlbGwnKTtcblxuICBjb25zdCBjb21iaW5lZCA9IGNvbmNhdChbJGRhdGEsICRoXSwgMSk7XG4gIGNvbnN0IHdlaWdodGVkID0gbWF0TXVsKGNvbWJpbmVkLCAkbHN0bUtlcm5lbCk7XG4gIGNvbnN0IHJlczogVGVuc29yMkQgPSBhZGQod2VpZ2h0ZWQsICRsc3RtQmlhcyk7XG5cbiAgLy8gaSA9IGlucHV0X2dhdGUsIGogPSBuZXdfaW5wdXQsIGYgPSBmb3JnZXRfZ2F0ZSwgbyA9IG91dHB1dF9nYXRlXG4gIGNvbnN0IGJhdGNoU2l6ZSA9IHJlcy5zaGFwZVswXTtcbiAgY29uc3Qgc2xpY2VDb2xzID0gcmVzLnNoYXBlWzFdIC8gNDtcbiAgY29uc3Qgc2xpY2VTaXplOiBbbnVtYmVyLCBudW1iZXJdID0gW2JhdGNoU2l6ZSwgc2xpY2VDb2xzXTtcbiAgY29uc3QgaSA9IHNsaWNlKHJlcywgWzAsIDBdLCBzbGljZVNpemUpO1xuICBjb25zdCBqID0gc2xpY2UocmVzLCBbMCwgc2xpY2VDb2xzXSwgc2xpY2VTaXplKTtcbiAgY29uc3QgZiA9IHNsaWNlKHJlcywgWzAsIHNsaWNlQ29scyAqIDJdLCBzbGljZVNpemUpO1xuICBjb25zdCBvID0gc2xpY2UocmVzLCBbMCwgc2xpY2VDb2xzICogM10sIHNsaWNlU2l6ZSk7XG5cbiAgY29uc3QgbmV3QzogVGVuc29yMkQgPVxuICAgICAgYWRkKG11bChzaWdtb2lkKGkpLCB0YW5oKGopKSxcbiAgICAgICAgICBtdWwoJGMsIHNpZ21vaWQoYWRkKCRmb3JnZXRCaWFzLCBmKSkgYXMgVGVuc29yMkQpKTtcbiAgY29uc3QgbmV3SDogVGVuc29yMkQgPSBtdWwodGFuaChuZXdDKSwgc2lnbW9pZChvKSk7XG4gIHJldHVybiBbbmV3QywgbmV3SF07XG59XG5cbmV4cG9ydCBjb25zdCBiYXNpY0xTVE1DZWxsID0gLyogQF9fUFVSRV9fICovIG9wKHtiYXNpY0xTVE1DZWxsX30pO1xuIl19