gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
/**
 * @license
 * Copyright 2020 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
import { convertToTensor } from '../tensor_util_env';
import { add } from './add';
import { concat } from './concat';
import { matMul } from './mat_mul';
import { mul } from './mul';
import { op } from './operation';
import { sigmoid } from './sigmoid';
import { slice } from './slice';
import { tanh } from './tanh';
/**
 * Computes the next state and output of a BasicLSTMCell.
 *
 * Returns `[newC, newH]`.
 *
 * Derived from tf.contrib.rnn.BasicLSTMCell.
 *
 * @param forgetBias Forget bias for the cell.
 * @param lstmKernel The weights for the cell.
 * @param lstmBias The bias for the cell.
 * @param data The input to the cell.
 * @param c Previous cell state.
 * @param h Previous cell output.
 *
 * @doc {heading: 'Operations', subheading: 'RNN'}
 */
function basicLSTMCell_(forgetBias, lstmKernel, lstmBias, data, c, h) {
    const $forgetBias = convertToTensor(forgetBias, 'forgetBias', 'basicLSTMCell');
    const $lstmKernel = convertToTensor(lstmKernel, 'lstmKernel', 'basicLSTMCell');
    const $lstmBias = convertToTensor(lstmBias, 'lstmBias', 'basicLSTMCell');
    const $data = convertToTensor(data, 'data', 'basicLSTMCell');
    const $c = convertToTensor(c, 'c', 'basicLSTMCell');
    const $h = convertToTensor(h, 'h', 'basicLSTMCell');
    const combined = concat([$data, $h], 1);
    const weighted = matMul(combined, $lstmKernel);
    const res = add(weighted, $lstmBias);
    // i = input_gate, j = new_input, f = forget_gate, o = output_gate
    const batchSize = res.shape[0];
    const sliceCols = res.shape[1] / 4;
    const sliceSize = [batchSize, sliceCols];
    const i = slice(res, [0, 0], sliceSize);
    const j = slice(res, [0, sliceCols], sliceSize);
    const f = slice(res, [0, sliceCols * 2], sliceSize);
    const o = slice(res, [0, sliceCols * 3], sliceSize);
    const newC = add(mul(sigmoid(i), tanh(j)), mul($c, sigmoid(add($forgetBias, f))));
    const newH = mul(tanh(newC), sigmoid(o));
    return [newC, newH];
}
export const basicLSTMCell = /* @__PURE__ */ op({ basicLSTMCell_ });
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoiYmFzaWNfbHN0bV9jZWxsLmpzIiwic291cmNlUm9vdCI6IiIsInNvdXJjZXMiOlsiLi4vLi4vLi4vLi4vLi4vLi4vdGZqcy1jb3JlL3NyYy9vcHMvYmFzaWNfbHN0bV9jZWxsLnRzIl0sIm5hbWVzIjpbXSwibWFwcGluZ3MiOiJBQUFBOzs7Ozs7Ozs7Ozs7Ozs7R0FlRztBQUdILE9BQU8sRUFBQyxlQUFlLEVBQUMsTUFBTSxvQkFBb0IsQ0FBQztBQUduRCxPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBQzFCLE9BQU8sRUFBQyxNQUFNLEVBQUMsTUFBTSxVQUFVLENBQUM7QUFDaEMsT0FBTyxFQUFDLE1BQU0sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNqQyxPQUFPLEVBQUMsR0FBRyxFQUFDLE1BQU0sT0FBTyxDQUFDO0FBQzFCLE9BQU8sRUFBQyxFQUFFLEVBQUMsTUFBTSxhQUFhLENBQUM7QUFDL0IsT0FBTyxFQUFDLE9BQU8sRUFBQyxNQUFNLFdBQVcsQ0FBQztBQUNsQyxPQUFPLEVBQUMsS0FBSyxFQUFDLE1BQU0sU0FBUyxDQUFDO0FBQzlCLE9BQU8sRUFBQyxJQUFJLEVBQUMsTUFBTSxRQUFRLENBQUM7QUFFNUI7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBQ0gsU0FBUyxjQUFjLENBQ25CLFVBQTZCLEVBQUUsVUFBK0IsRUFDOUQsUUFBNkIsRUFBRSxJQUF5QixFQUN4RCxDQUFzQixFQUFFLENBQXNCO0lBQ2hELE1BQU0sV0FBVyxHQUNiLGVBQWUsQ0FBQyxVQUFVLEVBQUUsWUFBWSxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBQy9ELE1BQU0sV0FBVyxHQUNiLGVBQWUsQ0FBQyxVQUFVLEVBQUUsWUFBWSxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBQy9ELE1BQU0sU0FBUyxHQUFHLGVBQWUsQ0FBQyxRQUFRLEVBQUUsVUFBVSxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBQ3pFLE1BQU0sS0FBSyxHQUFHLGVBQWUsQ0FBQyxJQUFJLEVBQUUsTUFBTSxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBQzdELE1BQU0sRUFBRSxHQUFHLGVBQWUsQ0FBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBQ3BELE1BQU0sRUFBRSxHQUFHLGVBQWUsQ0FBQyxDQUFDLEVBQUUsR0FBRyxFQUFFLGVBQWUsQ0FBQyxDQUFDO0lBRXBELE1BQU0sUUFBUSxHQUFHLE1BQU0sQ0FBQyxDQUFDLEtBQUssRUFBRSxFQUFFLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQztJQUN4QyxNQUFNLFFBQVEsR0FBRyxNQUFNLENBQUMsUUFBUSxFQUFFLFdBQVcsQ0FBQyxDQUFDO0lBQy9DLE1BQU0sR0FBRyxHQUFhLEdBQUcsQ0FBQyxRQUFRLEVBQUUsU0FBUyxDQUFDLENBQUM7SUFFL0Msa0VBQWtFO0lBQ2xFLE1BQU0sU0FBUyxHQUFHLEdBQUcsQ0FBQyxLQUFLLENBQUMsQ0FBQyxDQUFDLENBQUM7SUFDL0IsTUFBTSxTQUFTLEdBQUcsR0FBRyxDQUFDLEtBQUssQ0FBQyxDQUFDLENBQUMsR0FBRyxDQUFDLENBQUM7SUFDbkMsTUFBTSxTQUFTLEdBQXFCLENBQUMsU0FBUyxFQUFFLFNBQVMsQ0FBQyxDQUFDO0lBQzNELE1BQU0sQ0FBQyxHQUFHLEtBQUssQ0FBQyxHQUFHLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUUsU0FBUyxDQUFDLENBQUM7SUFDeEMsTUFBTSxDQUFDLEdBQUcsS0FBSyxDQUFDLEdBQUcsRUFBRSxDQUFDLENBQUMsRUFBRSxTQUFTLENBQUMsRUFBRSxTQUFTLENBQUMsQ0FBQztJQUNoRCxNQUFNLENBQUMsR0FBRyxLQUFLLENBQUMsR0FBRyxFQUFFLENBQUMsQ0FBQyxFQUFFLFNBQVMsR0FBRyxDQUFDLENBQUMsRUFBRSxTQUFTLENBQUMsQ0FBQztJQUNwRCxNQUFNLENBQUMsR0FBRyxLQUFLLENBQUMsR0FBRyxFQUFFLENBQUMsQ0FBQyxFQUFFLFNBQVMsR0FBRyxDQUFDLENBQUMsRUFBRSxTQUFTLENBQUMsQ0FBQztJQUVwRCxNQUFNLElBQUksR0FDTixHQUFHLENBQUMsR0FBRyxDQUFDLE9BQU8sQ0FBQyxDQUFDLENBQUMsRUFBRSxJQUFJLENBQUMsQ0FBQyxDQUFDLENBQUMsRUFDeEIsR0FBRyxDQUFDLEVBQUUsRUFBRSxPQUFPLENBQUMsR0FBRyxDQUFDLFdBQVcsRUFBRSxDQUFDLENBQUMsQ0FBYSxDQUFDLENBQUMsQ0FBQztJQUMzRCxNQUFNLElBQUksR0FBYSxHQUFHLENBQUMsSUFBSSxDQUFDLElBQUksQ0FBQyxFQUFFLE9BQU8sQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO0lBQ25ELE9BQU8sQ0FBQyxJQUFJLEVBQUUsSUFBSSxDQUFDLENBQUM7QUFDdEIsQ0FBQztBQUVELE1BQU0sQ0FBQyxNQUFNLGFBQWEsR0FBRyxlQUFlLENBQUMsRUFBRSxDQUFDLEVBQUMsY0FBYyxFQUFDLENBQUMsQ0FBQyIsInNvdXJjZXNDb250ZW50IjpbIi8qKlxuICogQGxpY2Vuc2VcbiAqIENvcHlyaWdodCAyMDIwIEdvb2dsZSBMTEMuIEFsbCBSaWdodHMgUmVzZXJ2ZWQuXG4gKiBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgXCJMaWNlbnNlXCIpO1xuICogeW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLlxuICogWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0XG4gKlxuICogaHR0cDovL3d3dy5hcGFjaGUub3JnL2xpY2Vuc2VzL0xJQ0VOU0UtMi4wXG4gKlxuICogVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZVxuICogZGlzdHJpYnV0ZWQgdW5kZXIgdGhlIExpY2Vuc2UgaXMgZGlzdHJpYnV0ZWQgb24gYW4gXCJBUyBJU1wiIEJBU0lTLFxuICogV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuXG4gKiBTZWUgdGhlIExpY2Vuc2UgZm9yIHRoZSBzcGVjaWZpYyBsYW5ndWFnZSBnb3Zlcm5pbmcgcGVybWlzc2lvbnMgYW5kXG4gKiBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS5cbiAqID09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09XG4gKi9cblxuaW1wb3J0IHtTY2FsYXIsIFRlbnNvcjFELCBUZW5zb3IyRH0gZnJvbSAnLi4vdGVuc29yJztcbmltcG9ydCB7Y29udmVydFRvVGVuc29yfSBmcm9tICcuLi90ZW5zb3JfdXRpbF9lbnYnO1xuaW1wb3J0IHtUZW5zb3JMaWtlfSBmcm9tICcuLi90eXBlcyc7XG5cbmltcG9ydCB7YWRkfSBmcm9tICcuL2FkZCc7XG5pbXBvcnQge2NvbmNhdH0gZnJvbSAnLi9jb25jYXQnO1xuaW1wb3J0IHttYXRNdWx9IGZyb20gJy4vbWF0X211bCc7XG5pbXBvcnQge211bH0gZnJvbSAnLi9tdWwnO1xuaW1wb3J0IHtvcH0gZnJvbSAnLi9vcGVyYXRpb24nO1xuaW1wb3J0IHtzaWdtb2lkfSBmcm9tICcuL3NpZ21vaWQnO1xuaW1wb3J0IHtzbGljZX0gZnJvbSAnLi9zbGljZSc7XG5pbXBvcnQge3Rhbmh9IGZyb20gJy4vdGFuaCc7XG5cbi8qKlxuICogQ29tcHV0ZXMgdGhlIG5leHQgc3RhdGUgYW5kIG91dHB1dCBvZiBhIEJhc2ljTFNUTUNlbGwuXG4gKlxuICogUmV0dXJucyBgW25ld0MsIG5ld0hdYC5cbiAqXG4gKiBEZXJpdmVkIGZyb20gdGYuY29udHJpYi5ybm4uQmFzaWNMU1RNQ2VsbC5cbiAqXG4gKiBAcGFyYW0gZm9yZ2V0QmlhcyBGb3JnZXQgYmlhcyBmb3IgdGhlIGNlbGwuXG4gKiBAcGFyYW0gbHN0bUtlcm5lbCBUaGUgd2VpZ2h0cyBmb3IgdGhlIGNlbGwuXG4gKiBAcGFyYW0gbHN0bUJpYXMgVGhlIGJpYXMgZm9yIHRoZSBjZWxsLlxuICogQHBhcmFtIGRhdGEgVGhlIGlucHV0IHRvIHRoZSBjZWxsLlxuICogQHBhcmFtIGMgUHJldmlvdXMgY2VsbCBzdGF0ZS5cbiAqIEBwYXJhbSBoIFByZXZpb3VzIGNlbGwgb3V0cHV0LlxuICpcbiAqIEBkb2Mge2hlYWRpbmc6ICdPcGVyYXRpb25zJywgc3ViaGVhZGluZzogJ1JOTid9XG4gKi9cbmZ1bmN0aW9uIGJhc2ljTFNUTUNlbGxfKFxuICAgIGZvcmdldEJpYXM6IFNjYWxhcnxUZW5zb3JMaWtlLCBsc3RtS2VybmVsOiBUZW5zb3IyRHxUZW5zb3JMaWtlLFxuICAgIGxzdG1CaWFzOiBUZW5zb3IxRHxUZW5zb3JMaWtlLCBkYXRhOiBUZW5zb3IyRHxUZW5zb3JMaWtlLFxuICAgIGM6IFRlbnNvcjJEfFRlbnNvckxpa2UsIGg6IFRlbnNvcjJEfFRlbnNvckxpa2UpOiBbVGVuc29yMkQsIFRlbnNvcjJEXSB7XG4gIGNvbnN0ICRmb3JnZXRCaWFzID1cbiAgICAgIGNvbnZlcnRUb1RlbnNvcihmb3JnZXRCaWFzLCAnZm9yZ2V0QmlhcycsICdiYXNpY0xTVE1DZWxsJyk7XG4gIGNvbnN0ICRsc3RtS2VybmVsID1cbiAgICAgIGNvbnZlcnRUb1RlbnNvcihsc3RtS2VybmVsLCAnbHN0bUtlcm5lbCcsICdiYXNpY0xTVE1DZWxsJyk7XG4gIGNvbnN0ICRsc3RtQmlhcyA9IGNvbnZlcnRUb1RlbnNvcihsc3RtQmlhcywgJ2xzdG1CaWFzJywgJ2Jhc2ljTFNUTUNlbGwnKTtcbiAgY29uc3QgJGRhdGEgPSBjb252ZXJ0VG9UZW5zb3IoZGF0YSwgJ2RhdGEnLCAnYmFzaWNMU1RNQ2VsbCcpO1xuICBjb25zdCAkYyA9IGNvbnZlcnRUb1RlbnNvcihjLCAnYycsICdiYXNpY0xTVE1DZWxsJyk7XG4gIGNvbnN0ICRoID0gY29udmVydFRvVGVuc29yKGgsICdoJywgJ2Jhc2ljTFNUTUNlbGwnKTtcblxuICBjb25zdCBjb21iaW5lZCA9IGNvbmNhdChbJGRhdGEsICRoXSwgMSk7XG4gIGNvbnN0IHdlaWdodGVkID0gbWF0TXVsKGNvbWJpbmVkLCAkbHN0bUtlcm5lbCk7XG4gIGNvbnN0IHJlczogVGVuc29yMkQgPSBhZGQod2VpZ2h0ZWQsICRsc3RtQmlhcyk7XG5cbiAgLy8gaSA9IGlucHV0X2dhdGUsIGogPSBuZXdfaW5wdXQsIGYgPSBmb3JnZXRfZ2F0ZSwgbyA9IG91dHB1dF9nYXRlXG4gIGNvbnN0IGJhdGNoU2l6ZSA9IHJlcy5zaGFwZVswXTtcbiAgY29uc3Qgc2xpY2VDb2xzID0gcmVzLnNoYXBlWzFdIC8gNDtcbiAgY29uc3Qgc2xpY2VTaXplOiBbbnVtYmVyLCBudW1iZXJdID0gW2JhdGNoU2l6ZSwgc2xpY2VDb2xzXTtcbiAgY29uc3QgaSA9IHNsaWNlKHJlcywgWzAsIDBdLCBzbGljZVNpemUpO1xuICBjb25zdCBqID0gc2xpY2UocmVzLCBbMCwgc2xpY2VDb2xzXSwgc2xpY2VTaXplKTtcbiAgY29uc3QgZiA9IHNsaWNlKHJlcywgWzAsIHNsaWNlQ29scyAqIDJdLCBzbGljZVNpemUpO1xuICBjb25zdCBvID0gc2xpY2UocmVzLCBbMCwgc2xpY2VDb2xzICogM10sIHNsaWNlU2l6ZSk7XG5cbiAgY29uc3QgbmV3QzogVGVuc29yMkQgPVxuICAgICAgYWRkKG11bChzaWdtb2lkKGkpLCB0YW5oKGopKSxcbiAgICAgICAgICBtdWwoJGMsIHNpZ21vaWQoYWRkKCRmb3JnZXRCaWFzLCBmKSkgYXMgVGVuc29yMkQpKTtcbiAgY29uc3QgbmV3SDogVGVuc29yMkQgPSBtdWwodGFuaChuZXdDKSwgc2lnbW9pZChvKSk7XG4gIHJldHVybiBbbmV3QywgbmV3SF07XG59XG5cbmV4cG9ydCBjb25zdCBiYXNpY0xTVE1DZWxsID0gLyogQF9fUFVSRV9fICovIG9wKHtiYXNpY0xTVE1DZWxsX30pO1xuIl19