/**
|
* @license
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
import * as tf from '../index';
|
import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
|
import { expectArraysClose } from '../test_util';
|
describeWithFlags('logSoftmax', ALL_ENVS, () => {
|
it('regular test', async () => {
|
const y = tf.logSoftmax(tf.tensor1d([2, 1, 3]));
|
expectArraysClose(await y.data(), [-1.407606, -2.4076061, -0.407606]);
|
});
|
it('Huge difference', async () => {
|
const y = tf.logSoftmax(tf.tensor1d([-1000, +1000]));
|
expectArraysClose(await y.data(), [-2000, 0]);
|
});
|
it('Propagates NaNs', async () => {
|
const a = tf.tensor1d([2, 1, NaN]);
|
const y = tf.logSoftmax(a);
|
expectArraysClose(await y.data(), [NaN, NaN, NaN]);
|
});
|
it('2D, axis=1', async () => {
|
const y = tf.logSoftmax(tf.tensor2d([[2, 1, 3], [1, 3, 2]], [2, 3]), 1);
|
const expected = [-1.407606, -2.4076061, -0.407606, -2.4076061, -0.4076061, -1.4076061];
|
expect(y.rank).toBe(2);
|
expectArraysClose(await y.data(), expected);
|
});
|
it('2D, implicit axis=1', async () => {
|
const y = tf.logSoftmax(tf.tensor2d([[2, 1, 3], [1, 3, 2]], [2, 3]));
|
const expected = [-1.407606, -2.4076061, -0.407606, -2.4076061, -0.4076061, -1.4076061];
|
expect(y.rank).toBe(2);
|
expectArraysClose(await y.data(), expected);
|
});
|
it('1D gradient', async () => {
|
const x = tf.tensor1d([1, 2, 10]);
|
const dy = tf.tensor1d([1, 2, 3]);
|
const dx = tf.grad((x) => x.logSoftmax())(x, dy);
|
expect(dx.shape).toEqual(x.shape);
|
expectArraysClose(await dx.data(), [0.9992599, 1.9979881, -2.9972477]);
|
});
|
it('2D, axis=0 throws error', () => {
|
const f = () => {
|
tf.logSoftmax(tf.tensor2d([[2, 1, 3], [1, 3, 2]], [2, 3]), 0);
|
};
|
expect(f).toThrowError();
|
});
|
it('throws when passed a non-tensor', () => {
|
expect(() => tf.logSoftmax({}))
|
.toThrowError(/Argument 'logits' passed to 'logSoftmax' must be a Tensor/);
|
});
|
it('accepts a tensor-like object', async () => {
|
const y = tf.logSoftmax([2, 1, 3]);
|
expectArraysClose(await y.data(), [-1.407606, -2.4076061, -0.407606]);
|
});
|
});
|
//# sourceMappingURL=data:application/json;base64,eyJ2ZXJzaW9uIjozLCJmaWxlIjoibG9nX3NvZnRtYXhfdGVzdC5qcyIsInNvdXJjZVJvb3QiOiIiLCJzb3VyY2VzIjpbIi4uLy4uLy4uLy4uLy4uLy4uL3RmanMtY29yZS9zcmMvb3BzL2xvZ19zb2Z0bWF4X3Rlc3QudHMiXSwibmFtZXMiOltdLCJtYXBwaW5ncyI6IkFBQUE7Ozs7Ozs7Ozs7Ozs7OztHQWVHO0FBRUgsT0FBTyxLQUFLLEVBQUUsTUFBTSxVQUFVLENBQUM7QUFDL0IsT0FBTyxFQUFDLFFBQVEsRUFBRSxpQkFBaUIsRUFBQyxNQUFNLGlCQUFpQixDQUFDO0FBQzVELE9BQU8sRUFBQyxpQkFBaUIsRUFBQyxNQUFNLGNBQWMsQ0FBQztBQUUvQyxpQkFBaUIsQ0FBQyxZQUFZLEVBQUUsUUFBUSxFQUFFLEdBQUcsRUFBRTtJQUM3QyxFQUFFLENBQUMsY0FBYyxFQUFFLEtBQUssSUFBSSxFQUFFO1FBQzVCLE1BQU0sQ0FBQyxHQUFHLEVBQUUsQ0FBQyxVQUFVLENBQUMsRUFBRSxDQUFDLFFBQVEsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBRWhELGlCQUFpQixDQUFDLE1BQU0sQ0FBQyxDQUFDLElBQUksRUFBRSxFQUFFLENBQUMsQ0FBQyxRQUFRLEVBQUUsQ0FBQyxTQUFTLEVBQUUsQ0FBQyxRQUFRLENBQUMsQ0FBQyxDQUFDO0lBQ3hFLENBQUMsQ0FBQyxDQUFDO0lBRUgsRUFBRSxDQUFDLGlCQUFpQixFQUFFLEtBQUssSUFBSSxFQUFFO1FBQy9CLE1BQU0sQ0FBQyxHQUFHLEVBQUUsQ0FBQyxVQUFVLENBQUMsRUFBRSxDQUFDLFFBQVEsQ0FBQyxDQUFDLENBQUMsSUFBSSxFQUFFLENBQUMsSUFBSSxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBRXJELGlCQUFpQixDQUFDLE1BQU0sQ0FBQyxDQUFDLElBQUksRUFBRSxFQUFFLENBQUMsQ0FBQyxJQUFJLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQztJQUNoRCxDQUFDLENBQUMsQ0FBQztJQUVILEVBQUUsQ0FBQyxpQkFBaUIsRUFBRSxLQUFLLElBQUksRUFBRTtRQUMvQixNQUFNLENBQUMsR0FBRyxFQUFFLENBQUMsUUFBUSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxHQUFHLENBQUMsQ0FBQyxDQUFDO1FBQ25DLE1BQU0sQ0FBQyxHQUFHLEVBQUUsQ0FBQyxVQUFVLENBQUMsQ0FBQyxDQUFDLENBQUM7UUFDM0IsaUJBQWlCLENBQUMsTUFBTSxDQUFDLENBQUMsSUFBSSxFQUFFLEVBQUUsQ0FBQyxHQUFHLEVBQUUsR0FBRyxFQUFFLEdBQUcsQ0FBQyxDQUFDLENBQUM7SUFDckQsQ0FBQyxDQUFDLENBQUM7SUFFSCxFQUFFLENBQUMsWUFBWSxFQUFFLEtBQUssSUFBSSxFQUFFO1FBQzFCLE1BQU0sQ0FBQyxHQUFHLEVBQUUsQ0FBQyxVQUFVLENBQUMsRUFBRSxDQUFDLFFBQVEsQ0FBQyxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDO1FBQ3hFLE1BQU0sUUFBUSxHQUNWLENBQUMsQ0FBQyxRQUFRLEVBQUUsQ0FBQyxTQUFTLEVBQUUsQ0FBQyxRQUFRLEVBQUUsQ0FBQyxTQUFTLEVBQUUsQ0FBQyxTQUFTLEVBQUUsQ0FBQyxTQUFTLENBQUMsQ0FBQztRQUMzRSxNQUFNLENBQUMsQ0FBQyxDQUFDLElBQUksQ0FBQyxDQUFDLElBQUksQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUN2QixpQkFBaUIsQ0FBQyxNQUFNLENBQUMsQ0FBQyxJQUFJLEVBQUUsRUFBRSxRQUFRLENBQUMsQ0FBQztJQUM5QyxDQUFDLENBQUMsQ0FBQztJQUVILEVBQUUsQ0FBQyxxQkFBcUIsRUFBRSxLQUFLLElBQUksRUFBRTtRQUNuQyxNQUFNLENBQUMsR0FBRyxFQUFFLENBQUMsVUFBVSxDQUFDLEVBQUUsQ0FBQyxRQUFRLENBQUMsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQ3JFLE1BQU0sUUFBUSxHQUNWLENBQUMsQ0FBQyxRQUFRLEVBQUUsQ0FBQyxTQUFTLEVBQUUsQ0FBQyxRQUFRLEVBQUUsQ0FBQyxTQUFTLEVBQUUsQ0FBQyxTQUFTLEVBQUUsQ0FBQyxTQUFTLENBQUMsQ0FBQztRQUMzRSxNQUFNLENBQUMsQ0FBQyxDQUFDLElBQUksQ0FBQyxDQUFDLElBQUksQ0FBQyxDQUFDLENBQUMsQ0FBQztRQUN2QixpQkFBaUIsQ0FBQyxNQUFNLENBQUMsQ0FBQyxJQUFJLEVBQUUsRUFBRSxRQUFRLENBQUMsQ0FBQztJQUM5QyxDQUFDLENBQUMsQ0FBQztJQUVILEVBQUUsQ0FBQyxhQUFhLEVBQUUsS0FBSyxJQUFJLEVBQUU7UUFDM0IsTUFBTSxDQUFDLEdBQUcsRUFBRSxDQUFDLFFBQVEsQ0FBQyxDQUFDLENBQUMsRUFBRSxDQUFDLEVBQUUsRUFBRSxDQUFDLENBQUMsQ0FBQztRQUNsQyxNQUFNLEVBQUUsR0FBRyxFQUFFLENBQUMsUUFBUSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBQ2xDLE1BQU0sRUFBRSxHQUFHLEVBQUUsQ0FBQyxJQUFJLENBQUMsQ0FBQyxDQUFDLEVBQUUsRUFBRSxDQUFDLENBQUMsQ0FBQyxVQUFVLEVBQUUsQ0FBQyxDQUFDLENBQUMsRUFBRSxFQUFFLENBQUMsQ0FBQztRQUVqRCxNQUFNLENBQUMsRUFBRSxDQUFDLEtBQUssQ0FBQyxDQUFDLE9BQU8sQ0FBQyxDQUFDLENBQUMsS0FBSyxDQUFDLENBQUM7UUFDbEMsaUJBQWlCLENBQUMsTUFBTSxFQUFFLENBQUMsSUFBSSxFQUFFLEVBQUUsQ0FBQyxTQUFTLEVBQUUsU0FBUyxFQUFFLENBQUMsU0FBUyxDQUFDLENBQUMsQ0FBQztJQUN6RSxDQUFDLENBQUMsQ0FBQztJQUVILEVBQUUsQ0FBQyx5QkFBeUIsRUFBRSxHQUFHLEVBQUU7UUFDakMsTUFBTSxDQUFDLEdBQUcsR0FBRyxFQUFFO1lBQ2IsRUFBRSxDQUFDLFVBQVUsQ0FBQyxFQUFFLENBQUMsUUFBUSxDQUFDLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxFQUFFLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxFQUFFLENBQUMsQ0FBQyxDQUFDLEVBQUUsQ0FBQyxDQUFDLENBQUM7UUFDaEUsQ0FBQyxDQUFDO1FBQ0YsTUFBTSxDQUFDLENBQUMsQ0FBQyxDQUFDLFlBQVksRUFBRSxDQUFDO0lBQzNCLENBQUMsQ0FBQyxDQUFDO0lBRUgsRUFBRSxDQUFDLGlDQUFpQyxFQUFFLEdBQUcsRUFBRTtRQUN6QyxNQUFNLENBQUMsR0FBRyxFQUFFLENBQUMsRUFBRSxDQUFDLFVBQVUsQ0FBQyxFQUFlLENBQUMsQ0FBQzthQUN2QyxZQUFZLENBQ1QsMkRBQTJELENBQUMsQ0FBQztJQUN2RSxDQUFDLENBQUMsQ0FBQztJQUVILEVBQUUsQ0FBQyw4QkFBOEIsRUFBRSxLQUFLLElBQUksRUFBRTtRQUM1QyxNQUFNLENBQUMsR0FBRyxFQUFFLENBQUMsVUFBVSxDQUFDLENBQUMsQ0FBQyxFQUFFLENBQUMsRUFBRSxDQUFDLENBQUMsQ0FBQyxDQUFDO1FBRW5DLGlCQUFpQixDQUFDLE1BQU0sQ0FBQyxDQUFDLElBQUksRUFBRSxFQUFFLENBQUMsQ0FBQyxRQUFRLEVBQUUsQ0FBQyxTQUFTLEVBQUUsQ0FBQyxRQUFRLENBQUMsQ0FBQyxDQUFDO0lBQ3hFLENBQUMsQ0FBQyxDQUFDO0FBQ0wsQ0FBQyxDQUFDLENBQUMiLCJzb3VyY2VzQ29udGVudCI6WyIvKipcbiAqIEBsaWNlbnNlXG4gKiBDb3B5cmlnaHQgMjAyMCBHb29nbGUgTExDLiBBbGwgUmlnaHRzIFJlc2VydmVkLlxuICogTGljZW5zZWQgdW5kZXIgdGhlIEFwYWNoZSBMaWNlbnNlLCBWZXJzaW9uIDIuMCAodGhlIFwiTGljZW5zZVwiKTtcbiAqIHlvdSBtYXkgbm90IHVzZSB0aGlzIGZpbGUgZXhjZXB0IGluIGNvbXBsaWFuY2Ugd2l0aCB0aGUgTGljZW5zZS5cbiAqIFlvdSBtYXkgb2J0YWluIGEgY29weSBvZiB0aGUgTGljZW5zZSBhdFxuICpcbiAqIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMFxuICpcbiAqIFVubGVzcyByZXF1aXJlZCBieSBhcHBsaWNhYmxlIGxhdyBvciBhZ3JlZWQgdG8gaW4gd3JpdGluZywgc29mdHdhcmVcbiAqIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuIFwiQVMgSVNcIiBCQVNJUyxcbiAqIFdJVEhPVVQgV0FSUkFOVElFUyBPUiBDT05ESVRJT05TIE9GIEFOWSBLSU5ELCBlaXRoZXIgZXhwcmVzcyBvciBpbXBsaWVkLlxuICogU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZFxuICogbGltaXRhdGlvbnMgdW5kZXIgdGhlIExpY2Vuc2UuXG4gKiA9PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PT09PVxuICovXG5cbmltcG9ydCAqIGFzIHRmIGZyb20gJy4uL2luZGV4JztcbmltcG9ydCB7QUxMX0VOVlMsIGRlc2NyaWJlV2l0aEZsYWdzfSBmcm9tICcuLi9qYXNtaW5lX3V0aWwnO1xuaW1wb3J0IHtleHBlY3RBcnJheXNDbG9zZX0gZnJvbSAnLi4vdGVzdF91dGlsJztcblxuZGVzY3JpYmVXaXRoRmxhZ3MoJ2xvZ1NvZnRtYXgnLCBBTExfRU5WUywgKCkgPT4ge1xuICBpdCgncmVndWxhciB0ZXN0JywgYXN5bmMgKCkgPT4ge1xuICAgIGNvbnN0IHkgPSB0Zi5sb2dTb2Z0bWF4KHRmLnRlbnNvcjFkKFsyLCAxLCAzXSkpO1xuXG4gICAgZXhwZWN0QXJyYXlzQ2xvc2UoYXdhaXQgeS5kYXRhKCksIFstMS40MDc2MDYsIC0yLjQwNzYwNjEsIC0wLjQwNzYwNl0pO1xuICB9KTtcblxuICBpdCgnSHVnZSBkaWZmZXJlbmNlJywgYXN5bmMgKCkgPT4ge1xuICAgIGNvbnN0IHkgPSB0Zi5sb2dTb2Z0bWF4KHRmLnRlbnNvcjFkKFstMTAwMCwgKzEwMDBdKSk7XG5cbiAgICBleHBlY3RBcnJheXNDbG9zZShhd2FpdCB5LmRhdGEoKSwgWy0yMDAwLCAwXSk7XG4gIH0pO1xuXG4gIGl0KCdQcm9wYWdhdGVzIE5hTnMnLCBhc3luYyAoKSA9PiB7XG4gICAgY29uc3QgYSA9IHRmLnRlbnNvcjFkKFsyLCAxLCBOYU5dKTtcbiAgICBjb25zdCB5ID0gdGYubG9nU29mdG1heChhKTtcbiAgICBleHBlY3RBcnJheXNDbG9zZShhd2FpdCB5LmRhdGEoKSwgW05hTiwgTmFOLCBOYU5dKTtcbiAgfSk7XG5cbiAgaXQoJzJELCBheGlzPTEnLCBhc3luYyAoKSA9PiB7XG4gICAgY29uc3QgeSA9IHRmLmxvZ1NvZnRtYXgodGYudGVuc29yMmQoW1syLCAxLCAzXSwgWzEsIDMsIDJdXSwgWzIsIDNdKSwgMSk7XG4gICAgY29uc3QgZXhwZWN0ZWQgPVxuICAgICAgICBbLTEuNDA3NjA2LCAtMi40MDc2MDYxLCAtMC40MDc2MDYsIC0yLjQwNzYwNjEsIC0wLjQwNzYwNjEsIC0xLjQwNzYwNjFdO1xuICAgIGV4cGVjdCh5LnJhbmspLnRvQmUoMik7XG4gICAgZXhwZWN0QXJyYXlzQ2xvc2UoYXdhaXQgeS5kYXRhKCksIGV4cGVjdGVkKTtcbiAgfSk7XG5cbiAgaXQoJzJELCBpbXBsaWNpdCBheGlzPTEnLCBhc3luYyAoKSA9PiB7XG4gICAgY29uc3QgeSA9IHRmLmxvZ1NvZnRtYXgodGYudGVuc29yMmQoW1syLCAxLCAzXSwgWzEsIDMsIDJdXSwgWzIsIDNdKSk7XG4gICAgY29uc3QgZXhwZWN0ZWQgPVxuICAgICAgICBbLTEuNDA3NjA2LCAtMi40MDc2MDYxLCAtMC40MDc2MDYsIC0yLjQwNzYwNjEsIC0wLjQwNzYwNjEsIC0xLjQwNzYwNjFdO1xuICAgIGV4cGVjdCh5LnJhbmspLnRvQmUoMik7XG4gICAgZXhwZWN0QXJyYXlzQ2xvc2UoYXdhaXQgeS5kYXRhKCksIGV4cGVjdGVkKTtcbiAgfSk7XG5cbiAgaXQoJzFEIGdyYWRpZW50JywgYXN5bmMgKCkgPT4ge1xuICAgIGNvbnN0IHggPSB0Zi50ZW5zb3IxZChbMSwgMiwgMTBdKTtcbiAgICBjb25zdCBkeSA9IHRmLnRlbnNvcjFkKFsxLCAyLCAzXSk7XG4gICAgY29uc3QgZHggPSB0Zi5ncmFkKCh4KSA9PiB4LmxvZ1NvZnRtYXgoKSkoeCwgZHkpO1xuXG4gICAgZXhwZWN0KGR4LnNoYXBlKS50b0VxdWFsKHguc2hhcGUpO1xuICAgIGV4cGVjdEFycmF5c0Nsb3NlKGF3YWl0IGR4LmRhdGEoKSwgWzAuOTk5MjU5OSwgMS45OTc5ODgxLCAtMi45OTcyNDc3XSk7XG4gIH0pO1xuXG4gIGl0KCcyRCwgYXhpcz0wIHRocm93cyBlcnJvcicsICgpID0+IHtcbiAgICBjb25zdCBmID0gKCkgPT4ge1xuICAgICAgdGYubG9nU29mdG1heCh0Zi50ZW5zb3IyZChbWzIsIDEsIDNdLCBbMSwgMywgMl1dLCBbMiwgM10pLCAwKTtcbiAgICB9O1xuICAgIGV4cGVjdChmKS50b1Rocm93RXJyb3IoKTtcbiAgfSk7XG5cbiAgaXQoJ3Rocm93cyB3aGVuIHBhc3NlZCBhIG5vbi10ZW5zb3InLCAoKSA9PiB7XG4gICAgZXhwZWN0KCgpID0+IHRmLmxvZ1NvZnRtYXgoe30gYXMgdGYuVGVuc29yKSlcbiAgICAgICAgLnRvVGhyb3dFcnJvcihcbiAgICAgICAgICAgIC9Bcmd1bWVudCAnbG9naXRzJyBwYXNzZWQgdG8gJ2xvZ1NvZnRtYXgnIG11c3QgYmUgYSBUZW5zb3IvKTtcbiAgfSk7XG5cbiAgaXQoJ2FjY2VwdHMgYSB0ZW5zb3ItbGlrZSBvYmplY3QnLCBhc3luYyAoKSA9PiB7XG4gICAgY29uc3QgeSA9IHRmLmxvZ1NvZnRtYXgoWzIsIDEsIDNdKTtcblxuICAgIGV4cGVjdEFycmF5c0Nsb3NlKGF3YWl0IHkuZGF0YSgpLCBbLTEuNDA3NjA2LCAtMi40MDc2MDYxLCAtMC40MDc2MDZdKTtcbiAgfSk7XG59KTtcbiJdfQ==
|