gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import * as tf from '../../dist/tfjs.esm';
 
import { Rect } from '../classes/index';
import { FaceDetection } from '../classes/FaceDetection';
import { NetInput, TNetInput, toNetInput } from '../dom/index';
import { NeuralNetwork } from '../NeuralNetwork';
import { extractParams } from './extractParams';
import { extractParamsFromWeightMap } from './extractParamsFromWeightMap';
import { mobileNetV1 } from './mobileNetV1';
import { nonMaxSuppression } from './nonMaxSuppression';
import { outputLayer } from './outputLayer';
import { predictionLayer } from './predictionLayer';
import { ISsdMobilenetv1Options, SsdMobilenetv1Options } from './SsdMobilenetv1Options';
import { NetParams } from './types';
 
export class SsdMobilenetv1 extends NeuralNetwork<NetParams> {
  constructor() {
    super('SsdMobilenetv1');
  }
 
  public forwardInput(input: NetInput) {
    const { params } = this;
    if (!params) throw new Error('SsdMobilenetv1 - load model before inference');
    return tf.tidy(() => {
      const batchTensor = tf.cast(input.toBatchTensor(512, false), 'float32');
      const x = tf.sub(tf.div(batchTensor, 127.5), 1) as tf.Tensor4D; // input is normalized -1..1
      const features = mobileNetV1(x, params.mobilenetv1);
      const { boxPredictions, classPredictions } = predictionLayer(features.out, features.conv11, params.prediction_layer);
      return outputLayer(boxPredictions, classPredictions, params.output_layer);
    });
  }
 
  public async forward(input: TNetInput) {
    return this.forwardInput(await toNetInput(input));
  }
 
  public async locateFaces(input: TNetInput, options: ISsdMobilenetv1Options = {}): Promise<FaceDetection[]> {
    const { maxResults, minConfidence } = new SsdMobilenetv1Options(options);
    const netInput = await toNetInput(input);
    const { boxes: _boxes, scores: _scores } = this.forwardInput(netInput);
    const boxes = _boxes[0];
    const scores = _scores[0];
    for (let i = 1; i < _boxes.length; i++) {
      _boxes[i].dispose();
      _scores[i].dispose();
    }
    const scoresData = Array.from(scores.dataSync());
    const iouThreshold = 0.5;
    const indices = nonMaxSuppression(boxes, scoresData as number[], maxResults, iouThreshold, minConfidence);
    const reshapedDims = netInput.getReshapedInputDimensions(0);
    const inputSize = netInput.inputSize as number;
    const padX = inputSize / reshapedDims.width;
    const padY = inputSize / reshapedDims.height;
    const boxesData = boxes.arraySync();
    const results = indices
      .map((idx) => {
        const [top, bottom] = [
          Math.max(0, boxesData[idx][0]),
          Math.min(1.0, boxesData[idx][2]),
        ].map((val) => val * padY);
        const [left, right] = [
          Math.max(0, boxesData[idx][1]),
          Math.min(1.0, boxesData[idx][3]),
        ].map((val) => val * padX);
        return new FaceDetection(
          scoresData[idx] as number,
          new Rect(left, top, right - left, bottom - top),
          { height: netInput.getInputHeight(0), width: netInput.getInputWidth(0) },
        );
      });
    boxes.dispose();
    scores.dispose();
    return results;
  }
 
  protected getDefaultModelName(): string {
    return 'ssd_mobilenetv1_model';
  }
 
  protected extractParamsFromWeightMap(weightMap: tf.NamedTensorMap) {
    return extractParamsFromWeightMap(weightMap);
  }
 
  protected extractParams(weights: Float32Array) {
    return extractParams(weights);
  }
}