/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import * as tf from '../index';
|
import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
describeWithFlags('basicLSTMCell', ALL_ENVS, () => {
|
it('basicLSTMCell with batch=2', async () => {
|
const lstmKernel = tf.randomNormal([3, 4]);
|
const lstmBias = tf.randomNormal([4]);
|
const forgetBias = tf.scalar(1.0);
|
const data = tf.randomNormal([1, 2]);
|
const batchedData = tf.concat2d([data, data], 0); // 2x2
|
const c = tf.randomNormal([1, 1]);
|
const batchedC = tf.concat2d([c, c], 0); // 2x1
|
const h = tf.randomNormal([1, 1]);
|
const batchedH = tf.concat2d([h, h], 0); // 2x1
|
const [newC, newH] = tf.basicLSTMCell(forgetBias, lstmKernel, lstmBias, batchedData, batchedC, batchedH);
|
const newCVals = await newC.array();
|
const newHVals = await newH.array();
|
expect(newCVals[0][0]).toEqual(newCVals[1][0]);
|
expect(newHVals[0][0]).toEqual(newHVals[1][0]);
|
});
|
it('basicLSTMCell accepts a tensor-like object', async () => {
|
const lstmKernel = tf.randomNormal([3, 4]);
|
const lstmBias = [0, 0, 0, 0];
|
const forgetBias = 1;
|
const data = [[0, 0]]; // 1x2
|
const batchedData = tf.concat2d([data, data], 0); // 2x2
|
const c = [[0]]; // 1x1
|
const batchedC = tf.concat2d([c, c], 0); // 2x1
|
const h = [[0]]; // 1x1
|
const batchedH = tf.concat2d([h, h], 0); // 2x1
|
const [newC, newH] = tf.basicLSTMCell(forgetBias, lstmKernel, lstmBias, batchedData, batchedC, batchedH);
|
const newCVals = await newC.array();
|
const newHVals = await newH.array();
|
expect(newCVals[0][0]).toEqual(newCVals[1][0]);
|
expect(newHVals[0][0]).toEqual(newHVals[1][0]);
|
});
|
});
|
describeWithFlags('basicLSTMCell throws with non-tensor', ALL_ENVS, () => {
|
it('input: forgetBias', () => {
|
const lstmKernel = tf.randomNormal([3, 4]);
|
const lstmBias = tf.randomNormal([4]);
|
const data = tf.randomNormal([1, 2]);
|
const batchedData = tf.concat2d([data, data], 0); // 2x2
|
const c = tf.randomNormal([1, 1]);
|
const batchedC = tf.concat2d([c, c], 0); // 2x1
|
const h = tf.randomNormal([1, 1]);
|
const batchedH = tf.concat2d([h, h], 0); // 2x1
|
expect(() => tf.basicLSTMCell({}, lstmKernel, lstmBias, batchedData, batchedC, batchedH))
|
.toThrowError(/Argument 'forgetBias' passed to 'basicLSTMCell' must be a Tensor/);
|
});
|
it('input: lstmKernel', () => {
|
const lstmBias = tf.randomNormal([4]);
|
const forgetBias = tf.scalar(1.0);
|
const data = tf.randomNormal([1, 2]);
|
const batchedData = tf.concat2d([data, data], 0); // 2x2
|
const c = tf.randomNormal([1, 1]);
|
const batchedC = tf.concat2d([c, c], 0); // 2x1
|
const h = tf.randomNormal([1, 1]);
|
const batchedH = tf.concat2d([h, h], 0); // 2x1
|
expect(() => tf.basicLSTMCell(forgetBias, {}, lstmBias, batchedData, batchedC, batchedH))
|
.toThrowError(/Argument 'lstmKernel' passed to 'basicLSTMCell' must be a Tensor/);
|
});
|
it('input: lstmBias', () => {
|
const lstmKernel = tf.randomNormal([3, 4]);
|
const forgetBias = tf.scalar(1.0);
|
const data = tf.randomNormal([1, 2]);
|
const batchedData = tf.concat2d([data, data], 0); // 2x2
|
const c = tf.randomNormal([1, 1]);
|
const batchedC = tf.concat2d([c, c], 0); // 2x1
|
const h = tf.randomNormal([1, 1]);
|
const batchedH = tf.concat2d([h, h], 0); // 2x1
|
expect(() => tf.basicLSTMCell(forgetBias, lstmKernel, {}, batchedData, batchedC, batchedH))
|
.toThrowError(/Argument 'lstmBias' passed to 'basicLSTMCell' must be a Tensor/);
|
});
|
it('input: data', () => {
|
const lstmKernel = tf.randomNormal([3, 4]);
|
const lstmBias = tf.randomNormal([4]);
|
const forgetBias = tf.scalar(1.0);
|
const c = tf.randomNormal([1, 1]);
|
const batchedC = tf.concat2d([c, c], 0); // 2x1
|
const h = tf.randomNormal([1, 1]);
|
const batchedH = tf.concat2d([h, h], 0); // 2x1
|
expect(() => tf.basicLSTMCell(forgetBias, lstmKernel, lstmBias, {}, batchedC, batchedH))
|
.toThrowError(/Argument 'data' passed to 'basicLSTMCell' must be a Tensor/);
|
});
|
it('input: c', () => {
|
const lstmKernel = tf.randomNormal([3, 4]);
|
const lstmBias = tf.randomNormal([4]);
|
const forgetBias = tf.scalar(1.0);
|
const data = tf.randomNormal([1, 2]);
|
const batchedData = tf.concat2d([data, data], 0); // 2x2
|
const h = tf.randomNormal([1, 1]);
|
const batchedH = tf.concat2d([h, h], 0); // 2x1
|
expect(() => tf.basicLSTMCell(forgetBias, lstmKernel, lstmBias, batchedData, {}, batchedH))
|
.toThrowError(/Argument 'c' passed to 'basicLSTMCell' must be a Tensor/);
|
});
|
it('input: h', () => {
|
const lstmKernel = tf.randomNormal([3, 4]);
|
const lstmBias = tf.randomNormal([4]);
|
const forgetBias = tf.scalar(1.0);
|
const data = tf.randomNormal([1, 2]);
|
const batchedData = tf.concat2d([data, data], 0); // 2x2
|
const c = tf.randomNormal([1, 1]);
|
const batchedC = tf.concat2d([c, c], 0); // 2x1
|
expect(() => tf.basicLSTMCell(forgetBias, lstmKernel, lstmBias, batchedData, batchedC, {}))
|
.toThrowError(/Argument 'h' passed to 'basicLSTMCell' must be a Tensor/);
|
});
|
});
|
//# sourceMappingURL=data:application/json;base64,{"version":3,"file":"basic_lstm_cell_test.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/ops/basic_lstm_cell_test.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,KAAK,EAAE,MAAM,UAAU,CAAC;AAC/B,OAAO,EAAC,QAAQ,EAAE,iBAAiB,EAAC,MAAM,iBAAiB,CAAC;AAG5D,iBAAiB,CAAC,eAAe,EAAE,QAAQ,EAAE,GAAG,EAAE;IAChD,EAAE,CAAC,4BAA4B,EAAE,KAAK,IAAI,EAAE;QAC1C,MAAM,UAAU,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpD,MAAM,QAAQ,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,CAAC,CAAC,CAAC;QAC/C,MAAM,UAAU,GAAG,EAAE,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC;QAElC,MAAM,IAAI,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC9C,MAAM,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QACzD,MAAM,CAAC,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3C,MAAM,QAAQ,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3C,MAAM,QAAQ,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QAChD,MAAM,CAAC,IAAI,EAAE,IAAI,CAAC,GAAG,EAAE,CAAC,aAAa,CACjC,UAAU,EAAE,UAAU,EAAE,QAAQ,EAAE,WAAW,EAAE,QAAQ,EAAE,QAAQ,CAAC,CAAC;QACvE,MAAM,QAAQ,GAAG,MAAM,IAAI,CAAC,KAAK,EAAE,CAAC;QACpC,MAAM,QAAQ,GAAG,MAAM,IAAI,CAAC,KAAK,EAAE,CAAC;QACpC,MAAM,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAC/C,MAAM,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IACjD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,4CAA4C,EAAE,KAAK,IAAI,EAAE;QAC1D,MAAM,UAAU,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpD,MAAM,QAAQ,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;QAC9B,MAAM,UAAU,GAAG,CAAC,CAAC;QAErB,MAAM,IAAI,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAA6B,MAAM;QACzD,MAAM,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QACzD,MAAM,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAmC,MAAM;QACzD,MAAM,QAAQ,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAW,MAAM;QACzD,MAAM,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAmC,MAAM;QACzD,MAAM,QAAQ,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAW,MAAM;QACzD,MAAM,CAAC,IAAI,EAAE,IAAI,CAAC,GAAG,EAAE,CAAC,aAAa,CACjC,UAAU,EAAE,UAAU,EAAE,QAAQ,EAAE,WAAW,EAAE,QAAQ,EAAE,QAAQ,CAAC,CAAC;QACvE,MAAM,QAAQ,GAAG,MAAM,IAAI,CAAC,KAAK,EAAE,CAAC;QACpC,MAAM,QAAQ,GAAG,MAAM,IAAI,CAAC,KAAK,EAAE,CAAC;QACpC,MAAM,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAC/C,MAAM,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IACjD,CAAC,CAAC,CAAC;AACL,CAAC,CAAC,CAAC;AACH,iBAAiB,CAAC,sCAAsC,EAAE,QAAQ,EAAE,GAAG,EAAE;IACvE,EAAE,CAAC,mBAAmB,EAAE,GAAG,EAAE;QAC3B,MAAM,UAAU,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpD,MAAM,QAAQ,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,CAAC,CAAC,CAAC;QAE/C,MAAM,IAAI,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC9C,MAAM,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QACzD,MAAM,CAAC,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3C,MAAM,QAAQ,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3C,MAAM,QAAQ,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QAChD,MAAM,CACF,GAAG,EAAE,CAAC,EAAE,CAAC,aAAa,CAClB,EAAe,EAAE,UAAU,EAAE,QAAQ,EAAE,WAAW,EAAE,QAAQ,EAC5D,QAAQ,CAAC,CAAC;aACb,YAAY,CACT,kEAAkE,CAAC,CAAC;IAC9E,CAAC,CAAC,CAAC;IACH,EAAE,CAAC,mBAAmB,EAAE,GAAG,EAAE;QAC3B,MAAM,QAAQ,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,CAAC,CAAC,CAAC;QAC/C,MAAM,UAAU,GAAG,EAAE,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC;QAElC,MAAM,IAAI,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC9C,MAAM,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QACzD,MAAM,CAAC,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3C,MAAM,QAAQ,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3C,MAAM,QAAQ,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QAChD,MAAM,CACF,GAAG,EAAE,CAAC,EAAE,CAAC,aAAa,CAClB,UAAU,EAAE,EAAiB,EAAE,QAAQ,EAAE,WAAW,EAAE,QAAQ,EAC9D,QAAQ,CAAC,CAAC;aACb,YAAY,CACT,kEAAkE,CAAC,CAAC;IAC9E,CAAC,CAAC,CAAC;IACH,EAAE,CAAC,iBAAiB,EAAE,GAAG,EAAE;QACzB,MAAM,UAAU,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpD,MAAM,UAAU,GAAG,EAAE,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC;QAElC,MAAM,IAAI,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC9C,MAAM,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QACzD,MAAM,CAAC,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3C,MAAM,QAAQ,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3C,MAAM,QAAQ,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QAChD,MAAM,CACF,GAAG,EAAE,CAAC,EAAE,CAAC,aAAa,CAClB,UAAU,EAAE,UAAU,EAAE,EAAiB,EAAE,WAAW,EAAE,QAAQ,EAChE,QAAQ,CAAC,CAAC;aACb,YAAY,CACT,gEAAgE,CAAC,CAAC;IAC5E,CAAC,CAAC,CAAC;IACH,EAAE,CAAC,aAAa,EAAE,GAAG,EAAE;QACrB,MAAM,UAAU,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpD,MAAM,QAAQ,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,CAAC,CAAC,CAAC;QAC/C,MAAM,UAAU,GAAG,EAAE,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC;QAElC,MAAM,CAAC,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3C,MAAM,QAAQ,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3C,MAAM,QAAQ,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QAChD,MAAM,CACF,GAAG,EAAE,CAAC,EAAE,CAAC,aAAa,CAClB,UAAU,EAAE,UAAU,EAAE,QAAQ,EAAE,EAAiB,EAAE,QAAQ,EAC7D,QAAQ,CAAC,CAAC;aACb,YAAY,CACT,4DAA4D,CAAC,CAAC;IACxE,CAAC,CAAC,CAAC;IACH,EAAE,CAAC,UAAU,EAAE,GAAG,EAAE;QAClB,MAAM,UAAU,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpD,MAAM,QAAQ,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,CAAC,CAAC,CAAC;QAC/C,MAAM,UAAU,GAAG,EAAE,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC;QAElC,MAAM,IAAI,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC9C,MAAM,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QACzD,MAAM,CAAC,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3C,MAAM,QAAQ,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QAChD,MAAM,CACF,GAAG,EAAE,CAAC,EAAE,CAAC,aAAa,CAClB,UAAU,EAAE,UAAU,EAAE,QAAQ,EAAE,WAAW,EAAE,EAAiB,EAChE,QAAQ,CAAC,CAAC;aACb,YAAY,CACT,yDAAyD,CAAC,CAAC;IACrE,CAAC,CAAC,CAAC;IACH,EAAE,CAAC,UAAU,EAAE,GAAG,EAAE;QAClB,MAAM,UAAU,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpD,MAAM,QAAQ,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,CAAC,CAAC,CAAC;QAC/C,MAAM,UAAU,GAAG,EAAE,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC;QAElC,MAAM,IAAI,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC9C,MAAM,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QACzD,MAAM,CAAC,GAAG,EAAE,CAAC,YAAY,CAAU,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC3C,MAAM,QAAQ,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,MAAM;QAChD,MAAM,CACF,GAAG,EAAE,CAAC,EAAE,CAAC,aAAa,CAClB,UAAU,EAAE,UAAU,EAAE,QAAQ,EAAE,WAAW,EAAE,QAAQ,EACvD,EAAiB,CAAC,CAAC;aACtB,YAAY,CACT,yDAAyD,CAAC,CAAC;IACrE,CAAC,CAAC,CAAC;AACL,CAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2020 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as tf from '../index';\nimport {ALL_ENVS, describeWithFlags} from '../jasmine_util';\nimport {Rank} from '../types';\n\ndescribeWithFlags('basicLSTMCell', ALL_ENVS, () => {\n  it('basicLSTMCell with batch=2', async () => {\n    const lstmKernel = tf.randomNormal<Rank.R2>([3, 4]);\n    const lstmBias = tf.randomNormal<Rank.R1>([4]);\n    const forgetBias = tf.scalar(1.0);\n\n    const data = tf.randomNormal<Rank.R2>([1, 2]);\n    const batchedData = tf.concat2d([data, data], 0);  // 2x2\n    const c = tf.randomNormal<Rank.R2>([1, 1]);\n    const batchedC = tf.concat2d([c, c], 0);  // 2x1\n    const h = tf.randomNormal<Rank.R2>([1, 1]);\n    const batchedH = tf.concat2d([h, h], 0);  // 2x1\n    const [newC, newH] = tf.basicLSTMCell(\n        forgetBias, lstmKernel, lstmBias, batchedData, batchedC, batchedH);\n    const newCVals = await newC.array();\n    const newHVals = await newH.array();\n    expect(newCVals[0][0]).toEqual(newCVals[1][0]);\n    expect(newHVals[0][0]).toEqual(newHVals[1][0]);\n  });\n\n  it('basicLSTMCell accepts a tensor-like object', async () => {\n    const lstmKernel = tf.randomNormal<Rank.R2>([3, 4]);\n    const lstmBias = [0, 0, 0, 0];\n    const forgetBias = 1;\n\n    const data = [[0, 0]];                             // 1x2\n    const batchedData = tf.concat2d([data, data], 0);  // 2x2\n    const c = [[0]];                                   // 1x1\n    const batchedC = tf.concat2d([c, c], 0);           // 2x1\n    const h = [[0]];                                   // 1x1\n    const batchedH = tf.concat2d([h, h], 0);           // 2x1\n    const [newC, newH] = tf.basicLSTMCell(\n        forgetBias, lstmKernel, lstmBias, batchedData, batchedC, batchedH);\n    const newCVals = await newC.array();\n    const newHVals = await newH.array();\n    expect(newCVals[0][0]).toEqual(newCVals[1][0]);\n    expect(newHVals[0][0]).toEqual(newHVals[1][0]);\n  });\n});\ndescribeWithFlags('basicLSTMCell throws with non-tensor', ALL_ENVS, () => {\n  it('input: forgetBias', () => {\n    const lstmKernel = tf.randomNormal<Rank.R2>([3, 4]);\n    const lstmBias = tf.randomNormal<Rank.R1>([4]);\n\n    const data = tf.randomNormal<Rank.R2>([1, 2]);\n    const batchedData = tf.concat2d([data, data], 0);  // 2x2\n    const c = tf.randomNormal<Rank.R2>([1, 1]);\n    const batchedC = tf.concat2d([c, c], 0);  // 2x1\n    const h = tf.randomNormal<Rank.R2>([1, 1]);\n    const batchedH = tf.concat2d([h, h], 0);  // 2x1\n    expect(\n        () => tf.basicLSTMCell(\n            {} as tf.Scalar, lstmKernel, lstmBias, batchedData, batchedC,\n            batchedH))\n        .toThrowError(\n            /Argument 'forgetBias' passed to 'basicLSTMCell' must be a Tensor/);\n  });\n  it('input: lstmKernel', () => {\n    const lstmBias = tf.randomNormal<Rank.R1>([4]);\n    const forgetBias = tf.scalar(1.0);\n\n    const data = tf.randomNormal<Rank.R2>([1, 2]);\n    const batchedData = tf.concat2d([data, data], 0);  // 2x2\n    const c = tf.randomNormal<Rank.R2>([1, 1]);\n    const batchedC = tf.concat2d([c, c], 0);  // 2x1\n    const h = tf.randomNormal<Rank.R2>([1, 1]);\n    const batchedH = tf.concat2d([h, h], 0);  // 2x1\n    expect(\n        () => tf.basicLSTMCell(\n            forgetBias, {} as tf.Tensor2D, lstmBias, batchedData, batchedC,\n            batchedH))\n        .toThrowError(\n            /Argument 'lstmKernel' passed to 'basicLSTMCell' must be a Tensor/);\n  });\n  it('input: lstmBias', () => {\n    const lstmKernel = tf.randomNormal<Rank.R2>([3, 4]);\n    const forgetBias = tf.scalar(1.0);\n\n    const data = tf.randomNormal<Rank.R2>([1, 2]);\n    const batchedData = tf.concat2d([data, data], 0);  // 2x2\n    const c = tf.randomNormal<Rank.R2>([1, 1]);\n    const batchedC = tf.concat2d([c, c], 0);  // 2x1\n    const h = tf.randomNormal<Rank.R2>([1, 1]);\n    const batchedH = tf.concat2d([h, h], 0);  // 2x1\n    expect(\n        () => tf.basicLSTMCell(\n            forgetBias, lstmKernel, {} as tf.Tensor1D, batchedData, batchedC,\n            batchedH))\n        .toThrowError(\n            /Argument 'lstmBias' passed to 'basicLSTMCell' must be a Tensor/);\n  });\n  it('input: data', () => {\n    const lstmKernel = tf.randomNormal<Rank.R2>([3, 4]);\n    const lstmBias = tf.randomNormal<Rank.R1>([4]);\n    const forgetBias = tf.scalar(1.0);\n\n    const c = tf.randomNormal<Rank.R2>([1, 1]);\n    const batchedC = tf.concat2d([c, c], 0);  // 2x1\n    const h = tf.randomNormal<Rank.R2>([1, 1]);\n    const batchedH = tf.concat2d([h, h], 0);  // 2x1\n    expect(\n        () => tf.basicLSTMCell(\n            forgetBias, lstmKernel, lstmBias, {} as tf.Tensor2D, batchedC,\n            batchedH))\n        .toThrowError(\n            /Argument 'data' passed to 'basicLSTMCell' must be a Tensor/);\n  });\n  it('input: c', () => {\n    const lstmKernel = tf.randomNormal<Rank.R2>([3, 4]);\n    const lstmBias = tf.randomNormal<Rank.R1>([4]);\n    const forgetBias = tf.scalar(1.0);\n\n    const data = tf.randomNormal<Rank.R2>([1, 2]);\n    const batchedData = tf.concat2d([data, data], 0);  // 2x2\n    const h = tf.randomNormal<Rank.R2>([1, 1]);\n    const batchedH = tf.concat2d([h, h], 0);  // 2x1\n    expect(\n        () => tf.basicLSTMCell(\n            forgetBias, lstmKernel, lstmBias, batchedData, {} as tf.Tensor2D,\n            batchedH))\n        .toThrowError(\n            /Argument 'c' passed to 'basicLSTMCell' must be a Tensor/);\n  });\n  it('input: h', () => {\n    const lstmKernel = tf.randomNormal<Rank.R2>([3, 4]);\n    const lstmBias = tf.randomNormal<Rank.R1>([4]);\n    const forgetBias = tf.scalar(1.0);\n\n    const data = tf.randomNormal<Rank.R2>([1, 2]);\n    const batchedData = tf.concat2d([data, data], 0);  // 2x2\n    const c = tf.randomNormal<Rank.R2>([1, 1]);\n    const batchedC = tf.concat2d([c, c], 0);  // 2x1\n    expect(\n        () => tf.basicLSTMCell(\n            forgetBias, lstmKernel, lstmBias, batchedData, batchedC,\n            {} as tf.Tensor2D))\n        .toThrowError(\n            /Argument 'h' passed to 'basicLSTMCell' must be a Tensor/);\n  });\n});\n"]}
|