Skip to content

Commit f7e1bec

Browse files
author
Gin
committed
Change YOLO detector to use custom model path
1 parent fbeb7f6 commit f7e1bec

File tree

6 files changed

+80
-16
lines changed

6 files changed

+80
-16
lines changed

Common/Cpp/PrettyPrint.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ std::string tostr_padded(size_t digits, uint64_t x);
1717
std::string tostr_u_commas(int64_t x);
1818
std::string tostr_bytes(uint64_t bytes);
1919

20+
// Convert double to string using the default precision on ostream.
2021
std::string tostr_default(double x);
22+
// Convert double to string with fixed precision.
2123
std::string tostr_fixed(double x, int precision);
2224

2325
// Format current time to a string to be used as filenames.

SerialPrograms/Source/ML/Inference/ML_YOLOv5Detector.cpp

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@
66

77
#include <filesystem>
88
#include <iostream>
9+
#include <fstream>
910
#include <QMessageBox>
1011
#include <opencv2/imgproc.hpp>
1112
#include <opencv2/imgcodecs.hpp>
1213
#include "CommonFramework/ImageTypes/ImageViewRGB32.h"
1314
#include "CommonFramework/VideoPipeline/VideoOverlay.h"
1415
#include "CommonFramework/VideoPipeline/VideoOverlayScopes.h"
1516
#include "CommonFramework/Globals.h"
17+
#include "Common/Cpp/StringTools.h"
18+
#include "Common/Cpp/PrettyPrint.h"
1619
#include "ML_YOLOv5Detector.h"
1720

