Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions modelopt/torch/opt/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def _modelopt_set_extra_state(self, state: Any):
return

if isinstance(state, torch.Tensor):
if state.numel() == 0:
return
# Default format: byte tensor with pickled data
#
# TODO: possible deserialization improvement
Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/quantization/model_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,8 @@ def print_quant_summary(model: nn.Module):
print(f"{count} TensorQuantizers found in model")


def fold_weight(model: nn.Module):
def fold_weight(model: nn.Module, keep_attrs: bool = False):
"""Fold weight quantizer for fast evaluation."""
for name, module in model.named_modules():
if isinstance(module, QuantModule):
module.fold_weight()
module.fold_weight(keep_attrs)
19 changes: 10 additions & 9 deletions modelopt/torch/quantization/nn/modules/quant_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ def forward(self, input, *args, **kwargs):
output = super().forward(input, *args, **kwargs)
return output

def fold_weight(self):
def fold_weight(self, keep_attrs: bool = False):
"""Fold the weight for faster eval."""
super().fold_weight()
super().fold_weight(keep_attrs)
if (
hasattr(self, "weight_quantizer")
and hasattr(self, "weight")
Expand All @@ -179,13 +179,14 @@ def fold_weight(self):
self.weight
+ self.weight_quantizer.svdquant_lora_b @ self.weight_quantizer.svdquant_lora_a
)
_attrs = [
"_svdquant_lora_a",
"_svdquant_lora_b",
]
for attr in _attrs:
if hasattr(self.weight_quantizer, attr):
delattr(self.weight_quantizer, attr)
if not keep_attrs:
_attrs = [
"_svdquant_lora_a",
"_svdquant_lora_b",
]
for attr in _attrs:
if hasattr(self.weight_quantizer, attr):
delattr(self.weight_quantizer, attr)


class RealQuantLinear(QuantModule):
Expand Down
17 changes: 9 additions & 8 deletions modelopt/torch/quantization/nn/modules/quant_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def modelopt_post_restore(self, prefix: str = ""):
if isinstance(module, TensorQuantizer):
module.to(non_tq_param_or_buffer.device)

def fold_weight(self):
def fold_weight(self, keep_attrs: bool = False):
"""Fold the weight for faster eval."""
# Handle all attributes that end with _weight_quantizer
for name in dir(self):
Expand All @@ -87,13 +87,14 @@ def fold_weight(self):
weight = getattr(self, weight_name)
weight.data.copy_(attr(weight.float()).to(weight.dtype))
attr.disable()
_attrs = [
"_pre_quant_scale",
"_amax",
]
for attr_name in _attrs:
if hasattr(attr, attr_name):
delattr(attr, attr_name)
if not keep_attrs:
_attrs = [
"_pre_quant_scale",
"_amax",
]
for attr_name in _attrs:
if hasattr(attr, attr_name):
delattr(attr, attr_name)


QuantModuleRegistry = _DMRegistryCls("Quant", QuantModule)
Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,9 @@ class HFRowParallelLinear(HFParallelLinear):
class _QuantHFParallelLinear(_ParallelLinear):
_functionals_to_replace = [(torch.nn.functional, "linear")]

def fold_weight(self):
def fold_weight(self, keep_attrs: bool = False):
with self.enable_weight_access_and_writeback():
super().fold_weight()
super().fold_weight(keep_attrs)

@contextmanager
def enable_weight_access_and_writeback(self):
Expand Down
2 changes: 1 addition & 1 deletion modelopt/torch/quantization/plugins/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
)

@torch.no_grad()
def fold_weight(self):
def fold_weight(self, keep_attrs: bool = False):
# the MoE weights can be super large, it consumes too much memory, so we need to fold the weight one by one
for i in range(self.w13_weight.shape[0]):
self.w13_weight[i].copy_(
Expand Down
51 changes: 49 additions & 2 deletions modelopt/torch/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Utility functions for getting samples and forward loop function for different datasets."""

import copy
import json
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from warnings import warn
Expand Down Expand Up @@ -102,16 +103,61 @@
]


def _get_jsonl_text_samples(jsonl_path: str, num_samples: int) -> list[str]:
"""Load up to ``num_samples`` entries from a JSONL file using the ``text`` field.

Each non-empty line must be a JSON object containing a ``text`` field.
"""
if num_samples <= 0:
return []

samples: list[str] = []

with open(jsonl_path, encoding="utf-8") as f:
for line_idx, line in enumerate(f, start=1):
if len(samples) >= num_samples:
break
line = line.strip()
if not line:
continue

try:
obj = json.loads(line)
except json.JSONDecodeError as e:
raise ValueError(
f"Invalid JSON in JSONL file {jsonl_path} at line {line_idx}: {e}"
) from e

if not isinstance(obj, dict):
raise ValueError(
f"Expected a JSON object in JSONL file {jsonl_path} at line {line_idx}, "
f"got {type(obj)}."
)

if "text" not in obj:
raise ValueError(
f"Missing required field 'text' in JSONL file {jsonl_path} at line {line_idx}."
)

samples.append(str(obj["text"]))

return samples


def _get_dataset_samples(dataset_name: str, num_samples: int) -> list[str]:
"""Load a portion of train dataset with the dataset name and a given size.

Args:
dataset_name: Name of the dataset to load.
dataset_name: Name of the dataset to load, or a path to a ``.jsonl``/``.jsonl.gz`` file.
num_samples: Number of samples to load from the dataset.

Returns:
Samples: The list of samples.
"""
# Local JSONL file path support (each line is a JSON object with a `text` field).
if dataset_name.endswith((".jsonl", ".jsonl.gz")):
return _get_jsonl_text_samples(dataset_name, num_samples)

# Load the dataset
if dataset_name not in SUPPORTED_DATASET_CONFIG:
raise NotImplementedError(
Expand Down Expand Up @@ -179,7 +225,8 @@ def get_dataset_dataloader(
"""Get a dataloader with the dataset name and toknizer of the target model.

Args:
dataset_name: Name of the dataset to load.
dataset_name: Name of the dataset to load, or a path to a ``.jsonl`` file.
If a ``.jsonl`` file is provided, each line must be a JSON object with a ``text`` field.
tokenizer: Instancne of Hugginface tokenizer.
batch_size: Batch size of the returned dataloader.
num_samples: Number of samples from the dataset.
Expand Down
4 changes: 3 additions & 1 deletion modelopt/torch/utils/plugins/megatron_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def megatron_prefill(
pixel_values: torch.FloatTensor | None = None,
image_grid_thw: torch.LongTensor | None = None,
image_sizes: torch.LongTensor | None = None,
skip_return_logits: bool = False,
) -> torch.Tensor:
"""A simple prefill function for Megatron Core V(LM) models."""
if not isinstance(model, MegatronModule):
Expand Down Expand Up @@ -110,6 +111,8 @@ def _forward_step_func(data, model):
forward_only=True,
collect_non_loss_data=True,
)
if skip_return_logits:
return None

if mpu.is_pipeline_last_stage():
logits = list_of_logits[0][:, :seq_length, :].detach()
Expand All @@ -122,7 +125,6 @@ def _forward_step_func(data, model):
logits_dtype = torch.float16
else:
logits_dtype = torch.float32

logits = broadcast_from_last_pipeline_stage(
[max_batch_size, seq_length, model.vocab_size], logits_dtype, logits
)
Expand Down