Skip to content

Commit f2aed77

Browse files
committed
actually enable CPU fallback for YoloSession.
1 parent 37bc98a commit f2aed77

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

SerialPrograms/Source/ML/Inference/ML_YOLOv5Detector.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ YOLOv5Detector::YOLOv5Detector(const std::string& model_path)
7878
}
7979
label_file.close();
8080

81-
m_yolo_session = std::make_unique<YOLOv5Session>(m_model_path, std::move(labels));
81+
m_yolo_session = std::make_unique<YOLOv5Session>(m_model_path, std::move(labels), m_use_gpu);
8282
}
8383

8484
bool YOLOv5Detector::detect(const ImageViewRGB32& screen){
@@ -104,7 +104,7 @@ bool YOLOv5Detector::detect(const ImageViewRGB32& screen){
104104
std::cerr << "Warning: YOLO session failed using the GPU. Will reattempt with the CPU.\n" << e.what() << std::endl;
105105
m_use_gpu = false;
106106
std::vector<std::string> labels = m_yolo_session->get_label_names();
107-
m_yolo_session = std::make_unique<YOLOv5Session>(m_model_path, std::move(labels));
107+
m_yolo_session = std::make_unique<YOLOv5Session>(m_model_path, std::move(labels), m_use_gpu);
108108
}else{
109109
std::cerr << "Error: YOLO session failed even when using the CPU.\n" << e.what() << std::endl;
110110
throw InternalProgramError(nullptr, PA_CURRENT_FUNCTION, "Error: YOLO session failed.");

SerialPrograms/Source/ML/Models/ML_YOLOv5Model.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ std::tuple<int, int, double, double> resize_image_with_border(
5858
}
5959

6060

61-
YOLOv5Session::YOLOv5Session(const std::string& model_path, std::vector<std::string> label_names)
61+
YOLOv5Session::YOLOv5Session(const std::string& model_path, std::vector<std::string> label_names, bool use_gpu)
6262
: m_label_names(std::move(label_names))
63-
, m_session_options(create_session_options(ML_MODEL_CACHE_PATH() + "YOLOv5", true))
63+
, m_session_options(create_session_options(ML_MODEL_CACHE_PATH() + "YOLOv5", use_gpu))
6464
, m_session{create_session(m_env, m_session_options, model_path, ML_MODEL_CACHE_PATH() + "YOLOv5")}
6565
, m_memory_info{Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU)}
6666
, m_input_names{m_session.GetInputNames()}

SerialPrograms/Source/ML/Models/ML_YOLOv5Model.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class YOLOv5Session{
2828
size_t label_idx;
2929
};
3030

31-
YOLOv5Session(const std::string& model_path, std::vector<std::string> label_names);
31+
YOLOv5Session(const std::string& model_path, std::vector<std::string> label_names, bool use_gpu);
3232

3333
void run(const cv::Mat& input_image, std::vector<DetectionBox>& detections);
3434

0 commit comments

Comments
 (0)