Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,26 @@ cd fms-model-optimizer
pip install -e .
```

#### Optional Dependencies
The following optional dependencies are available:
- `fp8`: `llmcompressor` package for fp8 quantization
- `gptq`: `GPTQModel` package for W4A16 quantization
- `mx`: `microxcaling` package for MX quantization
- `opt`: Shortcut for `fp8`, `gptq`, and `mx` installs
- `torchvision`: `torch` package for image recognition training and inference
- `visualize`: Dependencies for visualizing models and performance data
- `test`: Dependencies needed for unit testing
- `dev`: Dependencies needed for development

To install an optional dependency, modify the `pip install` commands above with a list of these names enclosed in brackets. The example below installs `llm-compressor` and `torchvision` with FMS Model Optimizer:

```shell
pip install fms-model-optimizer[fp8,torchvision]

pip install -e .[fp8,torchvision]
```
If you have already installed FMS Model Optimizer, then only the optional packages will be installed.

### Try It Out!

To help you get up and running as quickly as possible with the FMS Model Optimizer framework, check out the following resources which demonstrate how to use the framework with different quantization techniques:
Expand Down
57 changes: 35 additions & 22 deletions fms_mo/fx/dynamo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_target_op_from_mod_or_str,
get_target_op_from_node,
)
from fms_mo.utils.import_utils import available_packages

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1133,7 +1134,6 @@ def cus_backend_model_analyzer(
from functools import partial

# Third Party
from torchvision.models import VisionTransformer
from transformers import PreTrainedModel

if issubclass(type(model), torch.nn.Module):
Expand All @@ -1145,7 +1145,16 @@ def cus_backend_model_analyzer(
model_to_be_traced = model
model_param_size = 999

is_transformers = issubclass(type(model), (PreTrainedModel, VisionTransformer))
transformer_model_classes = (PreTrainedModel,)

if available_packages["torchvision"]:
# Third Party
# pylint: disable = import-error
from torchvision.models import VisionTransformer

transformer_model_classes += (VisionTransformer,)

is_transformers = issubclass(type(model), transformer_model_classes)
if model_param_size > 1:
# Standard
import sys
Expand Down Expand Up @@ -1188,11 +1197,13 @@ def call_seq_hook(mod, *_args, **_kwargs):

# only add last layer
qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][-1]]
# unless it's a ViT, skip first Conv as well
if issubclass(type(model), VisionTransformer) and isinstance(
model.get_submodule(qcfg["mod_call_seq"][0]), torch.nn.Conv2d
):
qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][0]]

if available_packages["torchvision"]:
# unless it's a ViT, skip first Conv as well
if issubclass(type(model), VisionTransformer) and isinstance(
model.get_submodule(qcfg["mod_call_seq"][0]), torch.nn.Conv2d
):
qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][0]]

with torch.no_grad():
model_opt = torch.compile(
Expand Down Expand Up @@ -1271,21 +1282,23 @@ def qbmm_auto_check(_mod, *_args, **_kwargs):
# c) identify RPN/FPN
# TODO this hack only works for torchvision models. will use find_rpn_fpn_gm()

# Third Party
from torchvision.models.detection.rpn import RegionProposalNetwork
from torchvision.ops import FeaturePyramidNetwork

rpnfpn_prefix = []
rpnfpn_convs = []
for n, m in model.named_modules():
if isinstance(m, (FeaturePyramidNetwork, RegionProposalNetwork)):
rpnfpn_prefix.append(n)
if isinstance(m, torch.nn.Conv2d) and any(
n.startswith(p) for p in rpnfpn_prefix
):
rpnfpn_convs.append(n)
if n not in qcfg["qskip_layer_name"]:
qcfg["qskip_layer_name"].append(n)
if available_packages["torchvision"]:
# Third Party
# pylint: disable = import-error
from torchvision.models.detection.rpn import RegionProposalNetwork
from torchvision.ops import FeaturePyramidNetwork

rpnfpn_prefix = []
rpnfpn_convs = []
for n, m in model.named_modules():
if isinstance(m, (FeaturePyramidNetwork, RegionProposalNetwork)):
rpnfpn_prefix.append(n)
if isinstance(m, torch.nn.Conv2d) and any(
n.startswith(p) for p in rpnfpn_prefix
):
rpnfpn_convs.append(n)
if n not in qcfg["qskip_layer_name"]:
qcfg["qskip_layer_name"].append(n)

if qcfg["N_backend_called"] > 1:
logger.warning(
Expand Down
1 change: 1 addition & 0 deletions fms_mo/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"pygraphviz",
"fms",
"triton",
"torchvision",
]

available_packages = {}
Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ dependencies = [
"ninja>=1.11.1.1,<2.0",
"tensorboard",
"notebook",
"torchvision>=0.17",
"evaluate",
"huggingface_hub",
"pandas",
Expand All @@ -42,13 +41,15 @@ dependencies = [
]

[project.optional-dependencies]
dev = ["pre-commit>=3.0.4,<5.0"]
fp8 = ["llmcompressor"]
gptq = ["Cython", "gptqmodel>=1.7.3"]
mx = ["microxcaling>=1.1"]
visualize = ["matplotlib", "graphviz", "pygraphviz"]
opt = ["fms-model-optimizer[fp8, gptq, mx]"]
torchvision = ["torchvision>=0.17"]
flash-attn = ["flash-attn>=2.5.3,<3.0"]
opt = ["fms-model-optimizer[fp8, gptq]"]
visualize = ["matplotlib", "graphviz", "pygraphviz"]
dev = ["pre-commit>=3.0.4,<5.0"]
test = ["pytest", "pillow"]

[project.urls]
homepage = "https://github.com/foundation-model-stack/fms-model-optimizer"
Expand Down
148 changes: 115 additions & 33 deletions tests/models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
import os

# Third Party
from PIL import Image # pylint: disable=import-error
from torch.utils.data import DataLoader, TensorDataset
from torchvision.io import read_image
from torchvision.models import ResNet50_Weights, ViT_B_16_Weights, resnet50, vit_b_16
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
BertConfig,
BertModel,
BertTokenizer,
Expand All @@ -43,6 +44,7 @@
# fms_mo imports
from fms_mo import qconfig_init
from fms_mo.modules import QLSTM, QBmm, QConv2d, QConvTranspose2d, QLinear
from fms_mo.utils.import_utils import available_packages
from fms_mo.utils.qconfig_utils import get_mx_specs_defaults, set_mx_specs

########################
Expand Down Expand Up @@ -1123,75 +1125,155 @@ def required_pair(request):
# Vision Model Fixtures #
#########################

# Create img
# downloaded from torchvision github (vision/test/assets/encoder_jpeg/ directory)
img = read_image(

if available_packages["torchvision"]:
# Third Party
# pylint: disable = import-error
from torchvision.io import read_image
from torchvision.models import (
ResNet50_Weights,
ViT_B_16_Weights,
resnet50,
vit_b_16,
)

# Create img
# downloaded from torchvision github (vision/test/assets/encoder_jpeg/ directory)
img_tv = read_image(
os.path.realpath(
os.path.join(os.path.dirname(__file__), "grace_hopper_517x606.jpg")
)
)

# Create resnet/vitbatch fixtures from weights
def prepocess_img(image, weights):
"""
Preprocess an image w/ a weights.transform()

Args:
img_tv (torch.FloatTensor): Image data
weights (torchvision.models): Weight object

Returns:
torch.FloatTensor: Preprocessed image
"""
preprocess = weights.transforms()
batch = preprocess(image).unsqueeze(0)
return batch

@pytest.fixture(scope="session")
def batch_resnet():
"""
Preprocess an image w/ Resnet weights.transform()

Returns:
torch.FloatTensor: Preprocessed image
"""
return prepocess_img(img_tv, ResNet50_Weights.IMAGENET1K_V2)

@pytest.fixture(scope="session")
def batch_vit():
"""
Preprocess an image w/ ViT weights.transform()

Returns:
torch.FloatTensor: Preprocessed image
"""
return prepocess_img(img_tv, ViT_B_16_Weights.IMAGENET1K_V1)

# Create resnet/vit model fixtures from weights
@pytest.fixture(scope="function")
def model_resnet():
"""
Create Resnet50 model + weights

Returns:
torchvision.models.resnet.ResNet: Resnet50 model
"""
return resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

@pytest.fixture(scope="function")
def model_vit():
"""
Create ViT model + weights

Returns:
torchvision.models.vision_transformer.VisionTransformer: ViT model
"""
return vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)


img = Image.open(
os.path.realpath(
os.path.join(os.path.dirname(__file__), "grace_hopper_517x606.jpg")
)
)
).convert("RGB")


# Create resnet/vit batch fixtures from weights
def prepocess_img(image, weights):
def process_img(
pretrained_model: str,
input_img: Image.Image,
):
"""
Preprocess an image w/ a weights.transform()
Process an image w/ AutoImageProcessor

Args:
img (torch.FloatTensor): Image data
weights (torchvision.models): Weight object
processor (AutoImageProcessor): Processor weights for pretrained model
pretrained_model (str): Weight object
input_img (Image.Image): Image data

Returns:
torch.FloatTensor: Preprocessed image
torch.FloatTensor: Processed image
"""
preprocess = weights.transforms()
batch = preprocess(image).unsqueeze(0)
return batch
img_processor = AutoImageProcessor.from_pretrained(pretrained_model, use_fast=True)
batch_dict = img_processor(images=input_img, return_tensors="pt")
return batch_dict["pixel_values"]


@pytest.fixture(scope="session")
def batch_resnet():
@pytest.fixture(scope="function")
def batch_resnet18():
"""
Preprocess an image w/ Resnet weights.transform()
Preprocess an image w/ ms resnet18 processor

Returns:
torch.FloatTensor: Preprocessed image
"""
return prepocess_img(img, ResNet50_Weights.IMAGENET1K_V2)
return process_img("microsoft/resnet-18", img)


@pytest.fixture(scope="session")
def batch_vit():
@pytest.fixture(scope="function")
def model_resnet18():
"""
Preprocess an image w/ ViT weights.transform()
Create MS ResNet18 model + weights

Returns:
torch.FloatTensor: Preprocessed image
AutoModelForImageClassification: Resnet18 model
"""
return prepocess_img(img, ViT_B_16_Weights.IMAGENET1K_V1)
return AutoModelForImageClassification.from_pretrained("microsoft/resnet-18")


# Create resnet/vit model fixtures from weights
@pytest.fixture(scope="function")
def model_resnet():
def batch_vit_base():
"""
Create Resnet50 model + weights
Preprocess an image w/ Google ViT-base processor

Returns:
torchvision.models.resnet.ResNet: Resnet50 model
torch.FloatTensor: Preprocessed image
"""
return resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
return process_img("google/vit-base-patch16-224", img)


@pytest.fixture(scope="function")
def model_vit():
def model_vit_base():
"""
Create ViT model + weights
Create Google ViT-base model + weights

Returns:
torchvision.models.vision_transformer.VisionTransformer: ViT model
AutoModelForImageClassification: Google ViT-base model
"""
return vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
return AutoModelForImageClassification.from_pretrained(
"google/vit-base-patch16-224"
)


#######################
Expand Down
Loading
Loading