-
Notifications
You must be signed in to change notification settings - Fork 257
Open
Description
Checklist
- This feature will maintain backward compatibility with the current APIs in
areal/api/. If not, please raise a refactor issue first.
Background
"I'd like to support Lora VLM Training, vLLM and SGLang can not support vision Lora."
I need to use Pytorch or Unsloth to generate my rollout.
Potential Solution
Wrap the Unsloth/Pytorch/Transformers generation to something like OpenAI API format like vLLM.
Additional Information
def _init_rollout(
self, rollout_config: InferenceEngineConfig, is_eval: bool = False
) -> InferenceEngine | RolloutController:
# Create a working copy of config
config = deepcopy(rollout_config)
if is_eval:
# NOTE: eval does not have any offpolicyness control
config.max_head_offpolicyness = int(1e12)
# Determine engine class and server args based on backend
if self.allocation_mode.gen_backend == "sglang":
engine_cls = RemoteSGLangEngine
server_args = SGLangConfig.build_args(
sglang_config=self.config.sglang,
tp_size=self.allocation_mode.gen.tp_size,
base_gpu_id=0,
)
elif self.allocation_mode.gen_backend == "vllm":
engine_cls = RemotevLLMEngine
server_args = vLLMConfig.build_args(
vllm_config=self.config.vllm,
tp_size=self.allocation_mode.gen.tp_size,
pp_size=self.allocation_mode.gen.pp_size,
)
else:
raise ValueError(
f"Invalid backend: {self.allocation_mode.gen_backend}, expected sglang or vllm"
)Something like:
def _generate_response(
self,
conversation: List[Dict],
verbose: bool
) -> Tuple[Dict[str, Any], str, int, int]:
"""Generate a response from the model with OpenAI-format tool calling.
Uses transformers' native chat template with tools support.
Args:
conversation: Conversation history
verbose: Whether to print logs
Returns:
Tuple of (response_message_dict, raw_content, input_tokens, output_tokens)
- response_message_dict: Contains 'content' and/or 'tool_calls'
- raw_content: Raw generated text for debugging
- input_tokens: Number of input tokens
- output_tokens: Number of output tokens
"""
# Check if tokenizer supports tool calling via chat template
# Format conversation with tools using chat template
# Many modern tokenizers (Llama 3, Qwen, etc.) support tools parameter
text = self.tokenizer.apply_chat_template(
conversation,
tools=self.tools,
tokenize=False,
add_generation_prompt=True,
)
# Generate
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
input_tokens = inputs.input_ids.shape[1]
temperature = 0.7
repetition_penalty = 1.0
outputs = self.model.generate(
**inputs,
max_new_tokens=self.policy_config.max_tokens,
temperature=temperature,
repetition_penalty=repetition_penalty,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)
output_tokens = outputs.shape[1] - input_tokens
# Decode response
response = self.tokenizer.decode(
outputs[0][input_tokens:],
skip_special_tokens=True,
)
# Parse tool calls from response if present
response_message, parsed_successfully = self._parse_tool_calls_from_response(response)
return response_message, response, input_tokens, output_tokens
Metadata
Metadata
Assignees
Labels
No labels