gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
/**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
 
import {BackendTimingInfo, DataMover, KernelBackend} from './backends/backend';
import {Environment, setEnvironmentGlobal} from './environment';
import {getGradient, getKernel, getKernelsForBackend, NamedAttrMap, TensorInfo} from './kernel_registry';
import {Profiler} from './profiler';
import {backpropagateGradients, getFilteredNodesXToY, NamedGradientMap, TapeNode} from './tape';
import {DataId, setTensorTracker, Tensor, TensorTracker, Variable} from './tensor';
import {GradSaveFunc, NamedTensorMap, NamedVariableMap, TensorContainer} from './tensor_types';
import {getTensorsInContainer} from './tensor_util';
import {BackendValues, DataType, DataValues} from './types';
import * as util from './util';
import {bytesFromStringArray, makeOnesTypedArray, now, sizeFromShape} from './util';
 
/**
 * A function that computes an output. The save function is for saving tensors
 * computed in the forward pass, that we need in the backward pass.
 */
export type ForwardFunc<T> = (backend: KernelBackend, save?: GradSaveFunc) => T;
 
/**
 * @docalias (a: Tensor, b: Tensor,..., save?: Function) => {
 *   value: Tensor,
 *   gradFunc: (dy: Tensor, saved?: NamedTensorMap) => Tensor | Tensor[]
 * }
 */
export type CustomGradientFunc<T extends Tensor> =
    (...inputs: Array<Tensor|GradSaveFunc>) => {
      value: T;
      gradFunc: (dy: T, saved: Tensor[]) => Tensor | Tensor[];
    };
 
export type MemoryInfo = {
  numTensors: number; numDataBuffers: number; numBytes: number;
  unreliable?: boolean; reasons: string[];
};
 
type KernelProfile = {
  name: string; bytesAdded: number; totalBytesSnapshot: number;
  tensorsAdded: number;
  totalTensorsSnapshot: number;
  inputShapes: number[][];
  outputShapes: number[][];
};
 
export type ProfileInfo = {
  newBytes: number; newTensors: number; peakBytes: number;
  kernels: KernelProfile[];
  result: TensorContainer;
};
 
export interface TimingInfo extends BackendTimingInfo {
  wallMs: number;
}
 
/** @docalias Function */
export type ScopeFn<T extends TensorContainer> = () => T;
 
interface ScopeState {
  track: Tensor[];
  name: string;
  id: number;
}
 
class EngineState {
  // Public since optimizers will use it.
  registeredVariables: NamedVariableMap = {};
 
  nextTapeNodeId = 0;
  numBytes = 0;
  numTensors = 0;
  numStringTensors = 0;
  numDataBuffers = 0;
 
  activeTape: TapeNode[];
  // Number of nested tf.grad() statements when computing higher-order
  // gradients. E.g. `1` for first-order gradients and `2` for second-order
  // gradients. Used to track if the tape should be removed after a backprop.
  gradientDepth = 0;
  // Number of nested kernel calls. When kernel depth is greater than 1, we turn
  // off the tape.
  kernelDepth = 0;
 
  // Keep Tensors that parallel the tapes.
  activeScope: ScopeState;
  scopeStack: ScopeState[] = [];
  /**
   * Keeps track of the number of data moves during a kernel execution. We
   * maintain a stack since kernels can call other kernels, recursively.
   */
  numDataMovesStack: number[] = [];
  nextScopeId = 0;
 
  tensorInfo = new WeakMap<DataId, {
    backend: KernelBackend,
    bytes: number,
    dtype: DataType,
    shape: number[],
    refCount: number
  }>();
 
  profiling = false;
  activeProfile: ProfileInfo =
      {newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null};
 
  dispose() {
    for (const variableName in this.registeredVariables) {
      this.registeredVariables[variableName].dispose();
    }
  }
}
 
export class Engine implements TensorTracker, DataMover {
  state: EngineState;
  backendName: string;
  registry: {[id: string]: KernelBackend} = {};
  registryFactory: {
    [id: string]: {
      factory: () => KernelBackend | Promise<KernelBackend>,
      priority: number
    }
  } = {};
 
  private profiler: Profiler;
  private backendInstance: KernelBackend;
  private pendingBackendInit: Promise<boolean>;
  private pendingBackendInitId = 0;
 
  constructor(public ENV: Environment) {
    this.state = new EngineState();
  }
 
