gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import * as tf from '../../dist/tfjs.esm';
 
import { FaceDetection, Point } from '../classes/index';
import { ParamMapping } from '../common/types';
import { TNetInput } from '../dom/types';
import {
  BOX_ANCHORS,
  BOX_ANCHORS_SEPARABLE,
  DEFAULT_MODEL_NAME,
  DEFAULT_MODEL_NAME_SEPARABLE_CONV,
  IOU_THRESHOLD,
  MEAN_RGB_SEPARABLE,
} from './const';
import { TinyYolov2Base } from './TinyYolov2Base';
import { ITinyYolov2Options } from './TinyYolov2Options';
import { TinyYolov2NetParams } from './types';
 
export class TinyYolov2 extends TinyYolov2Base {
  constructor(withSeparableConvs = true) {
    const config = {
      withSeparableConvs,
      iouThreshold: IOU_THRESHOLD,
      classes: ['face'],
      ...(withSeparableConvs
        ? {
          anchors: BOX_ANCHORS_SEPARABLE,
          meanRgb: MEAN_RGB_SEPARABLE,
        }
        : {
          anchors: BOX_ANCHORS,
          withClassScores: true,
        }),
    };
 
    super(config);
  }
 
  public get withSeparableConvs(): boolean {
    return this.config.withSeparableConvs;
  }
 
  public get anchors(): Point[] {
    return this.config.anchors;
  }
 
  public async locateFaces(input: TNetInput, forwardParams: ITinyYolov2Options): Promise<FaceDetection[]> {
    const objectDetections = await this.detect(input, forwardParams);
    return objectDetections.map((det) => new FaceDetection(det.score, det.relativeBox, { width: det.imageWidth, height: det.imageHeight }));
  }
 
  protected override getDefaultModelName(): string {
    return this.withSeparableConvs ? DEFAULT_MODEL_NAME_SEPARABLE_CONV : DEFAULT_MODEL_NAME;
  }
 
  protected override extractParamsFromWeightMap(weightMap: tf.NamedTensorMap): { params: TinyYolov2NetParams, paramMappings: ParamMapping[] } {
    return super.extractParamsFromWeightMap(weightMap);
  }
}