@@ -211,6 +211,7 @@ def __init__(self, model_cfg: ModelConfig):
211211 self .vllm_params = model_cfg .vllm_params
212212 self .gpu_memory_utilization = model_cfg .vllm_params .get ("gpu_memory_utilization" , 0.9 )
213213 self .save_local = model_cfg .get ("save_local" , False )
214+ self .save_path = model_cfg .get ("save_path" , "./" )
214215
215216 # 优化后的细粒度锁机制
216217 self .instances_lock = asyncio .Lock () # 保护service_instances字典
@@ -798,7 +799,7 @@ async def generate(self, messages: List[Dict[str, Any]], **kwargs) -> str:
798799 task_id = kwargs .get ("task_id" )
799800 trace_id = kwargs .get ("trace_id" )
800801 step = kwargs .get ("step" )
801- save_dir = os .path .join (f"./ { task_id } _trace-{ trace_id } " )
802+ save_dir = os .path .join (self . save_path , f" { task_id } _trace-{ trace_id } " )
802803 os .makedirs (save_dir , exist_ok = True )
803804 save_path = os .path .join (save_dir , f"image_{ int (step ) - 1 } .png" )
804805 with open (save_path , "wb" ) as f :
@@ -847,7 +848,7 @@ async def generate(self, messages: List[Dict[str, Any]], **kwargs) -> str:
847848 task_id = kwargs .get ("task_id" )
848849 trace_id = kwargs .get ("trace_id" )
849850 step = kwargs .get ("step" )
850- save_dir = os .path .join (f"./ { task_id } _trace-{ trace_id } " )
851+ save_dir = os .path .join (self . save_path , f" { task_id } _trace-{ trace_id } " )
851852 os .makedirs (save_dir , exist_ok = True )
852853 save_path = os .path .join (save_dir , f"data_for_step_{ int (step )} .pt" )
853854
@@ -878,7 +879,7 @@ async def save(self, messages: List[Dict], reward: float, task_id: str, trace_id
878879 return {"status" : "skipped" }
879880
880881 try :
881- save_dir = os .path .join (os . getcwd () , f"{ task_id } _trace-{ trace_id } " )
882+ save_dir = os .path .join (self . save_path , f"{ task_id } _trace-{ trace_id } " )
882883 os .makedirs (save_dir , exist_ok = True )
883884
884885 # 保存 messages
@@ -1122,7 +1123,8 @@ def main(cfg: DictConfig):
11221123 base_port = cfg .model .base_port ,
11231124 replicas = cfg .model .replicas ,
11241125 vllm_params = OmegaConf .to_container (cfg .model .vllm_params ),
1125- save_local = cfg .model .save_local
1126+ save_local = cfg .model .save_local ,
1127+ save_path = cfg .storage .root
11261128 )
11271129
11281130 # 定义并创建 lifespan 函数的闭包
0 commit comments