Skip to content

Commit b505541

Browse files
Type check config values
1 parent c8315e4 commit b505541

File tree

1 file changed

+125
-88
lines changed

1 file changed

+125
-88
lines changed

rlbot/config.py

Lines changed: 125 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,58 @@ class ConfigParsingException(Exception):
1111
pass
1212

1313

14-
def __parse_enum(table: dict, key: str, enum: Any, default: int = 0) -> Any:
14+
def __enum(table: dict, key: str, enum: Any, default: int = 0) -> Any:
1515
if key not in table:
1616
return enum(default)
1717
try:
1818
for i in range(100000):
1919
if str(enum(i)).split('.')[-1].lower() == table[key].lower():
2020
return enum(i)
2121
except ValueError:
22-
raise ConfigParsingException(f"Invalid value \"{table[key]}\" for key \"{key}\".")
22+
raise ConfigParsingException(f"Invalid value {repr(table[key])} for key '{key}'.")
23+
24+
25+
def __str(table: dict, key: str, default: str = "") -> str:
26+
v = table.get(key, default)
27+
if isinstance(v, str):
28+
return v
29+
raise ConfigParsingException(f"'{key}' has value {repr(v)}. Expected a string.")
30+
31+
32+
def __bool(table: dict, key: str, default: bool = False) -> bool:
33+
v = table.get(key, default)
34+
if isinstance(v, bool):
35+
return v
36+
raise ConfigParsingException(f"'{key}' has value {repr(v)}. Expected a bool.")
37+
38+
39+
def __int(table: dict, key: str, default: int = 0) -> int:
40+
v = table.get(key, default)
41+
if isinstance(v, int):
42+
return v
43+
raise ConfigParsingException(f"'{key}' has value {repr(v)}. Expected an int.")
44+
45+
46+
def __table(table: dict, key: str) -> dict:
47+
v = table.get(key, dict())
48+
if isinstance(v, dict):
49+
return v
50+
raise ConfigParsingException(f"'{key}' has value {repr(v)}. Expected a table.")
51+
52+
53+
def __team(table: dict) -> int:
54+
if 'team' not in table:
55+
return 0
56+
v = table['team']
57+
if isinstance(v, str):
58+
if v.lower() == "blue":
59+
return 0
60+
if v.lower() == "orange":
61+
return 1
62+
if isinstance(v, int):
63+
if 0 <= v <= 1:
64+
return v
65+
raise ConfigParsingException(f"'team' has value {repr(v)}. Expected a 0, 1, \"blue\", or \"orange\".")
2366

2467

2568
def load_match_config(config_path: Path | str) -> flat.MatchConfiguration:
@@ -30,25 +73,18 @@ def load_match_config(config_path: Path | str) -> flat.MatchConfiguration:
3073
with open(config_path, "rb") as f:
3174
config = tomllib.load(f)
3275

33-
rlbot_table = config.get("rlbot", dict())
34-
match_table = config.get("match", dict())
35-
mutator_table = config.get("mutators", dict())
76+
rlbot_table = __table(config, "rlbot")
77+
match_table = __table(config, "match")
78+
mutator_table = __table(config, "mutators")
3679

3780
players = []
3881
for car_table in config.get("cars", []):
39-
car_config = car_table.get("config")
40-
name = car_table.get("name", "")
41-
team = car_table.get("team", 0)
42-
try:
43-
team = int(team)
44-
except ValueError:
45-
team = {"blue": 0, "orange": 1}.get(team.lower())
46-
if team is None or team not in [0, 1]:
47-
raise ConfigParsingException(f"Invalid team \"{car_table.get("team")}\" for player {len(players)}.")
48-
49-
loadout_file = car_table.get("loadout_file")
50-
skill = __parse_enum(car_table, "skill", flat.PsyonixSkill, int(flat.PsyonixSkill.AllStar))
51-
variant = car_table.get("type", "rlbot").lower()
82+
car_config = __str(car_table, "config")
83+
name = __str(car_table, "name")
84+
team = __team(car_table)
85+
loadout_file = __str(car_table, "loadout_file") or None
86+
skill = __enum(car_table, "skill", flat.PsyonixSkill, int(flat.PsyonixSkill.AllStar))
87+
variant = __str(car_table, "type", "rlbot").lower()
5288

