gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
/**
 * @license
 * Copyright 2018 Google Inc. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
 
import {ENGINE} from '../engine';
import {Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorBuffer} from '../tensor';
import {convertToTensor, convertToTensorArray} from '../tensor_util_env';
import {DataType, DataTypeMap, Rank, ShapeMap, TensorLike, TensorLike4D} from '../types';
import * as util from '../util';
import {getAxesPermutation, getInnerMostAxes} from './axis_util';
import {concat} from './concat_split';
import {op} from './operation';
import {MPRandGauss, RandGamma, UniformRandom} from './rand';
import {zeros, zerosLike} from './tensor_ops';
 
/**
 * Broadcast an array to a compatible shape NumPy-style.
 *
 * The tensor's shape is compared to the broadcast shape from end to beginning.
 * Ones are prepended to the tensor's shape until is has the same length as
 * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is
 * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then
 * the input tensor is tiled N times along that axis (using tf.tile).
 *
 * @param input The tensor that is to be broadcasted.
 * @param shape The input is to be broadcast to this shape.
 */
/** @doc {heading: 'Tensors', subheading: 'Transformations'} */
function broadcastTo_<R extends Rank>(
    x: Tensor|TensorLike, shape: ShapeMap[R]): Tensor<R> {
  let input = convertToTensor(x, 'broadcastTo', 'x');
  const xShape = input.shape;
 
  if (shape.some(d => !(d > 0) || d % 1 !== 0)) {
    throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`);
  }
 
  if (shape.length < input.rank) {
    throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${
        input.rank}.`);
  }
 
  if (shape.length > input.rank) {
    const newShape = input.shape.slice();
    while (newShape.length < shape.length) {
      newShape.unshift(1);
    }
    input = input.reshape(newShape);
  }
 
  const reps: number[] = Array.from(shape);
  for (let i = shape.length - 1; i >= 0; i--) {
    if (input.shape[i] === shape[i]) {
      reps[i] = 1;
    } else if (input.shape[i] !== 1) {
      throw new Error(
          `broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`);
    }
  }
  const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0);
 
  if (axes.length === 0) {
    return input.clone() as Tensor<R>;
  }
 
  return ENGINE.runKernelFunc(
             backend => backend.tile(input, reps), {input},
             (dy: Tensor) =>
                 ({input: () => dy.sum(axes, /*keepDims=*/true)})) as Tensor<R>;
}
 
/**
 * Creates a new tensor with the same values and shape as the specified
 * tensor.
 *
 * ```js
 * const x = tf.tensor([1, 2]);
 *
 * x.clone().print();
 * ```
 *
 * @param x The tensor to clone.
 */
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
function clone_<T extends Tensor>(x: T|TensorLike): T {
  const $x = convertToTensor(x, 'x', 'clone', null);
  const der = (dy: T) => {
    return {$x: () => dy.toFloat()};
  };
  return ENGINE.runKernelFunc(
      () => ENGINE.makeTensorFromDataId($x.dataId, $x.shape, $x.dtype) as T,
      {$x}, der);
}
 
/**
 * Create an identity matrix.
 *
 * @param numRows Number of rows.
 * @param numColumns Number of columns. Defaults to `numRows`.
 * @param batchShape If provided, will add the batch shape to the beginning
 *   of the shape of the returned `tf.Tensor` by repeating the identity
 *   matrix.
 * @param dtype Data type.
 * @returns Identity matrix of the specified size and data type, possibly
 *   with batch repetition if `batchShape` is specified.
 */
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
function eye_(
    numRows: number, numColumns?: number,
    batchShape?:
        [
          number
        ]|[number,
           number]|[number, number, number]|[number, number, number, number],
    dtype: DataType = 'float32'): Tensor2D {
  if (numColumns == null) {
    numColumns = numRows;
  }
  const buff = buffer([numRows, numColumns], dtype);
  const n = numRows <= numColumns ? numRows : numColumns;
  for (let i = 0; i < n; ++i) {
    buff.set(1, i, i);
  }
  const out = buff.toTensor().as2D(numRows, numColumns);
  if (batchShape == null) {
    return out;
  } else {
    if (batchShape.length === 1) {
      return tile(expandDims(out, 0), [batchShape[0], 1, 1]);
    } else if (batchShape.length === 2) {
      return tile(
          expandDims(expandDims(out, 0), 0),
          [batchShape[0], batchShape[1], 1, 1]);
    } else if (batchShape.length === 3) {
      return tile(
          expandDims(expandDims(expandDims(out, 0), 0), 0),
          [batchShape[0], batchShape[1], batchShape[2], 1, 1]);
    } else {
      throw new Error(
          `eye() currently supports only 1D and 2D ` +
          // tslint:disable-next-line:no-any
          `batchShapes, but received ${(batchShape as any).length}D.`);
    }
  }
}
 
/**
 * Creates a `tf.Tensor` with values sampled from a normal distribution.
 *
 * ```js
 * tf.randomNormal([2, 2]).print();
 * ```
 *
 * @param shape An array of integers defining the output tensor shape.
 * @param mean The mean of the normal distribution.
 * @param stdDev The standard deviation of the normal distribution.
 * @param dtype The data type of the output.
 * @param seed The seed for the random number generator.
 */
/** @doc {heading: 'Tensors', subheading: 'Random'} */
function randomNormal_<R extends Rank>(
    shape: ShapeMap[R], mean = 0, stdDev = 1, dtype?: 'float32'|'int32',
    seed?: number): Tensor<R> {
  if (dtype != null && (dtype as DataType) === 'bool') {
    throw new Error(`Unsupported data type ${dtype}`);
  }
  const randGauss =
      new MPRandGauss(mean, stdDev, dtype, false /* truncated */, seed);
  const res = buffer(shape, dtype);
  for (let i = 0; i < res.values.length; i++) {
    res.values[i] = randGauss.nextValue();
  }
  return res.toTensor();
}
 
