Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: major
changes:
changed:
- Replace global execution counter in random() with variable-name-based salting for order-independent reproducibility. This is a breaking change that will produce different random values for existing simulations.
77 changes: 65 additions & 12 deletions policyengine_core/commons/formulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,38 +305,91 @@ def amount_between(
return clip(amount, threshold_1, threshold_2) - threshold_1


def random(population):
def _stable_string_hash(s: str) -> np.uint64:
"""Deterministic hash consistent across Python processes.

Python's built-in hash() is not deterministic across processes (since 3.3),
so we use a polynomial rolling hash with additional mixing.
"""
import warnings

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "overflow encountered", RuntimeWarning
)
h = np.uint64(0)
for byte in s.encode("utf-8"):
h = h * np.uint64(31) + np.uint64(byte)
h = h ^ (h >> np.uint64(33))
h = h * np.uint64(0xFF51AFD7ED558CCD)
h = h ^ (h >> np.uint64(33))
return h


def random(population, salt: str = None):
"""
Generate random values for each entity in the population.

Uses the current variable name (from the tracer stack) to create
deterministic random values that are independent of variable execution
order. This prevents the "ripple effect" where adding or reordering
variables would change random values for all subsequent variables.

Args:
population: The population object containing simulation data.
salt: Optional salt string to differentiate random streams when
called outside a formula context or to create distinct
streams within the same formula.

Returns:
np.ndarray: Array of random values for each entity.
np.ndarray: Array of random values in [0, 1) for each entity.

Raises:
ValueError: If called outside a formula context without a salt.
"""
# Initialize count of random calls if not already present
if not hasattr(population.simulation, "count_random_calls"):
population.simulation.count_random_calls = 0
population.simulation.count_random_calls += 1
simulation = population.simulation

# Get variable name from tracer stack
if simulation.tracer.stack:
variable_name = simulation.tracer.stack[-1]["name"]
elif salt is not None:
variable_name = "__no_variable__"
else:
raise ValueError(
"random() called outside formula context without a salt. "
"Provide salt parameter: random(population, salt='unique_name')"
)

# Per-variable counter (replaces global counter)
if not hasattr(simulation, "random_call_counts"):
simulation.random_call_counts = {}

counter_key = f"{variable_name}:{salt or ''}"
simulation.random_call_counts[counter_key] = (
simulation.random_call_counts.get(counter_key, 0) + 1
)
call_count = simulation.random_call_counts[counter_key]

# Get known periods or use default calculation period
known_periods = population.simulation.get_holder(
known_periods = simulation.get_holder(
f"{population.entity.key}_id"
).get_known_periods()
period = (
known_periods[0]
if known_periods
else population.simulation.default_calculation_period
else simulation.default_calculation_period
)

# Get entity IDs for the period
entity_ids = population(f"{population.entity.key}_id", period)

# Generate deterministic random values using vectorised hash
seeds = np.abs(
entity_ids * 100 + population.simulation.count_random_calls
).astype(np.uint64)
# Create deterministic seed from variable name + call count
base_seed = _stable_string_hash(
f"{variable_name}:{call_count}:{salt or ''}"
)

# Combine with entity IDs via XOR
seeds = (entity_ids.astype(np.uint64) ^ base_seed).astype(np.uint64)

# PCG-style mixing function for high-quality pseudo-random generation
x = seeds * np.uint64(0x5851F42D4C957F2D)
Expand Down
Loading