gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import * as tf from '../../dist/tfjs.esm';
 
import { ExtractWeightsFunction, ParamMapping, SeparableConvParams } from './types';
 
export function extractSeparableConvParamsFactory(
  extractWeights: ExtractWeightsFunction,
  paramMappings: ParamMapping[],
) {
  return (channelsIn: number, channelsOut: number, mappedPrefix: string): SeparableConvParams => {
    const depthwise_filter = tf.tensor4d(extractWeights(3 * 3 * channelsIn), [3, 3, channelsIn, 1]);
    const pointwise_filter = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut]);
    const bias = tf.tensor1d(extractWeights(channelsOut));
 
    paramMappings.push(
      { paramPath: `${mappedPrefix}/depthwise_filter` },
      { paramPath: `${mappedPrefix}/pointwise_filter` },
      { paramPath: `${mappedPrefix}/bias` },
    );
 
    return new SeparableConvParams(
      depthwise_filter,
      pointwise_filter,
      bias,
    );
  };
}
 
export function loadSeparableConvParamsFactory(
  // eslint-disable-next-line no-unused-vars
  extractWeightEntry: <T>(originalPath: string, paramRank: number) => T,
) {
  return (prefix: string): SeparableConvParams => {
    const depthwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/depthwise_filter`, 4);
    const pointwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/pointwise_filter`, 4);
    const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1);
 
    return new SeparableConvParams(
      depthwise_filter,
      pointwise_filter,
      bias,
    );
  };
}