  async ready(): Promise<void> {
    if (this.pendingBackendInit != null) {
      return this.pendingBackendInit.then(() => {});
    }
    if (this.backendInstance != null) {
      return;
    }
    const sortedBackends = this.getSortedBackends();
 
    for (let i = 0; i < sortedBackends.length; i++) {
      const backendName = sortedBackends[i];
      const success = await this.initializeBackend(backendName).success;
      if (success) {
        await this.setBackend(backendName);
        return;
      }
    }
 
    throw new Error(
        `Could not initialize any backends, all backend initializations ` +
        `failed.`);
  }
 
  get backend(): KernelBackend {
    if (this.pendingBackendInit != null) {
      throw new Error(
          `Backend '${this.backendName}' has not yet been initialized. Make ` +
          `sure to await tf.ready() or await tf.setBackend() before calling ` +
          `other methods`);
    }
    if (this.backendInstance == null) {
      const {name, asyncInit} = this.initializeBackendsAndReturnBest();
      if (asyncInit) {
        throw new Error(
            `The highest priority backend '${name}' has not yet been ` +
            `initialized. Make sure to await tf.ready() or ` +
            `await tf.setBackend() before calling other methods`);
      }
      this.setBackend(name);
    }
    return this.backendInstance;
  }
 
  backendNames(): string[] {
    return Object.keys(this.registryFactory);
  }
 
  findBackend(backendName: string): KernelBackend {
    if (!(backendName in this.registry)) {
      // If the backend hasn't been initialized but we have a registry entry for
      // it, initialize it and return it.
      if (backendName in this.registryFactory) {
        const {asyncInit} = this.initializeBackend(backendName);
        if (asyncInit) {
          // Backend is not ready yet.
          return null;
        }
      } else {
        return null;
      }
    }
    return this.registry[backendName];
  }
 
  findBackendFactory(backendName: string):
      () => KernelBackend | Promise<KernelBackend> {
    if (!(backendName in this.registryFactory)) {
      return null;
    }
    return this.registryFactory[backendName].factory;
  }
 
  registerBackend(
      backendName: string,
      factory: () => KernelBackend | Promise<KernelBackend>,
      priority = 1): boolean {
    if (backendName in this.registryFactory) {
      console.warn(
          `${backendName} backend was already registered. ` +
          `Reusing existing backend factory.`);
      return false;
    }
    this.registryFactory[backendName] = {factory, priority};
    return true;
  }
 
  async setBackend(backendName: string): Promise<boolean> {
    if (this.registryFactory[backendName] == null) {
      throw new Error(`Backend name '${backendName}' not found in registry`);
    }
    this.backendName = backendName;
    if (this.registry[backendName] == null) {
      this.backendInstance = null;
      const {success, asyncInit} = this.initializeBackend(backendName);
      const result = asyncInit ? await success : success;
      if (!result) {
        return false;
      }
    }
    this.backendInstance = this.registry[backendName];
    this.setupRegisteredKernels();
    // Reset the profiler.
    this.profiler = new Profiler(this.backendInstance);
 
    return true;
  }
 
  private setupRegisteredKernels(): void {
    const kernels = getKernelsForBackend(this.backendName);
    kernels.forEach(kernel => {
      if (kernel.setupFunc != null) {
        kernel.setupFunc(this.backendInstance);
      }
    });
  }
 
  private disposeRegisteredKernels(backendName: string): void {
    const kernels = getKernelsForBackend(backendName);
    kernels.forEach(kernel => {
      if (kernel.disposeFunc != null) {
        kernel.disposeFunc(this.registry[backendName]);
      }
    });
  }
 
  /**
   * Initializes a backend by looking up the backend name in the factory
   * registry and calling the factory method. Returns a boolean representing
   * whether the initialization of the backend suceeded. Throws an error if
   * there is no backend in the factory registry.
   */
  private initializeBackend(backendName: string):
      {success: boolean|Promise<boolean>, asyncInit: boolean} {
    const registryFactoryEntry = this.registryFactory[backendName];
    if (registryFactoryEntry == null) {
      throw new Error(
          `Cannot initialize backend ${backendName}, no registration found.`);
    }
 
    try {
      const backend = registryFactoryEntry.factory();
      // Test if the factory returns a promise.
      if (Promise.resolve(backend) === backend) {
        const promiseId = ++this.pendingBackendInitId;
        const success =
            backend
                .then(backendInstance => {
                  // Outdated promise. Another backend was set in the meantime.
                  if (promiseId < this.pendingBackendInitId) {
                    return false;
                  }
                  this.registry[backendName] = backendInstance;
                  this.pendingBackendInit = null;
                  return true;
                })
                .catch(err => {
                  // Outdated promise. Another backend was set in the meantime.
                  if (promiseId < this.pendingBackendInitId) {
                    return false;
                  }
                  this.pendingBackendInit = null;
                  console.warn(
                      `Initialization of backend ${backendName} failed`);
                  console.warn(err.stack || err.message);
                  return false;
                });
        this.pendingBackendInit = success;
        return {success, asyncInit: true};
      } else {
        this.registry[backendName] = backend as KernelBackend;
        return {success: true, asyncInit: false};
      }
    } catch (err) {
      console.warn(`Initialization of backend ${backendName} failed`);
      console.warn(err.stack || err.message);
      return {success: false, asyncInit: false};
    }
  }
 
