Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ To overwrite default `vllm serve` arguments, you can specify the arguments in a
vec-inf launch Meta-Llama-3.1-8B-Instruct --vllm-args '--max-model-len=65536,--compilation-config=3'
```

To download models directly from HuggingFace Hub without needing local weights, use `--hf-model`:

```bash
vec-inf launch Qwen2.5-3B-Instruct \
--hf-model Qwen/Qwen2.5-3B-Instruct \
--env 'HF_HOME=/path/to/cache' \
--vllm-args '--max-model-len=4096'
```

Set `HF_HOME` via `--env` to control where models are cached. If local weights exist, they take priority over `--hf-model`.

For the full list of `vllm serve` arguments, you can find them [here](https://docs.vllm.ai/en/stable/serving/engine_args.html), make sure you select the correct vLLM version.

#### Custom models
Expand Down
91 changes: 91 additions & 0 deletions tests/vec_inf/client/test_slurm_script_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
)


@pytest.fixture(autouse=True)
def patch_model_weights_exists(monkeypatch):
"""Ensure model weights directory existence checks default to True."""
monkeypatch.setattr(
"vec_inf.client._slurm_script_generator.Path.exists", lambda self: True
)


class TestSlurmScriptGenerator:
"""Tests for SlurmScriptGenerator class."""

Expand Down Expand Up @@ -168,6 +176,21 @@ def test_generate_server_setup_singularity(self, singularity_params):
"module load " in setup
) # Remove module name since it's inconsistent between clusters

def test_generate_server_setup_singularity_no_weights(
self, singularity_params, monkeypatch
):
"""Test server setup when model weights don't exist."""
monkeypatch.setattr(
"vec_inf.client._slurm_script_generator.Path.exists",
lambda self: False,
)

generator = SlurmScriptGenerator(singularity_params)
setup = generator._generate_server_setup()

assert "ray stop" in setup
assert "/path/to/model_weights/test-model" not in setup

def test_generate_launch_cmd_venv(self, basic_params):
"""Test launch command generation with virtual environment."""
generator = SlurmScriptGenerator(basic_params)
Expand All @@ -179,6 +202,21 @@ def test_generate_launch_cmd_venv(self, basic_params):
assert "--max-model-len 8192" in launch_cmd
assert "--enforce-eager" in launch_cmd

def test_generate_launch_cmd_with_hf_model_override(
self, basic_params, monkeypatch
):
"""Test launch command uses hf_model when local weights don't exist."""
monkeypatch.setattr(
"vec_inf.client._slurm_script_generator.Path.exists", lambda self: False
)
params = basic_params.copy()
params["hf_model"] = "meta-llama/Meta-Llama-3.1-8B-Instruct"
generator = SlurmScriptGenerator(params)
launch_cmd = generator._generate_launch_cmd()

assert "vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct" in launch_cmd
assert "vllm serve /path/to/model_weights/test-model" not in launch_cmd

def test_generate_launch_cmd_singularity(self, singularity_params):
"""Test launch command generation with Singularity."""
generator = SlurmScriptGenerator(singularity_params)
Expand All @@ -187,6 +225,22 @@ def test_generate_launch_cmd_singularity(self, singularity_params):
assert "apptainer exec --nv" in launch_cmd
assert "source" not in launch_cmd

def test_generate_launch_cmd_singularity_no_local_weights(
self, singularity_params, monkeypatch
):
"""Test container launch when model weights directory is missing."""
monkeypatch.setattr(
"vec_inf.client._slurm_script_generator.Path.exists",
lambda self: False,
)

generator = SlurmScriptGenerator(singularity_params)
launch_cmd = generator._generate_launch_cmd()

assert "exec --nv" in launch_cmd
assert "--bind /path/to/model_weights/test-model" not in launch_cmd
assert "vllm serve test-model" in launch_cmd

def test_generate_launch_cmd_boolean_args(self, basic_params):
"""Test launch command with boolean vLLM arguments."""
params = basic_params.copy()
Expand Down Expand Up @@ -377,6 +431,25 @@ def test_generate_model_launch_script_basic(
mock_touch.assert_called_once()
mock_write_text.assert_called_once()

@patch("pathlib.Path.touch")
@patch("pathlib.Path.write_text")
def test_generate_model_launch_script_with_hf_model_override(
self, mock_write_text, mock_touch, batch_params, monkeypatch
):
"""Test batch launch script uses hf_model when local weights don't exist."""
monkeypatch.setattr(
"vec_inf.client._slurm_script_generator.Path.exists", lambda self: False
)
params = batch_params.copy()
params["models"] = {k: v.copy() for k, v in batch_params["models"].items()}
params["models"]["model1"]["hf_model"] = "meta-llama/Meta-Llama-3.1-8B-Instruct"

generator = BatchSlurmScriptGenerator(params)
generator._generate_model_launch_script("model1")

call_args = mock_write_text.call_args[0][0]
assert "vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct" in call_args

@patch("pathlib.Path.touch")
@patch("pathlib.Path.write_text")
def test_generate_model_launch_script_singularity(
Expand All @@ -391,6 +464,24 @@ def test_generate_model_launch_script_singularity(
mock_touch.assert_called_once()
mock_write_text.assert_called_once()

@patch("pathlib.Path.touch")
@patch("pathlib.Path.write_text")
def test_generate_model_launch_script_singularity_no_weights(
self, mock_write_text, mock_touch, batch_singularity_params, monkeypatch
):
"""Test batch model launch script when model weights don't exist."""
monkeypatch.setattr(
"vec_inf.client._slurm_script_generator.Path.exists",
lambda self: False,
)

generator = BatchSlurmScriptGenerator(batch_singularity_params)
script_path = generator._generate_model_launch_script("model1")

assert script_path.name == "launch_model1.sh"
call_args = mock_write_text.call_args[0][0]
assert "/path/to/model_weights/model1" not in call_args

@patch("vec_inf.client._slurm_script_generator.datetime")
@patch("pathlib.Path.touch")
@patch("pathlib.Path.write_text")
Expand Down
11 changes: 11 additions & 0 deletions vec_inf/cli/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,15 @@ def cli() -> None:
type=str,
help="Path to parent directory containing model weights",
)
@click.option(
"--hf-model",
type=str,
help=(
"Full HuggingFace model id/path to use for vLLM serve (e.g. "
"'meta-llama/Meta-Llama-3.1-8B-Instruct'). "
"Keeps model-name as the short identifier for config/logs/job naming."
),
)
@click.option(
"--vllm-args",
type=str,
Expand Down Expand Up @@ -200,6 +209,8 @@ def launch(
Path to SLURM log directory
- model_weights_parent_dir : str, optional
Path to model weights directory
- hf_model : str, optional
Full HuggingFace model id/path to use for vLLM serve
- vllm_args : str, optional
vLLM engine arguments
- env : str, optional
Expand Down
4 changes: 4 additions & 0 deletions vec_inf/client/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ def _apply_cli_overrides(self, params: dict[str, Any]) -> None:
params : dict[str, Any]
Dictionary of launch parameters to override
"""
if self.kwargs.get("hf_model"):
params["hf_model"] = self.kwargs["hf_model"]
del self.kwargs["hf_model"]

if self.kwargs.get("vllm_args"):
vllm_args = self._process_vllm_args(self.kwargs["vllm_args"])
for key, value in vllm_args.items():
Expand Down
86 changes: 69 additions & 17 deletions vec_inf/client/_slurm_script_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SLURM_SCRIPT_TEMPLATE,
)
from vec_inf.client._slurm_vars import CONTAINER_MODULE_NAME
from vec_inf.client._utils import check_and_warn_hf_cache


class SlurmScriptGenerator:
Expand All @@ -37,8 +38,22 @@ def __init__(self, params: dict[str, Any]):
self.additional_binds = (
f",{self.params['bind']}" if self.params.get("bind") else ""
)
self.model_weights_path = str(
Path(self.params["model_weights_parent_dir"], self.params["model_name"])
model_weights_path = Path(
self.params["model_weights_parent_dir"], self.params["model_name"]
)
self.model_weights_exists = model_weights_path.exists()
self.model_weights_path = str(model_weights_path)
# Determine model source: local weights > hf_model > model name
if self.model_weights_exists:
self.model_source = self.model_weights_path
elif self.params.get("hf_model"):
self.model_source = self.params["hf_model"]
else:
self.model_source = self.params["model_name"]
check_and_warn_hf_cache(
self.model_weights_exists,
self.model_weights_path,
self.params.get("env", {}),
)
self.env_str = self._generate_env_str()

Expand Down Expand Up @@ -111,7 +126,9 @@ def _generate_server_setup(self) -> str:
server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_setup"]))
server_script.append(
SLURM_SCRIPT_TEMPLATE["bind_path"].format(
model_weights_path=self.model_weights_path,
model_weights_path=self.model_weights_path
if self.model_weights_exists
else "",
additional_binds=self.additional_binds,
)
)
Expand All @@ -131,7 +148,6 @@ def _generate_server_setup(self) -> str:
server_setup_str = server_setup_str.replace(
"CONTAINER_PLACEHOLDER",
SLURM_SCRIPT_TEMPLATE["container_command"].format(
model_weights_path=self.model_weights_path,
env_str=self.env_str,
),
)
Expand Down Expand Up @@ -165,22 +181,27 @@ def _generate_launch_cmd(self) -> str:
Server launch command.
"""
launcher_script = ["\n"]

vllm_args_copy = self.params["vllm_args"].copy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is necessary, as the model name should be parsed with launch command not part of --vllm-args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, model_name is the short name used for config lookup, log directories, and job naming (e.g., llama-3). The --model in vllm_args would allow users to specify the full HF path when downloading from HuggingFace.

I'm open to alternative approaches if you have a preference, like:
Dedicated CLI option (e.g., --hf-model) - keeps model_name as the short identifier, adds explicit option for full HF path
Reuse existing model_name - allow full HF paths directly, but adjust config lookups, log directory structure, etc. to handle paths with /

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah thanks for the clarification, I think having a dedicated CLI option keeps things clean and means minimal changes. However the code base has changed pretty significantly and there are quite a few conflicts in order to merge the changes, if you give me the access to your branch I can help resolve the conflicts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the PR to address your suggestion. Since my fork now lives under an org, I didn't have a simple option to grant access to the branch. So, I have also invited you to the repository, please feel free to make changes.

model_source = self.model_source
if "--model" in vllm_args_copy:
model_source = vllm_args_copy.pop("--model")

if self.use_container:
launcher_script.append(
SLURM_SCRIPT_TEMPLATE["container_command"].format(
model_weights_path=self.model_weights_path,
env_str=self.env_str,
)
)

launcher_script.append(
"\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"]).format(
model_weights_path=self.model_weights_path,
model_source=model_source,
model_name=self.params["model_name"],
)
)

for arg, value in self.params["vllm_args"].items():
for arg, value in vllm_args_copy.items():
if isinstance(value, bool):
launcher_script.append(f" {arg} \\")
else:
Expand Down Expand Up @@ -225,11 +246,34 @@ def __init__(self, params: dict[str, Any]):
if self.params["models"][model_name].get("bind")
else ""
)
self.params["models"][model_name]["model_weights_path"] = str(
Path(
self.params["models"][model_name]["model_weights_parent_dir"],
model_name,
model_weights_path = Path(
self.params["models"][model_name]["model_weights_parent_dir"],
model_name,
)
model_weights_exists = model_weights_path.exists()
model_weights_path_str = str(model_weights_path)
self.params["models"][model_name]["model_weights_path"] = (
model_weights_path_str
)
self.params["models"][model_name]["model_weights_exists"] = (
model_weights_exists
)
# Determine model source: local weights > hf_model > model name
if model_weights_exists:
self.params["models"][model_name]["model_source"] = (
model_weights_path_str
)
elif self.params["models"][model_name].get("hf_model"):
self.params["models"][model_name]["model_source"] = self.params[
"models"
][model_name]["hf_model"]
else:
self.params["models"][model_name]["model_source"] = model_name
check_and_warn_hf_cache(
model_weights_exists,
model_weights_path_str,
self.params["models"][model_name].get("env", {}),
model_name,
)

def _write_to_log_dir(self, script_content: list[str], script_name: str) -> Path:
Expand Down Expand Up @@ -266,7 +310,9 @@ def _generate_model_launch_script(self, model_name: str) -> Path:
script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_setup"])
script_content.append(
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["bind_path"].format(
model_weights_path=model_params["model_weights_path"],
model_weights_path=model_params["model_weights_path"]
if model_params.get("model_weights_exists", True)
else "",
additional_binds=model_params["additional_binds"],
)
)
Expand All @@ -283,19 +329,25 @@ def _generate_model_launch_script(self, model_name: str) -> Path:
model_name=model_name,
)
)
vllm_args_copy = model_params["vllm_args"].copy()
model_source = model_params.get(
"model_source", model_params["model_weights_path"]
)
if "--model" in vllm_args_copy:
model_source = vllm_args_copy.pop("--model")

if self.use_container:
script_content.append(
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format(
model_weights_path=model_params["model_weights_path"],
)
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format()
)
script_content.append(
"\n".join(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["launch_cmd"]).format(
model_weights_path=model_params["model_weights_path"],
model_source=model_source,
model_name=model_name,
)
)
for arg, value in model_params["vllm_args"].items():

for arg, value in vllm_args_copy.items():
if isinstance(value, bool):
script_content.append(f" {arg} \\")
else:
Expand Down
6 changes: 3 additions & 3 deletions vec_inf/client/_slurm_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class SlurmScriptTemplate(TypedDict):
f"{CONTAINER_MODULE_NAME} exec {IMAGE_PATH} ray stop",
],
"imports": "source {src_dir}/find_port.sh",
"bind_path": f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,/dev,/tmp,{{model_weights_path}}{{additional_binds}}",
"bind_path": f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,/dev,/tmp{{model_weights_path}}{{additional_binds}}",
"container_command": f"{CONTAINER_MODULE_NAME} exec --nv {{env_str}} --containall {IMAGE_PATH} \\",
"activate_venv": "source {venv}/bin/activate",
"server_setup": {
Expand Down Expand Up @@ -164,7 +164,7 @@ class SlurmScriptTemplate(TypedDict):
' && mv temp.json "$json_path"',
],
"launch_cmd": [
"vllm serve {model_weights_path} \\",
"vllm serve {model_source} \\",
" --served-model-name {model_name} \\",
' --host "0.0.0.0" \\',
" --port $vllm_port_number \\",
Expand Down Expand Up @@ -255,7 +255,7 @@ class BatchModelLaunchScriptTemplate(TypedDict):
],
"container_command": f"{CONTAINER_MODULE_NAME} exec --nv --containall {IMAGE_PATH} \\",
"launch_cmd": [
"vllm serve {model_weights_path} \\",
"vllm serve {model_source} \\",
" --served-model-name {model_name} \\",
' --host "0.0.0.0" \\',
" --port $vllm_port_number \\",
Expand Down
Loading