Skip to content

Commit c57b218

Browse files
committed
Gin's Dudunsparse experiment.
1 parent 8fd2a1b commit c57b218

File tree

3 files changed

+224
-0
lines changed

3 files changed

+224
-0
lines changed
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
/* Test Dudunsparce Form Detector
2+
*
3+
* From: https://github.com/PokemonAutomation/Arduino-Source
4+
*
5+
*/
6+
7+
#include "3rdParty/ONNX/OnnxToolsPA.h"
8+
#include "Common/Cpp/Time.h"
9+
#include "ClientSource/Connection/BotBase.h"
10+
#include "CommonFramework/ImageTypes/ImageRGB32.h"
11+
#include "CommonFramework/ImageTypes/ImageViewRGB32.h"
12+
#include "CommonFramework/ImageTools/ImageBoxes.h"
13+
#include "CommonTools/Async/InferenceRoutines.h"
14+
#include "CommonTools/Async/InferenceSession.h"
15+
#include "CommonTools/InferenceCallbacks/VisualInferenceCallback.h"
16+
#include "CommonFramework/VideoPipeline/VideoOverlay.h"
17+
#include "TestDudunsparceFormDetector.h"
18+
19+
#include <opencv2/imgproc.hpp>
20+
#include <opencv2/imgcodecs.hpp>
21+
22+
#include <onnxruntime_cxx_api.h>
23+
24+
#include <vector>
25+
#include <iostream>
26+
using std::cout, std::endl;
27+
28+
namespace PokemonAutomation{
29+
namespace NintendoSwitch{
30+
31+
class DudunsparceFormDetector : public VisualInferenceCallback{
32+
public:
33+
DudunsparceFormDetector(VideoOverlay& overlay);
34+
35+
virtual void make_overlays(VideoOverlaySet& items) const override;
36+
virtual bool process_frame(const ImageViewRGB32& frame, WallClock timestamp) override final;
37+
38+
std::string get_label() const {
39+
int detected_label_id = m_detected.load(std::memory_order_acquire);
40+
return labels[detected_label_id];
41+
}
42+
43+
private:
44+
VideoOverlaySet m_overlay_set;
45+
std::string model_path;
46+
ImagePixelBox m_pixel_box_1080p;
47+
ImageFloatBox m_float_box;
48+
std::atomic<int> m_detected;
49+
50+
Ort::Env env;
51+
Ort::RunOptions runOptions;
52+
Ort::Session session;
53+
54+
const char* labels[3]{"none", "three", "two"};
55+
};
56+
57+
58+
DudunsparceFormDetector::DudunsparceFormDetector(VideoOverlay& overlay)
59+
: VisualInferenceCallback("DudunsparceFormDetector")
60+
, m_overlay_set(overlay)
61+
, model_path("../../PAMLExperiments/dudunsparce/dudunsparce_form_detector.onnx")
62+
, m_detected(0)
63+
, session(nullptr)
64+
{
65+
// The input data for the ML model is created by cropping an 1080P frame from Switch
66+
// The crop is at image[500:750, 1500:1750].
67+
m_pixel_box_1080p = ImagePixelBox(1500, 500, 1750, 750);
68+
m_float_box = pixelbox_to_floatbox(1920, 1080, m_pixel_box_1080p);
69+
70+
// learned from ONN Runtime example: https://github.com/cassiebreviu/cpp-onnxruntime-resnet-console-app/blob/main/OnnxRuntimeResNet/OnnxRuntimeResNet.cpp
71+
session = Ort::Session(env, str_to_onnx_str(model_path).c_str(), Ort::SessionOptions{nullptr});
72+
}
73+
74+
void DudunsparceFormDetector::make_overlays(VideoOverlaySet& items) const {}
75+
76+
bool DudunsparceFormDetector::process_frame(const ImageViewRGB32& frame, WallClock timestamp){
77+
m_overlay_set.clear();
78+
79+
ImageViewRGB32 cropped_frame = (frame.height() == 1080) ? extract_box_reference(frame, m_pixel_box_1080p)
80+
: extract_box_reference(frame, m_float_box);
81+
cv::Mat cropped_mat = cropped_frame.to_opencv_Mat();
82+
cv::Mat resized_mat; // resize to the shape for the ML model input
83+
cv::resize(cropped_mat, resized_mat, cv::Size(25, 25));
84+
cv::Mat resized_mat_gray;
85+
cv::cvtColor(resized_mat, resized_mat_gray, cv::COLOR_BGRA2GRAY);
86+
87+
// cv::imwrite("./model_input.png", resized_mat_gray);
88+
89+
cv::Mat float_mat;
90+
resized_mat_gray.convertTo(float_mat, CV_32F, 1./255);
91+
92+
// ML stuff: prepare ML model input and output
93+
const std::array<int64_t, 3> inputShape = {1, 25, 25};
94+
const std::array<int64_t, 2> outputShape = {1, 3};
95+
96+
std::array<float, 25 * 25> modelInput;
97+
std::array<float, 3> modelOutput;
98+
auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
99+
auto inputTensor = Ort::Value::CreateTensor<float>(memoryInfo, modelInput.data(), modelInput.size(), inputShape.data(), inputShape.size());
100+
auto outputTensor = Ort::Value::CreateTensor<float>(memoryInfo, modelOutput.data(), modelOutput.size(), outputShape.data(), outputShape.size());
101+
102+
for (int row = 0, p_loc=0; row < 25; row++){
103+
for(int col = 0; col < 25; col++){
104+
float p = float_mat.at<float>(row, col);
105+
modelInput[p_loc++] = p;
106+
}
107+
}
108+
109+
Ort::AllocatorWithDefaultOptions ort_alloc;
110+
Ort::AllocatedStringPtr inputName = session.GetInputNameAllocated(0, ort_alloc);
111+
Ort::AllocatedStringPtr outputName = session.GetOutputNameAllocated(0, ort_alloc);
112+
const std::array<const char*, 1> inputNames = {inputName.get()};
113+
const std::array<const char*, 1> outputNames = {outputName.get()};
114+
inputName.release();
115+
outputName.release();
116+
117+
session.Run(runOptions, inputNames.data(), &inputTensor, 1, outputNames.data(), &outputTensor, 1);
118+
119+
double max_value = -DBL_MAX;
120+
int max_value_label = 0;
121+
for (int i = 0; i < 3; i++){
122+
// cout << labels[i] << ": " << modelOutput[i] << ", ";
123+
if (modelOutput[i] > max_value){
124+
max_value = modelOutput[i];
125+
max_value_label = i;
126+
}
127+
}
128+
cout << "Detector dudunsparce form: " << labels[max_value_label] << endl;
129+
m_detected.store(max_value_label);
130+
131+
return false;
132+
}
133+
134+
135+
TestDudunsparceFormDetector_Descriptor::TestDudunsparceFormDetector_Descriptor()
136+
: SingleSwitchProgramDescriptor(
137+
"NintendoSwitch:DudunsparceFormDetector",
138+
"Nintendo Switch", "Test Dudunsparce Form Detector",
139+
"",
140+
"Test ML model on Dudunsparce form in SV box system",
141+
FeedbackType::NONE, AllowCommandsWhenRunning::ENABLE_COMMANDS,
142+
{ControllerFeature::NintendoSwitch_ProController}
143+
)
144+
{}
145+
146+
TestDudunsparceFormDetector::TestDudunsparceFormDetector(){}
147+
148+
149+
void TestDudunsparceFormDetector::program(SingleSwitchProgramEnvironment& env, ProControllerContext& context){
150+
151+
DudunsparceFormDetector detector(env.console.overlay());
152+
153+
// ImageRGB32 test_image("../../datasets/Bidoof/images/test/im0005.png");
154+
// detector.process_frame(test_image, current_time());
155+
156+
// return;
157+
// InferenceSession session(
158+
// context, env.console,
159+
// {{detector, std::chrono::milliseconds(100)}}
160+
// );
161+
// context.wait_until_cancel();
162+
163+
std::string last_label = "";
164+
run_until<ProControllerContext>(
165+
env.console, context, [&](ProControllerContext& context){
166+
while (true){
167+
std::string cur_label = detector.get_label();
168+
if (cur_label != last_label){
169+
last_label = cur_label;
170+
env.console.overlay().add_log("Detected " + last_label);
171+
}
172+
context.wait_for(std::chrono::milliseconds(100));
173+
}
174+
},
175+
{detector},
176+
std::chrono::milliseconds(100)
177+
);
178+
179+
180+
std::cout << "ML detection test program finished." << std::endl;
181+
}
182+
183+
184+
}
185+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/* Test Dudunsparce Form Detector
2+
*
3+
* From: https://github.com/PokemonAutomation/Arduino-Source
4+
*
5+
*/
6+
7+
#ifndef PokemonAutomation_NintendoSwitch_TestDudunsparceFormDetector_H
8+
#define PokemonAutomation_NintendoSwitch_TestDudunsparceFormDetector_H
9+
10+
#include "NintendoSwitch/NintendoSwitch_SingleSwitchProgram.h"
11+
12+
namespace PokemonAutomation{
13+
namespace NintendoSwitch{
14+
15+
16+
class TestDudunsparceFormDetector_Descriptor : public SingleSwitchProgramDescriptor{
17+
public:
18+
TestDudunsparceFormDetector_Descriptor();
19+
};
20+
21+
22+
class TestDudunsparceFormDetector : public SingleSwitchProgramInstance{
23+
public:
24+
TestDudunsparceFormDetector();
25+
26+
virtual void program(SingleSwitchProgramEnvironment& env, ProControllerContext& context) override;
27+
28+
private:
29+
};
30+
31+
32+
33+
}
34+
}
35+
#endif // TESTPATHMAKER_H
36+

SerialPrograms/Source/NintendoSwitch/NintendoSwitch_Panels.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
#include "Pokemon/Inference/Pokemon_TrainIVCheckerOCR.h"
2929
#include "Pokemon/Inference/Pokemon_TrainPokemonOCR.h"
3030

31+
#include "DevPrograms/TestDudunsparceFormDetector.h"
32+
3133
#ifdef PA_OFFICIAL
3234
#include "../../Internal/SerialPrograms/NintendoSwitch_TestPrograms.h"
3335
#endif
@@ -71,6 +73,7 @@ std::vector<PanelEntry> PanelListFactory::make_panels() const{
7173
ret.emplace_back(make_single_switch_program<JoyconProgram_Descriptor, JoyconProgram>());
7274
ret.emplace_back(make_computer_program<Pokemon::TrainIVCheckerOCR_Descriptor, Pokemon::TrainIVCheckerOCR>());
7375
ret.emplace_back(make_computer_program<Pokemon::TrainPokemonOCR_Descriptor, Pokemon::TrainPokemonOCR>());
76+
ret.emplace_back(make_single_switch_program<TestDudunsparceFormDetector_Descriptor, TestDudunsparceFormDetector>());
7477
#ifdef PA_OFFICIAL
7578
add_panels(ret);
7679
#endif

0 commit comments

Comments
 (0)