summaryrefslogtreecommitdiff
path: root/thermal/virtualtemp_estimator/virtualtemp_estimator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'thermal/virtualtemp_estimator/virtualtemp_estimator.cpp')
-rw-r--r--thermal/virtualtemp_estimator/virtualtemp_estimator.cpp158
1 files changed, 158 insertions, 0 deletions
diff --git a/thermal/virtualtemp_estimator/virtualtemp_estimator.cpp b/thermal/virtualtemp_estimator/virtualtemp_estimator.cpp
new file mode 100644
index 00000000..2dc2185c
--- /dev/null
+++ b/thermal/virtualtemp_estimator/virtualtemp_estimator.cpp
@@ -0,0 +1,158 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "virtualtemp_estimator.h"
+
+#include <android-base/logging.h>
+#include <dlfcn.h>
+
+#include <vector>
+
+namespace thermal {
+namespace vtestimator {
+
+void VirtualTempEstimator::LoadTFLiteWrapper() {
+ if (!data_) {
+ LOG(ERROR) << "data_ is nullptr during LoadTFLiteWrapper";
+ return;
+ }
+
+ std::unique_lock<std::mutex> lock(data_->tflite_methods.mutex_);
+
+ void *mLibHandle = dlopen("/vendor/lib64/libthermal_tflite_wrapper.so", 0);
+ if (mLibHandle == nullptr) {
+ LOG(ERROR) << "Could not load libthermal_tflite_wrapper library with error: " << dlerror();
+ return;
+ }
+
+ data_->tflite_methods.create =
+ reinterpret_cast<tflitewrapper_create>(dlsym(mLibHandle, "Create"));
+ if (!data_->tflite_methods.create) {
+ LOG(ERROR) << "Could not link and cast tflitewrapper_create with error: " << dlerror();
+ }
+
+ data_->tflite_methods.init = reinterpret_cast<tflitewrapper_init>(dlsym(mLibHandle, "Init"));
+ if (!data_->tflite_methods.init) {
+ LOG(ERROR) << "Could not link and cast tflitewrapper_init with error: " << dlerror();
+ }
+
+ data_->tflite_methods.invoke =
+ reinterpret_cast<tflitewrapper_invoke>(dlsym(mLibHandle, "Invoke"));
+ if (!data_->tflite_methods.invoke) {
+ LOG(ERROR) << "Could not link and cast tflitewrapper_invoke with error: " << dlerror();
+ }
+
+ data_->tflite_methods.destroy =
+ reinterpret_cast<tflitewrapper_destroy>(dlsym(mLibHandle, "Destroy"));
+ if (!data_->tflite_methods.destroy) {
+ LOG(ERROR) << "Could not link and cast tflitewrapper_destroy with error: " << dlerror();
+ }
+}
+
+VirtualTempEstimator::VirtualTempEstimator(size_t num_input_samples) {
+ data_ = std::make_unique<VirtualTempEstimatorTFLiteData>(num_input_samples);
+ LoadTFLiteWrapper();
+}
+
+VirtualTempEstimator::~VirtualTempEstimator() {
+ LOG(INFO) << "VirtualTempEstimator destructor";
+}
+
+VtEstimatorStatus VirtualTempEstimator::Initialize(const char *model_path) {
+ LOG(INFO) << "Initialize VirtualTempEstimator\n";
+
+ if (!data_) {
+ LOG(ERROR) << "data_ is nullptr during Initialize\n";
+ return kVtEstimatorInitFailed;
+ }
+
+ std::unique_lock<std::mutex> lock(data_->tflite_methods.mutex_);
+
+ if (!model_path) {
+ LOG(ERROR) << "Invalid model_path:" << model_path;
+ return kVtEstimatorInvalidArgs;
+ }
+
+ if (!data_->input_buffer || !data_->input_buffer_size) {
+ LOG(ERROR) << "Invalid data_ members " << model_path
+ << " input_buffer: " << data_->input_buffer
+ << " input_buffer_size: " << data_->input_buffer_size;
+ return kVtEstimatorInitFailed;
+ }
+
+ if (!data_->tflite_methods.create || !data_->tflite_methods.init ||
+ !data_->tflite_methods.invoke || !data_->tflite_methods.destroy) {
+ LOG(ERROR) << "Invalid tflite methods";
+ return kVtEstimatorInitFailed;
+ }
+
+ data_->tflite_wrapper = data_->tflite_methods.create(kNumInputTensors, kNumOutputTensors);
+ if (!data_->tflite_wrapper) {
+ LOG(ERROR) << "Failed to create tflite wrapper";
+ return kVtEstimatorInitFailed;
+ }
+
+ int ret = data_->tflite_methods.init(data_->tflite_wrapper, model_path);
+ if (ret) {
+ LOG(ERROR) << "Failed to Init tflite_wrapper for " << model_path << " (ret: )" << ret
+ << ")";
+ return kVtEstimatorInitFailed;
+ }
+
+ data_->is_initialized = true;
+ data_->model_path = model_path;
+
+ LOG(INFO) << "Successfully initialized VirtualTempEstimator for " << model_path;
+ return kVtEstimatorOk;
+}
+
+VtEstimatorStatus VirtualTempEstimator::Estimate(const std::vector<float> &thermistors,
+ float *output) {
+ if (!data_) {
+ LOG(ERROR) << "data_ is nullptr during Estimate\n";
+ return kVtEstimatorInitFailed;
+ }
+
+ std::unique_lock<std::mutex> lock(data_->tflite_methods.mutex_);
+
+ if (!data_->is_initialized) {
+ LOG(ERROR) << "data_ not initialized for " << data_->model_path;
+ return kVtEstimatorInitFailed;
+ }
+
+ if ((thermistors.size() != data_->input_buffer_size) || (!output)) {
+ LOG(ERROR) << "Invalid args for " << data_->model_path
+ << " thermistors.size(): " << thermistors.size()
+ << " input_buffer_size: " << data_->input_buffer_size << " output: " << output;
+ return kVtEstimatorInvalidArgs;
+ }
+
+ // copy input data into input tensors
+ for (size_t i = 0; i < data_->input_buffer_size; ++i) {
+ data_->input_buffer[i] = thermistors[i];
+ }
+
+ int ret = data_->tflite_methods.invoke(data_->tflite_wrapper, data_->input_buffer,
+ data_->input_buffer_size, output, 1);
+ if (ret) {
+ LOG(ERROR) << "Failed to Invoke for " << data_->model_path << " (ret: " << ret << ")";
+ return kVtEstimatorInvokeFailed;
+ }
+
+ return kVtEstimatorOk;
+}
+
+} // namespace vtestimator
+} // namespace thermal