@@ -73,6 +73,242 @@ def test_masking_affects_entropy(self):
7373 dist .apply_masking (masks )
7474 assert int (dist .entropy ().exp ()) == v
7575
76+ def test_numerical_stability_with_masking (self ):
77+ """
78+ Test that masking does not cause numerical precision issues.
79+ This test is related to issue #81 and PR #302.
80+
81+ The bug occurred when using -1e8 as the masked logit value, which could cause
82+ numerical precision issues in rare cases: the probabilities computed via
83+ logits_to_probs would not sum exactly to 1.0 within PyTorch's tolerance,
84+ causing a ValueError during Categorical distribution initialization when
85+ validate_args=True.
86+
87+ With the fix (using -inf and storing pristine logits), this should not occur.
88+
89+ Note: The original bug was intermittent and difficult to reproduce
90+ deterministically. This test verifies the correct behavior of the fix rather
91+ than attempting to trigger the original bug.
92+ """
93+ # Use logits with various ranges to test numerical stability
94+ test_cases = [
95+ # Small logits
96+ th .tensor ([[0.1 , 0.2 , 0.3 , 0.4 , 0.5 ]]),
97+ # Mixed positive/negative
98+ th .tensor ([[- 1.0 , 2.0 , - 0.5 , 1.5 , 0.0 ]]),
99+ # Large logits (more susceptible to precision issues)
100+ th .tensor ([[10.0 , - 5.0 , 3.0 , - 2.0 , 0.5 ]]),
101+ # Very large batch similar to bug report
102+ th .randn (64 , 400 ) * 2.0 ,
103+ ]
104+
105+ for logits in test_cases :
106+ # Test with validation enabled (validate_args=True)
107+ # This is where the bug would manifest in the old code
108+ dist = MaskableCategorical (logits = logits , validate_args = True )
109+
110+ # Apply various masks - this triggers re-initialization of the distribution
111+ # which is where the numerical precision issue would occur
112+ num_actions = logits .shape [- 1 ]
113+ batch_size = logits .shape [0 ]
114+
115+ # Test with different mask patterns
116+ masks = [
117+ # Mask out every other action
118+ th .tensor ([[i % 2 == 0 for i in range (num_actions )]] * batch_size ),
119+ # Mask out first half
120+ th .tensor ([[i < num_actions // 2 for i in range (num_actions )]] * batch_size ),
121+ # Random mask
122+ th .rand (batch_size , num_actions ) > 0.3 ,
123+ ]
124+
125+ for mask in masks :
126+ # Ensure at least one action is valid per batch
127+ for i in range (batch_size ):
128+ if not mask [i ].any ():
129+ mask [i , 0 ] = True
130+
131+ # This should not raise a ValueError about Simplex constraint
132+ dist .apply_masking (mask )
133+
134+ # Verify that probs are valid (sum to 1.0)
135+ prob_sums = dist .probs .sum (dim = - 1 )
136+ assert th .allclose (prob_sums , th .ones_like (prob_sums ), atol = 1e-6 ), f"Probs don't sum to 1: { prob_sums } "
137+
138+ # Verify entropy can be computed without NaN/inf issues
139+ entropy = dist .entropy ()
140+ assert th .isfinite (entropy ).all (), f"Entropy not finite: { entropy } "
141+
142+ # Verify masked actions have very low or zero probability
143+ # With -inf masking (PR #302 fix), they should be exactly 0
144+ # With -1e8 masking (old code), they would be very small but non-zero
145+ masked_actions = ~ mask
146+ if masked_actions .any ():
147+ masked_probs = dist .probs [masked_actions ]
148+ # After PR #302, masked probabilities should be exactly or very close to 0
149+ assert th .allclose (masked_probs , th .zeros_like (masked_probs ), atol = 1e-7 ), f"Masked probs not near zero: { masked_probs [:10 ]} "
150+
151+ # Test with None mask (removing masking)
152+ dist .apply_masking (None )
153+ prob_sums = dist .probs .sum (dim = - 1 )
154+ assert th .allclose (prob_sums , th .ones_like (prob_sums ), atol = 1e-6 )
155+
156+ def test_entropy_with_all_but_one_masked (self ):
157+ """
158+ Test entropy calculation when all but one action is masked.
159+ This is an edge case that should result in zero entropy (no uncertainty).
160+ Related to issue #81 and PR #302.
161+ """
162+ NUM_DIMS = 5
163+ logits = th .randn (10 , NUM_DIMS ) # Random logits for batch of 10
164+
165+ dist = MaskableCategorical (logits = logits , validate_args = True )
166+
167+ # Mask all but one action (different valid action for each batch element)
168+ for i in range (NUM_DIMS ):
169+ mask = th .zeros (10 , NUM_DIMS , dtype = th .bool )
170+ mask [:, i ] = True # Only action i is valid
171+
172+ dist .apply_masking (mask )
173+
174+ # With only one valid action, entropy should be 0 (or very close to 0)
175+ entropy = dist .entropy ()
176+ assert th .allclose (entropy , th .zeros_like (entropy ), atol = 1e-5 )
177+
178+ # The valid action should have probability 1.0
179+ assert th .allclose (dist .probs [:, i ], th .ones (10 ), atol = 1e-5 )
180+
181+ # All other actions should have probability 0.0
182+ for j in range (NUM_DIMS ):
183+ if j != i :
184+ assert th .allclose (dist .probs [:, j ], th .zeros (10 ), atol = 1e-5 )
185+
186+ def test_repeated_masking_stability (self ):
187+ """
188+ Test that repeatedly applying different masks maintains numerical stability.
189+ This test verifies the fix from PR #302 where pristine logits are stored
190+ and used for each masking operation, avoiding accumulated numerical errors.
191+ Related to issue #81 and PR #302.
192+ """
193+ # Start with some logits
194+ original_logits = th .tensor ([[1.0 , 2.0 , 3.0 , 4.0 , 5.0 ]])
195+ dist = MaskableCategorical (logits = original_logits .clone (), validate_args = True )
196+
197+ # Apply a series of different masks
198+ masks = [
199+ th .tensor ([[True , True , False , False , False ]]),
200+ th .tensor ([[False , False , True , True , True ]]),
201+ th .tensor ([[True , False , True , False , True ]]),
202+ th .tensor ([[False , True , False , True , False ]]),
203+ th .tensor ([[True , True , True , True , True ]]), # All valid
204+ ]
205+
206+ for mask in masks :
207+ dist .apply_masking (mask )
208+
209+ # Verify probabilities are valid
210+ prob_sum = dist .probs .sum (dim = - 1 )
211+ assert th .allclose (prob_sum , th .ones_like (prob_sum ), atol = 1e-6 ), f"Probs sum: { prob_sum } , expected 1.0"
212+
213+ # Verify masked actions have 0 probability
214+ masked_out = ~ mask
215+ masked_probs = dist .probs [masked_out ]
216+ if masked_probs .numel () > 0 :
217+ assert th .allclose (masked_probs , th .zeros_like (masked_probs ), atol = 1e-7 ), f"Masked probs: { masked_probs } "
218+
219+ # Verify entropy is finite and non-negative
220+ entropy = dist .entropy ()
221+ assert th .isfinite (entropy ).all (), f"Entropy contains inf/nan: { entropy } "
222+ assert (entropy >= 0 ).all (), f"Entropy is negative: { entropy } "
223+
224+ # After all masks, remove masking and verify we get consistent results
225+ dist .apply_masking (None )
226+ prob_sum = dist .probs .sum (dim = - 1 )
227+ assert th .allclose (prob_sum , th .ones_like (prob_sum ), atol = 1e-6 )
228+
229+ def test_masked_actions_have_zero_probability (self ):
230+ """
231+ Test that masked actions have exactly zero probability with proper masking.
232+
233+ This test verifies that masked actions get zero probability, which is important
234+ for the fix in PR #302. While both -1e8 and -inf produce zero probabilities
235+ after softmax due to underflow, using -inf is mathematically more correct
236+ and avoids potential numerical issues in edge cases.
237+
238+ Related to issue #81 and PR #302.
239+ """
240+ # Test with various logit scales
241+ test_logits = [
242+ th .tensor ([[0.0 , 1.0 , 2.0 , 3.0 ]]), # Small scale
243+ th .tensor ([[10.0 , 20.0 , 30.0 , 40.0 ]]), # Large scale
244+ th .randn (5 , 10 ) * 5.0 , # Random batch
245+ ]
246+
247+ for logits in test_logits :
248+ dist = MaskableCategorical (logits = logits , validate_args = True )
249+
250+ # Create a mask that masks out alternating actions
251+ mask = th .zeros_like (logits , dtype = th .bool )
252+ mask [:, ::2 ] = True # Keep even indices, mask odd indices
253+
254+ dist .apply_masking (mask )
255+
256+ # Check that masked actions have exactly zero probability
257+ masked_indices = ~ mask
258+ if masked_indices .any ():
259+ masked_probs = dist .probs [masked_indices ]
260+ # Both old (-1e8) and new (-inf) implementations should produce 0 here
261+ # due to softmax underflow, but -inf is more robust
262+ assert th .allclose (masked_probs , th .zeros_like (masked_probs ), atol = 1e-10 ), f"Masked actions should have ~0 probability, got: { masked_probs [:10 ]} "
263+
264+ # Verify unmasked actions have non-zero probabilities
265+ unmasked_probs = dist .probs [mask ]
266+ assert th .all (unmasked_probs > 0.0 ), "Unmasked actions should have positive probability"
267+
268+ # Verify probabilities sum to 1
269+ prob_sums = dist .probs .sum (dim = - 1 )
270+ assert th .allclose (prob_sums , th .ones_like (prob_sums ), atol = 1e-6 )
271+
272+ def test_entropy_numerical_stability_with_masking (self ):
273+ """
274+ Test entropy calculation numerical stability with masked actions.
275+
276+ This specifically tests the improved entropy calculation from PR #302.
277+ The old entropy calculation could have issues with -1e8 logits, while
278+ the new calculation properly handles -inf values.
279+
280+ Related to issue #81 and PR #302.
281+ """
282+ # Test with various scenarios including edge cases
283+ test_cases = [
284+ # All but one action masked (entropy should be ~0)
285+ (th .tensor ([[1.0 , 2.0 , 3.0 , 4.0 ]]), th .tensor ([[True , False , False , False ]])),
286+ # Half actions masked
287+ (th .tensor ([[1.0 , 2.0 , 3.0 , 4.0 ]]), th .tensor ([[True , True , False , False ]])),
288+ # Large batch with various masks
289+ (th .randn (32 , 10 ), th .rand (32 , 10 ) > 0.3 ),
290+ ]
291+
292+ for logits , mask in test_cases :
293+ # Ensure at least one action is valid per batch
294+ for i in range (mask .shape [0 ]):
295+ if not mask [i ].any ():
296+ mask [i , 0 ] = True
297+
298+ dist = MaskableCategorical (logits = logits , validate_args = True )
299+ dist .apply_masking (mask )
300+
301+ # Compute entropy - should not produce NaN or inf
302+ entropy = dist .entropy ()
303+ assert th .isfinite (entropy ).all (), f"Entropy should be finite, got: { entropy } "
304+ assert (entropy >= 0 ).all (), f"Entropy should be non-negative, got: { entropy } "
305+
306+ # For single valid action, entropy should be close to 0
307+ single_action_mask = mask .sum (dim = - 1 ) == 1
308+ if single_action_mask .any ():
309+ single_action_entropy = entropy [single_action_mask ]
310+ assert th .allclose (single_action_entropy , th .zeros_like (single_action_entropy ), atol = 1e-4 ), f"Single action entropy should be ~0, got: { single_action_entropy } "
311+
76312
77313class TestMaskableCategoricalDistribution :
78314 def test_distribution_must_be_initialized (self ):
0 commit comments