gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import * as tf from '../../dist/tfjs.esm';
 
import { ConvParams, disposeUnusedWeightTensors, extractWeightEntryFactory, ParamMapping } from '../common/index';
import { isTensor3D } from '../utils/index';
import { BoxPredictionParams, MobileNetV1, NetParams, PointwiseConvParams, PredictionLayerParams } from './types';
 
function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
  const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings);
 
  function extractPointwiseConvParams(prefix: string, idx: number, mappedPrefix: string): PointwiseConvParams {
    const filters = extractWeightEntry(`${prefix}/Conv2d_${idx}_pointwise/weights`, 4, `${mappedPrefix}/filters`);
    const batch_norm_offset = extractWeightEntry(`${prefix}/Conv2d_${idx}_pointwise/convolution_bn_offset`, 1, `${mappedPrefix}/batch_norm_offset`);
    return { filters, batch_norm_offset };
  }
 
  function extractConvPairParams(idx: number): MobileNetV1.ConvPairParams {
    const mappedPrefix = `mobilenetv1/conv_${idx}`;
    const prefixDepthwiseConv = `MobilenetV1/Conv2d_${idx}_depthwise`;
    const mappedPrefixDepthwiseConv = `${mappedPrefix}/depthwise_conv`;
    const mappedPrefixPointwiseConv = `${mappedPrefix}/pointwise_conv`;
 
    const filters = extractWeightEntry(`${prefixDepthwiseConv}/depthwise_weights`, 4, `${mappedPrefixDepthwiseConv}/filters`);
    const batch_norm_scale = extractWeightEntry(`${prefixDepthwiseConv}/BatchNorm/gamma`, 1, `${mappedPrefixDepthwiseConv}/batch_norm_scale`);
    const batch_norm_offset = extractWeightEntry(`${prefixDepthwiseConv}/BatchNorm/beta`, 1, `${mappedPrefixDepthwiseConv}/batch_norm_offset`);
    const batch_norm_mean = extractWeightEntry(`${prefixDepthwiseConv}/BatchNorm/moving_mean`, 1, `${mappedPrefixDepthwiseConv}/batch_norm_mean`);
    const batch_norm_variance = extractWeightEntry(`${prefixDepthwiseConv}/BatchNorm/moving_variance`, 1, `${mappedPrefixDepthwiseConv}/batch_norm_variance`);
 
    return {
      depthwise_conv: {
        filters,
        batch_norm_scale,
        batch_norm_offset,
        batch_norm_mean,
        batch_norm_variance,
      },
      pointwise_conv: extractPointwiseConvParams('MobilenetV1', idx, mappedPrefixPointwiseConv),
    };
  }
 
  function extractMobilenetV1Params(): MobileNetV1.Params {
    return {
      conv_0: extractPointwiseConvParams('MobilenetV1', 0, 'mobilenetv1/conv_0'),
      conv_1: extractConvPairParams(1),
      conv_2: extractConvPairParams(2),
      conv_3: extractConvPairParams(3),
      conv_4: extractConvPairParams(4),
      conv_5: extractConvPairParams(5),
      conv_6: extractConvPairParams(6),
      conv_7: extractConvPairParams(7),
      conv_8: extractConvPairParams(8),
      conv_9: extractConvPairParams(9),
      conv_10: extractConvPairParams(10),
      conv_11: extractConvPairParams(11),
      conv_12: extractConvPairParams(12),
      conv_13: extractConvPairParams(13),
    };
  }
 
  function extractConvParams(prefix: string, mappedPrefix: string): ConvParams {
    const filters = extractWeightEntry(`${prefix}/weights`, 4, `${mappedPrefix}/filters`);
    const bias = extractWeightEntry(`${prefix}/biases`, 1, `${mappedPrefix}/bias`);
    return { filters, bias };
  }
 
  function extractBoxPredictorParams(idx: number): BoxPredictionParams {
    const box_encoding_predictor = extractConvParams(
      `Prediction/BoxPredictor_${idx}/BoxEncodingPredictor`,
      `prediction_layer/box_predictor_${idx}/box_encoding_predictor`,
    );
    const class_predictor = extractConvParams(
      `Prediction/BoxPredictor_${idx}/ClassPredictor`,
      `prediction_layer/box_predictor_${idx}/class_predictor`,
    );
    return { box_encoding_predictor, class_predictor };
  }
 
  function extractPredictionLayerParams(): PredictionLayerParams {
    return {
      conv_0: extractPointwiseConvParams('Prediction', 0, 'prediction_layer/conv_0'),
      conv_1: extractPointwiseConvParams('Prediction', 1, 'prediction_layer/conv_1'),
      conv_2: extractPointwiseConvParams('Prediction', 2, 'prediction_layer/conv_2'),
      conv_3: extractPointwiseConvParams('Prediction', 3, 'prediction_layer/conv_3'),
      conv_4: extractPointwiseConvParams('Prediction', 4, 'prediction_layer/conv_4'),
      conv_5: extractPointwiseConvParams('Prediction', 5, 'prediction_layer/conv_5'),
      conv_6: extractPointwiseConvParams('Prediction', 6, 'prediction_layer/conv_6'),
      conv_7: extractPointwiseConvParams('Prediction', 7, 'prediction_layer/conv_7'),
      box_predictor_0: extractBoxPredictorParams(0),
      box_predictor_1: extractBoxPredictorParams(1),
      box_predictor_2: extractBoxPredictorParams(2),
      box_predictor_3: extractBoxPredictorParams(3),
      box_predictor_4: extractBoxPredictorParams(4),
      box_predictor_5: extractBoxPredictorParams(5),
    };
  }
 
  return {
    extractMobilenetV1Params,
    extractPredictionLayerParams,
  };
}
 
export function extractParamsFromWeightMap(
  weightMap: tf.NamedTensorMap,
): { params: NetParams, paramMappings: ParamMapping[] } {
  const paramMappings: ParamMapping[] = [];
  const {
    extractMobilenetV1Params,
    extractPredictionLayerParams,
  } = extractorsFactory(weightMap, paramMappings);
  const extra_dim = weightMap['Output/extra_dim'];
  paramMappings.push({ originalPath: 'Output/extra_dim', paramPath: 'output_layer/extra_dim' });
  if (!isTensor3D(extra_dim)) {
    throw new Error(`expected weightMap['Output/extra_dim'] to be a Tensor3D, instead have ${extra_dim}`);
  }
 
  const params = {
    mobilenetv1: extractMobilenetV1Params(),
    prediction_layer: extractPredictionLayerParams(),
    output_layer: {
      extra_dim,
    },
  };
 
  disposeUnusedWeightTensors(weightMap, paramMappings);
  return { params, paramMappings };
}