Skip to content

[Feature] Support for Pytorch/Unsloth InferenceEngine #756

@zhenweiwang1990

Description

@zhenweiwang1990

Checklist

  • This feature will maintain backward compatibility with the current APIs in
    areal/api/. If not, please raise a refactor issue first.

Background

"I'd like to support Lora VLM Training, vLLM and SGLang can not support vision Lora."
I need to use Pytorch or Unsloth to generate my rollout.

Potential Solution

Wrap the Unsloth/Pytorch/Transformers generation to something like OpenAI API format like vLLM.

Additional Information

    def _init_rollout(
        self, rollout_config: InferenceEngineConfig, is_eval: bool = False
    ) -> InferenceEngine | RolloutController:
        # Create a working copy of config
        config = deepcopy(rollout_config)
        if is_eval:
            # NOTE: eval does not have any offpolicyness control
            config.max_head_offpolicyness = int(1e12)

        # Determine engine class and server args based on backend
        if self.allocation_mode.gen_backend == "sglang":
            engine_cls = RemoteSGLangEngine
            server_args = SGLangConfig.build_args(
                sglang_config=self.config.sglang,
                tp_size=self.allocation_mode.gen.tp_size,
                base_gpu_id=0,
            )
        elif self.allocation_mode.gen_backend == "vllm":
            engine_cls = RemotevLLMEngine
            server_args = vLLMConfig.build_args(
                vllm_config=self.config.vllm,
                tp_size=self.allocation_mode.gen.tp_size,
                pp_size=self.allocation_mode.gen.pp_size,
            )
        else:
            raise ValueError(
                f"Invalid backend: {self.allocation_mode.gen_backend}, expected sglang or vllm"
            )

Something like:

    def _generate_response(
        self, 
        conversation: List[Dict], 
        verbose: bool
    ) -> Tuple[Dict[str, Any], str, int, int]:
        """Generate a response from the model with OpenAI-format tool calling.
        
        Uses transformers' native chat template with tools support.
        
        Args:
            conversation: Conversation history
            verbose: Whether to print logs
            
        Returns:
            Tuple of (response_message_dict, raw_content, input_tokens, output_tokens)
            - response_message_dict: Contains 'content' and/or 'tool_calls'
            - raw_content: Raw generated text for debugging
            - input_tokens: Number of input tokens
            - output_tokens: Number of output tokens
        """
        # Check if tokenizer supports tool calling via chat template
        # Format conversation with tools using chat template
        # Many modern tokenizers (Llama 3, Qwen, etc.) support tools parameter
        text = self.tokenizer.apply_chat_template(
            conversation,
            tools=self.tools,
            tokenize=False,
            add_generation_prompt=True,
        )
        
        # Generate
        inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
        input_tokens = inputs.input_ids.shape[1]
        
        temperature = 0.7
        repetition_penalty = 1.0
        
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=self.policy_config.max_tokens,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            pad_token_id=self.tokenizer.pad_token_id,
        )
        
        output_tokens = outputs.shape[1] - input_tokens
        
        # Decode response
        response = self.tokenizer.decode(
            outputs[0][input_tokens:],
            skip_special_tokens=True,
        )
        
        # Parse tool calls from response if present
        response_message, parsed_successfully = self._parse_tool_calls_from_response(response)
 
        return response_message, response, input_tokens, output_tokens
    

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions