diff --git a/README.md b/README.md index a27e36a9..aa422d2f 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,26 @@ cd fms-model-optimizer pip install -e . ``` +#### Optional Dependencies +The following optional dependencies are available: +- `fp8`: `llmcompressor` package for fp8 quantization +- `gptq`: `GPTQModel` package for W4A16 quantization +- `mx`: `microxcaling` package for MX quantization +- `opt`: Shortcut for `fp8`, `gptq`, and `mx` installs +- `torchvision`: `torch` package for image recognition training and inference +- `visualize`: Dependencies for visualizing models and performance data +- `test`: Dependencies needed for unit testing +- `dev`: Dependencies needed for development + +To install an optional dependency, modify the `pip install` commands above with a list of these names enclosed in brackets. The example below installs `llm-compressor` and `torchvision` with FMS Model Optimizer: + +```shell +pip install fms-model-optimizer[fp8,torchvision] + +pip install -e .[fp8,torchvision] +``` +If you have already installed FMS Model Optimizer, then only the optional packages will be installed. + ### Try It Out! To help you get up and running as quickly as possible with the FMS Model Optimizer framework, check out the following resources which demonstrate how to use the framework with different quantization techniques: diff --git a/fms_mo/fx/dynamo_utils.py b/fms_mo/fx/dynamo_utils.py index 71bc069b..62967ec2 100644 --- a/fms_mo/fx/dynamo_utils.py +++ b/fms_mo/fx/dynamo_utils.py @@ -29,6 +29,7 @@ get_target_op_from_mod_or_str, get_target_op_from_node, ) +from fms_mo.utils.import_utils import available_packages logger = logging.getLogger(__name__) @@ -1133,7 +1134,6 @@ def cus_backend_model_analyzer( from functools import partial # Third Party - from torchvision.models import VisionTransformer from transformers import PreTrainedModel if issubclass(type(model), torch.nn.Module): @@ -1145,7 +1145,16 @@ def cus_backend_model_analyzer( model_to_be_traced = model model_param_size = 999 - is_transformers = issubclass(type(model), (PreTrainedModel, VisionTransformer)) + transformer_model_classes = (PreTrainedModel,) + + if available_packages["torchvision"]: + # Third Party + # pylint: disable = import-error + from torchvision.models import VisionTransformer + + transformer_model_classes += (VisionTransformer,) + + is_transformers = issubclass(type(model), transformer_model_classes) if model_param_size > 1: # Standard import sys @@ -1188,11 +1197,13 @@ def call_seq_hook(mod, *_args, **_kwargs): # only add last layer qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][-1]] - # unless it's a ViT, skip first Conv as well - if issubclass(type(model), VisionTransformer) and isinstance( - model.get_submodule(qcfg["mod_call_seq"][0]), torch.nn.Conv2d - ): - qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][0]] + + if available_packages["torchvision"]: + # unless it's a ViT, skip first Conv as well + if issubclass(type(model), VisionTransformer) and isinstance( + model.get_submodule(qcfg["mod_call_seq"][0]), torch.nn.Conv2d + ): + qcfg["qskip_layer_name"] += [qcfg["mod_call_seq"][0]] with torch.no_grad(): model_opt = torch.compile( @@ -1271,21 +1282,23 @@ def qbmm_auto_check(_mod, *_args, **_kwargs): # c) identify RPN/FPN # TODO this hack only works for torchvision models. will use find_rpn_fpn_gm() - # Third Party - from torchvision.models.detection.rpn import RegionProposalNetwork - from torchvision.ops import FeaturePyramidNetwork - - rpnfpn_prefix = [] - rpnfpn_convs = [] - for n, m in model.named_modules(): - if isinstance(m, (FeaturePyramidNetwork, RegionProposalNetwork)): - rpnfpn_prefix.append(n) - if isinstance(m, torch.nn.Conv2d) and any( - n.startswith(p) for p in rpnfpn_prefix - ): - rpnfpn_convs.append(n) - if n not in qcfg["qskip_layer_name"]: - qcfg["qskip_layer_name"].append(n) + if available_packages["torchvision"]: + # Third Party + # pylint: disable = import-error + from torchvision.models.detection.rpn import RegionProposalNetwork + from torchvision.ops import FeaturePyramidNetwork + + rpnfpn_prefix = [] + rpnfpn_convs = [] + for n, m in model.named_modules(): + if isinstance(m, (FeaturePyramidNetwork, RegionProposalNetwork)): + rpnfpn_prefix.append(n) + if isinstance(m, torch.nn.Conv2d) and any( + n.startswith(p) for p in rpnfpn_prefix + ): + rpnfpn_convs.append(n) + if n not in qcfg["qskip_layer_name"]: + qcfg["qskip_layer_name"].append(n) if qcfg["N_backend_called"] > 1: logger.warning( diff --git a/fms_mo/utils/import_utils.py b/fms_mo/utils/import_utils.py index f4b1538d..51b113ee 100644 --- a/fms_mo/utils/import_utils.py +++ b/fms_mo/utils/import_utils.py @@ -31,6 +31,7 @@ "pygraphviz", "fms", "triton", + "torchvision", ] available_packages = {} diff --git a/pyproject.toml b/pyproject.toml index fe5f086e..11a0cc52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,6 @@ dependencies = [ "ninja>=1.11.1.1,<2.0", "tensorboard", "notebook", -"torchvision>=0.17", "evaluate", "huggingface_hub", "pandas", @@ -42,13 +41,15 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["pre-commit>=3.0.4,<5.0"] fp8 = ["llmcompressor"] gptq = ["Cython", "gptqmodel>=1.7.3"] mx = ["microxcaling>=1.1"] -visualize = ["matplotlib", "graphviz", "pygraphviz"] +opt = ["fms-model-optimizer[fp8, gptq, mx]"] +torchvision = ["torchvision>=0.17"] flash-attn = ["flash-attn>=2.5.3,<3.0"] -opt = ["fms-model-optimizer[fp8, gptq]"] +visualize = ["matplotlib", "graphviz", "pygraphviz"] +dev = ["pre-commit>=3.0.4,<5.0"] +test = ["pytest", "pillow"] [project.urls] homepage = "https://github.com/foundation-model-stack/fms-model-optimizer" diff --git a/tests/models/conftest.py b/tests/models/conftest.py index 9ddc7d0c..8dc70379 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -22,10 +22,11 @@ import os # Third Party +from PIL import Image # pylint: disable=import-error from torch.utils.data import DataLoader, TensorDataset -from torchvision.io import read_image -from torchvision.models import ResNet50_Weights, ViT_B_16_Weights, resnet50, vit_b_16 from transformers import ( + AutoImageProcessor, + AutoModelForImageClassification, BertConfig, BertModel, BertTokenizer, @@ -43,6 +44,7 @@ # fms_mo imports from fms_mo import qconfig_init from fms_mo.modules import QLSTM, QBmm, QConv2d, QConvTranspose2d, QLinear +from fms_mo.utils.import_utils import available_packages from fms_mo.utils.qconfig_utils import get_mx_specs_defaults, set_mx_specs ######################## @@ -1123,75 +1125,155 @@ def required_pair(request): # Vision Model Fixtures # ######################### -# Create img -# downloaded from torchvision github (vision/test/assets/encoder_jpeg/ directory) -img = read_image( + +if available_packages["torchvision"]: + # Third Party + # pylint: disable = import-error + from torchvision.io import read_image + from torchvision.models import ( + ResNet50_Weights, + ViT_B_16_Weights, + resnet50, + vit_b_16, + ) + + # Create img + # downloaded from torchvision github (vision/test/assets/encoder_jpeg/ directory) + img_tv = read_image( + os.path.realpath( + os.path.join(os.path.dirname(__file__), "grace_hopper_517x606.jpg") + ) + ) + + # Create resnet/vitbatch fixtures from weights + def prepocess_img(image, weights): + """ + Preprocess an image w/ a weights.transform() + + Args: + img_tv (torch.FloatTensor): Image data + weights (torchvision.models): Weight object + + Returns: + torch.FloatTensor: Preprocessed image + """ + preprocess = weights.transforms() + batch = preprocess(image).unsqueeze(0) + return batch + + @pytest.fixture(scope="session") + def batch_resnet(): + """ + Preprocess an image w/ Resnet weights.transform() + + Returns: + torch.FloatTensor: Preprocessed image + """ + return prepocess_img(img_tv, ResNet50_Weights.IMAGENET1K_V2) + + @pytest.fixture(scope="session") + def batch_vit(): + """ + Preprocess an image w/ ViT weights.transform() + + Returns: + torch.FloatTensor: Preprocessed image + """ + return prepocess_img(img_tv, ViT_B_16_Weights.IMAGENET1K_V1) + + # Create resnet/vit model fixtures from weights + @pytest.fixture(scope="function") + def model_resnet(): + """ + Create Resnet50 model + weights + + Returns: + torchvision.models.resnet.ResNet: Resnet50 model + """ + return resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) + + @pytest.fixture(scope="function") + def model_vit(): + """ + Create ViT model + weights + + Returns: + torchvision.models.vision_transformer.VisionTransformer: ViT model + """ + return vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1) + + +img = Image.open( os.path.realpath( os.path.join(os.path.dirname(__file__), "grace_hopper_517x606.jpg") ) -) +).convert("RGB") -# Create resnet/vit batch fixtures from weights -def prepocess_img(image, weights): +def process_img( + pretrained_model: str, + input_img: Image.Image, +): """ - Preprocess an image w/ a weights.transform() + Process an image w/ AutoImageProcessor Args: - img (torch.FloatTensor): Image data - weights (torchvision.models): Weight object + processor (AutoImageProcessor): Processor weights for pretrained model + pretrained_model (str): Weight object + input_img (Image.Image): Image data Returns: - torch.FloatTensor: Preprocessed image + torch.FloatTensor: Processed image """ - preprocess = weights.transforms() - batch = preprocess(image).unsqueeze(0) - return batch + img_processor = AutoImageProcessor.from_pretrained(pretrained_model, use_fast=True) + batch_dict = img_processor(images=input_img, return_tensors="pt") + return batch_dict["pixel_values"] -@pytest.fixture(scope="session") -def batch_resnet(): +@pytest.fixture(scope="function") +def batch_resnet18(): """ - Preprocess an image w/ Resnet weights.transform() + Preprocess an image w/ ms resnet18 processor Returns: torch.FloatTensor: Preprocessed image """ - return prepocess_img(img, ResNet50_Weights.IMAGENET1K_V2) + return process_img("microsoft/resnet-18", img) -@pytest.fixture(scope="session") -def batch_vit(): +@pytest.fixture(scope="function") +def model_resnet18(): """ - Preprocess an image w/ ViT weights.transform() + Create MS ResNet18 model + weights Returns: - torch.FloatTensor: Preprocessed image + AutoModelForImageClassification: Resnet18 model """ - return prepocess_img(img, ViT_B_16_Weights.IMAGENET1K_V1) + return AutoModelForImageClassification.from_pretrained("microsoft/resnet-18") -# Create resnet/vit model fixtures from weights @pytest.fixture(scope="function") -def model_resnet(): +def batch_vit_base(): """ - Create Resnet50 model + weights + Preprocess an image w/ Google ViT-base processor Returns: - torchvision.models.resnet.ResNet: Resnet50 model + torch.FloatTensor: Preprocessed image """ - return resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) + return process_img("google/vit-base-patch16-224", img) @pytest.fixture(scope="function") -def model_vit(): +def model_vit_base(): """ - Create ViT model + weights + Create Google ViT-base model + weights Returns: - torchvision.models.vision_transformer.VisionTransformer: ViT model + AutoModelForImageClassification: Google ViT-base model """ - return vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1) + return AutoModelForImageClassification.from_pretrained( + "google/vit-base-patch16-224" + ) ####################### diff --git a/tests/models/test_qmodelprep.py b/tests/models/test_qmodelprep.py index c412d650..d41a5bfe 100644 --- a/tests/models/test_qmodelprep.py +++ b/tests/models/test_qmodelprep.py @@ -19,13 +19,13 @@ # Third Party import pytest import torch -import torchvision import transformers # Local # fms_mo imports from fms_mo import qconfig_init, qmodel_prep from fms_mo.prep import has_quantized_module +from fms_mo.utils.import_utils import available_packages from fms_mo.utils.utils import patch_torch_bmm from tests.models.test_model_utils import count_qmodules, delete_file, qmodule_error @@ -159,8 +159,12 @@ def test_config_fp32_qmodes( ########################### +@pytest.mark.skipif( + not available_packages["torchvision"], + reason="Requires torchvision", +) def test_resnet50_torchscript( - model_resnet: torchvision.models.resnet.ResNet, + model_resnet, batch_resnet: torch.FloatTensor, config_int8: dict, ): @@ -177,8 +181,12 @@ def test_resnet50_torchscript( qmodule_error(model_resnet, 6, 48) +@pytest.mark.skipif( + not available_packages["torchvision"], + reason="Requires torchvision", +) def test_resnet50_dynamo( - model_resnet: torchvision.models.resnet.ResNet, + model_resnet, batch_resnet: torch.FloatTensor, config_int8: dict, ): @@ -195,8 +203,12 @@ def test_resnet50_dynamo( qmodule_error(model_resnet, 6, 48) +@pytest.mark.skipif( + not available_packages["torchvision"], + reason="Requires torchvision", +) def test_resnet50_dynamo_layers( - model_resnet: torchvision.models.resnet.ResNet, + model_resnet, batch_resnet: torch.FloatTensor, config_int8: dict, ): @@ -216,8 +228,12 @@ def test_resnet50_dynamo_layers( # Vision Transformer tests +@pytest.mark.skipif( + not available_packages["torchvision"], + reason="Requires torchvision", +) def test_vit_torchscript( - model_vit: torchvision.models.vision_transformer.VisionTransformer, + model_vit, batch_vit: torch.FloatTensor, config_int8: dict, ): @@ -234,8 +250,12 @@ def test_vit_torchscript( qmodule_error(model_vit, 2, 36) +@pytest.mark.skipif( + not available_packages["torchvision"], + reason="Requires torchvision", +) def test_vit_dynamo( - model_vit: torchvision.models.vision_transformer.VisionTransformer, + model_vit, batch_vit: torch.FloatTensor, config_int8: dict, ): @@ -252,6 +272,42 @@ def test_vit_dynamo( qmodule_error(model_vit, 2, 36) +def test_resnet18( + model_resnet18, + batch_resnet18, + config_int8: dict, +): + """ + Perform int8 quantization on ResNet-18 w/ Dynamo tracer + + Args: + model_resnet18 (AutoModelForImageClassification): Resnet18 model + weights + batch_resnet18 (torch.FloatTensor): Batch image data for Resnet18 + config (dict): Recipe Config w/ int8 settings + """ + # Run qmodel_prep w/ Dynamo tracer + qmodel_prep(model_resnet18, batch_resnet18, config_int8, use_dynamo=True) + qmodule_error(model_resnet18, 4, 17) + + +def test_vit_base( + model_vit_base, + batch_vit_base, + config_int8: dict, +): + """ + Perform int8 quantization on ViT-base w/ Dynamo tracer + + Args: + model_vit_base (AutoModelForImageClassification): Resnet18 model + weights + batch_vit_base (torch.FloatTensor): Batch image data for Resnet18 + config (dict): Recipe Config w/ int8 settings + """ + # Run qmodel_prep w/ Dynamo tracer + qmodel_prep(model_vit_base, batch_vit_base, config_int8, use_dynamo=True) + qmodule_error(model_vit_base, 1, 73) + + def test_bert_dynamo( model_bert: transformers.models.bert.modeling_bert.BertModel, input_bert: torch.FloatTensor, diff --git a/tox.ini b/tox.ini index 447c5513..2f2d0441 100644 --- a/tox.ini +++ b/tox.ini @@ -6,6 +6,7 @@ minversion = 4.4 description = run tests (unit, unitcov) extras = dev + test package = wheel wheel_build_env = pkg deps =