/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import * as tf from '../index'; import { ALL_ENVS, describeWithFlags } from '../jasmine_util'; import { expectArraysEqual } from '../test_util'; /** * Unit tests for confusionMatrix(). */ describeWithFlags('confusionMatrix', ALL_ENVS, () => { // Reference (Python) TensorFlow code: // // ```py // import tensorflow as tf // // tf.enable_eager_execution() // // labels = tf.constant([0, 1, 2, 1, 0]) // predictions = tf.constant([0, 2, 2, 1, 0]) // out = tf.confusion_matrix(labels, predictions, 3) // // print(out) // ``` it('3x3 all cases present in both labels and predictions', async () => { const labels = tf.tensor1d([0, 1, 2, 1, 0], 'int32'); const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'int32'); const numClasses = 3; const out = tf.math.confusionMatrix(labels, predictions, numClasses); expectArraysEqual(await out.data(), [2, 0, 0, 0, 1, 1, 0, 0, 1]); expect(out.dtype).toBe('int32'); expect(out.shape).toEqual([3, 3]); }); it('float32 arguments are accepted', async () => { const labels = tf.tensor1d([0, 1, 2, 1, 0], 'float32'); const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'float32'); const numClasses = 3; const out = tf.math.confusionMatrix(labels, predictions, numClasses); expectArraysEqual(await out.data(), [2, 0, 0, 0, 1, 1, 0, 0, 1]); expect(out.dtype).toBe('int32'); expect(out.shape).toEqual([3, 3]); }); // Reference (Python) TensorFlow code: // // ```py // import tensorflow as tf // // tf.enable_eager_execution() // // labels = tf.constant([3, 3, 2, 2, 1, 1, 0, 0]) // predictions = tf.constant([2, 2, 2, 2, 0, 0, 0, 0]) // out = tf.confusion_matrix(labels, predictions, 4) // // print(out) // ``` it('4x4 all cases present in labels, but not predictions', async () => { const labels = tf.tensor1d([3, 3, 2, 2, 1, 1, 0, 0], 'int32'); const predictions = tf.tensor1d([2, 2, 2, 2, 0, 0, 0, 0], 'int32'); const numClasses = 4; const out = tf.math.confusionMatrix(labels, predictions, numClasses); expectArraysEqual(await out.data(), [2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0]); expect(out.dtype).toBe('int32'); expect(out.shape).toEqual([4, 4]); }); it('4x4 all cases present in predictions, but not labels', async () => { const labels = tf.tensor1d([2, 2, 2, 2, 0, 0, 0, 0], 'int32'); const predictions = tf.tensor1d([3, 3, 2, 2, 1, 1, 0, 0], 'int32'); const numClasses = 4; const out = tf.math.confusionMatrix(labels, predictions, numClasses); expectArraysEqual(await out.data(), [2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0]); expect(out.dtype).toBe('int32'); expect(out.shape).toEqual([4, 4]); }); it('Plain arrays as inputs', async () => { const labels = [3, 3, 2, 2, 1, 1, 0, 0]; const predictions = [2, 2, 2, 2, 0, 0, 0, 0]; const numClasses = 4; const out = tf.math.confusionMatrix(labels, predictions, numClasses); expectArraysEqual(await out.data(), [2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0]); expect(out.dtype).toBe('int32'); expect(out.shape).toEqual([4, 4]); }); it('Int32Arrays as inputs', async () => { const labels = new Int32Array([3, 3, 2, 2, 1, 1, 0, 0]); const predictions = new Int32Array([2, 2, 2, 2, 0, 0, 0, 0]); const numClasses = 4; const out = tf.math.confusionMatrix(labels, predictions, numClasses); expectArraysEqual(await out.data(), [2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0]); expect(out.dtype).toBe('int32'); expect(out.shape).toEqual([4, 4]); }); // Reference (Python) TensorFlow code: // // ```py // import tensorflow as tf // // tf.enable_eager_execution() // // labels = tf.constant([0, 4]) // predictions = tf.constant([4, 0]) // out = tf.confusion_matrix(labels, predictions, 5) // // print(out) // ``` it('5x5 predictions and labels both missing some cases', async () => { const labels = tf.tensor1d([0, 4], 'int32'); const predictions = tf.tensor1d([4, 0], 'int32'); const numClasses = 5; const out = tf.math.confusionMatrix(labels, predictions, numClasses); expectArraysEqual(await out.data(), [ 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 ]); expect(out.dtype).toBe('int32'); expect(out.shape).toEqual([5, 5]); }); it('Invalid numClasses leads to Error', () => { expect(() => tf.math.confusionMatrix(tf.tensor1d([0, 1]), tf.tensor1d([1, 0]), 2.5)) .toThrowError(/numClasses .* positive integer.* got 2\.5/); }); it('Incorrect tensor rank leads to Error', () => { expect(() => tf.math.confusionMatrix( // tslint:disable-next-line:no-any tf.scalar(0), tf.scalar(0), 1)) .toThrowError(/rank .* 1.*got 0/); expect(() => // tslint:disable-next-line:no-any tf.math.confusionMatrix(tf.zeros([3, 3]), tf.zeros([9]), 2)) .toThrowError(/rank .* 1.*got 2/); expect(() => // tslint:disable-next-line:no-any tf.math.confusionMatrix(tf.zeros([9]), tf.zeros([3, 3]), 2)) .toThrowError(/rank .* 1.*got 2/); }); it('Mismatch in lengths leads to Error', () => { expect( // tslint:disable-next-line:no-any () => tf.math.confusionMatrix(tf.zeros([3]), tf.zeros([9]), 2)) .toThrowError(/Mismatch .* 3 vs.* 9/); }); }); //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"confusion_matrix_test.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/ops/confusion_matrix_test.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,KAAK,EAAE,MAAM,UAAU,CAAC;AAC/B,OAAO,EAAC,QAAQ,EAAE,iBAAiB,EAAC,MAAM,iBAAiB,CAAC;AAC5D,OAAO,EAAC,iBAAiB,EAAC,MAAM,cAAc,CAAC;AAE/C;;GAEG;AAEH,iBAAiB,CAAC,iBAAiB,EAAE,QAAQ,EAAE,GAAG,EAAE;IAClD,sCAAsC;IACtC,EAAE;IACF,QAAQ;IACR,0BAA0B;IAC1B,EAAE;IACF,8BAA8B;IAC9B,EAAE;IACF,wCAAwC;IACxC,6CAA6C;IAC7C,oDAAoD;IACpD,EAAE;IACF,aAAa;IACb,MAAM;IACN,EAAE,CAAC,sDAAsD,EAAE,KAAK,IAAI,EAAE;QACpE,MAAM,MAAM,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACrD,MAAM,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QAC1D,MAAM,UAAU,GAAG,CAAC,CAAC;QACrB,MAAM,GAAG,GAAG,EAAE,CAAC,IAAI,CAAC,eAAe,CAAC,MAAM,EAAE,WAAW,EAAE,UAAU,CAAC,CAAC;QACrE,iBAAiB,CAAC,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACjE,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;QAChC,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACpC,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,gCAAgC,EAAE,KAAK,IAAI,EAAE;QAC9C,MAAM,MAAM,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC;QACvD,MAAM,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC;QAC5D,MAAM,UAAU,GAAG,CAAC,CAAC;QACrB,MAAM,GAAG,GAAG,EAAE,CAAC,IAAI,CAAC,eAAe,CAAC,MAAM,EAAE,WAAW,EAAE,UAAU,CAAC,CAAC;QACrE,iBAAiB,CAAC,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACjE,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;QAChC,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACpC,CAAC,CAAC,CAAC;IAEH,sCAAsC;IACtC,EAAE;IACF,QAAQ;IACR,0BAA0B;IAC1B,EAAE;IACF,8BAA8B;IAC9B,EAAE;IACF,iDAAiD;IACjD,sDAAsD;IACtD,oDAAoD;IACpD,EAAE;IACF,aAAa;IACb,MAAM;IACN,EAAE,CAAC,sDAAsD,EAAE,KAAK,IAAI,EAAE;QACpE,MAAM,MAAM,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QAC9D,MAAM,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACnE,MAAM,UAAU,GAAG,CAAC,CAAC;QACrB,MAAM,GAAG,GAAG,EAAE,CAAC,IAAI,CAAC,eAAe,CAAC,MAAM,EAAE,WAAW,EAAE,UAAU,CAAC,CAAC;QACrE,iBAAiB,CACb,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACxE,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;QAChC,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACpC,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,sDAAsD,EAAE,KAAK,IAAI,EAAE;QACpE,MAAM,MAAM,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QAC9D,MAAM,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACnE,MAAM,UAAU,GAAG,CAAC,CAAC;QACrB,MAAM,GAAG,GAAG,EAAE,CAAC,IAAI,CAAC,eAAe,CAAC,MAAM,EAAE,WAAW,EAAE,UAAU,CAAC,CAAC;QACrE,iBAAiB,CACb,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACxE,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;QAChC,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACpC,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,wBAAwB,EAAE,KAAK,IAAI,EAAE;QACtC,MAAM,MAAM,GAAa,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;QAClD,MAAM,WAAW,GAAa,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;QACvD,MAAM,UAAU,GAAG,CAAC,CAAC;QACrB,MAAM,GAAG,GAAG,EAAE,CAAC,IAAI,CAAC,eAAe,CAAC,MAAM,EAAE,WAAW,EAAE,UAAU,CAAC,CAAC;QACrE,iBAAiB,CACb,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACxE,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;QAChC,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACpC,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,uBAAuB,EAAE,KAAK,IAAI,EAAE;QACrC,MAAM,MAAM,GAAG,IAAI,UAAU,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACxD,MAAM,WAAW,GAAG,IAAI,UAAU,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC7D,MAAM,UAAU,GAAG,CAAC,CAAC;QACrB,MAAM,GAAG,GAAG,EAAE,CAAC,IAAI,CAAC,eAAe,CAAC,MAAM,EAAE,WAAW,EAAE,UAAU,CAAC,CAAC;QACrE,iBAAiB,CACb,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACxE,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;QAChC,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACpC,CAAC,CAAC,CAAC;IAEH,sCAAsC;IACtC,EAAE;IACF,QAAQ;IACR,0BAA0B;IAC1B,EAAE;IACF,8BAA8B;IAC9B,EAAE;IACF,+BAA+B;IAC/B,oCAAoC;IACpC,oDAAoD;IACpD,EAAE;IACF,aAAa;IACb,MAAM;IACN,EAAE,CAAC,oDAAoD,EAAE,KAAK,IAAI,EAAE;QAClE,MAAM,MAAM,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QAC5C,MAAM,WAAW,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACjD,MAAM,UAAU,GAAG,CAAC,CAAC;QACrB,MAAM,GAAG,GAAG,EAAE,CAAC,IAAI,CAAC,eAAe,CAAC,MAAM,EAAE,WAAW,EAAE,UAAU,CAAC,CAAC;QACrE,iBAAiB,CAAC,MAAM,GAAG,CAAC,IAAI,EAAE,EAAE;YAClC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC;YACrC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC;SACnC,CAAC,CAAC;QACH,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;QAChC,MAAM,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACpC,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,mCAAmC,EAAE,GAAG,EAAE;QAC3C,MAAM,CACF,GAAG,EAAE,CAAC,EAAE,CAAC,IAAI,CAAC,eAAe,CACzB,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;aAClD,YAAY,CAAC,2CAA2C,CAAC,CAAC;IACjE,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,sCAAsC,EAAE,GAAG,EAAE;QAC9C,MAAM,CACF,GAAG,EAAE,CAAC,EAAE,CAAC,IAAI,CAAC,eAAe;QACzB,kCAAkC;QAClC,EAAE,CAAC,MAAM,CAAC,CAAC,CAAQ,EAAE,EAAE,CAAC,MAAM,CAAC,CAAC,CAAQ,EAAE,CAAC,CAAC,CAAC;aAChD,YAAY,CAAC,kBAAkB,CAAC,CAAC;QACtC,MAAM,CACF,GAAG,EAAE;QACD,kCAAkC;QACtC,EAAE,CAAC,IAAI,CAAC,eAAe,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAQ,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;aAClE,YAAY,CAAC,kBAAkB,CAAC,CAAC;QACtC,MAAM,CACF,GAAG,EAAE;QACD,kCAAkC;QACtC,EAAE,CAAC,IAAI,CAAC,eAAe,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAQ,EAAE,CAAC,CAAC,CAAC;aAClE,YAAY,CAAC,kBAAkB,CAAC,CAAC;IACxC,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,oCAAoC,EAAE,GAAG,EAAE;QAC5C,MAAM;QACF,kCAAkC;QAClC,GAAG,EAAE,CAAC,EAAE,CAAC,IAAI,CAAC,eAAe,CAAC,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAQ,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;aACrE,YAAY,CAAC,sBAAsB,CAAC,CAAC;IAC5C,CAAC,CAAC,CAAC;AACL,CAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2018 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as tf from '../index';\nimport {ALL_ENVS, describeWithFlags} from '../jasmine_util';\nimport {expectArraysEqual} from '../test_util';\n\n/**\n * Unit tests for confusionMatrix().\n */\n\ndescribeWithFlags('confusionMatrix', ALL_ENVS, () => {\n  // Reference (Python) TensorFlow code:\n  //\n  // ```py\n  // import tensorflow as tf\n  //\n  // tf.enable_eager_execution()\n  //\n  // labels = tf.constant([0, 1, 2, 1, 0])\n  // predictions = tf.constant([0, 2, 2, 1, 0])\n  // out = tf.confusion_matrix(labels, predictions, 3)\n  //\n  // print(out)\n  // ```\n  it('3x3 all cases present in both labels and predictions', async () => {\n    const labels = tf.tensor1d([0, 1, 2, 1, 0], 'int32');\n    const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'int32');\n    const numClasses = 3;\n    const out = tf.math.confusionMatrix(labels, predictions, numClasses);\n    expectArraysEqual(await out.data(), [2, 0, 0, 0, 1, 1, 0, 0, 1]);\n    expect(out.dtype).toBe('int32');\n    expect(out.shape).toEqual([3, 3]);\n  });\n\n  it('float32 arguments are accepted', async () => {\n    const labels = tf.tensor1d([0, 1, 2, 1, 0], 'float32');\n    const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'float32');\n    const numClasses = 3;\n    const out = tf.math.confusionMatrix(labels, predictions, numClasses);\n    expectArraysEqual(await out.data(), [2, 0, 0, 0, 1, 1, 0, 0, 1]);\n    expect(out.dtype).toBe('int32');\n    expect(out.shape).toEqual([3, 3]);\n  });\n\n  // Reference (Python) TensorFlow code:\n  //\n  // ```py\n  // import tensorflow as tf\n  //\n  // tf.enable_eager_execution()\n  //\n  // labels = tf.constant([3, 3, 2, 2, 1, 1, 0, 0])\n  // predictions = tf.constant([2, 2, 2, 2, 0, 0, 0, 0])\n  // out = tf.confusion_matrix(labels, predictions, 4)\n  //\n  // print(out)\n  // ```\n  it('4x4 all cases present in labels, but not predictions', async () => {\n    const labels = tf.tensor1d([3, 3, 2, 2, 1, 1, 0, 0], 'int32');\n    const predictions = tf.tensor1d([2, 2, 2, 2, 0, 0, 0, 0], 'int32');\n    const numClasses = 4;\n    const out = tf.math.confusionMatrix(labels, predictions, numClasses);\n    expectArraysEqual(\n        await out.data(), [2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0]);\n    expect(out.dtype).toBe('int32');\n    expect(out.shape).toEqual([4, 4]);\n  });\n\n  it('4x4 all cases present in predictions, but not labels', async () => {\n    const labels = tf.tensor1d([2, 2, 2, 2, 0, 0, 0, 0], 'int32');\n    const predictions = tf.tensor1d([3, 3, 2, 2, 1, 1, 0, 0], 'int32');\n    const numClasses = 4;\n    const out = tf.math.confusionMatrix(labels, predictions, numClasses);\n    expectArraysEqual(\n        await out.data(), [2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0]);\n    expect(out.dtype).toBe('int32');\n    expect(out.shape).toEqual([4, 4]);\n  });\n\n  it('Plain arrays as inputs', async () => {\n    const labels: number[] = [3, 3, 2, 2, 1, 1, 0, 0];\n    const predictions: number[] = [2, 2, 2, 2, 0, 0, 0, 0];\n    const numClasses = 4;\n    const out = tf.math.confusionMatrix(labels, predictions, numClasses);\n    expectArraysEqual(\n        await out.data(), [2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0]);\n    expect(out.dtype).toBe('int32');\n    expect(out.shape).toEqual([4, 4]);\n  });\n\n  it('Int32Arrays as inputs', async () => {\n    const labels = new Int32Array([3, 3, 2, 2, 1, 1, 0, 0]);\n    const predictions = new Int32Array([2, 2, 2, 2, 0, 0, 0, 0]);\n    const numClasses = 4;\n    const out = tf.math.confusionMatrix(labels, predictions, numClasses);\n    expectArraysEqual(\n        await out.data(), [2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0]);\n    expect(out.dtype).toBe('int32');\n    expect(out.shape).toEqual([4, 4]);\n  });\n\n  // Reference (Python) TensorFlow code:\n  //\n  // ```py\n  // import tensorflow as tf\n  //\n  // tf.enable_eager_execution()\n  //\n  // labels = tf.constant([0, 4])\n  // predictions = tf.constant([4, 0])\n  // out = tf.confusion_matrix(labels, predictions, 5)\n  //\n  // print(out)\n  // ```\n  it('5x5 predictions and labels both missing some cases', async () => {\n    const labels = tf.tensor1d([0, 4], 'int32');\n    const predictions = tf.tensor1d([4, 0], 'int32');\n    const numClasses = 5;\n    const out = tf.math.confusionMatrix(labels, predictions, numClasses);\n    expectArraysEqual(await out.data(), [\n      0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,\n      0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0\n    ]);\n    expect(out.dtype).toBe('int32');\n    expect(out.shape).toEqual([5, 5]);\n  });\n\n  it('Invalid numClasses leads to Error', () => {\n    expect(\n        () => tf.math.confusionMatrix(\n            tf.tensor1d([0, 1]), tf.tensor1d([1, 0]), 2.5))\n        .toThrowError(/numClasses .* positive integer.* got 2\\.5/);\n  });\n\n  it('Incorrect tensor rank leads to Error', () => {\n    expect(\n        () => tf.math.confusionMatrix(\n            // tslint:disable-next-line:no-any\n            tf.scalar(0) as any, tf.scalar(0) as any, 1))\n        .toThrowError(/rank .* 1.*got 0/);\n    expect(\n        () =>\n            // tslint:disable-next-line:no-any\n        tf.math.confusionMatrix(tf.zeros([3, 3]) as any, tf.zeros([9]), 2))\n        .toThrowError(/rank .* 1.*got 2/);\n    expect(\n        () =>\n            // tslint:disable-next-line:no-any\n        tf.math.confusionMatrix(tf.zeros([9]), tf.zeros([3, 3]) as any, 2))\n        .toThrowError(/rank .* 1.*got 2/);\n  });\n\n  it('Mismatch in lengths leads to Error', () => {\n    expect(\n        // tslint:disable-next-line:no-any\n        () => tf.math.confusionMatrix(tf.zeros([3]) as any, tf.zeros([9]), 2))\n        .toThrowError(/Mismatch .* 3 vs.* 9/);\n  });\n});\n"]}