From 83bf8b76604caa11e51132bd64b172907e17e42d Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Wed, 5 Jun 2024 03:14:33 -0600 Subject: [PATCH] Support loading CosXL engines. --- tensorrt_loader.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tensorrt_loader.py b/tensorrt_loader.py index 8ceff29..11098ef 100644 --- a/tensorrt_loader.py +++ b/tensorrt_loader.py @@ -110,7 +110,9 @@ class TensorRTLoader: @classmethod def INPUT_TYPES(s): return {"required": {"unet_name": (folder_paths.get_filename_list("tensorrt"), ), - "model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd", "sd3", "auraflow"], ), + "model_type": (["sdxl_base", "sdxl_refiner", + "sd1.x", "sd2.x-768v", "svd", + "sd3", "auraflow", "cosxl"], ), }} RETURN_TYPES = ("MODEL",) FUNCTION = "load_unet" @@ -150,6 +152,10 @@ def load_unet(self, unet_name, model_type): conf = comfy.supported_models.AuraFlow({}) conf.unet_config["disable_unet_model_creation"] = True model = conf.get_model({}) + elif model_type == "cosxl": + conf = comfy.supported_models.SDXL({"adm_in_channels": 2816}) + conf.unet_config["disable_unet_model_creation"] = True + model = comfy.model_base.SDXL(conf, model_type=comfy.model_base.ModelType.V_PREDICTION_EDM) model.diffusion_model = unet model.memory_required = lambda *args, **kwargs: 0 #always pass inputs batched up as much as possible, our TRT code will handle batch splitting @@ -159,4 +165,4 @@ def load_unet(self, unet_name, model_type): NODE_CLASS_MAPPINGS = { "TensorRTLoader": TensorRTLoader, -} \ No newline at end of file +}