1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
| import { extractConvParamsFactory, extractSeparableConvParamsFactory, ExtractWeightsFunction, ParamMapping } from '../common/index';
| import { DenseBlock3Params, DenseBlock4Params } from './types';
|
| export function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
| const extractConvParams = extractConvParamsFactory(extractWeights, paramMappings);
| const extractSeparableConvParams = extractSeparableConvParamsFactory(extractWeights, paramMappings);
|
| function extractDenseBlock3Params(channelsIn: number, channelsOut: number, mappedPrefix: string, isFirstLayer = false): DenseBlock3Params {
| const conv0 = isFirstLayer
| ? extractConvParams(channelsIn, channelsOut, 3, `${mappedPrefix}/conv0`)
| : extractSeparableConvParams(channelsIn, channelsOut, `${mappedPrefix}/conv0`);
| const conv1 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv1`);
| const conv2 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv2`);
|
| return { conv0, conv1, conv2 };
| }
|
| function extractDenseBlock4Params(channelsIn: number, channelsOut: number, mappedPrefix: string, isFirstLayer = false): DenseBlock4Params {
| const { conv0, conv1, conv2 } = extractDenseBlock3Params(channelsIn, channelsOut, mappedPrefix, isFirstLayer);
| const conv3 = extractSeparableConvParams(channelsOut, channelsOut, `${mappedPrefix}/conv3`);
|
| return {
| conv0, conv1, conv2, conv3,
| };
| }
|
| return {
| extractDenseBlock3Params,
| extractDenseBlock4Params,
| };
| }
|
|