diff --git a/tensorrt_loader.py b/tensorrt_loader.py index 5e2ccac..7b1b4bc 100644 --- a/tensorrt_loader.py +++ b/tensorrt_loader.py @@ -114,7 +114,7 @@ class TensorRTLoader: @classmethod def INPUT_TYPES(s): return {"required": {"unet_name": (folder_paths.get_filename_list("tensorrt"), ), - "model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd", "sd3", "auraflow", "flux_dev", "flux_schnell"], ), + "model_type": (["sdxl_base", "sdxl_refiner", "sd15_instructpix2pix", "sdxl_instructpix2pix", "sd1.x", "sd2.x-768v", "svd", "sd3", "auraflow", "flux_dev", "flux_schnell"], ), }} RETURN_TYPES = ("MODEL",) FUNCTION = "load_unet" @@ -134,6 +134,14 @@ def load_unet(self, unet_name, model_type): {"adm_in_channels": 2560}) conf.unet_config["disable_unet_model_creation"] = True model = comfy.model_base.SDXLRefiner(conf) + elif model_type == "sdxl_instructpix2pix": + conf = comfy.supported_models.SDXL_instructpix2pix({}) + conf.unet_config["disable_unet_model_creation"] = True + model = comfy.model_base.SDXL_instructpix2pix(conf) + elif model_type == "sd15_instructpix2pix": + conf = comfy.supported_models.SD15_instructpix2pix({}) + conf.unet_config["disable_unet_model_creation"] = True + model = comfy.model_base.SD15_instructpix2pix(conf) elif model_type == "sd1.x": conf = comfy.supported_models.SD15({}) conf.unet_config["disable_unet_model_creation"] = True