  removeBackend(backendName: string): void {
    if (!(backendName in this.registryFactory)) {
      throw new Error(`${backendName} backend not found in registry`);
    }
    if (this.backendName === backendName && this.pendingBackendInit != null) {
      // There is a pending promise of the backend we want to remove. Make it
      // obsolete.
      this.pendingBackendInitId++;
    }
 
    if (backendName in this.registry) {
      this.disposeRegisteredKernels(backendName);
      this.registry[backendName].dispose();
      delete this.registry[backendName];
    }
 
    delete this.registryFactory[backendName];
 
    // Unset the backend if it is active.
    if (this.backendName === backendName) {
      this.pendingBackendInit = null;
      this.backendName = null;
      this.backendInstance = null;
    }
  }
 
  private getSortedBackends(): string[] {
    if (Object.keys(this.registryFactory).length === 0) {
      throw new Error('No backend found in registry.');
    }
    return Object.keys(this.registryFactory).sort((a: string, b: string) => {
      // Highest priority comes first.
      return this.registryFactory[b].priority -
          this.registryFactory[a].priority;
    });
  }
 
  private initializeBackendsAndReturnBest():
      {name: string, asyncInit: boolean} {
    const sortedBackends = this.getSortedBackends();
 
    for (let i = 0; i < sortedBackends.length; i++) {
      const backendName = sortedBackends[i];
      const {success, asyncInit} = this.initializeBackend(backendName);
      if (asyncInit || success) {
        return {name: backendName, asyncInit};
      }
    }
    throw new Error(
        `Could not initialize any backends, all backend initializations ` +
        `failed.`);
  }
 
  moveData(destBackend: KernelBackend, dataId: DataId) {
    const info = this.state.tensorInfo.get(dataId);
    const srcBackend = info.backend;
    const values = this.readSync(dataId);
    // Delete the tensor from the old backend and move it to the new
    // backend.
    srcBackend.disposeData(dataId);
    info.backend = destBackend;
    destBackend.move(dataId, values, info.shape, info.dtype);
    if (this.shouldCheckForMemLeaks()) {
      // Track the number of moves during a kernel execution to correctly
      // detect memory leaks.
      this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
    }
  }
 
