gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import * as tf from '../../dist/tfjs.esm';
 
import { Dimensions } from '../classes/Dimensions';
import { env } from '../env/index';
import { padToSquare } from '../ops/padToSquare';
import { computeReshapedDimensions, isTensor3D, isTensor4D, range } from '../utils/index';
import { createCanvasFromMedia } from './createCanvas';
import { imageToSquare } from './imageToSquare';
import { TResolvedNetInput } from './types';
 
export class NetInput {
  private _imageTensors: Array<tf.Tensor3D | tf.Tensor4D> = [];
 
  private _canvases: HTMLCanvasElement[] = [];
 
  private _batchSize: number;
 
  private _treatAsBatchInput = false;
 
  private _inputDimensions: number[][] = [];
 
  private _inputSize = 0;
 
  constructor(inputs: Array<TResolvedNetInput>, treatAsBatchInput = false) {
    if (!Array.isArray(inputs)) {
      throw new Error(`NetInput.constructor - expected inputs to be an Array of TResolvedNetInput or to be instanceof tf.Tensor4D, instead have ${inputs}`);
    }
 
    this._treatAsBatchInput = treatAsBatchInput;
    this._batchSize = inputs.length;
 
    inputs.forEach((input, idx) => {
      if (isTensor3D(input)) {
        this._imageTensors[idx] = input;
        this._inputDimensions[idx] = input.shape;
        return;
      }
 
      if (isTensor4D(input)) {
        const batchSize = (input as any).shape[0];
        if (batchSize !== 1) {
          throw new Error(`NetInput - tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`);
        }
 
        this._imageTensors[idx] = input;
        this._inputDimensions[idx] = (input as any).shape.slice(1);
        return;
      }
 
      // @ts-ignore
      const canvas = (input as any) instanceof env.getEnv().Canvas ? input : createCanvasFromMedia(input);
      this._canvases[idx] = canvas as HTMLCanvasElement;
      this._inputDimensions[idx] = [canvas.height, canvas.width, 3];
    });
  }
 
  public get imageTensors(): Array<tf.Tensor3D | tf.Tensor4D> {
    return this._imageTensors;
  }
 
  public get canvases(): HTMLCanvasElement[] {
    return this._canvases;
  }
 
  public get isBatchInput(): boolean {
    return this.batchSize > 1 || this._treatAsBatchInput;
  }
 
  public get batchSize(): number {
    return this._batchSize;
  }
 
  public get inputDimensions(): number[][] {
    return this._inputDimensions;
  }
 
  public get inputSize(): number | undefined {
    return this._inputSize;
  }
 
  public get reshapedInputDimensions(): Dimensions[] {
    return range(this.batchSize, 0, 1).map(
      (_, batchIdx) => this.getReshapedInputDimensions(batchIdx),
    );
  }
 
  public getInput(batchIdx: number): tf.Tensor3D | tf.Tensor4D | HTMLCanvasElement {
    return this.canvases[batchIdx] || this.imageTensors[batchIdx];
  }
 
  public getInputDimensions(batchIdx: number): number[] {
    return this._inputDimensions[batchIdx];
  }
 
  public getInputHeight(batchIdx: number): number {
    return this._inputDimensions[batchIdx][0];
  }
 
  public getInputWidth(batchIdx: number): number {
    return this._inputDimensions[batchIdx][1];
  }
 
  public getReshapedInputDimensions(batchIdx: number): Dimensions {
    if (typeof this.inputSize !== 'number') {
      throw new Error('getReshapedInputDimensions - inputSize not set, toBatchTensor has not been called yet');
    }
 
    const width = this.getInputWidth(batchIdx);
    const height = this.getInputHeight(batchIdx);
    return computeReshapedDimensions({ width, height }, this.inputSize);
  }
 
  /**
   * Create a batch tensor from all input canvases and tensors
   * with size [batchSize, inputSize, inputSize, 3].
   *
   * @param inputSize Height and width of the tensor.
   * @param isCenterImage (optional, default: false) If true, add an equal amount of padding on
   * both sides of the minor dimension oof the image.
   * @returns The batch tensor.
   */
  public toBatchTensor(inputSize: number, isCenterInputs = true): tf.Tensor4D {
    this._inputSize = inputSize;
 
    return tf.tidy(() => {
      const inputTensors = range(this.batchSize, 0, 1).map((batchIdx) => {
        const input = this.getInput(batchIdx);
 
        if (input instanceof tf.Tensor) {
          let imgTensor = isTensor4D(input) ? input : tf.expandDims(input);
          imgTensor = padToSquare(imgTensor as tf.Tensor4D, isCenterInputs);
 
          if (imgTensor.shape[1] !== inputSize || imgTensor.shape[2] !== inputSize) {
            imgTensor = tf['image'].resizeBilinear(imgTensor as tf.Tensor4D, [inputSize, inputSize], false, false);
          }
 
          return imgTensor.as3D(inputSize, inputSize, 3);
        }
 
        if (input instanceof env.getEnv().Canvas) {
          return tf['browser'].fromPixels(imageToSquare(input, inputSize, isCenterInputs));
        }
 
        throw new Error(`toBatchTensor - at batchIdx ${batchIdx}, expected input to be instanceof tf.Tensor or instanceof HTMLCanvasElement, instead have ${input}`);
      });
 
      const batchTensor = tf.stack(inputTensors.map((t) => tf.cast(t, 'float32'))).as4D(this.batchSize, inputSize, inputSize, 3);
      // const batchTensor = tf.stack(inputTensors.map((t) => tf.cast(t, 'float32'))) as tf.Tensor4D;
 
      return batchTensor;
    });
  }
}