1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
| import * as tf from '../../dist/tfjs.esm';
|
| import { ExtractWeightsFunction, FCParams, ParamMapping } from './types';
|
| export function extractFCParamsFactory(
| extractWeights: ExtractWeightsFunction,
| paramMappings: ParamMapping[],
| ) {
| return (
| channelsIn: number,
| channelsOut: number,
| mappedPrefix: string,
| ): FCParams => {
| const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut]);
| const fc_bias = tf.tensor1d(extractWeights(channelsOut));
|
| paramMappings.push(
| { paramPath: `${mappedPrefix}/weights` },
| { paramPath: `${mappedPrefix}/bias` },
| );
|
| return {
| weights: fc_weights,
| bias: fc_bias,
| };
| };
| }
|
|