Skip to content

Commit a60032b

Browse files
committed
AutoDA: Kill the program if it's likely that the user set the wrong language.
1 parent e0af2bf commit a60032b

21 files changed

+216
-35
lines changed

SerialPrograms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@ file(GLOB MAIN_SOURCES
492492
Source/CommonTools/Audio/SpectrogramMatcher.cpp
493493
Source/CommonTools/Audio/SpectrogramMatcher.h
494494
Source/CommonTools/DetectionDebouncer.h
495+
Source/CommonTools/FailureWatchdog.h
495496
Source/CommonTools/ImageMatch/CroppedImageDictionaryMatcher.cpp
496497
Source/CommonTools/ImageMatch/CroppedImageDictionaryMatcher.h
497498
Source/CommonTools/ImageMatch/ExactImageDictionaryMatcher.cpp

SerialPrograms/SerialPrograms.pro

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,6 +1376,7 @@ HEADERS += \
13761376
Source/CommonTools/Audio/AudioTemplateCache.h \
13771377
Source/CommonTools/Audio/SpectrogramMatcher.h \
13781378
Source/CommonTools/DetectionDebouncer.h \
1379+
Source/CommonTools/FailureWatchdog.h \
13791380
Source/CommonTools/GlobalInferenceRunner.h \
13801381
Source/CommonTools/ImageMatch/CroppedImageDictionaryMatcher.h \
13811382
Source/CommonTools/ImageMatch/ExactImageDictionaryMatcher.h \
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/* Failure Watchdog
2+
*
3+
* From: https://github.com/PokemonAutomation/Arduino-Source
4+
*
5+
*/
6+
7+
#ifndef PokemonAutomation_CommonTools_FailureWatchdog_H
8+
#define PokemonAutomation_CommonTools_FailureWatchdog_H
9+
10+
#include "Common/Cpp/AbstractLogger.h"
11+
#include "Common/Cpp/Time.h"
12+
#include "Common/Cpp/Exceptions.h"
13+
14+
namespace PokemonAutomation{
15+
16+
17+
18+
19+
20+
class FailureWatchdog{
21+
public:
22+
FailureWatchdog(
23+
Logger& logger,
24+
std::string failure_message,
25+
uint64_t min_count = 5,
26+
double min_success_rate = 0.5,
27+
std::chrono::seconds time_limit = std::chrono::seconds(120)
28+
)
29+
: m_logger(logger)
30+
, m_failure_message(std::move(failure_message))
31+
, m_min_count(min_count)
32+
, m_min_success_rate(min_success_rate)
33+
, m_time_limit(time_limit)
34+
{
35+
restart();
36+
}
37+
void restart(){
38+
m_expiration = current_time() + m_time_limit;
39+
m_expired = false;
40+
m_successes = 0;
41+
m_total = 0;
42+
}
43+
44+
void push_result(bool success){
45+
m_successes += success ? 1 : 0;
46+
m_total++;
47+
if (success || m_expired){
48+
return;
49+
}
50+
51+
WallClock current = current_time();
52+
if (current >= m_expiration){
53+
m_expired = true;
54+
}
55+
56+
if (m_total < m_min_count){
57+
return;
58+
}
59+
60+
double threshold = (double)m_total * m_min_success_rate;
61+
if ((double)m_successes >= threshold){
62+
return;
63+
}
64+
65+
66+
throw UserSetupError(m_logger, m_failure_message);
67+
}
68+
69+
70+
private:
71+
Logger& m_logger;
72+
std::string m_failure_message;
73+
uint64_t m_min_count;
74+
double m_min_success_rate;
75+
WallDuration m_time_limit;
76+
WallClock m_expiration;
77+
bool m_expired;
78+
79+
uint64_t m_successes;
80+
uint64_t m_total;
81+
};
82+
83+
84+
85+
86+
87+
class OcrFailureWatchdog : public FailureWatchdog{
88+
public:
89+
OcrFailureWatchdog(
90+
Logger& logger,
91+
std::string failure_message = "Too many text recognition errors. Did you set the correct language?",
92+
uint64_t min_count = 5,
93+
double min_success_rate = 0.5,
94+
std::chrono::seconds time_limit = std::chrono::seconds(120)
95+
)
96+
: FailureWatchdog(
97+
logger,
98+
std::move(failure_message),
99+
min_count,
100+
min_success_rate,
101+
time_limit
102+
)
103+
{}
104+
};
105+
106+
107+
108+
109+
110+
111+
}
112+
#endif

SerialPrograms/Source/Controllers/KeyboardInput/KeyboardInput.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ void KeyboardInputController::thread_loop(){
137137
break;
138138
}
139139

140-
141140
// If state is neutral, just issue a stop.
142141
if (neutral){
143142
if (try_stop_commands()){

SerialPrograms/Source/NintendoSwitch/DevPrograms/TestProgramComputer.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,6 @@ class WatchdogTest1 : public WatchdogCallback{
219219

220220

221221

222-
223-
224-
225-
226222
void TestProgramComputer::program(ProgramEnvironment& env, CancellableScope& scope){
227223
using namespace Kernels;
228224
using namespace NintendoSwitch;

SerialPrograms/Source/PokemonSwSh/MaxLair/Framework/PokemonSwSh_MaxLair_StateMachine.cpp

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
*
55
*/
66

7+
#include "Common/Cpp/Containers/FixedLimitVector.tpp"
78
#include "CommonFramework/ImageTypes/ImageViewRGB32.h"
89
#include "CommonFramework/Tools/ErrorDumper.h"
910
#include "CommonFramework/Notifications/ProgramInfo.h"
@@ -36,6 +37,35 @@ namespace PokemonSwSh{
3637
namespace MaxLairInternal{
3738

3839

40+
41+
AdventureRuntime::~AdventureRuntime() = default;
42+
AdventureRuntime::AdventureRuntime(
43+
FixedLimitVector<ConsoleHandle>& consoles,
44+
const size_t p_host_index,
45+
const Consoles& p_console_settings,
46+
const EndBattleDecider& p_actions,
47+
const bool p_go_home_when_done,
48+
HostingSettings& p_hosting_settings,
49+
EventNotificationOption& p_notification_status,
50+
EventNotificationOption& p_notification_shiny,
51+
Stats& p_session_stats
52+
)
53+
: host_index(p_host_index)
54+
, console_settings(p_console_settings)
55+
, actions(p_actions)
56+
, go_home_when_done(p_go_home_when_done)
57+
, hosting_settings(p_hosting_settings)
58+
, notification_status(p_notification_status)
59+
, notification_shiny(p_notification_shiny)
60+
, ocr_watchdog(p_console_settings.active_consoles())
61+
, session_stats(p_session_stats)
62+
{
63+
for (size_t c = 0; c < p_console_settings.active_consoles(); c++){
64+
ocr_watchdog.emplace_back(consoles[c].logger());
65+
}
66+
}
67+
68+
3969
StateMachineAction run_state_iteration(
4070
AdventureRuntime& runtime, size_t console_index,
4171
ProgramEnvironment& env,
@@ -94,6 +124,7 @@ StateMachineAction run_state_iteration(
94124
console_index,
95125
stream, context,
96126
global_state,
127+
runtime.ocr_watchdog[console_index],
97128
runtime.console_settings[console_index]
98129
);
99130
return StateMachineAction::KEEP_GOING;
@@ -124,7 +155,9 @@ StateMachineAction run_state_iteration(
124155
stream.log("Current State: Move Select");
125156
return run_move_select(
126157
env, console_index,
127-
stream, context, global_state,
158+
stream, context,
159+
runtime.ocr_watchdog[console_index],
160+
global_state,
128161
runtime.console_settings[console_index],
129162
battle_menu.dmaxed(),
130163
battle_menu.cheer()
@@ -136,6 +169,7 @@ StateMachineAction run_state_iteration(
136169
env, console_index,
137170
stream, context,
138171
runtime.console_settings[console_index].language,
172+
runtime.ocr_watchdog[console_index],
139173
global_state,
140174
decider
141175
);

SerialPrograms/Source/PokemonSwSh/MaxLair/Framework/PokemonSwSh_MaxLair_StateMachine.h

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
#ifndef PokemonAutomation_PokemonSwSh_MaxLair_StateMachine_H
88
#define PokemonAutomation_PokemonSwSh_MaxLair_StateMachine_H
99

10+
#include "Common/Cpp/Concurrency/SpinLock.h"
1011
#include "CommonFramework/Tools/VideoStream.h"
1112
#include "CommonFramework/Tools/ProgramEnvironment.h"
12-
#include "Common/Cpp/Concurrency/SpinLock.h"
13+
#include "CommonTools/FailureWatchdog.h"
1314
#include "NintendoSwitch/Controllers/NintendoSwitch_Controller.h"
15+
#include "NintendoSwitch/NintendoSwitch_ConsoleHandle.h"
1416
#include "PokemonSwSh/Inference/PokemonSwSh_QuantityReader.h"
1517
#include "PokemonSwSh/MaxLair/Options/PokemonSwSh_MaxLair_Options.h"
1618
#include "PokemonSwSh/MaxLair/Options/PokemonSwSh_MaxLair_Options_Consoles.h"
@@ -52,7 +54,9 @@ struct ConsoleRuntime{
5254
};
5355

5456
struct AdventureRuntime{
57+
~AdventureRuntime();
5558
AdventureRuntime(
59+
FixedLimitVector<ConsoleHandle>& consoles,
5660
const size_t p_host_index,
5761
const Consoles& p_console_settings,
5862
const EndBattleDecider& p_actions,
@@ -61,16 +65,7 @@ struct AdventureRuntime{
6165
EventNotificationOption& p_notification_status,
6266
EventNotificationOption& p_notification_shiny,
6367
Stats& p_session_stats
64-
)
65-
: host_index(p_host_index)
66-
, console_settings(p_console_settings)
67-
, actions(p_actions)
68-
, go_home_when_done(p_go_home_when_done)
69-
, hosting_settings(p_hosting_settings)
70-
, notification_status(p_notification_status)
71-
, notification_shiny(p_notification_shiny)
72-
, session_stats(p_session_stats)
73-
{}
68+
);
7469

7570
const size_t host_index;
7671
const Consoles& console_settings;
@@ -79,6 +74,9 @@ struct AdventureRuntime{
7974
HostingSettings& hosting_settings;
8075
EventNotificationOption& notification_status;
8176
EventNotificationOption& notification_shiny;
77+
78+
FixedLimitVector<OcrFailureWatchdog> ocr_watchdog;
79+
8280
Stats& session_stats;
8381

8482
PathStats path_stats;

SerialPrograms/Source/PokemonSwSh/MaxLair/Inference/PokemonSwSh_MaxLair_Detect_BattleMenu.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,13 @@ bool BattleMenuDetector::detect(const ImageViewRGB32& screen){
211211

212212

213213

214-
BattleMenuReader::BattleMenuReader(VideoOverlay& overlay, Language language)
214+
BattleMenuReader::BattleMenuReader(
215+
VideoOverlay& overlay,
216+
Language language,
217+
OcrFailureWatchdog& ocr_watchdog
218+
)
215219
: m_language(language)
220+
, m_ocr_watchdog(ocr_watchdog)
216221
, m_opponent_name(overlay, {0.3, 0.010, 0.4, 0.10}, COLOR_BLUE)
217222
, m_summary_opponent_name(overlay, {0.200, 0.100, 0.300, 0.065}, COLOR_BLUE)
218223
, m_summary_opponent_types(overlay, {0.200, 0.170, 0.300, 0.050}, COLOR_BLUE)
@@ -243,7 +248,7 @@ std::set<std::string> BattleMenuReader::read_opponent(
243248
for (size_t c = 0; c < 3; c++){
244249
screen = feed.snapshot();
245250
ImageViewRGB32 image = extract_box_reference(screen, m_opponent_name);
246-
result = read_pokemon_name(logger, m_language, image);
251+
result = read_pokemon_name(logger, m_language, m_ocr_watchdog, image);
247252
if (!result.empty()){
248253
return result;
249254
}
@@ -316,7 +321,7 @@ std::set<std::string> BattleMenuReader::read_opponent_in_summary(Logger& logger,
316321
ImageViewRGB32 name = extract_box_reference(screen, m_summary_opponent_name);
317322

318323
// We can use a weaker threshold here since we are cross-checking with the type.
319-
name_slugs = read_pokemon_name(logger, m_language, name, -1.0);
324+
name_slugs = read_pokemon_name(logger, m_language, m_ocr_watchdog, name, -1.0);
320325
}
321326

322327
// See if there's anything in common between the slugs that match the type
@@ -381,6 +386,7 @@ std::set<std::string> BattleMenuReader::read_opponent_in_summary(Logger& logger,
381386
std::string BattleMenuReader::read_own_mon(Logger& logger, const ImageViewRGB32& screen) const{
382387
return read_pokemon_name_sprite(
383388
logger,
389+
m_ocr_watchdog,
384390
screen,
385391
m_own_sprite,
386392
m_own_name, m_language,

SerialPrograms/Source/PokemonSwSh/MaxLair/Inference/PokemonSwSh_MaxLair_Detect_BattleMenu.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "CommonFramework/Language.h"
1111
#include "CommonFramework/Logging/Logger.h"
1212
#include "CommonFramework/VideoPipeline/VideoOverlayScopes.h"
13+
#include "CommonTools/FailureWatchdog.h"
1314
#include "CommonTools/InferenceCallbacks/VisualInferenceCallback.h"
1415
#include "PokemonSwSh/MaxLair/Framework/PokemonSwSh_MaxLair_State.h"
1516

@@ -56,7 +57,11 @@ class BattleMenuDetector : public VisualInferenceCallback{
5657

5758
class BattleMenuReader{
5859
public:
59-
BattleMenuReader(VideoOverlay& overlay, Language language);
60+
BattleMenuReader(
61+
VideoOverlay& overlay,
62+
Language language,
63+
OcrFailureWatchdog& ocr_watchdog
64+
);
6065

6166
std::set<std::string> read_opponent(
6267
Logger& logger, CancellableScope& scope,
@@ -74,6 +79,7 @@ class BattleMenuReader{
7479

7580
private:
7681
Language m_language;
82+
OcrFailureWatchdog& m_ocr_watchdog;
7783
OverlayBoxScope m_opponent_name;
7884
OverlayBoxScope m_summary_opponent_name;
7985
OverlayBoxScope m_summary_opponent_types;

SerialPrograms/Source/PokemonSwSh/MaxLair/Inference/PokemonSwSh_MaxLair_Detect_PokemonReader.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ std::string read_boss_sprite(VideoStream& stream){
108108

109109
std::set<std::string> read_pokemon_name(
110110
Logger& logger, Language language,
111+
OcrFailureWatchdog& ocr_watchdog,
111112
const ImageViewRGB32& image,
112113
double max_log10p
113114
){
@@ -123,8 +124,10 @@ std::set<std::string> read_pokemon_name(
123124
);
124125
// result.log(logger);
125126
if (result.results.empty()){
127+
ocr_watchdog.push_result(false);
126128
return {};
127129
}
130+
ocr_watchdog.push_result(true);
128131

129132
// Convert OCR slugs to MaxLair name slugs.
130133
std::set<std::string> ret;
@@ -301,6 +304,7 @@ std::string read_pokemon_sprite_with_item(
301304

302305
std::string read_pokemon_name_sprite(
303306
Logger& logger,
307+
OcrFailureWatchdog& ocr_watchdog,
304308
const ImageViewRGB32& screen,
305309
const ImageFloatBox& sprite_box,
306310
const ImageFloatBox& name_box, Language language,
@@ -313,7 +317,7 @@ std::string read_pokemon_name_sprite(
313317
ImageViewRGB32 image = extract_box_reference(screen, name_box);
314318

315319
std::set<std::string> ocr_slugs;
316-
for (const std::string& slug : read_pokemon_name(logger, language, image)){
320+
for (const std::string& slug : read_pokemon_name(logger, language, ocr_watchdog, image)){
317321
// Only include candidates that are valid rental Pokemon.
318322
auto iter = RENTALS.find(slug);
319323
if (iter != RENTALS.end()){

0 commit comments

Comments
 (0)