|
| 1 | +/* Test Dudunsparce Form Detector |
| 2 | + * |
| 3 | + * From: https://github.com/PokemonAutomation/Arduino-Source |
| 4 | + * |
| 5 | + */ |
| 6 | + |
| 7 | +#include "3rdParty/ONNX/OnnxToolsPA.h" |
| 8 | +#include "Common/Cpp/Time.h" |
| 9 | +#include "ClientSource/Connection/BotBase.h" |
| 10 | +#include "CommonFramework/ImageTypes/ImageRGB32.h" |
| 11 | +#include "CommonFramework/ImageTypes/ImageViewRGB32.h" |
| 12 | +#include "CommonFramework/ImageTools/ImageBoxes.h" |
| 13 | +#include "CommonTools/Async/InferenceRoutines.h" |
| 14 | +#include "CommonTools/Async/InferenceSession.h" |
| 15 | +#include "CommonTools/InferenceCallbacks/VisualInferenceCallback.h" |
| 16 | +#include "CommonFramework/VideoPipeline/VideoOverlay.h" |
| 17 | +#include "TestDudunsparceFormDetector.h" |
| 18 | + |
| 19 | +#include <opencv2/imgproc.hpp> |
| 20 | +#include <opencv2/imgcodecs.hpp> |
| 21 | + |
| 22 | +#include <onnxruntime_cxx_api.h> |
| 23 | + |
| 24 | +#include <vector> |
| 25 | +#include <iostream> |
| 26 | +using std::cout, std::endl; |
| 27 | + |
| 28 | +namespace PokemonAutomation{ |
| 29 | +namespace NintendoSwitch{ |
| 30 | + |
| 31 | +class DudunsparceFormDetector : public VisualInferenceCallback{ |
| 32 | +public: |
| 33 | + DudunsparceFormDetector(VideoOverlay& overlay); |
| 34 | + |
| 35 | + virtual void make_overlays(VideoOverlaySet& items) const override; |
| 36 | + virtual bool process_frame(const ImageViewRGB32& frame, WallClock timestamp) override final; |
| 37 | + |
| 38 | + std::string get_label() const { |
| 39 | + int detected_label_id = m_detected.load(std::memory_order_acquire); |
| 40 | + return labels[detected_label_id]; |
| 41 | + } |
| 42 | + |
| 43 | +private: |
| 44 | + VideoOverlaySet m_overlay_set; |
| 45 | + std::string model_path; |
| 46 | + ImagePixelBox m_pixel_box_1080p; |
| 47 | + ImageFloatBox m_float_box; |
| 48 | + std::atomic<int> m_detected; |
| 49 | + |
| 50 | + Ort::Env env; |
| 51 | + Ort::RunOptions runOptions; |
| 52 | + Ort::Session session; |
| 53 | + |
| 54 | + const char* labels[3]{"none", "three", "two"}; |
| 55 | +}; |
| 56 | + |
| 57 | + |
| 58 | +DudunsparceFormDetector::DudunsparceFormDetector(VideoOverlay& overlay) |
| 59 | + : VisualInferenceCallback("DudunsparceFormDetector") |
| 60 | + , m_overlay_set(overlay) |
| 61 | + , model_path("../../PAMLExperiments/dudunsparce/dudunsparce_form_detector.onnx") |
| 62 | + , m_detected(0) |
| 63 | + , session(nullptr) |
| 64 | +{ |
| 65 | + // The input data for the ML model is created by cropping an 1080P frame from Switch |
| 66 | + // The crop is at image[500:750, 1500:1750]. |
| 67 | + m_pixel_box_1080p = ImagePixelBox(1500, 500, 1750, 750); |
| 68 | + m_float_box = pixelbox_to_floatbox(1920, 1080, m_pixel_box_1080p); |
| 69 | + |
| 70 | + // learned from ONN Runtime example: https://github.com/cassiebreviu/cpp-onnxruntime-resnet-console-app/blob/main/OnnxRuntimeResNet/OnnxRuntimeResNet.cpp |
| 71 | + session = Ort::Session(env, str_to_onnx_str(model_path).c_str(), Ort::SessionOptions{nullptr}); |
| 72 | +} |
| 73 | + |
| 74 | +void DudunsparceFormDetector::make_overlays(VideoOverlaySet& items) const {} |
| 75 | + |
| 76 | +bool DudunsparceFormDetector::process_frame(const ImageViewRGB32& frame, WallClock timestamp){ |
| 77 | + m_overlay_set.clear(); |
| 78 | + |
| 79 | + ImageViewRGB32 cropped_frame = (frame.height() == 1080) ? extract_box_reference(frame, m_pixel_box_1080p) |
| 80 | + : extract_box_reference(frame, m_float_box); |
| 81 | + cv::Mat cropped_mat = cropped_frame.to_opencv_Mat(); |
| 82 | + cv::Mat resized_mat; // resize to the shape for the ML model input |
| 83 | + cv::resize(cropped_mat, resized_mat, cv::Size(25, 25)); |
| 84 | + cv::Mat resized_mat_gray; |
| 85 | + cv::cvtColor(resized_mat, resized_mat_gray, cv::COLOR_BGRA2GRAY); |
| 86 | + |
| 87 | + // cv::imwrite("./model_input.png", resized_mat_gray); |
| 88 | + |
| 89 | + cv::Mat float_mat; |
| 90 | + resized_mat_gray.convertTo(float_mat, CV_32F, 1./255); |
| 91 | + |
| 92 | + // ML stuff: prepare ML model input and output |
| 93 | + const std::array<int64_t, 3> inputShape = {1, 25, 25}; |
| 94 | + const std::array<int64_t, 2> outputShape = {1, 3}; |
| 95 | + |
| 96 | + std::array<float, 25 * 25> modelInput; |
| 97 | + std::array<float, 3> modelOutput; |
| 98 | + auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); |
| 99 | + auto inputTensor = Ort::Value::CreateTensor<float>(memoryInfo, modelInput.data(), modelInput.size(), inputShape.data(), inputShape.size()); |
| 100 | + auto outputTensor = Ort::Value::CreateTensor<float>(memoryInfo, modelOutput.data(), modelOutput.size(), outputShape.data(), outputShape.size()); |
| 101 | + |
| 102 | + for (int row = 0, p_loc=0; row < 25; row++){ |
| 103 | + for(int col = 0; col < 25; col++){ |
| 104 | + float p = float_mat.at<float>(row, col); |
| 105 | + modelInput[p_loc++] = p; |
| 106 | + } |
| 107 | + } |
| 108 | + |
| 109 | + Ort::AllocatorWithDefaultOptions ort_alloc; |
| 110 | + Ort::AllocatedStringPtr inputName = session.GetInputNameAllocated(0, ort_alloc); |
| 111 | + Ort::AllocatedStringPtr outputName = session.GetOutputNameAllocated(0, ort_alloc); |
| 112 | + const std::array<const char*, 1> inputNames = {inputName.get()}; |
| 113 | + const std::array<const char*, 1> outputNames = {outputName.get()}; |
| 114 | + inputName.release(); |
| 115 | + outputName.release(); |
| 116 | + |
| 117 | + session.Run(runOptions, inputNames.data(), &inputTensor, 1, outputNames.data(), &outputTensor, 1); |
| 118 | + |
| 119 | + double max_value = -DBL_MAX; |
| 120 | + int max_value_label = 0; |
| 121 | + for (int i = 0; i < 3; i++){ |
| 122 | + // cout << labels[i] << ": " << modelOutput[i] << ", "; |
| 123 | + if (modelOutput[i] > max_value){ |
| 124 | + max_value = modelOutput[i]; |
| 125 | + max_value_label = i; |
| 126 | + } |
| 127 | + } |
| 128 | + cout << "Detector dudunsparce form: " << labels[max_value_label] << endl; |
| 129 | + m_detected.store(max_value_label); |
| 130 | + |
| 131 | + return false; |
| 132 | +} |
| 133 | + |
| 134 | + |
| 135 | +TestDudunsparceFormDetector_Descriptor::TestDudunsparceFormDetector_Descriptor() |
| 136 | + : SingleSwitchProgramDescriptor( |
| 137 | + "NintendoSwitch:DudunsparceFormDetector", |
| 138 | + "Nintendo Switch", "Test Dudunsparce Form Detector", |
| 139 | + "", |
| 140 | + "Test ML model on Dudunsparce form in SV box system", |
| 141 | + FeedbackType::NONE, AllowCommandsWhenRunning::ENABLE_COMMANDS, |
| 142 | + {ControllerFeature::NintendoSwitch_ProController} |
| 143 | + ) |
| 144 | +{} |
| 145 | + |
| 146 | +TestDudunsparceFormDetector::TestDudunsparceFormDetector(){} |
| 147 | + |
| 148 | + |
| 149 | +void TestDudunsparceFormDetector::program(SingleSwitchProgramEnvironment& env, ProControllerContext& context){ |
| 150 | + |
| 151 | + DudunsparceFormDetector detector(env.console.overlay()); |
| 152 | + |
| 153 | + // ImageRGB32 test_image("../../datasets/Bidoof/images/test/im0005.png"); |
| 154 | + // detector.process_frame(test_image, current_time()); |
| 155 | + |
| 156 | + // return; |
| 157 | + // InferenceSession session( |
| 158 | + // context, env.console, |
| 159 | + // {{detector, std::chrono::milliseconds(100)}} |
| 160 | + // ); |
| 161 | + // context.wait_until_cancel(); |
| 162 | + |
| 163 | + std::string last_label = ""; |
| 164 | + run_until<ProControllerContext>( |
| 165 | + env.console, context, [&](ProControllerContext& context){ |
| 166 | + while (true){ |
| 167 | + std::string cur_label = detector.get_label(); |
| 168 | + if (cur_label != last_label){ |
| 169 | + last_label = cur_label; |
| 170 | + env.console.overlay().add_log("Detected " + last_label); |
| 171 | + } |
| 172 | + context.wait_for(std::chrono::milliseconds(100)); |
| 173 | + } |
| 174 | + }, |
| 175 | + {detector}, |
| 176 | + std::chrono::milliseconds(100) |
| 177 | + ); |
| 178 | + |
| 179 | + |
| 180 | + std::cout << "ML detection test program finished." << std::endl; |
| 181 | +} |
| 182 | + |
| 183 | + |
| 184 | +} |
| 185 | +} |
0 commit comments