1821
//#include <iostream>
@@ -25,18 +28,45 @@ namespace ML{
2528

2629
YOLOv5Detector::~YOLOv5Detector() = default;
2730

28-
YOLOv5Detector::YOLOv5Detector()
31+
32+
YOLOv5Detector::YOLOv5Detector(const std::string& model_path)
2933
{
30-
const std::string sam_model_path = RESOURCE_PATH() + "ML/yolov5.onnx";
31-
std::vector<std::string> labels = {"Bidoof"};
32-
if (std::filesystem::exists(sam_model_path)){
33-
m_yolo_session = std::make_unique<YOLOv5Session>(sam_model_path, std::move(labels));
34-
} else{
35-
std::cerr << "Error: no such YOLOv5 model path " << sam_model_path << "." << std::endl;
34+
if (!model_path.ends_with(".onnx")){
35+
std::cerr << "Error: wrong model path extension: " << model_path << ". It must be .onnx" << std::endl;
36+
QMessageBox box;
37+
box.critical(nullptr, "Wrong Model Extension",
38+
QString::fromStdString("YOLOv5 model path must end with .onnx. But got " + model_path + "."));
39+
return;
40+
}
41+
42+
std::string label_file_path = model_path.substr(0, model_path.size() - 5) + "_label.txt";
43+
std::vector<std::string> labels;
44+
if (!std::filesystem::exists(label_file_path)){
45+
std::cerr << "Error: no such YOLOv5 label file path " << label_file_path << "." << std::endl;
3646
QMessageBox box;
37-
box.critical(nullptr, "YOLOv5 Model Does Not Exist",
38-
QString::fromStdString("YOLOv5 model path" + sam_model_path + " does not exist."));
47+
box.critical(nullptr, "YOLOv5 Label File Does Not Exist",
48+
QString::fromStdString("YOLOv5 label file path " + label_file_path + " does not exist."));
49+
return;
3950
}
51+
std::ifstream label_file(label_file_path);
52+
if (!label_file.is_open()){
53+
std::cerr << "Error: failed to open YOLOv5 label file " << label_file_path << "." << std::endl;
54+
QMessageBox box;
55+
box.critical(nullptr, "Cannot Open YOLOv5 Label File",
56+
QString::fromStdString("YOLOv5 label file " + label_file_path + " cannot be opened."));
57+
return;
58+
}
59+
std::string line;
60+
while (std::getline(label_file, line)){
61+
line = StringTools::strip(line);
62+
if (line.empty() || line[0] == '#'){
63+
continue;
64+
}
65+
labels.push_back(line);
66+
}
67+
label_file.close();
68+
69+
m_yolo_session = std::make_unique<YOLOv5Session>(model_path, std::move(labels));
4070
}
4171

4272
bool YOLOv5Detector::detect(const ImageViewRGB32& screen){
@@ -55,22 +85,24 @@ bool YOLOv5Detector::detect(const ImageViewRGB32& screen){
5585
}
5686

5787

58-
YOLOv5Watcher::YOLOv5Watcher(VideoOverlay& overlay)
88+
YOLOv5Watcher::YOLOv5Watcher(VideoOverlay& overlay, const std::string& model_path)
5989
: VisualInferenceCallback("YOLOv5")
6090
, m_overlay_set(overlay)
91+
, m_detector(model_path)
6192
{
6293
}
6394

6495
bool YOLOv5Watcher::process_frame(const ImageViewRGB32& frame, WallClock timestamp){
65-
if (!m_detector.session()){
96+
if (!m_detector.model_loaded()){
6697
return false;
6798
}
6899

69100
m_detector.detect(frame);
70101

71102
m_overlay_set.clear();
72103
for(const auto& box : m_detector.detected_boxes()){
73-
m_overlay_set.add(COLOR_RED, box.box, m_detector.session()->label_name(box.label_idx));
104+
std::string text = m_detector.session()->label_name(box.label_idx) + ": " + tostr_fixed(box.score, 2);
105+
m_overlay_set.add(COLOR_RED, box.box, text);
74106
}
75107
return false;
76108
}

SerialPrograms/Source/ML/Inference/ML_YOLOv5Detector.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,17 @@ namespace ML{
2323

2424
class YOLOv5Detector : public StaticScreenDetector{
2525
public:
26-
YOLOv5Detector();
26+
// - model_path: path to the onnx model file. The label name file should be the same
27+
// file path and basename and with _label.txt suffix.
28+
// e.g. .../yolo.onnx, .../yolo_label.txt
29+
// If model loading fails, no exception is thrown but you can call `model_loaded()` to
30+
// check.
31+
YOLOv5Detector(const std::string& model_path);
2732
virtual ~YOLOv5Detector();
2833

34+
// If it loads the model successfully
35+
bool model_loaded() const { return m_yolo_session != nullptr; }
36+
2937
virtual void make_overlays(VideoOverlaySet& items) const override {}
3038
virtual bool detect(const ImageViewRGB32& screen) override;
3139

@@ -42,7 +50,12 @@ class YOLOv5Detector : public StaticScreenDetector{
4250

4351
class YOLOv5Watcher : public VisualInferenceCallback{
4452
public:
45-
YOLOv5Watcher(VideoOverlay& overlay);
53+
// - model_path: path to the onnx model file. The label name file should be the same
54+
// file path and basename and with _label.txt suffix.
55+
// e.g. .../yolo.onnx, .../yolo_label.txt
56+
// If model loading fails, no exception is thrown but you can call `model_loaded()` to
57+
// check.
58+
YOLOv5Watcher(VideoOverlay& overlay, const std::string& model_path);
4659
virtual ~YOLOv5Watcher() {}
4760

4861
virtual void make_overlays(VideoOverlaySet& items) const override {}

SerialPrograms/Source/ML/Models/ML_YOLOv5Model.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@ namespace ML{
1919
class YOLOv5Session{
2020
public:
2121
struct DetectionBox{
22+
// Confidence value the model predicts on the detection. Range: [0.0, 1.0].
23+
// The higher the value, the more confident the model thinks the prediction is.
2224
double score;
25+
// Bounding box of the detected object.
2326
ImageFloatBox box;
27+
// Object label ID.
2428
size_t label_idx;
2529
};
2630

SerialPrograms/Source/ML/Programs/ML_RunYOLO.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,21 @@ RunYOLO_Descriptor::RunYOLO_Descriptor()
3737

3838

3939

40-
RunYOLO::RunYOLO() {}
40+
RunYOLO::RunYOLO()
41+
: MODEL_PATH(
42+
"<b>YOLO Model Path:</b>",
43+
LockMode::UNLOCK_WHILE_RUNNING,
44+
RESOURCE_PATH() + "ML/yolov5.onnx",
45+
"*.onnx",
46+
"Path to YOLO .onnx model file"
47+
)
48+
{
49+
PA_ADD_OPTION(MODEL_PATH);
50+
}
4151

4252
void RunYOLO::program(NintendoSwitch::SingleSwitchProgramEnvironment& env, NintendoSwitch::ProControllerContext& context){
43-
YOLOv5Watcher watcher(env.console.overlay());
53+
std::string model_path = MODEL_PATH;
54+
YOLOv5Watcher watcher(env.console.overlay(), model_path);
4455

4556
wait_until(env.console, context, WallClock::max(), {watcher});
4657
}

SerialPrograms/Source/ML/Programs/ML_RunYOLO.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#ifndef PokemonAutomation_ML_RunYOLO_H
99
#define PokemonAutomation_ML_RunYOLO_H
1010

11+
#include "Common/Cpp/Options/PathOption.h"
1112
#include "NintendoSwitch/NintendoSwitch_SingleSwitchProgram.h"
1213

1314
namespace PokemonAutomation{
@@ -27,6 +28,7 @@ class RunYOLO : public NintendoSwitch::SingleSwitchProgramInstance{
2728
virtual void program(NintendoSwitch::SingleSwitchProgramEnvironment& env, NintendoSwitch::ProControllerContext& context) override;
2829

2930
private:
31+
PathOption MODEL_PATH;
3032
};
3133

3234

0 commit comments

Comments
 (0)