gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import * as tf from '../../dist/tfjs.esm';
 
import { ConvParams } from '../common/index';
import { disposeUnusedWeightTensors } from '../common/disposeUnusedWeightTensors';
import { loadSeparableConvParamsFactory } from '../common/extractSeparableConvParamsFactory';
import { extractWeightEntryFactory } from '../common/extractWeightEntryFactory';
import { ParamMapping } from '../common/types';
import { TinyYolov2Config } from './config';
import { BatchNorm, ConvWithBatchNorm, TinyYolov2NetParams } from './types';
 
function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
  const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings);
 
  function extractBatchNormParams(prefix: string): BatchNorm {
    const sub = extractWeightEntry(`${prefix}/sub`, 1);
    const truediv = extractWeightEntry(`${prefix}/truediv`, 1);
    return { sub, truediv };
  }
 
  function extractConvParams(prefix: string): ConvParams {
    const filters = extractWeightEntry(`${prefix}/filters`, 4);
    const bias = extractWeightEntry(`${prefix}/bias`, 1);
    return { filters, bias };
  }
 
  function extractConvWithBatchNormParams(prefix: string): ConvWithBatchNorm {
    const conv = extractConvParams(`${prefix}/conv`);
    const bn = extractBatchNormParams(`${prefix}/bn`);
    return { conv, bn };
  }
 
  const extractSeparableConvParams = loadSeparableConvParamsFactory(extractWeightEntry);
  return {
    extractConvParams,
    extractConvWithBatchNormParams,
    extractSeparableConvParams,
  };
}
 
export function extractParamsFromWeightMap(
  weightMap: tf.NamedTensorMap,
  config: TinyYolov2Config,
): { params: TinyYolov2NetParams, paramMappings: ParamMapping[] } {
  const paramMappings: ParamMapping[] = [];
 
  const {
    extractConvParams,
    extractConvWithBatchNormParams,
    extractSeparableConvParams,
  } = extractorsFactory(weightMap, paramMappings);
 
  let params: TinyYolov2NetParams;
 
  if (config.withSeparableConvs) {
    // eslint-disable-next-line no-mixed-operators
    const numFilters = (config.filterSizes && config.filterSizes.length || 9);
    params = {
      conv0: config.isFirstLayerConv2d ? extractConvParams('conv0') : extractSeparableConvParams('conv0'),
      conv1: extractSeparableConvParams('conv1'),
      conv2: extractSeparableConvParams('conv2'),
      conv3: extractSeparableConvParams('conv3'),
      conv4: extractSeparableConvParams('conv4'),
      conv5: extractSeparableConvParams('conv5'),
      conv6: numFilters > 7 ? extractSeparableConvParams('conv6') : undefined,
      conv7: numFilters > 8 ? extractSeparableConvParams('conv7') : undefined,
      conv8: extractConvParams('conv8'),
    };
  } else {
    params = {
      conv0: extractConvWithBatchNormParams('conv0'),
      conv1: extractConvWithBatchNormParams('conv1'),
      conv2: extractConvWithBatchNormParams('conv2'),
      conv3: extractConvWithBatchNormParams('conv3'),
      conv4: extractConvWithBatchNormParams('conv4'),
      conv5: extractConvWithBatchNormParams('conv5'),
      conv6: extractConvWithBatchNormParams('conv6'),
      conv7: extractConvWithBatchNormParams('conv7'),
      conv8: extractConvParams('conv8'),
    };
  }
 
  disposeUnusedWeightTensors(weightMap, paramMappings);
  return { params, paramMappings };
}