gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import * as tf from '../../dist/tfjs.esm';
 
import { extractConvParamsFactory } from '../common/index';
import { extractSeparableConvParamsFactory } from '../common/extractSeparableConvParamsFactory';
import { extractWeightsFactory } from '../common/extractWeightsFactory';
import { ExtractWeightsFunction, ParamMapping } from '../common/types';
import { TinyYolov2Config } from './config';
import { BatchNorm, ConvWithBatchNorm, TinyYolov2NetParams } from './types';
 
function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
  const extractConvParams = extractConvParamsFactory(extractWeights, paramMappings);
 
  function extractBatchNormParams(size: number, mappedPrefix: string): BatchNorm {
    const sub = tf.tensor1d(extractWeights(size));
    const truediv = tf.tensor1d(extractWeights(size));
 
    paramMappings.push(
      { paramPath: `${mappedPrefix}/sub` },
      { paramPath: `${mappedPrefix}/truediv` },
    );
    return { sub, truediv };
  }
 
  function extractConvWithBatchNormParams(channelsIn: number, channelsOut: number, mappedPrefix: string): ConvWithBatchNorm {
    const conv = extractConvParams(channelsIn, channelsOut, 3, `${mappedPrefix}/conv`);
    const bn = extractBatchNormParams(channelsOut, `${mappedPrefix}/bn`);
    return { conv, bn };
  }
  const extractSeparableConvParams = extractSeparableConvParamsFactory(extractWeights, paramMappings);
 
  return {
    extractConvParams,
    extractConvWithBatchNormParams,
    extractSeparableConvParams,
  };
}
 
export function extractParams(
  weights: Float32Array,
  config: TinyYolov2Config,
  boxEncodingSize: number,
  filterSizes: number[],
): { params: TinyYolov2NetParams, paramMappings: ParamMapping[] } {
  const {
    extractWeights,
    getRemainingWeights,
  } = extractWeightsFactory(weights);
 
  const paramMappings: ParamMapping[] = [];
  const {
    extractConvParams,
    extractConvWithBatchNormParams,
    extractSeparableConvParams,
  } = extractorsFactory(extractWeights, paramMappings);
  let params: TinyYolov2NetParams;
 
  if (config.withSeparableConvs) {
    const [s0, s1, s2, s3, s4, s5, s6, s7, s8] = filterSizes;
    const conv0 = config.isFirstLayerConv2d
      ? extractConvParams(s0, s1, 3, 'conv0')
      : extractSeparableConvParams(s0, s1, 'conv0');
    const conv1 = extractSeparableConvParams(s1, s2, 'conv1');
    const conv2 = extractSeparableConvParams(s2, s3, 'conv2');
    const conv3 = extractSeparableConvParams(s3, s4, 'conv3');
    const conv4 = extractSeparableConvParams(s4, s5, 'conv4');
    const conv5 = extractSeparableConvParams(s5, s6, 'conv5');
    const conv6 = s7 ? extractSeparableConvParams(s6, s7, 'conv6') : undefined;
    const conv7 = s8 ? extractSeparableConvParams(s7, s8, 'conv7') : undefined;
    const conv8 = extractConvParams(s8 || s7 || s6, 5 * boxEncodingSize, 1, 'conv8');
    params = {
      conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,
    };
  } else {
    const [s0, s1, s2, s3, s4, s5, s6, s7, s8] = filterSizes;
    const conv0 = extractConvWithBatchNormParams(s0, s1, 'conv0');
    const conv1 = extractConvWithBatchNormParams(s1, s2, 'conv1');
    const conv2 = extractConvWithBatchNormParams(s2, s3, 'conv2');
    const conv3 = extractConvWithBatchNormParams(s3, s4, 'conv3');
    const conv4 = extractConvWithBatchNormParams(s4, s5, 'conv4');
    const conv5 = extractConvWithBatchNormParams(s5, s6, 'conv5');
    const conv6 = extractConvWithBatchNormParams(s6, s7, 'conv6');
    const conv7 = extractConvWithBatchNormParams(s7, s8, 'conv7');
    const conv8 = extractConvParams(s8, 5 * boxEncodingSize, 1, 'conv8');
    params = {
      conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,
    };
  }
  if (getRemainingWeights().length !== 0) {
    throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`);
  }
  return { params, paramMappings };
}