/** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import * as tf from '../index'; import { ALL_ENVS, describeWithFlags } from '../jasmine_util'; import { expectArraysClose } from '../test_util'; describeWithFlags('softmax', ALL_ENVS, () => { it('regular test', async () => { const y = tf.softmax(tf.tensor1d([2, 1, 3])); expectArraysClose(await y.data(), [0.24472847, 0.09003057, 0.66524095]); expectArraysClose(await y.sum().data(), 1); }); it('overflow', async () => { const y = tf.softmax(tf.tensor1d([100, 100])); expectArraysClose(await y.data(), [0.5, 0.5]); }); it('underflow', async () => { const y = tf.softmax(tf.tensor1d([-100, -100])); expectArraysClose(await y.data(), [0.5, 0.5]); }); it('odd number of inputs', async () => { const y = tf.softmax(tf.tensor1d([-400, -400, 0, -400, -400, -400, -400])); expectArraysClose(await y.data(), [0, 0, 1, 0, 0, 0, 0]); }); it('Huge difference between probabilities', async () => { const y = tf.softmax(tf.tensor1d([-1000, +1000])); expectArraysClose(await y.data(), [0, 1]); }); it('Propagates NaNs', async () => { const a = tf.tensor1d([2, 1, NaN]); const y = tf.softmax(a); expectArraysClose(await y.data(), [NaN, NaN, NaN]); }); it('2D, dim=1', async () => { const y = tf.softmax(tf.tensor2d([[2, 1, 3], [1, 3, 2]], [2, 3]), 1); const expected = [ 0.24472847, 0.09003057, 0.66524095, 0.09003057, 0.66524095, 0.24472847 ]; expect(y.rank).toBe(2); expectArraysClose(await y.data(), expected); }); it('2D, implicit dim=1', async () => { const y = tf.softmax(tf.tensor2d([[2, 1, 3], [1, 3, 2]], [2, 3])); const expected = [ 0.24472847, 0.09003057, 0.66524095, 0.09003057, 0.66524095, 0.24472847 ]; expect(y.rank).toBe(2); expectArraysClose(await y.data(), expected); }); it('2D, dim=0 throws error', () => { const f = () => { tf.softmax(tf.tensor2d([[2, 1, 3], [1, 3, 2]], [2, 3]), 0); }; expect(f).toThrowError(); }); it('1D gradient', async () => { const x = tf.tensor1d([10, 0, -1]); const y = tf.softmax(x); const dy = tf.tensor1d([1, 2, 3]); const dx = tf.grad((x) => x.softmax())(x, dy); const totalSum = tf.sum(tf.mul(dy, y)); const dyVals = await dy.array(); const sumVals = await totalSum.array(); const yVals = await y.array(); expect(dx.shape).toEqual(x.shape); expectArraysClose(await dx.data(), [ (dyVals[0] - sumVals) * yVals[0], (dyVals[1] - sumVals) * yVals[1], (dyVals[2] - sumVals) * yVals[2], ]); }); it('gradient with clones', () => { const x = tf.tensor1d([10, 0, -1]); const dx = tf.grad((x) => x.clone().softmax().clone())(x); expect(dx.shape).toEqual(x.shape); expect(dx.dtype).toBe('float32'); }); it('2D gradient', async () => { const x = tf.tensor2d([10, 0, -1, 5, 4, 3], [2, 3]); const y = tf.softmax(x); const dy = tf.tensor2d([3, 2, 1, 1, 2, 3], [2, 3]); const dx = tf.grad((x) => x.softmax())(x, dy); const axis = -1; const totalSum = tf.sum(tf.mul(dy, y), axis); const dyVals = await dy.array(); const sumVals = await totalSum.array(); const yVals = await y.array(); expect(dx.shape).toEqual(x.shape); expectArraysClose(await dx.data(), [ (dyVals[0][0] - sumVals[0]) * yVals[0][0], (dyVals[0][1] - sumVals[0]) * yVals[0][1], (dyVals[0][2] - sumVals[0]) * yVals[0][2], (dyVals[1][0] - sumVals[1]) * yVals[1][0], (dyVals[1][1] - sumVals[1]) * yVals[1][1], (dyVals[1][2] - sumVals[1]) * yVals[1][2] ]); }); it('throws when passed a non-tensor', () => { expect(() => tf.softmax({})) .toThrowError(/Argument 'logits' passed to 'softmax' must be a Tensor/); }); it('throws when passed an int32 tensor', async () => { expect(() => tf.softmax(tf.tensor1d([2, 1, 3], 'int32'))) .toThrowError(/Argument 'logits' passed to 'softmax' must be float32/); }); it('accepts a tensor-like object', async () => { const y = tf.softmax([2, 1, 3]); expectArraysClose(await y.data(), [0.24472847, 0.09003057, 0.66524095]); expectArraysClose(await y.sum().data(), 1); }); }); //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"softmax_test.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/ops/softmax_test.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,KAAK,EAAE,MAAM,UAAU,CAAC;AAC/B,OAAO,EAAC,QAAQ,EAAE,iBAAiB,EAAC,MAAM,iBAAiB,CAAC;AAC5D,OAAO,EAAC,iBAAiB,EAAC,MAAM,cAAc,CAAC;AAE/C,iBAAiB,CAAC,SAAS,EAAE,QAAQ,EAAE,GAAG,EAAE;IAC1C,EAAE,CAAC,cAAc,EAAE,KAAK,IAAI,EAAE;QAC5B,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QAE7C,iBAAiB,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,UAAU,EAAE,UAAU,EAAE,UAAU,CAAC,CAAC,CAAC;QACxE,iBAAiB,CAAC,MAAM,CAAC,CAAC,GAAG,EAAE,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,CAAC;IAC7C,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,UAAU,EAAE,KAAK,IAAI,EAAE;QACxB,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,EAAE,CAAC,QAAQ,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC;QAE9C,iBAAiB,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;IAChD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,WAAW,EAAE,KAAK,IAAI,EAAE;QACzB,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;QAEhD,iBAAiB,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;IAChD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,sBAAsB,EAAE,KAAK,IAAI,EAAE;QACpC,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;QAE3E,iBAAiB,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,uCAAuC,EAAE,KAAK,IAAI,EAAE;QACrD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,IAAI,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAElD,iBAAiB,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IAC5C,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,iBAAiB,EAAE,KAAK,IAAI,EAAE;QAC/B,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,GAAG,CAAC,CAAC,CAAC;QACnC,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;QACxB,iBAAiB,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,GAAG,EAAE,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;IACrD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,WAAW,EAAE,KAAK,IAAI,EAAE;QACzB,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;QACrE,MAAM,QAAQ,GAAG;YACf,UAAU,EAAE,UAAU,EAAE,UAAU,EAAE,UAAU,EAAE,UAAU,EAAE,UAAU;SACvE,CAAC;QACF,MAAM,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QACvB,iBAAiB,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,EAAE,QAAQ,CAAC,CAAC;IAC9C,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,oBAAoB,EAAE,KAAK,IAAI,EAAE;QAClC,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QAClE,MAAM,QAAQ,GAAG;YACf,UAAU,EAAE,UAAU,EAAE,UAAU,EAAE,UAAU,EAAE,UAAU,EAAE,UAAU;SACvE,CAAC;QACF,MAAM,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QACvB,iBAAiB,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,EAAE,QAAQ,CAAC,CAAC;IAC9C,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,wBAAwB,EAAE,GAAG,EAAE;QAChC,MAAM,CAAC,GAAG,GAAG,EAAE;YACb,EAAE,CAAC,OAAO,CAAC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;QAC7D,CAAC,CAAC;QACF,MAAM,CAAC,CAAC,CAAC,CAAC,YAAY,EAAE,CAAC;IAC3B,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,aAAa,EAAE,KAAK,IAAI,EAAE;QAC3B,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QACnC,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;QACxB,MAAM,EAAE,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAClC,MAAM,EAAE,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAE9C,MAAM,QAAQ,GAAc,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,CAAC;QAElD,MAAM,MAAM,GAAG,MAAM,EAAE,CAAC,KAAK,EAAE,CAAC;QAChC,MAAM,OAAO,GAAG,MAAM,QAAQ,CAAC,KAAK,EAAE,CAAC;QACvC,MAAM,KAAK,GAAG,MAAM,CAAC,CAAC,KAAK,EAAE,CAAC;QAC9B,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QAClC,iBAAiB,CAAC,MAAM,EAAE,CAAC,IAAI,EAAE,EAAE;YACjC,CAAC,MAAM,CAAC,CAAC,CAAC,GAAG,OAAO,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC;YAChC,CAAC,MAAM,CAAC,CAAC,CAAC,GAAG,OAAO,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC;YAChC,CAAC,MAAM,CAAC,CAAC,CAAC,GAAG,OAAO,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC;SACjC,CAAC,CAAC;IACL,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,sBAAsB,EAAE,GAAG,EAAE;QAC9B,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QACnC,MAAM,EAAE,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,KAAK,EAAE,CAAC,OAAO,EAAE,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QAC1D,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QAClC,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;IACnC,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,aAAa,EAAE,KAAK,IAAI,EAAE;QAC3B,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;QACxB,MAAM,EAAE,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACnD,MAAM,EAAE,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC;QAE9C,MAAM,IAAI,GAAG,CAAC,CAAC,CAAC;QAChB,MAAM,QAAQ,GAAgB,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC;QAE1D,MAAM,MAAM,GAAG,MAAM,EAAE,CAAC,KAAK,EAAE,CAAC;QAChC,MAAM,OAAO,GAAG,MAAM,QAAQ,CAAC,KAAK,EAAE,CAAC;QACvC,MAAM,KAAK,GAAG,MAAM,CAAC,CAAC,KAAK,EAAE,CAAC;QAE9B,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QAClC,iBAAiB,CAAC,MAAM,EAAE,CAAC,IAAI,EAAE,EAAE;YACjC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YACzC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YACzC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YACzC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YACzC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YACzC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;SAC1C,CAAC,CAAC;IACL,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,iCAAiC,EAAE,GAAG,EAAE;QACzC,MAAM,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,OAAO,CAAC,EAAe,CAAC,CAAC;aACpC,YAAY,CAAC,wDAAwD,CAAC,CAAC;IAC9E,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,oCAAoC,EAAE,KAAK,IAAI,EAAE;QAClD,MAAM,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,OAAO,CAAC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC;aACpD,YAAY,CAAC,uDAAuD,CAAC,CAAC;IAC7E,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,8BAA8B,EAAE,KAAK,IAAI,EAAE;QAC5C,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAEhC,iBAAiB,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,EAAE,CAAC,UAAU,EAAE,UAAU,EAAE,UAAU,CAAC,CAAC,CAAC;QACxE,iBAAiB,CAAC,MAAM,CAAC,CAAC,GAAG,EAAE,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,CAAC;IAC7C,CAAC,CAAC,CAAC;AACL,CAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2017 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as tf from '../index';\nimport {ALL_ENVS, describeWithFlags} from '../jasmine_util';\nimport {expectArraysClose} from '../test_util';\n\ndescribeWithFlags('softmax', ALL_ENVS, () => {\n  it('regular test', async () => {\n    const y = tf.softmax(tf.tensor1d([2, 1, 3]));\n\n    expectArraysClose(await y.data(), [0.24472847, 0.09003057, 0.66524095]);\n    expectArraysClose(await y.sum().data(), 1);\n  });\n\n  it('overflow', async () => {\n    const y = tf.softmax(tf.tensor1d([100, 100]));\n\n    expectArraysClose(await y.data(), [0.5, 0.5]);\n  });\n\n  it('underflow', async () => {\n    const y = tf.softmax(tf.tensor1d([-100, -100]));\n\n    expectArraysClose(await y.data(), [0.5, 0.5]);\n  });\n\n  it('odd number of inputs', async () => {\n    const y = tf.softmax(tf.tensor1d([-400, -400, 0, -400, -400, -400, -400]));\n\n    expectArraysClose(await y.data(), [0, 0, 1, 0, 0, 0, 0]);\n  });\n\n  it('Huge difference between probabilities', async () => {\n    const y = tf.softmax(tf.tensor1d([-1000, +1000]));\n\n    expectArraysClose(await y.data(), [0, 1]);\n  });\n\n  it('Propagates NaNs', async () => {\n    const a = tf.tensor1d([2, 1, NaN]);\n    const y = tf.softmax(a);\n    expectArraysClose(await y.data(), [NaN, NaN, NaN]);\n  });\n\n  it('2D, dim=1', async () => {\n    const y = tf.softmax(tf.tensor2d([[2, 1, 3], [1, 3, 2]], [2, 3]), 1);\n    const expected = [\n      0.24472847, 0.09003057, 0.66524095, 0.09003057, 0.66524095, 0.24472847\n    ];\n    expect(y.rank).toBe(2);\n    expectArraysClose(await y.data(), expected);\n  });\n\n  it('2D, implicit dim=1', async () => {\n    const y = tf.softmax(tf.tensor2d([[2, 1, 3], [1, 3, 2]], [2, 3]));\n    const expected = [\n      0.24472847, 0.09003057, 0.66524095, 0.09003057, 0.66524095, 0.24472847\n    ];\n    expect(y.rank).toBe(2);\n    expectArraysClose(await y.data(), expected);\n  });\n\n  it('2D, dim=0 throws error', () => {\n    const f = () => {\n      tf.softmax(tf.tensor2d([[2, 1, 3], [1, 3, 2]], [2, 3]), 0);\n    };\n    expect(f).toThrowError();\n  });\n\n  it('1D gradient', async () => {\n    const x = tf.tensor1d([10, 0, -1]);\n    const y = tf.softmax(x);\n    const dy = tf.tensor1d([1, 2, 3]);\n    const dx = tf.grad((x) => x.softmax())(x, dy);\n\n    const totalSum: tf.Scalar = tf.sum(tf.mul(dy, y));\n\n    const dyVals = await dy.array();\n    const sumVals = await totalSum.array();\n    const yVals = await y.array();\n    expect(dx.shape).toEqual(x.shape);\n    expectArraysClose(await dx.data(), [\n      (dyVals[0] - sumVals) * yVals[0],\n      (dyVals[1] - sumVals) * yVals[1],\n      (dyVals[2] - sumVals) * yVals[2],\n    ]);\n  });\n\n  it('gradient with clones', () => {\n    const x = tf.tensor1d([10, 0, -1]);\n    const dx = tf.grad((x) => x.clone().softmax().clone())(x);\n    expect(dx.shape).toEqual(x.shape);\n    expect(dx.dtype).toBe('float32');\n  });\n\n  it('2D gradient', async () => {\n    const x = tf.tensor2d([10, 0, -1, 5, 4, 3], [2, 3]);\n    const y = tf.softmax(x);\n    const dy = tf.tensor2d([3, 2, 1, 1, 2, 3], [2, 3]);\n    const dx = tf.grad((x) => x.softmax())(x, dy);\n\n    const axis = -1;\n    const totalSum: tf.Tensor1D = tf.sum(tf.mul(dy, y), axis);\n\n    const dyVals = await dy.array();\n    const sumVals = await totalSum.array();\n    const yVals = await y.array();\n\n    expect(dx.shape).toEqual(x.shape);\n    expectArraysClose(await dx.data(), [\n      (dyVals[0][0] - sumVals[0]) * yVals[0][0],\n      (dyVals[0][1] - sumVals[0]) * yVals[0][1],\n      (dyVals[0][2] - sumVals[0]) * yVals[0][2],\n      (dyVals[1][0] - sumVals[1]) * yVals[1][0],\n      (dyVals[1][1] - sumVals[1]) * yVals[1][1],\n      (dyVals[1][2] - sumVals[1]) * yVals[1][2]\n    ]);\n  });\n\n  it('throws when passed a non-tensor', () => {\n    expect(() => tf.softmax({} as tf.Tensor))\n        .toThrowError(/Argument 'logits' passed to 'softmax' must be a Tensor/);\n  });\n\n  it('throws when passed an int32 tensor', async () => {\n    expect(() => tf.softmax(tf.tensor1d([2, 1, 3], 'int32')))\n        .toThrowError(/Argument 'logits' passed to 'softmax' must be float32/);\n  });\n\n  it('accepts a tensor-like object', async () => {\n    const y = tf.softmax([2, 1, 3]);\n\n    expectArraysClose(await y.data(), [0.24472847, 0.09003057, 0.66524095]);\n    expectArraysClose(await y.sum().data(), 1);\n  });\n});\n"]}