diff --git a/README.md b/README.md index 263373a..88706f4 100644 --- a/README.md +++ b/README.md @@ -274,6 +274,46 @@ ComfyUI has different data types that serve different purposes: - Supports built-in ComfyUI iteration over each item - Best for: - Working directly with multiple items in parallel + - Processing each item in a collection separately + - When you need ComfyUI's automatic iteration functionality + +### 2. LIST +- A Python list passed as a single ComfyUI variable +- Must be processed as a complete unit by compatible nodes +- Operations apply to the entire LIST at once +- Best for: + - Storing and manipulating structured data as a single unit + - When you need to preserve ordered collections + - Passing complex data structures between nodes + +### 3. SET +- A Python set passed as a single ComfyUI variable +- Unordered collection of unique items +- Useful for membership testing, removing duplicates, and set operations +- Best for: + - When you need to ensure uniqueness of items + - Performing mathematical set operations (union, intersection, difference) + - Efficient membership testing (contains operation) + - When item order doesn't matter + +## Control Flow Nodes + +Control flow nodes provide mechanisms to direct the flow of execution in your ComfyUI workflows, allowing for conditional processing and dynamic execution paths. + +### Available Control Flow Nodes: + +#### Conditional Processing +- **if/else** - Routes execution based on a boolean condition +- **if/elif/.../else** - Supports multiple conditional branches +- **switch/case** - Selects from multiple options based on an index + +#### Execution Management +- **disable flow** - Conditionally enables or disables a flow +- **flow select** - Directs output to either "true" or "false" path +- **force calculation** - Prevents caching and forces recalculation +- **force execution order** - Controls the sequence of node execution + +These control flow nodes enable building more complex, dynamic workflows with decision-making capabilities based on runtime conditions. - Batch processing scenarios - When you need to apply the same operation to multiple inputs - When your operation needs to work with individual items separately diff --git a/pyproject.toml b/pyproject.toml index 7dcd4fd..ba588ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "basic_data_handling" -version = "0.4.0" +version = "0.4.1" description = """Basic Python functions for manipulating data that every programmer is used to. Comprehensive node collection for data manipulation in ComfyUI workflows. diff --git a/src/basic_data_handling/control_flow_nodes.py b/src/basic_data_handling/control_flow_nodes.py index a18d53a..772d622 100644 --- a/src/basic_data_handling/control_flow_nodes.py +++ b/src/basic_data_handling/control_flow_nodes.py @@ -90,7 +90,7 @@ def check_lazy_status(self, **kwargs) -> list[str]: # Check if condition if kwargs.get("if", False) and kwargs.get("then") is None: needed.append("then") - return needed # If main condition is true, we only need "then" + return needed # If the main condition is true, we only need "then" # Check each elif condition elif_index = 0 @@ -153,7 +153,7 @@ class SwitchCase(ComfyNodeABC): def INPUT_TYPES(cls): return { "required": ContainsDynamicDict({ - "selector": (IO.INT, {"default": 0, "min": 0}), + "select": (IO.INT, {"default": 0, "min": 0}), "case_0": (IO.ANY, {"lazy": True, "_dynamic": "number"}), }), "optional": { @@ -167,29 +167,29 @@ def INPUT_TYPES(cls): DESCRIPTION = cleandoc(__doc__ or "") FUNCTION = "execute" - def check_lazy_status(self, selector: int, **kwargs) -> list[str]: + def check_lazy_status(self, select: int, **kwargs) -> list[str]: needed = [] - # Check for needed case inputs based on selector + # Check for necessary case inputs based on select case_count = 0 for key, value in kwargs.items(): if key.startswith("case_"): try: case_index = int(key.split("_")[1]) case_count = max(case_count, case_index + 1) - if value is None and selector == case_index: + if value is None and select == case_index: needed.append(key) except ValueError: pass # Not a numeric case key - # Check if default is needed when selector is out of range - if "default" in kwargs and kwargs["default"] is None and not 0 <= selector < case_count: + # Check if default is needed when select is out of range + if "default" in kwargs and kwargs["default"] is None and not 0 <= select < case_count: needed.append("default") return needed def execute(self, selector: int, **kwargs) -> tuple[Any]: - # Build cases array from all case_X inputs + # Build a case array from all case_X inputs cases = [] for i in range(len(kwargs)): case_key = f"case_{i}" @@ -202,10 +202,44 @@ def execute(self, selector: int, **kwargs) -> tuple[Any]: if 0 <= selector < len(cases) and cases[selector] is not None: return (cases[selector],) - # If selector is out of range or the selected case is None, return default + # If select is out of range or the selected case is None, return default return (kwargs.get("default"),) +class DisableFlow(ComfyNodeABC): + """ + Conditionally enable or disable a flow. + + This node takes a value and either passes it through or blocks execution + based on the 'select' parameter. When 'select' is True, the value passes through; + when False, execution is blocked. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + "select": (IO.BOOLEAN, {"default": True}), + } + } + + RETURN_TYPES = (IO.ANY,) + RETURN_NAMES = ("value",) + CATEGORY = "Basic/flow control" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "execute" + + @classmethod + def IS_CHANGED(s, value: Any): + return float("NaN") # not equal to anything -> trigger recalculation + + def execute(self, value: Any, select: bool = True) -> tuple[Any]: + if select: + return (value,) + else: + return (ExecutionBlocker(None),) + + class FlowSelect(ComfyNodeABC): """ Select the direction of the flow. @@ -237,6 +271,35 @@ def select(self, value, select = True) -> tuple[Any, Any]: return ExecutionBlocker(None), value +class ForceCalculation(ComfyNodeABC): + """ + Forces recalculation of the connected nodes. + + This node passes the input directly to the output but prevents caching + by marking itself as an output node and also indicates the out has changed. + Use this when you need to ensure nodes are always recalculated. + """ + + OUTPUT_NODE = True # Marks as an output node to force calculation + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": (IO.ANY, {}), + } + } + + RETURN_TYPES = (IO.ANY,) + RETURN_NAMES = ("value",) + CATEGORY = "Basic/flow control" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "execute" + + def execute(self, value: Any) -> tuple[Any, int]: + return (value,) + + class ExecutionOrder(ComfyNodeABC): """ Force execution order in the workflow. @@ -268,7 +331,9 @@ def execute(self, **kwargs) -> tuple[Any]: "Basic data handling: IfElse": IfElse, "Basic data handling: IfElifElse": IfElifElse, "Basic data handling: SwitchCase": SwitchCase, + "Basic data handling: DisableFlow": DisableFlow, "Basic data handling: FlowSelect": FlowSelect, + "Basic data handling: ForceCalculation": ForceCalculation, "Basic data handling: ExecutionOrder": ExecutionOrder, } @@ -276,6 +341,8 @@ def execute(self, **kwargs) -> tuple[Any]: "Basic data handling: IfElse": "if/else", "Basic data handling: IfElifElse": "if/elif/.../else", "Basic data handling: SwitchCase": "switch/case", + "Basic data handling: DisableFlow": "disable flow", "Basic data handling: FlowSelect": "flow select", + "Basic data handling: ForceCalculation": "force calculation", "Basic data handling: ExecutionOrder": "force execution order", } diff --git a/src/basic_data_handling/data_list_nodes.py b/src/basic_data_handling/data_list_nodes.py index 5da7b9c..894d68d 100644 --- a/src/basic_data_handling/data_list_nodes.py +++ b/src/basic_data_handling/data_list_nodes.py @@ -153,6 +153,56 @@ def create_list(self, **kwargs: list[Any]) -> tuple[list[Any]]: return (values[:-1],) +class DataListAll(ComfyNodeABC): + """ + Check if all elements in the data list are true. + Returns true if all elements are true (or if the list is empty). + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "list": (IO.ANY, {}), + } + } + + RETURN_TYPES = (IO.BOOLEAN,) + RETURN_NAMES = ("result",) + CATEGORY = "Basic/Data List" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "check_all" + INPUT_IS_LIST = True + + def check_all(self, **kwargs: list[Any]) -> tuple[bool]: + return (all(kwargs.get('list', [])),) + + +class DataListAny(ComfyNodeABC): + """ + Check if any element in the data list is true. + Returns true if at least one element is true. Returns false if the list is empty. + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "list": (IO.ANY, {}), + } + } + + RETURN_TYPES = (IO.BOOLEAN,) + RETURN_NAMES = ("result",) + CATEGORY = "Basic/Data List" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "check_any" + INPUT_IS_LIST = True + + def check_any(self, **kwargs: list[Any]) -> tuple[bool]: + return (any(kwargs.get('list', [])),) + + class DataListAppend(ComfyNodeABC): """ Adds an item to the end of a list. @@ -243,6 +293,37 @@ def count(self, **kwargs: list[Any]) -> tuple[int]: return (kwargs.get('list', []).count(value),) +class DataListEnumerate(ComfyNodeABC): + """ + Enumerate a data list, returning a list of [index, value] pairs. + Optionally, specify a starting value for the index. + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "list": (IO.ANY, {}), + }, + "optional": { + "start": (IO.INT, {"default": 0}), + } + } + + RETURN_TYPES = (IO.ANY,) + RETURN_NAMES = ("list",) + CATEGORY = "Basic/Data List" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "enumerate_list" + INPUT_IS_LIST = True + OUTPUT_IS_LIST = (True,) + + def enumerate_list(self, **kwargs: list[Any]) -> tuple[list]: + input_list = kwargs.get('list', []) + start = kwargs.get('start', [0])[0] + return ([list(item) for item in enumerate(input_list, start=start)],) + + class DataListExtend(ComfyNodeABC): """ Extends a list by appending elements from another list. @@ -675,6 +756,37 @@ def pop_random_element(self, **kwargs: list[Any]) -> tuple[list[Any], Any]: return input_list, None +class DataListRange(ComfyNodeABC): + """ + Creates a data list containing a sequence of numbers. + + This node generates a sequence of numbers similar to Python's range() function. + It takes start, stop, and step parameters to define the sequence. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "start": ("INT", {"default": 0}), + "stop": ("INT", {"default": 10}), + }, + "optional": { + "step": ("INT", {"default": 1}), + } + } + + RETURN_TYPES = ("INT",) + FUNCTION = "create_range" + CATEGORY = "Basic/Data List" + DESCRIPTION = cleandoc(__doc__ or "") + OUTPUT_IS_LIST = (True,) + + def create_range(self, stop: int, start: int = 0, step: int = 1) -> tuple[list[int]]: + if step == 0: + raise ValueError("Step cannot be zero") + return (list(range(start, stop, step)),) + + class DataListRemove(ComfyNodeABC): """ Removes the first occurrence of a specified value from a list. @@ -845,6 +957,37 @@ def sort(self, **kwargs: list[Any]) -> tuple[list[Any]]: return (result,) +class DataListSum(ComfyNodeABC): + """ + Sum all elements of the data list. + Returns 0 for an empty list. + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "list": (IO.NUMBER, {}), + }, + "optional": { + "start": (IO.INT, {"default": 0}), + } + } + + RETURN_TYPES = (IO.INT, IO.FLOAT,) + RETURN_NAMES = ("int_sum", "float_sum",) + CATEGORY = "Basic/Data List" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "sum_list" + INPUT_IS_LIST = True + + def sum_list(self, **kwargs: list[Any]) -> tuple[int, float]: + input_list = kwargs.get('list', []) + start = kwargs.get('start', [0])[0] + result = sum(input_list, start) + return int(result), float(result) + + class DataListZip(ComfyNodeABC): """ Combines multiple lists element-wise. @@ -944,9 +1087,12 @@ def convert(self, **kwargs: list[Any]) -> tuple[set[Any]]: "Basic data handling: DataListCreateFromFloat": DataListCreateFromFloat, "Basic data handling: DataListCreateFromInt": DataListCreateFromInt, "Basic data handling: DataListCreateFromString": DataListCreateFromString, + "Basic data handling: DataListAll": DataListAll, + "Basic data handling: DataListAny": DataListAny, "Basic data handling: DataListAppend": DataListAppend, "Basic data handling: DataListContains": DataListContains, "Basic data handling: DataListCount": DataListCount, + "Basic data handling: DataListEnumerate": DataListEnumerate, "Basic data handling: DataListExtend": DataListExtend, "Basic data handling: DataListFilter": DataListFilter, "Basic data handling: DataListFilterSelect": DataListFilterSelect, @@ -960,11 +1106,13 @@ def convert(self, **kwargs: list[Any]) -> tuple[set[Any]]: "Basic data handling: DataListMin": DataListMin, "Basic data handling: DataListPop": DataListPop, "Basic data handling: DataListPopRandom": DataListPopRandom, + "Basic data handling: DataListRange": DataListRange, "Basic data handling: DataListRemove": DataListRemove, "Basic data handling: DataListReverse": DataListReverse, "Basic data handling: DataListSetItem": DataListSetItem, "Basic data handling: DataListSlice": DataListSlice, "Basic data handling: DataListSort": DataListSort, + "Basic data handling: DataListSum": DataListSum, "Basic data handling: DataListZip": DataListZip, "Basic data handling: DataListToList": DataListToList, "Basic data handling: DataListToSet": DataListToSet, @@ -976,9 +1124,12 @@ def convert(self, **kwargs: list[Any]) -> tuple[set[Any]]: "Basic data handling: DataListCreateFromFloat": "create Data List from FLOATs", "Basic data handling: DataListCreateFromInt": "create Data List from INTs", "Basic data handling: DataListCreateFromString": "create Data List from STRINGs", + "Basic data handling: DataListAll": "all", + "Basic data handling: DataListAny": "any", "Basic data handling: DataListAppend": "append", "Basic data handling: DataListContains": "contains", "Basic data handling: DataListCount": "count", + "Basic data handling: DataListEnumerate": "enumerate", "Basic data handling: DataListExtend": "extend", "Basic data handling: DataListFilter": "filter", "Basic data handling: DataListFilterSelect": "filter select", @@ -992,11 +1143,13 @@ def convert(self, **kwargs: list[Any]) -> tuple[set[Any]]: "Basic data handling: DataListMin": "min", "Basic data handling: DataListPop": "pop", "Basic data handling: DataListPopRandom": "pop random", + "Basic data handling: DataListRange": "range", "Basic data handling: DataListRemove": "remove", "Basic data handling: DataListReverse": "reverse", "Basic data handling: DataListSetItem": "set item", "Basic data handling: DataListSlice": "slice", "Basic data handling: DataListSort": "sort", + "Basic data handling: DataListSum": "sum", "Basic data handling: DataListZip": "zip", "Basic data handling: DataListToList": "convert to LIST", "Basic data handling: DataListToSet": "convert to SET", diff --git a/src/basic_data_handling/list_nodes.py b/src/basic_data_handling/list_nodes.py index 5d3e3e8..870d5ca 100644 --- a/src/basic_data_handling/list_nodes.py +++ b/src/basic_data_handling/list_nodes.py @@ -143,6 +143,52 @@ def create_list(self, **kwargs: list[Any]) -> tuple[list[Any]]: return (values,) +class ListAll: + """ + Check if all elements in the list are true. + Returns true if all elements are true (or if the list is empty). + """ + + @classmethod + def INPUT_TYPES(cls) -> dict: + return { + "required": { + "list": ("LIST",), + } + } + + RETURN_TYPES = ("BOOLEAN",) + CATEGORY = "Basic/LIST" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "check_all" + + def check_all(self, list: list[Any]) -> tuple[bool]: + return (all(list),) + + +class ListAny: + """ + Check if any element in the list is true. + Returns true if at least one element is true. Returns false if the list is empty. + """ + + @classmethod + def INPUT_TYPES(cls) -> dict: + return { + "required": { + "list": ("LIST",), + } + } + + RETURN_TYPES = ("BOOLEAN",) + CATEGORY = "Basic/LIST" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "check_any" + + def check_any(self, list: list[Any]) -> tuple[bool]: + return (any(list),) + + class ListAppend(ComfyNodeABC): """ Adds an item to the end of a LIST. @@ -222,6 +268,32 @@ def count(self, list: list[Any], value: Any) -> tuple[int]: return (list.count(value),) +class ListEnumerate: + """ + Enumerate a list, returning a list of [index, value] pairs. + Optionally, specify a starting value for the index. + """ + + @classmethod + def INPUT_TYPES(cls) -> dict: + return { + "required": { + "list": ("LIST",), + }, + "optional": { + "start": ("INT", {"default": 0}), + } + } + + RETURN_TYPES = ("LIST",) + CATEGORY = "Basic/LIST" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "enumerate_list" + + def enumerate_list(self, list: list[Any], start: int = 0) -> tuple[list]: + return ([__builtins__['list'](enumerate(list, start=start))],) + + class ListExtend(ComfyNodeABC): """ Extends a LIST by appending elements from another LIST. @@ -548,6 +620,36 @@ def pop_random_element(self, list: list[Any]) -> tuple[list[Any], Any]: return result, None +class ListRange(ComfyNodeABC): + """ + Creates a LIST containing a sequence of numbers. + + This node generates a LIST of numbers similar to Python's range() function. + It takes start, stop, and step parameters to define the sequence. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "start": ("INT", {"default": 0}), + "stop": ("INT", {"default": 10}), + }, + "optional": { + "step": ("INT", {"default": 1}), + } + } + + RETURN_TYPES = ("LIST",) + CATEGORY = "Basic/LIST" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "create_range" + + def create_range(self, start: int, stop: int, step: int = 1) -> tuple[list[int]]: + if step == 0: + raise ValueError("Step cannot be zero") + return (list(range(start, stop, step)),) + + class ListRemove(ComfyNodeABC): """ Removes the first occurrence of a specified value from a LIST. @@ -635,7 +737,6 @@ def set_item(self, list: list[Any], index: int, value: Any) -> tuple[list[Any]]: except IndexError: raise IndexError(f"Index {index} out of range for LIST of length {len(list)}") - class ListSlice(ComfyNodeABC): """ Creates a slice of a LIST. @@ -665,7 +766,6 @@ def slice(self, list: list[Any], start: int = 0, stop: int = INT_MAX, step: int = 1) -> tuple[list[Any]]: return (list[start:stop:step],) - class ListSort(ComfyNodeABC): """ Sorts the items in a LIST. @@ -702,6 +802,33 @@ def sort(self, list: list[Any], reverse: str = "False") -> tuple[list[Any]]: return (list.copy(),) +class ListSum: + """ + Sum all elements of the list. + Returns 0 for an empty list. + """ + + @classmethod + def INPUT_TYPES(cls) -> dict: + return { + "required": { + "list": ("LIST",), + }, + "optional": { + "start": ("INT", {"default": 0}), + } + } + + RETURN_TYPES = ("INT", "FLOAT",) + CATEGORY = "Basic/LIST" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "sum_list" + + def sum_list(self, list: list[Any], start: int = 0) -> tuple[int, float]: + result = sum(list, start) + return result, float(result) + + class ListToDataList(ComfyNodeABC): """ Converts a LIST object into a ComfyUI data list. @@ -758,9 +885,12 @@ def convert(self, list: list[Any]) -> tuple[set[Any]]: "Basic data handling: ListCreateFromFloat": ListCreateFromFloat, "Basic data handling: ListCreateFromInt": ListCreateFromInt, "Basic data handling: ListCreateFromString": ListCreateFromString, + "Basic data handling: ListAll": ListAll, + "Basic data handling: ListAny": ListAny, "Basic data handling: ListAppend": ListAppend, "Basic data handling: ListContains": ListContains, "Basic data handling: ListCount": ListCount, + "Basic data handling: ListEnumerate": ListEnumerate, "Basic data handling: ListExtend": ListExtend, "Basic data handling: ListFirst": ListFirst, "Basic data handling: ListGetItem": ListGetItem, @@ -772,11 +902,13 @@ def convert(self, list: list[Any]) -> tuple[set[Any]]: "Basic data handling: ListMin": ListMin, "Basic data handling: ListPop": ListPop, "Basic data handling: ListPopRandom": ListPopRandom, + "Basic data handling: ListRange": ListRange, "Basic data handling: ListRemove": ListRemove, "Basic data handling: ListReverse": ListReverse, "Basic data handling: ListSetItem": ListSetItem, "Basic data handling: ListSlice": ListSlice, "Basic data handling: ListSort": ListSort, + "Basic data handling: ListSum": ListSum, "Basic data handling: ListToDataList": ListToDataList, "Basic data handling: ListToSet": ListToSet, } @@ -787,9 +919,12 @@ def convert(self, list: list[Any]) -> tuple[set[Any]]: "Basic data handling: ListCreateFromFloat": "create LIST from FLOATs", "Basic data handling: ListCreateFromInt": "create LIST from INTs", "Basic data handling: ListCreateFromString": "create LIST from STRINGs", + "Basic data handling: ListAll": "all", + "Basic data handling: ListAny": "any", "Basic data handling: ListAppend": "append", "Basic data handling: ListContains": "contains", "Basic data handling: ListCount": "count", + "Basic data handling: ListEnumerate": "enumerate", "Basic data handling: ListExtend": "extend", "Basic data handling: ListFirst": "first", "Basic data handling: ListGetItem": "get item", @@ -801,11 +936,13 @@ def convert(self, list: list[Any]) -> tuple[set[Any]]: "Basic data handling: ListMin": "min", "Basic data handling: ListPop": "pop", "Basic data handling: ListPopRandom": "pop random", + "Basic data handling: ListRange": "range", "Basic data handling: ListRemove": "remove", "Basic data handling: ListReverse": "reverse", "Basic data handling: ListSetItem": "set item", "Basic data handling: ListSlice": "slice", "Basic data handling: ListSort": "sort", + "Basic data handling: ListSum": "sum", "Basic data handling: ListToDataList": "convert to Data List", "Basic data handling: ListToSet": "convert to SET", } diff --git a/src/basic_data_handling/set_nodes.py b/src/basic_data_handling/set_nodes.py index e4c12c2..18ce224 100644 --- a/src/basic_data_handling/set_nodes.py +++ b/src/basic_data_handling/set_nodes.py @@ -168,6 +168,56 @@ def add(self, set: set[Any], item: Any) -> tuple[set[Any]]: return (result,) +class SetAll(ComfyNodeABC): + """ + Checks if all elements in the SET are true. + + This node takes a SET as input and returns True if all elements in the SET + evaluate to True (or if the SET is empty), and False otherwise. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "set": ("SET", {}), + } + } + + RETURN_TYPES = (IO.BOOLEAN,) + RETURN_NAMES = ("all_true",) + CATEGORY = "Basic/SET" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "check_all" + + def check_all(self, set: set[Any]) -> tuple[bool]: + return (all(set),) + + +class SetAny(ComfyNodeABC): + """ + Checks if any element in the SET is true. + + This node takes a SET as input and returns True if at least one element + in the SET evaluates to True, and False otherwise (including if the SET is empty). + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "set": ("SET", {}), + } + } + + RETURN_TYPES = (IO.BOOLEAN,) + RETURN_NAMES = ("any_true",) + CATEGORY = "Basic/SET" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "check_any" + + def check_any(self, set: set[Any]) -> tuple[bool]: + return (any(set),) + + class SetContains(ComfyNodeABC): """ Checks if a SET contains a specified value. @@ -248,6 +298,37 @@ def discard(self, set: set[Any], item: Any) -> tuple[set[Any]]: return (result,) +class SetEnumerate(ComfyNodeABC): + """ + Enumerates elements in a SET. + + This node takes a SET as input and returns a LIST of tuples where each tuple + contains an index and a value from the SET. The start parameter specifies the + initial index value (default is 0). + + Note: Since SETs are unordered, the enumeration order is arbitrary but consistent + within a single operation. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "set": ("SET", {}), + }, + "optional": { + "start": ("INT", {"default": 0}), + } + } + + RETURN_TYPES = ("LIST",) + CATEGORY = "Basic/SET" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "enumerate_set" + + def enumerate_set(self, set: set[Any], start: int = 0) -> tuple[list]: + return (list(enumerate(set, start=start)),) + + class SetIntersection(ComfyNodeABC): """ Returns the intersection of two or more SETs. @@ -481,6 +562,39 @@ def remove(self, set: set[Any], item: Any) -> tuple[set[Any], bool]: return result, False +class SetSum(ComfyNodeABC): + """ + Calculates the sum of all elements in a SET. + + This node takes a SET as input and returns the sum of all its elements. + The optional start parameter specifies the initial value (default is 0). + + Note: This operation requires all elements to be numeric or otherwise + compatible with addition. If the SET contains mixed or incompatible types, + it may raise an error. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "set": ("SET", {}), + }, + "optional": { + "start": ("INT", {"default": 0}), + } + } + + RETURN_TYPES = ("INT", "FLOAT",) + RETURN_NAMES = ("sum_int", "sum_float",) + CATEGORY = "Basic/SET" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "sum_set" + + def sum_set(self, set: set[Any], start: int = 0) -> tuple[int, float]: + result = sum(set, start) + return result, float(result) + + class SetSymmetricDifference(ComfyNodeABC): """ Returns the symmetric difference between two SETs. @@ -604,9 +718,12 @@ def convert(self, set: set[Any]) -> tuple[list[Any]]: "Basic data handling: SetCreateFromInt": SetCreateFromInt, "Basic data handling: SetCreateFromString": SetCreateFromString, "Basic data handling: SetAdd": SetAdd, + "Basic data handling: SetAll": SetAll, + "Basic data handling: SetAny": SetAny, "Basic data handling: SetContains": SetContains, "Basic data handling: SetDifference": SetDifference, "Basic data handling: SetDiscard": SetDiscard, + "Basic data handling: SetEnumerate": SetEnumerate, "Basic data handling: SetIntersection": SetIntersection, "Basic data handling: SetIsDisjoint": SetIsDisjoint, "Basic data handling: SetIsSubset": SetIsSubset, @@ -615,6 +732,7 @@ def convert(self, set: set[Any]) -> tuple[list[Any]]: "Basic data handling: SetPop": SetPop, "Basic data handling: SetPopRandom": SetPopRandom, "Basic data handling: SetRemove": SetRemove, + "Basic data handling: SetSum": SetSum, "Basic data handling: SetSymmetricDifference": SetSymmetricDifference, "Basic data handling: SetUnion": SetUnion, "Basic data handling: SetToDataList": SetToDataList, @@ -628,9 +746,12 @@ def convert(self, set: set[Any]) -> tuple[list[Any]]: "Basic data handling: SetCreateFromInt": "create SET from INTs", "Basic data handling: SetCreateFromString": "create SET from STRINGs", "Basic data handling: SetAdd": "add", + "Basic data handling: SetAll": "all", + "Basic data handling: SetAny": "any", "Basic data handling: SetContains": "contains", "Basic data handling: SetDifference": "difference", "Basic data handling: SetDiscard": "discard", + "Basic data handling: SetEnumerate": "enumerate", "Basic data handling: SetIntersection": "intersection", "Basic data handling: SetIsDisjoint": "is disjoint", "Basic data handling: SetIsSubset": "is subset", @@ -639,6 +760,7 @@ def convert(self, set: set[Any]) -> tuple[list[Any]]: "Basic data handling: SetPop": "pop", "Basic data handling: SetPopRandom": "pop random", "Basic data handling: SetRemove": "remove", + "Basic data handling: SetSum": "sum", "Basic data handling: SetSymmetricDifference": "symmetric difference", "Basic data handling: SetUnion": "union", "Basic data handling: SetToDataList": "convert to Data List", diff --git a/tests/test_data_list_nodes.py b/tests/test_data_list_nodes.py index ac0c5a5..6560716 100644 --- a/tests/test_data_list_nodes.py +++ b/tests/test_data_list_nodes.py @@ -1,34 +1,39 @@ import pytest from src.basic_data_handling.data_list_nodes import ( + DataListAll, + DataListAny, DataListAppend, - #DataListCreate, - #DataListCreateFromBoolean, - #DataListCreateFromFloat, - #DataListCreateFromInt, - #DataListCreateFromString, + DataListContains, + DataListCount, + DataListCreate, + DataListCreateFromBoolean, + DataListCreateFromFloat, + DataListCreateFromInt, + DataListCreateFromString, + DataListEnumerate, DataListExtend, + DataListFilter, + DataListFilterSelect, + DataListFirst, + DataListGetItem, + DataListIndex, DataListInsert, - DataListRemove, + DataListLast, + DataListLength, + DataListMax, + DataListMin, DataListPop, DataListPopRandom, - DataListIndex, - DataListCount, - DataListSort, + DataListRange, + DataListRemove, DataListReverse, - DataListLength, - DataListSlice, - DataListGetItem, DataListSetItem, - DataListContains, + DataListSlice, + DataListSort, + DataListSum, + DataListToList, + DataListToSet, DataListZip, - DataListFilter, - DataListFilterSelect, - DataListFirst, - DataListLast, - DataListMin, - DataListMax, - #DataListToList, - #DataListToSet, ) @@ -174,6 +179,66 @@ def test_filter_select(): assert false_list == [1, 2, 3] +def test_create(): + node = DataListCreate() + # Testing with one item + assert node.create_list(item_0="test", _dynamic_number=1) == (["test"],) + + # Testing with multiple items of different types + assert node.create_list(item_0=1, item_1="two", item_2=3.0, _dynamic_number=3) == ([1, "two", 3.0],) + + # Testing with empty list (no items) + assert node.create_list(_dynamic_number=0) == ([],) + + +def test_create_from_boolean(): + node = DataListCreateFromBoolean() + # Testing with boolean values + assert node.create_list(item_0=True, item_1=False, _dynamic_number=2) == ([True, False],) + + # Testing with boolean-convertible values + assert node.create_list(item_0=1, item_1=0, _dynamic_number=2) == ([True, False],) + + # Testing with empty list + assert node.create_list(_dynamic_number=0) == ([],) + + +def test_create_from_float(): + node = DataListCreateFromFloat() + # Testing with float values + assert node.create_list(item_0=1.5, item_1=2.5, _dynamic_number=2) == ([1.5, 2.5],) + + # Testing with float-convertible values + assert node.create_list(item_0=1, item_1="2.5", _dynamic_number=2) == ([1.0, 2.5],) + + # Testing with empty list + assert node.create_list(_dynamic_number=0) == ([],) + + +def test_create_from_int(): + node = DataListCreateFromInt() + # Testing with integer values + assert node.create_list(item_0=1, item_1=2, _dynamic_number=2) == ([1, 2],) + + # Testing with int-convertible values + assert node.create_list(item_0="1", item_1=2.0, _dynamic_number=2) == ([1, 2],) + + # Testing with empty list + assert node.create_list(_dynamic_number=0) == ([],) + + +def test_create_from_string(): + node = DataListCreateFromString() + # Testing with string values + assert node.create_list(item_0="hello", item_1="world", _dynamic_number=2) == (["hello", "world"],) + + # Testing with string-convertible values + assert node.create_list(item_0=123, item_1=True, _dynamic_number=2) == (["123", "True"],) + + # Testing with empty list + assert node.create_list(_dynamic_number=0) == ([],) + + def test_pop_random(): node = DataListPopRandom() result_list, item = node.pop_random_element(list=[1]) @@ -201,3 +266,105 @@ def test_last(): assert node.get_last_element(list=["a", "b", "c"]) == ("c",) assert node.get_last_element(list=[]) == (None,) # Empty list + +def test_to_list(): + node = DataListToList() + # Test with regular list + assert node.convert(list=[1, 2, 3]) == ([1, 2, 3],) + # Test with empty list + assert node.convert(list=[]) == ([],) + # Test with mixed types + assert node.convert(list=[1, "two", 3.0]) == ([1, "two", 3.0],) + + +def test_to_set(): + node = DataListToSet() + # Test with regular list + assert node.convert(list=[1, 2, 3]) == ({1, 2, 3},) + # Test with empty list + assert node.convert(list=[]) == (set(),) + # Test with duplicates + assert node.convert(list=[1, 2, 1, 3, 2]) == ({1, 2, 3},) + # Test with mixed types (that can be in a set) + assert node.convert(list=[1, "two", 3.0]) == ({1, "two", 3.0},) + + +def test_all(): + node = DataListAll() + # All true values + assert node.check_all(list=[True, True, 1, "text"]) == (True,) + # Contains false value + assert node.check_all(list=[True, False, True]) == (False,) + # Empty list (returns True) + assert node.check_all(list=[]) == (True,) + + +def test_any(): + node = DataListAny() + # Contains true value + assert node.check_any(list=[False, True, False]) == (True,) + # All false values + assert node.check_any(list=[False, 0, "", None]) == (False,) + # Empty list (returns False) + assert node.check_any(list=[]) == (False,) + + +def test_enumerate(): + node = DataListEnumerate() + # Basic enumeration starting from 0 + assert node.enumerate_list(list=['a', 'b', 'c']) == ([[0, 'a'], [1, 'b'], [2, 'c']],) + # Custom start index + assert node.enumerate_list(list=['x', 'y', 'z'], start=[10]) == ([[10, 'x'], [11, 'y'], [12, 'z']],) + # Empty list + assert node.enumerate_list(list=[]) == ([],) + + +def test_sum(): + node = DataListSum() + # Integer sum + int_sum, float_sum = node.sum_list(list=[1, 2, 3]) + assert int_sum == 6 + assert float_sum == 6.0 + + # Mixed number types + int_sum, float_sum = node.sum_list(list=[1, 2.5, 3]) + assert int_sum == 6 # Integer part of the sum + assert float_sum == 6.5 # Full float sum + + # With start value + int_sum, float_sum = node.sum_list(list=[1, 2, 3], start=[10]) + assert int_sum == 16 + assert float_sum == 16.0 + + # Empty list with start value + int_sum, float_sum = node.sum_list(list=[], start=[5]) + assert int_sum == 5 + assert float_sum == 5.0 + + +def test_range(): + node = DataListRange() + # Test with default start (0) and step (1) + assert node.create_range(stop=5) == ([0, 1, 2, 3, 4],) + + # Test with custom start and stop + assert node.create_range(start=2, stop=6) == ([2, 3, 4, 5],) + + # Test with custom step + assert node.create_range(start=0, stop=10, step=2) == ([0, 2, 4, 6, 8],) + + # Test with negative numbers + assert node.create_range(start=-3, stop=3) == ([-3, -2, -1, 0, 1, 2],) + + # Test backward counting + assert node.create_range(start=5, stop=0, step=-1) == ([5, 4, 3, 2, 1],) + + # Test empty range + assert node.create_range(start=0, stop=0) == ([],) + + # Test when start > stop with positive step (returns empty list) + assert node.create_range(start=10, stop=5) == ([],) + + # Test with ValueError (step cannot be zero) + with pytest.raises(ValueError, match="Step cannot be zero"): + node.create_range(start=0, stop=5, step=0) diff --git a/tests/test_list_nodes.py b/tests/test_list_nodes.py index 668fba8..1da2fbb 100644 --- a/tests/test_list_nodes.py +++ b/tests/test_list_nodes.py @@ -1,5 +1,7 @@ import pytest from src.basic_data_handling.list_nodes import ( + ListAll, + ListAny, ListAppend, ListContains, ListCount, @@ -8,6 +10,7 @@ ListCreateFromFloat, ListCreateFromInt, ListCreateFromString, + ListEnumerate, ListExtend, ListFirst, ListGetItem, @@ -19,11 +22,13 @@ ListMin, ListPop, ListPopRandom, + ListRange, ListRemove, ListReverse, ListSetItem, ListSlice, ListSort, + ListSum, ListToDataList, ListToSet, ) @@ -221,3 +226,121 @@ def test_list_to_set(): assert node.convert([1, 2, 3, 2, 1]) == ({1, 2, 3},) assert node.convert([]) == (set(),) assert node.convert(["a", "b", "a", "c"]) == ({"a", "b", "c"},) + + +def test_list_range(): + node = ListRange() + # Basic range with default step=1 + assert node.create_range(start=0, stop=5) == ([0, 1, 2, 3, 4],) + + # Range with custom step + assert node.create_range(start=0, stop=10, step=2) == ([0, 2, 4, 6, 8],) + + # Negative step (counting down) + assert node.create_range(start=5, stop=0, step=-1) == ([5, 4, 3, 2, 1],) + + # Start > stop with positive step (empty list) + assert node.create_range(start=10, stop=5) == ([],) + + # Empty range + assert node.create_range(start=0, stop=0) == ([],) + + # Test with negative indices + assert node.create_range(start=-5, stop=0) == ([-5, -4, -3, -2, -1],) + + # Test error case: step=0 + with pytest.raises(ValueError, match="Step cannot be zero"): + node.create_range(start=0, stop=5, step=0) + + +def test_list_all(): + node = ListAll() + # All true values + assert node.check_all([True, True, True]) == (True,) + + # Mixed values (all truthy) + assert node.check_all([1, "text", [1, 2]]) == (True,) + + # Contains a false value + assert node.check_all([True, False, True]) == (False,) + + # Contains a falsy value + assert node.check_all([True, 0, True]) == (False,) + + # Empty list (returns True according to Python's all() behavior) + assert node.check_all([]) == (True,) + + +def test_list_any(): + node = ListAny() + # All true values + assert node.check_any([True, True, True]) == (True,) + + # Mixed values with at least one true + assert node.check_any([False, True, False]) == (True,) + + # Mixed values with at least one truthy + assert node.check_any([0, "", 1]) == (True,) + + # All false values + assert node.check_any([False, False, False]) == (False,) + + # All falsy values + assert node.check_any([0, "", []]) == (False,) + + # Empty list (returns False according to Python's any() behavior) + assert node.check_any([]) == (False,) + + +def test_list_enumerate(): + node = ListEnumerate() + + # Basic enumeration with default start=0 + result = node.enumerate_list(["a", "b", "c"]) + assert result == ([[(0, "a"), (1, "b"), (2, "c")]],) + + # Enumeration with custom start + result = node.enumerate_list(["x", "y", "z"], start=10) + assert result == ([[(10, "x"), (11, "y"), (12, "z")]],) + + # Enumeration of empty list + result = node.enumerate_list([]) + assert result == ([[]],) + + # Enumeration of list with mixed types + result = node.enumerate_list([1, "text", True]) + assert result == ([[(0, 1), (1, "text"), (2, True)]],) + + +def test_list_sum(): + node = ListSum() + + # Sum of integers + int_result, float_result = node.sum_list([1, 2, 3, 4]) + assert int_result == 10 + assert float_result == 10.0 + + # Sum with default start value (0) + int_result, float_result = node.sum_list([5, 10, 15]) + assert int_result == 30 + assert float_result == 30.0 + + # Sum with custom start value + int_result, float_result = node.sum_list([1, 2, 3], start=10) + assert int_result == 16 + assert float_result == 16.0 + + # Sum of floats (should still return both int and float) + int_result, float_result = node.sum_list([1.5, 2.5, 3.0]) + assert int_result == 7.0 # Note: This will be a float, but returned as first value + assert float_result == 7.0 + + # Sum of empty list + int_result, float_result = node.sum_list([]) + assert int_result == 0 + assert float_result == 0.0 + + # Sum of mixed numbers + int_result, float_result = node.sum_list([1, 2.5, 3]) + assert int_result == 6.5 + assert float_result == 6.5 diff --git a/tests/test_set_nodes.py b/tests/test_set_nodes.py index 6ad37a1..dfb2276 100644 --- a/tests/test_set_nodes.py +++ b/tests/test_set_nodes.py @@ -1,6 +1,8 @@ #import pytest from src.basic_data_handling.set_nodes import ( SetAdd, + SetAll, + SetAny, SetContains, SetCreate, SetCreateFromBoolean, @@ -9,6 +11,7 @@ SetCreateFromString, SetDifference, SetDiscard, + SetEnumerate, SetIntersection, SetIsDisjoint, SetIsSubset, @@ -17,6 +20,7 @@ SetPop, SetPopRandom, SetRemove, + SetSum, SetSymmetricDifference, SetToDataList, SetToList, @@ -246,3 +250,95 @@ def test_set_to_data_list(): # Mixed types result = node.convert({1, "string", True}) assert set(result[0]) == {1, "string", True} # Can't check order, just content + + +def test_set_all(): + node = SetAll() + + # Test with all truthy values + assert node.check_all({1, True, "string", 3.14}) == (True,) + + # Test with one falsy value + assert node.check_all({1, False, "string"}) == (False,) + + # Test with all falsy values + assert node.check_all({False, 0, "", None}) == (False,) + + # Test with empty set (should return True per Python's all() behavior) + assert node.check_all(set()) == (True,) + + +def test_set_any(): + node = SetAny() + + # Test with all truthy values + assert node.check_any({1, True, "string", 3.14}) == (True,) + + # Test with one truthy value + assert node.check_any({0, False, "", 1}) == (True,) + + # Test with all falsy values + assert node.check_any({False, 0, "", None}) == (False,) + + # Test with empty set (should return False per Python's any() behavior) + assert node.check_any(set()) == (False,) + + +def test_set_enumerate(): + node = SetEnumerate() + + # Basic test with default start=0 + result = node.enumerate_set({10, 20, 30}) + assert isinstance(result, tuple) + assert isinstance(result[0], list) + + # Convert to set of tuples for comparison (order may vary) + result_set = {tuple(item) for item in result[0]} + assert result_set == {(0, 10), (1, 20), (2, 30)} + + # Test with custom start value + result = node.enumerate_set({10, 20, 30}, start=5) + result_set = {tuple(item) for item in result[0]} + assert result_set == {(5, 10), (6, 20), (7, 30)} + + # Test with empty set + result = node.enumerate_set(set()) + assert result[0] == [] + + # Test with mixed types + result = node.enumerate_set({1, "string", False}) + assert len(result[0]) == 3 + # Check format but not exact values due to arbitrary order + for item in result[0]: + assert isinstance(item, tuple) + assert len(item) == 2 + assert isinstance(item[0], int) + + +def test_set_sum(): + node = SetSum() + + # Test with integer set + int_result, float_result = node.sum_set({1, 2, 3}) + assert int_result == 6 + assert float_result == 6.0 + + # Test with float set + int_result, float_result = node.sum_set({1.5, 2.5, 3.0}) + assert int_result == 7.0 + assert float_result == 7.0 + + # Test with mixed numeric types + int_result, float_result = node.sum_set({1, 2.5, 3}) + assert int_result == 6.5 + assert float_result == 6.5 + + # Test with custom start value + int_result, float_result = node.sum_set({1, 2, 3}, start=10) + assert int_result == 16 + assert float_result == 16.0 + + # Test with empty set + int_result, float_result = node.sum_set(set(), start=5) + assert int_result == 5 + assert float_result == 5.0