diff --git a/pyproject.toml b/pyproject.toml index 997058e..0a9b282 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "basic_data_handling" -version = "0.3.3" +version = "0.3.4" description = """NOTE: Still in development! Expect breaking changes! Basic Python functions for manipulating data that every programmer is used to. Currently supported ComfyUI data types: BOOLEAN, FLOAT, INT, STRING and data lists. @@ -52,9 +52,9 @@ pythonpath = [ testpaths = [ "tests", ] -#python_files = ["test_*.py"] +python_files = ["test_*.py"] #python_files = ["conftest.py", "test_boolean_nodes.py"] -python_files = ["test_boolean_nodes.py"] +#python_files = ["test_boolean_nodes.py"] [tool.mypy] files = "." diff --git a/src/basic_data_handling/_dynamic_input.py b/src/basic_data_handling/_dynamic_input.py index e02d69a..fce3bc3 100644 --- a/src/basic_data_handling/_dynamic_input.py +++ b/src/basic_data_handling/_dynamic_input.py @@ -24,7 +24,6 @@ def __contains__(self, key): def __getitem__(self, key): # Dynamically return the value for keys matching a `prefix` pattern - print(f'_ dynamic prefixes: {self._dynamic_prefixes}; get key: {key}') for prefix, value in self._dynamic_prefixes.items(): if key.startswith(prefix) and key[len(prefix):].isdigit(): return value diff --git a/src/basic_data_handling/data_list_nodes.py b/src/basic_data_handling/data_list_nodes.py index 4c157bd..583f76d 100644 --- a/src/basic_data_handling/data_list_nodes.py +++ b/src/basic_data_handling/data_list_nodes.py @@ -348,6 +348,33 @@ def select(self, **kwargs: list[Any]) -> tuple[list[Any]]: return result_true, result_false +class DataListFirst(ComfyNodeABC): + """ + Returns the first element in a list. + + This node takes a list as input and returns the first element of the list. + If the list is empty, it returns None. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "list": (IO.ANY, {}), + } + } + + RETURN_TYPES = (IO.ANY,) + RETURN_NAMES = ("first_element",) + CATEGORY = "Basic/Data List" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "get_first_element" + INPUT_IS_LIST = True + + def get_first_element(self, **kwargs: list[Any]) -> tuple[Any]: + input_list = kwargs.get('list', []) + return (input_list[0] if input_list else None,) + + class DataListGetItem(ComfyNodeABC): """ Retrieves an item at a specified position in a list. @@ -453,6 +480,33 @@ def insert(self, **kwargs: list[Any]) -> tuple[list[Any]]: return (result,) +class DataListLast(ComfyNodeABC): + """ + Returns the last element in a list. + + This node takes a list as input and returns the last element of the list. + If the list is empty, it returns None. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "list": (IO.ANY, {}), + } + } + + RETURN_TYPES = (IO.ANY,) + RETURN_NAMES = ("last_element",) + CATEGORY = "Basic/Data List" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "get_last_element" + INPUT_IS_LIST = True + + def get_last_element(self, **kwargs: list[Any]) -> tuple[Any]: + input_list = kwargs.get('list', []) + return (input_list[-1] if input_list else None,) + + class DataListLength(ComfyNodeABC): """ Counts the number of items in a list. @@ -588,6 +642,38 @@ def pop(self, **kwargs: list[Any]) -> tuple[list[Any], Any]: return result, None +class DataListPopRandom(ComfyNodeABC): + """ + Removes and returns a random element from a list. + + This node takes a list as input and returns the list with the random element removed + and the removed element itself. If the list is empty, it returns None for the element. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "list": (IO.ANY, {}), + } + } + + RETURN_TYPES = (IO.ANY, IO.ANY) + RETURN_NAMES = ("list", "item") + CATEGORY = "Basic/Data List" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "pop_random_element" + INPUT_IS_LIST = True + OUTPUT_IS_LIST = (True, False) + + def pop_random_element(self, **kwargs: list[Any]) -> tuple[list[Any], Any]: + from random import randrange + input_list = kwargs.get('list', []).copy() + if input_list: + random_element = input_list.pop(randrange(len(input_list))) + return input_list, random_element + return input_list, None + + class DataListRemove(ComfyNodeABC): """ Removes the first occurrence of a specified value from a list. @@ -864,13 +950,16 @@ def convert(self, **kwargs: list[Any]) -> tuple[set[Any]]: "Basic data handling: DataListExtend": DataListExtend, "Basic data handling: DataListFilter": DataListFilter, "Basic data handling: DataListFilterSelect": DataListFilterSelect, + "Basic data handling: DataListFirst": DataListFirst, "Basic data handling: DataListGetItem": DataListGetItem, "Basic data handling: DataListIndex": DataListIndex, "Basic data handling: DataListInsert": DataListInsert, + "Basic data handling: DataListLast": DataListLast, "Basic data handling: DataListLength": DataListLength, "Basic data handling: DataListMax": DataListMax, "Basic data handling: DataListMin": DataListMin, "Basic data handling: DataListPop": DataListPop, + "Basic data handling: DataListPopRandom": DataListPopRandom, "Basic data handling: DataListRemove": DataListRemove, "Basic data handling: DataListReverse": DataListReverse, "Basic data handling: DataListSetItem": DataListSetItem, @@ -893,13 +982,16 @@ def convert(self, **kwargs: list[Any]) -> tuple[set[Any]]: "Basic data handling: DataListExtend": "extend", "Basic data handling: DataListFilter": "filter", "Basic data handling: DataListFilterSelect": "filter select", + "Basic data handling: DataListFirst": "first", "Basic data handling: DataListGetItem": "get item", "Basic data handling: DataListIndex": "index", "Basic data handling: DataListInsert": "insert", + "Basic data handling: DataListLast": "last", "Basic data handling: DataListLength": "length", "Basic data handling: DataListMax": "max", "Basic data handling: DataListMin": "min", "Basic data handling: DataListPop": "pop", + "Basic data handling: DataListPopRandom": "pop random", "Basic data handling: DataListRemove": "remove", "Basic data handling: DataListReverse": "reverse", "Basic data handling: DataListSetItem": "set item", diff --git a/src/basic_data_handling/dict_nodes.py b/src/basic_data_handling/dict_nodes.py index d243aa2..c2f0871 100644 --- a/src/basic_data_handling/dict_nodes.py +++ b/src/basic_data_handling/dict_nodes.py @@ -719,6 +719,42 @@ def popitem(self, input_dict: dict) -> tuple[dict, str, Any, bool]: return result, "", None, False +class DictPopRandom(ComfyNodeABC): + """ + Removes and returns a random key-value pair from a dictionary. + + This node takes a dictionary as input, removes a random key-value pair, + and returns the modified dictionary along with the removed key and value. + If the dictionary is empty, it returns empty values. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input_dict": ("DICT", {}), + } + } + + RETURN_TYPES = ("DICT", IO.STRING, IO.ANY, IO.BOOLEAN) + RETURN_NAMES = ("dict", "key", "value", "success") + CATEGORY = "Basic/DICT" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "pop_random" + + def pop_random(self, input_dict: dict) -> tuple[dict, str, Any, bool]: + import random + result = input_dict.copy() + try: + if result: + random_key = random.choice(list(result.keys())) + random_value = result.pop(random_key) + return result, random_key, random_value, True + else: + return result, "", None, False + except: + return result, "", None, False + + class DictRemove(ComfyNodeABC): """ Removes a key-value pair from a dictionary. @@ -883,6 +919,7 @@ def values(self, input_dict: dict) -> tuple[list]: "Basic data handling: DictMerge": DictMerge, "Basic data handling: DictPop": DictPop, "Basic data handling: DictPopItem": DictPopItem, + "Basic data handling: DictPopRandom": DictPopRandom, "Basic data handling: DictRemove": DictRemove, "Basic data handling: DictSet": DictSet, "Basic data handling: DictSetDefault": DictSetDefault, @@ -913,7 +950,8 @@ def values(self, input_dict: dict) -> tuple[list]: "Basic data handling: DictLength": "length", "Basic data handling: DictMerge": "merge", "Basic data handling: DictPop": "pop", - "Basic data handling: DictPopItem": "popitem", + "Basic data handling: DictPopItem": "pop item", + "Basic data handling: DictPopRandom": "pop random", "Basic data handling: DictRemove": "remove", "Basic data handling: DictSet": "set", "Basic data handling: DictSetDefault": "setdefault", diff --git a/src/basic_data_handling/list_nodes.py b/src/basic_data_handling/list_nodes.py index 11d8300..49fe20f 100644 --- a/src/basic_data_handling/list_nodes.py +++ b/src/basic_data_handling/list_nodes.py @@ -248,6 +248,31 @@ def extend(self, list1: list[Any], list2: list[Any]) -> tuple[list[Any]]: return (result,) +class ListFirst(ComfyNodeABC): + """ + Returns the first element in a LIST. + + This node takes a LIST as input and returns the first element of the list. + If the LIST is empty, it returns None. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "list": ("LIST", {}), + } + } + + RETURN_TYPES = (IO.ANY,) + RETURN_NAMES = ("first_element",) + CATEGORY = "Basic/LIST" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "get_first_element" + + def get_first_element(self, list: list[Any]) -> tuple[Any]: + return (list[0] if list else None,) + + class ListGetItem(ComfyNodeABC): """ Retrieves an item at a specified position in a LIST. @@ -343,6 +368,31 @@ def insert(self, list: list[Any], index: int, item: Any) -> tuple[list[Any]]: return (result,) +class ListLast(ComfyNodeABC): + """ + Returns the last element in a LIST. + + This node takes a LIST as input and returns the last element of the list. + If the LIST is empty, it returns None. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "list": ("LIST", {}), + } + } + + RETURN_TYPES = (IO.ANY,) + RETURN_NAMES = ("last_element",) + CATEGORY = "Basic/LIST" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "get_last_element" + + def get_last_element(self, list: list[Any]) -> tuple[Any]: + return (list[-1] if list else None,) + + class ListLength(ComfyNodeABC): """ Returns the number of items in a LIST. @@ -466,6 +516,37 @@ def pop(self, list: list[Any], index: int = -1) -> tuple[list[Any], Any]: return result, None +class ListPopRandom(ComfyNodeABC): + """ + Removes and returns a random element from a LIST. + + This node takes a LIST as input and returns the LIST with the random element removed + and the removed element itself. If the LIST is empty, it returns None for the element. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "list": ("LIST", {}), + } + } + + RETURN_TYPES = ("LIST", IO.ANY) + RETURN_NAMES = ("list", "item") + CATEGORY = "Basic/LIST" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "pop_random_element" + + def pop_random_element(self, list: list[Any]) -> tuple[list[Any], Any]: + import random + result = list.copy() + if result: + random_index = random.randrange(len(result)) + random_element = result.pop(random_index) + return result, random_element + return result, None + + class ListRemove(ComfyNodeABC): """ Removes the first occurrence of a specified value from a LIST. @@ -680,13 +761,16 @@ def convert(self, list: list[Any]) -> tuple[set[Any]]: "Basic data handling: ListContains": ListContains, "Basic data handling: ListCount": ListCount, "Basic data handling: ListExtend": ListExtend, + "Basic data handling: ListFirst": ListFirst, "Basic data handling: ListGetItem": ListGetItem, "Basic data handling: ListIndex": ListIndex, "Basic data handling: ListInsert": ListInsert, + "Basic data handling: ListLast": ListLast, "Basic data handling: ListLength": ListLength, "Basic data handling: ListMax": ListMax, "Basic data handling: ListMin": ListMin, "Basic data handling: ListPop": ListPop, + "Basic data handling: ListPopRandom": ListPopRandom, "Basic data handling: ListRemove": ListRemove, "Basic data handling: ListReverse": ListReverse, "Basic data handling: ListSetItem": ListSetItem, @@ -706,13 +790,16 @@ def convert(self, list: list[Any]) -> tuple[set[Any]]: "Basic data handling: ListContains": "contains", "Basic data handling: ListCount": "count", "Basic data handling: ListExtend": "extend", + "Basic data handling: ListFirst": "first", "Basic data handling: ListGetItem": "get item", "Basic data handling: ListIndex": "index", "Basic data handling: ListInsert": "insert", + "Basic data handling: ListLast": "last", "Basic data handling: ListLength": "length", "Basic data handling: ListMax": "max", "Basic data handling: ListMin": "min", "Basic data handling: ListPop": "pop", + "Basic data handling: ListPopRandom": "pop random", "Basic data handling: ListRemove": "remove", "Basic data handling: ListReverse": "reverse", "Basic data handling: ListSetItem": "set item", diff --git a/src/basic_data_handling/set_nodes.py b/src/basic_data_handling/set_nodes.py index ea9014e..7bebdd9 100644 --- a/src/basic_data_handling/set_nodes.py +++ b/src/basic_data_handling/set_nodes.py @@ -417,6 +417,37 @@ def pop(self, set: set[Any]) -> tuple[set[Any], Any]: return result, None +class SetPopRandom(ComfyNodeABC): + """ + Removes and returns a random element from a SET. + + This node takes a SET as input and returns the SET with a random element removed + and the removed element itself. If the SET is empty, it returns None for the element. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "set": ("SET", {}), + } + } + + RETURN_TYPES = ("SET", IO.ANY) + RETURN_NAMES = ("set", "item") + CATEGORY = "Basic/SET" + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "pop_random_element" + + def pop_random_element(self, set: set[Any]) -> tuple[set[Any], Any]: + import random + result = set.copy() + if result: + random_element = random.choice(list(result)) + result.remove(random_element) + return result, random_element + return result, None + + class SetRemove(ComfyNodeABC): """ Removes an item from a SET. @@ -581,6 +612,7 @@ def convert(self, set: set[Any]) -> tuple[list[Any]]: "Basic data handling: SetIsSuperset": SetIsSuperset, "Basic data handling: SetLength": SetLength, "Basic data handling: SetPop": SetPop, + "Basic data handling: SetPopRandom": SetPopRandom, "Basic data handling: SetRemove": SetRemove, "Basic data handling: SetSymmetricDifference": SetSymmetricDifference, "Basic data handling: SetUnion": SetUnion, @@ -604,6 +636,7 @@ def convert(self, set: set[Any]) -> tuple[list[Any]]: "Basic data handling: SetIsSuperset": "is superset", "Basic data handling: SetLength": "length", "Basic data handling: SetPop": "pop", + "Basic data handling: SetPopRandom": "pop random", "Basic data handling: SetRemove": "remove", "Basic data handling: SetSymmetricDifference": "symmetric difference", "Basic data handling: SetUnion": "union", diff --git a/tests/test_data_list_nodes.py b/tests/test_data_list_nodes.py index ab98877..ac0c5a5 100644 --- a/tests/test_data_list_nodes.py +++ b/tests/test_data_list_nodes.py @@ -1,10 +1,16 @@ import pytest from src.basic_data_handling.data_list_nodes import ( DataListAppend, + #DataListCreate, + #DataListCreateFromBoolean, + #DataListCreateFromFloat, + #DataListCreateFromInt, + #DataListCreateFromString, DataListExtend, DataListInsert, DataListRemove, DataListPop, + DataListPopRandom, DataListIndex, DataListCount, DataListSort, @@ -16,8 +22,13 @@ DataListContains, DataListZip, DataListFilter, + DataListFilterSelect, + DataListFirst, + DataListLast, DataListMin, DataListMax, + #DataListToList, + #DataListToSet, ) @@ -144,3 +155,49 @@ def test_max(): assert node.find_max(list=[3, 1, 2]) == (3,) assert node.find_max(list=[-1, -5, 0]) == (0,) assert node.find_max(list=[]) == (None,) + + +def test_filter_select(): + node = DataListFilterSelect() + true_list, false_list = node.select(value=[1, 2, 3], select=[True, False, True]) + assert true_list == [1, 3] + assert false_list == [2] + + # All true case + true_list, false_list = node.select(value=["a", "b"], select=[True, True]) + assert true_list == ["a", "b"] + assert false_list == [] + + # All false case + true_list, false_list = node.select(value=[1, 2, 3], select=[False, False, False]) + assert true_list == [] + assert false_list == [1, 2, 3] + + +def test_pop_random(): + node = DataListPopRandom() + result_list, item = node.pop_random_element(list=[1]) + assert result_list == [] and item == 1 # Only one item, so it must be chosen + + # With multiple items, we can't predict which one will be popped + # But we can check that an item was removed and returned + result_list, item = node.pop_random_element(list=[1, 2, 3]) + assert len(result_list) == 2 and item in [1, 2, 3] + + # Empty list case + assert node.pop_random_element(list=[]) == ([], None) + + +def test_first(): + node = DataListFirst() + assert node.get_first_element(list=[1, 2, 3]) == (1,) + assert node.get_first_element(list=["a", "b", "c"]) == ("a",) + assert node.get_first_element(list=[]) == (None,) # Empty list + + +def test_last(): + node = DataListLast() + assert node.get_last_element(list=[1, 2, 3]) == (3,) + assert node.get_last_element(list=["a", "b", "c"]) == ("c",) + assert node.get_last_element(list=[]) == (None,) # Empty list + diff --git a/tests/test_dict_nodes.py b/tests/test_dict_nodes.py index 39fd979..7b21d50 100644 --- a/tests/test_dict_nodes.py +++ b/tests/test_dict_nodes.py @@ -1,27 +1,32 @@ #import pytest from src.basic_data_handling.dict_nodes import ( + DictCompare, + DictContainsKey, DictCreate, + DictCreateFromBoolean, + DictCreateFromFloat, + DictCreateFromInt, + DictCreateFromLists, + DictCreateFromString, + DictExcludeKeys, + DictFilterByKeys, + DictFromKeys, DictGet, - DictSet, - DictKeys, - DictValues, + DictGetKeysValues, + DictGetMultiple, + DictInvert, DictItems, - DictContainsKey, - DictFromKeys, + DictKeys, + DictLength, + DictMerge, DictPop, DictPopItem, + DictPopRandom, + DictRemove, + DictSet, DictSetDefault, DictUpdate, - DictLength, - DictMerge, - DictGetKeysValues, - DictRemove, - DictFilterByKeys, - DictExcludeKeys, - DictGetMultiple, - DictInvert, - DictCreateFromLists, - DictCompare, + DictValues, ) def test_dict_create(): @@ -47,6 +52,65 @@ def test_dict_set(): # Test with empty dict assert node.set({}, "key", "value") == ({"key": "value"},) +def test_dict_create_from_boolean(): + node = DictCreateFromBoolean() + # Test with dynamic inputs + result = node.create(key_0="key1", value_0=True, key_1="key2", value_1=False) + assert result == ({"key1": True, "key2": False},) + # Test with empty inputs + assert node.create() == ({},) + + +def test_dict_create_from_float(): + node = DictCreateFromFloat() + # Test with dynamic inputs + result = node.create(key_0="key1", value_0=1.5, key_1="key2", value_1=2.5) + assert result == ({"key1": 1.5, "key2": 2.5},) + # Test with empty inputs + assert node.create() == ({},) + + +def test_dict_create_from_int(): + node = DictCreateFromInt() + # Test with dynamic inputs + result = node.create(key_0="key1", value_0=1, key_1="key2", value_1=2) + assert result == ({"key1": 1, "key2": 2},) + # Test with empty inputs + assert node.create() == ({},) + + +def test_dict_create_from_string(): + node = DictCreateFromString() + # Test with dynamic inputs + result = node.create(key_0="key1", value_0="value1", key_1="key2", value_1="value2") + assert result == ({"key1": "value1", "key2": "value2"},) + # Test with empty inputs + assert node.create() == ({},) + + +def test_dict_pop_random(): + node = DictPopRandom() + # Test with non-empty dictionary + my_dict = {"key1": "value1", "key2": "value2"} + result_dict, key, value, success = node.pop_random(my_dict) + + # Check that operation was successful + assert success is True + # Check that one item was removed + assert len(result_dict) == len(my_dict) - 1 + # Check that removed key is not in result dict + assert key not in result_dict + # Check that the original key-value pair matches + assert my_dict[key] == value + + # Test with empty dictionary + empty_result_dict, empty_key, empty_value, empty_success = node.pop_random({}) + assert empty_result_dict == {} + assert empty_key == "" + assert empty_value is None + assert empty_success is False + + def test_dict_keys(): node = DictKeys() diff --git a/tests/test_list_nodes.py b/tests/test_list_nodes.py index fa36b2f..2c8abbf 100644 --- a/tests/test_list_nodes.py +++ b/tests/test_list_nodes.py @@ -1,28 +1,31 @@ import pytest from src.basic_data_handling.list_nodes import ( + ListAppend, + ListContains, + ListCount, ListCreate, ListCreateFromBoolean, ListCreateFromFloat, ListCreateFromInt, ListCreateFromString, - ListAppend, ListExtend, + ListFirst, + ListGetItem, + ListIndex, ListInsert, - ListRemove, + ListLast, + ListLength, + ListMax, + ListMin, ListPop, - ListIndex, - ListCount, - ListSort, + ListPopRandom, + ListRemove, ListReverse, - ListLength, - ListSlice, - ListGetItem, ListSetItem, - ListContains, - ListMin, - ListMax, + ListSlice, + ListSort, ListToDataList, - ListToSet + ListToSet, ) @@ -61,6 +64,38 @@ def test_list_pop(): assert node.pop([], 0) == ([], None) # Empty list pop +def test_list_pop_random(): + node = ListPopRandom() + # Test with single item - must remove that item + result, item = node.pop_random_element([42]) + assert result == [] and item == 42 + + # Test with multiple items - can't predict which one will be popped + # but we can check the result list length and that the popped item was from the original list + original_list = [1, 2, 3, 4] + result, item = node.pop_random_element(original_list) + assert len(result) == len(original_list) - 1 + assert item in original_list + assert item not in result + + # Test with empty list + assert node.pop_random_element([]) == ([], None) + + +def test_list_first(): + node = ListFirst() + assert node.get_first_element([1, 2, 3]) == (1,) + assert node.get_first_element(["a", "b", "c"]) == ("a",) + assert node.get_first_element([]) == (None,) # Empty list + + +def test_list_last(): + node = ListLast() + assert node.get_last_element([1, 2, 3]) == (3,) + assert node.get_last_element(["a", "b", "c"]) == ("c",) + assert node.get_last_element([]) == (None,) # Empty list + + def test_list_index(): node = ListIndex() assert node.index([1, 2, 3, 2], 2) == (1,) diff --git a/tests/test_set_nodes.py b/tests/test_set_nodes.py index aeb3fb4..83a91d9 100644 --- a/tests/test_set_nodes.py +++ b/tests/test_set_nodes.py @@ -1,25 +1,26 @@ #import pytest from src.basic_data_handling.set_nodes import ( SetAdd, - SetRemove, + SetContains, + SetCreate, + SetCreateFromBoolean, + SetCreateFromFloat, + SetCreateFromInt, + SetCreateFromString, + SetDifference, SetDiscard, - SetPop, - SetUnion, SetIntersection, - SetDifference, - SetSymmetricDifference, + SetIsDisjoint, SetIsSubset, SetIsSuperset, - SetIsDisjoint, - SetContains, SetLength, - SetToList, - SetCreate, + SetPop, + SetPopRandom, + SetRemove, + SetSymmetricDifference, SetToDataList, - SetCreateFromInt, - SetCreateFromString, - SetCreateFromFloat, - SetCreateFromBoolean + SetToList, + SetUnion, ) def test_set_create(): @@ -28,6 +29,8 @@ def test_set_create(): assert node.create_set(item_0=1, item_1=2, item_2=3) == ({1, 2, 3},) assert node.create_set(item_0="a", item_1="b") == ({"a", "b"},) assert node.create_set() == (set(),) # Empty set with no arguments + # Mixed types + assert node.create_set(item_0=1, item_1="b", item_2=True) == ({1, "b", True},) def test_set_create_from_int(): @@ -35,15 +38,21 @@ def test_set_create_from_int(): assert node.create_set(item_0=1, item_1=2, item_2=3) == ({1, 2, 3},) assert node.create_set(item_0=5) == ({5},) # Single item set assert node.create_set(item_0=1, item_1=1) == ({1},) # Duplicate items become single item + assert node.create_set() == (set(),) # Empty set with no arguments def test_set_create_from_string(): node = SetCreateFromString() - # Note: Mocking the string function behavior as it's not defined in the file - # This simulates what would happen assuming string() acts like str() - node.create_set = lambda **kwargs: (set([str(value) for value in kwargs.values()]),) - assert node.create_set(item_0="apple", item_1="banana") == ({"apple", "banana"},) - assert node.create_set(item_0="apple", item_1="apple") == ({"apple"},) # Duplicate strings + result = node.create_set(item_0="apple", item_1="banana") + assert isinstance(result[0], set) + assert result[0] == {"apple", "banana"} + + # Duplicate strings + result = node.create_set(item_0="apple", item_1="apple") + assert result[0] == {"apple"} + + # Empty set + assert node.create_set() == (set(),) def test_set_create_from_float(): @@ -51,30 +60,40 @@ def test_set_create_from_float(): assert node.create_set(item_0=1.5, item_1=2.5) == ({1.5, 2.5},) assert node.create_set(item_0=3.14) == ({3.14},) # Single item set assert node.create_set(item_0=1.0, item_1=1.0) == ({1.0},) # Duplicate items + assert node.create_set() == (set(),) # Empty set with no arguments def test_set_create_from_boolean(): node = SetCreateFromBoolean() assert node.create_set(item_0=True, item_1=False) == ({True, False},) assert node.create_set(item_0=True, item_1=True) == ({True},) # Duplicate booleans + assert node.create_set() == (set(),) # Empty set with no arguments + # Test conversion from non-boolean values + assert node.create_set(item_0=1, item_1=0) == ({True, False},) def test_set_add(): node = SetAdd() assert node.add({1, 2}, 3) == ({1, 2, 3},) assert node.add({1, 2}, 1) == ({1, 2},) # Adding an existing item + assert node.add(set(), "first") == ({"first"},) # Adding to empty set + assert node.add({1, 2}, "string") == ({1, 2, "string"},) # Adding different type def test_set_remove(): node = SetRemove() assert node.remove({1, 2, 3}, 2) == ({1, 3}, True) # Successful removal assert node.remove({1, 2, 3}, 4) == ({1, 2, 3}, False) # Item not in set + assert node.remove({1}, 1) == (set(), True) # Removing the only element + assert node.remove(set(), 1) == (set(), False) # Removing from empty set def test_set_discard(): node = SetDiscard() assert node.discard({1, 2, 3}, 2) == ({1, 3},) # Successful removal assert node.discard({1, 2, 3}, 4) == ({1, 2, 3},) # No error for missing item + assert node.discard({1}, 1) == (set(),) # Discarding the only element + assert node.discard(set(), 1) == (set(),) # Discarding from empty set def test_set_pop(): @@ -83,16 +102,40 @@ def test_set_pop(): result_set, removed_item = node.pop(input_set) assert result_set != input_set # Arbitrary item removed assert removed_item in input_set # Removed item was part of original set + assert len(result_set) == len(input_set) - 1 # One item was removed + assert removed_item not in result_set # Removed item is not in result set empty_set = set() assert node.pop(empty_set) == (set(), None) # Handle empty set +def test_set_pop_random(): + node = SetPopRandom() + # Test with single item - must remove that item + single_item_set = {42} + result_set, removed_item = node.pop_random_element(single_item_set) + assert result_set == set() and removed_item == 42 + + # Test with multiple items - can't predict which one will be popped + # but we can check the result set size and that the popped item was from the original set + original_set = {1, 2, 3, 4} + result_set, removed_item = node.pop_random_element(original_set) + assert len(result_set) == len(original_set) - 1 + assert removed_item in original_set + assert removed_item not in result_set + + # Test with empty set + empty_set = set() + assert node.pop_random_element(empty_set) == (set(), None) + + def test_set_union(): node = SetUnion() assert node.union({1, 2}, {3, 4}) == ({1, 2, 3, 4},) assert node.union({1}, {2}, {3}, {4}) == ({1, 2, 3, 4},) assert node.union({1, 2}, set()) == ({1, 2},) # Union with empty set + assert node.union(set(), set()) == (set(),) # Union of empty sets + assert node.union({1, 2}, {2, 3}) == ({1, 2, 3},) # Overlapping sets def test_set_intersection(): @@ -100,18 +143,25 @@ def test_set_intersection(): assert node.intersection({1, 2, 3}, {2, 3, 4}) == ({2, 3},) assert node.intersection({1, 2, 3}, {4, 5}) == (set(),) # No common elements assert node.intersection({1, 2, 3}, {2, 3}, {3, 4}) == ({3},) # Multiple sets + assert node.intersection({1, 2, 3}, {1, 2, 3}) == ({1, 2, 3},) # Identical sets + assert node.intersection(set(), {1, 2, 3}) == (set(),) # Empty set intersection def test_set_difference(): node = SetDifference() assert node.difference({1, 2, 3}, {2, 3, 4}) == ({1},) assert node.difference({1, 2, 3}, {4, 5}) == ({1, 2, 3},) # Nothing to remove + assert node.difference({1, 2, 3}, {1, 2, 3}) == (set(),) # Identical sets + assert node.difference(set(), {1, 2, 3}) == (set(),) # Empty set difference + assert node.difference({1, 2, 3}, set()) == ({1, 2, 3},) # Difference with empty set def test_set_symmetric_difference(): node = SetSymmetricDifference() assert node.symmetric_difference({1, 2, 3}, {3, 4, 5}) == ({1, 2, 4, 5},) assert node.symmetric_difference({1, 2, 3}, {1, 2, 3}) == (set(),) # No unique elements + assert node.symmetric_difference(set(), {1, 2, 3}) == ({1, 2, 3},) # Empty set symmetric difference + assert node.symmetric_difference({1, 2, 3}, set()) == ({1, 2, 3},) # Symmetric difference with empty set def test_set_is_subset(): @@ -119,6 +169,8 @@ def test_set_is_subset(): assert node.is_subset({1, 2}, {1, 2, 3}) == (True,) assert node.is_subset({1, 4}, {1, 2, 3}) == (False,) assert node.is_subset(set(), {1, 2, 3}) == (True,) # Empty set is subset of all sets + assert node.is_subset({1, 2}, {1, 2}) == (True,) # Set is subset of itself + assert node.is_subset({1, 2, 3}, {1, 2}) == (False,) # Superset is not a subset def test_set_is_superset(): @@ -126,32 +178,51 @@ def test_set_is_superset(): assert node.is_superset({1, 2, 3}, {1, 2}) == (True,) assert node.is_superset({1, 2}, {1, 2, 3}) == (False,) assert node.is_superset(set(), set()) == (True,) # Empty set is a superset of itself + assert node.is_superset({1, 2}, {1, 2}) == (True,) # Set is superset of itself + assert node.is_superset({1, 2}, set()) == (True,) # Any set is superset of empty set def test_set_is_disjoint(): node = SetIsDisjoint() assert node.is_disjoint({1, 2}, {3, 4}) == (True,) # No common elements assert node.is_disjoint({1, 2}, {2, 3}) == (False,) # Common element + assert node.is_disjoint(set(), {1, 2, 3}) == (True,) # Empty set is disjoint with any set + assert node.is_disjoint({1, 2}, set()) == (True,) # Empty set is disjoint with any set + assert node.is_disjoint(set(), set()) == (True,) # Empty sets are disjoint def test_set_contains(): node = SetContains() assert node.contains({1, 2, 3}, 2) == (True,) assert node.contains({1, 2, 3}, 4) == (False,) + assert node.contains(set(), 1) == (False,) # Empty set contains nothing + assert node.contains({1, "string", True}, "string") == (True,) # Mixed type set + assert node.contains({1, "string", True}, False) == (False,) # Boolean check def test_set_length(): node = SetLength() assert node.length({1, 2, 3}) == (3,) assert node.length(set()) == (0,) # Empty set + assert node.length({1, 1, 1, 1}) == (1,) # Set with duplicate values (only counts unique) + assert node.length({12, "string", True, 3.14}) == (4,) # Mixed types def test_set_to_list(): node = SetToList() result = node.convert({1, 2, 3}) assert isinstance(result, tuple) + assert isinstance(result[0], list) assert sorted(result[0]) == [1, 2, 3] # Validate conversion to list + # Empty set + result = node.convert(set()) + assert result[0] == [] + + # Mixed types + result = node.convert({1, "string", True}) + assert set(result[0]) == {1, "string", True} # Can't check order, just content + def test_set_to_data_list(): node = SetToDataList() @@ -159,3 +230,19 @@ def test_set_to_data_list(): assert isinstance(result, tuple) assert isinstance(result[0], list) assert sorted(result[0]) == [1, 2, 3] # Validate conversion to data list + + # Empty set + result = node.convert(set()) + assert result[0] == [] + + # Mixed types + result = node.convert({1, "string", True}) + assert set(result[0]) == {1, "string", True} # Can't check order, just content + + # Empty set + result = node.convert(set()) + assert result[0] == [] + + # Mixed types + result = node.convert({1, "string", True}) + assert set(result[0]) == {1, "string", True} # Can't check order, just content