Skip to content

Commit bc95fbc

Browse files
fix: add fms-recommender wrapper
Signed-off-by: yashasvi <yashasvi@ibm.com>
1 parent facbce8 commit bc95fbc

File tree

2 files changed

+157
-0
lines changed

2 files changed

+157
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ dependencies = [
3939
"simpleeval>=0.9.13,<2.0",
4040
"pillow>=11.0.0,<12.0",
4141
"kernels<=0.9.0",
42+
"recommender @ git+https://github.com/YashasviChaurasia/tuning-config-recommender.git@fms-adapter-csv",
4243
]
4344

4445
[project.optional-dependencies]

tuning/fms-recommender.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#!/usr/bin/env python3
2+
# Standard
3+
from pathlib import Path
4+
import argparse
5+
import json
6+
import os
7+
import shlex
8+
import subprocess
9+
10+
# Third Party
11+
from recommender.adapters import FMSAdapter
12+
import yaml
13+
14+
ACCEL_NESTED_PREFIXES = {
15+
"fsdp_": "fsdp_config",
16+
}
17+
18+
DATA_KEYS = {
19+
"training_data_path",
20+
"validation_data_path",
21+
"dataset",
22+
}
23+
24+
25+
def grab_flags(tokens, start, end):
26+
cfg, i = {}, start
27+
while i < end:
28+
t = tokens[i]
29+
if t.startswith("--"):
30+
k, v = t[2:], True
31+
if "=" in t:
32+
k, v = k.split("=", 1)
33+
v = v.strip('"')
34+
elif i + 1 < end and not tokens[i + 1].startswith("--"):
35+
v = tokens[i + 1].strip('"')
36+
i += 1
37+
cfg[k] = v
38+
i += 1
39+
return cfg
40+
41+
42+
def load_yaml(path):
43+
if path and os.path.exists(path):
44+
try:
45+
with open(path, "r") as f:
46+
y = yaml.safe_load(f)
47+
return y if isinstance(y, dict) else {}
48+
except (OSError, yaml.YAMLError):
49+
return {}
50+
return {}
51+
52+
53+
def nest_accelerate_flags(flat_dist):
54+
nested = {section: {} for section in ACCEL_NESTED_PREFIXES.values()}
55+
remaining = {}
56+
57+
for k, v in flat_dist.items():
58+
matched = False
59+
for prefix, section in ACCEL_NESTED_PREFIXES.items():
60+
if k.startswith(prefix):
61+
nested[section][k] = v
62+
matched = True
63+
break
64+
if not matched:
65+
remaining[k] = v
66+
67+
for sec in list(nested.keys()):
68+
if not nested[sec]:
69+
nested.pop(sec)
70+
71+
return {**remaining, **nested}
72+
73+
74+
def parse(cmd: str):
75+
tokens = shlex.split(cmd)
76+
has_m = "-m" in tokens
77+
is_accel = "accelerate" in tokens and "launch" in tokens
78+
if is_accel and has_m:
79+
m = tokens.index("-m")
80+
dist_flat = grab_flags(tokens, 0, m)
81+
train = grab_flags(tokens, m + 2, len(tokens))
82+
83+
elif has_m:
84+
m = tokens.index("-m")
85+
dist_flat = {}
86+
train = grab_flags(tokens, m + 2, len(tokens))
87+
else:
88+
dist_flat = {}
89+
train = grab_flags(tokens, 0, len(tokens))
90+
91+
yaml_path = train.pop("data_config", None)
92+
if yaml_path:
93+
data = load_yaml(yaml_path)
94+
else:
95+
data = {}
96+
accel_yaml_path = dist_flat.pop("config_file", None)
97+
accel_yaml = load_yaml(accel_yaml_path) if accel_yaml_path else {}
98+
dist_nested = nest_accelerate_flags(dist_flat)
99+
dist = {**accel_yaml, **dist_nested}
100+
train.pop("config_file", None)
101+
102+
return train, dist, data
103+
104+
105+
def main():
106+
parser = argparse.ArgumentParser()
107+
parser.add_argument(
108+
"--debug",
109+
action="store_true",
110+
help="Print parsed configs and exit (no adapter, no execution).",
111+
)
112+
parser.add_argument(
113+
"--preview",
114+
action="store_true",
115+
help="Run adapter and show launch command but DO NOT execute it.",
116+
)
117+
parser.add_argument("command", nargs=argparse.REMAINDER)
118+
args = parser.parse_args()
119+
if not args.command:
120+
print("Error: No command provided.")
121+
return
122+
123+
cmd = " ".join(args.command)
124+
train_cfg, dist_cfg, data_cfg = parse(cmd)
125+
train_cfg.pop("config_file", None)
126+
dist_cfg.pop("config_file", None)
127+
128+
if args.debug:
129+
print("\n[dist_config]\n", json.dumps(dist_cfg, indent=2))
130+
print("\n[train_config]\n", json.dumps(train_cfg, indent=2))
131+
print("\n[data_config]\n", json.dumps(data_cfg, indent=2))
132+
return
133+
134+
adapter = FMSAdapter(base_dir=Path("fms_recommender_ouput/final"))
135+
ir, patches = adapter.execute(
136+
train_config=train_cfg,
137+
dist_config=dist_cfg,
138+
compute_config={},
139+
data_config=data_cfg,
140+
unique_tag="fms-recommender",
141+
)
142+
out = adapter._to_target(ir, patches, tag="fms-recommender")
143+
launch_cmd = out["launch_command"]
144+
145+
if args.preview:
146+
print("\n[LAUNCH COMMAND — PREVIEW ONLY]\n")
147+
print(launch_cmd)
148+
return
149+
150+
print("\n[EXECUTING launch command]\n")
151+
print(launch_cmd)
152+
subprocess.run(launch_cmd, shell=True, check=True)
153+
154+
155+
if __name__ == "__main__":
156+
main()

0 commit comments

Comments
 (0)