diff --git a/src/py-opentimelineio/opentimelineio/core/_core_utils.py b/src/py-opentimelineio/opentimelineio/core/_core_utils.py index d58791a67b..eb12a813b1 100644 --- a/src/py-opentimelineio/opentimelineio/core/_core_utils.py +++ b/src/py-opentimelineio/opentimelineio/core/_core_utils.py @@ -305,6 +305,41 @@ def insert(self, index, item): if conversion_func else item ) + def __le__(self, other): # Taken from collections.abc.Set + if not isinstance(other, collections.abc.Sequence): + return NotImplemented + if len(self) > len(other): + return False + for elem in self: + if elem not in other: + return False + return True + + def __lt__(self, other): # Taken from collections.abc.Set + if not isinstance(other, collections.abc.Sequence): + return NotImplemented + return len(self) < len(other) and self.__le__(other) + + def __gt__(self, other): # Taken from collections.abc.Set + if not isinstance(other, collections.abc.Sequence): + return NotImplemented + return len(self) > len(other) and self.__ge__(other) + + def __ge__(self, other): # Taken from collections.abc.Set + if not isinstance(other, collections.abc.Sequence): + return NotImplemented + if len(self) < len(other): + return False + for elem in other: + if elem not in self: + return False + return True + + def __eq__(self, other): # Taken from collections.abc.Set + if not isinstance(other, collections.abc.Sequence): + return NotImplemented + return len(self) == len(other) and self.__le__(other) + collections.abc.MutableSequence.register(sequenceClass) sequenceClass.__radd__ = __radd__ sequenceClass.__add__ = __add__ @@ -314,6 +349,11 @@ def insert(self, index, item): sequenceClass.insert = insert sequenceClass.__str__ = __str__ sequenceClass.__repr__ = __repr__ + sequenceClass.__le__ = __le__ + sequenceClass.__lt__ = __lt__ + sequenceClass.__gt__ = __gt__ + sequenceClass.__ge__ = __ge__ + sequenceClass.__eq__ = __eq__ seen = set() for klass in (collections.abc.MutableSequence, collections.abc.Sequence): diff --git a/tests/test_core_utils.py b/tests/test_core_utils.py index a0a7b9425f..b0755a61ba 100644 --- a/tests/test_core_utils.py +++ b/tests/test_core_utils.py @@ -100,10 +100,10 @@ def test_main(self): v.append(2) self.assertEqual(len(v), 2) - self.assertEqual([value for value in v], [1, 2]) + self.assertEqual(v, [1, 2]) v.insert(0, 5) - self.assertEqual([value for value in v], [5, 1, 2]) + self.assertEqual(v, [5, 1, 2]) self.assertEqual(v[0], 5) self.assertEqual(v[-3], 5) @@ -124,13 +124,11 @@ def test_main(self): del v[0] self.assertEqual(len(v), 2) - # Doesn't work... - # assert v == [1, 100] - self.assertEqual([value for value in v], [1, 100]) + self.assertEqual(v, [1, 100]) del v[1000] # This will surprisingly delete the last item... self.assertEqual(len(v), 1) - self.assertEqual([value for value in v], [1]) + self.assertEqual(v, [1]) # Will delete the last item even if the index doesn't match. # It's a surprising behavior. @@ -144,7 +142,7 @@ def test_main(self): items.append(value) self.assertEqual(items, [1, '234', {}]) - self.assertFalse(v == [1, '234', {}]) # __eq__ is not implemented + self.assertTrue(v == [1, '234', {}]) self.assertTrue(1 in v) # Test __contains__ self.assertTrue('234' in v) @@ -181,13 +179,13 @@ def test_main(self): self.assertEqual(v3[1:7:2], [1, 3, 5]) del v3[2:7] - self.assertEqual(list(v3), [0, 1, 7, 8, 9]) + self.assertEqual(v3, [0, 1, 7, 8, 9]) v4 = opentimelineio.core._core_utils.AnyVector() v4.extend(range(10)) del v4[::2] - self.assertEqual(list(v4), [1, 3, 5, 7, 9]) + self.assertEqual(v4, [1, 3, 5, 7, 9]) v5 = opentimelineio.core._core_utils.AnyVector() tmplist = [1, 2] @@ -225,7 +223,7 @@ def test_raises_if_ref_destroyed(self): def test_copy(self): list1 = [1, 2, [3, 4], 5] copied = copy.copy(list1) - self.assertEqual(list(list1), list(copied)) + self.assertEqual(list1, copied) v = opentimelineio.core._core_utils.AnyVector() v.extend([1, 2, [3, 4], 5])