gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import * as tf from '../../dist/tfjs.esm';
 
import { disposeUnusedWeightTensors, extractWeightEntryFactory, ParamMapping } from '../common/index';
import { isTensor2D } from '../utils/index';
import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types';
 
function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
  const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings);
 
  function extractScaleLayerParams(prefix: string): ScaleLayerParams {
    const weights = extractWeightEntry(`${prefix}/scale/weights`, 1);
    const biases = extractWeightEntry(`${prefix}/scale/biases`, 1);
 
    return { weights, biases };
  }
 
  function extractConvLayerParams(prefix: string): ConvLayerParams {
    const filters = extractWeightEntry(`${prefix}/conv/filters`, 4);
    const bias = extractWeightEntry(`${prefix}/conv/bias`, 1);
    const scale = extractScaleLayerParams(prefix);
 
    return { conv: { filters, bias }, scale };
  }
 
  function extractResidualLayerParams(prefix: string): ResidualLayerParams {
    return {
      conv1: extractConvLayerParams(`${prefix}/conv1`),
      conv2: extractConvLayerParams(`${prefix}/conv2`),
    };
  }
 
  return {
    extractConvLayerParams,
    extractResidualLayerParams,
  };
}
 
export function extractParamsFromWeightMap(
  weightMap: tf.NamedTensorMap,
): { params: NetParams, paramMappings: ParamMapping[] } {
  const paramMappings: ParamMapping[] = [];
 
  const {
    extractConvLayerParams,
    extractResidualLayerParams,
  } = extractorsFactory(weightMap, paramMappings);
 
  const conv32_down = extractConvLayerParams('conv32_down');
  const conv32_1 = extractResidualLayerParams('conv32_1');
  const conv32_2 = extractResidualLayerParams('conv32_2');
  const conv32_3 = extractResidualLayerParams('conv32_3');
 
  const conv64_down = extractResidualLayerParams('conv64_down');
  const conv64_1 = extractResidualLayerParams('conv64_1');
  const conv64_2 = extractResidualLayerParams('conv64_2');
  const conv64_3 = extractResidualLayerParams('conv64_3');
 
  const conv128_down = extractResidualLayerParams('conv128_down');
  const conv128_1 = extractResidualLayerParams('conv128_1');
  const conv128_2 = extractResidualLayerParams('conv128_2');
 
  const conv256_down = extractResidualLayerParams('conv256_down');
  const conv256_1 = extractResidualLayerParams('conv256_1');
  const conv256_2 = extractResidualLayerParams('conv256_2');
  const conv256_down_out = extractResidualLayerParams('conv256_down_out');
 
  const { fc } = weightMap;
  paramMappings.push({ originalPath: 'fc', paramPath: 'fc' });
 
  if (!isTensor2D(fc)) {
    throw new Error(`expected weightMap[fc] to be a Tensor2D, instead have ${fc}`);
  }
 
  const params = {
    conv32_down,
    conv32_1,
    conv32_2,
    conv32_3,
    conv64_down,
    conv64_1,
    conv64_2,
    conv64_3,
    conv128_down,
    conv128_1,
    conv128_2,
    conv256_down,
    conv256_1,
    conv256_2,
    conv256_down_out,
    fc,
  };
 
  disposeUnusedWeightTensors(weightMap, paramMappings);
 
  return { params, paramMappings };
}