Skip to content

Commit 077178c

Browse files
committed
feat: improve CLI for replay bot
1 parent 5c41d17 commit 077178c

File tree

1 file changed

+116
-56
lines changed

1 file changed

+116
-56
lines changed

bots/replay.py

Lines changed: 116 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import json
55
import logging
66
import sys
7+
import tempfile
78
import time
89
from pathlib import Path
910

@@ -14,21 +15,21 @@
1415
logging.basicConfig(level=logging.INFO)
1516

1617

17-
def get_most_recent_jsonl() -> Path:
18-
"""Find the most recent JSONL file in the runs directory."""
19-
runs_dir = Path("runs")
20-
if not runs_dir.exists():
21-
logger.error("Runs directory not found")
22-
sys.exit(1)
18+
def format_function_call(function_name: str, arguments: dict) -> str:
19+
"""Format function call in Python syntax for dry run mode."""
20+
args_str = json.dumps(arguments, indent=None, separators=(",", ": "))
21+
return f"{function_name}({args_str})"
2322

24-
jsonl_files = list(runs_dir.glob("*.jsonl"))
25-
if not jsonl_files:
26-
logger.error("No JSONL files found in runs directory")
27-
sys.exit(1)
2823

29-
# Sort by modification time, most recent first
30-
most_recent = max(jsonl_files, key=lambda f: f.stat().st_mtime)
31-
return most_recent
24+
def determine_output_path(output_arg: Path | None, input_path: Path) -> Path:
25+
"""Determine the final output path based on input and output arguments."""
26+
if output_arg is None:
27+
return input_path
28+
29+
if output_arg.is_dir():
30+
return output_arg / input_path.name
31+
else:
32+
return output_arg
3233

3334

3435
def load_steps_from_jsonl(jsonl_path: Path) -> list[dict]:
@@ -49,62 +50,121 @@ def load_steps_from_jsonl(jsonl_path: Path) -> list[dict]:
4950

5051
def main():
5152
"""Main replay function."""
52-
parser = argparse.ArgumentParser(description="Replay actions from a JSONL run file")
53+
parser = argparse.ArgumentParser(
54+
description="Replay actions from a JSONL run file",
55+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
56+
)
57+
parser.add_argument(
58+
"--input",
59+
"-i",
60+
type=Path,
61+
required=True,
62+
help="Input JSONL file to replay",
63+
)
64+
parser.add_argument(
65+
"--output",
66+
"-o",
67+
type=Path,
68+
help="Output path for generated run log (directory or .jsonl file). "
69+
"If directory, uses original filename. If not specified, overwrites input.",
70+
)
71+
parser.add_argument(
72+
"--port",
73+
"-p",
74+
type=int,
75+
default=12346,
76+
help="Port to connect to BalatroBot API",
77+
)
5378
parser.add_argument(
5479
"--delay",
55-
"-d",
5680
type=float,
5781
default=0.0,
58-
help="Delay between played moves in seconds (default: 0.0)",
82+
help="Delay between played moves in seconds",
5983
)
6084
parser.add_argument(
61-
"--path",
62-
"-p",
63-
type=Path,
64-
help="Path to JSONL run file (default: most recent file in runs/)",
85+
"--dry",
86+
"-d",
87+
action="store_true",
88+
help="Dry run mode: print function calls without executing them",
6589
)
6690

6791
args = parser.parse_args()
6892

69-
# Determine the path to use
70-
if args.path:
71-
jsonl_path = args.path
72-
else:
73-
jsonl_path = get_most_recent_jsonl()
74-
logger.info(f"Using most recent file: {jsonl_path}")
75-
76-
steps = load_steps_from_jsonl(jsonl_path)
77-
78-
try:
79-
with BalatroClient() as client:
80-
logger.info("Connected to BalatroBot API")
81-
82-
# Replay each step
83-
for i, step in enumerate(steps):
84-
function_name = step["function"]["name"]
85-
arguments = step["function"]["arguments"]
86-
logger.info(f"Step {i + 1}/{len(steps)}: {function_name}({arguments})")
87-
time.sleep(args.delay)
88-
89-
try:
90-
response = client.send_message(function_name, arguments)
91-
logger.debug(f"Response: {response}")
92-
except BalatroError as e:
93-
logger.error(f"API error in step {i + 1}: {e}")
94-
sys.exit(1)
95-
96-
logger.info("Replay completed successfully!")
97-
98-
except ConnectionFailedError as e:
99-
logger.error(f"Failed to connect to BalatroBot API: {e}")
93+
if not args.input.exists():
94+
logger.error(f"Input file not found: {args.input}")
10095
sys.exit(1)
101-
except KeyboardInterrupt:
102-
logger.info("Replay interrupted by user")
103-
sys.exit(0)
104-
except Exception as e:
105-
logger.error(f"Unexpected error during replay: {e}")
96+
97+
if not args.input.suffix == ".jsonl":
98+
logger.error(f"Input file must be a .jsonl file: {args.input}")
10699
sys.exit(1)
107100

101+
steps = load_steps_from_jsonl(args.input)
102+
final_output_path = determine_output_path(args.output, args.input)
103+
if args.dry:
104+
logger.info(
105+
f"Dry run mode: printing {len(steps)} function calls from {args.input}"
106+
)
107+
for i, step in enumerate(steps):
108+
function_name = step["function"]["name"]
109+
arguments = step["function"]["arguments"]
110+
print(format_function_call(function_name, arguments))
111+
time.sleep(args.delay)
112+
logger.info("Dry run completed")
113+
return
114+
115+
with tempfile.TemporaryDirectory() as temp_dir:
116+
temp_output_path = Path(temp_dir) / final_output_path.name
117+
118+
try:
119+
with BalatroClient(port=args.port) as client:
120+
logger.info(f"Connected to BalatroBot API on port {args.port}")
121+
logger.info(f"Replaying {len(steps)} steps from {args.input}")
122+
if final_output_path != args.input:
123+
logger.info(f"Output will be saved to: {final_output_path}")
124+
125+
for i, step in enumerate(steps):
126+
function_name = step["function"]["name"]
127+
arguments = step["function"]["arguments"]
128+
129+
if function_name == "start_run":
130+
arguments = arguments.copy()
131+
arguments["log_path"] = str(temp_output_path)
132+
133+
logger.info(
134+
f"Step {i + 1}/{len(steps)}: {format_function_call(function_name, arguments)}"
135+
)
136+
time.sleep(args.delay)
137+
138+
try:
139+
response = client.send_message(function_name, arguments)
140+
logger.debug(f"Response: {response}")
141+
except BalatroError as e:
142+
logger.error(f"API error in step {i + 1}: {e}")
143+
sys.exit(1)
144+
145+
logger.info("Replay completed successfully!")
146+
147+
if temp_output_path.exists():
148+
final_output_path.parent.mkdir(parents=True, exist_ok=True)
149+
temp_output_path.rename(final_output_path)
150+
logger.info(f"Output saved to: {final_output_path}")
151+
elif final_output_path != args.input:
152+
logger.warning(
153+
f"No output file was generated at {temp_output_path}"
154+
)
155+
156+
except ConnectionFailedError as e:
157+
logger.error(
158+
f"Failed to connect to BalatroBot API on port {args.port}: {e}"
159+
)
160+
sys.exit(1)
161+
except KeyboardInterrupt:
162+
logger.info("Replay interrupted by user")
163+
sys.exit(0)
164+
except Exception as e:
165+
logger.error(f"Unexpected error during replay: {e}")
166+
sys.exit(1)
167+
108168

109169
if __name__ == "__main__":
110170
main()

0 commit comments

Comments
 (0)