1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
| /**
| * @license
| * Copyright 2022 Google LLC.
| * Licensed under the Apache License, Version 2.0 (the "License");
| * you may not use this file except in compliance with the License.
| * You may obtain a copy of the License at
| *
| * http://www.apache.org/licenses/LICENSE-2.0
| *
| * Unless required by applicable law or agreed to in writing, software
| * distributed under the License is distributed on an "AS IS" BASIS,
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
| * See the License for the specific language governing permissions and
| * limitations under the License.
| * =============================================================================
| */
| import * as tf from '../index';
| import { ALL_ENVS, describeWithFlags } from '../jasmine_util';
| import { expectArraysClose } from '../test_util';
| async function runRaggedGather(starts, limits, deltas) {
| const output = tf.raggedRange(starts, limits, deltas);
| expect(output.rtNestedSplits.dtype).toEqual('int32');
| expect(output.rtNestedSplits.shape.length).toEqual(1);
| expect(output.rtDenseValues.dtype).toEqual(starts.dtype);
| expect(output.rtDenseValues.shape.length).toEqual(1);
| return {
| rtNestedSplits: await output.rtNestedSplits.data(),
| rtDenseValues: await output.rtDenseValues.data(),
| tensors: Object.values(output)
| };
| }
| describeWithFlags('raggedRange ', ALL_ENVS, () => {
| it('IntValues', async () => {
| const result = await runRaggedGather(tf.tensor1d([0, 5, 8, 5], 'int32'), tf.tensor1d([8, 7, 8, 1], 'int32'), tf.tensor1d([2, 1, 1, -1], 'int32'));
| // Expected: [[0, 2, 4, 6], [5, 6], [], [5, 4, 3, 2]]
| expectArraysClose(result.rtNestedSplits, [0, 4, 6, 6, 10]);
| expectArraysClose(result.rtDenseValues, [0, 2, 4, 6, 5, 6, 5, 4, 3, 2]);
| });
| it('FloatValues', async () => {
| const result = await runRaggedGather(tf.tensor1d([0, 5, 8, 5], 'float32'), tf.tensor1d([8, 7, 8, 1], 'float32'), tf.tensor1d([2, 1, 1, -1], 'float32'));
| // Expected: [[0, 2, 4, 6], [5, 6], [], [5, 4, 3, 2]]
| expectArraysClose(result.rtNestedSplits, [0, 4, 6, 6, 10]);
| expectArraysClose(result.rtDenseValues, [0, 2, 4, 6, 5, 6, 5, 4, 3, 2]);
| });
| it('RangeSizeOverflow', async () => {
| await expectAsync(runRaggedGather(tf.tensor1d([1.1, 0.1], 'float32'), tf.tensor1d([10, 1e10], 'float32'), tf.tensor1d([1, 1e-10], 'float32')))
| .toBeRejectedWithError('Requires ((limit - start) / delta) <= 2147483647');
| });
| it('BroadcastDeltas', async () => {
| const result = await runRaggedGather(tf.tensor1d([0, 5, 8], 'int32'), tf.tensor1d([8, 7, 8], 'int32'), tf.scalar(1, 'int32'));
| // Expected: [[0, 1, 2, 3, 4, 5, 6, 7], [5, 6], []]
| expectArraysClose(result.rtNestedSplits, [0, 8, 10, 10]);
| expectArraysClose(result.rtDenseValues, [0, 1, 2, 3, 4, 5, 6, 7, 5, 6]);
| });
| it('BroadcastLimitsAndDeltas', async () => {
| const result = await runRaggedGather(tf.scalar(0, 'int32'), tf.tensor1d([3, 0, 2], 'int32'), tf.scalar(1, 'int32'));
| // Expected: [[0, 1, 2], [], [0, 1]]
| expectArraysClose(result.rtNestedSplits, [0, 3, 3, 5]);
| expectArraysClose(result.rtDenseValues, [0, 1, 2, 0, 1]);
| });
| it('BroadcastStartsAndLimits', async () => {
| const result = await runRaggedGather(tf.scalar(0, 'int32'), tf.scalar(12, 'int32'), tf.tensor1d([3, 4, 5], 'int32'));
| // Expected: [[0, 3, 6, 9], [0, 4, 8], [0, 5, 10]]
| expectArraysClose(result.rtNestedSplits, [0, 4, 7, 10]);
| expectArraysClose(result.rtDenseValues, [0, 3, 6, 9, 0, 4, 8, 0, 5, 10]);
| });
| it('AllScalarInputs', async () => {
| const result = await runRaggedGather(tf.scalar(0, 'int32'), tf.scalar(5, 'int32'), tf.scalar(1, 'int32'));
| // Expected: [[0, 1, 2, 3, 4]]
| expectArraysClose(result.rtNestedSplits, [0, 5]);
| expectArraysClose(result.rtDenseValues, [0, 1, 2, 3, 4]);
| });
| it('InvalidArgsStarts', async () => {
| await expectAsync(runRaggedGather(tf.tensor2d([0, 5, 8, 5], [4, 1], 'int32'), tf.tensor1d([8, 7, 8, 1], 'int32'), tf.tensor1d([2, 1, 1, -1], 'int32')))
| .toBeRejectedWithError('starts must be a scalar or vector');
| });
| it('InvalidArgsLimits', async () => {
| await expectAsync(runRaggedGather(tf.tensor1d([0, 5, 8, 5], 'int32'), tf.tensor2d([8, 7, 8, 1], [4, 1], 'int32'), tf.tensor1d([2, 1, 1, -1], 'int32')))
| .toBeRejectedWithError('limits must be a scalar or vector');
| });
| it('InvalidArgsDeltas', async () => {
| await expectAsync(runRaggedGather(tf.tensor1d([0, 5, 8, 5], 'int32'), tf.tensor1d([8, 7, 8, 1], 'int32'), tf.tensor2d([2, 1, 1, -1], [4, 1], 'int32')))
| .toBeRejectedWithError('deltas must be a scalar or vector');
| });
| it('InvalidArgsShapeMismatch', async () => {
| await expectAsync(runRaggedGather(tf.tensor1d([0, 5, 8, 5], 'int32'), tf.tensor1d([7, 8, 1], 'int32'), tf.tensor1d([2, 1, 1, -1], 'int32')))
| .toBeRejectedWithError('starts, limits, and deltas must have the same shape');
| });
| it('InvalidArgsZeroDelta', async () => {
| await expectAsync(runRaggedGather(tf.tensor1d([0, 5, 8, 5], 'int32'), tf.tensor1d([7, 8, 8, 1], 'int32'), tf.tensor1d([2, 1, 0, -1], 'int32')))
| .toBeRejectedWithError('Requires delta != 0');
| });
| it('EmptyRangePositiveDelta', async () => {
| const result = await runRaggedGather(tf.tensor1d([0, 5], 'int32'), tf.tensor1d([5, 0], 'int32'), tf.scalar(2, 'int32'));
| // Expected: [[0, 2, 4], []]
| expectArraysClose(result.rtNestedSplits, [0, 3, 3]);
| expectArraysClose(result.rtDenseValues, [0, 2, 4]);
| });
| it('EmptyRangeNegativeDelta', async () => {
| const result = await runRaggedGather(tf.tensor1d([0, 5], 'int32'), tf.tensor1d([5, 0], 'int32'), tf.scalar(-2, 'int32'));
| // Expected: [[], [5, 3, 1]]
| expectArraysClose(result.rtNestedSplits, [0, 0, 3]);
| expectArraysClose(result.rtDenseValues, [5, 3, 1]);
| });
| it('does not have memory leak.', async () => {
| const beforeDataIds = tf.engine().backend.numDataIds();
| const starts = tf.tensor1d([0, 5, 8, 5], 'int32');
| const limits = tf.tensor1d([8, 7, 8, 1], 'int32');
| const deltas = tf.tensor1d([2, 1, 1, -1], 'int32');
| const result = await runRaggedGather(starts, limits, deltas);
| const afterResDataIds = tf.engine().backend.numDataIds();
| expect(afterResDataIds).toEqual(beforeDataIds + 5);
| tf.dispose([starts, limits, deltas]);
| tf.dispose(result.tensors);
| const afterDisposeDataIds = tf.engine().backend.numDataIds();
| expect(afterDisposeDataIds).toEqual(beforeDataIds);
| });
| });
| //# sourceMappingURL=data:application/json;base64,{"version":3,"file":"ragged_range_test.js","sourceRoot":"","sources":["../../../../../../tfjs-core/src/ops/ragged_range_test.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;;;;;;GAeG;AAEH,OAAO,KAAK,EAAE,MAAM,UAAU,CAAC;AAC/B,OAAO,EAAC,QAAQ,EAAE,iBAAiB,EAAC,MAAM,iBAAiB,CAAC;AAC5D,OAAO,EAAC,iBAAiB,EAAC,MAAM,cAAc,CAAC;AAE/C,KAAK,UAAU,eAAe,CAC1B,MAAiB,EAAE,MAAiB,EAAE,MAAiB;IACzD,MAAM,MAAM,GAAG,EAAE,CAAC,WAAW,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,CAAC,CAAC;IAEtD,MAAM,CAAC,MAAM,CAAC,cAAc,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,OAAO,CAAC,CAAC;IACrD,MAAM,CAAC,MAAM,CAAC,cAAc,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IAEtD,MAAM,CAAC,MAAM,CAAC,aAAa,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;IACzD,MAAM,CAAC,MAAM,CAAC,aAAa,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IAErD,OAAO;QACL,cAAc,EAAE,MAAM,MAAM,CAAC,cAAc,CAAC,IAAI,EAAE;QAClD,aAAa,EAAE,MAAM,MAAM,CAAC,aAAa,CAAC,IAAI,EAAE;QAChD,OAAO,EAAE,MAAM,CAAC,MAAM,CAAC,MAAM,CAAC;KAC/B,CAAC;AACJ,CAAC;AAED,iBAAiB,CAAC,cAAc,EAAE,QAAQ,EAAE,GAAG,EAAE;IAC/C,EAAE,CAAC,WAAW,EAAE,KAAK,IAAI,EAAE;QACzB,MAAM,MAAM,GAAG,MAAM,eAAe,CAChC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EAAE,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EACtE,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC;QAEzC,qDAAqD;QACrD,iBAAiB,CAAC,MAAM,CAAC,cAAc,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3D,iBAAiB,CAAC,MAAM,CAAC,aAAa,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IAC1E,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,aAAa,EAAE,KAAK,IAAI,EAAE;QAC3B,MAAM,MAAM,GAAG,MAAM,eAAe,CAChC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,SAAS,CAAC,EACpC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,SAAS,CAAC,EACpC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;QAE3C,qDAAqD;QACrD,iBAAiB,CAAC,MAAM,CAAC,cAAc,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAC3D,iBAAiB,CAAC,MAAM,CAAC,aAAa,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IAC1E,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,mBAAmB,EAAE,KAAK,IAAI,EAAE;QACjC,MAAM,WAAW,CAAC,eAAe,CACX,EAAE,CAAC,QAAQ,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,SAAS,CAAC,EAClC,EAAE,CAAC,QAAQ,CAAC,CAAC,EAAE,EAAE,IAAI,CAAC,EAAE,SAAS,CAAC,EAClC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,KAAK,CAAC,EAAE,SAAS,CAAC,CAAC,CAAC;aACrD,qBAAqB,CAClB,kDAAkD,CAAC,CAAC;IAC9D,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,iBAAiB,EAAE,KAAK,IAAI,EAAE;QAC/B,MAAM,MAAM,GAAG,MAAM,eAAe,CAChC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EAAE,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EAChE,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC;QAE3B,mDAAmD;QACnD,iBAAiB,CAAC,MAAM,CAAC,cAAc,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;QACzD,iBAAiB,CAAC,MAAM,CAAC,aAAa,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IAC1E,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,0BAA0B,EAAE,KAAK,IAAI,EAAE;QACxC,MAAM,MAAM,GAAG,MAAM,eAAe,CAChC,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,OAAO,CAAC,EAAE,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EACtD,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC;QAE3B,oCAAoC;QACpC,iBAAiB,CAAC,MAAM,CAAC,cAAc,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACvD,iBAAiB,CAAC,MAAM,CAAC,aAAa,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,0BAA0B,EAAE,KAAK,IAAI,EAAE;QACxC,MAAM,MAAM,GAAG,MAAM,eAAe,CAChC,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,OAAO,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC,EAAE,EAAE,OAAO,CAAC,EAC7C,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC;QAErC,kDAAkD;QAClD,iBAAiB,CAAC,MAAM,CAAC,cAAc,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QACxD,iBAAiB,CAAC,MAAM,CAAC,aAAa,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;IAC3E,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,iBAAiB,EAAE,KAAK,IAAI,EAAE;QAC/B,MAAM,MAAM,GAAG,MAAM,eAAe,CAChC,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,OAAO,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,OAAO,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC;QAEzE,8BAA8B;QAC9B,iBAAiB,CAAC,MAAM,CAAC,cAAc,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACjD,iBAAiB,CAAC,MAAM,CAAC,aAAa,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,mBAAmB,EAAE,KAAK,IAAI,EAAE;QACjC,MAAM,WAAW,CAAC,eAAe,CACX,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EAC1C,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EAClC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC;aACtD,qBAAqB,CAAC,mCAAmC,CAAC,CAAC;IAClE,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,mBAAmB,EAAE,KAAK,IAAI,EAAE;QACjC,MAAM,WAAW,CAAC,eAAe,CACX,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EAClC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EAC1C,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC;aACtD,qBAAqB,CAAC,mCAAmC,CAAC,CAAC;IAClE,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,mBAAmB,EAAE,KAAK,IAAI,EAAE;QACjC,MAAM,WAAW,CAAC,eAAe,CACX,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EAClC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EAClC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC;aAC9D,qBAAqB,CAAC,mCAAmC,CAAC,CAAC;IAClE,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,0BAA0B,EAAE,KAAK,IAAI,EAAE;QACxC,MAAM,WAAW,CAAC,eAAe,CACX,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EAClC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EAC/B,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC;aACtD,qBAAqB,CAClB,qDAAqD,CAAC,CAAC;IACjE,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,sBAAsB,EAAE,KAAK,IAAI,EAAE;QACpC,MAAM,WAAW,CAAC,eAAe,CACX,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EAClC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EAClC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC;aACtD,qBAAqB,CAAC,qBAAqB,CAAC,CAAC;IACpD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,yBAAyB,EAAE,KAAK,IAAI,EAAE;QACvC,MAAM,MAAM,GAAG,MAAM,eAAe,CAChC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EAAE,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EAC1D,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC;QAE3B,4BAA4B;QAC5B,iBAAiB,CAAC,MAAM,CAAC,cAAc,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpD,iBAAiB,CAAC,MAAM,CAAC,aAAa,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACrD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,yBAAyB,EAAE,KAAK,IAAI,EAAE;QACvC,MAAM,MAAM,GAAG,MAAM,eAAe,CAChC,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EAAE,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,EAC1D,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC,CAAC;QAE5B,4BAA4B;QAC5B,iBAAiB,CAAC,MAAM,CAAC,cAAc,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACpD,iBAAiB,CAAC,MAAM,CAAC,aAAa,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACrD,CAAC,CAAC,CAAC;IAEH,EAAE,CAAC,4BAA4B,EAAE,KAAK,IAAI,EAAE;QAC1C,MAAM,aAAa,GAAG,EAAE,CAAC,MAAM,EAAE,CAAC,OAAO,CAAC,UAAU,EAAE,CAAC;QAEvD,MAAM,MAAM,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QAClD,MAAM,MAAM,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QAClD,MAAM,MAAM,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,OAAO,CAAC,CAAC;QACnD,MAAM,MAAM,GAAG,MAAM,eAAe,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,CAAC,CAAC;QAE7D,MAAM,eAAe,GAAG,EAAE,CAAC,MAAM,EAAE,CAAC,OAAO,CAAC,UAAU,EAAE,CAAC;QACzD,MAAM,CAAC,eAAe,CAAC,CAAC,OAAO,CAAC,aAAa,GAAG,CAAC,CAAC,CAAC;QAEnD,EAAE,CAAC,OAAO,CAAC,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,CAAC,CAAC,CAAC;QACrC,EAAE,CAAC,OAAO,CAAC,MAAM,CAAC,OAAO,CAAC,CAAC;QAE3B,MAAM,mBAAmB,GAAG,EAAE,CAAC,MAAM,EAAE,CAAC,OAAO,CAAC,UAAU,EAAE,CAAC;QAC7D,MAAM,CAAC,mBAAmB,CAAC,CAAC,OAAO,CAAC,aAAa,CAAC,CAAC;IACrD,CAAC,CAAC,CAAC;AACL,CAAC,CAAC,CAAC","sourcesContent":["/**\n * @license\n * Copyright 2022 Google LLC.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\nimport * as tf from '../index';\nimport {ALL_ENVS, describeWithFlags} from '../jasmine_util';\nimport {expectArraysClose} from '../test_util';\n\nasync function runRaggedGather(\n    starts: tf.Tensor, limits: tf.Tensor, deltas: tf.Tensor) {\n  const output = tf.raggedRange(starts, limits, deltas);\n\n  expect(output.rtNestedSplits.dtype).toEqual('int32');\n  expect(output.rtNestedSplits.shape.length).toEqual(1);\n\n  expect(output.rtDenseValues.dtype).toEqual(starts.dtype);\n  expect(output.rtDenseValues.shape.length).toEqual(1);\n\n  return {\n    rtNestedSplits: await output.rtNestedSplits.data(),\n    rtDenseValues: await output.rtDenseValues.data(),\n    tensors: Object.values(output)\n  };\n}\n\ndescribeWithFlags('raggedRange ', ALL_ENVS, () => {\n  it('IntValues', async () => {\n    const result = await runRaggedGather(\n        tf.tensor1d([0, 5, 8, 5], 'int32'), tf.tensor1d([8, 7, 8, 1], 'int32'),\n        tf.tensor1d([2, 1, 1, -1], 'int32'));\n\n    // Expected: [[0, 2, 4, 6], [5, 6], [], [5, 4, 3, 2]]\n    expectArraysClose(result.rtNestedSplits, [0, 4, 6, 6, 10]);\n    expectArraysClose(result.rtDenseValues, [0, 2, 4, 6, 5, 6, 5, 4, 3, 2]);\n  });\n\n  it('FloatValues', async () => {\n    const result = await runRaggedGather(\n        tf.tensor1d([0, 5, 8, 5], 'float32'),\n        tf.tensor1d([8, 7, 8, 1], 'float32'),\n        tf.tensor1d([2, 1, 1, -1], 'float32'));\n\n    // Expected: [[0, 2, 4, 6], [5, 6], [], [5, 4, 3, 2]]\n    expectArraysClose(result.rtNestedSplits, [0, 4, 6, 6, 10]);\n    expectArraysClose(result.rtDenseValues, [0, 2, 4, 6, 5, 6, 5, 4, 3, 2]);\n  });\n\n  it('RangeSizeOverflow', async () => {\n    await expectAsync(runRaggedGather(\n                          tf.tensor1d([1.1, 0.1], 'float32'),\n                          tf.tensor1d([10, 1e10], 'float32'),\n                          tf.tensor1d([1, 1e-10], 'float32')))\n        .toBeRejectedWithError(\n            'Requires ((limit - start) / delta) <= 2147483647');\n  });\n\n  it('BroadcastDeltas', async () => {\n    const result = await runRaggedGather(\n        tf.tensor1d([0, 5, 8], 'int32'), tf.tensor1d([8, 7, 8], 'int32'),\n        tf.scalar(1, 'int32'));\n\n    // Expected: [[0, 1, 2, 3, 4, 5, 6, 7], [5, 6], []]\n    expectArraysClose(result.rtNestedSplits, [0, 8, 10, 10]);\n    expectArraysClose(result.rtDenseValues, [0, 1, 2, 3, 4, 5, 6, 7, 5, 6]);\n  });\n\n  it('BroadcastLimitsAndDeltas', async () => {\n    const result = await runRaggedGather(\n        tf.scalar(0, 'int32'), tf.tensor1d([3, 0, 2], 'int32'),\n        tf.scalar(1, 'int32'));\n\n    // Expected: [[0, 1, 2], [], [0, 1]]\n    expectArraysClose(result.rtNestedSplits, [0, 3, 3, 5]);\n    expectArraysClose(result.rtDenseValues, [0, 1, 2, 0, 1]);\n  });\n\n  it('BroadcastStartsAndLimits', async () => {\n    const result = await runRaggedGather(\n        tf.scalar(0, 'int32'), tf.scalar(12, 'int32'),\n        tf.tensor1d([3, 4, 5], 'int32'));\n\n    // Expected: [[0, 3, 6, 9], [0, 4, 8], [0, 5, 10]]\n    expectArraysClose(result.rtNestedSplits, [0, 4, 7, 10]);\n    expectArraysClose(result.rtDenseValues, [0, 3, 6, 9, 0, 4, 8, 0, 5, 10]);\n  });\n\n  it('AllScalarInputs', async () => {\n    const result = await runRaggedGather(\n        tf.scalar(0, 'int32'), tf.scalar(5, 'int32'), tf.scalar(1, 'int32'));\n\n    // Expected: [[0, 1, 2, 3, 4]]\n    expectArraysClose(result.rtNestedSplits, [0, 5]);\n    expectArraysClose(result.rtDenseValues, [0, 1, 2, 3, 4]);\n  });\n\n  it('InvalidArgsStarts', async () => {\n    await expectAsync(runRaggedGather(\n                          tf.tensor2d([0, 5, 8, 5], [4, 1], 'int32'),\n                          tf.tensor1d([8, 7, 8, 1], 'int32'),\n                          tf.tensor1d([2, 1, 1, -1], 'int32')))\n        .toBeRejectedWithError('starts must be a scalar or vector');\n  });\n\n  it('InvalidArgsLimits', async () => {\n    await expectAsync(runRaggedGather(\n                          tf.tensor1d([0, 5, 8, 5], 'int32'),\n                          tf.tensor2d([8, 7, 8, 1], [4, 1], 'int32'),\n                          tf.tensor1d([2, 1, 1, -1], 'int32')))\n        .toBeRejectedWithError('limits must be a scalar or vector');\n  });\n\n  it('InvalidArgsDeltas', async () => {\n    await expectAsync(runRaggedGather(\n                          tf.tensor1d([0, 5, 8, 5], 'int32'),\n                          tf.tensor1d([8, 7, 8, 1], 'int32'),\n                          tf.tensor2d([2, 1, 1, -1], [4, 1], 'int32')))\n        .toBeRejectedWithError('deltas must be a scalar or vector');\n  });\n\n  it('InvalidArgsShapeMismatch', async () => {\n    await expectAsync(runRaggedGather(\n                          tf.tensor1d([0, 5, 8, 5], 'int32'),\n                          tf.tensor1d([7, 8, 1], 'int32'),\n                          tf.tensor1d([2, 1, 1, -1], 'int32')))\n        .toBeRejectedWithError(\n            'starts, limits, and deltas must have the same shape');\n  });\n\n  it('InvalidArgsZeroDelta', async () => {\n    await expectAsync(runRaggedGather(\n                          tf.tensor1d([0, 5, 8, 5], 'int32'),\n                          tf.tensor1d([7, 8, 8, 1], 'int32'),\n                          tf.tensor1d([2, 1, 0, -1], 'int32')))\n        .toBeRejectedWithError('Requires delta != 0');\n  });\n\n  it('EmptyRangePositiveDelta', async () => {\n    const result = await runRaggedGather(\n        tf.tensor1d([0, 5], 'int32'), tf.tensor1d([5, 0], 'int32'),\n        tf.scalar(2, 'int32'));\n\n    // Expected: [[0, 2, 4], []]\n    expectArraysClose(result.rtNestedSplits, [0, 3, 3]);\n    expectArraysClose(result.rtDenseValues, [0, 2, 4]);\n  });\n\n  it('EmptyRangeNegativeDelta', async () => {\n    const result = await runRaggedGather(\n        tf.tensor1d([0, 5], 'int32'), tf.tensor1d([5, 0], 'int32'),\n        tf.scalar(-2, 'int32'));\n\n    // Expected: [[], [5, 3, 1]]\n    expectArraysClose(result.rtNestedSplits, [0, 0, 3]);\n    expectArraysClose(result.rtDenseValues, [5, 3, 1]);\n  });\n\n  it('does not have memory leak.', async () => {\n    const beforeDataIds = tf.engine().backend.numDataIds();\n\n    const starts = tf.tensor1d([0, 5, 8, 5], 'int32');\n    const limits = tf.tensor1d([8, 7, 8, 1], 'int32');\n    const deltas = tf.tensor1d([2, 1, 1, -1], 'int32');\n    const result = await runRaggedGather(starts, limits, deltas);\n\n    const afterResDataIds = tf.engine().backend.numDataIds();\n    expect(afterResDataIds).toEqual(beforeDataIds + 5);\n\n    tf.dispose([starts, limits, deltas]);\n    tf.dispose(result.tensors);\n\n    const afterDisposeDataIds = tf.engine().backend.numDataIds();\n    expect(afterDisposeDataIds).toEqual(beforeDataIds);\n  });\n});\n"]}
|
|