Skip to content

Commit e6fa170

Browse files
committed
Init Upload
1 parent 2d2a31f commit e6fa170

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

rollouter/model_service.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def __init__(self, model_cfg: ModelConfig):
211211
self.vllm_params = model_cfg.vllm_params
212212
self.gpu_memory_utilization = model_cfg.vllm_params.get("gpu_memory_utilization", 0.9)
213213
self.save_local = model_cfg.get("save_local", False)
214+
self.save_path = model_cfg.get("save_path", "./")
214215

215216
# 优化后的细粒度锁机制
216217
self.instances_lock = asyncio.Lock() # 保护service_instances字典
@@ -798,7 +799,7 @@ async def generate(self, messages: List[Dict[str, Any]], **kwargs) -> str:
798799
task_id = kwargs.get("task_id")
799800
trace_id = kwargs.get("trace_id")
800801
step = kwargs.get("step")
801-
save_dir = os.path.join(f"./{task_id}_trace-{trace_id}")
802+
save_dir = os.path.join(self.save_path, f"{task_id}_trace-{trace_id}")
802803
os.makedirs(save_dir, exist_ok=True)
803804
save_path = os.path.join(save_dir, f"image_{int(step) - 1}.png")
804805
with open(save_path, "wb") as f:
@@ -847,7 +848,7 @@ async def generate(self, messages: List[Dict[str, Any]], **kwargs) -> str:
847848
task_id = kwargs.get("task_id")
848849
trace_id = kwargs.get("trace_id")
849850
step = kwargs.get("step")
850-
save_dir = os.path.join(f"./{task_id}_trace-{trace_id}")
851+
save_dir = os.path.join(self.save_path, f"{task_id}_trace-{trace_id}")
851852
os.makedirs(save_dir, exist_ok=True)
852853
save_path = os.path.join(save_dir, f"data_for_step_{int(step)}.pt")
853854

@@ -878,7 +879,7 @@ async def save(self, messages: List[Dict], reward: float, task_id: str, trace_id
878879
return {"status": "skipped"}
879880

880881
try:
881-
save_dir = os.path.join(os.getcwd(), f"{task_id}_trace-{trace_id}")
882+
save_dir = os.path.join(self.save_path, f"{task_id}_trace-{trace_id}")
882883
os.makedirs(save_dir, exist_ok=True)
883884

884885
# 保存 messages
@@ -1122,7 +1123,8 @@ def main(cfg: DictConfig):
11221123
base_port=cfg.model.base_port,
11231124
replicas=cfg.model.replicas,
11241125
vllm_params=OmegaConf.to_container(cfg.model.vllm_params),
1125-
save_local=cfg.model.save_local
1126+
save_local=cfg.model.save_local,
1127+
save_path=cfg.storage.root
11261128
)
11271129

11281130
# 定义并创建 lifespan 函数的闭包

0 commit comments

Comments
 (0)