/**
 * Creates a `tf.Tensor` with values sampled from a truncated normal
 * distribution.
 *
 * ```js
 * tf.truncatedNormal([2, 2]).print();
 * ```
 *
 * The generated values follow a normal distribution with specified mean and
 * standard deviation, except that values whose magnitude is more than 2
 * standard deviations from the mean are dropped and re-picked.
 *
 * @param shape An array of integers defining the output tensor shape.
 * @param mean The mean of the normal distribution.
 * @param stdDev The standard deviation of the normal distribution.
 * @param dtype The data type of the output tensor.
 * @param seed The seed for the random number generator.
 */
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
function truncatedNormal_<R extends Rank>(
    shape: ShapeMap[R], mean = 0, stdDev = 1, dtype?: 'float32'|'int32',
    seed?: number): Tensor<R> {
  if (dtype != null && (dtype as DataType) === 'bool') {
    throw new Error(`Unsupported data type ${dtype}`);
  }
  const randGauss =
      new MPRandGauss(mean, stdDev, dtype, true /* truncated */, seed);
  const res = buffer(shape, dtype);
  for (let i = 0; i < res.values.length; i++) {
    res.values[i] = randGauss.nextValue();
  }
  return res.toTensor();
}
 
/**
 * Creates a `tf.Tensor` with values sampled from a gamma distribution.
 *
 * ```js
 * tf.randomGamma([2, 2], 1).print();
 * ```
 *
 * @param shape An array of integers defining the output tensor shape.
 * @param alpha The shape parameter of the gamma distribution.
 * @param beta The inverse scale parameter of the gamma distribution. Defaults
 *     to 1.
 * @param dtype The data type of the output. Defaults to float32.
 * @param seed The seed for the random number generator.
 */
/** @doc {heading: 'Tensors', subheading: 'Random'} */
function randomGamma_<R extends Rank>(
    shape: ShapeMap[R], alpha: number, beta = 1,
    dtype: 'float32'|'int32' = 'float32', seed?: number): Tensor<R> {
  if (beta == null) {
    beta = 1;
  }
  if (dtype == null) {
    dtype = 'float32';
  }
  if (dtype !== 'float32' && dtype !== 'int32') {
    throw new Error(`Unsupported data type ${dtype}`);
  }
  const rgamma = new RandGamma(alpha, beta, dtype, seed);
  const res = buffer(shape, dtype);
  for (let i = 0; i < res.values.length; i++) {
    res.values[i] = rgamma.nextValue();
  }
  return res.toTensor();
}
 
/**
 * Creates a `tf.Tensor` with values sampled from a uniform distribution.
 *
 * The generated values follow a uniform distribution in the range [minval,
 * maxval). The lower bound minval is included in the range, while the upper
 * bound maxval is excluded.
 *
 * ```js
 * tf.randomUniform([2, 2]).print();
 * ```
 *
 * @param shape An array of integers defining the output tensor shape.
 * @param minval The lower bound on the range of random values to generate.
 *   Defaults to 0.
 * @param maxval The upper bound on the range of random values to generate.
 *   Defaults to 1.
 * @param dtype The data type of the output tensor. Defaults to 'float32'.
 */
/** @doc {heading: 'Tensors', subheading: 'Random'} */
function randomUniform_<R extends Rank>(
    shape: ShapeMap[R], minval = 0, maxval = 1, dtype: DataType = 'float32',
    seed?: number|string): Tensor<R> {
  const res = buffer(shape, dtype);
  const random = new UniformRandom(minval, maxval, null, seed);
  for (let i = 0; i < res.values.length; i++) {
    res.values[i] = random.nextValue();
  }
  return res.toTensor();
}
 
/**
 * Creates a `tf.Tensor` with values sampled from a random number generator
 * function defined by the user.
 *
 * @param shape An array of integers defining the output tensor shape.
 * @param randFunction A random number generator function which is called
 * for each element in the output tensor.
 * @param dtype The data type of the output tensor. Defaults to 'float32'.
 */
function rand_<R extends Rank>(
    shape: ShapeMap[R], randFunction: () => number,
    dtype?: DataType): Tensor<R> {
  const size = util.sizeFromShape(shape);
 
  let values = null;
  if (dtype == null || dtype === 'float32') {
    values = new Float32Array(size);
  } else if (dtype === 'int32') {
    values = new Int32Array(size);
  } else if (dtype === 'bool') {
    values = new Uint8Array(size);
  } else {
    throw new Error(`Unknown data type ${dtype}`);
  }
 
  for (let i = 0; i < size; i++) {
    values[i] = randFunction();
  }
  return ENGINE.makeTensor(values, shape, dtype) as Tensor<R>;
}
 
/**
 * Creates a `tf.Tensor` with values drawn from a multinomial distribution.
 *
 * ```js
 * const probs = tf.tensor([.75, .25]);
 * tf.multinomial(probs, 3).print();
 * ```
 *
 * @param logits 1D array with unnormalized log-probabilities, or
 *     2D array of shape `[batchSize, numOutcomes]`. See the `normalized`
 *     parameter.
 * @param numSamples Number of samples to draw for each row slice.
 * @param seed The seed number.
 * @param normalized Whether the provided `logits` are normalized true
 *     probabilities (sum to 1). Defaults to false.
 * @return 1D array of shape `[numSamples]`, or 2D array of shape
 *     `[batchSize, numSamples]`, depending on the rank of the input.
 */
