"use strict";
|
/**
|
* @license
|
* Copyright 2020 Google Inc. All Rights Reserved.
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
* =============================================================================
|
*/
|
Object.defineProperty(exports, "__esModule", { value: true });
|
exports.getCustomConverterOpsModule = exports.getCustomModuleString = void 0;
|
var util_1 = require("./util");
|
function getCustomModuleString(config, moduleProvider) {
|
var kernels = config.kernels, backends = config.backends, forwardModeOnly = config.forwardModeOnly, models = config.models;
|
var tfjs = [(0, util_1.getPreamble)()];
|
// A custom tfjs module
|
addLine(tfjs, moduleProvider.importCoreStr(forwardModeOnly));
|
if (models.length > 0) {
|
// A model.json has been passed.
|
addLine(tfjs, moduleProvider.importConverterStr());
|
}
|
for (var _i = 0, backends_1 = backends; _i < backends_1.length; _i++) {
|
var backend = backends_1[_i];
|
addLine(tfjs, "\n//backend = ".concat(backend));
|
addLine(tfjs, moduleProvider.importBackendStr(backend));
|
for (var _a = 0, kernels_1 = kernels; _a < kernels_1.length; _a++) {
|
var kernelName = kernels_1[_a];
|
var kernelImport = moduleProvider.importKernelStr(kernelName, backend);
|
if (!moduleProvider.validateImportPath(kernelImport.importPath)) {
|
console.warn('WARNING:', "Import path '".concat(kernelImport.importPath, "' cannot be resolved. Skipping..."));
|
continue;
|
}
|
addLine(tfjs, kernelImport.importStatement);
|
addLine(tfjs, registerKernelStr(kernelImport.kernelConfigId));
|
}
|
}
|
if (!forwardModeOnly) {
|
addLine(tfjs, "\n//Gradients");
|
for (var _b = 0, kernels_2 = kernels; _b < kernels_2.length; _b++) {
|
var kernelName = kernels_2[_b];
|
var gradImport = moduleProvider.importGradientConfigStr(kernelName);
|
if (!moduleProvider.validateImportPath(gradImport.importPath)) {
|
console.warn('WARNING:', "Import path '".concat(gradImport.importPath, "' cannot be resolved. Skipping..."));
|
continue;
|
}
|
addLine(tfjs, gradImport.importStatement);
|
addLine(tfjs, registerGradientConfigStr(gradImport.gradConfigId));
|
}
|
}
|
// A custom tfjs core module for imports within tfjs packages
|
var core = [(0, util_1.getPreamble)()];
|
addLine(core, moduleProvider.importCoreStr(forwardModeOnly));
|
return {
|
core: core.join('\n'),
|
tfjs: tfjs.join('\n'),
|
};
|
}
|
exports.getCustomModuleString = getCustomModuleString;
|
function getCustomConverterOpsModule(ops, moduleProvider) {
|
var result = ['// This file is autogenerated\n'];
|
// Separate namespaced apis from non namespaced ones as they require a
|
// different export pattern that treats each namespace as a whole.
|
var flatOps = [];
|
var namespacedOps = {};
|
for (var _i = 0, ops_1 = ops; _i < ops_1.length; _i++) {
|
var opSymbol = ops_1[_i];
|
if (opSymbol.match(/\./)) {
|
var parts = opSymbol.split(/\./);
|
var namespace = parts[0];
|
var opName = parts[1];
|
if (namespacedOps[namespace] == null) {
|
namespacedOps[namespace] = [];
|
}
|
namespacedOps[namespace].push(opName);
|
}
|
else {
|
flatOps.push(opSymbol);
|
}
|
}
|
// Group the namespaced symbols by namespace
|
for (var _a = 0, _b = Object.keys(namespacedOps); _a < _b.length; _a++) {
|
var namespace = _b[_a];
|
var opSymbols = namespacedOps[namespace];
|
result.push(moduleProvider.importNamespacedOpsForConverterStr(namespace, opSymbols));
|
}
|
for (var _c = 0, flatOps_1 = flatOps; _c < flatOps_1.length; _c++) {
|
var opSymbol = flatOps_1[_c];
|
result.push(moduleProvider.importOpForConverterStr(opSymbol));
|
}
|
return result.join('\n');
|
}
|
exports.getCustomConverterOpsModule = getCustomConverterOpsModule;
|
function addLine(target, line) {
|
target.push(line);
|
}
|
function registerKernelStr(kernelConfigId) {
|
return "registerKernel(".concat(kernelConfigId, ");");
|
}
|
function registerGradientConfigStr(gradConfigId) {
|
return "registerGradient(".concat(gradConfigId, ");");
|
}
|
//# sourceMappingURL=custom_module.js.map
|