Skip to content

Commit eac4c83

Browse files
author
Gin
committed
Working on YOLOv5 inference
1 parent acb51ec commit eac4c83

File tree

11 files changed

+201
-49
lines changed

11 files changed

+201
-49
lines changed

SerialPrograms/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,8 @@ file(GLOB MAIN_SOURCES
913913
Source/ML/DataLabeling/ML_SegmentAnythingModel.cpp
914914
Source/ML/DataLabeling/ML_SegmentAnythingModel.h
915915
Source/ML/DataLabeling/ML_SegmentAnythingModelConstants.h
916+
Source/ML/Inference/ML_YOLOv5Detector.cpp
917+
Source/ML/Inference/ML_YOLOv5Detector.h
916918
Source/ML/ML_Panels.cpp
917919
Source/ML/ML_Panels.h
918920
Source/ML/Models/ML_ONNXRuntimeHelpers.cpp
@@ -1845,7 +1847,7 @@ file(GLOB MAIN_SOURCES
18451847
Source/PokemonSV/Programs/Farming/PokemonSV_BlueberryQuests.cpp
18461848
Source/PokemonSV/Programs/Farming/PokemonSV_BlueberryQuests.h
18471849
Source/PokemonSV/Programs/Farming/PokemonSV_ClaimMysteryGift.cpp
1848-
Source/PokemonSV/Programs/Farming/PokemonSV_ClaimMysteryGift.h
1850+
Source/PokemonSV/Programs/Farming/PokemonSV_ClaimMysteryGift.h
18491851
Source/PokemonSV/Programs/Farming/PokemonSV_ESPTraining.cpp
18501852
Source/PokemonSV/Programs/Farming/PokemonSV_ESPTraining.h
18511853
Source/PokemonSV/Programs/Farming/PokemonSV_FlyingTrialFarmer.cpp
@@ -2572,7 +2574,7 @@ else() # macOS and Linux
25722574
target_include_directories(SerialPrograms PRIVATE ${HOMEBREW_PREFIX}/include/onnxruntime)
25732575
target_link_libraries(SerialPrograms PRIVATE ${HOMEBREW_PREFIX}/lib/libonnxruntime.dylib)
25742576
endif()
2575-
2577+
25762578
else() # Linux
25772579
# Instructions took from https://github.com/microsoft/onnxruntime/discussions/6489
25782580
# Step 1: Download the file

SerialPrograms/Scripts/CodeTemplates/VisualDetector/GameName_ObjectNameDetector.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
#include "Common/Cpp/Containers/FixedLimitVector.h"
1313
#include "CommonFramework/ImageTools/ImageBoxes.h"
1414
#include "CommonFramework/VideoPipeline/VideoOverlayScopes.h"
15-
#include "CommonFramework/InferenceInfra/VisualInferenceCallback.h"
16-
#include "CommonFramework/Inference/VisualDetector.h"
15+
#include "CommonTools/InferenceCallbacks/VisualInferenceCallback.h"
16+
#include "CommonTools/VisualDetector.h"
1717

1818
namespace PokemonAutomation{
1919

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import onnx
2+
import hashlib
3+
import sys
4+
5+
# You can use any other hash algorithms to ensure the model and its hash-value is a one-one mapping.
6+
def hash_file(file_path:str, algorithm:str='sha256', chunk_size:int=8192):
7+
hash_func = hashlib.new(algorithm)
8+
with open(file_path, 'rb') as file:
9+
while chunk := file.read(chunk_size):
10+
hash_func.update(chunk)
11+
return hash_func.hexdigest()
12+
13+
CACHE_KEY_NAME = "CACHE_KEY"
14+
15+
16+
if len(sys.argv) == 1:
17+
print(f"Usage: {sys.argv[0]} <model_path>")
18+
exit(0)
19+
20+
model_path = sys.argv[1]
21+
print(f"Adding a cache value to the metadata_props of the model: {model_path}")
22+
23+
if not model_path.endswith(".onnx"):
24+
print(f"Error: model path is not an onnx file")
25+
26+
m = onnx.load(model_path)
27+
28+
cache_key = m.metadata_props.add()
29+
cache_key.key = CACHE_KEY_NAME
30+
cache_key.value = str(hash_file(model_path))
31+
32+
print(f"Added key {CACHE_KEY_NAME} and value {cache_key.value}")
33+
34+
onnx.save_model(m, model_path)
35+
36+
print(f"Saved model to the same path, {model_path}")

SerialPrograms/Source/ML/DataLabeling/ML_SegmentAnythingModel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace ML{
2424

2525

2626
SAMEmbedderSession::SAMEmbedderSession(const std::string& model_path)
27-
: session_options(create_session_option())
27+
: session_options(create_session_option("SAMEmbedder"))
2828
, session{env, str_to_onnx_str(model_path).c_str(), session_options}
2929
, memory_info{Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU)}
3030
, input_names{session.GetInputNames()}
@@ -64,7 +64,7 @@ void SAMEmbedderSession::run(cv::Mat& input_image, std::vector<float>& model_out
6464

6565

6666
SAMSession::SAMSession(const std::string& model_path)
67-
: session_options(create_session_option())
67+
: session_options(create_session_option("SAM"))
6868
, session{env, str_to_onnx_str(model_path).c_str(), session_options}
6969
, memory_info{Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU)}
7070
, input_names{session.GetInputNames()}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/* YOLOv5 Detector
2+
*
3+
* From: https://github.com/PokemonAutomation/Arduino-Source
4+
*
5+
*/
6+
7+
#include <filesystem>
8+
#include <iostream>
9+
#include <QMessageBox>
10+
#include <opencv2/imgproc.hpp>
11+
#include <opencv2/imgcodecs.hpp>
12+
#include "CommonFramework/ImageTypes/ImageViewRGB32.h"
13+
#include "CommonFramework/VideoPipeline/VideoOverlay.h"
14+
#include "CommonFramework/VideoPipeline/VideoOverlayScopes.h"
15+
#include "CommonFramework/Globals.h"
16+
#include "ML_YOLOv5Detector.h"
17+
18+
//#include <iostream>
19+
//using std::cout;
20+
//using std::endl;
21+
22+
namespace PokemonAutomation{
23+
namespace ML{
24+
25+
26+
YOLOv5Detector::~YOLOv5Detector() = default;
27+
28+
YOLOv5Detector::YOLOv5Detector()
29+
{
30+
const std::string sam_model_path = RESOURCE_PATH() + "ML/yolov5.onnx";
31+
std::vector<std::string> labels = {"Bidoof"};
32+
if (std::filesystem::exists(sam_model_path)){
33+
m_yolo_session = std::make_unique<YOLOv5Session>(sam_model_path, std::move(labels));
34+
} else{
35+
std::cerr << "Error: no such YOLOv5 model path " << sam_model_path << "." << std::endl;
36+
QMessageBox box;
37+
box.critical(nullptr, "YOLOv5 Model Does Not Exist",
38+
QString::fromStdString("YOLOv5 model path" + sam_model_path + " does not exist."));
39+
}
40+
}
41+
42+
bool YOLOv5Detector::detect(const ImageViewRGB32& screen){
43+
if (!m_yolo_session){
44+
return false;
45+
}
46+
47+
cv::Mat frame_mat_bgra = screen.to_opencv_Mat();
48+
cv::Mat frame_mat_rgb;
49+
cv::cvtColor(frame_mat_bgra, frame_mat_rgb, cv::COLOR_BGRA2RGB);
50+
51+
m_output_boxes.clear();
52+
m_yolo_session->run(frame_mat_rgb, m_output_boxes);
53+
54+
return m_output_boxes.size() > 0;
55+
}
56+
57+
58+
YOLOv5Watcher::YOLOv5Watcher(VideoOverlay& overlay)
59+
: VisualInferenceCallback("YOLOv5")
60+
, m_overlay_set(overlay)
61+
{
62+
}
63+
64+
bool YOLOv5Watcher::process_frame(const ImageViewRGB32& frame, WallClock timestamp){
65+
if (!m_detector.session()){
66+
return false;
67+
}
68+
69+
m_detector.detect(frame);
70+
71+
m_overlay_set.clear();
72+
for(const auto& box : m_detector.detected_boxes()){
73+
m_overlay_set.add(COLOR_RED, box.box, m_detector.session()->label_name(box.label_idx));
74+
}
75+
return false;
76+
}
77+
78+
79+
80+
81+
82+
}
83+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/* YOLOv5 Detector
2+
*
3+
* From: https://github.com/PokemonAutomation/Arduino-Source
4+
*
5+
*/
6+
7+
#ifndef PokemonAutomation_ML_YOLOv5Detector_H
8+
#define PokemonAutomation_ML_YOLOv5Detector_H
9+
10+
#include <vector>
11+
#include "Common/Cpp/Color.h"
12+
#include "CommonFramework/VideoPipeline/VideoOverlayScopes.h"
13+
#include "CommonTools/InferenceCallbacks/VisualInferenceCallback.h"
14+
#include "CommonTools/VisualDetector.h"
15+
#include "ML/Models/ML_YOLOv5Model.h"
16+
17+
namespace PokemonAutomation{
18+
19+
class VideoOverlay;
20+
21+
namespace ML{
22+
23+
24+
class YOLOv5Detector : public StaticScreenDetector{
25+
public:
26+
YOLOv5Detector();
27+
virtual ~YOLOv5Detector();
28+
29+
virtual void make_overlays(VideoOverlaySet& items) const override {}
30+
virtual bool detect(const ImageViewRGB32& screen) override;
31+
32+
const std::vector<YOLOv5Session::DetectionBox>& detected_boxes() const { return m_output_boxes; }
33+
34+
const std::unique_ptr<YOLOv5Session>& session() const { return m_yolo_session; }
35+
36+
protected:
37+
std::unique_ptr<YOLOv5Session> m_yolo_session;
38+
std::vector<YOLOv5Session::DetectionBox> m_output_boxes;
39+
};
40+
41+
42+
43+
class YOLOv5Watcher : public VisualInferenceCallback{
44+
public:
45+
YOLOv5Watcher(VideoOverlay& overlay);
46+
virtual ~YOLOv5Watcher() {}
47+
48+
virtual void make_overlays(VideoOverlaySet& items) const override {}
49+
virtual bool process_frame(const ImageViewRGB32& frame, WallClock timestamp) override;
50+
51+
52+
protected:
53+
VideoOverlaySet m_overlay_set;
54+
YOLOv5Detector m_detector;
55+
};
56+
57+
58+
59+
}
60+
}
61+
#endif

SerialPrograms/Source/ML/Models/ML_ONNXRuntimeHelpers.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
namespace PokemonAutomation{
1313
namespace ML{
1414

15-
Ort::SessionOptions create_session_option(){
15+
Ort::SessionOptions create_session_option(const std::string& cache_folder_name){
1616
Ort::SessionOptions so;
1717

1818
#if __APPLE__
@@ -21,7 +21,8 @@ Ort::SessionOptions create_session_option(){
2121
// See for provider options: https://onnxruntime.ai/docs/execution-providers/CoreML-ExecutionProvider.html
2222
// "NeuralNetwork" is a faster ModelFormat than "MLProgram".
2323
provider_options["ModelFormat"] = std::string("NeuralNetwork");
24-
provider_options["ModelCacheDirectory"] = "./ModelCache/";
24+
// TODO: need to make sure the cache works
25+
provider_options["ModelCacheDirectory"] = "./ModelCache/" + cache_folder_name;
2526
// provider_options["MLComputeUnits"] = "ALL";
2627
// provider_options["RequireStaticInputShapes"] = "0";
2728
// provider_options["EnableOnSubgraphs"] = "0";

SerialPrograms/Source/ML/Models/ML_ONNXRuntimeHelpers.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ namespace ML{
2121
// If on macOS, will use CoreML as the backend.
2222
// Otherwise, use CPU to run the model.
2323
// TODO: add Cuda backend for Windows machine.
24-
Ort::SessionOptions create_session_option();
24+
// cache_folder_name: the folder name in under ./ModelCache/ to store model caches. This name is better
25+
// to be unique for each model for easier file management.
26+
Ort::SessionOptions create_session_option(const std::string& cache_folder_name);
2527

2628
// Handy function to create an ONNX Runtime tensor view class from a vector-like `buffer` object holding
2729
// the tensor data and an array-like `shape` object that represents the dimension of the tensor.

SerialPrograms/Source/ML/Models/ML_YOLOv5Model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ std::tuple<int, int, double, double> resize_image_with_border(
5858

5959
YOLOv5Session::YOLOv5Session(const std::string& model_path, std::vector<std::string> label_names)
6060
: m_label_names(std::move(label_names))
61-
, m_session_options(create_session_option())
61+
, m_session_options(create_session_option("YOLOv5"))
6262
, m_session{m_env, str_to_onnx_str(model_path).c_str(), m_session_options}
6363
, m_memory_info{Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU)}
6464
, m_input_names{m_session.GetInputNames()}

SerialPrograms/Source/ML/Programs/ML_RunYOLO.cpp

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
#include <opencv2/imgcodecs.hpp>
1313
#include "Common/Cpp/PrettyPrint.h"
1414
#include "CommonFramework/Globals.h"
15+
#include "CommonTools/Async/InferenceRoutines.h"
1516
#include "CommonFramework/VideoPipeline/VideoOverlayScopes.h"
1617
#include "CommonFramework/VideoPipeline/VideoFeed.h"
18+
#include "ML/Inference/ML_YOLOv5Detector.h"
1719
#include "ML_RunYOLO.h"
1820

1921
namespace PokemonAutomation{
@@ -34,43 +36,13 @@ RunYOLO_Descriptor::RunYOLO_Descriptor()
3436

3537

3638

37-
RunYOLO::RunYOLO()
38-
{
39-
const std::string sam_model_path = RESOURCE_PATH() + "ML/yolov5_cpu.onnx";
40-
std::vector<std::string> labels = {"Bidoof"};
41-
if (std::filesystem::exists(sam_model_path)){
42-
m_yolo_session = std::make_unique<YOLOv5Session>(sam_model_path, labels);
43-
} else{
44-
std::cerr << "Error: no such YOLOv5 model path " << sam_model_path << "." << std::endl;
45-
QMessageBox box;
46-
box.critical(nullptr, "YOLOv5 Model Does Not Exist",
47-
QString::fromStdString("YOLOv5 model path" + sam_model_path + " does not exist."));
48-
}
49-
50-
}
39+
RunYOLO::RunYOLO() {}
5140

5241
void RunYOLO::program(NintendoSwitch::SingleSwitchProgramEnvironment& env, NintendoSwitch::ProControllerContext& context){
53-
if (!m_yolo_session){
54-
return;
55-
}
56-
57-
VideoOverlaySet overlay_set(env.console.overlay());
42+
43+
YOLOv5Watcher watcher(env.console.overlay());
5844

59-
std::vector<YOLOv5Session::DetectionBox> output_boxes;
60-
while (true){
61-
VideoSnapshot last = env.console.video().snapshot();
62-
cv::Mat frame_mat_bgra = last.frame->to_opencv_Mat();
63-
cv::Mat frame_mat_rgb;
64-
cv::cvtColor(frame_mat_bgra, frame_mat_rgb, cv::COLOR_BGRA2RGB);
65-
66-
output_boxes.clear();
67-
m_yolo_session->run(frame_mat_rgb, output_boxes);
68-
overlay_set.clear();
69-
for(const auto& box : output_boxes){
70-
overlay_set.add(COLOR_RED, box.box, m_yolo_session->label_name(box.label_idx));
71-
}
72-
// context.wait_until(last.timestamp + std::chrono::milliseconds(PERIOD_MILLISECONDS));
73-
}
45+
wait_until(env.console, context, WallClock::max(), {watcher});
7446
}
7547

7648

0 commit comments

Comments
 (0)