/** @doc {heading: 'Tensors', subheading: 'Random'} */
function multinomial_(
    logits: Tensor1D|Tensor2D|TensorLike, numSamples: number, seed?: number,
    normalized = false): Tensor1D|Tensor2D {
  const $logits = convertToTensor(logits, 'logits', 'multinomial');
  const numOutcomes = $logits.size;
  const origRank = $logits.rank;
  if (numOutcomes < 2) {
    throw new Error(
        `Error in multinomial: you need at least 2 outcomes, but got ` +
        `${numOutcomes}.`);
  }
  if (origRank > 2) {
    throw new Error(`Rank of probabilities must be 1 or 2, but is ${origRank}`);
  }
  seed = seed || Math.random();
  const logits2D = origRank === 1 ? $logits.as2D(1, -1) : $logits as Tensor2D;
  const res = ENGINE.runKernelFunc(
      backend => backend.multinomial(logits2D, normalized, numSamples, seed),
      {logits2D});
 
  return origRank === 1 ? res.as1D() : res;
}
 
/**
 * Creates a one-hot `tf.Tensor`. The locations represented by `indices` take
 * value `onValue` (defaults to 1), while all other locations take value
 * `offValue` (defaults to 0). If `indices` is rank `R`, the output has rank
 * `R+1` with the last axis of size `depth`.
 *
 * ```js
 * tf.oneHot(tf.tensor1d([0, 1], 'int32'), 3).print();
 * ```
 *
 * @param indices `tf.Tensor` of indices with dtype `int32`.
 * @param depth The depth of the one hot dimension.
 * @param onValue A number used to fill in the output when the index matches
 * the location.
 * @param offValue A number used to fill in the output when the index does
 *     not match the location.
 */
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
function oneHot_(
    indices: Tensor|TensorLike, depth: number, onValue = 1,
    offValue = 0): Tensor {
  if (depth < 2) {
    throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`);
  }
  let $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32');
  const outShape = [...$indices.shape, depth];
  $indices = $indices.flatten();
 
  const grad = (dy: Tensor2D) => {
    return {$indices: () => zeros($indices.shape, 'float32')};
  };
  const result = ENGINE.runKernelFunc(
      backend => backend.oneHot($indices as Tensor1D, depth, onValue, offValue),
      {$indices}, grad);
  return result.reshape(outShape);
}
 
/**
 * Reshapes a `tf.Tensor` to a given shape.
 *
 * Given an input tensor, returns a new tensor with the same values as the
 * input tensor with shape `shape`.
 *
 * If one component of shape is the special value -1, the size of that
 * dimension is computed so that the total size remains constant. In
 * particular, a shape of [-1] flattens into 1-D. At most one component of
 * shape can be -1.
 *
 * If shape is 1-D or higher, then the operation returns a tensor with shape
 * shape filled with the values of tensor. In this case, the number of
 * elements implied by shape must be the same as the number of elements in
 * tensor.
 *
 * ```js
 * const x = tf.tensor1d([1, 2, 3, 4]);
 * x.reshape([2, 2]).print();
 * ```
 *
 * @param x The input tensor to be reshaped.
 * @param shape An array of integers defining the output tensor shape.
 */
/** @doc {heading: 'Tensors', subheading: 'Transformations'} */
function reshape_<R2 extends Rank>(
    x: Tensor|TensorLike, shape: ShapeMap[R2]): Tensor<R2> {
  const $x = convertToTensor(x, 'x', 'reshape', null);
  shape = util.inferFromImplicitShape(shape, $x.size) as ShapeMap[R2];
  util.assert(
      $x.size === util.sizeFromShape(shape),
      () => 'new shape and old shape must have the same number of elements.');
 
  const grad = (dy: Tensor<R2>) => {
    return {x: () => dy.reshape($x.shape)};
  };
  const attrs = {shape};
  return ENGINE.runKernelFunc(
      backend => backend.reshape($x, shape), {x: $x}, grad, 'Reshape', attrs);
}
 
/**
 * Removes dimensions of size 1 from the shape of a `tf.Tensor`.
 *
 * ```js
 * const x = tf.tensor([1, 2, 3, 4], [1, 1, 4]);
 * x.squeeze().print();
 * ```
 *
 * @param x The input tensor to be squeezed.
 * @param axis An optional list of numbers. If specified, only
 *     squeezes the dimensions listed. The dimension index starts at 0. It
 * is an error to squeeze a dimension that is not 1.
 */
/** @doc {heading: 'Tensors', subheading: 'Transformations'} */
function squeeze_<T extends Tensor>(x: Tensor|TensorLike, axis?: number[]): T {
  const $x = convertToTensor(x, 'x', 'squeeze');
  return reshape($x, util.squeezeShape($x.shape, axis).newShape) as T;
}
 
/**
 * Casts a `tf.Tensor` to a new dtype.
 *
 * ```js
 * const x = tf.tensor1d([1.5, 2.5, 3]);
 * tf.cast(x, 'int32').print();
 * ```
 * @param x The input tensor to be casted.
 * @param dtype The dtype to cast the input tensor to.
 */
/** @doc {heading: 'Tensors', subheading: 'Transformations'} */
function cast_<T extends Tensor>(x: T|TensorLike, dtype: DataType): T {
  const $x = convertToTensor(x, 'x', 'cast');
 
  // Sanity checks.
  if (!util.isValidDtype(dtype)) {
    throw new Error(`Failed to cast to unknown dtype ${dtype}`);
  }
  if (dtype === 'string' && $x.dtype !== 'string' ||
      dtype !== 'string' && $x.dtype === 'string') {
    throw new Error('Only strings can be casted to strings');
  }
 
  const grad = (dy: T) => {
    return {x: () => dy.clone()};
  };
  const attrs = {dtype};
  return ENGINE.runKernelFunc(
      backend => backend.cast($x, dtype), {x: $x}, grad, 'Cast', attrs);
}
 
/**
 * Construct a tensor by repeating it the number of times given by reps.
 *
 * This operation creates a new tensor by replicating `input` `reps`
 * times. The output tensor's i'th dimension has `input.shape[i] *
 * reps[i]` elements, and the values of `input` are replicated
 * `reps[i]` times along the i'th dimension. For example, tiling
 * `[a, b, c, d]` by `[2]` produces `[a, b, c, d, a, b, c, d]`.
 *
 * ```js
 * const a = tf.tensor1d([1, 2]);
 *
 * a.tile([2]).print();    // or a.tile([2])
 * ```
 *
 * ```js
 * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
 *
 * a.tile([1, 2]).print();  // or a.tile([1, 2])
 * ```
 * @param x The tensor to tile.
 * @param reps Determines the number of replications per dimension.
 */
/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */
function tile_<T extends Tensor>(x: T|TensorLike, reps: number[]): T {
  const parseAs: DataType = null;
  const $x = convertToTensor(x, 'x', 'tile', parseAs);
 
  util.assert(
      $x.rank === reps.length,
      () => `Error in transpose: rank of input ${$x.rank} ` +
          `must match length of reps ${reps}.`);
  const grad = (dy: T, saved: Tensor[]) => {
    const [$x] = saved;
    const derX = () => {
      let xGrad = zerosLike($x);
      // TODO(cais): Maybe reduce memory footprint by avoiding repeated
      // slicing.
      if ($x.rank === 1) {
        for (let i = 0; i < reps[0]; ++i) {
          xGrad = xGrad.add(dy.slice([i * $x.shape[0]], [$x.shape[0]]));
        }
      } else if ($x.rank === 2) {
        for (let i = 0; i < reps[0]; ++i) {
          for (let j = 0; j < reps[1]; ++j) {
            xGrad = xGrad.add(dy.slice(
                [i * $x.shape[0], j * $x.shape[1]],
                [$x.shape[0], $x.shape[1]]));
          }
        }
      } else if ($x.rank === 3) {
        for (let i = 0; i < reps[0]; ++i) {
          for (let j = 0; j < reps[1]; ++j) {
            for (let k = 0; k < reps[2]; ++k) {
              xGrad = xGrad.add(dy.slice(
                  [i * $x.shape[0], j * $x.shape[1], k * $x.shape[2]],
                  [$x.shape[0], $x.shape[1], $x.shape[2]]));
            }
          }
        }
      } else if ($x.rank === 4) {
        for (let i = 0; i < reps[0]; ++i) {
          for (let j = 0; j < reps[1]; ++j) {
            for (let k = 0; k < reps[2]; ++k) {
              for (let l = 0; l < reps[3]; ++l) {
                xGrad = xGrad.add(dy.slice(
                    [
                      i * $x.shape[0], j * $x.shape[1], k * $x.shape[2],
                      l * $x.shape[3]
                    ],
                    [$x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]));
              }
            }
          }
        }
      } else {
        throw new Error(
            `Gradient for tile operation is not implemented for rank-` +
            `${$x.rank} tensors yet.`);
      }
      return xGrad as T;
    };
    return {x: derX};
  };
  const inputsToSave = [$x];
  const attrs = {reps};
  return ENGINE.runKernelFunc((backend, save) => {
    const res = backend.tile($x, reps);
    save([$x]);
    return res;
  }, {x: $x}, grad, 'Tile', attrs, inputsToSave);
}
 
/**
 * Pads a `tf.Tensor1D` with a given value and paddings. See `pad` for details.
 */
function pad1d_(
    x: Tensor1D|TensorLike, paddings: [number, number],
    constantValue = 0): Tensor1D {
  util.assert(
      paddings.length === 2,
      () => 'Invalid number of paddings. Must be length of 2.');
  return pad(x, [paddings], constantValue);
}
 
/**
 * Pads a `tf.Tensor2D` with a given value and paddings. See `pad` for details.
 */
function pad2d_(
    x: Tensor2D|TensorLike, paddings: [[number, number], [number, number]],
    constantValue = 0): Tensor2D {
  util.assert(
      paddings.length === 2 && paddings[0].length === 2 &&
          paddings[1].length === 2,
      () => 'Invalid number of paddings. Must be length of 2 each.');
  return pad(x, paddings, constantValue);
}
 
/**
 * Pads a `tf.Tensor3D` with a given value and paddings. See `pad` for details.
 */
function pad3d_(
    x: Tensor3D|TensorLike,
    paddings: [[number, number], [number, number], [number, number]],
    constantValue = 0): Tensor3D {
  util.assert(
      paddings.length === 3 && paddings[0].length === 2 &&
          paddings[1].length === 2 && paddings[2].length === 2,
      () => 'Invalid number of paddings. Must be length of 2 each.');
  return pad(x, paddings, constantValue);
}
 
/**
 * Pads a `tf.Tensor4D` with a given value and paddings. See `pad` for details.
 */
function pad4d_(
    x: Tensor4D|TensorLike,
    paddings:
        [
          [number, number], [number, number], [number, number], [number, number]
        ],
    constantValue = 0): Tensor4D {
  util.assert(
      paddings.length === 4 && paddings[0].length === 2 &&
          paddings[1].length === 2 && paddings[2].length === 2 &&
          paddings[3].length === 2,
      () => 'Invalid number of paddings. Must be length of 2 each.');
  return pad(x, paddings, constantValue);
}
 
/**
 * Pads a `tf.Tensor` with a given value and paddings.
 *
 * This operation currently only implements the `CONSTANT` mode.
 *
 * Also available are stricter rank-specific methods with the same signature
 * as this method that assert that `paddings` is of given length.
 *   - `tf.pad1d`
 *   - `tf.pad2d`
 *   - `tf.pad3d`
 *   - `tf.pad4d`
 *
 * ```js
 * const x = tf.tensor1d([1, 2, 3, 4]);
 * x.pad([[1, 2]]).print();
 * ```
 * @param x The tensor to pad.
 * @param paddings An array of length `R` (the rank of the tensor), where
 * each element is a length-2 tuple of ints `[padBefore, padAfter]`,
 * specifying how much to pad along each dimension of the tensor.
 * @param constantValue The pad value to use. Defaults to 0.
 */
/** @doc {heading: 'Tensors', subheading: 'Transformations'} */
function pad_<T extends Tensor>(
    x: T|TensorLike, paddings: Array<[number, number]>, constantValue = 0): T {
  const $x = convertToTensor(x, 'x', 'pad');
 
  if ($x.rank === 0) {
    throw new Error('pad(scalar) is not defined. Pass non-scalar to pad');
  }
 
  const grad = (dy: T) => {
    // Pad introduces values around the original tensor, so the gradient
    // slices the original shape out of the gradient.
    const begin = paddings.map(p => p[0]);
    return {x: () => dy.slice(begin, $x.shape)};
  };
  const attrs = {paddings, constantValue};
  return ENGINE.runKernelFunc(
      backend => backend.pad($x, paddings, constantValue), {x: $x}, grad,
      'PadV2', attrs);
}
 
/**
 * Stacks a list of rank-`R` `tf.Tensor`s into one rank-`(R+1)` `tf.Tensor`.
 *
 * ```js
 * const a = tf.tensor1d([1, 2]);
 * const b = tf.tensor1d([3, 4]);
 * const c = tf.tensor1d([5, 6]);
 * tf.stack([a, b, c]).print();
 * ```
 *
 * @param tensors A list of tensor objects with the same shape and dtype.
 * @param axis The axis to stack along. Defaults to 0 (the first dim).
 */
/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */
function stack_<T extends Tensor>(
    tensors: Array<T|TensorLike>, axis = 0): Tensor {
  const $tensors = convertToTensorArray(tensors, 'tensors', 'stack');
 
  util.assert(
      $tensors.length >= 1, () => 'Pass at least one tensor to tf.stack');
  if ($tensors.length === 1) {
    return $tensors[0].expandDims(axis);
  }
  const rank = $tensors[0].rank;
  const shape = $tensors[0].shape;
  const dtype = $tensors[0].dtype;
 
  util.assert(axis <= rank, () => 'Axis must be <= rank of the tensor');
 
  $tensors.forEach(t => {
    util.assertShapesMatch(
        shape, t.shape,
        'All tensors passed to stack must have matching shapes');
  });
 
  $tensors.forEach(t => {
    util.assert(
        dtype === t.dtype,
        () => 'All tensors passed to stack must have matching dtypes');
  });
  const expandedTensors = $tensors.map(t => t.expandDims(axis));
  return concat(expandedTensors, axis);
}
 
/**
 * This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of
 * shape `blockShape + [batch]`, interleaves these blocks back into the grid
 * defined by the spatial dimensions `[1, ..., M]`, to obtain a result with
 * the same rank as the input. The spatial dimensions of this intermediate
 * result are then optionally cropped according to `crops` to produce the
 * output. This is the reverse of `tf.spaceToBatchND`. See below for a precise
 * description.
 *
 * ```js
 * const x = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]);
 * const blockShape = [2, 2];
 * const crops = [[0, 0], [0, 0]];
 *
 * x.batchToSpaceND(blockShape, crops).print();
 * ```
 *
 * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
 * remainingShape`, where spatialShape has `M` dimensions.
 * @param blockShape A 1-D array. Must have shape `[M]`, all values must
 * be >= 1.
 * @param crops A 2-D array.  Must have shape `[M, 2]`, all values must be >= 0.
 * `crops[i] = [cropStart, cropEnd]` specifies the amount to crop from input
 * dimension `i + 1`, which corresponds to spatial dimension `i`. It is required
 * that `cropStart[i] + cropEnd[i] <= blockShape[i] * inputShape[i + 1]`
 *
 * This operation is equivalent to the following steps:
 *
 * 1. Reshape `x` to `reshaped` of shape: `[blockShape[0], ...,
 * blockShape[M-1], batch / prod(blockShape), x.shape[1], ...,
 * x.shape[N-1]]`
 *
 * 2. Permute dimensions of `reshaped`to produce `permuted` of shape `[batch /
 * prod(blockShape),x.shape[1], blockShape[0], ..., x.shape[M],
 * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
 *
 * 3. Reshape `permuted` to produce `reshapedPermuted` of shape `[batch /
 * prod(blockShape),x.shape[1] * blockShape[0], ..., x.shape[M] *
 * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
 *
 * 4. Crop the start and end of dimensions `[1, ..., M]` of `reshapedPermuted`
 * according to `crops` to produce the output of shape: `[batch /
 * prod(blockShape),x.shape[1] * blockShape[0] - crops[0,0] - crops[0,1],
 * ..., x.shape[M] * blockShape[M-1] - crops[M-1,0] -
 * crops[M-1,1],x.shape[M+1], ..., x.shape[N-1]]`
 */
/** @doc {heading: 'Tensors', subheading: 'Transformations'} */
function batchToSpaceND_<T extends Tensor>(
    x: T|TensorLike, blockShape: number[], crops: number[][]): T {
  const $x = convertToTensor(x, 'x', 'batchToSpaceND');
  const prod = blockShape.reduce((a, b) => a * b);
 
  util.assert(
      $x.rank >= 1 + blockShape.length,
      () => `input rank is ${$x.rank} but should be > than blockShape.length ${
          blockShape.length}`);
 
  util.assert(
      crops.length === blockShape.length,
      () => `crops.length is ${
          crops.length} but should be equal to blockShape.length  ${
          blockShape.length}`);
 
  util.assert(
      $x.shape[0] % prod === 0,
      () => `input tensor batch is ${
                $x.shape[0]} but is not divisible by the product of ` +
          `the elements of blockShape ${blockShape.join(' * ')} === ${prod}`);
 
  const grad = (dy: T) => {
    return {$x: () => dy.spaceToBatchND(blockShape, crops)};
  };
 
  return ENGINE.runKernelFunc(
      backend => backend.batchToSpaceND($x, blockShape, crops), {$x}, grad);
}
 
/**
 * This operation divides "spatial" dimensions `[1, ..., M]` of the input into
 * a grid of blocks of shape `blockShape`, and interleaves these blocks with
 * the "batch" dimension (0) such that in the output, the spatial
 * dimensions `[1, ..., M]` correspond to the position within the grid,
 * and the batch dimension combines both the position within a spatial block
 * and the original batch position. Prior to division into blocks,
 * the spatial dimensions of the input are optionally zero padded
 * according to `paddings`. See below for a precise description.
 *
 * ```js
 * const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]);
 * const blockShape = [2, 2];
 * const paddings = [[0, 0], [0, 0]];
 *
 * x.spaceToBatchND(blockShape, paddings).print();
 * ```
 *
 * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
 * remainingShape`, where spatialShape has `M` dimensions.
 * @param blockShape A 1-D array. Must have shape `[M]`, all values must
 * be >= 1.
 * @param paddings A 2-D array. Must have shape `[M, 2]`, all values must be >=
 *     0. `paddings[i] = [padStart, padEnd]` specifies the amount to zero-pad
 * from input dimension `i + 1`, which corresponds to spatial dimension `i`. It
 * is required that
 * `(inputShape[i + 1] + padStart + padEnd) % blockShape[i] === 0`
 *
 * This operation is equivalent to the following steps:
 *
 * 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the input
 * according to `paddings` to produce `padded` of shape paddedShape.
 *
 * 2. Reshape `padded` to `reshapedPadded` of shape:
 * `[batch] + [paddedShape[1] / blockShape[0], blockShape[0], ...,
 * paddedShape[M] / blockShape[M-1], blockShape[M-1]] + remainingShape`
 *
 * 3. Permute dimensions of `reshapedPadded` to produce `permutedReshapedPadded`
 * of shape: `blockShape + [batch] + [paddedShape[1] / blockShape[0], ...,
 * paddedShape[M] / blockShape[M-1]] + remainingShape`
 *
 * 4. Reshape `permutedReshapedPadded` to flatten `blockShape` into the
 * batch dimension, producing an output tensor of shape:
 * `[batch * prod(blockShape)] + [paddedShape[1] / blockShape[0], ...,
 * paddedShape[M] / blockShape[M-1]] + remainingShape`
 */
/** @doc {heading: 'Tensors', subheading: 'Transformations'} */
function spaceToBatchND_<T extends Tensor>(
    x: T|TensorLike, blockShape: number[], paddings: number[][]): T {
  const $x = convertToTensor(x, 'x', 'spaceToBatchND');
 
  util.assert(
      $x.rank >= 1 + blockShape.length,
      () => `input rank ${$x.rank} should be > than [blockShape] ${
          blockShape.length}`);
 
  util.assert(
      paddings.length === blockShape.length,
      () => `paddings.shape[0] ${
          paddings.length} must be equal to [blockShape] ${blockShape.length}`);
 
  util.assert(
      $x.shape.reduce(
          (a, b, i) => {
            if (i > 0 && i <= blockShape.length) {
              return a &&
                  ((b + paddings[i - 1][0] + paddings[i - 1][1]) %
                       blockShape[i - 1] ===
                   0);
            }
            return a;
          },
          true),
      () => `input spatial dimensions ${$x.shape.slice(1)} with paddings ${
          paddings.toString()} must be divisible by blockShapes ${
          blockShape.toString()}`);
 
  const grad = (dy: T) => {
    return {$x: () => dy.batchToSpaceND(blockShape, paddings)};
  };
 
  return ENGINE.runKernelFunc(
      backend => backend.spaceToBatchND($x, blockShape, paddings), {$x}, grad);
}
 
/**
 * Unstacks a `tf.Tensor` of rank-`R` into a list of rank-`(R-1)` `tf.Tensor`s.
 *
 * ```js
 * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
 *
 * tf.unstack(a).forEach(tensor => tensor.print());
 * ```
 *
 * @param x A tensor object.
 * @param axis The axis to unstack along. Defaults to 0 (the first dim).
 */
/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */
function unstack_(x: Tensor|TensorLike, axis = 0): Tensor[] {
  axis = axis || 0;
  const $x = convertToTensor(x, 'x', 'unstack');
  util.assert(
      axis >= -$x.shape.length && axis < $x.shape.length,
      () =>
          `Axis = ${axis} is not in [-${$x.shape.length}, ${$x.shape.length})`);
  if (axis < 0) {
    axis += $x.shape.length;
  }
  const grad = (dy: Tensor[]) => {
    return {x: () => stack(dy, axis)};
  };
  const attrs = {axis};
  return ENGINE.runKernelFunc(
      backend => backend.unstack($x, axis), {x: $x}, grad, 'Unpack', attrs);
}
 
/**
 * Computes the cumulative sum of a `tf.Tensor` along `axis`.
 *
 * ```js
 * const x = tf.tensor([1, 2, 3, 4]);
 * x.cumsum().print();
 * ```
 * ```js
 * const x = tf.tensor([[1, 2], [3, 4]]);
 * x.cumsum().print();
 * ```
 *
 * @param x The input tensor to be summed.
 * @param axis The axis along which to sum. Optional. Defaults to 0.
 * @param exclusive Whether to perform exclusive cumulative sum. Optional.
 *     Defaults to false. If set to true then the sum of each tensor entry
 *     does not include its own value, but only the values previous to it
 *     along the specified axis.
 * @param reverse Whether to sum in the opposite direction. Optional.
 *     Defaults to false.
 */
/** @doc {heading: 'Operations', subheading: 'Scan'} */
function cumsum_<T extends Tensor>(
    x: Tensor|TensorLike, axis = 0, exclusive = false, reverse = false): T {
  const $x = convertToTensor(x, 'x', 'cumsum');
 
  axis = axis | 0;
  const permutation = getAxesPermutation([axis], $x.rank);
  let permutedX = $x;
  if (permutation != null) {
    permutedX = $x.transpose(permutation);
  }
  const permutedAxis = getInnerMostAxes(1, $x.rank)[0];
 
  const grad = (dy: T) => {
    return {permutedX: () => dy.cumsum(axis, exclusive, !reverse)};
  };
  let value = ENGINE.runKernelFunc(
                  backend => backend.cumsum(
                      permutedX, permutedAxis, exclusive, reverse),
                  {permutedX}, grad) as T;
 
  if (permutation != null) {
    value = value.transpose(permutation);
  }
  return value;
}
 
/**
 * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension
 * into the tensor's shape.
 *
 * ```js
 * const x = tf.tensor1d([1, 2, 3, 4]);
 * const axis = 1;
 * x.expandDims(axis).print();
 * ```
 *
 * @param x The input tensor whose dimensions to be expanded.
 * @param axis The dimension index at which to insert shape of `1`. Defaults
 *     to 0 (the first dimension).
 */
/** @doc {heading: 'Tensors', subheading: 'Transformations'} */
function expandDims_<R2 extends Rank>(
    x: Tensor|TensorLike, axis = 0): Tensor<R2> {
  const parseAs: DataType = null;
  const $x = convertToTensor(x, 'x', 'expandDims', parseAs);
 
  util.assert(axis <= $x.rank, () => 'Axis must be <= rank of the tensor');
  const newShape = $x.shape.slice();
  if (axis < 0) {
    // Negative value is counted from the tail of rank.
    util.assert(
        -($x.rank + 1) <= axis,
        () => `Axis must be in the interval [${- ($x.rank + 1)}, ${$x.rank}]`);
    axis = $x.rank + axis + 1;
  }
  newShape.splice(axis, 0, 1);
  return reshape($x, newShape as ShapeMap[R2]);
}
 
/**
 * Rearranges data from depth into blocks of spatial data. More specifically,
 * this op outputs a copy of the input tensor where values from the `depth`
 * dimension are moved in spatial blocks to the `height` and `width` dimensions.
 * The attr `blockSize` indicates the input block size and how the data is
 * moved.
 *
 *  - Chunks of data of size `blockSize * blockSize` from depth are rearranged
 * into non-overlapping blocks of size `blockSize x blockSize`
 *
 *  - The width the output tensor is `inputWidth * blockSize`, whereas the
 * height is `inputHeight * blockSize`
 *
 *  - The Y, X coordinates within each block of the output image are determined
 * by the high order component of the input channel index
 *
 *  - The depth of the input tensor must be divisible by `blockSize *
 * blockSize`
 *
 * The `dataFormat` attr specifies the layout of the input and output tensors
 * with the following options: "NHWC": [ `batch, height, width, channels` ]
 * "NCHW": [ `batch, channels, height, width` ]
 *
 * ```js
 * const x = tf.tensor4d([1, 2, 3, 4], [1, 1, 1, 4]);
 * const blockSize = 2;
 * const dataFormat = "NHWC";
 *
 * tf.depthToSpace(x, blockSize, dataFormat).print();
 * ```
 *
 * @param x The input tensor of rank 4
 * @param blockSIze  An `int` that is `>= 2`. The size of the spatial block
 * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to "NHWC"
 */
/** @doc {heading: 'Tensors', subheading: 'Transformations'} */
function depthToSpace_(
    x: Tensor4D|TensorLike4D, blockSize: number,
    dataFormat: 'NHWC'|'NCHW' = 'NHWC'): Tensor4D {
  const $x = convertToTensor(x, 'x', 'depthToSpace') as Tensor4D;
 
  const inputHeight = (dataFormat === 'NHWC') ? $x.shape[1] : $x.shape[2];
  const inputWidth = (dataFormat === 'NHWC') ? $x.shape[2] : $x.shape[3];
  const inputDepth = (dataFormat === 'NHWC') ? $x.shape[3] : $x.shape[1];
 
  util.assert(
      inputHeight * blockSize >= 0,
      () => `Negative dimension size caused by overflow when multiplying
      ${inputHeight} and ${blockSize}  for depthToSpace with input shape
      ${$x.shape}`);
 
  util.assert(
      inputWidth * blockSize >= 0,
      () => `Negative dimension size caused by overflow when multiplying
      ${inputWidth} and ${blockSize} for depthToSpace with input shape
          ${$x.shape}`);
 
  util.assert(
      (inputDepth % (blockSize * blockSize) === 0),
      () => `Dimension size must be evenly divisible by ${
          blockSize * blockSize} but is ${
          inputDepth} for depthToSpace with input shape ${$x.shape}`);
 
  return ENGINE.runKernelFunc(
      backend => backend.depthToSpace($x, blockSize, dataFormat), {$x});
}
 
/**
 * Computes the difference between two lists of numbers.
 *
 * Given a Tensor `x` and a Tensor `y`, this operation returns a Tensor `out`
 * that represents all values that are in `x` but not in `y`. The returned
 * Tensor `out` is sorted in the same order that the numbers appear in `x`
 * (duplicates are preserved). This operation also returns a Tensor indices that
 * represents the position of each out element in `x`. In other words:
 *
 * `out[i] = x[idx[i]] for i in [0, 1, ..., out.length - 1]`
 *
 * ```js
 * const x = [1, 2, 3, 4, 5, 6];
 * const y = [1, 3, 5];
 *
 * const [out, indices] = await tf.setdiff1dAsync(x, y);
 * out.print(); // [2, 4, 6]
 * indices.print(); // [1, 3, 5]
 * ```
 *
 * @param x 1-D Tensor. Values to keep.
 * @param y 1-D Tensor. Must have the same type as x. Values to exclude in the
 *     output.
 * @returns Promise of Tensor tuple [out, indices].
 *  out: Tensor with the same type as x.
 *  indices: A Tensor of type int32.
 */
/** @doc {heading: 'Tensors', subheading: 'Transformations'} */
async function setdiff1dAsync_(
    x: Tensor|TensorLike, y: Tensor|TensorLike): Promise<[Tensor, Tensor]> {
  const $x = convertToTensor(x, 'x', 'setdiff1d');
  const $y = convertToTensor(y, 'y', 'setdiff1d');
 
  util.assert(
      $x.dtype === $y.dtype,
      () => `x and y should have the same dtype, but got x (${
          $x.dtype}) and y (${$y.dtype}).`);
 
  util.assert(
      $x.rank === 1, () => `x should be 1D tensor, but got x (${$x.shape}).`);
 
  util.assert(
      $y.rank === 1, () => `y should be 1D tensor, but got y (${$y.shape}).`);
 
  const xVals = await $x.data();
  const yVals = await $y.data();
  const ySet = new Set(yVals);
 
  let outputSize = 0;
  for (let i = 0; i < xVals.length; i++) {
    if (!ySet.has(xVals[i])) {
      outputSize++;
    }
  }
 
  const buffer = new TensorBuffer([outputSize], $x.dtype);
  const indices = new TensorBuffer([outputSize], 'int32');
  for (let i = 0, p = 0; i < xVals.length; i++) {
    if (!ySet.has(xVals[i])) {
      buffer.values[p] = xVals[i];
      indices.values[p] = i;
      p++;
    }
  }
  return [buffer.toTensor(), indices.toTensor()];
}
 
/**
 * Creates an empty `tf.TensorBuffer` with the specified `shape` and `dtype`.
 *
 * The values are stored in CPU as `TypedArray`. Fill the buffer using
 * `buffer.set()`, or by modifying directly `buffer.values`.
 *
 * When done, call `buffer.toTensor()` to get an immutable `tf.Tensor` with
 * those values.
 *
 * ```js
 * // Create a buffer and set values at particular indices.
 * const buffer = tf.buffer([2, 2]);
 * buffer.set(3, 0, 0);
 * buffer.set(5, 1, 0);
 *
 * // Convert the buffer back to a tensor.
 * buffer.toTensor().print();
 * ```
 *
 * @param shape An array of integers defining the output tensor shape.
 * @param dtype The dtype of the buffer. Defaults to 'float32'.
 * @param values The values of the buffer as `TypedArray`. Defaults to
 * zeros.
 */
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
function buffer<R extends Rank, D extends DataType = 'float32'>(
    shape: ShapeMap[R], dtype: D = 'float32' as D,
    values?: DataTypeMap[D]): TensorBuffer<R, D> {
  dtype = dtype || 'float32' as D;
  util.assertNonNegativeIntegerDimensions(shape);
  return new TensorBuffer<R, D>(shape, dtype, values);
}
 
/**
 * Prints information about the `tf.Tensor` including its data.
 *
 * ```js
 * const verbose = true;
 * tf.tensor2d([1, 2, 3, 4], [2, 2]).print(verbose);
 * ```
 * @param x The tensor to be printed.
 * @param verbose Whether to print verbose information about the ` Tensor`,
 * including dtype and size.
 */
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
function print<T extends Tensor>(x: T, verbose = false): void {
  console.log(x.toString(verbose));
}
 
export {
  buffer,  // Not wrapped in op() since no tensors.
  print    // Not wrapped in op() since no need to increase stack trace.
};
 
export const batchToSpaceND = op({batchToSpaceND_});
export const broadcastTo = op({broadcastTo_});
export const cast = op({cast_});
export const clone = op({clone_});
export const cumsum = op({cumsum_});
export const depthToSpace = op({depthToSpace_});
export const expandDims = op({expandDims_});
export const eye = op({eye_});
export const multinomial = op({multinomial_});
export const oneHot = op({oneHot_});
export const pad = op({pad_});
export const pad1d = op({pad1d_});
export const pad2d = op({pad2d_});
export const pad3d = op({pad3d_});
export const pad4d = op({pad4d_});
export const rand = op({rand_});
export const randomNormal = op({randomNormal_});
export const randomGamma = op({randomGamma_});
export const randomUniform = op({randomUniform_});
export const reshape = op({reshape_});
export const spaceToBatchND = op({spaceToBatchND_});
export const squeeze = op({squeeze_});
export const stack = op({stack_});
export const tile = op({tile_});
export const truncatedNormal = op({truncatedNormal_});
export const unstack = op({unstack_});
export const setdiff1dAsync = setdiff1dAsync_;