1 Star 29 Fork 11

Quard/QFacefusion

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
onnxbase.cpp 6.70 KB
一键复制 编辑 原始数据 按行查看 历史
#include "onnxbase.h"
#include <atomic>
using namespace Ort;
void onnxBaseLogger(OnnxBase* instance, OrtLoggingLevel severity,
const char* category, const char* logid, const char* code_location, const char* message) {
auto logLevel = instance->ortLogLevel2OnnxBaseLogLevel(severity);
instance->logger(logLevel, category, logid, code_location, message);
}
OnnxBase::OnnxBase(std::string model_path, LoggerCallback logger, void* user_data) {
m_LoggerCallback = logger;
m_UserData = user_data;
env = Ort::Env(ORT_LOGGING_LEVEL_VERBOSE, model_path.c_str(), (OrtLoggingFunction)onnxBaseLogger, this);
Ort::GetApi().SetUserLoggingFunction(sessionOptions, (OrtLoggingFunction)onnxBaseLogger, this);
#if defined(CUDA_FACEFUSION_BUILD)
try {
OrtCUDAProviderOptions cuda_options;
cuda_options.device_id = 0; // 设置 GPU 设备 ID
cuda_options.arena_extend_strategy = 0; // 使用默认的内存分配策略
cuda_options.gpu_mem_limit = SIZE_MAX; // 设置 GPU 内存限制
cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::EXHAUSTIVE; // 使用最优的卷积算法
cuda_options.do_copy_in_default_stream = 1; // 在默认流中进行数据复制
sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
} catch (const Ort::Exception& e) {
std::cerr << "Error appending CUDA execution provider: " << e.what() << std::endl;
}
#endif
#if defined(COREML_FACEFUSION_BUILD)
OrtSessionOptionsAppendExecutionProvider_CoreML(sessionOptions,COREML_FLAG_ENABLE_ON_SUBGRAPH);
#endif
sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
#if defined(WINDOWS_FACEFUSION_BUILD)
std::wstring widestr = std::wstring(model_path.begin(), model_path.end());
ort_session = new Session(env, widestr.c_str(), sessionOptions);
#endif
#if defined(LINUX_FACEFUSION_BUILD) || defined(MACOS_FACEFUSION_BUILD)
ort_session = new Session(env, model_path.c_str(), sessionOptions);
#endif
size_t numInputNodes = ort_session->GetInputCount();
size_t numOutputNodes = ort_session->GetOutputCount();
AllocatorWithDefaultOptions allocator;
for (size_t i = 0; i < numInputNodes; i++)
{
input_names_ptrs.push_back(ort_session->GetInputNameAllocated(i, allocator));
input_names.push_back(input_names_ptrs.back().get());
Ort::TypeInfo input_type_info = ort_session->GetInputTypeInfo(i);
auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();
auto input_dims = input_tensor_info.GetShape();
input_node_dims.push_back(input_dims);
}
for (size_t i = 0; i < numOutputNodes; i++)
{
output_names_ptrs.push_back(ort_session->GetOutputNameAllocated(i, allocator));
output_names.push_back(output_names_ptrs.back().get());
Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i);
auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();
auto output_dims = output_tensor_info.GetShape();
output_node_dims.push_back(output_dims);
}
}
OnnxBase::~OnnxBase() {
delete ort_session;
}
std::vector<Ort::Value> OnnxBase::runInferenceWithProgress(Ort::RunOptions& runOptions,
const char* const* input_names, const Ort::Value* input, size_t input_count,
const char* const* output_names, size_t output_count) const {
auto start_time = std::chrono::steady_clock::now();
// 如果设置了进度回调,启动进度监控线程
std::thread progress_thread;
std::atomic<bool> inference_finished{false};
if (m_ProgressCallback && m_average_time > 0.1) { // 只对预估时间大于0.1秒的推理启用进度监控
progress_thread = std::thread([this, start_time, &inference_finished]() {
while (!inference_finished) {
auto current_time = std::chrono::steady_clock::now();
auto elapsed = std::chrono::duration<double>(current_time - start_time).count();
float progress = std::min(0.95f, static_cast<float>(elapsed / m_average_time)); // 最大显示95%
std::ostringstream msg;
msg << "推理进行中... " << std::fixed << std::setprecision(1)
<< (progress * 100.0f) << "% (预计剩余: "
<< std::max(0.0, m_average_time - elapsed) << "s)";
m_ProgressCallback(progress, msg.str(), m_ProgressUserData);
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // 每100ms更新一次进度
}
});
}
// 执行推理
std::vector<Ort::Value> result;
try {
if (m_ProgressCallback) {
m_ProgressCallback(0.0f, "开始推理...", m_ProgressUserData);
}
result = this->ort_session->Run(runOptions, input_names, input, input_count, output_names, output_count);
// 标记推理完成,停止进度监控线程
inference_finished = true;
if (progress_thread.joinable()) {
progress_thread.join();
}
if (m_ProgressCallback) {
m_ProgressCallback(1.0f, "推理完成", m_ProgressUserData);
}
} catch (...) {
// 发生异常时也要停止进度监控线程
inference_finished = true;
if (progress_thread.joinable()) {
progress_thread.join();
}
throw; // 重新抛出异常
}
// 记录实际执行时间并更新历史数据
auto end_time = std::chrono::steady_clock::now();
double actual_duration = std::chrono::duration<double>(end_time - start_time).count();
// 更新历史时间记录
m_execution_times.push_back(actual_duration);
// 只保留最近10次的执行时间
if (m_execution_times.size() > 10) {
m_execution_times.erase(m_execution_times.begin());
}
// 重新计算平均时间
if (!m_execution_times.empty()) {
double sum = 0.0;
for (double time : m_execution_times) {
sum += time;
}
m_average_time = sum / m_execution_times.size();
}
return result;
}
OnnxBase::OnnxBaseLogLevel OnnxBase::ortLogLevel2OnnxBaseLogLevel(OrtLoggingLevel level) {
switch (level) {
case ORT_LOGGING_LEVEL_VERBOSE:
return VERBOSE;
case ORT_LOGGING_LEVEL_INFO:
return INFO;
case ORT_LOGGING_LEVEL_WARNING:
return WARNING;
case ORT_LOGGING_LEVEL_ERROR:
return ERROR;
case ORT_LOGGING_LEVEL_FATAL:
return FATAL;
default:
return UNKNOWN;
}
}
void OnnxBase::setProgressCallback(ProgressCallback callback, void* user_data) {
m_ProgressCallback = callback;
m_ProgressUserData = user_data;
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
C
1
https://gitee.com/QQxiaoming/QFacefusion.git
git@gitee.com:QQxiaoming/QFacefusion.git
QQxiaoming
QFacefusion
QFacefusion
main

搜索帮助