/** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ #ifndef TF_NODEJS_TFJS_BACKEND_H_ #define TF_NODEJS_TFJS_BACKEND_H_ #include #include #include #include #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api.h" namespace tfnodejs { class TFJSBackend { public: ~TFJSBackend(); // Creates, initializes, and returns a TFJSBackend instance. If initialization // fails, a nullptr is returned. static TFJSBackend *Create(napi_env env); // Creates a new Tensor with given shape and data and returns an ID that // refernces the new Tensor. // - shape_value (number[]) // - dtype_value (number) // - array_value (TypedArray|Array) napi_value CreateTensor(napi_env env, napi_value shape_value, napi_value dtype_value, napi_value array_value); // Deletes a created Tensor. // - tensor_id_value (number) void DeleteTensor(napi_env env, napi_value tensor_id_value); // Returns a typed-array as a `napi_value` with the data associated with the // TF/TFE pointers. // - tensor_id_value (number) napi_value GetTensorData(napi_env env, napi_value tensor_id_value); // Executes a TFE Op and returns an array of objects containing tensor // attributes (id, dtype, shape). // - op_name_value (string) // - op_attr_inputs (array of TFE Op attributes) // - input_tensor_ids (array of input tensor IDs) // - num_output_values (number) napi_value ExecuteOp(napi_env env, napi_value op_name_value, napi_value op_attr_inputs, napi_value input_tensor_ids, napi_value num_output_values); // Load a SavedModel from a path: // - export_dir (string) // - tags_value (string) napi_value LoadSavedModel(napi_env env, napi_value export_dir, napi_value tags_value); // Delete the SavedModel corresponding TF_Session and TF_Graph // - saved_model_id (number) void DeleteSavedModel(napi_env env, napi_value saved_model_id); // Execute a session from SavedModel with the provided inputs: // - saved_model_id (number) // - input_tensor_ids (array of input tensor IDs) // - input_op_names (array of input op names) // - output_op_names (array of output op names) napi_value RunSavedModel(napi_env env, napi_value saved_model_id, napi_value input_tensor_ids, napi_value input_op_names, napi_value output_op_names); // Get number of loaded SavedModel in the backend: napi_value GetNumOfSavedModels(napi_env env); // Get number of tensor handles in the backend: napi_value GetNumOfTensors(napi_env env); private: TFJSBackend(napi_env env); int32_t InsertHandle(TFE_TensorHandle *tfe_handle); int32_t InsertSavedModel(TF_Session *tf_session, TF_Graph *tf_graph); napi_value GenerateOutputTensorInfo(napi_env env, TFE_TensorHandle *handle); TFE_Context *tfe_context_; std::unordered_map tfe_handle_map_; std::unordered_map> tf_savedmodel_map_; int32_t next_tensor_id_; int32_t next_savedmodel_id_; std::string device_name; public: bool is_gpu_device; }; } // namespace tfnodejs #endif // TF_NODEJS_TFJS_BACKEND_H_