1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
| import * as tf from '../../dist/tfjs.esm';
|
| import { ConvParams, ExtractWeightsFunction, ParamMapping } from './types';
|
| export function extractConvParamsFactory(
| extractWeights: ExtractWeightsFunction,
| paramMappings: ParamMapping[],
| ) {
| return (
| channelsIn: number,
| channelsOut: number,
| filterSize: number,
| mappedPrefix: string,
| ): ConvParams => {
| const filters = tf.tensor4d(
| extractWeights(channelsIn * channelsOut * filterSize * filterSize),
| [filterSize, filterSize, channelsIn, channelsOut],
| );
| const bias = tf.tensor1d(extractWeights(channelsOut));
|
| paramMappings.push(
| { paramPath: `${mappedPrefix}/filters` },
| { paramPath: `${mappedPrefix}/bias` },
| );
|
| return { filters, bias };
| };
| }
|
|