gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import * as tf from '../../dist/tfjs.esm';
 
import { ConvParams, ExtractWeightsFunction, ParamMapping } from './types';
 
export function extractConvParamsFactory(
  extractWeights: ExtractWeightsFunction,
  paramMappings: ParamMapping[],
) {
  return (
    channelsIn: number,
    channelsOut: number,
    filterSize: number,
    mappedPrefix: string,
  ): ConvParams => {
    const filters = tf.tensor4d(
      extractWeights(channelsIn * channelsOut * filterSize * filterSize),
      [filterSize, filterSize, channelsIn, channelsOut],
    );
    const bias = tf.tensor1d(extractWeights(channelsOut));
 
    paramMappings.push(
      { paramPath: `${mappedPrefix}/filters` },
      { paramPath: `${mappedPrefix}/bias` },
    );
 
    return { filters, bias };
  };
}