  tidy<T extends TensorContainer>(nameOrFn: string|ScopeFn<T>, fn?: ScopeFn<T>):
      T {
    let name: string = null;
    if (fn == null) {
      // Called with only 1 argument.
      if (typeof nameOrFn !== 'function') {
        throw new Error('Please provide a function to tidy()');
      }
      fn = nameOrFn;
    } else {
      // Called with 2 arguments.
      if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) {
        throw new Error(
            'When calling with two arguments, the first argument ' +
            'to tidy() must be a string');
      }
      if (typeof fn !== 'function') {
        throw new Error(
            'When calling with two arguments, the 2nd argument ' +
            'to tidy() must be a function');
      }
      name = nameOrFn as string;
      // TODO(nsthorat,smilkov): Do operation logging and performance
      // profiling.
    }
    let result: T;
    return this.scopedRun(
        () => this.startScope(name), () => this.endScope(result), () => {
          result = fn();
          if (result instanceof Promise) {
            console.error('Cannot return a Promise inside of tidy.');
          }
          return result;
        });
  }
 
  private scopedRun<T>(start: () => void, end: () => void, f: () => T): T {
    start();
    try {
      const res = f();
      end();
      return res;
    } catch (ex) {
      end();
      throw ex;
    }
  }
 
  private static nextTensorId = 0;
  private nextTensorId(): number {
    return Engine.nextTensorId++;
  }
 
  private static nextVariableId = 0;
  private nextVariableId(): number {
    return Engine.nextVariableId++;
  }
 
  /**
   * This method is called instead of the public-facing tensor.clone() when
   * saving a tensor for backwards pass. It makes sure to add the clone
   * operation to the tape regardless of being called inside a kernel
   * execution.
   *
   * This method will go away once all kernels are modularized since we won't
   * need to turn off the tape inside runKernel().
   */
  private clone(x: Tensor): Tensor {
    const y = this.makeTensorFromDataId(x.dataId, x.shape, x.dtype);
    const inputs = {x};
    const grad = (dy: Tensor) => ({x: () => dy.toFloat()});
    const saved: Tensor[] = [];
    this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved);
    return y;
  }
 
  /**
   * Execute a kernel with the given name and return the output tensor.
   *
   * @param kernelName The name of the kernel to execute.
   * @param inputs A map of input names to tensors.
   * @param attrs A map of attribute names to their values. An attribute is a
   *     primitive (non-tensor) input to the kernel.
   * @param inputsToSave A list of tensors, inputs to save for the backprop
   *     computation.
   * @param outputsToSave A list of booleans, specifying which output to save
   *     for the backprop computation. These are booleans since the output
   * tensors are not visible to the user.
   */
  runKernel(
      kernelName: string, inputs: NamedTensorMap, attrs: NamedAttrMap,
      inputsToSave?: Tensor[], outputsToSave?: boolean[]): Tensor|Tensor[] {
    const forwardFunc: null = null;
    const backwardsFunc: null = null;
    // Call runKernel as a stop-gap until we modularize all kernels.
    // Once we modularize all kernels, we will remove the existing
    // `runKernelFunc`.
    return this.runKernelFunc(
        forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave,
        outputsToSave);
  }
 
  private shouldCheckForMemLeaks(): boolean {
    return this.ENV.getBool('IS_TEST');
  }
 
  private checkKernelForMemLeak(
      kernelName: string, numDataIdsBefore: number,
      outInfos: TensorInfo[]): void {
    const numDataIdsAfter = this.backend.numDataIds();
 
    // Count the number of data ids associated with the result of the kernel.
    let numOutputDataIds = 0;
    outInfos.forEach(info => {
      // Complex numbers allocate 3 data ids, one for 'real', one for
      // 'imaginary', and one for the container that holds the former two.
      numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1);
    });
 
    // Account for the number of moves during kernel execution. A "data move"
    // can happen in the middle of a kernel execution, placing a new (key,value)
    // pair in the data storage. Since data moves have net zero effect (we
    // always remove the data from the old backend), we have to cancel them out
    // when detecting memory leaks.
    const numMoves =
        this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1];
    const dataIdsLeaked =
        numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves;
    if (dataIdsLeaked > 0) {
      throw new Error(
          `Backend '${this.backendName}' has an internal memory leak ` +
          `(${dataIdsLeaked} data ids) after running '${kernelName}'`);
    }
  }
 
  /**
   * @deprecated Use `runKernel` for newly added kernels. Keep using this method
   *     only for kernels that are not yet fully modularized.
   */
  runKernelFunc<T extends Tensor|Tensor[], I extends NamedTensorMap>(
      forwardFunc: ForwardFunc<T>, inputs: I,
      backwardsFunc?: (dy: T, saved: Tensor[]) => {[P in keyof I]: () => I[P]},
      kernelName?: string, attrs?: NamedAttrMap, inputsToSave: Tensor[] = [],
      outputsToSave: boolean[] = []): T {
    let outputs: Tensor[];
    let saved: Tensor[] = [];
    const isTapeOn = this.isTapeOn();
    if (kernelName == null) {
      kernelName =
          this.state.activeScope != null ? this.state.activeScope.name : '';
    }
    const saveFunc: GradSaveFunc = (tensors) => {
      // Do not save unless we are recording to the tape. Otherwise it would
      // cause a mem leak since we would never run backprop, which disposes
      // the kept tensors.
      if (!isTapeOn) {
        return;
      }
      saved = tensors.map(tensor => this.keep(this.clone(tensor)));
    };
 
    const startingBytecount = this.state.numBytes;
    const startingNumTensors = this.state.numTensors;
 
    if (this.shouldCheckForMemLeaks()) {
      this.state.numDataMovesStack.push(0);
    }
 
    let kernelFunc: () => Tensor[];
    const kernel = getKernel(kernelName, this.backendName);
    let out: TensorInfo|TensorInfo[];
    if (kernel != null) {
      kernelFunc = () => {
        const numDataIdsBefore = this.backend.numDataIds();
        out = kernel.kernelFunc({inputs, attrs, backend: this.backend});
        const outInfos = Array.isArray(out) ? out : [out];
        if (this.shouldCheckForMemLeaks()) {
          this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
        }
        const outTensors = outInfos.map(
            ({dataId, shape, dtype}) =>
                this.makeTensorFromDataId(dataId, shape, dtype));
        const outsToSave = outTensors.filter((_, i) => outputsToSave[i]);
        // Save the inputs and outputs.
        saveFunc((inputsToSave || []).slice().concat(outsToSave));
        return outTensors;
      };
    } else {
      kernelFunc = () => {
        const numDataIdsBefore = this.backend.numDataIds();
        out = this.tidy(() => forwardFunc(this.backend, saveFunc));
        const outs = (Array.isArray(out) ? out : [out]) as Tensor[];
        if (this.shouldCheckForMemLeaks()) {
          this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outs);
        }
        return outs;
      };
    }
 
    // Stop recording to a tape when running a kernel.
    this.scopedRun(
        () => this.state.kernelDepth++, () => this.state.kernelDepth--, () => {
          if (!this.ENV.getBool('DEBUG')) {
            outputs = kernelFunc();
          } else {
            outputs = this.profiler.profileKernel(
                kernelName, inputs, () => kernelFunc());
          }
        });
 
    if (isTapeOn) {
      this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved);
    }
 
    if (this.state.profiling) {
      this.state.activeProfile.kernels.push({
        name: kernelName,
        bytesAdded: this.state.numBytes - startingBytecount,
        totalBytesSnapshot: this.state.numBytes,
        tensorsAdded: this.state.numTensors - startingNumTensors,
        totalTensorsSnapshot: this.state.numTensors,
        inputShapes: Object.keys(inputs).map(key => inputs[key].shape),
        outputShapes: outputs.map(item => item.shape)
      });
    }
    return (Array.isArray(out) ? outputs : outputs[0]) as T;
  }
 
  /**
   * Internal method used by public APIs for tensor creation. Makes a new
   * tensor with the provided shape, dtype and values. It always
   * creates a new data id and writes the values to the underlying backend.
   */
  makeTensor(
      values: DataValues, shape: number[], dtype: DataType,
      backend?: KernelBackend): Tensor {
    if (values == null) {
      throw new Error('Values passed to engine.makeTensor() are null');
    }
    dtype = dtype || 'float32';
    backend = backend || this.backend;
    let backendVals = values as BackendValues;
    if (dtype === 'string' && util.isString(values[0])) {
      backendVals = (values as string[]).map(d => util.encodeString(d));
    }
    const dataId = backend.write(backendVals, shape, dtype);
    const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
    this.incRef(t, backend);
 
    // Count bytes for string tensors.
    if (dtype === 'string') {
      const info = this.state.tensorInfo.get(dataId);
      const newBytes = bytesFromStringArray(backendVals as Uint8Array[]);
      this.state.numBytes += newBytes - info.bytes;
      info.bytes = newBytes;
    }
    return t;
  }
 
  /**
   * Internal method used by backends. Makes a new tensor
   * that is a wrapper around an existing data id. It doesn't create
   * a new data id, only increments the ref count used in memory tracking.
   */
  makeTensorFromDataId(
      dataId: DataId, shape: number[], dtype: DataType,
      backend?: KernelBackend): Tensor {
    dtype = dtype || 'float32';
    const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
    this.incRef(t, backend);
    return t;
  }
 
  makeVariable(
      initialValue: Tensor, trainable = true, name?: string,
      dtype?: DataType): Variable {
    name = name || this.nextVariableId().toString();
    if (dtype != null && dtype !== initialValue.dtype) {
      initialValue = initialValue.asType(dtype);
    }
    const v = new Variable(initialValue, trainable, name, this.nextTensorId());
    if (this.state.registeredVariables[v.name] != null) {
      throw new Error(`Variable with name ${v.name} was already registered`);
    }
    this.state.registeredVariables[v.name] = v;
    this.incRef(v, this.backend);
    return v;
  }
 
  incRef(a: Tensor, backend: KernelBackend): void {
    const refCount = this.state.tensorInfo.has(a.dataId) ?
        this.state.tensorInfo.get(a.dataId).refCount :
        0;
    this.state.numTensors++;
    if (a.dtype === 'string') {
      this.state.numStringTensors++;
    }
    if (refCount === 0) {
      this.state.numDataBuffers++;
 
      // Bytes for complex numbers are counted by their components. Bytes for
      // string tensors are counted when writing values.
      let bytes = 0;
      if (a.dtype !== 'complex64' && a.dtype !== 'string') {
        bytes = a.size * util.bytesPerElement(a.dtype);
      }
      this.state.tensorInfo.set(a.dataId, {
        backend: backend || this.backend,
        dtype: a.dtype,
        shape: a.shape,
        bytes,
        refCount: 0
      });
      this.state.numBytes += bytes;
    }
    this.state.tensorInfo.get(a.dataId).refCount++;
    if (!(a instanceof Variable)) {
      this.track(a);
    }
  }
 
  disposeTensor(a: Tensor): void {
    if (!this.state.tensorInfo.has(a.dataId)) {
      return;
    }
 
    this.state.numTensors--;
    if (a.dtype === 'string') {
      this.state.numStringTensors--;
    }
    const info = this.state.tensorInfo.get(a.dataId);
    const refCount = info.refCount;
    if (refCount <= 1) {
      // Don't count bytes for complex numbers as they are counted by their
      // components.
      if (a.dtype !== 'complex64') {
        this.state.numBytes -= info.bytes;
      }
      this.state.numDataBuffers--;
      info.backend.disposeData(a.dataId);
      this.state.tensorInfo.delete(a.dataId);
    } else {
      this.state.tensorInfo.get(a.dataId).refCount--;
    }
    // TODO(nsthorat): Construct an error and save the stack trace for
    // debugging when in debug mode. Creating a stack trace is too expensive
    // to do unconditionally.
  }
 
  disposeVariables(): void {
    for (const varName in this.state.registeredVariables) {
      const v = this.state.registeredVariables[varName];
      this.disposeVariable(v);
    }
  }
 
  disposeVariable(v: Variable): void {
    this.disposeTensor(v);
    if (this.state.registeredVariables[v.name] != null) {
      delete this.state.registeredVariables[v.name];
    }
  }
 
  memory(): MemoryInfo {
    const info = this.backend.memory() as MemoryInfo;
    info.numTensors = this.state.numTensors;
    info.numDataBuffers = this.state.numDataBuffers;
    info.numBytes = this.state.numBytes;
    if (this.state.numStringTensors > 0) {
      info.unreliable = true;
      if (info.reasons == null) {
        info.reasons = [];
      }
      info.reasons.push(
          'Memory usage by string tensors is approximate ' +
          '(2 bytes per character)');
    }
    return info;
  }
 
  async profile(query: () => TensorContainer): Promise<ProfileInfo> {
    this.state.profiling = true;
 
    const startBytes = this.state.numBytes;
    const startNumTensors = this.state.numTensors;
 
    this.state.activeProfile.kernels = [];
    this.state.activeProfile.result = query();
 
    this.state.profiling = false;
 
    this.state.activeProfile.peakBytes = Math.max(
        ...this.state.activeProfile.kernels.map(d => d.totalBytesSnapshot));
    this.state.activeProfile.newBytes = this.state.numBytes - startBytes;
    this.state.activeProfile.newTensors =
        this.state.numTensors - startNumTensors;
    return this.state.activeProfile;
  }
 
  isTapeOn(): boolean {
    return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
  }
 
  private addTapeNode(
      kernelName: string, inputs: NamedTensorMap, outputs: Tensor[],
      gradientsFunc: (dy: Tensor|Tensor[], saved: Tensor[]) => NamedGradientMap,
      saved: Tensor[]): void {
    const tapeNode: TapeNode =
        {id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved};
 
    const gradConfig = getGradient(kernelName);
    if (gradConfig != null) {
      gradientsFunc = gradConfig.gradFunc;
    }
    if (gradientsFunc != null) {
      tapeNode.gradient = (dys: Tensor[]) => {
        // TODO(smilkov): To optimize back-prop, pass dys that are not used in
        // the backprop graph to the user as null instead of zeros
        dys = dys.map((dy, i) => {
          if (dy == null) {
            const output = outputs[i];
            const vals = util.makeZerosTypedArray(output.size, output.dtype);
            return this.makeTensor(vals, output.shape, output.dtype);
          }
          return dy;
        });
        // Grad functions of ops with single outputs expect a dy, while ops
        // with multiple outputs expect dys (array of dy).
        return gradientsFunc(dys.length > 1 ? dys : dys[0], saved);
      };
    }
    this.state.activeTape.push(tapeNode);
  }
 
  keep<T extends Tensor>(result: T): T {
    result.kept = true;
    return result;
  }
 
  private startTape() {
    if (this.state.gradientDepth === 0) {
      this.state.activeTape = [];
    }
    this.state.gradientDepth++;
  }
 
  private endTape() {
    this.state.gradientDepth--;
  }
 
  /**
   * Start a scope. Use this with endScope() to achieve the same functionality
   * as scope() without the need for a function closure.
   */
  startScope(name?: string) {
    const scopeInfo: ScopeState = {
      track: [],
      name: 'unnamed scope',
      id: this.state.nextScopeId++
    };
    if (name) {
      scopeInfo.name = name;
    }
    this.state.scopeStack.push(scopeInfo);
    this.state.activeScope = scopeInfo;
  }
 
  /**
   * End a scope. Use this with startScope() to achieve the same functionality
   * as scope() without the need for a function closure.
   */
  endScope(result?: TensorContainer) {
    const tensorsToTrackInParent = getTensorsInContainer(result);
    const tensorsToTrackInParentSet =
        new Set(tensorsToTrackInParent.map(t => t.id));
 
    // Dispose the arrays tracked in this scope.
    for (let i = 0; i < this.state.activeScope.track.length; i++) {
      const tensor = this.state.activeScope.track[i];
      if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) {
        tensor.dispose();
      }
    }
 
    const oldScope = this.state.scopeStack.pop();
    this.state.activeScope = this.state.scopeStack.length === 0 ?
        null :
        this.state.scopeStack[this.state.scopeStack.length - 1];
 
    // Track the current result in the parent scope.
    tensorsToTrackInParent.forEach(tensor => {
      // Only track the tensor if was allocated in the inner scope and is not
      // globally kept.
      if (!tensor.kept && tensor.scopeId === oldScope.id) {
        this.track(tensor);
      }
    });
  }
 
  /**
   * Returns gradients of `f` with respect to each of the `xs`. The gradients
   * returned are of the same length as `xs`, but some might be null if `f`
   * was not a function of that `x`. It also takes optional dy to multiply the
   * gradient, which defaults to `1`.
   */
  gradients<T extends Tensor>(
      f: () => T, xs: Tensor[], dy?: T,
      allowNoGradients = false): {value: T, grads: Tensor[]} {
    util.assert(
        xs.length > 0, () => 'gradients() received an empty list of xs.');
    if (dy != null && dy.dtype !== 'float32') {
      throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`);
    }
 
    const y = this.scopedRun(
        () => this.startTape(), () => this.endTape(),
        () => this.tidy('forward', f));
 
    util.assert(
        y instanceof Tensor,
        () => 'The result y returned by f() must be a tensor.');
    // Filter out the nodes that don't connect x => y.
    const filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y);
    if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
      throw new Error(
          'Cannot compute gradient of y=f(x) with respect to x. Make sure ' +
          'that the f you passed encloses all operations that lead from x ' +
          'to y.');
    }
 
    return this.tidy('backward', () => {
      const accumulatedGradientMap: {[tensorId: number]: Tensor} = {};
      accumulatedGradientMap[y.id] = (dy == null) ? ones(y.shape) : dy;
 
      // Backprop gradients through the filtered nodes.
      backpropagateGradients(
          accumulatedGradientMap, filteredTape,
          // Pass the tidy function to avoid circular dep with `tape.ts`.
          f => this.tidy(f as ScopeFn<Tensor>));
      const grads = xs.map(x => accumulatedGradientMap[x.id]);
 
      if (this.state.gradientDepth === 0) {
        // This means that we are not computing higher-order gradients
        // and can clean up the tape.
        this.state.activeTape.forEach(node => {
          for (const tensor of node.saved) {
            tensor.dispose();
          }
        });
        this.state.activeTape = null;
      }
      return {value: y, grads};
    });
  }
 
  customGrad<T extends Tensor>(f: CustomGradientFunc<T>):
      (...args: Array<Tensor|GradSaveFunc>) => T {
    util.assert(
        util.isFunction(f),
        () => 'The f passed in customGrad(f) must be a function.');
    return (...inputs: Tensor[]): T => {
      util.assert(
          inputs.every(t => t instanceof Tensor),
          () => 'The args passed in customGrad(f)(x1, x2,...) must all be ' +
              'tensors');
 
      let res: {
        value: T,
        gradFunc: (dy: T, saved: Tensor[]) => Tensor | Tensor[],
      };
      const inputMap: NamedTensorMap = {};
      inputs.forEach((input, i) => {
        inputMap[i] = input;
      });
      return this.runKernelFunc(
          (_, save) => {
            res = f(...[...inputs, save]);
            util.assert(
                res.value instanceof Tensor,
                () => 'The function f passed in customGrad(f) must return an ' +
                    'object where `obj.value` is a tensor');
            util.assert(
                util.isFunction(res.gradFunc),
                () => 'The function f passed in customGrad(f) must return an ' +
                    'object where `obj.gradFunc` is a function.');
            return res.value;
          },
          inputMap,
          (dy: T, saved: Tensor[]) => {
            const gradRes = res.gradFunc(dy, saved);
            const grads: Tensor[] =
                Array.isArray(gradRes) ? gradRes : [gradRes];
            util.assert(
                grads.length === inputs.length,
                () => 'The function f passed in customGrad(f) must return an ' +
                    'object where `obj.gradFunc` is a function that returns ' +
                    'the same number of tensors as inputs passed to f(...).');
            util.assert(
                grads.every(t => t instanceof Tensor),
                () => 'The function f passed in customGrad(f) must return an ' +
                    'object where `obj.gradFunc` is a function that returns ' +
                    'a list of only tensors.');
            const gradMap: {[key: string]: () => Tensor} = {};
            grads.forEach((grad, i) => {
              gradMap[i] = () => grad;
            });
            return gradMap;
          });
    };
  }
 
  readSync(dataId: DataId): BackendValues {
    // Route the read to the correct backend.
    const info = this.state.tensorInfo.get(dataId);
    return info.backend.readSync(dataId);
  }
  read(dataId: DataId): Promise<BackendValues> {
    // Route the read to the correct backend.
    const info = this.state.tensorInfo.get(dataId);
    return info.backend.read(dataId);
  }
 
  async time(query: () => void): Promise<TimingInfo> {
    const start = now();
    const timingInfo = await this.backend.time(query) as TimingInfo;
    timingInfo.wallMs = now() - start;
    return timingInfo;
  }
 
  /**
   * Tracks a Tensor in the current scope to be automatically cleaned up
   * when the current scope ends, and returns the value.
   *
   * @param result The Tensor to track in the current scope.
   */
  private track<T extends Tensor>(result: T): T {
    if (this.state.activeScope != null) {
      result.scopeId = this.state.activeScope.id;
      this.state.activeScope.track.push(result);
    }
 
    return result;
  }
 
  get registeredVariables(): NamedVariableMap {
    return this.state.registeredVariables;
  }
 
  /**
   * Resets the engine state. Removes all backends but does not remove
   * registered backend factories.
   */
  reset(): void {
    // Make any pending promise obsolete.
    this.pendingBackendInitId++;
 
    this.state.dispose();
    this.ENV.reset();
    this.state = new EngineState();
 
    for (const backendName in this.registry) {
      this.disposeRegisteredKernels(backendName);
      this.registry[backendName].dispose();
      delete this.registry[backendName];
    }
    this.backendName = null;
    this.backendInstance = null;
    this.pendingBackendInit = null;
  }
}
 
function ones(shape: number[]): Tensor {
  const values = makeOnesTypedArray(sizeFromShape(shape), 'float32');
  return ENGINE.makeTensor(values, shape, 'float32');
}
 
let GLOBAL: {_tfengine: Engine};
function getGlobalNamespace(): {_tfengine: Engine} {
  if (GLOBAL == null) {
    // tslint:disable-next-line:no-any
    let ns: any;
    if (typeof (window) !== 'undefined') {
      ns = window;
    } else if (typeof (global) !== 'undefined') {
      ns = global;
    } else if (typeof (process) !== 'undefined') {
      ns = process;
    } else if (typeof (self) !== 'undefined') {
      ns = self;
    } else {
      throw new Error('Could not find a global object');
    }
    GLOBAL = ns;
  }
  return GLOBAL;
}
 
function getOrMakeEngine(): Engine {
  const ns = getGlobalNamespace();
  if (ns._tfengine == null) {
    const environment = new Environment(ns);
    ns._tfengine = new Engine(environment);
  }
  setEnvironmentGlobal(ns._tfengine.ENV);
 
  // Tell the current tensor interface that the global engine is responsible
  // for tracking.
  setTensorTracker(() => ns._tfengine);
  return ns._tfengine;
}
 
export const ENGINE = getOrMakeEngine();