5389
match variant:
5490
case "rlbot":
@@ -61,9 +97,9 @@ def load_match_config(config_path: Path | str) -> flat.MatchConfiguration:
6197
logger.warning("PartyMember player type is not supported yet.")
6298
variety, use_config = flat.PartyMember, False
6399
case t:
64-
raise ConfigParsingException(f"Invalid player type \"{t}\" for player {len(players)}.")
100+
raise ConfigParsingException(f"Invalid player type {repr(t)} for player {len(players)}.")
65101

66-
if use_config and car_config is not None:
102+
if use_config and car_config:
67103
abs_config_path = (config_path.parent / car_config).resolve()
68104
players.append(load_player_config(abs_config_path, variety, team, name, loadout_file))
69105
else:
@@ -72,50 +108,50 @@ def load_match_config(config_path: Path | str) -> flat.MatchConfiguration:
72108

73109
scripts = []
74110
for script_table in config.get("scripts", []):
75-
if script_config := script_table.get("config"):
111+
if script_config := __str(script_table, "config"):
76112
abs_config_path = (config_path.parent / script_config).resolve()
77113
scripts.append(load_script_config(abs_config_path))
78114
else:
79115
scripts.append(flat.ScriptConfiguration())
80116

81117
mutators = flat.MutatorSettings(
82-
match_length=__parse_enum(mutator_table, "match_length", flat.MatchLengthMutator),
83-
max_score=__parse_enum(mutator_table, "max_score", flat.MaxScoreMutator),
84-
multi_ball=__parse_enum(mutator_table, "multi_ball", flat.MultiBallMutator),
85-
overtime=__parse_enum(mutator_table, "overtime", flat.OvertimeMutator),
86-
series_length=__parse_enum(mutator_table, "series_length", flat.SeriesLengthMutator),
87-
game_speed=__parse_enum(mutator_table, "game_speed", flat.GameSpeedMutator),
88-
ball_max_speed=__parse_enum(mutator_table, "ball_max_speed", flat.BallMaxSpeedMutator),
89-
ball_type=__parse_enum(mutator_table, "ball_type", flat.BallTypeMutator),
90-
ball_weight=__parse_enum(mutator_table, "ball_weight", flat.BallWeightMutator),
91-
ball_size=__parse_enum(mutator_table, "ball_size", flat.BallSizeMutator),
92-
ball_bounciness=__parse_enum(mutator_table, "ball_bounciness", flat.BallBouncinessMutator),
93-
boost=__parse_enum(mutator_table, "boost_amount", flat.BoostMutator),
94-
rumble=__parse_enum(mutator_table, "rumble", flat.RumbleMutator),
95-
boost_strength=__parse_enum(mutator_table, "boost_strength", flat.BoostStrengthMutator),
96-
gravity=__parse_enum(mutator_table, "gravity", flat.GravityMutator),
97-
demolish=__parse_enum(mutator_table, "demolish", flat.DemolishMutator),
98-
respawn_time=__parse_enum(mutator_table, "respawn_time", flat.RespawnTimeMutator),
99-
max_time=__parse_enum(mutator_table, "max_time", flat.MaxTimeMutator),
100-
game_event=__parse_enum(mutator_table, "game_event", flat.GameEventMutator),
101-
audio=__parse_enum(mutator_table, "audio", flat.AudioMutator),
118+
match_length=__enum(mutator_table, "match_length", flat.MatchLengthMutator),
119+
max_score=__enum(mutator_table, "max_score", flat.MaxScoreMutator),
120+
multi_ball=__enum(mutator_table, "multi_ball", flat.MultiBallMutator),
121+
overtime=__enum(mutator_table, "overtime", flat.OvertimeMutator),
122+
series_length=__enum(mutator_table, "series_length", flat.SeriesLengthMutator),
123+
game_speed=__enum(mutator_table, "game_speed", flat.GameSpeedMutator),
124+
ball_max_speed=__enum(mutator_table, "ball_max_speed", flat.BallMaxSpeedMutator),
125+
ball_type=__enum(mutator_table, "ball_type", flat.BallTypeMutator),
126+
ball_weight=__enum(mutator_table, "ball_weight", flat.BallWeightMutator),
127+
ball_size=__enum(mutator_table, "ball_size", flat.BallSizeMutator),
128+
ball_bounciness=__enum(mutator_table, "ball_bounciness", flat.BallBouncinessMutator),
129+
boost=__enum(mutator_table, "boost_amount", flat.BoostMutator),
130+
rumble=__enum(mutator_table, "rumble", flat.RumbleMutator),
131+
boost_strength=__enum(mutator_table, "boost_strength", flat.BoostStrengthMutator),
132+
gravity=__enum(mutator_table, "gravity", flat.GravityMutator),
133+
demolish=__enum(mutator_table, "demolish", flat.DemolishMutator),
134+
respawn_time=__enum(mutator_table, "respawn_time", flat.RespawnTimeMutator),
135+
max_time=__enum(mutator_table, "max_time", flat.MaxTimeMutator),
136+
game_event=__enum(mutator_table, "game_event", flat.GameEventMutator),
137+
audio=__enum(mutator_table, "audio", flat.AudioMutator),
102138
)
103139

