Skip to content

Commit fdbbb1f

Browse files
committed
test error handling
1 parent a713081 commit fdbbb1f

File tree

1 file changed

+44
-24
lines changed

1 file changed

+44
-24
lines changed

examples/lexico.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
def all_perm(n, k, endian=None):
99
"""all_perm(n, k, endian=None) -> iterator
1010
11-
Return an iterator over all bitarrays of length `n` with `k` bits set to one
12-
in lexicographical order.
11+
Return an iterator over all bitarrays of length `n` and
12+
population count `k` in lexicographical order.
1313
"""
1414
if n < 0:
1515
raise ValueError("length must be >= 0")
@@ -31,23 +31,25 @@ def all_perm(n, k, endian=None):
3131
v = t | ((((t & -t) // (v & -v)) >> 1) - 1)
3232

3333

34-
def next_perm(a):
35-
"""next_perm(bitarray) -> bitarray
34+
def next_perm(__a):
35+
"""next_perm(a, /) -> bitarray
3636
37-
Return the next lexicographical permutation. The length and the number
38-
of 1 bits in the bitarray is unchanged. The integer value (`ba2int`) of the
39-
next permutation will always increase, except when the cycle is completed (in
40-
which case the lowest lexicographical permutation will be returned).
37+
Return the next lexicographical permutation of bitarray `a`. The length
38+
and population count of the result is that of `a`. The integer
39+
value (`ba2int()`) of the next permutation will always increase, except
40+
when the cycle is completed. In that case, the lowest lexicographical
41+
permutation will be returned.
4142
"""
42-
v = ba2int(a)
43+
v = ba2int(__a)
4344
if v == 0:
44-
return a
45+
return __a
46+
4547
t = (v | (v - 1)) + 1
4648
v = t | ((((t & -t) // (v & -v)) >> 1) - 1)
4749
try:
48-
return int2ba(v, length=len(a), endian=a.endian)
50+
return int2ba(v, length=len(__a), endian=__a.endian)
4951
except OverflowError:
50-
return a[::-1]
52+
return __a[::-1]
5153

5254
# ---------------------------------------------------------------------------
5355

@@ -62,13 +64,25 @@ def next_perm(a):
6264

6365
class PermTests(unittest.TestCase):
6466

65-
def test_explicit_1(self):
66-
a = bitarray('00010011', 'big')
67-
for s in ['00010101', '00010110', '00011001',
68-
'00011010', '00011100', '00100011']:
69-
a = next_perm(a)
70-
self.assertEqual(a.count(), 3)
71-
self.assertEqual(a, bitarray(s, 'big'))
67+
def test_errors(self):
68+
N = next_perm
69+
self.assertRaises(TypeError, N)
70+
self.assertRaises(TypeError, N, bitarray('1'), 1)
71+
self.assertRaises(TypeError, N, '1')
72+
self.assertRaises(ValueError, N, bitarray())
73+
74+
A = all_perm
75+
self.assertRaises(TypeError, A)
76+
self.assertRaises(TypeError, A, 4)
77+
self.assertRaises(TypeError, next, A("4", 2))
78+
self.assertRaises(TypeError, next, A(1, "0.5"))
79+
self.assertRaises(TypeError, A, 1, p=1)
80+
self.assertRaises(TypeError, next, A(11, 5.5))
81+
self.assertRaises(ValueError, next, A(-1, 0))
82+
for k in -1, 11: # k is not 0 <= k <= n
83+
self.assertRaises(ValueError, next, A(10, k))
84+
self.assertRaises(ValueError, next, A(10, 7, 'foo'))
85+
self.assertRaises(ValueError, next, A(10, 7, endian='foo'))
7286

7387
def test_zeros_ones(self):
7488
for n in range(1, 30):
@@ -88,7 +102,15 @@ def test_zeros_ones(self):
88102
self.assertEqual(next_perm(a), a)
89103
self.assertEqual(a, c)
90104

91-
def test_turnover(self):
105+
def test_next_perm_explicit(self):
106+
a = bitarray('00010011', 'big')
107+
for s in ['00010101', '00010110', '00011001',
108+
'00011010', '00011100', '00100011']:
109+
a = next_perm(a)
110+
self.assertEqual(a.count(), 3)
111+
self.assertEqual(a, bitarray(s, 'big'))
112+
113+
def test_next_perm_turnover(self):
92114
for a in [bitarray('11111110000', 'big'),
93115
bitarray('0000001111111', 'little')]:
94116
self.assertEqual(next_perm(a), a[::-1])
@@ -104,15 +126,12 @@ def test_next_perm_random(self):
104126
self.assertEqual(b.endian, a.endian)
105127
self.assertNotEqual(a, b)
106128
if ba2int(a) > ba2int(b):
129+
print(n)
107130
c = a.copy()
108131
c.sort(c.endian == 'big')
109132
self.assertEqual(a, c)
110133
self.assertEqual(b, a[::-1])
111134

112-
def test_errors(self):
113-
self.assertRaises(ValueError, next_perm, bitarray())
114-
self.assertRaises(TypeError, next_perm, '1')
115-
116135
def check_perm_cycle(self, start):
117136
n, k = len(start), start.count()
118137
a = bitarray(start)
@@ -173,6 +192,7 @@ def test_all_perm(self):
173192
c.sort(c.endian == "little")
174193
self.assertEqual(a, c)
175194
else:
195+
self.assertNotEqual(a, first)
176196
self.assertEqual(next_perm(prev), a)
177197
self.assertTrue(ba2int(prev) < ba2int(a))
178198
prev = a

0 commit comments

Comments
 (0)