diff --git a/.github/workflows/build-pipeline.yml b/.github/workflows/build-pipeline.yml index 1287d3b..9233b18 100644 --- a/.github/workflows/build-pipeline.yml +++ b/.github/workflows/build-pipeline.yml @@ -26,7 +26,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install .[dev] + pip install .[dev] numpy torch pillow - name: Run Linting run: | ruff check . diff --git a/README.md b/README.md index 864f2b9..ede3c08 100644 --- a/README.md +++ b/README.md @@ -173,6 +173,16 @@ These nodes are very lightweight and require no additional dependencies. - Path conversions: - relative - Computes a relative path from a start path to a target path - expand_vars - Replaces environment variables in a path with their values +- File loading: + - load STRING from file - Loads a text file and returns its content as a STRING + - load IMAGE from file (RGB) - Loads an image and returns RGB channels as a tensor + - load IMAGE+MASK from file (RGBA) - Loads an image and returns RGB channels as a tensor and alpha channel as a mask + - load MASK from alpha channel - Loads an image and extracts its alpha channel as a mask + - load MASK from greyscale/red - Loads an image and creates a mask from its greyscale or red channel +- File saving: + - save STRING to file - Saves a string to a text file with optional directory creation + - save IMAGE to file - Saves an image tensor to a file in various formats (PNG, JPG, WEBP, JXL) + - save IMAGE+MASK to file - Saves an image with transparency using a mask as the alpha channel ### SET: Python set manipulation nodes (as a single variable) - Creation: diff --git a/pyproject.toml b/pyproject.toml index f55ac08..1d99235 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "basic_data_handling" -version = "0.4.3" +version = "0.4.4" description = """Basic Python functions for manipulating data that every programmer is used to, lightweight with no additional dependencies. Supported data types: @@ -20,7 +20,7 @@ Feature categories: - Mathematical operations - Mathematical formula node in a safe implementation - String manipulation -- File system path handling +- File system path handling, including STRING, IMAGE and MASK load and save - SET operations""" authors = [ {name = "StableLlama"} diff --git a/src/basic_data_handling/path_nodes.py b/src/basic_data_handling/path_nodes.py index a7af9e9..4e6fb9a 100644 --- a/src/basic_data_handling/path_nodes.py +++ b/src/basic_data_handling/path_nodes.py @@ -11,10 +11,79 @@ class IO: FLOAT = "FLOAT" STRING = "STRING" NUMBER = "FLOAT,INT" + IMAGE = "IMAGE" + MASK = "MASK" ANY = "*" ComfyNodeABC = object +# helper functions: + +def load_image_helper(path: str): + """Helper function to load an image from a path""" + from PIL import Image, ImageOps + try: + import pillow_jxl # noqa: F401 - imported but unused, kept for JPEG XL support + except ModuleNotFoundError: + pass + + if not os.path.exists(path): + raise FileNotFoundError(f"Basic data handling: Image file not found: {path}") + + # Open and process the image + img = Image.open(path) + img = ImageOps.exif_transpose(img) + + return img + + +def extract_mask_from_alpha(img): + """Extract a mask from the alpha channel of an image""" + import numpy as np + import torch + + if 'A' in img.getbands(): + alpha = np.array(img.getchannel('A')).astype(np.float32) / 255.0 + mask_tensor = 1.0 - torch.from_numpy(alpha) + elif img.mode == 'P' and 'transparency' in img.info: + alpha = np.array(img.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 + mask_tensor = 1.0 - torch.from_numpy(alpha) + else: + # Create a blank mask if no alpha channel + mask_tensor = torch.zeros((img.height, img.width), dtype=torch.float32) + + # Add batch dimension to mask + mask_tensor = mask_tensor.unsqueeze(0) + + return mask_tensor + + +def extract_mask_from_greyscale(img): + """Extract a mask from a greyscale image or the red channel of an RGB image""" + import numpy as np + import torch + + if img.mode == 'L': + # Image is already greyscale + gray = np.array(img).astype(np.float32) / 255.0 + elif img.mode == 'RGB' or img.mode == 'RGBA': + # Use the red channel of RGB or RGBA + gray = np.array(img.getchannel('R')).astype(np.float32) / 255.0 + else: + # Convert to greyscale if it's another format + gray_img = img.convert('L') + gray = np.array(gray_img).astype(np.float32) / 255.0 + + # Convert to tensor and invert (white pixels in image = transparent in mask) + mask_tensor = 1.0 - torch.from_numpy(gray) + + # Add batch dimension + mask_tensor = mask_tensor.unsqueeze(0) + + return mask_tensor + +# the nodes: + class PathAbspath(ComfyNodeABC): """ Returns the absolute path of a file or directory. @@ -271,6 +340,35 @@ def INPUT_TYPES(cls): FUNCTION = "glob_paths" OUTPUT_IS_LIST = (True,) + # Class variable to store the last matched paths + _last_matched_paths = {} + + @classmethod + def IS_CHANGED(s, pattern: str, recursive: bool = False): + # Get current paths + current_paths = glob.glob(pattern, recursive=recursive) + + # Create a key for this specific pattern and recursive setting + key = f"{pattern}_{recursive}" + + # If we haven't seen this pattern before, store it and trigger recalculation + if key not in s._last_matched_paths: + s._last_matched_paths[key] = current_paths + return float("NaN") + + # Compare with previous paths + previous_paths = s._last_matched_paths[key] + if previous_paths != current_paths: + # Update stored paths and trigger recalculation + s._last_matched_paths[key] = current_paths + return float("NaN") + + # No changes, return a consistent value + import hashlib + m = hashlib.md5() + m.update(str(current_paths).encode()) + return m.hexdigest() + def glob_paths(self, pattern: str, recursive: bool = False) -> tuple[list[str]]: return (glob.glob(pattern, recursive=recursive),) @@ -564,6 +662,415 @@ def split_ext(self, path: str) -> tuple[str, str]: return os.path.splitext(path) +class PathLoadStringFile(ComfyNodeABC): + """ + Loads a text file in UTF-8 encoding and returns its content as a STRING + without any further processing. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "path": (IO.STRING, {"default": ""}), + }, + } + + RETURN_TYPES = (IO.STRING,) + RETURN_NAMES = ("text",) + CATEGORY = "Basic/Path" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "load_text" + + def load_text(self, path: str): + if not os.path.exists(path): + raise FileNotFoundError(f"Basic data handling: String file not found: {path}") + + with open(path, "r", encoding="utf-8") as f: + text = f.read() + return (text,) + + +class PathLoadImageRGB(ComfyNodeABC): + """ + Loads an image from a file path and returns only the RGB channels. + + This node loads an image from the specified path and processes it to + return only the RGB channels as a tensor, ignoring any alpha channel. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "path": (IO.STRING, {"default": ""}), + }, + } + + RETURN_TYPES = (IO.IMAGE,) + RETURN_NAMES = ("image",) + CATEGORY = "Basic/Path" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "load_image_rgb" + + def load_image_rgb(self, path: str): + import numpy as np + import torch + + img = load_image_helper(path) + + # Convert to RGB (removing alpha if present) + img_rgb = img.convert("RGB") + + # Convert to tensor format expected by ComfyUI + image_tensor = np.array(img_rgb).astype(np.float32) / 255.0 + image_tensor = torch.from_numpy(image_tensor)[None,] + + return (image_tensor,) + + +class PathLoadImageRGBA(ComfyNodeABC): + """ + Loads an image from a file path and returns RGB channels and Alpha as a mask. + + This node loads an image from the specified path and processes it to + return the RGB channels as a tensor and the Alpha channel as a mask tensor. + If the image has no alpha channel, a blank mask is returned. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "path": (IO.STRING, {"default": ""}), + }, + } + + RETURN_TYPES = (IO.IMAGE, IO.MASK) + RETURN_NAMES = ("image", "mask") + CATEGORY = "Basic/Path" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "load_image_rgba" + + def load_image_rgba(self, path: str): + import numpy as np + import torch + + img = load_image_helper(path) + + # Convert to RGB for the image + img_rgb = img.convert("RGB") + + # Convert to tensor format expected by ComfyUI + image_tensor = np.array(img_rgb).astype(np.float32) / 255.0 + image_tensor = torch.from_numpy(image_tensor)[None,] + + # Extract alpha channel as mask + mask_tensor = extract_mask_from_alpha(img) + + return (image_tensor, mask_tensor) + + +class PathLoadMaskFromAlpha(ComfyNodeABC): + """ + Loads a mask from the alpha channel of an image. + + This node loads an image from the specified path and extracts the alpha + channel to use as a mask. If the image has no alpha channel, a blank mask + is returned. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "path": (IO.STRING, {"default": ""}), + }, + } + + RETURN_TYPES = (IO.MASK,) + RETURN_NAMES = ("mask",) + CATEGORY = "Basic/Path" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "load_mask_from_alpha" + + def load_mask_from_alpha(self, path: str): + img = load_image_helper(path) + mask_tensor = extract_mask_from_alpha(img) + return (mask_tensor,) + + +class PathLoadMaskFromGreyscale(ComfyNodeABC): + """ + Loads a mask from a greyscale image or the red channel of an RGB image. + + This node loads an image from the specified path and creates a mask from it. + If the image is greyscale, the intensity is used directly. + If the image is RGB, the red channel is used. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "path": (IO.STRING, {"default": ""}), + }, + "optional": { + "invert": (IO.BOOLEAN, {"default": False}), + }, + } + + RETURN_TYPES = (IO.MASK,) + RETURN_NAMES = ("mask",) + CATEGORY = "Basic/Path" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "load_mask_from_greyscale" + + def load_mask_from_greyscale(self, path: str, invert: bool = False): + img = load_image_helper(path) + mask_tensor = extract_mask_from_greyscale(img) + + # Optionally invert the mask (1.0 - mask) + if invert: + mask_tensor = 1.0 - mask_tensor + + return (mask_tensor,) + + +class PathSaveStringFile(ComfyNodeABC): + """ + Saves a string to a text file. + + This node takes a string and saves it to the specified path as a text file. + Optionally, you can choose to create the directory if it doesn't exist. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "text": (IO.STRING, {"default": ""}), + "path": (IO.STRING, {"default": ""}), + }, + "optional": { + "create_dirs": (IO.BOOLEAN, {"default": True}), + "encoding": (IO.STRING, {"default": "utf-8"}), + } + } + + RETURN_TYPES = (IO.BOOLEAN) + RETURN_NAMES = ("success") + CATEGORY = "Basic/Path" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "save_text" + OUTPUT_NODE = True + + def save_text(self, text: str, path: str, create_dirs: bool = True, encoding: str = "utf-8"): + if not path: + return (False,) + + try: + # Create directories if needed + directory = os.path.dirname(path) + if directory and create_dirs and not os.path.exists(directory): + os.makedirs(directory) + + with open(path, "w", encoding=encoding) as f: + f.write(text) + + return (True,) + except Exception as e: + print(f"Basic data handling: Error saving text file: {e}") + return (False,) + + +class PathSaveImageRGB(ComfyNodeABC): + """ + Saves an image to a file. + + This node takes an image tensor and saves it to the specified path. + Supports various image formats like PNG, JPG, WEBP, JXL (if pillow-jxl is installed), etc. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": (IO.IMAGE,), + "path": (IO.STRING, {"default": ""}), + }, + "optional": { + "format": (IO.STRING, {"default": "png"}), + "quality": (IO.INT, {"default": 95, "min": 1, "max": 100}), + "create_dirs": (IO.BOOLEAN, {"default": True}), + } + } + + RETURN_TYPES = (IO.BOOLEAN) + RETURN_NAMES = ("success") + CATEGORY = "Basic/Path" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "save_image" + OUTPUT_NODE = True + + def save_image(self, images, path: str, format: str = "png", quality: int = 95, create_dirs: bool = True): + if not path: + return (False,) + + # If the path doesn't have an extension or it doesn't match the format, add it + if not path.lower().endswith(f".{format.lower()}"): + path = f"{path}.{format.lower()}" + + try: + import numpy as np + from PIL import Image + + # Check if pillow_jxl is available for JXL support + has_jxl_support = False + try: + import pillow_jxl # noqa: F401 - imported but unused, kept for JPEG XL support + has_jxl_support = True + except ModuleNotFoundError: + # pillow_jxl is not installed + if format.lower() == "jxl": + print("Basic data handling: JPEG XL format requested but pillow_jxl module is not installed. " + "Please install it with 'pip install pillow-jxl-plugin'.") + return (False,) + + # Create directories if needed + directory = os.path.dirname(path) + if directory and create_dirs and not os.path.exists(directory): + os.makedirs(directory) + + # Convert from tensor format back to PIL Image + # Extract the first image from the batch + i = 0 + img_tensor = images[i].cpu().numpy() + + # Convert to uint8 format for PIL + img_np = (img_tensor * 255).astype(np.uint8) + + # Create PIL image + pil_img = Image.fromarray(img_np) + + # Save the image + if format.lower() == "jpg" or format.lower() == "jpeg": + pil_img.save(path, format="JPEG", quality=quality) + elif format.lower() == "webp": + pil_img.save(path, format="WEBP", quality=quality) + elif format.lower() == "jxl" and has_jxl_support: + # JPEG XL specific options + pil_img.save(path, format="JXL", quality=quality) + else: + pil_img.save(path, format=format.upper()) + + return (True,) + except Exception as e: + print(f"Basic data handling: Error saving image: {e}") + return (False,) + + +class PathSaveImageRGBA(ComfyNodeABC): + """ + Saves an image with a mask to a file with transparency. + + This node takes an image tensor and a mask tensor and saves them to the + specified path as an image with transparency, where the mask defines the + alpha channel. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "images": (IO.IMAGE,), + "mask": (IO.MASK,), + "path": (IO.STRING, {"default": ""}), + }, + "optional": { + "format": (IO.STRING, {"default": "png"}), + "quality": (IO.INT, {"default": 95, "min": 1, "max": 100}), + "invert_mask": (IO.BOOLEAN, {"default": False}), + "create_dirs": (IO.BOOLEAN, {"default": True}), + } + } + + RETURN_TYPES = (IO.BOOLEAN) + RETURN_NAMES = ("success") + CATEGORY = "Basic/Path" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "save_image_with_mask" + OUTPUT_NODE = True + + def save_image_with_mask(self, images, mask, path: str, format: str = "png", + quality: int = 95, invert_mask: bool = False, + create_dirs: bool = True): + if not path: + return (False,) + + # Check format compatibility - needs to support alpha channel + if format.lower() in ["jpg", "jpeg"]: + print("Basic data handling: JPEG format doesn't support transparency. Using PNG instead.") + format = "png" + + # If the path doesn't have an extension or it doesn't match the format, add it + if not path.lower().endswith(f".{format.lower()}"): + path = f"{path}.{format.lower()}" + + try: + import numpy as np + from PIL import Image + + # Check if pillow_jxl is available for JXL support + has_jxl_support = False + try: + import pillow_jxl # noqa: F401 - imported but unused, kept for JPEG XL support + has_jxl_support = True + except ModuleNotFoundError: + # pillow_jxl is not installed + if format.lower() == "jxl": + print("Basic data handling: JPEG XL format requested but pillow_jxl module is not installed. " + "Please install it with 'pip install pillow-jxl-plugin'.") + return (False,) + + # Create directories if needed + directory = os.path.dirname(path) + if directory and create_dirs and not os.path.exists(directory): + os.makedirs(directory) + + # Convert from tensor format back to PIL Image + # Extract the first image from the batch + i = 0 + img_tensor = images[i].cpu().numpy() + mask_tensor = mask[i].cpu() + + # Invert the mask if needed (1.0 becomes transparent, 0.0 becomes opaque) + if invert_mask: + mask_tensor = 1.0 - mask_tensor + + # Convert to alpha channel (0-255) + alpha_np = (255.0 * (1.0 - mask_tensor.numpy())).astype(np.uint8) + + # Convert to uint8 format for PIL + img_np = (img_tensor * 255).astype(np.uint8) + + # Create PIL image (RGB) + pil_img = Image.fromarray(img_np) + + # Create alpha channel image + alpha_img = Image.fromarray(alpha_np, mode='L') + + # Convert to RGBA and add alpha channel + pil_img_rgba = pil_img.convert("RGBA") + pil_img_rgba.putalpha(alpha_img) + + # Save the image + if format.lower() == "webp": + pil_img_rgba.save(path, format="WEBP", quality=quality) + elif format.lower() == "jxl" and has_jxl_support: + # JPEG XL supports alpha channel + pil_img_rgba.save(path, format="JXL", quality=quality) + else: + pil_img_rgba.save(path, format=format.upper()) + + return (True,) + except Exception as e: + print(f"Basic data handling: Error saving image with mask: {e}") + return (False,) + + NODE_CLASS_MAPPINGS = { "Basic data handling: PathAbspath": PathAbspath, "Basic data handling: PathBasename": PathBasename, @@ -585,6 +1092,14 @@ def split_ext(self, path: str) -> tuple[str, str]: "Basic data handling: PathRelative": PathRelative, "Basic data handling: PathSplit": PathSplit, "Basic data handling: PathSplitExt": PathSplitExt, + "Basic data handling: PathLoadStringFile": PathLoadStringFile, + "Basic data handling: PathLoadImageRGB": PathLoadImageRGB, + "Basic data handling: PathLoadImageRGBA": PathLoadImageRGBA, + "Basic data handling: PathLoadMaskFromAlpha": PathLoadMaskFromAlpha, + "Basic data handling: PathLoadMaskFromGreyscale": PathLoadMaskFromGreyscale, + "Basic data handling: PathSaveStringFile": PathSaveStringFile, + "Basic data handling: PathSaveImageRGB": PathSaveImageRGB, + "Basic data handling: PathSaveImageRGBA": PathSaveImageRGBA, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -608,4 +1123,12 @@ def split_ext(self, path: str) -> tuple[str, str]: "Basic data handling: PathRelative": "relative", "Basic data handling: PathSplit": "split", "Basic data handling: PathSplitExt": "splitext", + "Basic data handling: PathLoadStringFile": "load STRING from file", + "Basic data handling: PathLoadImageRGB": "load IMAGE from file (RGB)", + "Basic data handling: PathLoadImageRGBA": "load IMAGE+MASK from file (RGBA)", + "Basic data handling: PathLoadMaskFromAlpha": "load MASK from alpha channel", + "Basic data handling: PathLoadMaskFromGreyscale": "load MASK from greyscale/red", + "Basic data handling: PathSaveStringFile": "save STRING to file", + "Basic data handling: PathSaveImageRGB": "save IMAGE to file", + "Basic data handling: PathSaveImageRGBA": "save IMAGE+MASK to file", } diff --git a/tests/test_path_nodes.py b/tests/test_path_nodes.py index 0052253..dab0b9c 100644 --- a/tests/test_path_nodes.py +++ b/tests/test_path_nodes.py @@ -1,12 +1,17 @@ import os import pytest import platform +import numpy as np +import torch +from PIL import Image from src.basic_data_handling.path_nodes import ( PathJoin, PathAbspath, PathExists, PathIsFile, PathIsDir, PathGetSize, PathSplit, PathSplitExt, PathBasename, PathDirname, PathGetExtension, PathSetExtension, PathNormalize, PathRelative, PathGlob, PathExpandVars, PathGetCwd, - PathListDir, PathIsAbsolute, PathCommonPrefix, + PathListDir, PathIsAbsolute, PathCommonPrefix, PathLoadStringFile, PathSaveStringFile, + PathLoadImageRGB, PathSaveImageRGB, PathLoadImageRGBA, PathSaveImageRGBA, + PathLoadMaskFromAlpha, PathLoadMaskFromGreyscale, ) @@ -21,6 +26,39 @@ def test_path_join(): nested_path = os.path.join("folder", "subfolder", "file.txt") assert node.join_paths("folder", os.path.join("subfolder", "file.txt")) == (nested_path,) +def test_path_load_save_string_file(tmp_path): + # Test saving a string to a file + save_node = PathSaveStringFile() + test_string = "This is a test string\nwith multiple lines" + file_path = str(tmp_path / "test_string.txt") + + # Save the string + assert save_node.save_text(test_string, file_path) == (True,) + + # Verify file exists + assert os.path.exists(file_path) + + # Test loading the string back + load_node = PathLoadStringFile() + loaded_string = load_node.load_text(file_path) + assert loaded_string == (test_string,) + + # Test creating directories when saving + nested_path = str(tmp_path / "nested" / "dir" / "test.txt") + assert save_node.save_text("Test with nested dirs", nested_path) == (True,) + assert os.path.exists(nested_path) + + # Test different encodings + utf8_text = "UTF-8 text with special chars: 你好, ñ, é, ö" + utf8_path = str(tmp_path / "utf8_test.txt") + assert save_node.save_text(utf8_text, utf8_path, encoding="utf-8") == (True,) + assert load_node.load_text(utf8_path) == (utf8_text,) + + # Test error handling + with pytest.raises(FileNotFoundError): + load_node.load_text(str(tmp_path / "nonexistent.txt")) + + def test_path_abspath(): node = PathAbspath() @@ -217,6 +255,193 @@ def test_path_set_extension(): assert node.set_extension("", ".txt") == (".txt",) +def test_path_load_mask_nodes(tmp_path, monkeypatch): + # Create test images + img_size = (64, 64) + + # Image with alpha channel + rgba_img = Image.new('RGBA', img_size, color=(255, 0, 0, 128)) # Semi-transparent red + rgba_path = str(tmp_path / "alpha_mask.png") + rgba_img.save(rgba_path) + + # Grayscale image + gray_img = Image.new('L', img_size) + # Create a gradient from black to white + for y in range(img_size[1]): + for x in range(img_size[0]): + gray_img.putpixel((x, y), int(255 * (x / img_size[0]))) + gray_path = str(tmp_path / "gray_mask.png") + gray_img.save(gray_path) + + # Mock the helper functions + def mock_load_image_helper(path): + if not os.path.exists(path): + raise FileNotFoundError(f"Image file not found: {path}") + return Image.open(path) + + def mock_extract_mask_from_alpha(img): + if 'A' in img.getbands(): + alpha = np.array(img.getchannel('A')).astype(np.float32) / 255.0 + mask_tensor = 1.0 - torch.from_numpy(alpha) + return mask_tensor.unsqueeze(0) + return torch.zeros((1, img.height, img.width), dtype=torch.float32) + + def mock_extract_mask_from_greyscale(img): + if img.mode == 'L': + gray = np.array(img).astype(np.float32) / 255.0 + else: + gray = np.array(img.getchannel('R')).astype(np.float32) / 255.0 + mask_tensor = 1.0 - torch.from_numpy(gray) + return mask_tensor.unsqueeze(0) + + monkeypatch.setattr("src.basic_data_handling.path_nodes.load_image_helper", mock_load_image_helper) + monkeypatch.setattr("src.basic_data_handling.path_nodes.extract_mask_from_alpha", mock_extract_mask_from_alpha) + monkeypatch.setattr("src.basic_data_handling.path_nodes.extract_mask_from_greyscale", mock_extract_mask_from_greyscale) + + # Test loading mask from alpha channel + alpha_node = PathLoadMaskFromAlpha() + alpha_mask = alpha_node.load_mask_from_alpha(rgba_path) + + # Verify the mask shape + assert isinstance(alpha_mask, tuple) + assert len(alpha_mask) == 1 + assert isinstance(alpha_mask[0], torch.Tensor) + assert alpha_mask[0].shape == (1, img_size[1], img_size[0]) + + # Test loading mask from grayscale image + gray_node = PathLoadMaskFromGreyscale() + gray_mask = gray_node.load_mask_from_greyscale(gray_path) + + # Verify the mask shape + assert isinstance(gray_mask, tuple) + assert len(gray_mask) == 1 + assert isinstance(gray_mask[0], torch.Tensor) + assert gray_mask[0].shape == (1, img_size[1], img_size[0]) + + # Test with invert option + inverted_mask = gray_node.load_mask_from_greyscale(gray_path, invert=True) + assert torch.allclose(inverted_mask[0], 1.0 - gray_mask[0]) + + # Test error handling + with pytest.raises(FileNotFoundError): + alpha_node.load_mask_from_alpha(str(tmp_path / "nonexistent.png")) + + +def test_path_load_save_image_rgba(tmp_path, monkeypatch): + # Create a test image with transparency + img_size = (64, 64) + test_img = Image.new('RGBA', img_size, color=(255, 0, 0, 128)) # Semi-transparent red + img_path = str(tmp_path / "test_rgba.png") + test_img.save(img_path) + + # Mock the load_image_helper and extraction functions + def mock_load_image_helper(path): + if not os.path.exists(path): + raise FileNotFoundError(f"Image file not found: {path}") + return Image.open(path) + + def mock_extract_mask_from_alpha(img): + if 'A' in img.getbands(): + alpha = np.array(img.getchannel('A')).astype(np.float32) / 255.0 + mask_tensor = 1.0 - torch.from_numpy(alpha) + return mask_tensor.unsqueeze(0) + return torch.zeros((1, img.height, img.width), dtype=torch.float32) + + monkeypatch.setattr("src.basic_data_handling.path_nodes.load_image_helper", mock_load_image_helper) + monkeypatch.setattr("src.basic_data_handling.path_nodes.extract_mask_from_alpha", mock_extract_mask_from_alpha) + + # Test loading an image with alpha + load_node = PathLoadImageRGBA() + loaded_img, loaded_mask = load_node.load_image_rgba(img_path) + + # Verify that the returned objects are tensors with the right shapes + assert isinstance(loaded_img, torch.Tensor) + assert isinstance(loaded_mask, torch.Tensor) + assert loaded_img.shape == (1, img_size[1], img_size[0], 3) # (batch, height, width, channels) + assert loaded_mask.shape == (1, img_size[1], img_size[0]) # (batch, height, width) + + # Test saving an image with mask + save_node = PathSaveImageRGBA() + output_path = str(tmp_path / "output_rgba") + + # Create a simple red test image tensor and a gradient mask + red_img = torch.zeros(1, img_size[1], img_size[0], 3) + red_img[0, :, :, 0] = 1.0 # Red channel set to 1 + + # Create a gradient mask (0.0 to 1.0 from left to right) + mask = torch.zeros(1, img_size[1], img_size[0]) + for i in range(img_size[0]): + mask[0, :, i] = i / img_size[0] + + # Save the image with mask + assert save_node.save_image_with_mask(red_img, mask, output_path) == (True,) + + # Verify the image was saved with the correct extension + assert os.path.exists(output_path + ".png") + + # Test with invert_mask option + assert save_node.save_image_with_mask(red_img, mask, str(tmp_path / "inverted"), invert_mask=True) == (True,) + assert os.path.exists(str(tmp_path / "inverted.png")) + + # Test with JPEG format (should switch to PNG for transparency) + assert save_node.save_image_with_mask(red_img, mask, str(tmp_path / "jpeg_test"), format="jpg") == (True,) + # Should be saved as PNG despite the request for JPEG + assert os.path.exists(str(tmp_path / "jpeg_test.png")) + + +def test_path_load_save_image_rgb(tmp_path, monkeypatch): + # Create a test image + img_size = (64, 64) + test_img = Image.new('RGB', img_size, color='red') + img_path = str(tmp_path / "test_rgb.png") + test_img.save(img_path) + + # Mock the load_image_helper function to avoid PIL import issues in testing + def mock_load_image_helper(path): + if not os.path.exists(path): + raise FileNotFoundError(f"Image file not found: {path}") + return Image.open(path) + + monkeypatch.setattr("src.basic_data_handling.path_nodes.load_image_helper", mock_load_image_helper) + + # Test loading an image + load_node = PathLoadImageRGB() + loaded_img = load_node.load_image_rgb(img_path) + + # Verify that the returned object is a tensor with the right shape + assert isinstance(loaded_img, tuple) + assert len(loaded_img) == 1 + assert isinstance(loaded_img[0], torch.Tensor) + assert loaded_img[0].shape == (1, img_size[1], img_size[0], 3) # (batch, height, width, channels) + + # Test saving an image + save_node = PathSaveImageRGB() + output_path = str(tmp_path / "output_rgb") + + # Create a simple red test image tensor + red_img = torch.zeros(1, img_size[1], img_size[0], 3) + red_img[0, :, :, 0] = 1.0 # Red channel set to 1 + + # Save the image + assert save_node.save_image(red_img, output_path) == (True,) + + # Verify the image was saved with the correct extension + assert os.path.exists(output_path + ".png") + + # Test different formats + assert save_node.save_image(red_img, str(tmp_path / "jpeg_test"), format="jpg") == (True,) + assert os.path.exists(str(tmp_path / "jpeg_test.jpg")) + + # Test with directories that don't exist + nested_path = str(tmp_path / "nested" / "images" / "test_rgb") + assert save_node.save_image(red_img, nested_path) == (True,) + assert os.path.exists(nested_path + ".png") + + # Test error handling for loading + with pytest.raises(FileNotFoundError): + load_node.load_image_rgb(str(tmp_path / "nonexistent.png")) + + def test_path_normalize(): node = PathNormalize()