Skip to content

Commit 671f47e

Browse files
committed
Add a preliminary global parallelizer for heavy inference. Parallelize the silhouette matcher.
1 parent b2d69c8 commit 671f47e

File tree

8 files changed

+116
-6
lines changed

8 files changed

+116
-6
lines changed

Common/Cpp/Concurrency/ParallelTaskRunner.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
#include "Common/Cpp/PanicDump.h"
1111
#include "ParallelTaskRunner.h"
1212

13+
//#include <iostream>
14+
//using std::cout;
15+
//using std::endl;
16+
1317
namespace PokemonAutomation{
1418

1519

@@ -72,6 +76,37 @@ std::shared_ptr<AsyncTask> ParallelTaskRunner::dispatch(std::function<void()>&&
7276
}
7377

7478

79+
void ParallelTaskRunner::run_in_parallel(
80+
const std::function<void(size_t index)>& func,
81+
size_t start, size_t end,
82+
size_t block_size
83+
){
84+
if (start >= end){
85+
return;
86+
}
87+
size_t total = end - start;
88+
size_t blocks = (total + block_size - 1) / block_size;
89+
90+
std::vector<std::shared_ptr<AsyncTask>> tasks;
91+
for (size_t c = 0; c < blocks; c++){
92+
tasks.emplace_back(dispatch([=, &func]{
93+
size_t s = start + c * block_size;
94+
size_t e = std::min(s + block_size, end);
95+
// cout << "Running: [" << s << "," << e << ")" << endl;
96+
for (; s < e; s++){
97+
func(s);
98+
}
99+
}));
100+
}
101+
102+
for (std::shared_ptr<AsyncTask>& task : tasks){
103+
task->wait_and_rethrow_exceptions();
104+
}
105+
}
106+
107+
108+
109+
75110
void ParallelTaskRunner::thread_loop(){
76111
//#if _WIN32
77112
// SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_IDLE);

Common/Cpp/Concurrency/ParallelTaskRunner.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ class ParallelTaskRunner{
2525

2626
std::shared_ptr<AsyncTask> dispatch(std::function<void()>&& func);
2727

28+
void run_in_parallel(
29+
const std::function<void(size_t index)>& func,
30+
size_t start, size_t end,
31+
size_t block_size = 0
32+
);
33+
2834

2935
private:
3036
// void dispatch_task(AsyncTask& task);

SerialPrograms/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,7 @@ file(GLOB MAIN_SOURCES
173173
../Common/Microcontroller/DeviceRoutines.h
174174
../Common/Microcontroller/MessageProtocol.h
175175
../Common/NintendoSwitch/NintendoSwitch_ControllerDefs.h
176-
../Common/NintendoSwitch/NintendoSwitch_Protocol_DigitEntry.h
177176
../Common/NintendoSwitch/NintendoSwitch_Protocol_PushButtons.h
178-
../Common/NintendoSwitch/NintendoSwitch_Protocol_Routines.h
179177
../Common/NintendoSwitch/NintendoSwitch_Protocol_Superscalar.h
180178
../Common/NintendoSwitch/NintendoSwitch_SlotDatabase.h
181179
../Common/PokemonSwSh/PokemonProgramIDs.h
@@ -535,6 +533,8 @@ file(GLOB MAIN_SOURCES
535533
Source/CommonTools/InferencePivots/AudioInferencePivot.h
536534
Source/CommonTools/InferencePivots/VisualInferencePivot.cpp
537535
Source/CommonTools/InferencePivots/VisualInferencePivot.h
536+
Source/CommonTools/GlobalInferenceRunner.cpp
537+
Source/CommonTools/GlobalInferenceRunner.h
538538
Source/CommonTools/InferenceThrottler.h
539539
Source/CommonTools/OCR/OCR_DictionaryMatcher.cpp
540540
Source/CommonTools/OCR/OCR_DictionaryMatcher.h

SerialPrograms/SerialPrograms.pro

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ SOURCES += \
254254
Source/CommonTools/Audio/AudioPerSpectrumDetectorBase.cpp \
255255
Source/CommonTools/Audio/AudioTemplateCache.cpp \
256256
Source/CommonTools/Audio/SpectrogramMatcher.cpp \
257+
Source/CommonTools/GlobalInferenceRunner.cpp \
257258
Source/CommonTools/ImageMatch/CroppedImageDictionaryMatcher.cpp \
258259
Source/CommonTools/ImageMatch/ExactImageDictionaryMatcher.cpp \
259260
Source/CommonTools/ImageMatch/ExactImageMatcher.cpp \
@@ -1184,9 +1185,7 @@ HEADERS += \
11841185
../Common/Microcontroller/DeviceRoutines.h \
11851186
../Common/Microcontroller/MessageProtocol.h \
11861187
../Common/NintendoSwitch/NintendoSwitch_ControllerDefs.h \
1187-
../Common/NintendoSwitch/NintendoSwitch_Protocol_DigitEntry.h \
11881188
../Common/NintendoSwitch/NintendoSwitch_Protocol_PushButtons.h \
1189-
../Common/NintendoSwitch/NintendoSwitch_Protocol_Routines.h \
11901189
../Common/NintendoSwitch/NintendoSwitch_Protocol_Superscalar.h \
11911190
../Common/NintendoSwitch/NintendoSwitch_SlotDatabase.h \
11921191
../Common/PokemonSwSh/PokemonProgramIDs.h \
@@ -1362,6 +1361,7 @@ HEADERS += \
13621361
Source/CommonTools/Audio/AudioTemplateCache.h \
13631362
Source/CommonTools/Audio/SpectrogramMatcher.h \
13641363
Source/CommonTools/DetectionDebouncer.h \
1364+
Source/CommonTools/GlobalInferenceRunner.h \
13651365
Source/CommonTools/ImageMatch/CroppedImageDictionaryMatcher.h \
13661366
Source/CommonTools/ImageMatch/ExactImageDictionaryMatcher.h \
13671367
Source/CommonTools/ImageMatch/ExactImageMatcher.h \
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/* Global Inference Runner
2+
*
3+
* From: https://github.com/PokemonAutomation/Arduino-Source
4+
*
5+
*/
6+
7+
#include "CommonFramework/GlobalSettingsPanel.h"
8+
#include "CommonFramework/Options/Environment/PerformanceOptions.h"
9+
#include "GlobalInferenceRunner.h"
10+
11+
namespace PokemonAutomation{
12+
13+
14+
15+
ParallelTaskRunner& global_inference_runner(){
16+
static ParallelTaskRunner runner(
17+
[](){
18+
GlobalSettings::instance().PERFORMANCE->INFERENCE_PRIORITY.set_on_this_thread();
19+
},
20+
0, std::thread::hardware_concurrency()
21+
);
22+
return runner;
23+
}
24+
25+
26+
27+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/* Global Inference Runner
2+
*
3+
* From: https://github.com/PokemonAutomation/Arduino-Source
4+
*
5+
*/
6+
7+
#ifndef PokemonAutomation_CommonTools_GlobalInferenceRunner_H
8+
#define PokemonAutomation_CommonTools_GlobalInferenceRunner_H
9+
10+
#include "Common/Cpp/Concurrency/ParallelTaskRunner.h"
11+
12+
namespace PokemonAutomation{
13+
14+
15+
16+
ParallelTaskRunner& global_inference_runner();
17+
18+
19+
20+
}
21+
#endif

SerialPrograms/Source/CommonTools/ImageMatch/SilhouetteDictionaryMatcher.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
*/
66

77
#include "Common/Cpp/Exceptions.h"
8+
#include "Common/Cpp/Concurrency/SpinLock.h"
89
#include "CommonFramework/ImageTypes/ImageViewRGB32.h"
910
#include "CommonFramework/ImageTools/ImageDiff.h"
11+
#include "CommonTools/GlobalInferenceRunner.h"
1012
#include "ImageCropper.h"
1113
#include "SilhouetteDictionaryMatcher.h"
1214

@@ -27,11 +29,12 @@ void SilhouetteDictionaryMatcher::add(const std::string& slug, const ImageViewRG
2729
throw InternalProgramError(nullptr, PA_CURRENT_FUNCTION, "Duplicate slug: " + slug);
2830
}
2931

30-
m_database.emplace(
32+
iter = m_database.emplace(
3133
std::piecewise_construct,
3234
std::forward_as_tuple(slug),
3335
std::forward_as_tuple(trim_image_alpha(image).copy())
34-
);
36+
).first;
37+
m_database_vector.emplace_back(&*iter);
3538
}
3639

3740

@@ -45,14 +48,31 @@ ImageMatchResult SilhouetteDictionaryMatcher::match(
4548
return results;
4649
}
4750

51+
SpinLock lock;
52+
global_inference_runner().run_in_parallel(
53+
[&](size_t index){
54+
const auto& matcher = *m_database_vector[index];
55+
double alpha = matcher.second.rmsd_masked(image);
56+
WriteSpinLock lg(lock);
57+
results.add(alpha, matcher.first);
58+
results.clear_beyond_spread(alpha_spread);
59+
},
60+
0, m_database_vector.size(),
61+
100
62+
);
63+
64+
65+
#if 0
4866
for (const auto& item : m_database){
4967
// if (item.first != "solosis"){
5068
// continue;
5169
// }
5270
double alpha = item.second.rmsd_masked(image);
71+
// WriteSpinLock lg(lock);
5372
results.add(alpha, item.first);
5473
results.clear_beyond_spread(alpha_spread);
5574
}
75+
#endif
5676

5777
return results;
5878
}

SerialPrograms/Source/CommonTools/ImageMatch/SilhouetteDictionaryMatcher.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class SilhouetteDictionaryMatcher{
4747

4848
private:
4949
std::map<std::string, ExactImageMatcher> m_database;
50+
std::vector<const std::pair<const std::string, ExactImageMatcher>*> m_database_vector;
5051
};
5152

5253

0 commit comments

Comments
 (0)