Skip to content

Commit 44a59ae

Browse files
Copilotaraffin
andcommitted
Add tests for PR #302 fix (issue #81 numerical stability)
Co-authored-by: araffin <1973948+araffin@users.noreply.github.com>
1 parent 541161e commit 44a59ae

File tree

1 file changed

+236
-0
lines changed

1 file changed

+236
-0
lines changed

tests/test_distributions.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,242 @@ def test_masking_affects_entropy(self):
7373
dist.apply_masking(masks)
7474
assert int(dist.entropy().exp()) == v
7575

76+
def test_numerical_stability_with_masking(self):
77+
"""
78+
Test that masking does not cause numerical precision issues.
79+
This test is related to issue #81 and PR #302.
80+
81+
The bug occurred when using -1e8 as the masked logit value, which could cause
82+
numerical precision issues in rare cases: the probabilities computed via
83+
logits_to_probs would not sum exactly to 1.0 within PyTorch's tolerance,
84+
causing a ValueError during Categorical distribution initialization when
85+
validate_args=True.
86+
87+
With the fix (using -inf and storing pristine logits), this should not occur.
88+
89+
Note: The original bug was intermittent and difficult to reproduce
90+
deterministically. This test verifies the correct behavior of the fix rather
91+
than attempting to trigger the original bug.
92+
"""
93+
# Use logits with various ranges to test numerical stability
94+
test_cases = [
95+
# Small logits
96+
th.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]),
97+
# Mixed positive/negative
98+
th.tensor([[-1.0, 2.0, -0.5, 1.5, 0.0]]),
99+
# Large logits (more susceptible to precision issues)
100+
th.tensor([[10.0, -5.0, 3.0, -2.0, 0.5]]),
101+
# Very large batch similar to bug report
102+
th.randn(64, 400) * 2.0,
103+
]
104+
105+
for logits in test_cases:
106+
# Test with validation enabled (validate_args=True)
107+
# This is where the bug would manifest in the old code
108+
dist = MaskableCategorical(logits=logits, validate_args=True)
109+
110+
# Apply various masks - this triggers re-initialization of the distribution
111+
# which is where the numerical precision issue would occur
112+
num_actions = logits.shape[-1]
113+
batch_size = logits.shape[0]
114+
115+
# Test with different mask patterns
116+
masks = [
117+
# Mask out every other action
118+
th.tensor([[i % 2 == 0 for i in range(num_actions)]] * batch_size),
119+
# Mask out first half
120+
th.tensor([[i < num_actions // 2 for i in range(num_actions)]] * batch_size),
121+
# Random mask
122+
th.rand(batch_size, num_actions) > 0.3,
123+
]
124+
125+
for mask in masks:
126+
# Ensure at least one action is valid per batch
127+
for i in range(batch_size):
128+
if not mask[i].any():
129+
mask[i, 0] = True
130+
131+
# This should not raise a ValueError about Simplex constraint
132+
dist.apply_masking(mask)
133+
134+
# Verify that probs are valid (sum to 1.0)
135+
prob_sums = dist.probs.sum(dim=-1)
136+
assert th.allclose(prob_sums, th.ones_like(prob_sums), atol=1e-6), f"Probs don't sum to 1: {prob_sums}"
137+
138+
# Verify entropy can be computed without NaN/inf issues
139+
entropy = dist.entropy()
140+
assert th.isfinite(entropy).all(), f"Entropy not finite: {entropy}"
141+
142+
# Verify masked actions have very low or zero probability
143+
# With -inf masking (PR #302 fix), they should be exactly 0
144+
# With -1e8 masking (old code), they would be very small but non-zero
145+
masked_actions = ~mask
146+
if masked_actions.any():
147+
masked_probs = dist.probs[masked_actions]
148+
# After PR #302, masked probabilities should be exactly or very close to 0
149+
assert th.allclose(masked_probs, th.zeros_like(masked_probs), atol=1e-7), f"Masked probs not near zero: {masked_probs[:10]}"
150+
151+
# Test with None mask (removing masking)
152+
dist.apply_masking(None)
153+
prob_sums = dist.probs.sum(dim=-1)
154+
assert th.allclose(prob_sums, th.ones_like(prob_sums), atol=1e-6)
155+
156+
def test_entropy_with_all_but_one_masked(self):
157+
"""
158+
Test entropy calculation when all but one action is masked.
159+
This is an edge case that should result in zero entropy (no uncertainty).
160+
Related to issue #81 and PR #302.
161+
"""
162+
NUM_DIMS = 5
163+
logits = th.randn(10, NUM_DIMS) # Random logits for batch of 10
164+
165+
dist = MaskableCategorical(logits=logits, validate_args=True)
166+
167+
# Mask all but one action (different valid action for each batch element)
168+
for i in range(NUM_DIMS):
169+
mask = th.zeros(10, NUM_DIMS, dtype=th.bool)
170+
mask[:, i] = True # Only action i is valid
171+
172+
dist.apply_masking(mask)
173+
174+
# With only one valid action, entropy should be 0 (or very close to 0)
175+
entropy = dist.entropy()
176+
assert th.allclose(entropy, th.zeros_like(entropy), atol=1e-5)
177+
178+
# The valid action should have probability 1.0
179+
assert th.allclose(dist.probs[:, i], th.ones(10), atol=1e-5)
180+
181+
# All other actions should have probability 0.0
182+
for j in range(NUM_DIMS):
183+
if j != i:
184+
assert th.allclose(dist.probs[:, j], th.zeros(10), atol=1e-5)
185+
186+
def test_repeated_masking_stability(self):
187+
"""
188+
Test that repeatedly applying different masks maintains numerical stability.
189+
This test verifies the fix from PR #302 where pristine logits are stored
190+
and used for each masking operation, avoiding accumulated numerical errors.
191+
Related to issue #81 and PR #302.
192+
"""
193+
# Start with some logits
194+
original_logits = th.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
195+
dist = MaskableCategorical(logits=original_logits.clone(), validate_args=True)
196+
197+
# Apply a series of different masks
198+
masks = [
199+
th.tensor([[True, True, False, False, False]]),
200+
th.tensor([[False, False, True, True, True]]),
201+
th.tensor([[True, False, True, False, True]]),
202+
th.tensor([[False, True, False, True, False]]),
203+
th.tensor([[True, True, True, True, True]]), # All valid
204+
]
205+
206+
for mask in masks:
207+
dist.apply_masking(mask)
208+
209+
# Verify probabilities are valid
210+
prob_sum = dist.probs.sum(dim=-1)
211+
assert th.allclose(prob_sum, th.ones_like(prob_sum), atol=1e-6), f"Probs sum: {prob_sum}, expected 1.0"
212+
213+
# Verify masked actions have 0 probability
214+
masked_out = ~mask
215+
masked_probs = dist.probs[masked_out]
216+
if masked_probs.numel() > 0:
217+
assert th.allclose(masked_probs, th.zeros_like(masked_probs), atol=1e-7), f"Masked probs: {masked_probs}"
218+
219+
# Verify entropy is finite and non-negative
220+
entropy = dist.entropy()
221+
assert th.isfinite(entropy).all(), f"Entropy contains inf/nan: {entropy}"
222+
assert (entropy >= 0).all(), f"Entropy is negative: {entropy}"
223+
224+
# After all masks, remove masking and verify we get consistent results
225+
dist.apply_masking(None)
226+
prob_sum = dist.probs.sum(dim=-1)
227+
assert th.allclose(prob_sum, th.ones_like(prob_sum), atol=1e-6)
228+
229+
def test_masked_actions_have_zero_probability(self):
230+
"""
231+
Test that masked actions have exactly zero probability with proper masking.
232+
233+
This test verifies that masked actions get zero probability, which is important
234+
for the fix in PR #302. While both -1e8 and -inf produce zero probabilities
235+
after softmax due to underflow, using -inf is mathematically more correct
236+
and avoids potential numerical issues in edge cases.
237+
238+
Related to issue #81 and PR #302.
239+
"""
240+
# Test with various logit scales
241+
test_logits = [
242+
th.tensor([[0.0, 1.0, 2.0, 3.0]]), # Small scale
243+
th.tensor([[10.0, 20.0, 30.0, 40.0]]), # Large scale
244+
th.randn(5, 10) * 5.0, # Random batch
245+
]
246+
247+
for logits in test_logits:
248+
dist = MaskableCategorical(logits=logits, validate_args=True)
249+
250+
# Create a mask that masks out alternating actions
251+
mask = th.zeros_like(logits, dtype=th.bool)
252+
mask[:, ::2] = True # Keep even indices, mask odd indices
253+
254+
dist.apply_masking(mask)
255+
256+
# Check that masked actions have exactly zero probability
257+
masked_indices = ~mask
258+
if masked_indices.any():
259+
masked_probs = dist.probs[masked_indices]
260+
# Both old (-1e8) and new (-inf) implementations should produce 0 here
261+
# due to softmax underflow, but -inf is more robust
262+
assert th.allclose(masked_probs, th.zeros_like(masked_probs), atol=1e-10), f"Masked actions should have ~0 probability, got: {masked_probs[:10]}"
263+
264+
# Verify unmasked actions have non-zero probabilities
265+
unmasked_probs = dist.probs[mask]
266+
assert th.all(unmasked_probs > 0.0), "Unmasked actions should have positive probability"
267+
268+
# Verify probabilities sum to 1
269+
prob_sums = dist.probs.sum(dim=-1)
270+
assert th.allclose(prob_sums, th.ones_like(prob_sums), atol=1e-6)
271+
272+
def test_entropy_numerical_stability_with_masking(self):
273+
"""
274+
Test entropy calculation numerical stability with masked actions.
275+
276+
This specifically tests the improved entropy calculation from PR #302.
277+
The old entropy calculation could have issues with -1e8 logits, while
278+
the new calculation properly handles -inf values.
279+
280+
Related to issue #81 and PR #302.
281+
"""
282+
# Test with various scenarios including edge cases
283+
test_cases = [
284+
# All but one action masked (entropy should be ~0)
285+
(th.tensor([[1.0, 2.0, 3.0, 4.0]]), th.tensor([[True, False, False, False]])),
286+
# Half actions masked
287+
(th.tensor([[1.0, 2.0, 3.0, 4.0]]), th.tensor([[True, True, False, False]])),
288+
# Large batch with various masks
289+
(th.randn(32, 10), th.rand(32, 10) > 0.3),
290+
]
291+
292+
for logits, mask in test_cases:
293+
# Ensure at least one action is valid per batch
294+
for i in range(mask.shape[0]):
295+
if not mask[i].any():
296+
mask[i, 0] = True
297+
298+
dist = MaskableCategorical(logits=logits, validate_args=True)
299+
dist.apply_masking(mask)
300+
301+
# Compute entropy - should not produce NaN or inf
302+
entropy = dist.entropy()
303+
assert th.isfinite(entropy).all(), f"Entropy should be finite, got: {entropy}"
304+
assert (entropy >= 0).all(), f"Entropy should be non-negative, got: {entropy}"
305+
306+
# For single valid action, entropy should be close to 0
307+
single_action_mask = mask.sum(dim=-1) == 1
308+
if single_action_mask.any():
309+
single_action_entropy = entropy[single_action_mask]
310+
assert th.allclose(single_action_entropy, th.zeros_like(single_action_entropy), atol=1e-4), f"Single action entropy should be ~0, got: {single_action_entropy}"
311+
76312

77313
class TestMaskableCategoricalDistribution:
78314
def test_distribution_must_be_initialized(self):

0 commit comments

Comments
 (0)