104140
return flat.MatchConfiguration(
105-
launcher=__parse_enum(rlbot_table, "launcher", flat.Launcher),
106-
launcher_arg=rlbot_table.get("launcher_arg", ""),
107-
auto_start_bots=rlbot_table.get("auto_start_bots", True),
108-
game_map_upk=match_table.get("game_map_upk", ""),
141+
launcher=__enum(rlbot_table, "launcher", flat.Launcher),
142+
launcher_arg=__str(rlbot_table, "launcher_arg"),
143+
auto_start_bots=__bool(rlbot_table, "auto_start_bots", True),
144+
game_map_upk=__str(match_table, "game_map_upk"),
109145
player_configurations=players,
110146
script_configurations=scripts,
111-
game_mode=__parse_enum(match_table, "game_mode", flat.GameMode),
112-
skip_replays=match_table.get("skip_replays", False),
113-
instant_start=match_table.get("instant_start", False),
147+
game_mode=__enum(match_table, "game_mode", flat.GameMode),
148+
skip_replays=__bool(match_table, "skip_replays"),
149+
instant_start=__bool(match_table, "instant_start"),
114150
mutators=mutators,
115-
existing_match_behavior=__parse_enum(match_table, "existing_match_behavior", flat.ExistingMatchBehavior),
116-
enable_rendering=match_table.get("enable_rendering", False),
117-
enable_state_setting=match_table.get("enable_state_setting", False),
118-
freeplay=match_table.get("freeplay", False),
151+
existing_match_behavior=__enum(match_table, "existing_match_behavior", flat.ExistingMatchBehavior),
152+
enable_rendering=__bool(match_table, "enable_rendering"),
153+
enable_state_setting=__bool(match_table, "enable_state_setting"),
154+
freeplay=__bool(match_table, "freeplay"),
119155
)
120156

121157

@@ -126,34 +162,35 @@ def load_player_loadout(path: Path | str, team: int) -> flat.PlayerLoadout:
126162
with open(path, "rb") as f:
127163
config = tomllib.load(f)
128164

129-
loadout = config["blue_loadout"] if team == 0 else config["orange_loadout"]
165+
table_name = "blue_loadout" if team == 0 else "orange_loadout"
166+
loadout = __table(config, table_name)
130167
paint = None
131-
if paint_table := loadout.get("paint", None):
168+
if paint_table := __table(loadout, "paint"):
132169
paint = flat.LoadoutPaint(
133-
car_paint_id=paint_table.get("car_paint_id", 0),
134-
decal_paint_id=paint_table.get("decal_paint_id", 0),
135-
wheels_paint_id=paint_table.get("wheels_paint_id", 0),
136-
boost_paint_id=paint_table.get("boost_paint_id", 0),
137-
antenna_paint_id=paint_table.get("antenna_paint_id", 0),
138-
hat_paint_id=paint_table.get("hat_paint_id", 0),
139-
trails_paint_id=paint_table.get("trails_paint_id", 0),
140-
goal_explosion_paint_id=paint_table.get("goal_explosion_paint_id", 0),
170+
car_paint_id=__int(paint_table, "car_paint_id"),
171+
decal_paint_id=__int(paint_table, "decal_paint_id"),
172+
wheels_paint_id=__int(paint_table, "wheels_paint_id"),
173+
boost_paint_id=__int(paint_table, "boost_paint_id"),
174+
antenna_paint_id=__int(paint_table, "antenna_paint_id"),
175+
hat_paint_id=__int(paint_table, "hat_paint_id"),
176+
trails_paint_id=__int(paint_table, "trails_paint_id"),
177+
goal_explosion_paint_id=__int(paint_table, "goal_explosion_paint_id"),
141178
)
142179

