gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import * as tf from '../../dist/tfjs.esm';
 
import { NetInput, TNetInput, toNetInput } from '../dom/index';
import { FaceFeatureExtractor } from '../faceFeatureExtractor/FaceFeatureExtractor';
import { FaceFeatureExtractorParams } from '../faceFeatureExtractor/types';
import { FaceProcessor } from '../faceProcessor/FaceProcessor';
import { FaceExpressions } from './FaceExpressions';
 
export class FaceExpressionNet extends FaceProcessor<FaceFeatureExtractorParams> {
  constructor(faceFeatureExtractor: FaceFeatureExtractor = new FaceFeatureExtractor()) {
    super('FaceExpressionNet', faceFeatureExtractor);
  }
 
  public forwardInput(input: NetInput | tf.Tensor4D): tf.Tensor2D {
    return tf.tidy(() => tf.softmax(this.runNet(input)));
  }
 
  public async forward(input: TNetInput): Promise<tf.Tensor2D> {
    return this.forwardInput(await toNetInput(input));
  }
 
  public async predictExpressions(input: TNetInput) {
    const netInput = await toNetInput(input);
    const out = await this.forwardInput(netInput);
    const probabilitesByBatch = await Promise.all(tf.unstack(out).map(async (t) => {
      const data = t.dataSync();
      t.dispose();
      return data;
    }));
    out.dispose();
 
    const predictionsByBatch = probabilitesByBatch
      .map((probabilites) => new FaceExpressions(probabilites as Float32Array));
 
    return netInput.isBatchInput
      ? predictionsByBatch
      : predictionsByBatch[0];
  }
 
  protected getDefaultModelName(): string {
    return 'face_expression_model';
  }
 
  protected getClassifierChannelsIn(): number {
    return 256;
  }
 
  protected getClassifierChannelsOut(): number {
    return 7;
  }
}