gx
chenyc
2025-02-12 ea42ff3ebee1eeb3fb29423aa848a249441db81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"use strict";
/**
 * @license
 * Copyright 2020 Google Inc. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
Object.defineProperty(exports, "__esModule", { value: true });
exports.getCustomConverterOpsModule = exports.getCustomModuleString = void 0;
var util_1 = require("./util");
function getCustomModuleString(config, moduleProvider) {
    var kernels = config.kernels, backends = config.backends, forwardModeOnly = config.forwardModeOnly, models = config.models;
    var tfjs = [(0, util_1.getPreamble)()];
    // A custom tfjs module
    addLine(tfjs, moduleProvider.importCoreStr(forwardModeOnly));
    if (models.length > 0) {
        // A model.json has been passed.
        addLine(tfjs, moduleProvider.importConverterStr());
    }
    for (var _i = 0, backends_1 = backends; _i < backends_1.length; _i++) {
        var backend = backends_1[_i];
        addLine(tfjs, "\n//backend = ".concat(backend));
        addLine(tfjs, moduleProvider.importBackendStr(backend));
        for (var _a = 0, kernels_1 = kernels; _a < kernels_1.length; _a++) {
            var kernelName = kernels_1[_a];
            var kernelImport = moduleProvider.importKernelStr(kernelName, backend);
            if (!moduleProvider.validateImportPath(kernelImport.importPath)) {
                console.warn('WARNING:', "Import path '".concat(kernelImport.importPath, "' cannot be resolved. Skipping..."));
                continue;
            }
            addLine(tfjs, kernelImport.importStatement);
            addLine(tfjs, registerKernelStr(kernelImport.kernelConfigId));
        }
    }
    if (!forwardModeOnly) {
        addLine(tfjs, "\n//Gradients");
        for (var _b = 0, kernels_2 = kernels; _b < kernels_2.length; _b++) {
            var kernelName = kernels_2[_b];
            var gradImport = moduleProvider.importGradientConfigStr(kernelName);
            if (!moduleProvider.validateImportPath(gradImport.importPath)) {
                console.warn('WARNING:', "Import path '".concat(gradImport.importPath, "' cannot be resolved. Skipping..."));
                continue;
            }
            addLine(tfjs, gradImport.importStatement);
            addLine(tfjs, registerGradientConfigStr(gradImport.gradConfigId));
        }
    }
    // A custom tfjs core module for imports within tfjs packages
    var core = [(0, util_1.getPreamble)()];
    addLine(core, moduleProvider.importCoreStr(forwardModeOnly));
    return {
        core: core.join('\n'),
        tfjs: tfjs.join('\n'),
    };
}
exports.getCustomModuleString = getCustomModuleString;
function getCustomConverterOpsModule(ops, moduleProvider) {
    var result = ['// This file is autogenerated\n'];
    // Separate namespaced apis from non namespaced ones as they require a
    // different export pattern that treats each namespace as a whole.
    var flatOps = [];
    var namespacedOps = {};
    for (var _i = 0, ops_1 = ops; _i < ops_1.length; _i++) {
        var opSymbol = ops_1[_i];
        if (opSymbol.match(/\./)) {
            var parts = opSymbol.split(/\./);
            var namespace = parts[0];
            var opName = parts[1];
            if (namespacedOps[namespace] == null) {
                namespacedOps[namespace] = [];
            }
            namespacedOps[namespace].push(opName);
        }
        else {
            flatOps.push(opSymbol);
        }
    }
    // Group the namespaced symbols by namespace
    for (var _a = 0, _b = Object.keys(namespacedOps); _a < _b.length; _a++) {
        var namespace = _b[_a];
        var opSymbols = namespacedOps[namespace];
        result.push(moduleProvider.importNamespacedOpsForConverterStr(namespace, opSymbols));
    }
    for (var _c = 0, flatOps_1 = flatOps; _c < flatOps_1.length; _c++) {
        var opSymbol = flatOps_1[_c];
        result.push(moduleProvider.importOpForConverterStr(opSymbol));
    }
    return result.join('\n');
}
exports.getCustomConverterOpsModule = getCustomConverterOpsModule;
function addLine(target, line) {
    target.push(line);
}
function registerKernelStr(kernelConfigId) {
    return "registerKernel(".concat(kernelConfigId, ");");
}
function registerGradientConfigStr(gradConfigId) {
    return "registerGradient(".concat(gradConfigId, ");");
}
//# sourceMappingURL=custom_module.js.map