gx
chenyc
2025-06-12 7b72ac13a83764a662159d4a49b7fffb90476ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
var common_1 = require("../common");
var utils_1 = require("../utils");
function extractorsFactory(weightMap, paramMappings) {
    var extractWeightEntry = common_1.extractWeightEntryFactory(weightMap, paramMappings);
    function extractScaleLayerParams(prefix) {
        var weights = extractWeightEntry(prefix + "/scale/weights", 1);
        var biases = extractWeightEntry(prefix + "/scale/biases", 1);
        return { weights: weights, biases: biases };
    }
    function extractConvLayerParams(prefix) {
        var filters = extractWeightEntry(prefix + "/conv/filters", 4);
        var bias = extractWeightEntry(prefix + "/conv/bias", 1);
        var scale = extractScaleLayerParams(prefix);
        return { conv: { filters: filters, bias: bias }, scale: scale };
    }
    function extractResidualLayerParams(prefix) {
        return {
            conv1: extractConvLayerParams(prefix + "/conv1"),
            conv2: extractConvLayerParams(prefix + "/conv2")
        };
    }
    return {
        extractConvLayerParams: extractConvLayerParams,
        extractResidualLayerParams: extractResidualLayerParams
    };
}
function extractParamsFromWeigthMap(weightMap) {
    var paramMappings = [];
    var _a = extractorsFactory(weightMap, paramMappings), extractConvLayerParams = _a.extractConvLayerParams, extractResidualLayerParams = _a.extractResidualLayerParams;
    var conv32_down = extractConvLayerParams('conv32_down');
    var conv32_1 = extractResidualLayerParams('conv32_1');
    var conv32_2 = extractResidualLayerParams('conv32_2');
    var conv32_3 = extractResidualLayerParams('conv32_3');
    var conv64_down = extractResidualLayerParams('conv64_down');
    var conv64_1 = extractResidualLayerParams('conv64_1');
    var conv64_2 = extractResidualLayerParams('conv64_2');
    var conv64_3 = extractResidualLayerParams('conv64_3');
    var conv128_down = extractResidualLayerParams('conv128_down');
    var conv128_1 = extractResidualLayerParams('conv128_1');
    var conv128_2 = extractResidualLayerParams('conv128_2');
    var conv256_down = extractResidualLayerParams('conv256_down');
    var conv256_1 = extractResidualLayerParams('conv256_1');
    var conv256_2 = extractResidualLayerParams('conv256_2');
    var conv256_down_out = extractResidualLayerParams('conv256_down_out');
    var fc = weightMap['fc'];
    paramMappings.push({ originalPath: 'fc', paramPath: 'fc' });
    if (!utils_1.isTensor2D(fc)) {
        throw new Error("expected weightMap[fc] to be a Tensor2D, instead have " + fc);
    }
    var params = {
        conv32_down: conv32_down,
        conv32_1: conv32_1,
        conv32_2: conv32_2,
        conv32_3: conv32_3,
        conv64_down: conv64_down,
        conv64_1: conv64_1,
        conv64_2: conv64_2,
        conv64_3: conv64_3,
        conv128_down: conv128_down,
        conv128_1: conv128_1,
        conv128_2: conv128_2,
        conv256_down: conv256_down,
        conv256_1: conv256_1,
        conv256_2: conv256_2,
        conv256_down_out: conv256_down_out,
        fc: fc
    };
    common_1.disposeUnusedWeightTensors(weightMap, paramMappings);
    return { params: params, paramMappings: paramMappings };
}
exports.extractParamsFromWeigthMap = extractParamsFromWeigthMap;
//# sourceMappingURL=extractParamsFromWeigthMap.js.map