143180
return flat.PlayerLoadout(
144-
team_color_id=loadout.get("team_color_id", 0),
145-
custom_color_id=loadout.get("custom_color_id", 0),
146-
car_id=loadout.get("car_id", 0),
147-
decal_id=loadout.get("decal_id", 0),
148-
wheels_id=loadout.get("wheels_id", 0),
149-
boost_id=loadout.get("boost_id", 0),
150-
antenna_id=loadout.get("antenna_id", 0),
151-
hat_id=loadout.get("hat_id", 0),
152-
paint_finish_id=loadout.get("paint_finish_id", 0),
153-
custom_finish_id=loadout.get("custom_finish_id", 0),
154-
engine_audio_id=loadout.get("engine_audio_id", 0),
155-
trails_id=loadout.get("trails_id", 0),
156-
goal_explosion_id=loadout.get("goal_explosion_id", 0),
181+
team_color_id=__int(loadout, "team_color_id"),
182+
custom_color_id=__int(loadout, "custom_color_id"),
183+
car_id=__int(loadout, "car_id"),
184+
decal_id=__int(loadout, "decal_id"),
185+
wheels_id=__int(loadout, "wheels_id"),
186+
boost_id=__int(loadout, "boost_id"),
187+
antenna_id=__int(loadout, "antenna_id"),
188+
hat_id=__int(loadout, "hat_id"),
189+
paint_finish_id=__int(loadout, "paint_finish_id"),
190+
custom_finish_id=__int(loadout, "custom_finish_id"),
191+
engine_audio_id=__int(loadout, "engine_audio_id"),
192+
trails_id=__int(loadout, "trails_id"),
193+
goal_explosion_id=__int(loadout, "goal_explosion_id"),
157194
loadout_paint=paint,
158195
)
159196

@@ -170,30 +207,30 @@ def load_player_config(
170207
with open(path, "rb") as f:
171208
config = tomllib.load(f)
172209

173-
settings: dict[str, Any] = config["settings"]
210+
settings = __table(config, "settings")
174211

175212
root_dir = path.parent.absolute()
176213
if "root_dir" in settings:
177-
root_dir /= Path(settings["root_dir"])
214+
root_dir /= Path(__str(settings, "root_dir"))
178215

179-
run_command = settings.get("run_command", "")
216+
run_command = __str(settings, "run_command")
180217
if CURRENT_OS == OS.LINUX and "run_command_linux" in settings:
181-
run_command = settings.get("run_command_linux", "")
218+
run_command = __str(settings, "run_command_linux")
182219

183-
loadout_path = path.parent / Path(settings["loadout_file"]) if "loadout_file" in settings else None
220+
loadout_path = path.parent / Path(__str(settings, "loadout_file")) if "loadout_file" in settings else None
184221
loadout_path = loadout_override or loadout_path
185222
loadout = load_player_loadout(loadout_path, team) if loadout_path is not None else None
186223

187224
return flat.PlayerConfiguration(
188225
type,
189-
name_override or settings.get("name", ""),
226+
name_override or __str(settings, "name"),
190227
team,
191228
str(root_dir),
192229
run_command,
193230
loadout,
194231
0,
195-
settings.get("agent_id", ""),
196-
settings.get("hivemind", False),
232+
__str(settings, "agent_id"),
233+
__bool(settings, "hivemind"),
197234
)
198235

199236

@@ -209,16 +246,16 @@ def load_script_config(path: Path | str) -> flat.ScriptConfiguration:
209246

210247
root_dir = path.parent
211248
if "root_dir" in settings:
212-
root_dir /= Path(settings["root_dir"])
249+
root_dir /= Path(__str(settings, "root_dir"))
213250

214-
run_command = settings.get("run_command", "")
251+
run_command = __str(settings, "run_command")
215252
if CURRENT_OS == OS.LINUX and "run_command_linux" in settings:
216-
run_command = settings["run_command_linux"]
253+
run_command = __str(settings, "run_command_linux")
217254

218255
return flat.ScriptConfiguration(
219-
settings.get("name", ""),
256+
__str(settings, "name"),
220257
str(root_dir),
221258
run_command,
222259
0,
223-
settings.get("agent_id", ""),
260+
__str(settings, "agent_id"),
224261
)

0 commit comments

Comments
 (0)