From eb72019d6f55a5e3b1901ab17a503397e7c5b17b Mon Sep 17 00:00:00 2001 From: Mario Vega Date: Fri, 20 Feb 2026 21:19:35 +0100 Subject: [PATCH 1/4] refactor(testing): Refactor specs to allow reusing logic for benchmark checks (#2203) * feat(testing): Add check to verify no benchmark test exceeds limit refactor(testing/specs): Abstract and reuse logic refactor(testing/specs): Remove `tag` field fix(test-filler): Wrapper test classes overwriting `BaseTest.spec_types` refactor(tests): Kill unused `tag` from every test fix(test-execute): Fix execute fix: tox * fix: review comments * fix: upstream merge issue * Apply suggestion from @marioevz * fix: Bug and typing Co-authored-by: spencer-tb --------- Co-authored-by: spencer-tb --- .../plugins/execute/execute.py | 23 +++- .../pytest_commands/plugins/filler/filler.py | 66 ++++++---- .../filler/tests/test_prealloc_group.py | 44 ++----- .../plugins/shared/execute_fill.py | 21 +++ .../src/execution_testing/execution/base.py | 12 +- .../execution/blob_transaction.py | 15 ++- .../execution/transaction_post.py | 52 +++----- .../src/execution_testing/fixtures/base.py | 15 ++- .../fixtures/tests/test_base.py | 1 + .../fixtures/tests/test_blockchain.py | 4 +- .../fixtures/tests/test_collector.py | 1 + .../src/execution_testing/specs/base.py | 124 ++++++++++-------- .../src/execution_testing/specs/benchmark.py | 13 +- .../src/execution_testing/specs/blobs.py | 5 +- .../src/execution_testing/specs/blockchain.py | 87 ++++++------ .../src/execution_testing/specs/state.py | 54 ++++---- .../specs/tests/test_expect.py | 15 ++- .../specs/tests/test_fixtures.py | 76 ++++++----- .../specs/tests/test_transaction.py | 16 ++- .../execution_testing/specs/transaction.py | 17 ++- .../eip4844_blobs/test_excess_blob_gas.py | 57 -------- .../test_warm_coinbase.py | 2 - .../eip4895_withdrawals/test_withdrawals.py | 4 +- 23 files changed, 374 insertions(+), 350 deletions(-) diff --git a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/execute.py b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/execute.py index 7c1ce4deea..bd94f44d12 100644 --- a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/execute.py +++ b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/execute.py @@ -625,6 +625,8 @@ def base_test_parametrizer_func( max_fee_per_blob_gas: int, max_gas_limit_per_test: int | None, gas_limit_accumulator: GasInfoAccumulator, + is_tx_gas_heavy_test: bool, + is_exception_test: bool, ) -> Type[BaseTest]: """ Fixture used to instantiate an auto-fillable BaseTest object from @@ -648,6 +650,8 @@ def base_test_parametrizer_func( ) class BaseTestWrapper(cls): # type: ignore + __is_base_test_wrapper__ = True + def __init__(self, *args: Any, **kwargs: Any) -> None: if "pre" not in kwargs: kwargs["pre"] = pre @@ -655,10 +659,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: raise ValueError( "The pre-alloc object was modified by the test." ) - # Set default for expected_benchmark_gas_used if "expected_benchmark_gas_used" not in kwargs: kwargs["expected_benchmark_gas_used"] = gas_benchmark_value kwargs["fork"] = fork + kwargs["operation_mode"] = request.config.op_mode + kwargs["is_tx_gas_heavy_test"] = is_tx_gas_heavy_test + kwargs["is_exception_test"] = is_exception_test kwargs |= { p: request.getfixturevalue(p) for p in cls_fixture_parameters @@ -668,7 +674,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: request.node.config.sender_address = str(pre._sender) super(BaseTestWrapper, self).__init__(*args, **kwargs) - self._request = request execute = self.execute(execute_format=execute_format) # get balances of required sender accounts @@ -730,12 +735,17 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: [str(eoa) for eoa in pre._funded_eoa] ) - execute.execute( + execute_result = execute.execute( fork=fork, eth_rpc=eth_rpc, engine_rpc=engine_rpc, request=request, ) + self.validate_benchmark_gas( + benchmark_gas_used=execute_result.benchmark_gas_used, + gas_benchmark_value=gas_benchmark_value, + ) + collector.collect(request.node.nodeid, execute) return BaseTestWrapper @@ -744,7 +754,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # Dynamically generate a pytest fixture for each test spec type. -for cls in BaseTest.spec_types.values(): +for name, cls in BaseTest.spec_types.items(): + if getattr(cls, "__is_base_test_wrapper__", False): + raise RuntimeError( + f"Test spec type {name}: {cls.__name__} is already wrapped. " + f"{BaseTest.spec_types.items()}." + ) # Fixture needs to be defined in the global scope so pytest can detect it. globals()[cls.pytest_parameter_name()] = base_test_parametrizer(cls) diff --git a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/filler.py b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/filler.py index bb9dfea4f9..65b8b8341c 100644 --- a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/filler.py +++ b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/filler.py @@ -42,6 +42,7 @@ from execution_testing.fixtures import ( BaseFixture, BlockchainEngineFixture, + BlockchainEngineXFixture, BlockchainFixture, FixtureCollector, FixtureConsumer, @@ -66,7 +67,7 @@ get_transition_forks, ) from execution_testing.specs import BaseTest -from execution_testing.specs.base import OpMode +from execution_testing.specs.base import FillResult, OpMode from execution_testing.test_types import EnvironmentDefaults from execution_testing.tools.utility.versioning import ( generate_github_url, @@ -1183,7 +1184,7 @@ def verify_fixtures_bin(request: pytest.FixtureRequest) -> Path | None: def session_t8n( request: pytest.FixtureRequest, ) -> Generator[TransitionTool, None, None]: - """Return configured transition tool.""" + """Return configured transition tool for the session.""" t8n: TransitionTool = request.config.t8n # type: ignore if not t8n.exception_mapper.reliable: t8n_name = t8n.__class__.__name__ @@ -1600,6 +1601,8 @@ def base_test_parametrizer_func( fixture_source_url: str, gas_benchmark_value: int, fixed_opcode_count: int | None, + is_tx_gas_heavy_test: bool, + is_exception_test: bool, ) -> Any: """ Fixture used to instantiate an auto-fillable BaseTest object from @@ -1623,12 +1626,27 @@ def base_test_parametrizer_func( fork = request.node.fork class BaseTestWrapper(cls): # type: ignore + __is_base_test_wrapper__ = True + def __init__(self, *args: Any, **kwargs: Any) -> None: if "pre" not in kwargs: kwargs["pre"] = pre if "expected_benchmark_gas_used" not in kwargs: kwargs["expected_benchmark_gas_used"] = gas_benchmark_value kwargs["fork"] = fork + op_mode: OpMode = request.config.op_mode # type: ignore + kwargs["operation_mode"] = op_mode + kwargs["is_tx_gas_heavy_test"] = is_tx_gas_heavy_test + kwargs["is_exception_test"] = is_exception_test + if ( + op_mode == OpMode.OPTIMIZE_GAS + or op_mode == OpMode.OPTIMIZE_GAS_POST_PROCESSING + ): + kwargs["gas_optimization_max_gas_limit"] = ( + request.config.getoption( + "optimize_gas_max_gas_limit", None + ) + ) kwargs |= { p: request.getfixturevalue(p) for p in cls_fixture_parameters @@ -1636,20 +1654,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: } super(BaseTestWrapper, self).__init__(*args, **kwargs) - self._request = request - self._operation_mode = ( - request.config.op_mode # type: ignore[attr-defined] - ) - if ( - self._operation_mode == OpMode.OPTIMIZE_GAS - or self._operation_mode - == OpMode.OPTIMIZE_GAS_POST_PROCESSING - ): - self._gas_optimization_max_gas_limit = ( - request.config.getoption( - "optimize_gas_max_gas_limit", None - ) - ) # Get the filling session from config session: FillingSession = request.config.filling_session # type: ignore @@ -1703,16 +1707,16 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: ) group = session.get_pre_alloc_group(pre_alloc_hash) self.pre = group.pre + fill_result: FillResult | None = None try: - fixture = self.generate( + fill_result = self.generate( t8n=t8n, fixture_format=fixture_format, ) finally: if ( - request.config.op_mode # type: ignore[attr-defined] - == OpMode.OPTIMIZE_GAS - or request.config.op_mode # type: ignore[attr-defined] + self.operation_mode == OpMode.OPTIMIZE_GAS + or self.operation_mode == OpMode.OPTIMIZE_GAS_POST_PROCESSING ): gas_optimized_tests = ( @@ -1723,8 +1727,17 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # None, to keep track of failed tests in the output # file. gas_optimized_tests[request.node.nodeid] = ( - self._gas_optimization + fill_result.gas_optimization + if fill_result is not None + else None ) + assert fill_result is not None + fixture = fill_result.fixture + # If operation mode is benchmarking, check the gas used. + self.validate_benchmark_gas( + benchmark_gas_used=fill_result.benchmark_gas_used, + gas_benchmark_value=gas_benchmark_value, + ) # Post-process for Engine X format (add pre_hash and state # diff) @@ -1733,6 +1746,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: in fixture_format.format_phases and pre_alloc_hash is not None ): + # TODO: This should be handled by the `generate` method + # of the spec. + assert isinstance(fixture, BlockchainEngineXFixture) fixture.pre_hash = pre_alloc_hash # Calculate state diff for efficiency @@ -1749,6 +1765,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: t8n.version(), test_case_description, fixture_source_url=fixture_source_url, + opcode_count=t8n.opcode_count, ref_spec=reference_spec, _info_metadata=t8n._info_metadata, ) @@ -1773,7 +1790,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # Dynamically generate a pytest fixture for each test spec type. -for cls in BaseTest.spec_types.values(): +for name, cls in BaseTest.spec_types.items(): + if getattr(cls, "__is_base_test_wrapper__", False): + raise RuntimeError( + f"Test spec type {name}: {cls.__name__} is already wrapped. " + f"{BaseTest.spec_types.items()}." + ) # Fixture needs to be defined in the global scope so pytest can detect it. globals()[cls.pytest_parameter_name()] = base_test_parametrizer(cls) diff --git a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/tests/test_prealloc_group.py b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/tests/test_prealloc_group.py index e9f8ec6663..736d663386 100644 --- a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/tests/test_prealloc_group.py +++ b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/tests/test_prealloc_group.py @@ -8,9 +8,9 @@ import pytest from execution_testing.base_types import Address -from execution_testing.fixtures import BaseFixture, PreAllocGroups +from execution_testing.fixtures import PreAllocGroups from execution_testing.forks import Fork, Prague -from execution_testing.specs.base import BaseTest +from execution_testing.specs.base import BaseTest, FillResult from execution_testing.test_types import Environment from execution_testing.vm import Op @@ -30,7 +30,6 @@ def __init__( pre: Alloc, fork: Fork, genesis_environment: Environment, - request: Mock | None = None, ) -> None: """Initialize mock test.""" super().__init__( # type: ignore @@ -38,9 +37,8 @@ def __init__( fork=fork, genesis_environment=genesis_environment, ) - self._request = request - def generate(self, *args: Any, **kwargs: Any) -> BaseFixture: + def generate(self, *args: Any, **kwargs: Any) -> FillResult: """Mock generate method.""" raise NotImplementedError("This is a mock test class") @@ -98,9 +96,7 @@ def test_pre_alloc_group_separate() -> None: mock_marker.args = ("separate",) mock_request.node.get_closest_marker = Mock(return_value=mock_marker) - test2 = MockTest( - pre=pre, genesis_environment=env, request=mock_request, fork=fork - ) + test2 = MockTest(pre=pre, genesis_environment=env, fork=fork) genesis_env2 = test2.get_genesis_environment() # For "separate" marker, use the node ID as the salt hash2 = pre.compute_pre_alloc_group_hash( @@ -136,9 +132,7 @@ def test_pre_alloc_group_custom_salt() -> None: mock_marker1.args = ("eip1234",) mock_request1.node.get_closest_marker = Mock(return_value=mock_marker1) - test1 = MockTest( - pre=pre, genesis_environment=env, request=mock_request1, fork=fork - ) + test1 = MockTest(pre=pre, genesis_environment=env, fork=fork) genesis_env1 = test1.get_genesis_environment() hash1 = pre.compute_pre_alloc_group_hash( fork=fork, genesis_environment=genesis_env1, group_salt="eip1234" @@ -154,9 +148,7 @@ def test_pre_alloc_group_custom_salt() -> None: mock_marker2.args = ("eip1234",) # Same group mock_request2.node.get_closest_marker = Mock(return_value=mock_marker2) - test2 = MockTest( - pre=pre, genesis_environment=env, request=mock_request2, fork=fork - ) + test2 = MockTest(pre=pre, genesis_environment=env, fork=fork) genesis_env2 = test2.get_genesis_environment() hash2 = pre.compute_pre_alloc_group_hash( fork=fork, genesis_environment=genesis_env2, group_salt="eip1234" @@ -173,9 +165,7 @@ def test_pre_alloc_group_custom_salt() -> None: mock_marker3.args = ("eip5678",) # Different group mock_request3.node.get_closest_marker = Mock(return_value=mock_marker3) - test3 = MockTest( - pre=pre, genesis_environment=env, request=mock_request3, fork=fork - ) + test3 = MockTest(pre=pre, genesis_environment=env, fork=fork) genesis_env3 = test3.get_genesis_environment() hash3 = pre.compute_pre_alloc_group_hash( fork=fork, genesis_environment=genesis_env3, group_salt="eip5678" @@ -200,9 +190,7 @@ def test_pre_alloc_group_separate_different_nodeids() -> None: mock_marker1.args = ("separate",) mock_request1.node.get_closest_marker = Mock(return_value=mock_marker1) - test1 = MockTest( - pre=pre, genesis_environment=env, request=mock_request1, fork=fork - ) + test1 = MockTest(pre=pre, genesis_environment=env, fork=fork) genesis_env1 = test1.get_genesis_environment() hash1 = pre.compute_pre_alloc_group_hash( fork=fork, @@ -218,9 +206,7 @@ def test_pre_alloc_group_separate_different_nodeids() -> None: mock_marker2.args = ("separate",) mock_request2.node.get_closest_marker = Mock(return_value=mock_marker2) - test2 = MockTest( - pre=pre, genesis_environment=env, request=mock_request2, fork=fork - ) + test2 = MockTest(pre=pre, genesis_environment=env, fork=fork) genesis_env2 = test2.get_genesis_environment() hash2 = pre.compute_pre_alloc_group_hash( fork=fork, @@ -244,9 +230,7 @@ def test_no_pre_alloc_group_marker() -> None: mock_request.node.nodeid = "test_module.py::test_function" mock_request.node.get_closest_marker = Mock(return_value=None) # No marker - test1 = MockTest( - pre=pre, genesis_environment=env, request=mock_request, fork=fork - ) + test1 = MockTest(pre=pre, genesis_environment=env, fork=fork) genesis_env1 = test1.get_genesis_environment() hash1 = pre.compute_pre_alloc_group_hash( fork=fork, genesis_environment=genesis_env1, group_salt=None @@ -280,9 +264,7 @@ def test_pre_alloc_group_with_reason() -> None: } mock_request1.node.get_closest_marker = Mock(return_value=mock_marker1) - test1 = MockTest( - pre=pre, genesis_environment=env, request=mock_request1, fork=fork - ) + test1 = MockTest(pre=pre, genesis_environment=env, fork=fork) genesis_env1 = test1.get_genesis_environment() hash1 = pre.compute_pre_alloc_group_hash( fork=fork, @@ -299,9 +281,7 @@ def test_pre_alloc_group_with_reason() -> None: mock_marker2.kwargs = {"reason": "Different reason but same group"} mock_request2.node.get_closest_marker = Mock(return_value=mock_marker2) - test2 = MockTest( - pre=pre, genesis_environment=env, request=mock_request2, fork=fork - ) + test2 = MockTest(pre=pre, genesis_environment=env, fork=fork) genesis_env2 = test2.get_genesis_environment() hash2 = pre.compute_pre_alloc_group_hash( fork=fork, diff --git a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/shared/execute_fill.py b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/shared/execute_fill.py index a0c877bdc3..16152b8199 100644 --- a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/shared/execute_fill.py +++ b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/shared/execute_fill.py @@ -304,6 +304,27 @@ def alloc_flags( return alloc_flags_from_test_markers +@pytest.fixture(scope="function") +def is_tx_gas_heavy_test(request: pytest.FixtureRequest) -> bool: + """ + Check, given the test node properties, whether the test is gas-heavy + for transaction execution. + """ + node = request.node + has_slow_marker = node.get_closest_marker("slow") is not None + has_benchmark_marker = node.get_closest_marker("benchmark") is not None + return has_slow_marker or has_benchmark_marker + + +@pytest.fixture(scope="function") +def is_exception_test(request: pytest.FixtureRequest) -> bool: + """ + Check, given the test node properties, whether the test is an exception + test (invalid block, invalid transaction). + """ + return request.node.get_closest_marker("exception_test") is not None + + def pytest_addoption(parser: pytest.Parser) -> None: """Add command-line options to pytest.""" static_filler_group = parser.getgroup( diff --git a/packages/testing/src/execution_testing/execution/base.py b/packages/testing/src/execution_testing/execution/base.py index dacd8a121b..11cd9248cb 100644 --- a/packages/testing/src/execution_testing/execution/base.py +++ b/packages/testing/src/execution_testing/execution/base.py @@ -11,9 +11,19 @@ from execution_testing.rpc import EngineRPC, EthRPC +class ExecuteResult(CamelModel): + """ + Result of the execute operation. + """ + + benchmark_gas_used: int | None = None + + class BaseExecute(CamelModel): """Represents a base execution format.""" + benchmark_mode: bool = False + # Base Execute class properties formats: ClassVar[Dict[str, Type["BaseExecute"]]] = {} @@ -56,7 +66,7 @@ def execute( eth_rpc: EthRPC, engine_rpc: EngineRPC | None, request: FixtureRequest, - ) -> None: + ) -> ExecuteResult: """Execute the format.""" pass diff --git a/packages/testing/src/execution_testing/execution/blob_transaction.py b/packages/testing/src/execution_testing/execution/blob_transaction.py index 7ffe16641f..75762d947a 100644 --- a/packages/testing/src/execution_testing/execution/blob_transaction.py +++ b/packages/testing/src/execution_testing/execution/blob_transaction.py @@ -26,7 +26,7 @@ TransactionTestMetadata, ) -from .base import BaseExecute +from .base import BaseExecute, ExecuteResult logger = get_logger(__name__) @@ -104,7 +104,7 @@ def execute( eth_rpc: EthRPC, engine_rpc: EngineRPC | None, request: FixtureRequest, - ) -> None: + ) -> ExecuteResult: """Execute the format.""" versioned_hashes: Dict[Hash, BlobAndProofV1 | BlobAndProofV2] = {} sent_txs: List[Transaction] = [] @@ -140,7 +140,9 @@ def execute( logger.info( "Engine RPC is not available, skipping getBlobsV* validation." ) - return + return ExecuteResult( + benchmark_gas_used=None, + ) version = fork.engine_get_blobs_version() assert version is not None, ( @@ -172,7 +174,9 @@ def execute( "the client correctly returned 'null')" ) eth_rpc.wait_for_transactions(sent_txs) - return + return ExecuteResult( + benchmark_gas_used=None, + ) assert blob_response is not None local_blobs_and_proofs = list(versioned_hashes.values()) @@ -250,3 +254,6 @@ def execute( ) eth_rpc.wait_for_transactions(sent_txs) + return ExecuteResult( + benchmark_gas_used=None, + ) diff --git a/packages/testing/src/execution_testing/execution/transaction_post.py b/packages/testing/src/execution_testing/execution/transaction_post.py index 347d762832..ae6032119a 100644 --- a/packages/testing/src/execution_testing/execution/transaction_post.py +++ b/packages/testing/src/execution_testing/execution/transaction_post.py @@ -20,7 +20,7 @@ TransactionTestMetadata, ) -from .base import BaseExecute +from .base import BaseExecute, ExecuteResult logger = get_logger(__name__) @@ -32,13 +32,6 @@ class TransactionPost(BaseExecute): blocks: List[List[Transaction]] post: Alloc - # Gas validation fields for benchmark tests - expected_benchmark_gas_used: int | None = ( - None # Expected total gas to be consumed - ) - skip_gas_used_validation: bool = ( - False # Skip gas validation even if expected is set - ) format_name: ClassVar[str] = "transaction_post_test" description: ClassVar[str] = ( @@ -78,7 +71,7 @@ def execute( eth_rpc: EthRPC, engine_rpc: EngineRPC | None, request: FixtureRequest, - ) -> None: + ) -> ExecuteResult: """Execute the format.""" del fork del engine_rpc @@ -92,7 +85,8 @@ def execute( ) # Track transaction hashes for gas validation (benchmarking) - all_tx_hashes = [] + all_tx_hashes: List[Hash] = [] + last_block_tx_hashes: List[Hash] = [] for block in self.blocks: signed_txs: List[Transaction] = [] @@ -117,11 +111,12 @@ def execute( tx_index=tx_index, ) signed_txs.append(tx) + current_block_tx_hashes: List[Hash] = [] if any(tx.error is not None for tx in signed_txs): for transaction in signed_txs: if transaction.error is None: eth_rpc.send_wait_transaction(transaction) - all_tx_hashes.append(transaction.hash) + current_block_tx_hashes.append(transaction.hash) else: logger.info( f"Sending transaction expecting rejection " @@ -138,32 +133,21 @@ def execute( else: # Send transactions (batching is handled by eth_rpc internally) eth_rpc.send_wait_transactions(signed_txs) - all_tx_hashes.extend([tx.hash for tx in signed_txs]) - - # Perform gas validation if required for benchmarking - # Ensures benchmark tests consume exactly the expected gas - if ( - not self.skip_gas_used_validation - and self.expected_benchmark_gas_used is not None - ): - total_gas_used = 0 - # Fetch transaction receipts to get actual gas used - for tx_hash in all_tx_hashes: + current_block_tx_hashes = [tx.hash for tx in signed_txs] + all_tx_hashes.extend(current_block_tx_hashes) + last_block_tx_hashes = current_block_tx_hashes + + # Fetch transaction receipts to get actual gas used + benchmark_gas_used: int | None = None + if self.benchmark_mode: + benchmark_gas_used = 0 + for tx_hash in last_block_tx_hashes: receipt = eth_rpc.get_transaction_receipt(tx_hash) assert receipt is not None, ( f"Failed to get receipt for transaction {tx_hash}" ) gas_used = int(receipt["gasUsed"], 16) - total_gas_used += gas_used - - # Verify that the total gas consumed matches expectations - expected_gas = self.expected_benchmark_gas_used - diff = total_gas_used - expected_gas - assert total_gas_used == expected_gas, ( - f"Total gas used ({total_gas_used}) does not match " - f"expected benchmark gas ({expected_gas}), " - f"difference: {diff}" - ) + benchmark_gas_used += gas_used for address, account in self.post.root.items(): balance = eth_rpc.get_balance(address) @@ -204,3 +188,7 @@ def execute( f"Storage value at {key} of {address} is " f"{storage_value}, expected {value}." ) + + return ExecuteResult( + benchmark_gas_used=benchmark_gas_used, + ) diff --git a/packages/testing/src/execution_testing/fixtures/base.py b/packages/testing/src/execution_testing/fixtures/base.py index ce0395dc0a..fa2dc82111 100644 --- a/packages/testing/src/execution_testing/fixtures/base.py +++ b/packages/testing/src/execution_testing/fixtures/base.py @@ -29,6 +29,7 @@ from pydantic_core.core_schema import ValidatorFunctionWrapHandler from execution_testing.base_types import CamelModel, ReferenceSpec +from execution_testing.client_clis.cli_types import OpcodeCount from execution_testing.forks import Fork @@ -152,13 +153,22 @@ def json_dict_with_info(self, hash_only: bool = False) -> Dict[str, Any]: dict_with_info["_info"].update(self.info) return dict_with_info + def model_post_init(self, __context: Any, /) -> None: + """ + Model post-init to assert that the custom pre-allocation was + provided and the default was not used. + """ + super().model_post_init(__context) + self.info["fixture-format"] = self.format_name + def fill_info( self, t8n_version: str, test_case_description: str, fixture_source_url: str, + opcode_count: OpcodeCount | None, ref_spec: ReferenceSpec | None, - _info_metadata: Dict[str, Any], + _info_metadata: Dict[str, Any] | None, ) -> None: """Fill the info field for this fixture.""" if "comment" not in self.info: @@ -166,7 +176,8 @@ def fill_info( self.info["filling-transition-tool"] = t8n_version self.info["description"] = test_case_description self.info["url"] = fixture_source_url - self.info["fixture-format"] = self.format_name + if opcode_count is not None: + self.info["opcode_count"] = opcode_count.model_dump() if ref_spec is not None: ref_spec.write_info(self.info) if _info_metadata: diff --git a/packages/testing/src/execution_testing/fixtures/tests/test_base.py b/packages/testing/src/execution_testing/fixtures/tests/test_base.py index 045c137da5..892eb5ac8b 100644 --- a/packages/testing/src/execution_testing/fixtures/tests/test_base.py +++ b/packages/testing/src/execution_testing/fixtures/tests/test_base.py @@ -149,6 +149,7 @@ def test_base_fixtures_parsing(fixture: BaseFixture) -> None: "t8n-version", "test_case_description", fixture_source_url="fixture_source_url", + opcode_count=None, ref_spec=None, _info_metadata={}, ) diff --git a/packages/testing/src/execution_testing/fixtures/tests/test_blockchain.py b/packages/testing/src/execution_testing/fixtures/tests/test_blockchain.py index 126d2a9df3..e8c68fd7ea 100644 --- a/packages/testing/src/execution_testing/fixtures/tests/test_blockchain.py +++ b/packages/testing/src/execution_testing/fixtures/tests/test_blockchain.py @@ -1060,7 +1060,9 @@ ], ), { - "_info": {}, + "_info": { + "fixture-format": "blockchain_test_stateful_engine", + }, "network": "Prague", "postStateHash": Hash(2).hex(), "lastblockhash": Hash(1).hex(), diff --git a/packages/testing/src/execution_testing/fixtures/tests/test_collector.py b/packages/testing/src/execution_testing/fixtures/tests/test_collector.py index 4b41f1c6d4..93b6e3794f 100644 --- a/packages/testing/src/execution_testing/fixtures/tests/test_collector.py +++ b/packages/testing/src/execution_testing/fixtures/tests/test_collector.py @@ -22,6 +22,7 @@ def _make_fixture(nonce: int = 0) -> TransactionFixture: f"test description {nonce}", fixture_source_url="http://example.com", ref_spec=None, + opcode_count=None, _info_metadata={}, ) return fixture diff --git a/packages/testing/src/execution_testing/specs/base.py b/packages/testing/src/execution_testing/specs/base.py index 9eec347556..c37237b1b4 100644 --- a/packages/testing/src/execution_testing/specs/base.py +++ b/packages/testing/src/execution_testing/specs/base.py @@ -17,11 +17,12 @@ ) import pytest -from pydantic import BaseModel, ConfigDict, PrivateAttr +from pydantic import BaseModel, ConfigDict from typing_extensions import Self from execution_testing.base_types import to_hex from execution_testing.client_clis import Result, TransitionTool +from execution_testing.client_clis.cli_types import OpcodeCount from execution_testing.execution import ( BaseExecute, ExecuteFormat, @@ -81,6 +82,17 @@ class OpMode(StrEnum): OPTIMIZE_GAS_POST_PROCESSING = "optimize-gas-post-processing" +class FillResult(BaseModel): + """ + Result of the filling operation, returned by the `generate` method. + """ + + fixture: BaseFixture + gas_optimization: int | None + benchmark_gas_used: int | None = None + benchmark_opcode_count: OpcodeCount | None = None + + class BaseTest(BaseModel): """ Represents a base Ethereum test which must return a single test fixture. @@ -88,23 +100,20 @@ class BaseTest(BaseModel): model_config = ConfigDict(extra="forbid") - tag: str = "" fork: Fork = ( BaseFork # type: ignore[type-abstract] # default to BaseFork to allow the filler to set it, # instead of each test having to set it ) - - _request: pytest.FixtureRequest | None = PrivateAttr(None) - _operation_mode: OpMode | None = PrivateAttr(None) - _gas_optimization: int | None = PrivateAttr(None) - _gas_optimization_max_gas_limit: int | None = PrivateAttr(None) - + operation_mode: OpMode | None = None + gas_optimization_max_gas_limit: int | None = None expected_benchmark_gas_used: int | None = None skip_gas_used_validation: bool = False + is_tx_gas_heavy_test: bool = False + is_exception_test: bool = False + # Class variables, to be set by subclasses spec_types: ClassVar[Dict[str, Type["BaseTest"]]] = {} - supported_fixture_formats: ClassVar[ Sequence[FixtureFormat | LabeledFixtureFormat] ] = [] @@ -142,6 +151,12 @@ def __pydantic_init_subclass__(cls, **kwargs: Any) -> None: Register all subclasses of BaseFixture with a fixture format name set as possible fixture formats. """ + super().__pydantic_init_subclass__(**kwargs) + + # Don't register dynamically generated wrapper classes + if getattr(cls, "__is_base_test_wrapper__", False): + return + if cls.pytest_parameter_name(): # Register the new fixture format BaseTest.spec_types[cls.pytest_parameter_name()] = cls @@ -154,16 +169,10 @@ def from_test( **kwargs: Any, ) -> Self: """Create a test in a different format from a base test.""" - new_instance = cls( - tag=base_test.tag, - fork=base_test.fork, - expected_benchmark_gas_used=base_test.expected_benchmark_gas_used, - skip_gas_used_validation=base_test.skip_gas_used_validation, - **kwargs, - ) - new_instance._request = base_test._request - new_instance._operation_mode = base_test._operation_mode - return new_instance + for k in BaseTest.model_fields.keys(): + if k not in kwargs: + kwargs[k] = getattr(base_test, k) + return cls(**kwargs) @classmethod def discard_execute_format_by_marks( @@ -185,8 +194,8 @@ def generate( *, t8n: TransitionTool, fixture_format: FixtureFormat, - ) -> BaseFixture: - """Generate the list of test fixtures.""" + ) -> FillResult: + """Generate the test fixture using the given fixture format.""" pass def execute( @@ -211,48 +220,13 @@ def pytest_parameter_name(cls) -> str: lambda x, y: x + ("_" if y.isupper() else "") + y, cls.__name__ ).lower() - def is_tx_gas_heavy_test(self) -> bool: - """Check if the test is gas-heavy for transaction execution.""" - if self._request is not None and hasattr(self._request, "node"): - node = self._request.node - has_slow_marker = node.get_closest_marker("slow") is not None - has_benchmark_marker = ( - node.get_closest_marker("benchmark") is not None - ) - return has_slow_marker or has_benchmark_marker - return False - - def is_exception_test(self) -> bool | None: - """ - Check if the test is an exception test (invalid block, invalid - transaction). - - `None` is returned if it's not possible to determine if the test is - negative or not. This is the case when the test is not run in pytest. - """ - if self._request is not None and hasattr(self._request, "node"): - return ( - self._request.node.get_closest_marker("exception_test") - is not None - ) - return None - - def node(self) -> pytest.Item | pytest.Function | None: - """Return the pytest node of the test.""" - if self._request is not None and hasattr(self._request, "node"): - return self._request.node - return None - def check_exception_test( self, *, exception: bool, ) -> None: """Compare the test marker against the outcome of the test.""" - negative_test_marker = self.is_exception_test() - if negative_test_marker is None: - return - if negative_test_marker != exception: + if self.is_exception_test != exception: if exception: raise Exception( "Test produced an invalid block or transaction but was " @@ -278,5 +252,41 @@ def get_genesis_environment(self) -> Environment: "access for use with pre-allocation groups." ) + def validate_benchmark_gas( + self, *, benchmark_gas_used: int | None, gas_benchmark_value: int + ) -> None: + """ + Validates the total consumed gas of the last block in the test matches + the expectation of the benchmark test. + + Requires the following fields to be set: + - expected_benchmark_gas_used + - operation_mode + """ + if self.operation_mode != OpMode.BENCHMARKING: + return + assert benchmark_gas_used is not None, "_benchmark_gas_used is not set" + # Perform gas validation if required for benchmarking. + # Ensures benchmark tests consume exactly the expected gas. + if not self.skip_gas_used_validation: + # Verify that the total gas consumed in the last block + # matches expectations + expected_benchmark_gas_used = self.expected_benchmark_gas_used + if expected_benchmark_gas_used is None: + expected_benchmark_gas_used = gas_benchmark_value + diff = benchmark_gas_used - expected_benchmark_gas_used + assert benchmark_gas_used == expected_benchmark_gas_used, ( + f"Total gas used ({benchmark_gas_used}) does not " + "match expected benchmark gas " + f"({expected_benchmark_gas_used}), " + f"difference: {diff}" + ) + # Gas used should never exceed the maximum benchmark gas allowed. + assert benchmark_gas_used <= gas_benchmark_value, ( + f"benchmark_gas_used ({benchmark_gas_used}) exceeds maximum " + "benchmark gas allowed for this configuration: " + f"{gas_benchmark_value}" + ) + TestSpec = Callable[[Fork], Generator[BaseTest, None, None]] diff --git a/packages/testing/src/execution_testing/specs/benchmark.py b/packages/testing/src/execution_testing/specs/benchmark.py index 8f2fe7fc83..da475d2df9 100644 --- a/packages/testing/src/execution_testing/specs/benchmark.py +++ b/packages/testing/src/execution_testing/specs/benchmark.py @@ -31,7 +31,6 @@ TransactionPost, ) from execution_testing.fixtures import ( - BaseFixture, BlockchainEngineFixture, BlockchainEngineXFixture, BlockchainFixture, @@ -42,7 +41,7 @@ from execution_testing.test_types import Alloc, Environment, Transaction from execution_testing.vm import Bytecode, Op -from .base import BaseTest +from .base import BaseTest, FillResult from .blockchain import Block, BlockchainTest @@ -527,14 +526,13 @@ def generate( self, t8n: TransitionTool, fixture_format: FixtureFormat, - ) -> BaseFixture: + ) -> FillResult: """Generate the blockchain test fixture.""" self.check_exception_test( exception=self.tx.error is not None if self.tx else False ) if fixture_format in BlockchainTest.supported_fixture_formats: - blockchain_test = self.generate_blockchain_test() - fixture = blockchain_test.generate( + fill_result = self.generate_blockchain_test().generate( t8n=t8n, fixture_format=fixture_format ) @@ -544,9 +542,9 @@ def generate( and self.fixed_opcode_count is not None ): self._verify_target_opcode_count( - blockchain_test._benchmark_opcode_count + fill_result.benchmark_opcode_count ) - return fixture + return fill_result else: raise Exception(f"Unsupported fixture format: {fixture_format}") @@ -561,6 +559,7 @@ def execute( return TransactionPost( blocks=[block.txs for block in self.blocks], post=self.post, + benchmark_mode=True, ) raise Exception(f"Unsupported execute format: {execute_format}") diff --git a/packages/testing/src/execution_testing/specs/blobs.py b/packages/testing/src/execution_testing/specs/blobs.py index af9b219a88..2375b5c873 100644 --- a/packages/testing/src/execution_testing/specs/blobs.py +++ b/packages/testing/src/execution_testing/specs/blobs.py @@ -7,7 +7,6 @@ from execution_testing.client_clis import TransitionTool from execution_testing.execution import BaseExecute, BlobTransaction from execution_testing.fixtures import ( - BaseFixture, FixtureFormat, ) from execution_testing.test_types import ( @@ -15,7 +14,7 @@ Transaction, ) -from .base import BaseTest, ExecuteFormat, LabeledExecuteFormat +from .base import BaseTest, ExecuteFormat, FillResult, LabeledExecuteFormat class BlobsTest(BaseTest): @@ -38,7 +37,7 @@ def generate( *, t8n: TransitionTool, fixture_format: FixtureFormat, - ) -> BaseFixture: + ) -> FillResult: """Generate the list of test fixtures.""" del t8n raise Exception(f"Unknown fixture format: {fixture_format}") diff --git a/packages/testing/src/execution_testing/specs/blockchain.py b/packages/testing/src/execution_testing/specs/blockchain.py index 7587e0d147..deadb5bc10 100644 --- a/packages/testing/src/execution_testing/specs/blockchain.py +++ b/packages/testing/src/execution_testing/specs/blockchain.py @@ -17,7 +17,6 @@ from pydantic import ( ConfigDict, Field, - PrivateAttr, field_validator, model_serializer, ) @@ -89,7 +88,7 @@ BlockAccessListExpectation, ) -from .base import BaseTest, OpMode, verify_result +from .base import BaseTest, FillResult, OpMode, verify_result from .debugging import print_traces from .helpers import verify_block, verify_transactions @@ -516,8 +515,6 @@ class BlockchainTest(BaseTest): Include transaction receipts in the fixture output. """ - _benchmark_opcode_count: OpcodeCount | None = PrivateAttr(None) - supported_fixture_formats: ClassVar[ Sequence[FixtureFormat | LabeledFixtureFormat] ] = [ @@ -613,7 +610,6 @@ def generate_block_data( block: Block, previous_env: Environment, previous_alloc: Alloc | LazyAlloc, - last_block: bool, ) -> BuiltBlock: """ Generate common block data for both make_fixture and make_hive_fixture. @@ -646,7 +642,7 @@ def generate_block_data( ), blob_schedule=self.fork.blob_schedule(), ), - slow_request=self.is_tx_gas_heavy_test(), + slow_request=self.is_tx_gas_heavy_test, ) # One special case of the invalid transactions is the blob gas used, @@ -687,25 +683,6 @@ def generate_block_data( f"Verification of block {int(env.number)} failed" ) from e - if last_block and self._operation_mode == OpMode.BENCHMARKING: - expected_benchmark_gas_used = self.expected_benchmark_gas_used - assert expected_benchmark_gas_used is not None, ( - "expected_benchmark_gas_used is not set" - ) - gas_used = int(transition_tool_output.result.gas_used) - - if not self.skip_gas_used_validation: - diff = gas_used - expected_benchmark_gas_used - assert gas_used == expected_benchmark_gas_used, ( - f"gas_used ({gas_used}) does not match " - f"expected_benchmark_gas_used " - f"({expected_benchmark_gas_used}), difference: {diff}" - ) - - self._benchmark_opcode_count = ( - transition_tool_output.result.opcode_count - ) - requests_list: List[Bytes] | None = None if self.fork.header_requests_required( block_number=header.number, timestamp=header.timestamp @@ -862,7 +839,7 @@ def verify_post_state( def make_fixture( self, t8n: TransitionTool, - ) -> BlockchainFixture: + ) -> FillResult: """Create a fixture from the blockchain test definition.""" fixture_blocks: List[FixtureBlock | InvalidFixtureBlock] = [] @@ -873,6 +850,8 @@ def make_fixture( env = environment_from_parent_header(genesis.header) head = genesis.header.block_hash invalid_blocks = 0 + benchmark_gas_used: int | None = None + benchmark_opcode_count: OpcodeCount | None = None for i, block in enumerate(self.blocks): is_last_block = i == len(self.blocks) - 1 # This is the most common case, the RLP needs to be constructed @@ -883,8 +862,10 @@ def make_fixture( block=block, previous_env=env, previous_alloc=alloc, - last_block=is_last_block, ) + if is_last_block and self.operation_mode == OpMode.BENCHMARKING: + benchmark_gas_used = int(built_block.result.gas_used) + benchmark_opcode_count = built_block.result.opcode_count include_receipts = ( block.include_receipts_in_output if block.include_receipts_in_output is not None @@ -919,10 +900,7 @@ def make_fixture( self.check_exception_test(exception=invalid_blocks > 0) alloc = alloc.get() if isinstance(alloc, LazyAlloc) else alloc self.verify_post_state(t8n, t8n_state=alloc) - info = {} - if t8n.opcode_count is not None: - info["opcode_count"] = t8n.opcode_count.model_dump() - return BlockchainFixture( + fixture = BlockchainFixture( fork=self.fork, genesis=genesis.header, genesis_rlp=genesis.rlp, @@ -942,18 +920,19 @@ def make_fixture( ), chain_id=self.chain_id, ), - info=info, + ) + return FillResult( + fixture=fixture, + gas_optimization=None, + benchmark_gas_used=benchmark_gas_used, + benchmark_opcode_count=benchmark_opcode_count, ) def make_hive_fixture( self, t8n: TransitionTool, fixture_format: FixtureFormat = BlockchainEngineFixture, - ) -> ( - BlockchainEngineFixture - | BlockchainEngineXFixture - | BlockchainEngineSyncFixture - ): + ) -> FillResult: """Create a hive fixture from the blocktest definition.""" fixture_payloads: List[FixtureEngineNewPayload] = [] @@ -966,14 +945,19 @@ def make_hive_fixture( env = environment_from_parent_header(genesis.header) head_hash = genesis.header.block_hash invalid_blocks = 0 + benchmark_gas_used: int | None = None + benchmark_opcode_count: OpcodeCount | None = None for i, block in enumerate(self.blocks): + is_last_block = i == len(self.blocks) - 1 built_block = self.generate_block_data( t8n=t8n, block=block, previous_env=env, previous_alloc=alloc, - last_block=i == len(self.blocks) - 1, ) + if is_last_block and self.operation_mode == OpMode.BENCHMARKING: + benchmark_gas_used = int(built_block.result.gas_used) + benchmark_opcode_count = built_block.result.opcode_count fixture_payloads.append( built_block.get_fixture_engine_new_payload() ) @@ -1007,9 +991,6 @@ def make_hive_fixture( self.verify_post_state(t8n, t8n_state=alloc) # Create base fixture data, common to all fixture formats - info = {} - if t8n.opcode_count is not None: - info["opcode_count"] = t8n.opcode_count.model_dump() fixture_data = { "fork": self.fork, "genesis": genesis.header, @@ -1025,10 +1006,10 @@ def make_hive_fixture( self.fork.blob_schedule() ), ), - "info": info, } # Add format-specific fields + fixture: BaseFixture if fixture_format == BlockchainEngineXFixture: # For Engine X format, exclude pre (will be provided via shared # state) and prepare for state diff optimization @@ -1040,7 +1021,7 @@ def make_hive_fixture( "pre_hash": "", # Will be set by BaseTestWrapper } ) - return BlockchainEngineXFixture(**fixture_data) + fixture = BlockchainEngineXFixture(**fixture_data) elif fixture_format == BlockchainEngineSyncFixture: # Sync fixture format assert genesis.header.block_hash != head_hash, ( @@ -1055,7 +1036,6 @@ def make_hive_fixture( block=Block(), previous_env=env, previous_alloc=alloc, - last_block=False, ) fixture_data.update( { @@ -1068,7 +1048,7 @@ def make_hive_fixture( else None, } ) - return BlockchainEngineSyncFixture(**fixture_data) + fixture = BlockchainEngineSyncFixture(**fixture_data) else: # Standard engine fixture fixture_data.update( @@ -1079,13 +1059,20 @@ def make_hive_fixture( else None, } ) - return BlockchainEngineFixture(**fixture_data) + fixture = BlockchainEngineFixture(**fixture_data) + + return FillResult( + fixture=fixture, + gas_optimization=None, + benchmark_gas_used=benchmark_gas_used, + benchmark_opcode_count=benchmark_opcode_count, + ) def generate( self, t8n: TransitionTool, fixture_format: FixtureFormat, - ) -> BaseFixture: + ) -> FillResult: """Generate the BlockchainTest fixture.""" if fixture_format in [ BlockchainEngineFixture, @@ -1110,14 +1097,14 @@ def execute( blocks += [block.txs] # Pass gas validation params for benchmark tests # If not benchmark mode, skip gas used validation - if self._operation_mode != OpMode.BENCHMARKING: + if self.operation_mode != OpMode.BENCHMARKING: self.skip_gas_used_validation = True + benchmark_mode = self.operation_mode == OpMode.BENCHMARKING return TransactionPost( blocks=blocks, post=self.post, - expected_benchmark_gas_used=self.expected_benchmark_gas_used, - skip_gas_used_validation=self.skip_gas_used_validation, + benchmark_mode=benchmark_mode, ) raise Exception(f"Unsupported execute format: {execute_format}") diff --git a/packages/testing/src/execution_testing/specs/state.py b/packages/testing/src/execution_testing/specs/state.py index 6774d2b91a..87ad9f4d6e 100644 --- a/packages/testing/src/execution_testing/specs/state.py +++ b/packages/testing/src/execution_testing/specs/state.py @@ -33,7 +33,6 @@ TransactionPost, ) from execution_testing.fixtures import ( - BaseFixture, FixtureFormat, LabeledFixtureFormat, StateFixture, @@ -57,7 +56,7 @@ Transaction, ) -from .base import BaseTest, OpMode +from .base import BaseTest, FillResult, OpMode from .blockchain import Block, BlockchainTest, Header from .debugging import print_traces from .helpers import verify_transactions @@ -153,7 +152,7 @@ def verify_modified_gas_limit( blob_schedule=fork.blob_schedule(), state_test=True, ), - slow_request=self.is_tx_gas_heavy_test(), + slow_request=self.is_tx_gas_heavy_test, ) modified_traces = modified_tool_output.result.traces assert modified_traces is not None, ( @@ -337,7 +336,7 @@ def generate_blockchain_test(self) -> BlockchainTest: def make_state_test_fixture( self, t8n: TransitionTool, - ) -> StateFixture: + ) -> FillResult: """Create a fixture from the state test definition.""" # We can't generate a state test fixture that names a transition fork, # so we get the fork at the block number and timestamp of the state @@ -366,7 +365,7 @@ def make_state_test_fixture( blob_schedule=fork.blob_schedule(), state_test=True, ), - slow_request=self.is_tx_gas_heavy_test(), + slow_request=self.is_tx_gas_heavy_test, ) output_alloc = transition_tool_output.alloc.get() @@ -388,12 +387,14 @@ def make_state_test_fixture( pprint(output_alloc) raise e + gas_optimization: int | None = None + if ( - self._operation_mode == OpMode.OPTIMIZE_GAS - or self._operation_mode == OpMode.OPTIMIZE_GAS_POST_PROCESSING + self.operation_mode == OpMode.OPTIMIZE_GAS + or self.operation_mode == OpMode.OPTIMIZE_GAS_POST_PROCESSING ): enable_post_processing = ( - self._operation_mode == OpMode.OPTIMIZE_GAS_POST_PROCESSING + self.operation_mode == OpMode.OPTIMIZE_GAS_POST_PROCESSING ) base_tool_output = transition_tool_output base_tool_alloc = base_tool_output.alloc.get() @@ -434,13 +435,13 @@ def make_state_test_fixture( else: minimum_gas_limit = current_gas_limit + 1 if ( - self._gas_optimization_max_gas_limit is not None + self.gas_optimization_max_gas_limit is not None and minimum_gas_limit - > self._gas_optimization_max_gas_limit + > self.gas_optimization_max_gas_limit ): raise Exception( "Requires more than the minimum " - f"{self._gas_optimization_max_gas_limit} " + f"{self.gas_optimization_max_gas_limit} " "wanted." ) @@ -454,23 +455,10 @@ def make_state_test_fixture( env=env, enable_post_processing=enable_post_processing, ) - self._gas_optimization = current_gas_limit + gas_optimization = current_gas_limit else: raise Exception("Impossible to compare.") - if self._operation_mode == OpMode.BENCHMARKING: - expected_benchmark_gas_used = self.expected_benchmark_gas_used - assert expected_benchmark_gas_used is not None, ( - "expected_benchmark_gas_used is not set" - ) - gas_used = int(transition_tool_output.result.gas_used) - if not self.skip_gas_used_validation: - diff = gas_used - expected_benchmark_gas_used - assert gas_used == expected_benchmark_gas_used, ( - f"gas_used ({gas_used}) does not match " - f"expected_benchmark_gas_used " - f"({expected_benchmark_gas_used}), difference: {diff}" - ) if len(transition_tool_output.result.receipts) == 1: receipt = FixtureTransactionReceipt.from_transaction_receipt( transition_tool_output.result.receipts[0], tx @@ -487,7 +475,7 @@ def make_state_test_fixture( else: receipt = None - return StateFixture( + fixture = StateFixture( env=FixtureEnvironment(**env.model_dump(exclude_none=True)), pre=pre_alloc, post={ @@ -510,6 +498,12 @@ def make_state_test_fixture( chain_id=self.chain_id, ), ) + return FillResult( + fixture=fixture, + gas_optimization=gas_optimization, + benchmark_gas_used=transition_tool_output.result.gas_used, + benchmark_opcode_count=transition_tool_output.result.opcode_count, + ) def get_genesis_environment(self) -> Environment: """Get the genesis environment for pre-allocation groups.""" @@ -519,7 +513,7 @@ def generate( self, t8n: TransitionTool, fixture_format: FixtureFormat, - ) -> BaseFixture: + ) -> FillResult: """Generate the BlockchainTest fixture.""" self.check_exception_test(exception=self.tx.error is not None) if fixture_format in BlockchainTest.supported_fixture_formats: @@ -540,13 +534,13 @@ def execute( if execute_format == TransactionPost: # Pass gas validation params for benchmark tests # If not benchmark mode, skip gas used validation - if self._operation_mode != OpMode.BENCHMARKING: + if self.operation_mode != OpMode.BENCHMARKING: self.skip_gas_used_validation = True + benchmark_mode = self.operation_mode == OpMode.BENCHMARKING return TransactionPost( blocks=[[self.tx]], post=self.post, - expected_benchmark_gas_used=self.expected_benchmark_gas_used, - skip_gas_used_validation=self.skip_gas_used_validation, + benchmark_mode=benchmark_mode, ) raise Exception(f"Unsupported execute format: {execute_format}") diff --git a/packages/testing/src/execution_testing/specs/tests/test_expect.py b/packages/testing/src/execution_testing/specs/tests/test_expect.py index 6a5e69a10e..096bc250e9 100644 --- a/packages/testing/src/execution_testing/specs/tests/test_expect.py +++ b/packages/testing/src/execution_testing/specs/tests/test_expect.py @@ -73,12 +73,18 @@ def fork() -> Fork: # noqa: D103 return get_deployed_forks()[-1] +@pytest.fixture +def is_exception_test() -> bool: # noqa: D103 + return False + + @pytest.fixture def state_test( # noqa: D103 pre: Mapping[Any, Any], post: Mapping[Any, Any], tx: Transaction, fork: Fork, + is_exception_test: bool, ) -> StateTest: return StateTest( env=Environment(), @@ -86,6 +92,7 @@ def state_test( # noqa: D103 post=post, tx=tx, fork=fork, + is_exception_test=is_exception_test, ) @@ -368,7 +375,7 @@ def test_post_account_mismatch( # Transaction result mismatch tests @pytest.mark.parametrize( - "tx,exception_type", + "tx,exception_type,is_exception_test", [ pytest.param( Transaction( @@ -377,6 +384,7 @@ def test_post_account_mismatch( error=TransactionException.SENDER_NOT_EOA, ), ExecutionExceptionMismatchError, + True, id="TransactionExecutionExceptionMismatchError", ), pytest.param( @@ -388,6 +396,7 @@ def test_post_account_mismatch( ), ), UnexpectedExecutionSuccessError, + True, id="TransactionUnexpectedExecutionSuccessError", ), pytest.param( @@ -399,6 +408,7 @@ def test_post_account_mismatch( ), ), UnexpectedExecutionFailError, + False, id="TransactionUnexpectedExecutionFailError", ), pytest.param( @@ -409,6 +419,7 @@ def test_post_account_mismatch( ), ), TransactionReceiptMismatchError, + False, id="TransactionReceiptMismatchError", ), pytest.param( @@ -420,6 +431,7 @@ def test_post_account_mismatch( ), ), UnexpectedExecutionFailError, + False, id="TransactionUnexpectedExecutionFailError+TransactionReceiptMismatchError", ), pytest.param( @@ -431,6 +443,7 @@ def test_post_account_mismatch( ), ), UnexpectedExecutionSuccessError, + True, id="TransactionUnexpectedExecutionSuccessError+TransactionReceiptMismatchError", ), ], diff --git a/packages/testing/src/execution_testing/specs/tests/test_fixtures.py b/packages/testing/src/execution_testing/specs/tests/test_fixtures.py index 21d5267857..beea47d1cd 100644 --- a/packages/testing/src/execution_testing/specs/tests/test_fixtures.py +++ b/packages/testing/src/execution_testing/specs/tests/test_fixtures.py @@ -108,14 +108,17 @@ def test_make_genesis( # noqa: D103 } ) - fixture = BlockchainTest( - fork=fork, - genesis_environment=env, - pre=pre, - post={}, - blocks=[], - tag="some_state_test", - ).generate(t8n=default_t8n, fixture_format=BlockchainFixture) + fixture = ( + BlockchainTest( + fork=fork, + genesis_environment=env, + pre=pre, + post={}, + blocks=[], + ) + .generate(t8n=default_t8n, fixture_format=BlockchainFixture) + .fixture + ) assert isinstance(fixture, BlockchainFixture) assert fixture.genesis is not None @@ -193,14 +196,17 @@ def test_fill_state_test( ), } - generated_fixture = StateTest( - fork=fork, - env=env, - pre=pre, - post=post, - tx=tx, - tag="my_chain_id_test", - ).generate(t8n=default_t8n, fixture_format=fixture_format) + generated_fixture = ( + StateTest( + fork=fork, + env=env, + pre=pre, + post=post, + tx=tx, + ) + .generate(t8n=default_t8n, fixture_format=fixture_format) + .fixture + ) assert generated_fixture.__class__ == fixture_format fixture_key = f"000/my_chain_id_test/{fork}/tx_type_{tx_type}" fixture = { @@ -521,14 +527,17 @@ def blockchain_test_fixture( # noqa: D102 fixture_format: FixtureFormat, default_t8n: TransitionTool, ) -> BaseFixture: - return BlockchainTest( - fork=fork, - pre=pre, - post=post, - blocks=blocks, - genesis_environment=genesis_environment, - tag="my_blockchain_test_valid_txs", - ).generate(t8n=default_t8n, fixture_format=fixture_format) + return ( + BlockchainTest( + fork=fork, + pre=pre, + post=post, + blocks=blocks, + genesis_environment=genesis_environment, + ) + .generate(t8n=default_t8n, fixture_format=fixture_format) + .fixture + ) @pytest.mark.parametrize("fork", [London, Shanghai], indirect=True) def test_fill_blockchain_valid_txs( # noqa: D102 @@ -922,13 +931,18 @@ def test_fill_blockchain_invalid_txs( fixture_format: FixtureFormat = ( BlockchainEngineFixture if check_hive else BlockchainFixture ) - generated_fixture = BlockchainTest( - fork=fork, - pre=pre, - post=post, - blocks=blocks, - genesis_environment=genesis_environment, - ).generate(t8n=default_t8n, fixture_format=fixture_format) + generated_fixture = ( + BlockchainTest( + fork=fork, + pre=pre, + post=post, + blocks=blocks, + genesis_environment=genesis_environment, + is_exception_test=True, + ) + .generate(t8n=default_t8n, fixture_format=fixture_format) + .fixture + ) assert generated_fixture.__class__ == fixture_format # BlockchainEngineFixture inherits from BlockchainEngineFixtureCommon # (not BlockchainFixtureCommon) diff --git a/packages/testing/src/execution_testing/specs/tests/test_transaction.py b/packages/testing/src/execution_testing/specs/tests/test_transaction.py index 0114e38029..abc58a3421 100644 --- a/packages/testing/src/execution_testing/specs/tests/test_transaction.py +++ b/packages/testing/src/execution_testing/specs/tests/test_transaction.py @@ -27,12 +27,16 @@ def test_transaction_test_filling( name: str, tx: Transaction, fork: Fork ) -> None: """Test the transaction test filling.""" - generated_fixture = TransactionTest( - tx=tx.with_signature_and_sender(), - fork=fork, - ).generate( - t8n=None, # type: ignore - fixture_format=TransactionFixture, + generated_fixture = ( + TransactionTest( + tx=tx.with_signature_and_sender(), + fork=fork, + ) + .generate( + t8n=None, # type: ignore + fixture_format=TransactionFixture, + ) + .fixture ) assert generated_fixture.__class__ == TransactionFixture fixture_json_dict = generated_fixture.json_dict_with_info() diff --git a/packages/testing/src/execution_testing/specs/transaction.py b/packages/testing/src/execution_testing/specs/transaction.py index b1e108cdb3..d66125f9fa 100644 --- a/packages/testing/src/execution_testing/specs/transaction.py +++ b/packages/testing/src/execution_testing/specs/transaction.py @@ -10,7 +10,6 @@ TransactionPost, ) from execution_testing.fixtures import ( - BaseFixture, FixtureFormat, LabeledFixtureFormat, TransactionFixture, @@ -18,7 +17,7 @@ from execution_testing.fixtures.transaction import FixtureResult from execution_testing.test_types import Alloc, Transaction -from .base import BaseTest +from .base import BaseTest, FillResult, OpMode class TransactionTest(BaseTest): @@ -44,7 +43,7 @@ class TransactionTest(BaseTest): def make_transaction_test_fixture( self, - ) -> TransactionFixture: + ) -> FillResult: """Create a fixture from the transaction test definition.""" if self.tx.error is not None: result = FixtureResult( @@ -70,18 +69,24 @@ def make_transaction_test_fixture( sender=self.tx.sender, ) - return TransactionFixture( + fixture = TransactionFixture( result={ self.fork: result, }, transaction=self.tx.with_signature_and_sender().rlp(), ) + return FillResult( + fixture=fixture, + gas_optimization=None, + benchmark_gas_used=None, + benchmark_opcode_count=None, + ) def generate( self, t8n: TransitionTool, fixture_format: FixtureFormat, - ) -> BaseFixture: + ) -> FillResult: """Generate the TransactionTest fixture.""" del t8n @@ -98,9 +103,11 @@ def execute( ) -> BaseExecute: """Execute the transaction test by sending it to the live network.""" if execute_format == TransactionPost: + benchmark_mode = self.operation_mode == OpMode.BENCHMARKING return TransactionPost( blocks=[[self.tx]], post={}, + benchmark_mode=benchmark_mode, ) raise Exception(f"Unsupported execute format: {execute_format}") diff --git a/tests/cancun/eip4844_blobs/test_excess_blob_gas.py b/tests/cancun/eip4844_blobs/test_excess_blob_gas.py index ffa29e5a3c..84952a629c 100644 --- a/tests/cancun/eip4844_blobs/test_excess_blob_gas.py +++ b/tests/cancun/eip4844_blobs/test_excess_blob_gas.py @@ -332,7 +332,6 @@ def test_correct_excess_blob_gas_calculation( post=post, blocks=blocks, genesis_environment=env, - tag=f"expected_excess_blob_gas:{hex(correct_excess_blob_gas)}", ) @@ -399,7 +398,6 @@ def test_correct_increasing_blob_gas_costs( post=post, blocks=blocks, genesis_environment=env, - tag=f"expected_excess_blob_gas:{hex(correct_excess_blob_gas)}", ) @@ -431,7 +429,6 @@ def test_correct_decreasing_blob_gas_costs( post=post, blocks=blocks, genesis_environment=env, - tag=f"expected_excess_blob_gas:{hex(correct_excess_blob_gas)}", ) @@ -466,12 +463,6 @@ def test_invalid_zero_excess_blob_gas_in_header( post={}, blocks=blocks, genesis_environment=env, - tag="-".join( - [ - f"correct:{hex(correct_excess_blob_gas)}", - f"header:{hex(header_excess_blob_gas)}", - ] - ), ) @@ -517,12 +508,6 @@ def test_invalid_blob_gas_used_in_header( post={}, blocks=blocks, genesis_environment=env, - tag="-".join( - [ - f"correct:{hex(new_blobs * blob_gas_per_blob)}", - f"header:{hex(header_blob_gas_used)}", - ] - ), ) @@ -573,12 +558,6 @@ def test_invalid_excess_blob_gas_above_target_change( post={}, blocks=blocks, genesis_environment=env, - tag="-".join( - [ - f"correct:{hex(correct_excess_blob_gas)}", - f"header:{hex(header_excess_blob_gas)}", - ] - ), ) @@ -617,12 +596,6 @@ def test_invalid_static_excess_blob_gas( post={}, blocks=blocks, genesis_environment=env, - tag="-".join( - [ - f"correct:{hex(correct_excess_blob_gas)}", - f"header:{hex(parent_excess_blob_gas)}", - ] - ), ) @@ -662,12 +635,6 @@ def test_invalid_excess_blob_gas_target_blobs_increase_from_zero( post={}, blocks=blocks, genesis_environment=env, - tag="-".join( - [ - f"correct:{hex(correct_excess_blob_gas)}", - f"header:{hex(header_excess_blob_gas)}", - ] - ), ) @@ -707,12 +674,6 @@ def test_invalid_static_excess_blob_gas_from_zero_on_blobs_above_target( post={}, blocks=blocks, genesis_environment=env, - tag="-".join( - [ - f"correct:{hex(correct_excess_blob_gas)}", - f"header:{hex(header_excess_blob_gas)}", - ] - ), ) @@ -761,12 +722,6 @@ def test_invalid_excess_blob_gas_change( post={}, blocks=blocks, genesis_environment=env, - tag="-".join( - [ - f"correct:{hex(correct_excess_blob_gas)}", - f"header:{hex(header_excess_blob_gas)}", - ] - ), ) @@ -814,12 +769,6 @@ def test_invalid_negative_excess_blob_gas( post={}, blocks=blocks, genesis_environment=env, - tag="-".join( - [ - f"correct:{hex(correct_excess_blob_gas)}", - f"header:{hex(header_excess_blob_gas)}", - ] - ), ) @@ -867,10 +816,4 @@ def test_invalid_non_multiple_excess_blob_gas( post={}, blocks=blocks, genesis_environment=env, - tag="-".join( - [ - f"correct:{hex(correct_excess_blob_gas)}", - f"header:{hex(header_excess_blob_gas)}", - ] - ), ) diff --git a/tests/shanghai/eip3651_warm_coinbase/test_warm_coinbase.py b/tests/shanghai/eip3651_warm_coinbase/test_warm_coinbase.py index a9c2a18280..1a33c17447 100644 --- a/tests/shanghai/eip3651_warm_coinbase/test_warm_coinbase.py +++ b/tests/shanghai/eip3651_warm_coinbase/test_warm_coinbase.py @@ -119,7 +119,6 @@ def test_warm_coinbase_call_out_of_gas( pre=pre, post=post, tx=tx, - tag="opcode_" + opcode, ) @@ -205,5 +204,4 @@ def test_warm_coinbase_gas_usage( pre=pre, post=post, tx=tx, - tag="opcode_" + opcode.lower(), ) diff --git a/tests/shanghai/eip4895_withdrawals/test_withdrawals.py b/tests/shanghai/eip4895_withdrawals/test_withdrawals.py index 9af9975c46..272dc6293a 100644 --- a/tests/shanghai/eip4895_withdrawals/test_withdrawals.py +++ b/tests/shanghai/eip4895_withdrawals/test_withdrawals.py @@ -356,7 +356,6 @@ def blocks(self, addresses: Address, test_case: str) -> List[Block]: # noqa: D1 def test_multiple_withdrawals_same_address( self, blockchain_test: BlockchainTestFiller, - test_case: str, pre: Alloc, addresses: List[Address], blocks: List[Block], @@ -372,7 +371,7 @@ def test_multiple_withdrawals_same_address( storage={}, ) - blockchain_test(pre=pre, post=post, blocks=blocks, tag=test_case) + blockchain_test(pre=pre, post=post, blocks=blocks) def test_many_withdrawals( @@ -710,7 +709,6 @@ def test_zero_amount( # to allow for Account.NONEXISTENT post=post, blocks=[Block(withdrawals=withdrawals)], - tag=test_case.value, ) From c8e583b8a77dd5b2691d60c338dd79d287252b0d Mon Sep 17 00:00:00 2001 From: Guruprasad Kamath <48196632+gurukamath@users.noreply.github.com> Date: Fri, 20 Feb 2026 23:37:01 +0100 Subject: [PATCH 2/4] feat(amsterdam): add PreState protocol and DictPreState implementation (#2207) Introduce a shared `PreState` protocol (`src/ethereum/state.py`) that defines the read-only interface any pre-execution state provider must support: `get_account_optional`, `get_storage`, `account_has_storage`, and `compute_state_root_and_trie_changes`. Move `Account`, `Address`, `Root`, and trie internal node types to the shared module so they can be referenced by the protocol without circular imports. Add `DictPreState` in amsterdam's `state.py` as an in-memory implementation backed by `Trie` dicts, along with helper functions `dict_pre_state_set_account` and `dict_pre_state_set_storage` for populating it. feat(amsterdam): replace mutable State with BlockStateTracker/TxStateTracker Replace the mutable `State` class (with in-memory tries, snapshots, and rollback) with lightweight state trackers that record diffs on top of a read-only `PreState`. Add `state_tracking.py` with `BlockStateTracker` (block-level accumulator) and `TxStateTracker` (per-transaction tracker). All state read/write operations now go through these trackers with the read chain: TxStateTracker -> BlockStateTracker -> PreState. Snapshot/rollback is replaced by copy-on-write via `copy_tx_state_tracker`/`restore_tx_state_tracker`. Transient storage moves from a separate `Trie`-based class to a simple dict on the tracker. At block end, `extract_block_diffs` produces the accumulated changes and `PreState.compute_state_root_and_trie_changes` computes the state root. `apply_diffs_to_state` advances `chain.state` for subsequent blocks. Simplify `State` to just hold tries (no snapshots/created_accounts). Remove all mutable state operations from `state.py` (moved to `state_tracking.py`). Update all consumer modules (fork.py, VM, interpreter, instructions, eoa_delegation, message utils). Co-authored-by: Peter Miller feat(t8n): update tooling to support BlockStateTracker forks Update the t8n tool and fork loader to work with forks that use `BlockStateTracker` instead of the mutable `State`. In `fork_loader.py`, add `has_block_tracker` detection and conditionally return `DictPreState`, `dict_pre_state_set_account`, and `dict_pre_state_set_storage` for Amsterdam. In `t8n/__init__.py`, create a `BlockStateTracker` wrapping the `DictPreState` alloc when `has_block_tracker` is true, and pass `block_tracker=` instead of `state=` to `BlockEnvironment`. In `t8n_types.py`, compute the state root from block tracker diffs via `extract_block_diffs` + `compute_state_root_and_trie_changes`, then apply diffs back to the `DictPreState` for alloc output. refactor(amsterdam): consolidate DictPreState into State Make State implement the PreState protocol directly instead of maintaining a separate DictPreState wrapper class. This eliminates duplication since both had identical _main_trie and _storage_tries fields. Simplifies fork_loader.py by removing has_block_tracker conditionals from State/set_account/set_storage properties. replace BAL with state tracker fix(spec): clean-up code fix(specs): remove tracker suffix from names fix(specs): post-review updates fix(specs): fix docstring references fix(tool): fix EMPTY_ACCOUNT in fork loader add state_root to amsterdam fix(ci): remove optimized ci run The optimized run is no longer compatible with the state tracker and will have to be refactored/redesigned. The commit takes it out of the CI with the aim of adding it back once ready --- .github/workflows/test.yaml | 18 - .../amsterdam/block_access_lists/__init__.py | 2 + .../amsterdam/block_access_lists/builder.py | 166 ++-- .../amsterdam/block_access_lists/rlp_types.py | 14 +- src/ethereum/forks/amsterdam/blocks.py | 12 +- src/ethereum/forks/amsterdam/fork.py | 189 ++-- src/ethereum/forks/amsterdam/fork_types.py | 26 +- src/ethereum/forks/amsterdam/state.py | 709 ++------------ src/ethereum/forks/amsterdam/state_tracker.py | 874 +++++++++++------- src/ethereum/forks/amsterdam/transactions.py | 3 +- src/ethereum/forks/amsterdam/trie.py | 63 +- src/ethereum/forks/amsterdam/utils/address.py | 3 +- .../forks/amsterdam/utils/hexadecimal.py | 3 +- src/ethereum/forks/amsterdam/utils/message.py | 14 +- src/ethereum/forks/amsterdam/vm/__init__.py | 20 +- .../forks/amsterdam/vm/eoa_delegation.py | 41 +- .../amsterdam/vm/instructions/environment.py | 32 +- .../amsterdam/vm/instructions/storage.py | 47 +- .../forks/amsterdam/vm/instructions/system.py | 132 +-- .../forks/amsterdam/vm/interpreter.py | 115 +-- .../vm/precompiled_contracts/mapping.py | 3 +- src/ethereum/state.py | 136 +++ .../evm_tools/loaders/fork_loader.py | 12 + .../evm_tools/t8n/__init__.py | 38 +- .../evm_tools/t8n/t8n_types.py | 26 +- tox.ini | 16 - 26 files changed, 1170 insertions(+), 1544 deletions(-) create mode 100644 src/ethereum/state.py diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 326cce02cf..d4980ab094 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -121,24 +121,6 @@ jobs: - name: Run json infra tests run: tox -e json_infra -- --file-list="${{ steps.get-changed-files.outputs.file_list }}" - optimized: - runs-on: [self-hosted-ghr, size-xl-x64] - needs: static - steps: - - uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 - with: - submodules: recursive - fetch-depth: 0 # Fetch full history for commit comparison - - name: Setup Python - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 - with: - python-version: "3.11" - - uses: ./.github/actions/setup-env - - uses: ./.github/actions/get-changed-files - id: get-changed-files - - name: Run optimized tests - run: tox -e optimized -- --file-list="${{ steps.get-changed-files.outputs.file_list }}" - tests_pytest_py3: runs-on: [self-hosted-ghr, size-xl-x64] needs: static diff --git a/src/ethereum/forks/amsterdam/block_access_lists/__init__.py b/src/ethereum/forks/amsterdam/block_access_lists/__init__.py index 33681d3145..338c290d1c 100644 --- a/src/ethereum/forks/amsterdam/block_access_lists/__init__.py +++ b/src/ethereum/forks/amsterdam/block_access_lists/__init__.py @@ -11,6 +11,7 @@ add_storage_write, add_touched_account, build_block_access_list, + update_builder_from_tx, ) from .rlp_utils import compute_block_access_list_hash @@ -24,4 +25,5 @@ "add_touched_account", "build_block_access_list", "compute_block_access_list_hash", + "update_builder_from_tx", ] diff --git a/src/ethereum/forks/amsterdam/block_access_lists/builder.py b/src/ethereum/forks/amsterdam/block_access_lists/builder.py index 3b5d446e6b..a8b32f41f5 100644 --- a/src/ethereum/forks/amsterdam/block_access_lists/builder.py +++ b/src/ethereum/forks/amsterdam/block_access_lists/builder.py @@ -14,12 +14,14 @@ """ # noqa: E501 from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, List, Set +from typing import Dict, List, Optional, Set -from ethereum_types.bytes import Bytes -from ethereum_types.numeric import U64, U256 +from ethereum_types.bytes import Bytes, Bytes32 +from ethereum_types.numeric import U64, U256, Uint -from ..fork_types import Address +from ethereum.state import Account, Address, PreState + +from ..state_tracker import BlockState, TransactionState from .rlp_types import ( AccountChanges, BalanceChange, @@ -31,9 +33,6 @@ StorageChange, ) -if TYPE_CHECKING: - from ..state_tracker import StateChanges - @dataclass class AccountData: @@ -89,6 +88,15 @@ class BlockAccessListBuilder: [`BlockAccessList`]: ref:ethereum.forks.amsterdam.block_access_lists.rlp_types.BlockAccessList """ # noqa: E501 + block_access_index: BlockAccessIndex = BlockAccessIndex(0) + """ + Current block access index. Set by the caller before each + [`incorporate_tx_into_block`] call (0 for system txs, i+1 for the + i-th user tx, N+1 for post-execution operations). + + [`incorporate_tx_into_block`]: ref:ethereum.forks.amsterdam.state_tracker.incorporate_tx_into_block + """ # noqa: E501 + accounts: Dict[Address, AccountData] = field(default_factory=dict) """ Mapping from account address to its tracked changes during block execution. @@ -104,7 +112,6 @@ def ensure_account(builder: BlockAccessListBuilder, address: Address) -> None: multiple times for the same address. [`AccountData`]: ref:ethereum.forks.amsterdam.block_access_lists.builder.AccountData - """ # noqa: E501 if address not in builder.accounts: builder.accounts[address] = AccountData() @@ -156,8 +163,6 @@ def add_storage_read( Records that a storage slot was read during execution. Storage slots that are both read and written will only appear in the storage changes list, not in the storage reads list, as per [EIP-7928]. - - [EIP-7928]: https://eips.ethereum.org/EIPS/eip-7928 """ ensure_account(builder, address) builder.accounts[address].storage_reads.add(slot) @@ -358,58 +363,111 @@ def _build_from_builder( return block_access_list +def _get_pre_tx_account( + pre_tx_accounts: Dict[Address, Optional[Account]], + pre_state: PreState, + address: Address, +) -> Optional[Account]: + """ + Look up an account in cumulative state, falling back to `pre_state`. + + The cumulative account state (`pre_tx_accounts`) should contain state up + to (but not including) the current transaction. + + Returns `None` if the `address` does not exist. + """ + if address in pre_tx_accounts: + return pre_tx_accounts[address] + return pre_state.get_account_optional(address) + + +def _get_pre_tx_storage( + pre_tx_storage: Dict[Address, Dict[Bytes32, U256]], + pre_state: PreState, + address: Address, + key: Bytes32, +) -> U256: + """ + Look up a storage value in cumulative state, falling back to `pre_state`. + + Returns `0` if not set. + """ + if address in pre_tx_storage and key in pre_tx_storage[address]: + return pre_tx_storage[address][key] + return pre_state.get_storage(address, key) + + +def update_builder_from_tx( + builder: BlockAccessListBuilder, + tx_state: TransactionState, +) -> None: + """ + Update the BAL builder with changes from a single transaction. + + Compare the transaction's writes against the block's cumulative + state (falling back to `pre_state`) to extract balance, nonce, code, and + storage changes. Net-zero filtering is automatic: if the pre-tx value + equals the post-tx value, no change is recorded. + + Must be called **before** the transaction's writes are merged into + the block state. + """ + block_state = tx_state.parent + pre_state = block_state.pre_state + idx = builder.block_access_index + + # Compare account writes against block cumulative state + for address, post_account in tx_state.account_writes.items(): + pre_account = _get_pre_tx_account( + block_state.account_writes, pre_state, address + ) + + pre_balance = pre_account.balance if pre_account else U256(0) + post_balance = post_account.balance if post_account else U256(0) + if pre_balance != post_balance: + add_balance_change(builder, address, idx, post_balance) + + pre_nonce = pre_account.nonce if pre_account else Uint(0) + post_nonce = post_account.nonce if post_account else Uint(0) + if pre_nonce != post_nonce: + add_nonce_change(builder, address, idx, U64(post_nonce)) + + pre_code = pre_account.code if pre_account else b"" + post_code = post_account.code if post_account else b"" + if pre_code != post_code: + add_code_change(builder, address, idx, post_code) + + # Compare storage writes against block cumulative state + for address, slots in tx_state.storage_writes.items(): + for key, post_value in slots.items(): + pre_value = _get_pre_tx_storage( + block_state.storage_writes, pre_state, address, key + ) + if pre_value != post_value: + # Convert slot from internal Bytes32 format to U256 for BAL. + # EIP-7928 uses U256 as it's more space-efficient in RLP. + u256_slot = U256.from_be_bytes(key) + add_storage_write(builder, address, u256_slot, idx, post_value) + + def build_block_access_list( - state_changes: "StateChanges", + builder: BlockAccessListBuilder, + block_state: BlockState, ) -> BlockAccessList: """ - Build a [`BlockAccessList`] from a StateChanges frame. + Build a [`BlockAccessList`] from the builder and block state. - Converts the accumulated state changes from the frame-based architecture - into the final deterministic BlockAccessList format. + Feed accumulated reads from the block state into the builder, then produce + the final sorted and encoded block access list. [`BlockAccessList`]: ref:ethereum.forks.amsterdam.block_access_lists.rlp_types.BlockAccessList - [`StateChanges`]: ref:ethereum.forks.amsterdam.state_tracker.StateChanges """ # noqa: E501 - builder = BlockAccessListBuilder() + # Add storage reads (convert Bytes32 to U256 for BAL encoding) + for address, slot in block_state.storage_reads: + add_storage_read(builder, address, U256.from_be_bytes(slot)) - # Add all touched addresses - for address in state_changes.touched_addresses: + # Add touched addresses + for address in block_state.account_reads: add_touched_account(builder, address) - # Add all storage reads - for address, slot in state_changes.storage_reads: - add_storage_read(builder, address, U256(int.from_bytes(slot))) - - # Add all storage writes - # Net-zero filtering happens at transaction commit time, not here. - # At block level, we track ALL writes at their respective indices. - for ( - address, - slot, - block_access_index, - ), value in state_changes.storage_writes.items(): - u256_slot = U256(int.from_bytes(slot)) - add_storage_write( - builder, address, u256_slot, block_access_index, value - ) - - # Add all balance changes (balance_changes is keyed by (address, index)) - for ( - address, - block_access_index, - ), new_balance in state_changes.balance_changes.items(): - add_balance_change(builder, address, block_access_index, new_balance) - - # Add all nonce changes - for address, block_access_index, new_nonce in state_changes.nonce_changes: - add_nonce_change(builder, address, block_access_index, new_nonce) - - # Add all code changes - # Filtering happens at transaction level in eoa_delegation.py - for ( - address, - block_access_index, - ), new_code in state_changes.code_changes.items(): - add_code_change(builder, address, block_access_index, new_code) - return _build_from_builder(builder) diff --git a/src/ethereum/forks/amsterdam/block_access_lists/rlp_types.py b/src/ethereum/forks/amsterdam/block_access_lists/rlp_types.py index d2518f1400..17e812bf78 100644 --- a/src/ethereum/forks/amsterdam/block_access_lists/rlp_types.py +++ b/src/ethereum/forks/amsterdam/block_access_lists/rlp_types.py @@ -14,7 +14,7 @@ from ethereum_types.frozen import slotted_freezable from ethereum_types.numeric import U16, U64, U256 -from ..fork_types import Address +from ethereum.state import Address # Type aliases for clarity (matching EIP-7928 specification) StorageKey: TypeAlias = U256 @@ -33,7 +33,7 @@ class StorageChange: storage slot. [slot]: ref:ethereum.forks.amsterdam.block_access_lists.rlp_types.SlotChanges - [`Account`]: ref:ethereum.forks.amsterdam.fork_types.Account + [`Account`]: ref:ethereum.state.Account """ # noqa: E501 block_access_index: BlockAccessIndex @@ -48,7 +48,7 @@ class BalanceChange: balance. [bal]: ref:ethereum.forks.amsterdam.block_access_lists.rlp_types.BlockAccessList - [`Account`]: ref:ethereum.forks.amsterdam.fork_types.Account + [`Account`]: ref:ethereum.state.Account """ # noqa: E501 block_access_index: BlockAccessIndex @@ -63,7 +63,7 @@ class NonceChange: nonce. [bal]: ref:ethereum.forks.amsterdam.block_access_lists.rlp_types.BlockAccessList - [`Account`]: ref:ethereum.forks.amsterdam.fork_types.Account + [`Account`]: ref:ethereum.state.Account """ # noqa: E501 block_access_index: BlockAccessIndex @@ -78,7 +78,7 @@ class CodeChange: code. [bal]: ref:ethereum.forks.amsterdam.block_access_lists.rlp_types.BlockAccessList - [`Account`]: ref:ethereum.forks.amsterdam.fork_types.Account + [`Account`]: ref:ethereum.state.Account """ # noqa: E501 block_access_index: BlockAccessIndex @@ -93,7 +93,7 @@ class SlotChanges: storage. [bal]: ref:ethereum.forks.amsterdam.block_access_lists.rlp_types.BlockAccessList - [`Account`]: ref:ethereum.forks.amsterdam.fork_types.Account + [`Account`]: ref:ethereum.state.Account """ # noqa: E501 slot: StorageKey @@ -106,7 +106,7 @@ class AccountChanges: """ All changes for a single [`Account`], grouped by field type. - [`Account`]: ref:ethereum.forks.amsterdam.fork_types.Account + [`Account`]: ref:ethereum.state.Account """ address: Address diff --git a/src/ethereum/forks/amsterdam/blocks.py b/src/ethereum/forks/amsterdam/blocks.py index ef44549d28..07186e3411 100644 --- a/src/ethereum/forks/amsterdam/blocks.py +++ b/src/ethereum/forks/amsterdam/blocks.py @@ -18,8 +18,9 @@ from ethereum_types.numeric import U64, U256, Uint from ethereum.crypto.hash import Hash32 +from ethereum.state import Address, Root -from .fork_types import Address, Bloom, Root +from .fork_types import Bloom from .transactions import ( AccessListTransaction, BlobTransaction, @@ -106,13 +107,14 @@ class Header: Root hash ([`keccak256`]) of the state trie after executing all transactions in this block. It represents the state of the Ethereum Virtual Machine (EVM) after all transactions in this block have been processed. It - is computed using the [`state_root()`] function, which computes the root - of the Merkle-Patricia [Trie] representing the Ethereum world state. + is computed using [`compute_state_root_and_trie_changes()`][changes], + which computes the root of the Merkle-Patricia [Trie] representing the + Ethereum world state after applying the block's state changes. [`keccak256`]: ref:ethereum.crypto.hash.keccak256 - [`state_root()`]: ref:ethereum.forks.amsterdam.state.state_root + [changes]: ref:ethereum.forks.amsterdam.state.State.compute_state_root_and_trie_changes [Trie]: ref:ethereum.forks.amsterdam.trie.Trie - """ + """ # noqa: E501 transactions_root: Root """ diff --git a/src/ethereum/forks/amsterdam/fork.py b/src/ethereum/forks/amsterdam/fork.py index 3e45c3e953..3bafb85beb 100644 --- a/src/ethereum/forks/amsterdam/fork.py +++ b/src/ethereum/forks/amsterdam/fork.py @@ -27,9 +27,14 @@ InvalidSenderError, NonceMismatchError, ) +from ethereum.state import Address from . import vm -from .block_access_lists.builder import build_block_access_list +from .block_access_lists.builder import ( + BlockAccessListBuilder, + build_block_access_list, +) +from .block_access_lists.rlp_types import BlockAccessIndex from .block_access_lists.rlp_utils import compute_block_access_list_hash from .blocks import Block, Header, Log, Receipt, Withdrawal, encode_receipt from .bloom import logs_bloom @@ -44,7 +49,7 @@ PriorityFeeGreaterThanMaxFeeError, TransactionTypeContractCreationError, ) -from .fork_types import Account, Address, Authorization, VersionedHash +from .fork_types import Authorization, VersionedHash from .requests import ( CONSOLIDATION_REQUEST_TYPE, DEPOSIT_REQUEST_TYPE, @@ -54,26 +59,19 @@ ) from .state import ( State, - TransientStorage, + apply_changes_to_state, +) +from .state_tracker import ( + BlockState, + TransactionState, account_exists_and_is_empty, destroy_account, + extract_block_diffs, get_account, + incorporate_tx_into_block, increment_nonce, - modify_state, set_account_balance, - state_root, -) -from .state_tracker import ( - StateChanges, - capture_pre_balance, - commit_transaction_frame, - create_child_frame, - filter_net_zero_frame_changes, - increment_block_access_index, track_address, - track_balance_change, - track_nonce_change, - track_selfdestruct, ) from .transactions import ( AccessListTransaction, @@ -115,6 +113,7 @@ SYSTEM_TRANSACTION_GAS = Uint(30000000) MAX_BLOB_GAS_PER_BLOCK = BLOB_SCHEDULE_MAX * GAS_PER_BLOB VERSIONED_HASH_VERSION_KZG = b"\x01" +GWEI_TO_WEI = U256(10**9) WITHDRAWAL_REQUEST_PREDEPLOY_ADDRESS = hex_to_address( "0x00000961Ef480Eb55e80D19ad83579A64c007002" @@ -236,9 +235,11 @@ def state_transition(chain: BlockChain, block: Block) -> None: if block.ommers != (): raise InvalidBlock + block_state = BlockState(pre_state=chain.state) + block_env = vm.BlockEnvironment( chain_id=chain.chain_id, - state=chain.state, + state=block_state, block_gas_limit=block.header.gas_limit, block_hashes=get_last_256_block_hashes(chain), coinbase=block.header.coinbase, @@ -248,7 +249,7 @@ def state_transition(chain: BlockChain, block: Block) -> None: prev_randao=block.header.prev_randao, excess_blob_gas=block.header.excess_blob_gas, parent_beacon_block_root=block.header.parent_beacon_block_root, - state_changes=StateChanges(), + block_access_list_builder=BlockAccessListBuilder(), ) block_output = apply_body( @@ -256,7 +257,11 @@ def state_transition(chain: BlockChain, block: Block) -> None: transactions=block.transactions, withdrawals=block.withdrawals, ) - block_state_root = state_root(block_env.state) + account_changes, storage_changes = extract_block_diffs(block_state) + block_state_root, _ = chain.state.compute_state_root_and_trie_changes( + account_changes, storage_changes + ) + apply_changes_to_state(chain.state, account_changes, storage_changes) transactions_root = root(block_output.transactions_trie) receipt_root = root(block_output.receipts_trie) block_logs_bloom = logs_bloom(block_output.block_logs) @@ -418,6 +423,7 @@ def check_transaction( block_env: vm.BlockEnvironment, block_output: vm.BlockOutput, tx: Transaction, + tx_state: TransactionState, ) -> Tuple[Address, Uint, Tuple[VersionedHash, ...], U64]: """ Check if the transaction is includable in the block. @@ -430,6 +436,8 @@ def check_transaction( The block output for the current block. tx : The transaction. + tx_state : + The transaction state tracker. Returns ------- @@ -488,7 +496,7 @@ def check_transaction( raise BlobGasLimitExceededError("blob gas limit exceeded") sender_address = recover_sender(block_env.chain_id, tx) - sender_account = get_account(block_env.state, sender_address) + sender_account = get_account(tx_state, sender_address) if isinstance( tx, (FeeMarketTransaction, BlobTransaction, SetCodeTransaction) @@ -634,9 +642,7 @@ def process_system_transaction( Output of processing the system transaction. """ - # EIP-7928: Create a child frame for system transaction - # This allows proper pre-state capture for net-zero filtering - system_tx_state_changes = create_child_frame(block_env.state_changes) + system_tx_state = TransactionState(parent=block_env.state) tx_env = vm.TransactionEnvironment( origin=SYSTEM_ADDRESS, @@ -644,17 +650,13 @@ def process_system_transaction( gas=SYSTEM_TRANSACTION_GAS, access_list_addresses=set(), access_list_storage_keys=set(), - transient_storage=TransientStorage(), + state=system_tx_state, blob_versioned_hashes=(), authorizations=(), index_in_block=None, tx_hash=None, - state_changes=system_tx_state_changes, ) - # Create call frame as child of tx frame - call_frame = create_child_frame(tx_env.state_changes) - system_tx_message = Message( block_env=block_env, tx_env=tx_env, @@ -674,14 +676,13 @@ def process_system_transaction( disable_precompiles=False, parent_evm=None, is_create=False, - state_changes=call_frame, ) system_tx_output = process_message_call(system_tx_message) - # Commit system transaction changes to block frame - # System transactions always succeed (or block is invalid) - commit_transaction_frame(tx_env.state_changes) + incorporate_tx_into_block( + system_tx_state, block_env.block_access_list_builder + ) return system_tx_output @@ -710,7 +711,8 @@ def process_checked_system_transaction( Output of processing the system transaction. """ - system_contract_code = get_account(block_env.state, target_address).code + system_tx_state = TransactionState(parent=block_env.state) + system_contract_code = get_account(system_tx_state, target_address).code if len(system_contract_code) == 0: raise InvalidBlock( @@ -758,7 +760,8 @@ def process_unchecked_system_transaction( Output of processing the system transaction. """ - system_contract_code = get_account(block_env.state, target_address).code + system_tx_state = TransactionState(parent=block_env.state) + system_contract_code = get_account(system_tx_state, target_address).code return process_system_transaction( block_env, target_address, @@ -799,10 +802,6 @@ def apply_body( """ block_output = vm.BlockOutput() - # EIP-7928: System contracts use block_access_index 0 - # The block frame already starts at index 0, so system transactions - # naturally use that index through the block frame - process_unchecked_system_transaction( block_env=block_env, target_address=BEACON_ROOTS_ADDRESS, @@ -818,10 +817,10 @@ def apply_body( for i, tx in enumerate(map(decode_transaction, transactions)): process_transaction(block_env, block_output, tx, Uint(i)) - # EIP-7928: Increment block frame to post-execution index - # After N transactions, block frame is at index N - # Post-execution operations (withdrawals, etc.) use index N+1 - increment_block_access_index(block_env.state_changes) + # EIP-7928: Post-execution operations use index N+1 + block_env.block_access_list_builder.block_access_index = BlockAccessIndex( + Uint(len(transactions)) + Uint(1) + ) process_withdrawals(block_env, block_output, withdrawals) @@ -829,9 +828,9 @@ def apply_body( block_env=block_env, block_output=block_output, ) - # Build block access list from block_env.state_changes + block_output.block_access_list = build_block_access_list( - block_env.state_changes + block_env.block_access_list_builder, block_env.state ) return block_output @@ -912,19 +911,12 @@ def process_transaction( Index of the transaction in the block. """ - # EIP-7928: Create a transaction-level StateChanges frame - # The frame will read the current block_access_index from the block frame - increment_block_access_index(block_env.state_changes) - tx_state_changes = create_child_frame(block_env.state_changes) - - # Capture coinbase pre-balance for net-zero filtering - coinbase_pre_balance = get_account( - block_env.state, block_env.coinbase - ).balance - track_address(tx_state_changes, block_env.coinbase) - capture_pre_balance( - tx_state_changes, block_env.coinbase, coinbase_pre_balance + block_env.block_access_list_builder.block_access_index = BlockAccessIndex( + index + Uint(1) ) + tx_state = TransactionState(parent=block_env.state) + + track_address(tx_state, block_env.coinbase) trie_set( block_output.transactions_trie, @@ -943,9 +935,10 @@ def process_transaction( block_env=block_env, block_output=block_output, tx=tx, + tx_state=tx_state, ) - sender_account = get_account(block_env.state, sender) + sender_account = get_account(tx_state, sender) if isinstance(tx, BlobTransaction): blob_gas_fee = calculate_data_fee(block_env.excess_blob_gas, tx) @@ -956,27 +949,13 @@ def process_transaction( gas = tx.gas - intrinsic_gas - # Track sender nonce increment - increment_nonce(block_env.state, sender) - sender_nonce_after = get_account(block_env.state, sender).nonce - track_nonce_change(tx_state_changes, sender, U64(sender_nonce_after)) - - # Track sender balance deduction for gas fee - sender_balance_before = get_account(block_env.state, sender).balance - track_address(tx_state_changes, sender) - capture_pre_balance(tx_state_changes, sender, sender_balance_before) + increment_nonce(tx_state, sender) + track_address(tx_state, sender) sender_balance_after_gas_fee = ( Uint(sender_account.balance) - effective_gas_fee - blob_gas_fee ) - set_account_balance( - block_env.state, sender, U256(sender_balance_after_gas_fee) - ) - track_balance_change( - tx_state_changes, - sender, - U256(sender_balance_after_gas_fee), - ) + set_account_balance(tx_state, sender, U256(sender_balance_after_gas_fee)) access_list_addresses = set() access_list_storage_keys = set() @@ -1005,12 +984,11 @@ def process_transaction( gas=gas, access_list_addresses=access_list_addresses, access_list_storage_keys=access_list_storage_keys, - transient_storage=TransientStorage(), + state=tx_state, blob_versioned_hashes=blob_versioned_hashes, authorizations=authorizations, index_in_block=index, tx_hash=get_transaction_hash(encode_transaction(tx)), - state_changes=tx_state_changes, ) message = prepare_message( @@ -1043,33 +1021,23 @@ def process_transaction( transaction_fee = tx_gas_used_after_refund * priority_fee_per_gas # refund gas - sender_balance_after_refund = get_account( - block_env.state, sender - ).balance + U256(gas_refund_amount) - set_account_balance(block_env.state, sender, sender_balance_after_refund) - track_balance_change( - tx_env.state_changes, - sender, - sender_balance_after_refund, + sender_balance_after_refund = get_account(tx_state, sender).balance + U256( + gas_refund_amount ) + set_account_balance(tx_state, sender, sender_balance_after_refund) coinbase_balance_after_mining_fee = get_account( - block_env.state, block_env.coinbase + tx_state, block_env.coinbase ).balance + U256(transaction_fee) set_account_balance( - block_env.state, block_env.coinbase, coinbase_balance_after_mining_fee - ) - track_balance_change( - tx_env.state_changes, - block_env.coinbase, - coinbase_balance_after_mining_fee, + tx_state, block_env.coinbase, coinbase_balance_after_mining_fee ) if coinbase_balance_after_mining_fee == 0 and account_exists_and_is_empty( - block_env.state, block_env.coinbase + tx_state, block_env.coinbase ): - destroy_account(block_env.state, block_env.coinbase) + destroy_account(tx_state, block_env.coinbase) block_output.block_gas_used += tx_gas_used_after_refund block_output.blob_gas_used += tx_blob_gas_used @@ -1090,12 +1058,9 @@ def process_transaction( block_output.block_logs += tx_output.logs for address in tx_output.accounts_to_delete: - destroy_account(block_env.state, address) - track_selfdestruct(tx_env.state_changes, address) + destroy_account(tx_state, address) - # EIP-7928: Commit transaction frame (includes net-zero filtering). - # Must happen AFTER destroy_account so filtering sees correct state. - commit_transaction_frame(tx_env.state_changes) + incorporate_tx_into_block(tx_state, block_env.block_access_list_builder) def process_withdrawals( @@ -1106,15 +1071,7 @@ def process_withdrawals( """ Increase the balance of the withdrawing account. """ - # Capture pre-state for withdrawal balance filtering - withdrawal_addresses = {wd.address for wd in withdrawals} - for address in withdrawal_addresses: - pre_balance = get_account(block_env.state, address).balance - track_address(block_env.state_changes, address) - capture_pre_balance(block_env.state_changes, address, pre_balance) - - def increase_recipient_balance(recipient: Account) -> None: - recipient.balance += wd.amount * U256(10**9) + wd_state = TransactionState(parent=block_env.state) for i, wd in enumerate(withdrawals): trie_set( @@ -1123,20 +1080,12 @@ def increase_recipient_balance(recipient: Account) -> None: rlp.encode(wd), ) - modify_state(block_env.state, wd.address, increase_recipient_balance) - - new_balance = get_account(block_env.state, wd.address).balance - track_balance_change( - block_env.state_changes, - wd.address, - new_balance, - ) - - if account_exists_and_is_empty(block_env.state, wd.address): - destroy_account(block_env.state, wd.address) + track_address(wd_state, wd.address) + current_balance = get_account(wd_state, wd.address).balance + new_balance = current_balance + wd.amount * GWEI_TO_WEI + set_account_balance(wd_state, wd.address, new_balance) - # EIP-7928: Filter net-zero balance changes for withdrawals - filter_net_zero_frame_changes(block_env.state_changes) + incorporate_tx_into_block(wd_state, block_env.block_access_list_builder) def check_gas_limit(gas_limit: Uint, parent_gas_limit: Uint) -> bool: diff --git a/src/ethereum/forks/amsterdam/fork_types.py b/src/ethereum/forks/amsterdam/fork_types.py index 24fbf08baf..a88a2aae28 100644 --- a/src/ethereum/forks/amsterdam/fork_types.py +++ b/src/ethereum/forks/amsterdam/fork_types.py @@ -14,38 +14,18 @@ from dataclasses import dataclass from ethereum_rlp import rlp -from ethereum_types.bytes import Bytes, Bytes20, Bytes256 +from ethereum_types.bytes import Bytes, Bytes256 from ethereum_types.frozen import slotted_freezable -from ethereum_types.numeric import U8, U64, U256, Uint +from ethereum_types.numeric import U8, U64, U256 from ethereum.crypto.hash import Hash32, keccak256 +from ethereum.state import Account, Address -Address = Bytes20 -Root = Hash32 VersionedHash = Hash32 Bloom = Bytes256 -@slotted_freezable -@dataclass -class Account: - """ - State associated with an address. - """ - - nonce: Uint - balance: U256 - code: Bytes - - -EMPTY_ACCOUNT = Account( - nonce=Uint(0), - balance=U256(0), - code=b"", -) - - def encode_account(raw_account_data: Account, storage_root: Bytes) -> Bytes: """ Encode `Account` dataclass. diff --git a/src/ethereum/forks/amsterdam/state.py b/src/ethereum/forks/amsterdam/state.py index 2d04c39834..d6fef3be37 100644 --- a/src/ethereum/forks/amsterdam/state.py +++ b/src/ethereum/forks/amsterdam/state.py @@ -17,17 +17,14 @@ """ from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Tuple -from ethereum_types.bytes import Bytes, Bytes32 -from ethereum_types.frozen import modify -from ethereum_types.numeric import U256, Uint +from ethereum_types.bytes import Bytes32 +from ethereum_types.numeric import U256 -from .fork_types import EMPTY_ACCOUNT, Account, Address, Root -from .trie import EMPTY_TRIE_ROOT, Trie, copy_trie, root, trie_get, trie_set +from ethereum.state import Account, Address, InternalNode, Root -if TYPE_CHECKING: - from .vm import BlockEnvironment # noqa: F401 +from .trie import EMPTY_TRIE_ROOT, Trie, copy_trie, root, trie_get, trie_set @dataclass @@ -42,286 +39,141 @@ class State: _storage_tries: Dict[Address, Trie[Bytes32, U256]] = field( default_factory=dict ) - _snapshots: List[ - Tuple[ - Trie[Address, Optional[Account]], - Dict[Address, Trie[Bytes32, U256]], - ] - ] = field(default_factory=list) - created_accounts: Set[Address] = field(default_factory=set) + def get_account_optional(self, address: Address) -> Optional[Account]: + """ + Get the account at an address. -@dataclass -class TransientStorage: - """ - Contains all information that is preserved between message calls - within a transaction. - """ + Return ``None`` if there is no account at the address. + """ + return trie_get(self._main_trie, address) - _tries: Dict[Address, Trie[Bytes32, U256]] = field(default_factory=dict) - _snapshots: List[Dict[Address, Trie[Bytes32, U256]]] = field( - default_factory=list - ) + def get_storage(self, address: Address, key: Bytes32) -> U256: + """ + Get a storage value. + Return ``U256(0)`` if the key has not been set. + """ + trie = self._storage_tries.get(address) + if trie is None: + return U256(0) -def close_state(state: State) -> None: - """ - Free resources held by the state. Used by optimized implementations to - release file descriptors. - """ - del state._main_trie - del state._storage_tries - del state._snapshots - del state.created_accounts + value = trie_get(trie, key) + assert isinstance(value, U256) + return value -def begin_transaction( - state: State, transient_storage: TransientStorage -) -> None: - """ - Start a state transaction. + def account_has_storage(self, address: Address) -> bool: + """ + Check whether an account has any storage. - Transactions are entirely implicit and can be nested. It is not possible to - calculate the state root during a transaction. + Only needed for EIP-7610. + """ + return address in self._storage_tries - Parameters - ---------- - state : State - The state. - transient_storage : TransientStorage - The transient storage of the transaction. + def compute_state_root_and_trie_changes( + self, + account_changes: Dict[Address, Optional[Account]], + storage_changes: Dict[Address, Dict[Bytes32, U256]], + ) -> Tuple[Root, List[InternalNode]]: + """ + Compute the state root after applying changes to the pre-state. - """ - state._snapshots.append( - ( - copy_trie(state._main_trie), - {k: copy_trie(t) for (k, t) in state._storage_tries.items()}, - ) - ) - transient_storage._snapshots.append( - {k: copy_trie(t) for (k, t) in transient_storage._tries.items()} - ) + Return the new state root together with the internal trie nodes + that were created or modified. + """ + main_trie = copy_trie(self._main_trie) + storage_tries = { + k: copy_trie(v) for k, v in self._storage_tries.items() + } + for address, account in account_changes.items(): + trie_set(main_trie, address, account) -def commit_transaction( - state: State, transient_storage: TransientStorage -) -> None: - """ - Commit a state transaction. + for address, slots in storage_changes.items(): + trie = storage_tries.get(address) + if trie is None: + trie = Trie(secured=True, default=U256(0)) + storage_tries[address] = trie + for key, value in slots.items(): + trie_set(trie, key, value) + if trie._data == {}: + del storage_tries[address] - Parameters - ---------- - state : State - The state. - transient_storage : TransientStorage - The transient storage of the transaction. + def get_storage_root(addr: Address) -> Root: + if addr in storage_tries: + return root(storage_tries[addr]) + return EMPTY_TRIE_ROOT - """ - state._snapshots.pop() - if not state._snapshots: - state.created_accounts.clear() + state_root_value = root(main_trie, get_storage_root=get_storage_root) - transient_storage._snapshots.pop() + return state_root_value, [] -def rollback_transaction( - state: State, transient_storage: TransientStorage -) -> None: +def close_state(state: State) -> None: """ - Rollback a state transaction, resetting the state to the point when the - corresponding `begin_transaction()` call was made. - - Parameters - ---------- - state : State - The state. - transient_storage : TransientStorage - The transient storage of the transaction. - + Free resources held by the state. Used by optimized implementations to + release file descriptors. """ - state._main_trie, state._storage_tries = state._snapshots.pop() - if not state._snapshots: - state.created_accounts.clear() - - transient_storage._tries = transient_storage._snapshots.pop() + del state._main_trie + del state._storage_tries -def get_account(state: State, address: Address) -> Account: +def apply_changes_to_state( + state: State, + account_changes: Dict[Address, Optional[Account]], + storage_changes: Dict[Address, Dict[Bytes32, U256]], +) -> None: """ - Get the `Account` object at an address. Returns `EMPTY_ACCOUNT` if there - is no account at the address. - - Use `get_account_optional()` if you care about the difference between a - non-existent account and `EMPTY_ACCOUNT`. + Apply block-level diffs to the ``State`` for the next block. Parameters ---------- - state: `State` - The state - address : `Address` - Address to lookup. - - Returns - ------- - account : `Account` - Account at address. + state : + The state to update. + account_changes : + Account changes to apply. + storage_changes : + Storage changes to apply. """ - account = get_account_optional(state, address) - if isinstance(account, Account): - return account - else: - return EMPTY_ACCOUNT + for address, account in account_changes.items(): + trie_set(state._main_trie, address, account) - -def get_account_optional(state: State, address: Address) -> Optional[Account]: - """ - Get the `Account` object at an address. Returns `None` (rather than - `EMPTY_ACCOUNT`) if there is no account at the address. - - Parameters - ---------- - state: `State` - The state - address : `Address` - Address to lookup. - - Returns - ------- - account : `Account` - Account at address. - - """ - account = trie_get(state._main_trie, address) - return account + for address, slots in storage_changes.items(): + trie = state._storage_tries.get(address) + if trie is None: + trie = Trie(secured=True, default=U256(0)) + state._storage_tries[address] = trie + for key, value in slots.items(): + trie_set(trie, key, value) + if trie._data == {}: + del state._storage_tries[address] def set_account( - state: State, address: Address, account: Optional[Account] + state: State, + address: Address, + account: Optional[Account], ) -> None: """ - Set the `Account` object at an address. Setting to `None` deletes - the account (but not its storage, see `destroy_account()`). - - Parameters - ---------- - state: `State` - The state - address : `Address` - Address to set. - account : `Account` - Account to set at address. + Set an account in a ``State``. + Setting to ``None`` deletes the account. """ trie_set(state._main_trie, address, account) -def destroy_account(state: State, address: Address) -> None: - """ - Completely remove the account at `address` and all of its storage. - - This function is made available exclusively for the `SELFDESTRUCT` - opcode. It is expected that `SELFDESTRUCT` will be disabled in a future - hardfork and this function will be removed. - - Parameters - ---------- - state: `State` - The state - address : `Address` - Address of account to destroy. - - """ - destroy_storage(state, address) - set_account(state, address, None) - - -def destroy_storage(state: State, address: Address) -> None: - """ - Completely remove the storage at `address`. - - Parameters - ---------- - state: `State` - The state - address : `Address` - Address of account whose storage is to be deleted. - - """ - if address in state._storage_tries: - del state._storage_tries[address] - - -def mark_account_created(state: State, address: Address) -> None: - """ - Mark an account as having been created in the current transaction. - This information is used by `get_storage_original()` to handle an obscure - edgecase, and to respect the constraints added to SELFDESTRUCT by - EIP-6780. - - The marker is not removed even if the account creation reverts. Since the - account cannot have had code prior to its creation and can't call - `get_storage_original()`, this is harmless. - - Parameters - ---------- - state: `State` - The state - address : `Address` - Address of the account that has been created. - - """ - state.created_accounts.add(address) - - -def get_storage(state: State, address: Address, key: Bytes32) -> U256: - """ - Get a value at a storage key on an account. Returns `U256(0)` if the - storage key has not been set previously. - - Parameters - ---------- - state: `State` - The state - address : `Address` - Address of the account. - key : `Bytes` - Key to lookup. - - Returns - ------- - value : `U256` - Value at the key. - - """ - trie = state._storage_tries.get(address) - if trie is None: - return U256(0) - - value = trie_get(trie, key) - - assert isinstance(value, U256) - return value - - def set_storage( - state: State, address: Address, key: Bytes32, value: U256 + state: State, + address: Address, + key: Bytes32, + value: U256, ) -> None: """ - Set a value at a storage key on an account. Setting to `U256(0)` deletes - the key. - - Parameters - ---------- - state: `State` - The state - address : `Address` - Address of the account. - key : `Bytes` - Key to set. - value : `U256` - Value to set at the key. + Set a storage value in a ``State``. + Setting to ``U256(0)`` deletes the key. """ assert trie_get(state._main_trie, address) is not None @@ -334,368 +186,9 @@ def set_storage( del state._storage_tries[address] -def storage_root(state: State, address: Address) -> Root: - """ - Calculate the storage root of an account. - - Parameters - ---------- - state: - The state - address : - Address of the account. - - Returns - ------- - root : `Root` - Storage root of the account. - - """ - assert not state._snapshots - if address in state._storage_tries: - return root(state._storage_tries[address]) - else: - return EMPTY_TRIE_ROOT - - def state_root(state: State) -> Root: """ - Calculate the state root. - - Parameters - ---------- - state: - The current state. - - Returns - ------- - root : `Root` - The state root. - - """ - assert not state._snapshots - - def get_storage_root(address: Address) -> Root: - return storage_root(state, address) - - return root(state._main_trie, get_storage_root=get_storage_root) - - -def account_exists(state: State, address: Address) -> bool: - """ - Checks if an account exists in the state trie. - - Parameters - ---------- - state: - The state - address: - Address of the account that needs to be checked. - - Returns - ------- - account_exists : `bool` - True if account exists in the state trie, False otherwise - - """ - return get_account_optional(state, address) is not None - - -def account_has_code_or_nonce(state: State, address: Address) -> bool: + Compute the state root of the current state. """ - Checks if an account has non-zero nonce or non-empty code. - - Parameters - ---------- - state: - The state - address: - Address of the account that needs to be checked. - - Returns - ------- - has_code_or_nonce : `bool` - True if the account has non-zero nonce or non-empty code, - False otherwise. - - """ - account = get_account(state, address) - return account.nonce != Uint(0) or account.code != b"" - - -def account_has_storage(state: State, address: Address) -> bool: - """ - Checks if an account has storage. - - Parameters - ---------- - state: - The state - address: - Address of the account that needs to be checked. - - Returns - ------- - has_storage : `bool` - True if the account has storage, False otherwise. - - """ - return address in state._storage_tries - - -def account_exists_and_is_empty(state: State, address: Address) -> bool: - """ - Checks if an account exists and has zero nonce, empty code and zero - balance. - - Parameters - ---------- - state: - The state - address: - Address of the account that needs to be checked. - - Returns - ------- - exists_and_is_empty : `bool` - True if an account exists and has zero nonce, empty code and zero - balance, False otherwise. - - """ - account = get_account_optional(state, address) - return ( - account is not None - and account.nonce == Uint(0) - and account.code == b"" - and account.balance == 0 - ) - - -def is_account_alive(state: State, address: Address) -> bool: - """ - Check whether an account is both in the state and non-empty. - - Parameters - ---------- - state: - The state - address: - Address of the account that needs to be checked. - - Returns - ------- - is_alive : `bool` - True if the account is alive. - - """ - account = get_account_optional(state, address) - return account is not None and account != EMPTY_ACCOUNT - - -def modify_state( - state: State, address: Address, f: Callable[[Account], None] -) -> None: - """ - Modify an `Account` in the `State`. If, after modification, the account - exists and has zero nonce, empty code, and zero balance, it is destroyed. - """ - set_account(state, address, modify(get_account(state, address), f)) - if account_exists_and_is_empty(state, address): - destroy_account(state, address) - - -def move_ether( - state: State, - sender_address: Address, - recipient_address: Address, - amount: U256, -) -> None: - """ - Move funds between accounts. - - Parameters - ---------- - state: - The current state. - sender_address: - Address of the sender. - recipient_address: - Address of the recipient. - amount: - The amount to transfer. - - """ - - def reduce_sender_balance(sender: Account) -> None: - if sender.balance < amount: - raise AssertionError - sender.balance -= amount - - def increase_recipient_balance(recipient: Account) -> None: - recipient.balance += amount - - modify_state(state, sender_address, reduce_sender_balance) - modify_state(state, recipient_address, increase_recipient_balance) - - -def set_account_balance(state: State, address: Address, amount: U256) -> None: - """ - Sets the balance of an account. - - Parameters - ---------- - state: - The current state. - - address: - Address of the account whose balance needs to be set. - - amount: - The amount that needs to be set in the balance. - - """ - - def set_balance(account: Account) -> None: - account.balance = amount - - modify_state(state, address, set_balance) - - -def increment_nonce(state: State, address: Address) -> None: - """ - Increments the nonce of an account. - - Parameters - ---------- - state: - The current state. - - address: - Address of the account whose nonce needs to be incremented. - - """ - - def increase_nonce(sender: Account) -> None: - sender.nonce += Uint(1) - - modify_state(state, address, increase_nonce) - - -def set_code(state: State, address: Address, code: Bytes) -> None: - """ - Sets Account code. - - Parameters - ---------- - state: - The current state. - - address: - Address of the account whose code needs to be updated. - - code: - The bytecode that needs to be set. - - """ - - def write_code(sender: Account) -> None: - sender.code = code - - modify_state(state, address, write_code) - - -def get_storage_original(state: State, address: Address, key: Bytes32) -> U256: - """ - Get the original value in a storage slot i.e. the value before the current - transaction began. This function reads the value from the snapshots taken - before executing the transaction. - - Parameters - ---------- - state: - The current state. - address: - Address of the account to read the value from. - key: - Key of the storage slot. - - """ - # In the transaction where an account is created, its preexisting storage - # is ignored. - if address in state.created_accounts: - return U256(0) - - _, original_trie = state._snapshots[0] - original_account_trie = original_trie.get(address) - - if original_account_trie is None: - original_value = U256(0) - else: - original_value = trie_get(original_account_trie, key) - - assert isinstance(original_value, U256) - - return original_value - - -def get_transient_storage( - transient_storage: TransientStorage, address: Address, key: Bytes32 -) -> U256: - """ - Get a value at a storage key on an account from transient storage. - Returns `U256(0)` if the storage key has not been set previously. - - Parameters - ---------- - transient_storage: `TransientStorage` - The transient storage - address : `Address` - Address of the account. - key : `Bytes` - Key to lookup. - - Returns - ------- - value : `U256` - Value at the key. - - """ - trie = transient_storage._tries.get(address) - if trie is None: - return U256(0) - - value = trie_get(trie, key) - - assert isinstance(value, U256) - return value - - -def set_transient_storage( - transient_storage: TransientStorage, - address: Address, - key: Bytes32, - value: U256, -) -> None: - """ - Set a value at a storage key on an account. Setting to `U256(0)` deletes - the key. - - Parameters - ---------- - transient_storage: `TransientStorage` - The transient storage - address : `Address` - Address of the account. - key : `Bytes` - Key to set. - value : `U256` - Value to set at the key. - - """ - trie = transient_storage._tries.get(address) - if trie is None: - trie = Trie(secured=True, default=U256(0)) - transient_storage._tries[address] = trie - trie_set(trie, key, value) - if trie._data == {}: - del transient_storage._tries[address] + root_value, _ = state.compute_state_root_and_trie_changes({}, {}) + return root_value diff --git a/src/ethereum/forks/amsterdam/state_tracker.py b/src/ethereum/forks/amsterdam/state_tracker.py index d0996a7b93..f6f12f2a1f 100644 --- a/src/ethereum/forks/amsterdam/state_tracker.py +++ b/src/ethereum/forks/amsterdam/state_tracker.py @@ -1,556 +1,760 @@ """ -EIP-7928 Block Access Lists: Hierarchical State Change Tracking. +State Tracking for Block Execution. -Frame hierarchy mirrors EVM execution: Block -> Transaction -> Call frames. -Each frame tracks state accesses and merges to parent on completion. +Track state changes on top of a read-only ``PreState``. At block end, +accumulated diffs feed into +``PreState.compute_state_root_and_trie_changes()``. -On success, changes merge upward with net-zero filtering (pre-state vs final). -On failure, only reads merge (writes discarded). Pre-state captures use -first-write-wins semantics and are stored at the transaction frame level. +.. contents:: Table of Contents + :backlinks: none + :local: -[EIP-7928]: https://eips.ethereum.org/EIPS/eip-7928 +Introduction +------------ + +Replace the mutable ``State`` class with lightweight state trackers that +record diffs. ``BlockState`` accumulates committed transaction +changes across a block. ``TransactionState`` tracks in-flight changes +within a single transaction and supports copy-on-write rollback. """ from dataclasses import dataclass, field -from typing import Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple from ethereum_types.bytes import Bytes, Bytes32 -from ethereum_types.numeric import U16, U64, U256 +from ethereum_types.frozen import modify +from ethereum_types.numeric import U256, Uint + +from ethereum.state import EMPTY_ACCOUNT, Account, Address, PreState -from .block_access_lists.rlp_types import BlockAccessIndex -from .fork_types import Address +if TYPE_CHECKING: + from .block_access_lists.builder import BlockAccessListBuilder @dataclass -class StateChanges: +class BlockState: """ - Tracks state changes within a single execution frame. + Accumulate committed transaction-level changes across a block. - Frames form a hierarchy (Block -> Transaction -> Call) linked by parent - references. The block_access_index is stored at the root frame. Pre-state - captures (pre_balances, etc.) are only populated at the transaction level. - """ + Read chain: block writes -> pre_state. - parent: Optional["StateChanges"] = None - block_access_index: BlockAccessIndex = BlockAccessIndex(0) + ``account_reads`` and ``storage_reads`` accumulate across all + transactions for BAL generation. + """ - touched_addresses: Set[Address] = field(default_factory=set) + pre_state: PreState + account_reads: Set[Address] = field(default_factory=set) + account_writes: Dict[Address, Optional[Account]] = field( + default_factory=dict + ) storage_reads: Set[Tuple[Address, Bytes32]] = field(default_factory=set) - storage_writes: Dict[Tuple[Address, Bytes32, BlockAccessIndex], U256] = ( - field(default_factory=dict) + storage_writes: Dict[Address, Dict[Bytes32, U256]] = field( + default_factory=dict ) - balance_changes: Dict[Tuple[Address, BlockAccessIndex], U256] = field( + +@dataclass +class TransactionState: + """ + Track in-flight state changes within a single transaction. + + Read chain: tx writes -> block writes -> pre_state. + + ``storage_reads`` and ``account_reads`` are shared references + that survive rollback (reads from failed calls still appear in the + Block Access List). + """ + + parent: BlockState + account_reads: Set[Address] = field(default_factory=set) + account_writes: Dict[Address, Optional[Account]] = field( default_factory=dict ) - nonce_changes: Set[Tuple[Address, BlockAccessIndex, U64]] = field( - default_factory=set - ) - code_changes: Dict[Tuple[Address, BlockAccessIndex], Bytes] = field( + storage_reads: Set[Tuple[Address, Bytes32]] = field(default_factory=set) + storage_writes: Dict[Address, Dict[Bytes32, U256]] = field( default_factory=dict ) - - # Pre-state captures (transaction-scoped, only populated at tx frame) - pre_balances: Dict[Address, U256] = field(default_factory=dict) - pre_storage: Dict[Tuple[Address, Bytes32], U256] = field( + created_accounts: Set[Address] = field(default_factory=set) + transient_storage: Dict[Tuple[Address, Bytes32], U256] = field( default_factory=dict ) - pre_code: Dict[Address, Bytes] = field(default_factory=dict) -def get_block_frame(state_changes: StateChanges) -> StateChanges: +def get_account_optional( + tx_state: TransactionState, address: Address +) -> Optional[Account]: """ - Walk to the root (block-level) frame. + Get the ``Account`` object at an address. Return ``None`` (rather than + ``EMPTY_ACCOUNT``) if there is no account at the address. Parameters ---------- - state_changes : - Any frame in the hierarchy. + tx_state : + The transaction state. + address : + Address to look up. Returns ------- - block_frame : StateChanges - The root block-level frame. + account : ``Optional[Account]`` + Account at address. """ - block_frame = state_changes - while block_frame.parent is not None: - block_frame = block_frame.parent - return block_frame + if address in tx_state.account_writes: + return tx_state.account_writes[address] + if address in tx_state.parent.account_writes: + return tx_state.parent.account_writes[address] + return tx_state.parent.pre_state.get_account_optional(address) -def increment_block_access_index(root_frame: StateChanges) -> None: +def get_account(tx_state: TransactionState, address: Address) -> Account: """ - Increment the block access index in the root frame. + Get the ``Account`` object at an address. Return ``EMPTY_ACCOUNT`` + if there is no account at the address. + + Use ``get_account_optional()`` if you care about the difference + between a non-existent account and ``EMPTY_ACCOUNT``. Parameters ---------- - root_frame : - The root block-level frame. + tx_state : + The transaction state. + address : + Address to look up. + + Returns + ------- + account : ``Account`` + Account at address. """ - root_frame.block_access_index = root_frame.block_access_index + U16(1) + account = get_account_optional(tx_state, address) + if isinstance(account, Account): + return account + else: + return EMPTY_ACCOUNT -def get_transaction_frame(state_changes: StateChanges) -> StateChanges: +def get_storage( + tx_state: TransactionState, address: Address, key: Bytes32 +) -> U256: """ - Walk to the transaction-level frame (child of block frame). + Get a value at a storage key on an account. Return ``U256(0)`` if + the storage key has not been set previously. Parameters ---------- - state_changes : - Any frame in the hierarchy. + tx_state : + The transaction state. + address : + Address of the account. + key : + Key to look up. Returns ------- - tx_frame : StateChanges - The transaction-level frame. + value : ``U256`` + Value at the key. """ - tx_frame = state_changes - while tx_frame.parent is not None and tx_frame.parent.parent is not None: - tx_frame = tx_frame.parent - return tx_frame + if address in tx_state.storage_writes: + if key in tx_state.storage_writes[address]: + return tx_state.storage_writes[address][key] + if address in tx_state.parent.storage_writes: + if key in tx_state.parent.storage_writes[address]: + return tx_state.parent.storage_writes[address][key] + return tx_state.parent.pre_state.get_storage(address, key) -def capture_pre_balance( - tx_frame: StateChanges, address: Address, balance: U256 -) -> None: +def get_storage_original( + tx_state: TransactionState, address: Address, key: Bytes32 +) -> U256: """ - Capture pre-balance if not already captured (first-write-wins). + Get the original value in a storage slot i.e. the value before the + current transaction began. Read from block-level writes, then + pre_state. Return ``U256(0)`` for accounts created in the current + transaction. Parameters ---------- - tx_frame : - The transaction-level frame. + tx_state : + The transaction state. address : - The address whose balance to capture. - balance : - The current balance value. + Address of the account to read the value from. + key : + Key of the storage slot. """ - # Only capture pre-values in a transaction level - # or block level frame - assert tx_frame.parent is None or tx_frame.parent.parent is None - if address not in tx_frame.pre_balances: - tx_frame.pre_balances[address] = balance + if address in tx_state.created_accounts: + return U256(0) + if address in tx_state.parent.storage_writes: + if key in tx_state.parent.storage_writes[address]: + return tx_state.parent.storage_writes[address][key] + return tx_state.parent.pre_state.get_storage(address, key) -def capture_pre_storage( - tx_frame: StateChanges, address: Address, key: Bytes32, value: U256 -) -> None: +def get_transient_storage( + tx_state: TransactionState, address: Address, key: Bytes32 +) -> U256: """ - Capture pre-storage value if not already captured (first-write-wins). + Get a value at a storage key on an account from transient storage. + Return ``U256(0)`` if the storage key has not been set previously. Parameters ---------- - tx_frame : - The transaction-level frame. + tx_state : + The transaction state. address : - The address whose storage to capture. + Address of the account. key : - The storage key. - value : - The current storage value. + Key to look up. + + Returns + ------- + value : ``U256`` + Value at the key. """ - # Only capture pre-values in a transaction level - # or block level frame - assert tx_frame.parent is None or tx_frame.parent.parent is None - slot = (address, key) - if slot not in tx_frame.pre_storage: - tx_frame.pre_storage[slot] = value + return tx_state.transient_storage.get((address, key), U256(0)) -def capture_pre_code( - tx_frame: StateChanges, address: Address, code: Bytes -) -> None: +def account_exists(tx_state: TransactionState, address: Address) -> bool: """ - Capture pre-code if not already captured (first-write-wins). + Check if an account exists in the state trie. Parameters ---------- - tx_frame : - The transaction-level frame. + tx_state : + The transaction state. address : - The address whose code to capture. - code : - The current code value. + Address of the account that needs to be checked. + + Returns + ------- + account_exists : ``bool`` + True if account exists in the state trie, False otherwise. """ - # Only capture pre-values in a transaction level - # or block level frame - assert tx_frame.parent is None or tx_frame.parent.parent is None - if address not in tx_frame.pre_code: - tx_frame.pre_code[address] = code + return get_account_optional(tx_state, address) is not None -def track_address(state_changes: StateChanges, address: Address) -> None: +def account_has_code_or_nonce( + tx_state: TransactionState, address: Address +) -> bool: """ - Record that an address was accessed. + Check if an account has non-zero nonce or non-empty code. Parameters ---------- - state_changes : - The state changes frame. + tx_state : + The transaction state. address : - The address that was accessed. + Address of the account that needs to be checked. + + Returns + ------- + has_code_or_nonce : ``bool`` + True if the account has non-zero nonce or non-empty code, + False otherwise. """ - state_changes.touched_addresses.add(address) + account = get_account(tx_state, address) + return account.nonce != Uint(0) or account.code != b"" -def track_storage_read( - state_changes: StateChanges, address: Address, key: Bytes32 +def account_has_storage(tx_state: TransactionState, address: Address) -> bool: + """ + Check if an account has storage. + + Parameters + ---------- + tx_state : + The transaction state. + address : + Address of the account that needs to be checked. + + Returns + ------- + has_storage : ``bool`` + True if the account has storage, False otherwise. + + """ + if tx_state.storage_writes.get(address): + return True + if tx_state.parent.storage_writes.get(address): + return True + return tx_state.parent.pre_state.account_has_storage(address) + + +def account_exists_and_is_empty( + tx_state: TransactionState, address: Address +) -> bool: + """ + Check if an account exists and has zero nonce, empty code and zero + balance. + + Parameters + ---------- + tx_state : + The transaction state. + address : + Address of the account that needs to be checked. + + Returns + ------- + exists_and_is_empty : ``bool`` + True if an account exists and has zero nonce, empty code and + zero balance, False otherwise. + + """ + account = get_account_optional(tx_state, address) + return ( + account is not None + and account.nonce == Uint(0) + and account.code == b"" + and account.balance == 0 + ) + + +def is_account_alive(tx_state: TransactionState, address: Address) -> bool: + """ + Check whether an account is both in the state and non-empty. + + Parameters + ---------- + tx_state : + The transaction state. + address : + Address of the account that needs to be checked. + + Returns + ------- + is_alive : ``bool`` + True if the account is alive. + + """ + account = get_account_optional(tx_state, address) + return account is not None and account != EMPTY_ACCOUNT + + +def set_account( + tx_state: TransactionState, + address: Address, + account: Optional[Account], ) -> None: """ - Record a storage read operation. + Set the ``Account`` object at an address. Setting to ``None`` + deletes the account (but not its storage, see + ``destroy_account()``). Parameters ---------- - state_changes : - The state changes frame. + tx_state : + The transaction state. address : - The address whose storage was read. - key : - The storage key that was read. + Address to set. + account : + Account to set at address. """ - state_changes.storage_reads.add((address, key)) + tx_state.account_writes[address] = account -def track_storage_write( - state_changes: StateChanges, +def set_storage( + tx_state: TransactionState, address: Address, key: Bytes32, value: U256, ) -> None: """ - Record a storage write keyed by (address, key, block_access_index). + Set a value at a storage key on an account. Parameters ---------- - state_changes : - The state changes frame. + tx_state : + The transaction state. address : - The address whose storage was written. + Address of the account. key : - The storage key that was written. + Key to set. value : - The new storage value. + Value to set at the key. """ - idx = state_changes.block_access_index - state_changes.storage_writes[(address, key, idx)] = value + assert get_account_optional(tx_state, address) is not None + if address not in tx_state.storage_writes: + tx_state.storage_writes[address] = {} + tx_state.storage_writes[address][key] = value -def track_balance_change( - state_changes: StateChanges, - address: Address, - new_balance: U256, -) -> None: +def destroy_account(tx_state: TransactionState, address: Address) -> None: """ - Record a balance change keyed by (address, block_access_index). + Completely remove the account at ``address`` and all of its storage. + + This function is made available exclusively for the ``SELFDESTRUCT`` + opcode. It is expected that ``SELFDESTRUCT`` will be disabled in a + future hardfork and this function will be removed. Only supports same + transaction destruction. Parameters ---------- - state_changes : - The state changes frame. + tx_state : + The transaction state. address : - The address whose balance changed. - new_balance : - The new balance value. + Address of account to destroy. """ - idx = state_changes.block_access_index - state_changes.balance_changes[(address, idx)] = new_balance + destroy_storage(tx_state, address) + set_account(tx_state, address, None) -def track_nonce_change( - state_changes: StateChanges, - address: Address, - new_nonce: U64, -) -> None: +def destroy_storage(tx_state: TransactionState, address: Address) -> None: """ - Record a nonce change as (address, block_access_index, new_nonce). + Completely remove the storage at ``address``. + + Convert storage writes to reads before deleting so that accesses + from created-then-destroyed accounts appear in the Block Access + List. Only supports same transaction destruction. Parameters ---------- - state_changes : - The state changes frame. + tx_state : + The transaction state. address : - The address whose nonce changed. - new_nonce : - The new nonce value. + Address of account whose storage is to be deleted. """ - idx = state_changes.block_access_index - state_changes.nonce_changes.add((address, idx, new_nonce)) + if address in tx_state.storage_writes: + for key in tx_state.storage_writes[address]: + tx_state.storage_reads.add((address, key)) + del tx_state.storage_writes[address] + +def mark_account_created(tx_state: TransactionState, address: Address) -> None: + """ + Mark an account as having been created in the current transaction. + This information is used by ``get_storage_original()`` to handle an + obscure edgecase, and to respect the constraints added to + SELFDESTRUCT by EIP-6780. -def track_code_change( - state_changes: StateChanges, + The marker is not removed even if the account creation reverts. + Since the account cannot have had code prior to its creation and + can't call ``get_storage_original()``, this is harmless. + + Parameters + ---------- + tx_state : + The transaction state. + address : + Address of the account that has been created. + + """ + tx_state.created_accounts.add(address) + + +def set_transient_storage( + tx_state: TransactionState, address: Address, - new_code: Bytes, + key: Bytes32, + value: U256, ) -> None: """ - Record a code change keyed by (address, block_access_index). + Set a value at a storage key on an account in transient storage. Parameters ---------- - state_changes : - The state changes frame. + tx_state : + The transaction state. address : - The address whose code changed. - new_code : - The new code value. + Address of the account. + key : + Key to set. + value : + Value to set at the key. """ - idx = state_changes.block_access_index - state_changes.code_changes[(address, idx)] = new_code + if value == U256(0): + tx_state.transient_storage.pop((address, key), None) + else: + tx_state.transient_storage[(address, key)] = value -def track_selfdestruct( - tx_frame: StateChanges, +def modify_state( + tx_state: TransactionState, address: Address, + f: Callable[[Account], None], ) -> None: """ - Handle selfdestruct of account created in same transaction. + Modify an ``Account`` in the state. If, after modification, the + account exists and has zero nonce, empty code, and zero balance, it + is destroyed. + """ + set_account(tx_state, address, modify(get_account(tx_state, address), f)) + if account_exists_and_is_empty(tx_state, address): + destroy_account(tx_state, address) + - Per EIP-7928/EIP-6780: removes nonce/code changes, converts storage - writes to reads. Balance changes handled by net-zero filtering. +def move_ether( + tx_state: TransactionState, + sender_address: Address, + recipient_address: Address, + amount: U256, +) -> None: + """ + Move funds between accounts. Parameters ---------- - tx_frame : - The state changes tracker. Should be a transaction frame. - address : - The address that self-destructed. + tx_state : + The transaction state. + sender_address : + Address of the sender. + recipient_address : + Address of the recipient. + amount : + The amount to transfer. """ - # Has to be a transaction frame - assert tx_frame.parent is not None and tx_frame.parent.parent is None - idx = tx_frame.block_access_index + def reduce_sender_balance(sender: Account) -> None: + if sender.balance < amount: + raise AssertionError + sender.balance -= amount - # Remove nonce changes from current transaction - tx_frame.nonce_changes = { - (addr, i, nonce) - for addr, i, nonce in tx_frame.nonce_changes - if not (addr == address and i == idx) - } + def increase_recipient_balance(recipient: Account) -> None: + recipient.balance += amount - # Remove balance changes from current transaction - if (address, idx) in tx_frame.balance_changes: - pre_balance = tx_frame.pre_balances[address] - if pre_balance == U256(0): - # Post balance will be U256(0) after deletion. - # So no change and hence bal does not need to - # capture anything. - del tx_frame.balance_changes[(address, idx)] + modify_state(tx_state, sender_address, reduce_sender_balance) + modify_state(tx_state, recipient_address, increase_recipient_balance) - # Remove code changes from current transaction - if (address, idx) in tx_frame.code_changes: - del tx_frame.code_changes[(address, idx)] - # Convert storage writes from current transaction to reads - for addr, key, i in list(tx_frame.storage_writes.keys()): - if addr == address and i == idx: - del tx_frame.storage_writes[(addr, key, i)] - tx_frame.storage_reads.add((addr, key)) +def set_account_balance( + tx_state: TransactionState, address: Address, amount: U256 +) -> None: + """ + Set the balance of an account. + Parameters + ---------- + tx_state : + The transaction state. + address : + Address of the account whose balance needs to be set. + amount : + The amount that needs to be set in the balance. -def merge_on_success(child_frame: StateChanges) -> None: """ - Merge child frame into parent on success. - Child values overwrite parent values (most recent wins). No net-zero - filtering here - that happens once at transaction commit via - normalize_transaction(). + def set_balance(account: Account) -> None: + account.balance = amount + + modify_state(tx_state, address, set_balance) + + +def increment_nonce(tx_state: TransactionState, address: Address) -> None: + """ + Increment the nonce of an account. Parameters ---------- - child_frame : - The child frame being merged. + tx_state : + The transaction state. + address : + Address of the account whose nonce needs to be incremented. """ - assert child_frame.parent is not None - parent_frame = child_frame.parent - # Merge address accesses - parent_frame.touched_addresses.update(child_frame.touched_addresses) + def increase_nonce(sender: Account) -> None: + sender.nonce += Uint(1) - # Merge storage: reads union, writes overwrite (child supersedes parent) - parent_frame.storage_reads.update(child_frame.storage_reads) - for storage_key, storage_value in child_frame.storage_writes.items(): - parent_frame.storage_writes[storage_key] = storage_value + modify_state(tx_state, address, increase_nonce) - # Merge balance changes: child overwrites parent for same key - for balance_key, balance_value in child_frame.balance_changes.items(): - parent_frame.balance_changes[balance_key] = balance_value - # Merge nonce changes: keep highest nonce per address - address_final_nonces: Dict[Address, Tuple[BlockAccessIndex, U64]] = {} - for addr, idx, nonce in child_frame.nonce_changes: - if ( - addr not in address_final_nonces - or nonce > address_final_nonces[addr][1] - ): - address_final_nonces[addr] = (idx, nonce) - for addr, (idx, final_nonce) in address_final_nonces.items(): - parent_frame.nonce_changes.add((addr, idx, final_nonce)) +def set_code( + tx_state: TransactionState, address: Address, code: Bytes +) -> None: + """ + Set Account code. - # Merge code changes: child overwrites parent for same key - for code_key, code_value in child_frame.code_changes.items(): - parent_frame.code_changes[code_key] = code_value + Parameters + ---------- + tx_state : + The transaction state. + address : + Address of the account whose code needs to be updated. + code : + The bytecode that needs to be set. + + """ + + def write_code(sender: Account) -> None: + sender.code = code + + modify_state(tx_state, address, write_code) + + +# -- Snapshot / Rollback --------------------------------------------------- -def merge_on_failure(child_frame: StateChanges) -> None: +def copy_tx_state(tx_state: TransactionState) -> TransactionState: """ - Merge child frame into parent on failure/revert. + Create a snapshot of the transaction state for rollback. - Only reads merge; writes are discarded (converted to reads). + Deep-copy writes and transient storage. The parent reference, + ``created_accounts``, ``storage_reads``, and ``account_reads`` + are shared (not rolled back). Parameters ---------- - child_frame : - The failed child frame. + tx_state : + The transaction state to snapshot. + + Returns + ------- + snapshot : ``TransactionState`` + A copy of the transaction state. + + """ + return TransactionState( + parent=tx_state.parent, + account_writes=dict(tx_state.account_writes), + storage_writes={ + addr: dict(slots) + for addr, slots in tx_state.storage_writes.items() + }, + created_accounts=tx_state.created_accounts, + transient_storage=dict(tx_state.transient_storage), + storage_reads=tx_state.storage_reads, + account_reads=tx_state.account_reads, + ) + + +def restore_tx_state( + tx_state: TransactionState, snapshot: TransactionState +) -> None: + """ + Restore transaction state from a snapshot (rollback on failure). + + Parameters + ---------- + tx_state : + The transaction state to restore. + snapshot : + The snapshot to restore from. """ - assert child_frame.parent is not None - parent_frame = child_frame.parent - # Only merge reads and address accesses on failure - parent_frame.touched_addresses.update(child_frame.touched_addresses) - parent_frame.storage_reads.update(child_frame.storage_reads) + tx_state.account_writes = snapshot.account_writes + tx_state.storage_writes = snapshot.storage_writes + tx_state.transient_storage = snapshot.transient_storage - # Convert writes to reads (failed writes still accessed the slots) - for address, key, _idx in child_frame.storage_writes.keys(): - parent_frame.storage_reads.add((address, key)) - # Note: balance_changes, nonce_changes, and code_changes are NOT - # merged on failure - they are discarded +# -- Lifecycle -------------------------------------------------------------- -def commit_transaction_frame(tx_frame: StateChanges) -> None: +def incorporate_tx_into_block( + tx_state: TransactionState, + builder: "BlockAccessListBuilder", +) -> None: """ - Commit transaction frame to block frame. + Merge transaction writes into the block state and clear for reuse. - Filters net-zero changes before merging to ensure only actual state - modifications are recorded in the block access list. + Update the BAL builder incrementally by diffing this transaction's + writes against the block's cumulative state. Merge reads and + touches into block-level sets. Parameters ---------- - tx_frame : - The transaction frame to commit. + tx_state : + The transaction state to commit. + builder : + The BAL builder for incremental updates. """ - assert tx_frame.parent is not None - block_frame = tx_frame.parent + from .block_access_lists.builder import update_builder_from_tx - # Filter net-zero changes before committing - filter_net_zero_frame_changes(tx_frame) + block = tx_state.parent - # Merge address accesses - block_frame.touched_addresses.update(tx_frame.touched_addresses) + # Update BAL builder before merging writes into block state + update_builder_from_tx(builder, tx_state) - # Merge storage operations - block_frame.storage_reads.update(tx_frame.storage_reads) - for (addr, key, idx), value in tx_frame.storage_writes.items(): - block_frame.storage_writes[(addr, key, idx)] = value + # Merge reads and touches into block-level sets + block.storage_reads.update(tx_state.storage_reads) + block.account_reads.update(tx_state.account_reads) - # Merge balance changes - for (addr, idx), final_balance in tx_frame.balance_changes.items(): - block_frame.balance_changes[(addr, idx)] = final_balance + # Merge cumulative writes + for address, account in tx_state.account_writes.items(): + block.account_writes[address] = account - # Merge nonce changes - for addr, idx, nonce in tx_frame.nonce_changes: - block_frame.nonce_changes.add((addr, idx, nonce)) + for address, slots in tx_state.storage_writes.items(): + if address not in block.storage_writes: + block.storage_writes[address] = {} + block.storage_writes[address].update(slots) - # Merge code changes - for (addr, idx), final_code in tx_frame.code_changes.items(): - block_frame.code_changes[(addr, idx)] = final_code + tx_state.account_writes.clear() + tx_state.storage_writes.clear() + tx_state.created_accounts.clear() + tx_state.transient_storage.clear() + tx_state.storage_reads = set() + tx_state.account_reads = set() -def create_child_frame(parent: StateChanges) -> StateChanges: +def extract_block_diffs( + block_state: BlockState, +) -> Tuple[ + Dict[Address, Optional[Account]], + Dict[Address, Dict[Bytes32, U256]], +]: """ - Create a child frame linked to the given parent. - - Inherits block_access_index from parent so track functions can - access it directly without walking up the frame hierarchy. + Extract account and storage diffs from the block state. Parameters ---------- - parent : - The parent frame. + block_state : + The block state. Returns ------- - child : StateChanges - A new child frame with parent reference and inherited - block_access_index. + account_diffs : + Account changes to apply. + storage_diffs : + Storage changes to apply. """ - return StateChanges( - parent=parent, - block_access_index=parent.block_access_index, - ) + return block_state.account_writes, block_state.storage_writes + +# -- BAL Tracking ----------------------------------------------------------- + + +def track_address(tx_state: TransactionState, address: Address) -> None: + """ + Record that an address was accessed. + + Parameters + ---------- + tx_state : + The transaction state. + address : + The address that was accessed. -def filter_net_zero_frame_changes(tx_frame: StateChanges) -> None: """ - Filter net-zero changes from transaction frame before commit. + tx_state.account_reads.add(address) + - Compares final values against pre-tx state for storage, balance, and code. - Net-zero storage writes are converted to reads. Net-zero balance/code - changes are removed entirely. Nonces are not filtered (only increment). +def track_storage_read( + tx_state: TransactionState, address: Address, key: Bytes32 +) -> None: + """ + Record a storage read operation. Parameters ---------- - tx_frame : - The transaction-level state changes frame. - - """ - idx = tx_frame.block_access_index - - # Filter storage: compare against pre_storage, convert net-zero to reads - addresses_to_check_storage = [ - (addr, key) - for (addr, key, i) in tx_frame.storage_writes.keys() - if i == idx - ] - for addr, key in addresses_to_check_storage: - # For any (address, key) whose balance has changed, its - # pre-value should have been captured - assert (addr, key) in tx_frame.pre_storage - pre_value = tx_frame.pre_storage[(addr, key)] - post_value = tx_frame.storage_writes[(addr, key, idx)] - if pre_value == post_value: - # Net-zero write - convert to read - del tx_frame.storage_writes[(addr, key, idx)] - tx_frame.storage_reads.add((addr, key)) - - # Filter balance: compare pre vs post, remove if equal - addresses_to_check_balance = [ - addr for (addr, i) in tx_frame.balance_changes.keys() if i == idx - ] - for addr in addresses_to_check_balance: - # For any account whose balance has changed, its - # pre-balance should have been captured - assert addr in tx_frame.pre_balances - pre_balance = tx_frame.pre_balances[addr] - post_balance = tx_frame.balance_changes[(addr, idx)] - if pre_balance == post_balance: - del tx_frame.balance_changes[(addr, idx)] - - # Filter code: compare pre vs post, remove if equal - addresses_to_check_code = [ - addr for (addr, i) in tx_frame.code_changes.keys() if i == idx - ] - for addr in addresses_to_check_code: - assert addr in tx_frame.pre_code - pre_code = tx_frame.pre_code[addr] - post_code = tx_frame.code_changes[(addr, idx)] - if pre_code == post_code: - del tx_frame.code_changes[(addr, idx)] - - # Nonces: no filtering needed (nonces only increment, never net-zero) + tx_state : + The transaction state. + address : + The address whose storage was read. + key : + The storage key that was read. + + """ + tx_state.storage_reads.add((address, key)) diff --git a/src/ethereum/forks/amsterdam/transactions.py b/src/ethereum/forks/amsterdam/transactions.py index bea74819ef..9246f85903 100644 --- a/src/ethereum/forks/amsterdam/transactions.py +++ b/src/ethereum/forks/amsterdam/transactions.py @@ -19,13 +19,14 @@ InvalidSignatureError, NonceOverflowError, ) +from ethereum.state import Address from .exceptions import ( InitCodeTooLargeError, TransactionGasLimitExceededError, TransactionTypeError, ) -from .fork_types import Address, Authorization, VersionedHash +from .fork_types import Authorization, VersionedHash TX_BASE_COST = Uint(21000) """ diff --git a/src/ethereum/forks/amsterdam/trie.py b/src/ethereum/forks/amsterdam/trie.py index 77c41f7ff8..0f89d29583 100644 --- a/src/ethereum/forks/amsterdam/trie.py +++ b/src/ethereum/forks/amsterdam/trie.py @@ -30,16 +30,25 @@ from ethereum_rlp import Extended, rlp from ethereum_types.bytes import Bytes -from ethereum_types.frozen import slotted_freezable from ethereum_types.numeric import U256, Uint from typing_extensions import assert_type from ethereum.crypto.hash import keccak256 from ethereum.forks.bpo5 import trie as previous_trie +from ethereum.state import ( + Account, + Address, + BranchNode, + BranchSubnodes, + ExtensionNode, + InternalNode, + LeafNode, + Root, +) from ethereum.utils.hexadecimal import hex_to_bytes from .blocks import Receipt, Withdrawal -from .fork_types import Account, Address, Root, encode_account +from .fork_types import encode_account from .transactions import LegacyTransaction # note: an empty trie (regardless of whether it is secured) has root: @@ -85,56 +94,6 @@ ) -@slotted_freezable -@dataclass -class LeafNode: - """Leaf node in the Merkle Trie.""" - - rest_of_key: Bytes - value: Extended - - -@slotted_freezable -@dataclass -class ExtensionNode: - """Extension node in the Merkle Trie.""" - - key_segment: Bytes - subnode: Extended - - -BranchSubnodes = Tuple[ - Extended, - Extended, - Extended, - Extended, - Extended, - Extended, - Extended, - Extended, - Extended, - Extended, - Extended, - Extended, - Extended, - Extended, - Extended, - Extended, -] - - -@slotted_freezable -@dataclass -class BranchNode: - """Branch node in the Merkle Trie.""" - - subnodes: BranchSubnodes - value: Extended - - -InternalNode = LeafNode | ExtensionNode | BranchNode - - def encode_internal_node(node: Optional[InternalNode]) -> Extended: """ Encodes a Merkle Trie node into its RLP form. The RLP will then be diff --git a/src/ethereum/forks/amsterdam/utils/address.py b/src/ethereum/forks/amsterdam/utils/address.py index 270d562ca3..1ae7eadf49 100644 --- a/src/ethereum/forks/amsterdam/utils/address.py +++ b/src/ethereum/forks/amsterdam/utils/address.py @@ -17,10 +17,9 @@ from ethereum_types.numeric import U256, Uint from ethereum.crypto.hash import keccak256 +from ethereum.state import Address from ethereum.utils.byte import left_pad_zero_bytes -from ..fork_types import Address - def to_address_masked(data: Uint | U256) -> Address: """ diff --git a/src/ethereum/forks/amsterdam/utils/hexadecimal.py b/src/ethereum/forks/amsterdam/utils/hexadecimal.py index 23401e5d4f..995f9b1132 100644 --- a/src/ethereum/forks/amsterdam/utils/hexadecimal.py +++ b/src/ethereum/forks/amsterdam/utils/hexadecimal.py @@ -14,10 +14,9 @@ from ethereum_types.bytes import Bytes +from ethereum.state import Address, Root from ethereum.utils.hexadecimal import remove_hex_prefix -from ..fork_types import Address, Root - def hex_to_root(hex_string: str) -> Root: """ diff --git a/src/ethereum/forks/amsterdam/utils/message.py b/src/ethereum/forks/amsterdam/utils/message.py index 130532fef6..d74a6deb7b 100644 --- a/src/ethereum/forks/amsterdam/utils/message.py +++ b/src/ethereum/forks/amsterdam/utils/message.py @@ -15,9 +15,9 @@ from ethereum_types.bytes import Bytes, Bytes0 from ethereum_types.numeric import Uint -from ..fork_types import Address -from ..state import get_account -from ..state_tracker import create_child_frame +from ethereum.state import Address + +from ..state_tracker import get_account from ..transactions import Transaction from ..vm import BlockEnvironment, Message, TransactionEnvironment from ..vm.precompiled_contracts.mapping import PRE_COMPILED_CONTRACTS @@ -55,7 +55,7 @@ def prepare_message( if isinstance(tx.to, Bytes0): current_target = compute_contract_address( tx_env.origin, - get_account(block_env.state, tx_env.origin).nonce - Uint(1), + get_account(tx_env.state, tx_env.origin).nonce - Uint(1), ) msg_data = Bytes(b"") code = tx.data @@ -63,16 +63,13 @@ def prepare_message( elif isinstance(tx.to, Address): current_target = tx.to msg_data = tx.data - code = get_account(block_env.state, tx.to).code + code = get_account(tx_env.state, tx.to).code code_address = tx.to else: raise AssertionError("Target must be address or empty bytes") accessed_addresses.add(current_target) - # Create call frame as child of transaction frame - call_frame = create_child_frame(tx_env.state_changes) - return Message( block_env=block_env, tx_env=tx_env, @@ -92,5 +89,4 @@ def prepare_message( disable_precompiles=False, parent_evm=None, is_create=isinstance(tx.to, Bytes0), - state_changes=call_frame, ) diff --git a/src/ethereum/forks/amsterdam/vm/__init__.py b/src/ethereum/forks/amsterdam/vm/__init__.py index 3d69fbd706..3f588961ac 100644 --- a/src/ethereum/forks/amsterdam/vm/__init__.py +++ b/src/ethereum/forks/amsterdam/vm/__init__.py @@ -20,12 +20,13 @@ from ethereum.crypto.hash import Hash32 from ethereum.exceptions import EthereumException +from ethereum.state import Address +from ..block_access_lists.builder import BlockAccessListBuilder from ..block_access_lists.rlp_types import BlockAccessList from ..blocks import Log, Receipt, Withdrawal -from ..fork_types import Address, Authorization, VersionedHash -from ..state import State, TransientStorage -from ..state_tracker import StateChanges, merge_on_failure, merge_on_success +from ..fork_types import Authorization, VersionedHash +from ..state_tracker import BlockState, TransactionState from ..transactions import LegacyTransaction from ..trie import Trie @@ -39,7 +40,7 @@ class BlockEnvironment: """ chain_id: U64 - state: State + state: BlockState block_gas_limit: Uint block_hashes: List[Hash32] coinbase: Address @@ -49,7 +50,7 @@ class BlockEnvironment: prev_randao: Bytes32 excess_blob_gas: U64 parent_beacon_block_root: Hash32 - state_changes: StateChanges + block_access_list_builder: BlockAccessListBuilder @dataclass @@ -108,12 +109,11 @@ class TransactionEnvironment: gas: Uint access_list_addresses: Set[Address] access_list_storage_keys: Set[Tuple[Address, Bytes32]] - transient_storage: TransientStorage + state: TransactionState blob_versioned_hashes: Tuple[VersionedHash, ...] authorizations: Tuple[Authorization, ...] index_in_block: Optional[Uint] tx_hash: Optional[Hash32] - state_changes: "StateChanges" = field(default_factory=StateChanges) @dataclass @@ -140,7 +140,6 @@ class Message: disable_precompiles: bool parent_evm: Optional["Evm"] is_create: bool - state_changes: "StateChanges" = field(default_factory=StateChanges) @dataclass @@ -163,7 +162,6 @@ class Evm: error: Optional[EthereumException] accessed_addresses: Set[Address] accessed_storage_keys: Set[Tuple[Address, Bytes32]] - state_changes: StateChanges def incorporate_child_on_success(evm: Evm, child_evm: Evm) -> None: @@ -185,8 +183,6 @@ def incorporate_child_on_success(evm: Evm, child_evm: Evm) -> None: evm.accessed_addresses.update(child_evm.accessed_addresses) evm.accessed_storage_keys.update(child_evm.accessed_storage_keys) - merge_on_success(child_evm.state_changes) - def incorporate_child_on_error(evm: Evm, child_evm: Evm) -> None: """ @@ -201,5 +197,3 @@ def incorporate_child_on_error(evm: Evm, child_evm: Evm) -> None: """ evm.gas_left += child_evm.gas_left - - merge_on_failure(child_evm.state_changes) diff --git a/src/ethereum/forks/amsterdam/vm/eoa_delegation.py b/src/ethereum/forks/amsterdam/vm/eoa_delegation.py index ba645ae7c4..7d8ed23982 100644 --- a/src/ethereum/forks/amsterdam/vm/eoa_delegation.py +++ b/src/ethereum/forks/amsterdam/vm/eoa_delegation.py @@ -10,19 +10,15 @@ from ethereum.crypto.elliptic_curve import SECP256K1N, secp256k1_recover from ethereum.crypto.hash import keccak256 from ethereum.exceptions import InvalidBlock, InvalidSignatureError +from ethereum.state import Address -from ..fork_types import Address, Authorization -from ..state import ( +from ..fork_types import Authorization +from ..state_tracker import ( account_exists, get_account, increment_nonce, set_code, -) -from ..state_tracker import ( - capture_pre_code, track_address, - track_code_change, - track_nonce_change, ) from ..utils.hexadecimal import hex_to_address from ..vm.gas import GAS_COLD_ACCOUNT_ACCESS, GAS_WARM_ACCESS @@ -143,10 +139,10 @@ def calculate_delegation_cost( The delegation address and access gas cost. """ - state = evm.message.block_env.state + tx_state = evm.message.tx_env.state - code = get_account(state, address).code - track_address(evm.state_changes, address) + code = get_account(tx_state, address).code + track_address(tx_state, address) if not is_valid_delegation(code): return False, address, Uint(0) @@ -176,7 +172,7 @@ def set_delegation(message: Message) -> U256: Refund from authority which already exists in state. """ - state = message.block_env.state + tx_state = message.tx_env.state refund_counter = U256(0) for auth in message.tx_env.authorizations: if auth.chain_id not in (message.block_env.chain_id, U256(0)): @@ -192,9 +188,9 @@ def set_delegation(message: Message) -> U256: message.accessed_addresses.add(authority) - authority_account = get_account(state, authority) + authority_account = get_account(tx_state, authority) authority_code = authority_account.code - track_address(message.tx_env.state_changes, authority) + track_address(tx_state, authority) if authority_code and not is_valid_delegation(authority_code): continue @@ -203,7 +199,7 @@ def set_delegation(message: Message) -> U256: if authority_nonce != auth.nonce: continue - if account_exists(state, authority): + if account_exists(tx_state, authority): refund_counter += U256(PER_EMPTY_ACCOUNT_COST - PER_AUTH_BASE_COST) if auth.address == NULL_ADDRESS: @@ -211,23 +207,12 @@ def set_delegation(message: Message) -> U256: else: code_to_set = EOA_DELEGATION_MARKER + auth.address - tx_frame = message.tx_env.state_changes - # EIP-7928: Capture pre-code before any changes - capture_pre_code(tx_frame, authority, authority_code) - - set_code(state, authority, code_to_set) - - if authority_code != code_to_set: - # Track code change if different from current - track_code_change(tx_frame, authority, code_to_set) - - increment_nonce(state, authority) - nonce_after = get_account(state, authority).nonce - track_nonce_change(tx_frame, authority, U64(nonce_after)) + set_code(tx_state, authority, code_to_set) + increment_nonce(tx_state, authority) if message.code_address is None: raise InvalidBlock("Invalid type 4 transaction: no target") - message.code = get_account(state, message.code_address).code + message.code = get_account(tx_state, message.code_address).code return refund_counter diff --git a/src/ethereum/forks/amsterdam/vm/instructions/environment.py b/src/ethereum/forks/amsterdam/vm/instructions/environment.py index 45c3bfe835..0ae81489b2 100644 --- a/src/ethereum/forks/amsterdam/vm/instructions/environment.py +++ b/src/ethereum/forks/amsterdam/vm/instructions/environment.py @@ -15,12 +15,10 @@ from ethereum_types.numeric import U256, Uint, ulen from ethereum.crypto.hash import keccak256 +from ethereum.state import EMPTY_ACCOUNT from ethereum.utils.numeric import ceil32 -# track_address_access removed - now using state_changes.track_address() -from ...fork_types import EMPTY_ACCOUNT -from ...state import get_account -from ...state_tracker import track_address +from ...state_tracker import get_account, track_address from ...utils.address import to_address_masked from ...vm.memory import buffer_read, memory_write from .. import Evm @@ -87,9 +85,9 @@ def balance(evm: Evm) -> None: # OPERATION # Non-existent accounts default to EMPTY_ACCOUNT, which has balance 0. - state = evm.message.block_env.state - balance = get_account(state, address).balance - track_address(evm.state_changes, address) + tx_state = evm.message.tx_env.state + balance = get_account(tx_state, address).balance + track_address(tx_state, address) push(evm.stack, balance) @@ -356,9 +354,9 @@ def extcodesize(evm: Evm) -> None: charge_gas(evm, access_gas_cost) # OPERATION - state = evm.message.block_env.state - code = get_account(state, address).code - track_address(evm.state_changes, address) + tx_state = evm.message.tx_env.state + code = get_account(tx_state, address).code + track_address(tx_state, address) codesize = U256(len(code)) push(evm.stack, codesize) @@ -403,9 +401,9 @@ def extcodecopy(evm: Evm) -> None: # OPERATION evm.memory += b"\x00" * extend_memory.expand_by - state = evm.message.block_env.state - code = get_account(state, address).code - track_address(evm.state_changes, address) + tx_state = evm.message.tx_env.state + code = get_account(tx_state, address).code + track_address(tx_state, address) value = buffer_read(code, code_start_index, size) memory_write(evm.memory, memory_start_index, value) @@ -496,9 +494,9 @@ def extcodehash(evm: Evm) -> None: charge_gas(evm, access_gas_cost) # OPERATION - state = evm.message.block_env.state - account = get_account(state, address) - track_address(evm.state_changes, address) + tx_state = evm.message.tx_env.state + account = get_account(tx_state, address) + track_address(tx_state, address) if account == EMPTY_ACCOUNT: codehash = U256(0) @@ -531,7 +529,7 @@ def self_balance(evm: Evm) -> None: # OPERATION # Non-existent accounts default to EMPTY_ACCOUNT, which has balance 0. balance = get_account( - evm.message.block_env.state, evm.message.current_target + evm.message.tx_env.state, evm.message.current_target ).balance push(evm.stack, balance) diff --git a/src/ethereum/forks/amsterdam/vm/instructions/storage.py b/src/ethereum/forks/amsterdam/vm/instructions/storage.py index 18afa2a2ba..88d4adea03 100644 --- a/src/ethereum/forks/amsterdam/vm/instructions/storage.py +++ b/src/ethereum/forks/amsterdam/vm/instructions/storage.py @@ -13,17 +13,13 @@ from ethereum_types.numeric import Uint -from ...state import ( +from ...state_tracker import ( get_storage, get_storage_original, get_transient_storage, set_storage, set_transient_storage, -) -from ...state_tracker import ( - capture_pre_storage, track_storage_read, - track_storage_write, ) from .. import Evm from ..exceptions import WriteInStaticContext @@ -62,14 +58,9 @@ def sload(evm: Evm) -> None: charge_gas(evm, GAS_COLD_SLOAD) # OPERATION - value = get_storage( - evm.message.block_env.state, evm.message.current_target, key - ) - track_storage_read( - evm.state_changes, - evm.message.current_target, - key, - ) + tx_state = evm.message.tx_env.state + value = get_storage(tx_state, evm.message.current_target, key) + track_storage_read(tx_state, evm.message.current_target, key) push(evm.stack, value) @@ -97,11 +88,11 @@ def sstore(evm: Evm) -> None: # check we have at least the stipend gas check_gas(evm, GAS_CALL_STIPEND + Uint(1)) - state = evm.message.block_env.state + tx_state = evm.message.tx_env.state original_value = get_storage_original( - state, evm.message.current_target, key + tx_state, evm.message.current_target, key ) - current_value = get_storage(state, evm.message.current_target, key) + current_value = get_storage(tx_state, evm.message.current_target, key) gas_cost = Uint(0) @@ -109,17 +100,7 @@ def sstore(evm: Evm) -> None: evm.accessed_storage_keys.add((evm.message.current_target, key)) gas_cost += GAS_COLD_SLOAD - capture_pre_storage( - evm.message.tx_env.state_changes, - evm.message.current_target, - key, - current_value, - ) - track_storage_read( - evm.state_changes, - evm.message.current_target, - key, - ) + track_storage_read(tx_state, evm.message.current_target, key) if original_value == current_value and current_value != new_value: if original_value == 0: @@ -151,13 +132,7 @@ def sstore(evm: Evm) -> None: ) charge_gas(evm, gas_cost) - set_storage(state, evm.message.current_target, key, new_value) - track_storage_write( - evm.state_changes, - evm.message.current_target, - key, - new_value, - ) + set_storage(tx_state, evm.message.current_target, key, new_value) # PROGRAM COUNTER evm.pc += Uint(1) @@ -182,7 +157,7 @@ def tload(evm: Evm) -> None: # OPERATION value = get_transient_storage( - evm.message.tx_env.transient_storage, evm.message.current_target, key + evm.message.tx_env.state, evm.message.current_target, key ) push(evm.stack, value) @@ -210,7 +185,7 @@ def tstore(evm: Evm) -> None: # GAS charge_gas(evm, GAS_WARM_ACCESS) set_transient_storage( - evm.message.tx_env.transient_storage, + evm.message.tx_env.state, evm.message.current_target, key, new_value, diff --git a/src/ethereum/forks/amsterdam/vm/instructions/system.py b/src/ethereum/forks/amsterdam/vm/instructions/system.py index 72f44cdf70..62a3644345 100644 --- a/src/ethereum/forks/amsterdam/vm/instructions/system.py +++ b/src/ethereum/forks/amsterdam/vm/instructions/system.py @@ -12,12 +12,12 @@ """ from ethereum_types.bytes import Bytes, Bytes0 -from ethereum_types.numeric import U64, U256, Uint +from ethereum_types.numeric import U256, Uint +from ethereum.state import Address from ethereum.utils.numeric import ceil32 -from ...fork_types import Address -from ...state import ( +from ...state_tracker import ( account_has_code_or_nonce, account_has_storage, get_account, @@ -25,13 +25,7 @@ is_account_alive, move_ether, set_account_balance, -) -from ...state_tracker import ( - capture_pre_balance, - create_child_frame, track_address, - track_balance_change, - track_nonce_change, ) from ...utils.address import ( compute_contract_address, @@ -95,7 +89,7 @@ def generic_create( if memory_size > U256(MAX_INIT_CODE_SIZE): raise OutOfGasError - state = evm.message.block_env.state + tx_state = evm.message.tx_env.state call_data = memory_read_bytes( evm.memory, memory_start_position, memory_size @@ -106,7 +100,7 @@ def generic_create( evm.return_data = b"" sender_address = evm.message.current_target - sender = get_account(state, sender_address) + sender = get_account(tx_state, sender_address) if ( sender.balance < endowment @@ -119,31 +113,15 @@ def generic_create( evm.accessed_addresses.add(contract_address) - track_address(evm.state_changes, contract_address) + track_address(tx_state, contract_address) if account_has_code_or_nonce( - state, contract_address - ) or account_has_storage(state, contract_address): - increment_nonce(state, evm.message.current_target) - nonce_after = get_account(state, evm.message.current_target).nonce - track_nonce_change( - evm.state_changes, - evm.message.current_target, - U64(nonce_after), - ) + tx_state, contract_address + ) or account_has_storage(tx_state, contract_address): + increment_nonce(tx_state, evm.message.current_target) push(evm.stack, U256(0)) return - # Track nonce increment for CREATE - increment_nonce(state, evm.message.current_target) - nonce_after = get_account(state, evm.message.current_target).nonce - track_nonce_change( - evm.state_changes, - evm.message.current_target, - U64(nonce_after), - ) - - # Create call frame as child of parent EVM's frame - child_state_changes = create_child_frame(evm.state_changes) + increment_nonce(tx_state, evm.message.current_target) child_message = Message( block_env=evm.message.block_env, @@ -164,7 +142,6 @@ def generic_create( disable_precompiles=False, parent_evm=evm, is_create=True, - state_changes=child_state_changes, ) child_evm = process_create_message(child_message) @@ -206,7 +183,7 @@ def create(evm: Evm) -> None: contract_address = compute_contract_address( evm.message.current_target, get_account( - evm.message.block_env.state, evm.message.current_target + evm.message.tx_env.state, evm.message.current_target ).nonce, ) @@ -340,9 +317,6 @@ def generic_call( evm.memory, memory_input_start_position, memory_input_size ) - # Create call frame as child of parent EVM's frame - child_state_changes = create_child_frame(evm.state_changes) - child_message = Message( block_env=evm.message.block_env, tx_env=evm.message.tx_env, @@ -362,7 +336,6 @@ def generic_call( disable_precompiles=disable_precompiles, parent_evm=evm, is_create=False, - state_changes=child_state_changes, ) child_evm = process_message(child_message) @@ -430,12 +403,12 @@ def call(evm: Evm) -> None: ) # STATE ACCESS - state = evm.message.block_env.state + tx_state = evm.message.tx_env.state if is_cold_access: evm.accessed_addresses.add(to) create_gas_cost = GAS_NEW_ACCOUNT - if value == 0 or is_account_alive(state, to): + if value == 0 or is_account_alive(tx_state, to): create_gas_cost = Uint(0) extra_gas = access_gas_cost + transfer_gas_cost + create_gas_cost @@ -449,11 +422,11 @@ def call(evm: Evm) -> None: # check enough gas for delegation access extra_gas += delegation_access_cost check_gas(evm, extra_gas + extend_memory.cost) - track_address(evm.state_changes, code_address) + track_address(tx_state, code_address) if code_address not in evm.accessed_addresses: evm.accessed_addresses.add(code_address) - code = get_account(state, code_address).code + code = get_account(tx_state, code_address).code message_call_gas = calculate_message_call_gas( value, @@ -465,7 +438,7 @@ def call(evm: Evm) -> None: charge_gas(evm, message_call_gas.cost + extend_memory.cost) evm.memory += b"\x00" * extend_memory.expand_by - sender_balance = get_account(state, evm.message.current_target).balance + sender_balance = get_account(tx_state, evm.message.current_target).balance if sender_balance < value: push(evm.stack, U256(0)) evm.return_data = b"" @@ -537,7 +510,7 @@ def callcode(evm: Evm) -> None: ) # STATE ACCESS - state = evm.message.block_env.state + tx_state = evm.message.tx_env.state if is_cold_access: evm.accessed_addresses.add(code_address) @@ -552,11 +525,11 @@ def callcode(evm: Evm) -> None: # check enough gas for delegation access extra_gas += delegation_access_cost check_gas(evm, extra_gas + extend_memory.cost) - track_address(evm.state_changes, code_address) + track_address(tx_state, code_address) if code_address not in evm.accessed_addresses: evm.accessed_addresses.add(code_address) - code = get_account(state, code_address).code + code = get_account(tx_state, code_address).code message_call_gas = calculate_message_call_gas( value, @@ -569,19 +542,7 @@ def callcode(evm: Evm) -> None: # OPERATION evm.memory += b"\x00" * extend_memory.expand_by - sender_balance = get_account( - evm.message.block_env.state, evm.message.current_target - ).balance - - # EIP-7928: For CALLCODE with value transfer, capture pre-balance - # in transaction frame. CALLCODE transfers value from/to current_target - # (same address), affecting current storage context, not child frame - if value != 0 and sender_balance >= value: - capture_pre_balance( - evm.message.tx_env.state_changes, - evm.message.current_target, - sender_balance, - ) + sender_balance = get_account(tx_state, evm.message.current_target).balance if sender_balance < value: push(evm.stack, U256(0)) @@ -636,57 +597,34 @@ def selfdestruct(evm: Evm) -> None: check_gas(evm, gas_cost) # STATE ACCESS - state = evm.message.block_env.state + tx_state = evm.message.tx_env.state if is_cold_access: evm.accessed_addresses.add(beneficiary) - track_address(evm.state_changes, beneficiary) + track_address(tx_state, beneficiary) if ( - not is_account_alive(state, beneficiary) - and get_account(state, evm.message.current_target).balance != 0 + not is_account_alive(tx_state, beneficiary) + and get_account(tx_state, evm.message.current_target).balance != 0 ): gas_cost += GAS_SELF_DESTRUCT_NEW_ACCOUNT charge_gas(evm, gas_cost) - state = evm.message.block_env.state originator = evm.message.current_target - originator_balance = get_account(state, originator).balance - beneficiary_balance = get_account(state, beneficiary).balance + originator_balance = get_account(tx_state, originator).balance - # Get tracking context - tx_frame = evm.message.tx_env.state_changes - - # Capture pre-balances for net-zero filtering - track_address(evm.state_changes, originator) - capture_pre_balance(tx_frame, originator, originator_balance) - capture_pre_balance(tx_frame, beneficiary, beneficiary_balance) + track_address(tx_state, originator) # Transfer balance - move_ether(state, originator, beneficiary, originator_balance) - - # Track balance changes - originator_new_balance = get_account(state, originator).balance - beneficiary_new_balance = get_account(state, beneficiary).balance - track_balance_change( - evm.state_changes, - originator, - originator_new_balance, - ) - track_balance_change( - evm.state_changes, - beneficiary, - beneficiary_new_balance, - ) + move_ether(tx_state, originator, beneficiary, originator_balance) # register account for deletion only if it was created # in the same transaction - if originator in state.created_accounts: + if originator in tx_state.created_accounts: # If beneficiary is the same as originator, then # the ether is burnt. - set_account_balance(state, originator, U256(0)) - track_balance_change(evm.state_changes, originator, U256(0)) + set_account_balance(tx_state, originator, U256(0)) evm.accounts_to_delete.add(originator) # HALT the execution @@ -733,7 +671,7 @@ def delegatecall(evm: Evm) -> None: check_gas(evm, access_gas_cost + extend_memory.cost) # STATE ACCESS - state = evm.message.block_env.state + tx_state = evm.message.tx_env.state if is_cold_access: evm.accessed_addresses.add(code_address) @@ -748,11 +686,11 @@ def delegatecall(evm: Evm) -> None: # check enough gas for delegation access extra_gas += delegation_access_cost check_gas(evm, extra_gas + extend_memory.cost) - track_address(evm.state_changes, code_address) + track_address(tx_state, code_address) if code_address not in evm.accessed_addresses: evm.accessed_addresses.add(code_address) - code = get_account(state, code_address).code + code = get_account(tx_state, code_address).code message_call_gas = calculate_message_call_gas( U256(0), @@ -823,7 +761,7 @@ def staticcall(evm: Evm) -> None: check_gas(evm, access_gas_cost + extend_memory.cost) # STATE ACCESS - state = evm.message.block_env.state + tx_state = evm.message.tx_env.state if is_cold_access: evm.accessed_addresses.add(to) @@ -838,11 +776,11 @@ def staticcall(evm: Evm) -> None: # check enough gas for delegation access extra_gas += delegation_access_cost check_gas(evm, extra_gas + extend_memory.cost) - track_address(evm.state_changes, code_address) + track_address(tx_state, code_address) if code_address not in evm.accessed_addresses: evm.accessed_addresses.add(code_address) - code = get_account(state, code_address).code + code = get_account(tx_state, code_address).code message_call_gas = calculate_message_call_gas( U256(0), diff --git a/src/ethereum/forks/amsterdam/vm/interpreter.py b/src/ethereum/forks/amsterdam/vm/interpreter.py index 21aecab4a9..ee5c59e1b3 100644 --- a/src/ethereum/forks/amsterdam/vm/interpreter.py +++ b/src/ethereum/forks/amsterdam/vm/interpreter.py @@ -15,9 +15,10 @@ from typing import Optional, Set, Tuple from ethereum_types.bytes import Bytes, Bytes0 -from ethereum_types.numeric import U64, U256, Uint, ulen +from ethereum_types.numeric import U256, Uint, ulen from ethereum.exceptions import EthereumException +from ethereum.state import Address from ethereum.trace import ( EvmStop, OpEnd, @@ -30,29 +31,18 @@ ) from ..blocks import Log -from ..fork_types import Address -from ..state import ( +from ..state_tracker import ( account_has_code_or_nonce, account_has_storage, - begin_transaction, - commit_transaction, + copy_tx_state, destroy_storage, get_account, increment_nonce, mark_account_created, move_ether, - rollback_transaction, + restore_tx_state, set_code, -) -from ..state_tracker import ( - capture_pre_balance, - capture_pre_code, - merge_on_failure, - merge_on_success, track_address, - track_balance_change, - track_code_change, - track_nonce_change, ) from ..vm import Message from ..vm.eoa_delegation import get_delegated_code_address, set_delegation @@ -115,13 +105,13 @@ def process_message_call(message: Message) -> MessageCallOutput: Output of the message call """ - block_env = message.block_env + tx_state = message.tx_env.state refund_counter = U256(0) if message.target == Bytes0(b""): is_collision = account_has_code_or_nonce( - block_env.state, message.current_target - ) or account_has_storage(block_env.state, message.current_target) - track_address(message.tx_env.state_changes, message.current_target) + tx_state, message.current_target + ) or account_has_storage(tx_state, message.current_target) + track_address(tx_state, message.current_target) if is_collision: return MessageCallOutput( Uint(0), @@ -141,9 +131,9 @@ def process_message_call(message: Message) -> MessageCallOutput: if delegated_address is not None: message.disable_precompiles = True message.accessed_addresses.add(delegated_address) - message.code = get_account(block_env.state, delegated_address).code + message.code = get_account(tx_state, delegated_address).code message.code_address = delegated_address - track_address(message.block_env.state_changes, delegated_address) + track_address(tx_state, delegated_address) evm = process_message(message) @@ -185,10 +175,9 @@ def process_create_message(message: Message) -> Evm: Items containing execution specific objects. """ - state = message.block_env.state - transient_storage = message.tx_env.transient_storage + tx_state = message.tx_env.state # take snapshot of state before processing the message - begin_transaction(state, transient_storage) + snapshot = copy_tx_state(tx_state) # If the address where the account is being created has storage, it is # destroyed. This can only happen in the following highly unlikely @@ -197,23 +186,15 @@ def process_create_message(message: Message) -> Evm: # `CREATE` or `CREATE2` call. # * The first `CREATE` happened before Spurious Dragon and left empty # code. - destroy_storage(state, message.current_target) + destroy_storage(tx_state, message.current_target) # In the previously mentioned edge case the preexisting storage is ignored # for gas refund purposes. In order to do this we must track created # accounts. This tracking is also needed to respect the constraints # added to SELFDESTRUCT by EIP-6780. - mark_account_created(state, message.current_target) - - increment_nonce(state, message.current_target) - nonce_after = get_account(state, message.current_target).nonce - track_nonce_change( - message.state_changes, - message.current_target, - U64(nonce_after), - ) + mark_account_created(tx_state, message.current_target) - capture_pre_code(message.tx_env.state_changes, message.current_target, b"") + increment_nonce(tx_state, message.current_target) evm = process_message(message) if not evm.error: @@ -227,25 +208,14 @@ def process_create_message(message: Message) -> Evm: if len(contract_code) > MAX_CODE_SIZE: raise OutOfGasError except ExceptionalHalt as error: - rollback_transaction(state, transient_storage) - merge_on_failure(message.state_changes) + restore_tx_state(tx_state, snapshot) evm.gas_left = Uint(0) evm.output = b"" evm.error = error else: - # Note: No need to capture pre code since it's always b"" here - set_code(state, message.current_target, contract_code) - if contract_code != b"": - track_code_change( - message.state_changes, - message.current_target, - contract_code, - ) - commit_transaction(state, transient_storage) - merge_on_success(message.state_changes) + set_code(tx_state, message.current_target, contract_code) else: - rollback_transaction(state, transient_storage) - merge_on_failure(message.state_changes) + restore_tx_state(tx_state, snapshot) return evm @@ -264,8 +234,7 @@ def process_message(message: Message) -> Evm: Items containing execution specific objects """ - state = message.block_env.state - transient_storage = message.tx_env.transient_storage + tx_state = message.tx_env.state if message.depth > STACK_DEPTH_LIMIT: raise StackDepthLimitError("Stack depth limit reached") @@ -288,47 +257,21 @@ def process_message(message: Message) -> Evm: error=None, accessed_addresses=message.accessed_addresses, accessed_storage_keys=message.accessed_storage_keys, - state_changes=message.state_changes, ) # take snapshot of state before processing the message - begin_transaction(state, transient_storage) + snapshot = copy_tx_state(tx_state) - track_address(message.state_changes, message.current_target) + track_address(tx_state, message.current_target) if message.should_transfer_value and message.value != 0: - # Track value transfer - sender_balance = get_account(state, message.caller).balance - recipient_balance = get_account(state, message.current_target).balance - - track_address(message.state_changes, message.caller) - capture_pre_balance( - message.tx_env.state_changes, message.caller, sender_balance - ) - capture_pre_balance( - message.tx_env.state_changes, - message.current_target, - recipient_balance, - ) + track_address(tx_state, message.caller) move_ether( - state, message.caller, message.current_target, message.value - ) - - sender_new_balance = get_account(state, message.caller).balance - recipient_new_balance = get_account( - state, message.current_target - ).balance - - track_balance_change( - message.state_changes, + tx_state, message.caller, - U256(sender_new_balance), - ) - track_balance_change( - message.state_changes, message.current_target, - U256(recipient_new_balance), + message.value, ) try: @@ -360,11 +303,5 @@ def process_message(message: Message) -> Evm: evm.error = error if evm.error: - rollback_transaction(state, transient_storage) - if not message.is_create: - merge_on_failure(evm.state_changes) - else: - commit_transaction(state, transient_storage) - if not message.is_create: - merge_on_success(evm.state_changes) + restore_tx_state(tx_state, snapshot) return evm diff --git a/src/ethereum/forks/amsterdam/vm/precompiled_contracts/mapping.py b/src/ethereum/forks/amsterdam/vm/precompiled_contracts/mapping.py index 7486203c3e..1d32bdce81 100644 --- a/src/ethereum/forks/amsterdam/vm/precompiled_contracts/mapping.py +++ b/src/ethereum/forks/amsterdam/vm/precompiled_contracts/mapping.py @@ -13,7 +13,8 @@ from typing import Callable, Dict -from ...fork_types import Address +from ethereum.state import Address + from . import ( ALT_BN128_ADD_ADDRESS, ALT_BN128_MUL_ADDRESS, diff --git a/src/ethereum/state.py b/src/ethereum/state.py new file mode 100644 index 0000000000..8300e39e53 --- /dev/null +++ b/src/ethereum/state.py @@ -0,0 +1,136 @@ +""" +Shared state types and the `PreState` protocol used by the state transition +function. + +The `PreState` protocol specifies the operations that any pre-execution state +provider must support, allowing multiple backing implementations (in-memory +`dict`, on-disk database, witness, etc.). +""" + +from dataclasses import dataclass +from typing import Dict, List, Optional, Protocol, Tuple + +from ethereum_rlp import Extended +from ethereum_types.bytes import Bytes, Bytes20, Bytes32 +from ethereum_types.frozen import slotted_freezable +from ethereum_types.numeric import U256, Uint + +from ethereum.crypto.hash import Hash32 + +Address = Bytes20 +Root = Hash32 + + +@slotted_freezable +@dataclass +class Account: + """ + State associated with an address. + """ + + nonce: Uint + balance: U256 + code: Bytes + + +EMPTY_ACCOUNT = Account( + nonce=Uint(0), + balance=U256(0), + code=b"", +) + + +@slotted_freezable +@dataclass +class LeafNode: + """Leaf node in the Merkle Trie.""" + + rest_of_key: Bytes + value: Extended + + +@slotted_freezable +@dataclass +class ExtensionNode: + """Extension node in the Merkle Trie.""" + + key_segment: Bytes + subnode: Extended + + +BranchSubnodes = Tuple[ + Extended, + Extended, + Extended, + Extended, + Extended, + Extended, + Extended, + Extended, + Extended, + Extended, + Extended, + Extended, + Extended, + Extended, + Extended, + Extended, +] + + +@slotted_freezable +@dataclass +class BranchNode: + """Branch node in the Merkle Trie.""" + + subnodes: BranchSubnodes + value: Extended + + +InternalNode = LeafNode | ExtensionNode | BranchNode + + +class PreState(Protocol): + """ + Protocol for providing pre-execution state. + + Specify the operations that any pre-state provider (dict, database, + witness, etc.) must support for the EELS state transition. + """ + + def get_account_optional(self, address: Address) -> Optional[Account]: + """ + Get the account at an address. + + Return ``None`` if there is no account at the address. + """ + ... + + def get_storage(self, address: Address, key: Bytes32) -> U256: + """ + Get a storage value. + + Return ``U256(0)`` if the key has not been set. + """ + ... + + def account_has_storage(self, address: Address) -> bool: + """ + Check whether an account has any storage. + + Only needed for EIP-7610. + """ + ... + + def compute_state_root_and_trie_changes( + self, + account_changes: Dict[Address, Optional[Account]], + storage_changes: Dict[Address, Dict[Bytes32, U256]], + ) -> Tuple[Root, List[InternalNode]]: + """ + Compute the state root after applying changes to the pre-state. + + Return the new state root together with the internal trie nodes + that were created or modified. + """ + ... diff --git a/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py b/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py index 9a14efa54c..d8e353e986 100644 --- a/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py +++ b/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py @@ -5,6 +5,7 @@ from inspect import signature from typing import Any, Final +from ethereum.state import EMPTY_ACCOUNT from ethereum_spec_tools.forks import Hardfork @@ -203,6 +204,8 @@ def Bloom(self) -> Any: @property def EMPTY_ACCOUNT(self) -> Any: """EMPTY_ACCOUNT of the fork.""" + if self.has_block_state: + return EMPTY_ACCOUNT return self._module("fork_types").EMPTY_ACCOUNT @property @@ -278,6 +281,15 @@ def has_decode_transaction(self) -> bool: """Check if this fork has a `decode_transaction`.""" return hasattr(self._module("transactions"), "decode_transaction") + @property + def has_block_state(self) -> bool: + """Check if the fork uses BlockState instead of State.""" + try: + module = self._module("state_tracker") + except ModuleNotFoundError: + return False + return hasattr(module, "BlockState") + @property def State(self) -> Any: """State class of the fork.""" diff --git a/src/ethereum_spec_tools/evm_tools/t8n/__init__.py b/src/ethereum_spec_tools/evm_tools/t8n/__init__.py index 03939e8601..003fb5e756 100644 --- a/src/ethereum_spec_tools/evm_tools/t8n/__init__.py +++ b/src/ethereum_spec_tools/evm_tools/t8n/__init__.py @@ -16,7 +16,15 @@ from ethereum import trace from ethereum.exceptions import EthereumException, InvalidBlock from ethereum.fork_criteria import ByBlockNumber, ByTimestamp, Unscheduled -from ethereum.forks.amsterdam.state_tracker import StateChanges + +# TODO: Make this not amsterdam specific once the state tracker has +# been added to older forks. +from ethereum.forks.amsterdam.block_access_lists.builder import ( + BlockAccessListBuilder, +) +from ethereum.forks.amsterdam.block_access_lists.rlp_types import ( + BlockAccessIndex, +) from ethereum_spec_tools.forks import Hardfork, TemporaryHardfork from ..loaders.fixture_loader import Load @@ -285,11 +293,21 @@ def block_environment(self) -> Any: "coinbase": self.env.coinbase, "number": self.env.block_number, "time": self.env.block_timestamp, - "state": self.alloc.state, "block_gas_limit": self.env.block_gas_limit, "chain_id": self.chain_id, } + if self.fork.has_block_state: + from ethereum.forks.amsterdam.state_tracker import ( + BlockState, + ) + + block_state = BlockState(pre_state=self.alloc.state) + kw_arguments["state"] = block_state + self._block_state = block_state + else: + kw_arguments["state"] = self.alloc.state + block_environment = self.fork.BlockEnvironment if self.fork.has_calculate_base_fee_per_gas: @@ -307,7 +325,9 @@ def block_environment(self) -> Any: kw_arguments["excess_blob_gas"] = self.env.excess_blob_gas if self.fork.has_block_access_list_hash: - kw_arguments["state_changes"] = StateChanges() + kw_arguments["block_access_list_builder"] = ( + BlockAccessListBuilder() + ) return block_environment(**kw_arguments) @@ -408,14 +428,13 @@ def _run_blockchain_test(self, block_env: Any, block_output: Any) -> None: f"Transaction {original_idx} failed: {e!r}" ) - # Post-execution operations use index N+1 + # EIP-7928: Post-execution operations use index N+1 + num_txs = len(self.txs.transactions) if self.fork.has_block_access_list_hash: - from ethereum.forks.amsterdam.state_tracker import ( - increment_block_access_index, + block_env.block_access_list_builder.block_access_index = ( + BlockAccessIndex(Uint(num_txs) + Uint(1)) ) - increment_block_access_index(block_env.state_changes) - if not self.fork.proof_of_stake: if self.options.state_reward is None: self.pay_block_rewards(self.fork.BLOCK_REWARD, block_env) @@ -433,9 +452,8 @@ def _run_blockchain_test(self, block_env: Any, block_output: Any) -> None: self.fork.process_general_purpose_requests(block_env, block_output) if self.fork.has_block_access_list_hash: - # Build block access list from block_env.state_changes block_output.block_access_list = self.fork.build_block_access_list( - block_env.state_changes + block_env.block_access_list_builder, block_env.state ) def run_blockchain_test(self) -> None: diff --git a/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py b/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py index 5595c38960..c3aa13cdcc 100644 --- a/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py +++ b/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py @@ -307,7 +307,31 @@ def update(self, t8n: "T8N", block_env: Any, block_output: Any) -> None: self.receipt_root = t8n.fork.root(block_output.receipts_trie) self.bloom = t8n.fork.logs_bloom(block_output.block_logs) self.logs_hash = keccak256(rlp.encode(block_output.block_logs)) - self.state_root = t8n.fork.state_root(block_env.state) + if t8n.fork.has_block_state: + # TODO: remove this once the state tracker is ported over + # to the older forks + from ethereum.forks.amsterdam.state import apply_changes_to_state + from ethereum.forks.amsterdam.state_tracker import ( + extract_block_diffs, + ) + + account_changes, storage_changes = extract_block_diffs( + t8n._block_state + ) + state_root_value, _ = ( + t8n.alloc.state.compute_state_root_and_trie_changes( + account_changes, storage_changes + ) + ) + self.state_root = state_root_value + # Apply diffs to pre-state for alloc output + apply_changes_to_state( + t8n.alloc.state, + account_changes, + storage_changes, + ) + else: + self.state_root = t8n.fork.state_root(block_env.state) self.receipts = self.get_receipts_from_output(t8n, block_output) if hasattr(block_env, "base_fee_per_gas"): diff --git a/tox.ini b/tox.ini index deedbbc515..f23707dc35 100644 --- a/tox.ini +++ b/tox.ini @@ -8,7 +8,6 @@ envlist = pypy3 json_infra static - optimized tests_pytest_py3 tests_pytest_pypy3 tests_benchmark_pytest_py3 @@ -135,21 +134,6 @@ commands = {posargs} \ tests -[testenv:optimized] -description = Run unit tests for optimized state and ethash -passenv = - PYPY_GC_MAX - PYPY_GC_MIN -commands = - pytest \ - -m "not slow and not evm_tools" \ - -n auto --maxprocesses 5 --dist=loadfile \ - --ignore-glob='tests/test_t8n.py' \ - --basetemp="{temp_dir}/pytest" \ - --optimized \ - {posargs} \ - tests/json_infra - [testenv:benchmark-gas-values] description = Run benchmark tests with `--gas-benchmark-values` passenv = From 7b8b30826e6d7c12c4d96bf7d7ee359f28b5d033 Mon Sep 17 00:00:00 2001 From: felipe Date: Fri, 20 Feb 2026 15:41:53 -0700 Subject: [PATCH 3/4] test: Add more invalid bal tests (#2261) * chore(ai): ask user to fill new tests after impl; check against JSON * feat(test): add more invalid BAL tests for duplicate entries within accounts * chore: add unit tests for BAL modifiers * chore: add modifier and test for duplicated code changes --- .claude/commands/write-test.md | 4 + .../test_types/block_access_list/modifiers.py | 205 +++++++++++++++- .../tests/test_block_access_list_modifiers.py | 219 ++++++++++++++++++ .../test_block_access_lists_invalid.py | 138 +++++++++++ 4 files changed, 565 insertions(+), 1 deletion(-) create mode 100644 packages/testing/src/execution_testing/test_types/tests/test_block_access_list_modifiers.py diff --git a/.claude/commands/write-test.md b/.claude/commands/write-test.md index a7cc488f89..018317a744 100644 --- a/.claude/commands/write-test.md +++ b/.claude/commands/write-test.md @@ -61,6 +61,10 @@ Conventions and patterns for writing consensus tests. Run this skill before writ - `@pytest.mark.parametrize("name", [pytest.param(val, id="label"), ...])` with descriptive `id=` strings - Stack parametrize decorators for multiple dimensions +## After Writing Tests + +After writing or modifying tests, ask the user: "Would you like me to load the `/fill-tests` skill to verify the new tests fill correctly? (This loads an additional skill into context.)" If they agree, run `/fill-tests`, fill the new tests, then inspect the generated fixture JSON to verify the fixture contents match what the test intends. + ## References See `docs/writing_tests/` and `docs/writing_tests/opcode_metadata.md` for detailed documentation. diff --git a/packages/testing/src/execution_testing/test_types/block_access_list/modifiers.py b/packages/testing/src/execution_testing/test_types/block_access_list/modifiers.py index 550c21e8c8..121ad26d13 100644 --- a/packages/testing/src/execution_testing/test_types/block_access_list/modifiers.py +++ b/packages/testing/src/execution_testing/test_types/block_access_list/modifiers.py @@ -500,6 +500,202 @@ def transform(bal: BlockAccessList) -> BlockAccessList: return transform +def _duplicate_in_field( + address: Address, + field_name: str, + match_fn: Callable[[Any], bool], + error_msg: str, + sub_field: Optional[str] = None, + sub_match_fn: Optional[Callable[[Any], bool]] = None, +) -> Callable[[BlockAccessList], BlockAccessList]: + """ + Duplicate the first matching entry in an account's field list. + + When sub_field and sub_match_fn are provided, find the parent entry + via match_fn then duplicate within sub_field using sub_match_fn. + """ + found = False + + def _copy(entry: Any) -> Any: + if hasattr(entry, "model_copy"): + return entry.model_copy(deep=True) + return ZeroPaddedHexNumber(entry) + + def transform(bal: BlockAccessList) -> BlockAccessList: + nonlocal found + new_root = [] + for account_change in bal.root: + if account_change.address == address: + new_account = account_change.model_copy(deep=True) + entries = getattr(new_account, field_name) + + if sub_field is not None and sub_match_fn is not None: + for parent in entries: + if match_fn(parent): + children = getattr(parent, sub_field) + new_children = [] + for child in children: + new_children.append(child) + if not found and sub_match_fn(child): + found = True + new_children.append(_copy(child)) + setattr(parent, sub_field, new_children) + break + else: + new_entries = [] + for entry in entries: + new_entries.append(entry) + if not found and match_fn(entry): + found = True + new_entries.append(_copy(entry)) + setattr(new_account, field_name, new_entries) + + new_root.append(new_account) + else: + new_root.append(account_change) + + if not found: + raise ValueError(error_msg) + + return BlockAccessList(root=new_root) + + return transform + + +def duplicate_nonce_change( + address: Address, block_access_index: int +) -> Callable[[BlockAccessList], BlockAccessList]: + """Duplicate a nonce change entry for a given block access index.""" + return _duplicate_in_field( + address, + "nonce_changes", + match_fn=lambda c: c.block_access_index == block_access_index, + error_msg=( + f"Block access index {block_access_index} not found in " + f"nonce_changes of account {address}" + ), + ) + + +def duplicate_balance_change( + address: Address, block_access_index: int +) -> Callable[[BlockAccessList], BlockAccessList]: + """Duplicate a balance change entry for a given block access index.""" + return _duplicate_in_field( + address, + "balance_changes", + match_fn=lambda c: c.block_access_index == block_access_index, + error_msg=( + f"Block access index {block_access_index} not found in " + f"balance_changes of account {address}" + ), + ) + + +def duplicate_code_change( + address: Address, block_access_index: int +) -> Callable[[BlockAccessList], BlockAccessList]: + """Duplicate a code change entry for a given block access index.""" + return _duplicate_in_field( + address, + "code_changes", + match_fn=lambda c: c.block_access_index == block_access_index, + error_msg=( + f"Block access index {block_access_index} not found in " + f"code_changes of account {address}" + ), + ) + + +def duplicate_storage_slot( + address: Address, slot: int +) -> Callable[[BlockAccessList], BlockAccessList]: + """Duplicate a storage slot entry in storage_changes.""" + return _duplicate_in_field( + address, + "storage_changes", + match_fn=lambda s: s.slot == slot, + error_msg=( + f"Storage slot {slot} not found in storage_changes " + f"of account {address}" + ), + ) + + +def duplicate_storage_read( + address: Address, slot: int +) -> Callable[[BlockAccessList], BlockAccessList]: + """Duplicate a storage read entry.""" + return _duplicate_in_field( + address, + "storage_reads", + match_fn=lambda r: r == slot, + error_msg=( + f"Storage slot {slot} not found in storage_reads " + f"of account {address}" + ), + ) + + +def duplicate_slot_change( + address: Address, slot: int, block_access_index: int +) -> Callable[[BlockAccessList], BlockAccessList]: + """Duplicate a slot change within a specific storage slot.""" + return _duplicate_in_field( + address, + "storage_changes", + match_fn=lambda s: s.slot == slot, + error_msg=( + f"Block access index {block_access_index} not found " + f"in storage slot {slot} of account {address}" + ), + sub_field="slot_changes", + sub_match_fn=lambda c: c.block_access_index == block_access_index, + ) + + +def insert_storage_read( + address: Address, slot: int +) -> Callable[[BlockAccessList], BlockAccessList]: + """ + Insert a storage read at the correct sorted position. + + Useful for testing that a key must not appear in both + storage_changes and storage_reads. + """ + found_address = False + + def transform(bal: BlockAccessList) -> BlockAccessList: + nonlocal found_address + new_root = [] + for account_change in bal.root: + if account_change.address == address: + found_address = True + new_account = account_change.model_copy(deep=True) + reads = list(new_account.storage_reads) + new_slot = ZeroPaddedHexNumber(slot) + # Find insertion point to maintain sorted order + insert_idx = len(reads) + for i, existing in enumerate(reads): + if existing >= new_slot: + insert_idx = i + break + reads.insert(insert_idx, new_slot) + new_account.storage_reads = reads + new_root.append(new_account) + else: + new_root.append(account_change) + + if not found_address: + raise ValueError( + f"Address {address} not found in BAL to insert storage read" + ) + + return BlockAccessList(root=new_root) + + return transform + + def reverse_accounts() -> Callable[[BlockAccessList], BlockAccessList]: """Reverse the order of accounts in the BAL.""" @@ -567,7 +763,6 @@ def transform(bal: BlockAccessList) -> BlockAccessList: __all__ = [ - # Core functions # Account-level modifiers "remove_accounts", "append_account", @@ -589,4 +784,12 @@ def transform(bal: BlockAccessList) -> BlockAccessList: "modify_code", # Block access index modifiers "swap_bal_indices", + # Duplicate entry modifiers (uniqueness constraint testing) + "duplicate_nonce_change", + "duplicate_balance_change", + "duplicate_code_change", + "duplicate_storage_slot", + "duplicate_storage_read", + "duplicate_slot_change", + "insert_storage_read", ] diff --git a/packages/testing/src/execution_testing/test_types/tests/test_block_access_list_modifiers.py b/packages/testing/src/execution_testing/test_types/tests/test_block_access_list_modifiers.py new file mode 100644 index 0000000000..50ca7d9dc8 --- /dev/null +++ b/packages/testing/src/execution_testing/test_types/tests/test_block_access_list_modifiers.py @@ -0,0 +1,219 @@ +"""Unit tests for BAL modifier functions.""" + +import pytest + +from execution_testing.base_types import Address +from execution_testing.test_types.block_access_list import ( + BalAccountChange, + BalBalanceChange, + BalCodeChange, + BalNonceChange, + BalStorageChange, + BalStorageSlot, + BlockAccessList, +) +from execution_testing.test_types.block_access_list.modifiers import ( + duplicate_account, + duplicate_balance_change, + duplicate_code_change, + duplicate_nonce_change, + duplicate_slot_change, + duplicate_storage_read, + duplicate_storage_slot, + insert_storage_read, +) + +ALICE = Address(0xA) +CONTRACT = Address(0xC) + + +@pytest.fixture() +def sample_bal() -> BlockAccessList: + """Build a minimal BAL with one flat account and one storage account.""" + return BlockAccessList( + [ + BalAccountChange( + address=ALICE, + nonce_changes=[ + BalNonceChange(block_access_index=1, post_nonce=1), + ], + balance_changes=[ + BalBalanceChange(block_access_index=1, post_balance=100), + ], + code_changes=[ + BalCodeChange(block_access_index=1, new_code=b"\x60"), + ], + ), + BalAccountChange( + address=CONTRACT, + storage_changes=[ + BalStorageSlot( + slot=1, + slot_changes=[ + BalStorageChange( + block_access_index=1, post_value=0x42 + ), + ], + ), + ], + storage_reads=[2, 5], + ), + ] + ) + + +def test_duplicate_account(sample_bal: BlockAccessList) -> None: + """Duplicate an account entry.""" + result = duplicate_account(ALICE)(sample_bal) + alice_entries = [a for a in result.root if a.address == ALICE] + assert len(alice_entries) == 2 + + +def test_duplicate_account_missing_raises() -> None: + """Raise when the target address is absent.""" + bal = BlockAccessList([BalAccountChange(address=ALICE, nonce_changes=[])]) + with pytest.raises(ValueError, match="not found"): + duplicate_account(CONTRACT)(bal) + + +def test_duplicate_nonce_change(sample_bal: BlockAccessList) -> None: + """Duplicate a nonce change by block_access_index.""" + result = duplicate_nonce_change(ALICE, 1)(sample_bal) + assert len(result.root[0].nonce_changes) == 2 + assert ( + result.root[0].nonce_changes[0].block_access_index + == result.root[0].nonce_changes[1].block_access_index + ) + + +def test_duplicate_nonce_change_missing_index_raises( + sample_bal: BlockAccessList, +) -> None: + """Raise when the block_access_index is absent.""" + with pytest.raises(ValueError, match="not found"): + duplicate_nonce_change(ALICE, 99)(sample_bal) + + +def test_duplicate_balance_change(sample_bal: BlockAccessList) -> None: + """Duplicate a balance change by block_access_index.""" + result = duplicate_balance_change(ALICE, 1)(sample_bal) + assert len(result.root[0].balance_changes) == 2 + + +def test_duplicate_balance_change_missing_index_raises( + sample_bal: BlockAccessList, +) -> None: + """Raise when the block_access_index is absent.""" + with pytest.raises(ValueError, match="not found"): + duplicate_balance_change(ALICE, 99)(sample_bal) + + +# --- duplicate_code_change --- + + +def test_duplicate_code_change(sample_bal: BlockAccessList) -> None: + """Duplicate a code change by block_access_index.""" + result = duplicate_code_change(ALICE, 1)(sample_bal) + assert len(result.root[0].code_changes) == 2 + assert ( + result.root[0].code_changes[0].block_access_index + == result.root[0].code_changes[1].block_access_index + ) + + +def test_duplicate_code_change_missing_index_raises( + sample_bal: BlockAccessList, +) -> None: + """Raise when the block_access_index is absent.""" + with pytest.raises(ValueError, match="not found"): + duplicate_code_change(ALICE, 99)(sample_bal) + + +def test_duplicate_storage_slot(sample_bal: BlockAccessList) -> None: + """Duplicate a storage slot entry.""" + result = duplicate_storage_slot(CONTRACT, 1)(sample_bal) + contract = [a for a in result.root if a.address == CONTRACT][0] + assert len(contract.storage_changes) == 2 + assert contract.storage_changes[0].slot == contract.storage_changes[1].slot + + +def test_duplicate_storage_slot_missing_raises( + sample_bal: BlockAccessList, +) -> None: + """Raise when the slot is absent.""" + with pytest.raises(ValueError, match="not found"): + duplicate_storage_slot(CONTRACT, 99)(sample_bal) + + +def test_duplicate_storage_read(sample_bal: BlockAccessList) -> None: + """Duplicate a storage read entry.""" + result = duplicate_storage_read(CONTRACT, 2)(sample_bal) + contract = [a for a in result.root if a.address == CONTRACT][0] + assert len(contract.storage_reads) == 3 + assert contract.storage_reads[0] == contract.storage_reads[1] == 2 + + +def test_duplicate_storage_read_missing_raises( + sample_bal: BlockAccessList, +) -> None: + """Raise when the slot is absent from storage_reads.""" + with pytest.raises(ValueError, match="not found"): + duplicate_storage_read(CONTRACT, 99)(sample_bal) + + +def test_duplicate_slot_change(sample_bal: BlockAccessList) -> None: + """Duplicate a slot change within a storage slot.""" + result = duplicate_slot_change(CONTRACT, 1, 1)(sample_bal) + contract = [a for a in result.root if a.address == CONTRACT][0] + assert len(contract.storage_changes[0].slot_changes) == 2 + assert ( + contract.storage_changes[0].slot_changes[0].block_access_index + == contract.storage_changes[0].slot_changes[1].block_access_index + ) + + +def test_duplicate_slot_change_missing_index_raises( + sample_bal: BlockAccessList, +) -> None: + """Raise when the block_access_index is absent within the slot.""" + with pytest.raises(ValueError, match="not found"): + duplicate_slot_change(CONTRACT, 1, 99)(sample_bal) + + +def test_duplicate_slot_change_missing_slot_raises( + sample_bal: BlockAccessList, +) -> None: + """Raise when the parent slot is absent.""" + with pytest.raises(ValueError, match="not found"): + duplicate_slot_change(CONTRACT, 99, 1)(sample_bal) + + +def test_insert_storage_read(sample_bal: BlockAccessList) -> None: + """Insert a storage read at the correct sorted position.""" + result = insert_storage_read(CONTRACT, 3)(sample_bal) + contract = [a for a in result.root if a.address == CONTRACT][0] + assert len(contract.storage_reads) == 3 + assert list(contract.storage_reads) == [2, 3, 5] + + +def test_insert_storage_read_at_beginning( + sample_bal: BlockAccessList, +) -> None: + """Insert before all existing reads.""" + result = insert_storage_read(CONTRACT, 1)(sample_bal) + contract = [a for a in result.root if a.address == CONTRACT][0] + assert list(contract.storage_reads) == [1, 2, 5] + + +def test_insert_storage_read_at_end(sample_bal: BlockAccessList) -> None: + """Insert after all existing reads.""" + result = insert_storage_read(CONTRACT, 10)(sample_bal) + contract = [a for a in result.root if a.address == CONTRACT][0] + assert list(contract.storage_reads) == [2, 5, 10] + + +def test_insert_storage_read_missing_address_raises() -> None: + """Raise when the address is absent.""" + bal = BlockAccessList([BalAccountChange(address=ALICE, nonce_changes=[])]) + with pytest.raises(ValueError, match="not found"): + insert_storage_read(CONTRACT, 1)(bal) diff --git a/tests/amsterdam/eip7928_block_level_access_lists/test_block_access_lists_invalid.py b/tests/amsterdam/eip7928_block_level_access_lists/test_block_access_lists_invalid.py index f819c64143..0b1f8e4965 100644 --- a/tests/amsterdam/eip7928_block_level_access_lists/test_block_access_lists_invalid.py +++ b/tests/amsterdam/eip7928_block_level_access_lists/test_block_access_lists_invalid.py @@ -21,15 +21,24 @@ BlockAccessListExpectation, BlockchainTestFiller, BlockException, + Initcode, Op, Storage, Transaction, + compute_create_address, ) from execution_testing.test_types.block_access_list.modifiers import ( append_account, append_change, append_storage, duplicate_account, + duplicate_balance_change, + duplicate_code_change, + duplicate_nonce_change, + duplicate_slot_change, + duplicate_storage_read, + duplicate_storage_slot, + insert_storage_read, modify_balance, modify_nonce, modify_storage, @@ -836,3 +845,132 @@ def test_bal_invalid_extraneous_entries( ) ], ) + + +@pytest.mark.valid_from("Amsterdam") +@pytest.mark.exception_test +@pytest.mark.parametrize( + "modifier", + [ + pytest.param( + lambda alice, **_: duplicate_nonce_change(alice, 1), + id="duplicate_nonce_change", + ), + pytest.param( + lambda oracle, **_: duplicate_balance_change(oracle, 1), + id="duplicate_balance_change", + ), + pytest.param( + lambda created, **_: duplicate_code_change(created, 1), + id="duplicate_code_change", + ), + pytest.param( + lambda oracle, **_: duplicate_storage_slot(oracle, 1), + id="duplicate_storage_slot", + ), + pytest.param( + lambda oracle, **_: duplicate_storage_read(oracle, 2), + id="duplicate_storage_read", + ), + pytest.param( + lambda oracle, **_: duplicate_slot_change(oracle, 1, 1), + id="duplicate_slot_change", + ), + pytest.param( + lambda oracle, **_: insert_storage_read(oracle, 1), + id="storage_key_in_both_changes_and_reads", + ), + ], +) +def test_bal_invalid_duplicate_entries( + blockchain_test: BlockchainTestFiller, + pre: Alloc, + modifier: Callable, +) -> None: + """ + Test that clients reject blocks where BAL contains duplicate entries. + + Oracle writes storage, reads storage, and CREATEs a small contract. + Verify the EIP-7928 uniqueness constraints: each block_access_index + must appear at most once per change list (nonce, balance, code, + slot), each storage key at most once in storage_changes and + storage_reads, and no key in both. + """ + alice = pre.fund_eoa() + deploy_code = b"\x13\x37" + initcode = Initcode(deploy_code=deploy_code) + initcode_word = int.from_bytes(bytes(initcode).ljust(32, b"\x00"), "big") + oracle = pre.deploy_contract( + code=( + Op.SSTORE(1, 0x42) + + Op.SLOAD(2) + + Op.MSTORE(0, initcode_word) + + Op.CREATE(0, 0, len(initcode)) + ), + storage={2: 0x84}, + ) + created = compute_create_address(address=oracle, nonce=1) + + tx = Transaction( + sender=alice, + to=oracle, + value=100, + gas_limit=2_000_000, + ) + + blockchain_test( + pre=pre, + post=pre, + blocks=[ + Block( + txs=[tx], + exception=BlockException.INVALID_BLOCK_ACCESS_LIST, + expected_block_access_list=BlockAccessListExpectation( + account_expectations={ + alice: BalAccountExpectation( + nonce_changes=[ + BalNonceChange( + block_access_index=1, + post_nonce=1, + ), + ], + ), + oracle: BalAccountExpectation( + balance_changes=[ + BalBalanceChange( + block_access_index=1, + post_balance=100, + ), + ], + storage_changes=[ + BalStorageSlot( + slot=1, + slot_changes=[ + BalStorageChange( + block_access_index=1, + post_value=0x42, + ), + ], + ), + ], + storage_reads=[2], + ), + created: BalAccountExpectation( + code_changes=[ + BalCodeChange( + block_access_index=1, + new_code=deploy_code, + ), + ], + ), + } + ).modify( + modifier( + alice=alice, + oracle=oracle, + created=created, + ) + ), + ) + ], + ) From fa39bdcb06914abd14c6417b6e18eea5481d0b8e Mon Sep 17 00:00:00 2001 From: Mario Vega Date: Sat, 21 Feb 2026 00:26:32 +0100 Subject: [PATCH 4/4] feat(test-execute): Batch RPC requests (#2243) * refactor(test-execute): Use batch rpc requests Add skip_code option, batch worker seed fundings save fix: Chain builder EthRPC fix: tox Add deterministic deploy contract to hive's pre Handle invalid transaction tests tox tox refactor(tests-benchmark): fewer post checks * Update packages/testing/src/execution_testing/rpc/rpc.py Co-authored-by: spencer * fix: review comments * fix: tox * fix: remove unused code --------- Co-authored-by: spencer --- .../plugins/execute/contracts.py | 4 +- .../plugins/execute/execute.py | 43 +- .../plugins/execute/pre_alloc.py | 124 ++-- .../execute/rpc/chain_builder_eth_rpc.py | 349 +--------- .../plugins/execute/rpc/hive.py | 13 +- .../plugins/execute/rpc/remote.py | 2 - .../plugins/execute/rpc/remote_seed_sender.py | 5 +- .../pytest_commands/plugins/execute/sender.py | 144 +++-- .../execution/transaction_post.py | 62 +- .../src/execution_testing/rpc/__init__.py | 6 + .../testing/src/execution_testing/rpc/rpc.py | 603 ++++++++++++++---- .../src/execution_testing/rpc/rpc_types.py | 57 +- tests/benchmark/compute/helpers.py | 22 +- .../compute/instruction/test_account_query.py | 21 +- .../scenario/test_unchunkified_bytecode.py | 22 +- 15 files changed, 789 insertions(+), 688 deletions(-) diff --git a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/contracts.py b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/contracts.py index 8566869c72..8e05289608 100644 --- a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/contracts.py +++ b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/contracts.py @@ -89,7 +89,7 @@ def deploy_deterministic_factory_contract( tx_index=tx_index, ) tx_index += 1 - eth_rpc.send_wait_transaction(fund_tx) + eth_rpc.send_wait_transactions([fund_tx]) logger.info(f"Funding transaction mined: {fund_tx.hash}") # Add deployment transaction. @@ -102,7 +102,7 @@ def deploy_deterministic_factory_contract( tx_index=tx_index, ) tx_index += 1 - eth_rpc.send_wait_transaction(deploy_tx) + eth_rpc.send_wait_transactions([deploy_tx]) logger.info(f"Deployment transaction mined: {deploy_tx.hash}") deployment_contract_code = eth_rpc.get_code(DETERMINISTIC_FACTORY_ADDRESS) logger.info(f"Deployment contract code: {deployment_contract_code}") diff --git a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/execute.py b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/execute.py index bd94f44d12..79cb2f8f6b 100644 --- a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/execute.py +++ b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/execute.py @@ -10,6 +10,8 @@ import pytest from pytest_metadata.plugin import metadata_key +from execution_testing.base_types import Account +from execution_testing.base_types import Alloc as BaseAlloc from execution_testing.execution import BaseExecute from execution_testing.forks import Fork from execution_testing.logging import get_logger @@ -716,21 +718,32 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # send the funds to the required sender accounts pre.send_pending_transactions() - for ( - deployed_contract, - expected_code, - ) in pre._deployed_contracts: - actual_code = eth_rpc.get_code(deployed_contract) - if actual_code != expected_code: - msg = ( - f"Deployed test contract didn't match expected " - f"code at address {deployed_contract} " - f"(not enough gas_limit?).\n" - f"Expected: {expected_code}\n" - f"Actual: {actual_code}" - ) - logger.error(msg) - raise Exception(msg) + if pre._deployed_contracts: + contract_alloc = BaseAlloc( + root={ + addr: Account() + for addr, _ in pre._deployed_contracts + } + ) + actual_alloc = eth_rpc.get_alloc(contract_alloc) + for ( + deployed_contract, + expected_code, + ) in pre._deployed_contracts: + actual_account = actual_alloc.root[deployed_contract] + assert actual_account is not None + actual_code = actual_account.code + if actual_code != expected_code: + msg = ( + f"Deployed test contract didn't match " + f"expected code at address " + f"{deployed_contract} " + f"(not enough gas_limit?).\n" + f"Expected: {expected_code}\n" + f"Actual: {actual_code}" + ) + logger.error(msg) + raise Exception(msg) request.node.config.funded_accounts = ", ".join( [str(eoa) for eoa in pre._funded_eoa] ) diff --git a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/pre_alloc.py b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/pre_alloc.py index bf39504513..6ae0fcf7ee 100644 --- a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/pre_alloc.py +++ b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/pre_alloc.py @@ -394,17 +394,8 @@ def _deterministic_deploy_contract( self._deployed_contracts.append((contract_address, deploy_code)) - balance = self._eth_rpc.get_balance(contract_address) - nonce = self._eth_rpc.get_transaction_count(contract_address) - self.__internal_setitem__( - contract_address, - Account( - nonce=nonce, - balance=balance, - code=deploy_code, - storage={}, - ), - ) + account = self._eth_rpc.get_account(contract_address) + self.__internal_setitem__(contract_address, account) contract_address.label = label return contract_address @@ -446,27 +437,20 @@ def _deploy_contract( f"Using address stub '{stub}' at {contract_address} " f"(label={label})" ) - code = self._eth_rpc.get_code(contract_address) + account = self._eth_rpc.get_account(contract_address) + code = account.code if code == b"": raise ValueError( f"Stub {stub} at {contract_address} has no code" ) - balance = self._eth_rpc.get_balance(contract_address) - nonce = self._eth_rpc.get_transaction_count(contract_address) + balance = account.balance + nonce = account.nonce bal_eth = balance / 10**18 logger.debug( f"Stub contract {contract_address}: balance={bal_eth:.18f} " f"ETH, nonce={nonce}, code_size={len(code)} bytes" ) - self.__internal_setitem__( - contract_address, - Account( - nonce=nonce, - balance=balance, - code=code, - storage={}, - ), - ) + self.__internal_setitem__(contract_address, account) return contract_address initcode_prefix = Bytecode() @@ -827,36 +811,14 @@ def send_pending_transactions(self) -> List[TransactionByHashResponse]: f"(deployed_contracts={len(self._deployed_contracts)}, " f"funded_eoas={len(self._funded_eoa)})" ) - transaction_batches: List[List[PendingTransaction]] = [] - last_tx_batch: List[PendingTransaction] = [] - max_txs_per_batch = 100 for tx in self._pending_txs: assert tx.value is not None, ( "Transaction value must be set before sending them to the RPC." ) - if len(last_tx_batch) >= max_txs_per_batch: - transaction_batches.append(last_tx_batch) - last_tx_batch = [] - last_tx_batch.append(tx) - if last_tx_batch: - transaction_batches.append(last_tx_batch) - - responses: List[TransactionByHashResponse] = [] - for tx_batch in transaction_batches: - txs = [tx.with_signature_and_sender() for tx in tx_batch] - tx_hashes = self._eth_rpc.send_transactions(txs) - hash_strs = [str(h) for h in tx_hashes[:5]] - n_hashes = len(tx_hashes) - extra = f" and {n_hashes - 5} more" if n_hashes > 5 else "" - logger.info(f"Sent {n_hashes} transactions: {hash_strs}{extra}") - logger.info( - f"Waiting for {len(tx_batch)} transactions to be included " - "in blocks" - ) - responses += self._eth_rpc.wait_for_transactions(tx_batch) - logger.info( - f"All {len(responses)} transactions confirmed in blocks" - ) + + txs = [tx.with_signature_and_sender() for tx in self._pending_txs] + responses = self._eth_rpc.send_wait_transactions(txs) + for response in responses: logger.debug(f"Transaction response: {response.model_dump_json()}") return responses @@ -928,22 +890,30 @@ def pre( return # Refund all EOAs (regardless of whether the test passed or failed) + funded_eoas = pre._funded_eoa logger.info( - f"Starting cleanup phase: refunding {len(pre._funded_eoa)} funded EOAs" + f"Starting cleanup phase: refunding {len(funded_eoas)} funded EOAs" ) - refund_txs = [] + + if not funded_eoas: + logger.info("No funded EOAs to refund") + return + + # Build refund transactions + refund_txs: List[Transaction] = [] skipped_refunds = 0 - error_refunds = 0 - for idx, eoa in enumerate(pre._funded_eoa): - remaining_balance = eth_rpc.get_balance(eoa) - eoa.nonce = Number(eth_rpc.get_transaction_count(eoa)) - refund_gas_limit = 21_000 - tx_cost = refund_gas_limit * max_fee_per_gas + refund_gas_limit = 21_000 + tx_cost = refund_gas_limit * max_fee_per_gas + for idx, eoa in enumerate(funded_eoas): + account = eth_rpc.get_account(eoa, skip_code=True) + remaining_balance = account.balance + eoa.nonce = Number(account.nonce) if remaining_balance < tx_cost: rem_eth = remaining_balance / 10**18 cost_eth = tx_cost / 10**18 logger.debug( - f"Skipping refund for EOA {eoa} (label={eoa.label}): " + f"Skipping refund for EOA {eoa} " + f"(label={eoa.label}): " f"insufficient balance {rem_eth:.18f} ETH < " f"transaction cost {cost_eth:.18f} ETH" ) @@ -954,14 +924,15 @@ def pre( rem_eth = remaining_balance / 10**18 cost_eth = tx_cost / 10**18 logger.debug( - f"Preparing refund transaction for EOA {eoa} (label={eoa.label}): " + f"Preparing refund transaction for EOA {eoa} " + f"(label={eoa.label}): " f"{ref_eth:.18f} ETH (remaining: {rem_eth:.18f} ETH, " f"cost: {cost_eth:.18f} ETH)" ) refund_tx = Transaction( sender=eoa, to=worker_key, - gas_limit=21_000, + gas_limit=refund_gas_limit, max_fee_per_gas=max_fee_per_gas, max_priority_fee_per_gas=max_priority_fee_per_gas, value=refund_value, @@ -973,35 +944,18 @@ def pre( target=eoa.label, tx_index=idx, ) - try: - logger.info( - f"Sending refund transaction for EOA {eoa}: {refund_tx.hash}" - ) - refund_tx_hash = eth_rpc.send_transaction(refund_tx) - logger.info(f"Refund transaction sent: {refund_tx_hash}") - refund_txs.append(refund_tx) - except Exception as e: - eoa_key = eoa.key - logger.error( - f"Error sending refund transaction for EOA {eoa}: {e}." - ) - if eoa_key is not None: - logger.info( - f"Retrieve funds manually from EOA {eoa} " - f"using private key {eoa_key.hex()}." - ) - error_refunds += 1 - continue + refund_txs.append(refund_tx) + if refund_txs: logger.info( - f"Waiting for {len(refund_txs)} refund transactions " - f"({skipped_refunds} skipped due to insufficient balance, " - f"{error_refunds} errored)" + f"Sending {len(refund_txs)} refund transactions " + f"({skipped_refunds} skipped due to insufficient balance)" ) - eth_rpc.wait_for_transactions(refund_txs) + eth_rpc.send_wait_transactions(refund_txs) logger.info(f"All {len(refund_txs)} refund transactions confirmed") else: logger.info( - f"No refund transactions to send ({skipped_refunds} EOAs skipped " - f"due to insufficient balance, {error_refunds} errored)" + f"No refund transactions to send " + f"({skipped_refunds} EOAs skipped " + f"due to insufficient balance)" ) diff --git a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/rpc/chain_builder_eth_rpc.py b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/rpc/chain_builder_eth_rpc.py index 6702c5d0ea..c23276b064 100644 --- a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/rpc/chain_builder_eth_rpc.py +++ b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/rpc/chain_builder_eth_rpc.py @@ -4,181 +4,21 @@ """ import time +from contextlib import AbstractContextManager from pathlib import Path -from typing import Any, Dict, Iterator, List, Sequence +from typing import Any, List from filelock import FileLock -from pydantic import RootModel -from typing_extensions import Self from execution_testing.base_types import Address, Hash, HexNumber from execution_testing.forks import Fork -from execution_testing.rpc import EngineRPC, TransactionProtocol +from execution_testing.rpc import EngineRPC from execution_testing.rpc import EthRPC as BaseEthRPC from execution_testing.rpc.rpc_types import ( ForkchoiceState, PayloadAttributes, PayloadStatusEnum, - TransactionByHashResponse, ) -from execution_testing.test_types.trie import keccak256 - - -class HashList(RootModel[List[Hash]]): - """Hash list class.""" - - root: List[Hash] - - def append(self, item: Hash) -> None: - """Append an item to the list.""" - self.root.append(item) - - def clear(self) -> None: - """Clear the list.""" - self.root.clear() - - def remove(self, item: Hash) -> None: - """Remove an item from the list.""" - self.root.remove(item) - - def __contains__(self, item: Hash) -> bool: - """Check if an item is in the list.""" - return item in self.root - - def __len__(self) -> int: - """Get the length of the list.""" - return len(self.root) - - def __iter__(self) -> Iterator[Hash]: # type: ignore - """Iterate over the list.""" - return iter(self.root) - - -class AddressList(RootModel[List[Address]]): - """Address list class.""" - - root: List[Address] - - def append(self, item: Address) -> None: - """Append an item to the list.""" - self.root.append(item) - - def clear(self) -> None: - """Clear the list.""" - self.root.clear() - - def remove(self, item: Address) -> None: - """Remove an item from the list.""" - self.root.remove(item) - - def __contains__(self, item: Address) -> bool: - """Check if an item is in the list.""" - return item in self.root - - def __len__(self) -> int: - """Get the length of the list.""" - return len(self.root) - - def __iter__(self) -> Iterator[Address]: # type: ignore - """Iterate over the list.""" - return iter(self.root) - - -class PendingTxHashes: - """ - A class to manage the pending transaction hashes in a multi-process - environment. - - It uses a lock file to ensure that only one process can access the pending - hashes file at a time. - """ - - pending_hashes_file: Path - pending_hashes_lock: Path - pending_tx_hashes: HashList | None - lock: FileLock | None - - def __init__(self, temp_folder: Path): - """Initialize the pending transaction hashes manager.""" - self.pending_hashes_file = temp_folder / "pending_tx_hashes" - self.pending_hashes_lock = temp_folder / "pending_tx_hashes.lock" - self.pending_tx_hashes = None - self.lock = None - - def __enter__(self) -> Self: - """Lock the pending hashes file and load it.""" - assert self.lock is None, "Lock already acquired" - self.lock = FileLock(self.pending_hashes_lock, timeout=-1) - self.lock.acquire() - assert self.pending_tx_hashes is None, ( - "Pending transaction hashes already loaded" - ) - if self.pending_hashes_file.exists(): - with open(self.pending_hashes_file, "r") as f: - self.pending_tx_hashes = HashList.model_validate_json(f.read()) - else: - self.pending_tx_hashes = HashList([]) - return self - - def __exit__( - self, exc_type: object, exc_value: object, traceback: object - ) -> None: - """Flush the pending hashes to the file and release the lock.""" - assert self.lock is not None, "Lock not acquired" - assert self.pending_tx_hashes is not None, ( - "Pending transaction hashes not loaded" - ) - with open(self.pending_hashes_file, "w") as f: - f.write(self.pending_tx_hashes.model_dump_json()) - self.lock.release() - self.lock = None - self.pending_tx_hashes = None - - def append(self, tx_hash: Hash) -> None: - """Add a transaction hash to the pending list.""" - assert self.lock is not None, "Lock not acquired" - assert self.pending_tx_hashes is not None, ( - "Pending transaction hashes not loaded" - ) - self.pending_tx_hashes.append(tx_hash) - - def clear(self) -> None: - """Remove a transaction hash from the pending list.""" - assert self.lock is not None, "Lock not acquired" - assert self.pending_tx_hashes is not None - self.pending_tx_hashes.clear() - - def remove(self, tx_hash: Hash) -> None: - """Remove a transaction hash from the pending list.""" - assert self.lock is not None, "Lock not acquired" - assert self.pending_tx_hashes is not None, ( - "Pending transaction hashes not loaded" - ) - self.pending_tx_hashes.remove(tx_hash) - - def __contains__(self, tx_hash: Hash) -> bool: - """Check if a transaction hash is in the pending list.""" - assert self.lock is not None, "Lock not acquired" - assert self.pending_tx_hashes is not None, ( - "Pending transaction hashes not loaded" - ) - return tx_hash in self.pending_tx_hashes - - def __len__(self) -> int: - """Get the number of pending transaction hashes.""" - assert self.lock is not None, "Lock not acquired" - assert self.pending_tx_hashes is not None, ( - "Pending transaction hashes not loaded" - ) - return len(self.pending_tx_hashes) - - def __iter__(self) -> Iterator[Hash]: - """Iterate over the pending transaction hashes.""" - assert self.lock is not None, "Lock not acquired" - assert self.pending_tx_hashes is not None, ( - "Pending transaction hashes not loaded" - ) - return iter(self.pending_tx_hashes) class ChainBuilderEthRPC(BaseEthRPC, namespace="eth"): @@ -190,9 +30,8 @@ class ChainBuilderEthRPC(BaseEthRPC, namespace="eth"): fork: Fork engine_rpc: EngineRPC - transactions_per_block: int get_payload_wait_time: float - pending_tx_hashes: PendingTxHashes + block_building_lock: FileLock def __init__( self, @@ -200,7 +39,6 @@ def __init__( rpc_endpoint: str, fork: Fork, engine_rpc: EngineRPC, - transactions_per_block: int, session_temp_folder: Path, get_payload_wait_time: float, initial_forkchoice_update_retries: int = 5, @@ -215,17 +53,17 @@ def __init__( ) self.fork = fork self.engine_rpc = engine_rpc - self.transactions_per_block = transactions_per_block - self.pending_tx_hashes = PendingTxHashes(session_temp_folder) + self.block_building_lock = FileLock( + session_temp_folder / "chain_builder_fcu.lock" + ) self.get_payload_wait_time = get_payload_wait_time # Send initial forkchoice updated only if we are the first worker base_name = "eth_rpc_forkchoice_updated" base_file = session_temp_folder / base_name base_error_file = session_temp_folder / f"{base_name}.err" - base_lock_file = session_temp_folder / f"{base_name}.lock" - with FileLock(base_lock_file): + with self.block_building_lock: if base_error_file.exists(): raise Exception( "Error occurred during initial forkchoice_updated" @@ -262,6 +100,17 @@ def __init__( base_error_file.unlink() # Success base_file.touch() + @property + def transaction_polling_context(self) -> AbstractContextManager: + """ + Return the block building lock as context manager so it's acquired + during transaction polling. + + Reasoning is that the lock gets acquired once while all processes + wait for transactions, only one of them produces a new block. + """ + return self.block_building_lock + def generate_block(self: "ChainBuilderEthRPC") -> None: """Generate a block using the Engine API.""" # Get the head block hash @@ -358,162 +207,12 @@ def generate_block(self: "ChainBuilderEthRPC") -> None: assert response.payload_status.status == PayloadStatusEnum.VALID, ( "Payload was invalid" ) - for tx in new_payload.execution_payload.transactions: - tx_hash = Hash(keccak256(tx)) - if tx_hash in self.pending_tx_hashes: - self.pending_tx_hashes.remove(tx_hash) - def send_transaction(self, transaction: TransactionProtocol) -> Hash: - """`eth_sendRawTransaction`: Send a transaction to the client.""" - returned_hash = super().send_transaction(transaction) - with self.pending_tx_hashes: - self.pending_tx_hashes.append(transaction.hash) - if len(self.pending_tx_hashes) >= self.transactions_per_block: - self.generate_block() - return returned_hash - - def wait_for_transaction( - self, transaction: TransactionProtocol - ) -> TransactionByHashResponse: + def pending_transactions_handler(self) -> None: """ - Wait for a specific transaction to be included in a block. - - Waits for a specific transaction to be included in a block by polling - `eth_getTransactionByHash` until it is confirmed or a timeout occurs. - - Args: - transaction: The transaction to track. + Called inside the transaction inclusion wait-loop. - Returns: - The transaction details after it is included in a block. - - """ - return self.wait_for_transactions([transaction])[0] - - def wait_for_transactions( - self, transactions: Sequence[TransactionProtocol] - ) -> List[TransactionByHashResponse]: + This class triggers the block building process if it's still waiting + for transactions to be included. """ - Wait for all transactions in the provided list to be included in a - block. - - Waits for all transactions in the provided list to be included in a - block by polling `eth_getTransactionByHash` until they are confirmed or - a timeout occurs. - - Args: - transactions: A list of transactions to track. - - Returns: - A list of transaction details after they are included in a block. - - Raises: - Exception: If one or more transactions are not included in a block - within the timeout period. - - """ - tx_hashes = [tx.hash for tx in transactions] - responses: List[TransactionByHashResponse] = [] - pending_responses: Dict[Hash, TransactionByHashResponse] = {} - - start_time = time.time() - pending_transactions_handler = PendingTransactionHandler(self) - while True: - tx_id = 0 - pending_responses = {} - while tx_id < len(tx_hashes): - tx_hash = tx_hashes[tx_id] - tx = self.get_transaction_by_hash(tx_hash) - assert tx is not None, f"Transaction {tx_hash} not found" - if tx.block_number is not None: - responses.append(tx) - tx_hashes.pop(tx_id) - else: - pending_responses[tx_hash] = tx - tx_id += 1 - - if not tx_hashes: - return responses - - pending_transactions_handler.handle() - - if (time.time() - start_time) > self.transaction_wait_timeout: - break - time.sleep(0.1) - - missing_txs_strings = [ - f"{tx.hash} ({tx.model_dump_json()})" - for tx in transactions - if tx.hash in tx_hashes - ] - - pending_tx_responses_string = "\n".join( - [ - f"{tx_hash}: {tx.model_dump_json()}" - for tx_hash, tx in pending_responses.items() - ] - ) - missing_str = ", ".join(missing_txs_strings) - raise Exception( - f"Transactions {missing_str} were not included in a block " - f"within {self.transaction_wait_timeout} seconds:\n" - f"{pending_tx_responses_string}" - ) - - -class PendingTransactionHandler: - """ - Manages block generation based on the number of pending transactions or a - block generation interval. - - Attributes: - block_generation_interval: The number of iterations after which a block - is generated if no new transactions are added (default: 10). - - """ - - chain_builder_eth_rpc: ChainBuilderEthRPC - block_generation_interval: int - last_pending_tx_hashes_count: int | None = None - i: int = 0 - - def __init__( - self, - chain_builder_eth_rpc: ChainBuilderEthRPC, - block_generation_interval: int = 10, - ): - """Initialize the pending transaction handler.""" - self.chain_builder_eth_rpc = chain_builder_eth_rpc - self.block_generation_interval = block_generation_interval - - def handle(self) -> None: - """ - Handle pending transactions and generate blocks if necessary. - - If the number of pending transactions reaches the limit, a block is - generated. - - If no new transactions have been added to the pending list and the - block generation interval has been reached, a block is generated to - avoid potential deadlock. - """ - with self.chain_builder_eth_rpc.pending_tx_hashes: - if ( - len(self.chain_builder_eth_rpc.pending_tx_hashes) - >= self.chain_builder_eth_rpc.transactions_per_block - ): - self.chain_builder_eth_rpc.generate_block() - else: - if ( - self.last_pending_tx_hashes_count is not None - and len(self.chain_builder_eth_rpc.pending_tx_hashes) - == self.last_pending_tx_hashes_count - and self.i % self.block_generation_interval == 0 - ): - # If no new transactions have been added to the pending - # list, generate a block to avoid potential deadlock. - self.chain_builder_eth_rpc.generate_block() - self.last_pending_tx_hashes_count = len( - self.chain_builder_eth_rpc.pending_tx_hashes - ) - self.i += 1 + self.generate_block() diff --git a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/rpc/hive.py b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/rpc/hive.py index 10c23aa78b..218511dd50 100644 --- a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/rpc/hive.py +++ b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/rpc/hive.py @@ -24,6 +24,8 @@ from execution_testing.forks import Fork from execution_testing.rpc import EngineRPC, EthRPC from execution_testing.test_types import ( + DETERMINISTIC_FACTORY_ADDRESS, + DETERMINISTIC_FACTORY_BYTECODE, EOA, Alloc, ChainConfig, @@ -104,7 +106,14 @@ def base_pre( seed_key_initial_balance = request.config.getoption( "seed_key_initial_balance" ) - return Alloc({seed_key: Account(balance=seed_key_initial_balance)}) + return Alloc( + { + seed_key: Account(balance=seed_key_initial_balance), + DETERMINISTIC_FACTORY_ADDRESS: Account( + nonce=1, code=DETERMINISTIC_FACTORY_BYTECODE + ), + } + ) @pytest.fixture(scope="session") @@ -387,7 +396,6 @@ def eth_rpc( client: Client, engine_rpc: EngineRPC, session_fork: Fork, - transactions_per_block: int, session_temp_folder: Path, max_transactions_per_batch: int | None, ) -> EthRPC: @@ -398,7 +406,6 @@ def eth_rpc( rpc_endpoint=f"http://{client.ip}:8545", fork=session_fork, engine_rpc=engine_rpc, - transactions_per_block=transactions_per_block, session_temp_folder=session_temp_folder, get_payload_wait_time=get_payload_wait_time, transaction_wait_timeout=tx_wait_timeout, diff --git a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/rpc/remote.py b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/rpc/remote.py index 78609d4fe9..b2045e8978 100644 --- a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/rpc/remote.py +++ b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/rpc/remote.py @@ -177,7 +177,6 @@ def eth_rpc( rpc_endpoint: str, engine_rpc: EngineRPC | None, session_fork: Fork, - transactions_per_block: int, session_temp_folder: Path, max_transactions_per_batch: int | None, ) -> EthRPC: @@ -194,7 +193,6 @@ def eth_rpc( rpc_endpoint=rpc_endpoint, fork=session_fork, engine_rpc=engine_rpc, - transactions_per_block=transactions_per_block, session_temp_folder=session_temp_folder, get_payload_wait_time=get_payload_wait_time, transaction_wait_timeout=tx_wait_timeout, diff --git a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/rpc/remote_seed_sender.py b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/rpc/remote_seed_sender.py index 5a0480f08c..fa91b082a5 100644 --- a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/rpc/remote_seed_sender.py +++ b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/rpc/remote_seed_sender.py @@ -54,10 +54,11 @@ def seed_key( ) # check the nonce through the rpc client seed_key = EOA(key=rpc_seed_key) - seed_key.nonce = Number(eth_rpc.get_transaction_count(seed_key)) + seed_account = eth_rpc.get_account(seed_key, skip_code=True) + seed_key.nonce = Number(seed_account.nonce) # Record the start balance of the worker key - start_balance = eth_rpc.get_balance(seed_key) + start_balance = seed_account.balance yield seed_key diff --git a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/sender.py b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/sender.py index fc403f3464..1640a01925 100644 --- a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/sender.py +++ b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/execute/sender.py @@ -1,16 +1,22 @@ """Sender mutex class that allows sending transactions one at a time.""" +import json +import time from pathlib import Path -from typing import Generator, Iterator +from typing import Generator, Iterator, List import pytest from filelock import FileLock from pytest_metadata.plugin import metadata_key -from execution_testing.base_types import Number, Wei +from execution_testing.base_types import Address, Number, Wei from execution_testing.logging import get_logger from execution_testing.rpc import EthRPC -from execution_testing.test_types import EOA, Transaction +from execution_testing.test_types import ( + EOA, + Transaction, + TransactionTestMetadata, +) logger = get_logger(__name__) @@ -55,6 +61,17 @@ def pytest_addoption(parser: pytest.Parser) -> None: ), ) + sender_group.addoption( + "--worker-funding-timeout", + action="store", + dest="worker_funding_timeout", + type=int, + default=120, + help=( + "Timeout in seconds for workers waiting to be funded. Default=120" + ), + ) + @pytest.fixture(scope="session") def sender_funding_transactions_gas_price( @@ -256,55 +273,68 @@ def session_worker_key( assert worker_key_funding_amount is not None, ( "`worker_key_funding_amount` is None" ) - # For the seed sender we do need to keep track of the nonce because it is - # shared among different processes, and there might not be a new block - # produced between the transactions. - seed_sender_nonce_file_name = "seed_sender_nonce" - seed_sender_lock_file_name = f"{seed_sender_nonce_file_name}.lock" - seed_sender_nonce_file = session_temp_folder / seed_sender_nonce_file_name - seed_sender_lock_file = session_temp_folder / seed_sender_lock_file_name + worker_keys_file_name = "worker_keys.json" + worker_keys_lock_file_name = f"{worker_keys_file_name}.lock" + worker_keys_file = session_temp_folder / worker_keys_file_name + worker_keys_lock_file = session_temp_folder / worker_keys_lock_file_name + worker_funded_file = session_temp_folder / "worker_keys_funded" worker_key = next(eoa_iterator) logger.info(f"Allocated worker key: {worker_key}") - # Prepare funding transaction for this specific worker. - # Each worker locks the next nonce by using a file lock to coordinate. - with FileLock(seed_sender_lock_file): - if seed_sender_nonce_file.exists(): - with seed_sender_nonce_file.open("r") as f: - seed_key.nonce = Number(f.read()) - logger.debug( - f"Loaded seed key nonce from file: {seed_key.nonce}" - ) + # Each worker registers its key. The last worker to register builds + # and sends all funding transactions as a single batch. + is_last_worker = False + with FileLock(worker_keys_lock_file): + if worker_keys_file.exists(): + registered_keys = json.loads(worker_keys_file.read_text()) else: - logger.debug( - "No existing seed key nonce file, using current nonce" - ) - fund_tx = Transaction( - sender=seed_key, - to=worker_key, - gas_limit=sender_fund_refund_gas_limit, - gas_price=sender_funding_transactions_gas_price, - value=worker_key_funding_amount, - ).with_signature_and_sender() - fund_eth = worker_key_funding_amount / 10**18 - logger.info( - f"Preparing funding transaction: {fund_eth:.18f} ETH " - f"from {seed_key} to {worker_key} (nonce={seed_key.nonce})" - ) - if not dry_run: - eth_rpc.send_transaction(fund_tx) - logger.info(f"Sent funding transaction: {fund_tx.hash}") - else: - logger.info("Dry run: skipping funding transaction send") - with seed_sender_nonce_file.open("w") as f: - f.write(str(seed_key.nonce)) - if not dry_run: + registered_keys = [] + registered_keys.append(str(worker_key)) + worker_keys_file.write_text(json.dumps(registered_keys)) + + if len(registered_keys) == worker_count: + is_last_worker = True + fund_txs: List[Transaction] = [] + for i, key_str in enumerate(registered_keys): + fund_tx = Transaction( + sender=seed_key, + to=Address(key_str), + gas_limit=sender_fund_refund_gas_limit, + gas_price=sender_funding_transactions_gas_price, + value=worker_key_funding_amount, + metadata=TransactionTestMetadata( + test_id="global", + phase="setup", + action="fund_eoa", + target="Session Worker Key", + tx_index=i, + ), + ).with_signature_and_sender() + fund_txs.append(fund_tx) + + if not dry_run: + eth_rpc.send_wait_transactions(fund_txs) + logger.info("All worker funding transactions confirmed") + else: + logger.info("Dry run: skipping funding transaction send") + worker_funded_file.touch() + + if not is_last_worker: + funding_timeout = request.config.option.worker_funding_timeout logger.info( - f"Waiting for funding transaction to be mined: {fund_tx.hash}" + "Waiting for all workers to be funded " + f"(timeout={funding_timeout}s)..." ) - eth_rpc.wait_for_transaction(fund_tx) - logger.info(f"Funding transaction confirmed: {fund_tx.hash}") + deadline = time.monotonic() + funding_timeout + while not worker_funded_file.exists(): + if time.monotonic() > deadline: + raise TimeoutError( + f"Timed out after {funding_timeout}s waiting " + "for all workers to be funded" + ) + time.sleep(0.1) + logger.info("All workers funded, proceeding") # Run all tests for this worker. yield worker_key @@ -313,8 +343,8 @@ def session_worker_key( logger.info( f"All tests completed for worker {worker_key}, preparing refund" ) - remaining_balance = eth_rpc.get_balance(worker_key) - worker_key.nonce = Number(eth_rpc.get_transaction_count(worker_key)) + worker_account = eth_rpc.get_account(worker_key, skip_code=True) + remaining_balance = worker_account.balance used_balance = worker_key_funding_amount - remaining_balance logger.info( f"Worker {worker_key} used balance: {used_balance / 10**18:.18f} ETH " @@ -346,7 +376,7 @@ def session_worker_key( # Update the nonce of the sender in case one of the pre-alloc transactions # failed - worker_key.nonce = Number(eth_rpc.get_transaction_count(worker_key)) + worker_key.nonce = worker_account.nonce refund_value = remaining_balance - tx_cost - 1 logger.info( f"Preparing refund transaction: {refund_value / 10**18:.18f} ETH " @@ -359,12 +389,19 @@ def session_worker_key( gas_limit=refund_gas_limit, gas_price=refund_gas_price, value=refund_value, + metadata=TransactionTestMetadata( + test_id="global", + phase="cleanup", + action="fund_eoa", + target="Session Worker Key Refund", + tx_index=worker_key.nonce, + ), ).with_signature_and_sender() logger.info( f"Sending and waiting for refund transaction: {refund_tx.hash}" ) - eth_rpc.send_wait_transaction(refund_tx) + eth_rpc.send_wait_transactions([refund_tx]) logger.info(f"Refund transaction confirmed: {refund_tx.hash}") @@ -374,11 +411,10 @@ def worker_key( ) -> Generator[EOA, None, None]: """Prepare the worker key for the current test.""" logger.debug(f"Preparing worker key {session_worker_key} for test") - rpc_nonce = Number( - eth_rpc.get_transaction_count( - session_worker_key, block_number="pending" - ) + session_worker_account = eth_rpc.get_account( + session_worker_key, block_number="pending", skip_code=True ) + rpc_nonce = Number(session_worker_account.nonce) if rpc_nonce != session_worker_key.nonce: wk_nonce = session_worker_key.nonce logger.info(f"Worker key nonce mismatch: {wk_nonce} != {rpc_nonce}") @@ -386,7 +422,7 @@ def worker_key( session_worker_key.nonce = rpc_nonce # Record the start balance of the worker key - worker_key_start_balance = eth_rpc.get_balance(session_worker_key) + worker_key_start_balance = session_worker_account.balance start_eth = worker_key_start_balance / 10**18 logger.debug(f"Worker key start balance: {start_eth:.18f} ETH") diff --git a/packages/testing/src/execution_testing/execution/transaction_post.py b/packages/testing/src/execution_testing/execution/transaction_post.py index ae6032119a..d034e6e523 100644 --- a/packages/testing/src/execution_testing/execution/transaction_post.py +++ b/packages/testing/src/execution_testing/execution/transaction_post.py @@ -113,11 +113,17 @@ def execute( signed_txs.append(tx) current_block_tx_hashes: List[Hash] = [] if any(tx.error is not None for tx in signed_txs): + tx_queue: List[Transaction] = [] for transaction in signed_txs: if transaction.error is None: - eth_rpc.send_wait_transaction(transaction) - current_block_tx_hashes.append(transaction.hash) + tx_queue.append(transaction) else: + if tx_queue: + eth_rpc.send_wait_transactions(tx_queue) + current_block_tx_hashes.extend( + tx.hash for tx in tx_queue + ) + tx_queue = [] logger.info( f"Sending transaction expecting rejection " f"(expected error: {transaction.error})..." @@ -130,6 +136,9 @@ def execute( "Transaction rejected as expected: " f"{exc_info.value}" ) + if tx_queue: + eth_rpc.send_wait_transactions(tx_queue) + current_block_tx_hashes.extend(tx.hash for tx in tx_queue) else: # Send transactions (batching is handled by eth_rpc internally) eth_rpc.send_wait_transactions(signed_txs) @@ -149,45 +158,24 @@ def execute( gas_used = int(receipt["gasUsed"], 16) benchmark_gas_used += gas_used - for address, account in self.post.root.items(): - balance = eth_rpc.get_balance(address) - code = eth_rpc.get_code(address) - nonce = eth_rpc.get_transaction_count(address) - if account is None: - assert balance == 0, ( - f"Balance of {address} is {balance}, expected 0." + actual_alloc = eth_rpc.get_alloc(self.post) + for address, expected_account in self.post.root.items(): + actual_account = actual_alloc.root[address] + assert actual_account is not None + if expected_account is None: + assert actual_account.balance == 0, ( + f"Balance of {address} is " + f"{actual_account.balance}, expected 0." ) - assert code == b"", ( - f"Code of {address} is {code}, expected 0x." + assert actual_account.code == b"", ( + f"Code of {address} is {actual_account.code}, expected 0x." ) - assert nonce == 0, ( - f"Nonce of {address} is {nonce}, expected 0." + assert actual_account.nonce == 0, ( + f"Nonce of {address} is " + f"{actual_account.nonce}, expected 0." ) else: - if "balance" in account.model_fields_set: - assert balance == account.balance, ( - f"Balance of {address} is {balance}, " - f"expected {account.balance}." - ) - if "code" in account.model_fields_set: - assert code == account.code, ( - f"Code of {address} is {code}, " - f"expected {account.code}." - ) - if "nonce" in account.model_fields_set: - assert nonce == account.nonce, ( - f"Nonce of {address} is {nonce}, " - f"expected {account.nonce}." - ) - if "storage" in account.model_fields_set: - for key, value in account.storage.items(): - storage_value = eth_rpc.get_storage_at( - address, Hash(key) - ) - assert storage_value == value, ( - f"Storage value at {key} of {address} is " - f"{storage_value}, expected {value}." - ) + expected_account.check_alloc(address, actual_account) return ExecuteResult( benchmark_gas_used=benchmark_gas_used, diff --git a/packages/testing/src/execution_testing/rpc/__init__.py b/packages/testing/src/execution_testing/rpc/__init__.py index 603cf0c282..6e0eef7680 100644 --- a/packages/testing/src/execution_testing/rpc/__init__.py +++ b/packages/testing/src/execution_testing/rpc/__init__.py @@ -20,6 +20,9 @@ EthConfigResponse, ForkConfig, ForkConfigBlobSchedule, + JSONRPCRequest, + JSONRPCResponse, + RPCCall, TransactionProtocol, ) @@ -36,7 +39,10 @@ "ForkConfig", "ForkConfigBlobSchedule", "ForkchoiceUpdateTimeoutError", + "JSONRPCRequest", + "JSONRPCResponse", "NetRPC", + "RPCCall", "PeerConnectionTimeoutError", "SendTransactionExceptionError", "TransactionProtocol", diff --git a/packages/testing/src/execution_testing/rpc/rpc.py b/packages/testing/src/execution_testing/rpc/rpc.py index 65e4643f7b..7f1477f1af 100644 --- a/packages/testing/src/execution_testing/rpc/rpc.py +++ b/packages/testing/src/execution_testing/rpc/rpc.py @@ -5,6 +5,7 @@ import logging import os import time +from contextlib import AbstractContextManager, nullcontext from itertools import count from pprint import pprint from typing import Any, Callable, ClassVar, Dict, List, Literal, Sequence @@ -24,7 +25,14 @@ wait_fixed as wait_fixed_tenacity, ) -from execution_testing.base_types import Address, Bytes, Hash, to_json +from execution_testing.base_types import ( + Account, + Address, + Alloc, + Bytes, + Hash, + to_json, +) from execution_testing.logging import ( get_logger, ) @@ -35,10 +43,12 @@ ForkchoiceUpdateResponse, GetBlobsResponse, GetPayloadResponse, - JSONRPCError, + JSONRPCRequest, + JSONRPCResponse, PayloadAttributes, PayloadStatus, PayloadStatusEnum, + RPCCall, TransactionByHashResponse, TransactionProtocol, ) @@ -192,7 +202,7 @@ def __init_subclass__(cls, namespace: str | None = None) -> None: def _make_request( self, url: str, - json_payload: dict[str, Any], + json_payload: dict[str, Any] | list[dict[str, Any]], headers: dict[str, str], timeout: int | None, ) -> requests.Response: @@ -212,59 +222,108 @@ def _make_request( url, json=json_payload, headers=headers, timeout=timeout ) + def _build_json_rpc_request( + self, + call: RPCCall, + ) -> JSONRPCRequest: + """Build a JSON-RPC request object with namespace prefix.""" + assert self.namespace, "RPC namespace not set" + + next_request_id_counter = next(self.request_id_counter) + request_id = call.request_id + if request_id is None: + request_id = next_request_id_counter + + return JSONRPCRequest( + method=f"{self.namespace}_{call.method}", + params=call.params, + id=request_id, + ) + def post_request( self, *, - method: str, - params: List[Any] | None = None, + request: RPCCall, extra_headers: Dict[str, str] | None = None, - request_id: int | str | None = None, timeout: int | None = None, - ) -> Any: + ) -> JSONRPCResponse: """ Send JSON-RPC POST request to the client RPC server at port defined in the url. """ if extra_headers is None: extra_headers = {} - if params is None: - params = [] - assert self.namespace, "RPC namespace not set" + json_rpc_request = self._build_json_rpc_request(request) + base_header = { + "Content-Type": "application/json", + } + headers = base_header | extra_headers - next_request_id_counter = next(self.request_id_counter) - if request_id is None: - request_id = next_request_id_counter + logger.debug( + f"Sending RPC request to {self.url}, " + f"method={json_rpc_request.method}, timeout={timeout}..." + ) - payload = { - "jsonrpc": "2.0", - "method": f"{self.namespace}_{method}", - "params": params, - "id": request_id, - } + response = self._make_request( + self.url, json_rpc_request.model_dump(), headers, timeout + ) + response.raise_for_status() + + return JSONRPCResponse.model_validate(response.json()) + + def post_batch_request( + self, + *, + calls: Sequence[RPCCall], + extra_headers: Dict[str, str] | None = None, + timeout: int | None = None, + ) -> List[JSONRPCResponse]: + """ + Send a JSON-RPC batch POST request to the client RPC server at port + defined in the url. + """ + if extra_headers is None: + extra_headers = {} + + json_rpc_requests = [ + self._build_json_rpc_request(call) for call in calls + ] + payload = [r.model_dump() for r in json_rpc_requests] base_header = { "Content-Type": "application/json", } headers = base_header | extra_headers logger.debug( - f"Sending RPC request to {self.url}, " - f"method={self.namespace}_{method}, timeout={timeout}..." + f"Sending batch RPC request to {self.url}, " + f"{len(json_rpc_requests)} calls, timeout={timeout}..." ) response = self._make_request(self.url, payload, headers, timeout) response.raise_for_status() response_json = response.json() - if "error" in response_json: - raise JSONRPCError(**response_json["error"]) - - assert "result" in response_json, ( - "RPC response didn't contain a result field" + assert isinstance(response_json, list), ( + "Batch RPC response is not a list" ) - result = response_json["result"] - logger.info(f"RPC Result: {result}") - return result + + response_map: dict[int | str, JSONRPCResponse] = { + r.id: r + for r in [ + JSONRPCResponse.model_validate(item) for item in response_json + ] + } + + results = [] + for json_rpc_request in json_rpc_requests: + assert json_rpc_request.id in response_map, ( + f"Missing response for request ID {json_rpc_request.id}" + ) + results.append(response_map[json_rpc_request.id]) + + logger.info(f"Batch RPC: {len(results)} responses received") + return results class EthRPC(BaseRPC): @@ -351,7 +410,9 @@ def config(self, timeout: int | None = None) -> EthConfigResponse | None: """ try: logger.info("Requesting eth_config..") - response = self.post_request(method="config", timeout=timeout) + response = self.post_request( + request=RPCCall(method="config"), timeout=timeout + ).result_or_raise() if response is None: logger.warning("eth_config request: failed to get response") return None @@ -370,7 +431,9 @@ def config(self, timeout: int | None = None) -> EthConfigResponse | None: def chain_id(self) -> int: """`eth_chainId`: Returns the current chain id.""" logger.info("Requesting chainid of provided RPC endpoint..") - response = self.post_request(method="chainId", timeout=10) + response = self.post_request( + request=RPCCall(method="chainId"), timeout=10 + ).result_or_raise() return int(response, 16) def get_block_by_number( @@ -387,8 +450,9 @@ def get_block_by_number( ) logger.info(f"Requesting info about block {block}..") params = [block, full_txs] - response = self.post_request(method="getBlockByNumber", params=params) - return response + return self.post_request( + request=RPCCall(method="getBlockByNumber", params=params) + ).result_or_raise() def get_block_by_hash( self, block_hash: Hash, full_txs: bool = True @@ -396,8 +460,9 @@ def get_block_by_hash( """`eth_getBlockByHash`: Returns information about a block by hash.""" logger.info(f"Requesting block info of {block_hash}..") params = [f"{block_hash}", full_txs] - response = self.post_request(method="getBlockByHash", params=params) - return response + return self.post_request( + request=RPCCall(method="getBlockByHash", params=params) + ).result_or_raise() def get_block_by_hash_with_retry( self, @@ -470,7 +535,9 @@ def get_balance( ) logger.info(f"Requesting balance of {address} at block {block}") params = [f"{address}", block] - response = self.post_request(method="getBalance", params=params) + response = self.post_request( + request=RPCCall(method="getBalance", params=params) + ).result_or_raise() return int(response, 16) def get_code( @@ -484,7 +551,9 @@ def get_code( ) logger.info(f"Requesting code of {address} at block {block}") params = [f"{address}", block] - response = self.post_request(method="getCode", params=params) + response = self.post_request( + request=RPCCall(method="getCode", params=params) + ).result_or_raise() return Bytes(response) def get_transaction_count( @@ -502,8 +571,8 @@ def get_transaction_count( logger.info(f"Requesting nonce of {address}") params = [f"{address}", block] response = self.post_request( - method="getTransactionCount", params=params - ) + request=RPCCall(method="getTransactionCount", params=params) + ).result_or_raise() return int(response, 16) def get_transaction_by_hash( @@ -513,8 +582,11 @@ def get_transaction_by_hash( try: logger.info(f"Requesting tx details of {transaction_hash}") response = self.post_request( - method="getTransactionByHash", params=[f"{transaction_hash}"] - ) + request=RPCCall( + method="getTransactionByHash", + params=[f"{transaction_hash}"], + ) + ).result_or_raise() if response is None: return None return TransactionByHashResponse.model_validate( @@ -524,6 +596,39 @@ def get_transaction_by_hash( pprint(e.errors()) raise e + def get_transactions_by_hash( + self, transaction_hashes: Sequence[Hash] + ) -> List[TransactionByHashResponse | None]: + """ + Batch `eth_getTransactionByHash` for multiple hashes. + + Return a list of responses in the same order as the input + hashes. Entries are `None` if the transaction was not found. + """ + if not transaction_hashes: + return [] + calls = [ + RPCCall( + method="getTransactionByHash", + params=[f"{tx_hash}"], + ) + for tx_hash in transaction_hashes + ] + responses = self.post_batch_request(calls=calls) + results: List[TransactionByHashResponse | None] = [] + for response in responses: + result = response.result_or_raise() + if result is None: + results.append(None) + else: + results.append( + TransactionByHashResponse.model_validate( + result, + context=self.response_validation_context, + ) + ) + return results + def get_transaction_receipt( self, transaction_hash: Hash ) -> dict[str, Any] | None: @@ -534,10 +639,12 @@ def get_transaction_receipt( in benchmark tests. """ logger.info(f"Requesting tx receipt of {transaction_hash}") - response = self.post_request( - method="getTransactionReceipt", params=[f"{transaction_hash}"] - ) - return response + return self.post_request( + request=RPCCall( + method="getTransactionReceipt", + params=[f"{transaction_hash}"], + ) + ).result_or_raise() def get_storage_at( self, @@ -559,7 +666,9 @@ def get_storage_at( f"of contract {address}" ) params = [f"{address}", f"{position}", block] - response = self.post_request(method="getStorageAt", params=params) + response = self.post_request( + request=RPCCall(method="getStorageAt", params=params) + ).result_or_raise() return Hash(response) def _get_gas_information( @@ -572,7 +681,9 @@ def _get_gas_information( time.time() - self._gas_information_cache_timestamp[method] > self.gas_information_stale_seconds ): - response = self.post_request(method=method) + response = self.post_request( + request=RPCCall(method=method) + ).result_or_raise() logger.info(f"Requesting stale {method}") self._gas_information_cache[method] = int(response, 16) self._gas_information_cache_timestamp[method] = time.time() @@ -602,10 +713,12 @@ def send_raw_transaction( try: logger.info("Sending raw tx..") response = self.post_request( - method="sendRawTransaction", - params=[transaction_rlp.hex()], - request_id=request_id, - ) + request=RPCCall( + method="sendRawTransaction", + params=[transaction_rlp.hex()], + request_id=request_id, + ) + ).result_or_raise() result_hash = Hash(response) assert result_hash is not None return result_hash @@ -623,10 +736,12 @@ def send_transaction(self, transaction: TransactionProtocol) -> Hash: try: logger.info("Sending tx..") response = self.post_request( - method="sendRawTransaction", - params=[transaction.rlp().hex()], - request_id=transaction.metadata_string(), - ) + request=RPCCall( + method="sendRawTransaction", + params=[transaction.rlp().hex()], + request_id=transaction.metadata_string(), + ) + ).result_or_raise() result_hash = Hash(response) assert result_hash == transaction.hash assert result_hash is not None @@ -638,26 +753,198 @@ def send_transactions( self, transactions: Sequence[TransactionProtocol] ) -> List[Hash]: """ - Use `eth_sendRawTransaction` to send a list of transactions to the + Use `eth_sendRawTransaction` to send a batch of transactions to the client. """ - return [self.send_transaction(tx) for tx in transactions] + if not transactions: + return [] - def storage_at_keys( + calls = [ + RPCCall( + method="sendRawTransaction", + params=[tx.rlp().hex()], + request_id=tx.metadata_string(), + ) + for tx in transactions + ] + responses = self.post_batch_request(calls=calls) + + results: List[Hash] = [] + for tx, response in zip(transactions, responses, strict=True): + try: + result_hash = Hash(response.result_or_raise()) + assert result_hash == tx.hash + assert result_hash is not None + results.append(tx.hash) + except Exception as e: + raise SendTransactionExceptionError(str(e), tx=tx) from e + return results + + def _build_get_account_calls( self, - account: Address, - keys: List[Hash], + address: Address, + account: Account | None, + block: str, + skip_code: bool = False, + ) -> tuple[List[RPCCall], List[tuple[str, Any]]]: + """Build the RPC calls needed to fetch an account's state.""" + calls: List[RPCCall] = [] + # (field_name, storage_key) + call_info: List[tuple[str, Any]] = [] + + calls.append( + RPCCall( + method="getBalance", + params=[f"{address}", block], + ) + ) + call_info.append(("balance", None)) + if not skip_code: + calls.append( + RPCCall( + method="getCode", + params=[f"{address}", block], + ) + ) + call_info.append(("code", None)) + calls.append( + RPCCall( + method="getTransactionCount", + params=[f"{address}", block], + ) + ) + call_info.append(("nonce", None)) + + if account is not None and "storage" in account.model_fields_set: + for key in account.storage.root: + calls.append( + RPCCall( + method="getStorageAt", + params=[ + f"{address}", + f"{Hash(key)}", + block, + ], + ) + ) + call_info.append(("storage", key)) + + return calls, call_info + + @staticmethod + def _parse_account_responses( + call_info: List[tuple[str, Any]], + responses: List[JSONRPCResponse], + ) -> Account: + """Parse RPC responses into an Account.""" + data: Dict[str, Any] = {} + for (field, key), response in zip(call_info, responses, strict=True): + result = response.result_or_raise() + if field == "balance": + data["balance"] = int(result, 16) + elif field == "code": + data["code"] = Bytes(result) + elif field == "nonce": + data["nonce"] = int(result, 16) + elif field == "storage": + if "storage" not in data: + data["storage"] = {} + data["storage"][key] = Hash(result) + return Account(**data) + + def get_account( + self, + address: Address, + account: Account | None = None, block_number: BlockNumberType = "latest", - ) -> Dict[Hash, Hash]: + skip_code: bool = False, + ) -> Account: """ - Retrieve the storage values for the specified keys at a given address - and block number. + Fetch account state from the chain for a single address using + a batch RPC request. + + If `account` is provided, its storage keys are also fetched. + If `skip_code` is True, the code fetch is omitted. """ - results: Dict[Hash, Hash] = {} - for key in keys: - storage_value = self.get_storage_at(account, key, block_number) - results[key] = storage_value - return results + block = ( + hex(block_number) + if isinstance(block_number, int) + else block_number + ) + calls, call_info = self._build_get_account_calls( + address, account, block, skip_code=skip_code + ) + responses = self.post_batch_request(calls=calls) + return self._parse_account_responses(call_info, responses) + + def get_alloc( + self, + alloc: Alloc, + block_number: BlockNumberType = "latest", + skip_code: bool = False, + ) -> Alloc: + """ + Fetch account state from the chain for all addresses in the + given alloc using a batch RPC request. + + If `skip_code` is True, the code fetch is omitted for all + accounts. + """ + if not alloc.root: + return Alloc() + + block = ( + hex(block_number) + if isinstance(block_number, int) + else block_number + ) + + all_calls: List[RPCCall] = [] + # (address, per-account call_info list, call count) + address_info: List[tuple[Address, List[tuple[str, Any]]]] = [] + + for address, account in alloc.root.items(): + calls, call_info = self._build_get_account_calls( + address, account, block, skip_code=skip_code + ) + all_calls.extend(calls) + address_info.append((address, call_info)) + + responses = self.post_batch_request(calls=all_calls) + + result_alloc: Dict[Address, Account | None] = {} + offset = 0 + for address, call_info in address_info: + n = len(call_info) + result_alloc[address] = self._parse_account_responses( + call_info, responses[offset : offset + n] + ) + offset += n + + return Alloc(root=result_alloc) + + @property + def transaction_polling_context(self) -> AbstractContextManager: + """ + Return a context manager acquired during transaction polling. + + By default a no-op. Subclasses can override to synchronize + transaction querying with block building. + """ + return nullcontext() + + def pending_transactions_handler(self) -> None: + """ + Called inside the transaction_polling_context context during the + transaction inclusion wait-loop. + + Useful for subclasses to override to introduce logic to perform + between transaction waits, such as triggering the block building + process. + + By default it only waits the `poll_interval`. + """ + time.sleep(self.poll_interval) def wait_for_transaction( self, transaction: TransactionProtocol @@ -670,56 +957,76 @@ def wait_for_transaction( start_time = time.time() while True: logger.info(f"Waiting for inclusion of tx {tx_hash} in a block..") - tx = self.get_transaction_by_hash(tx_hash) - if tx is not None and tx.block_number is not None: - return tx - if (time.time() - start_time) > self.transaction_wait_timeout: - break - time.sleep(self.poll_interval) + with self.transaction_polling_context: + tx = self.get_transaction_by_hash(tx_hash) + if tx is not None and tx.block_number is not None: + return tx + if (time.time() - start_time) > self.transaction_wait_timeout: + break + self.pending_transactions_handler() raise Exception( f"Transaction {tx_hash} ({transaction.model_dump_json()}) " - f"not included in a block after {self.transaction_wait_timeout} " - "seconds" + f"not included in a block after " + f"{self.transaction_wait_timeout} seconds" ) def wait_for_transactions( self, transactions: Sequence[TransactionProtocol] ) -> List[TransactionByHashResponse]: """ - Use `eth_getTransactionByHash` to wait until all transactions in list - are included in a block. + Use `eth_getTransactionByHash` batch requests to wait until all + transactions in list are included in a block. """ - tx_hashes = [tx.hash for tx in transactions] - responses: List[TransactionByHashResponse] = [] + if not transactions: + return [] + + pending: dict[Hash, TransactionProtocol] = { + tx.hash: tx for tx in transactions + } + found: dict[Hash, TransactionByHashResponse] = {} start_time = time.time() - logger.info("Waiting for all transaction to be included in a block..") - while True: - i = 0 - while i < len(tx_hashes): - tx_hash = tx_hashes[i] - tx = self.get_transaction_by_hash(tx_hash) - if tx is not None and tx.block_number is not None: - responses.append(tx) - logger.info( - f"Tx {tx.hash} was included in block {tx.block_number}" + logger.info("Waiting for all transactions to be included in a block..") + + while pending: + with self.transaction_polling_context: + pending_hashes = list(pending.keys()) + tx_responses = self.get_transactions_by_hash(pending_hashes) + + newly_found: List[Hash] = [] + for tx_hash, tx_response in zip( + pending_hashes, tx_responses, strict=True + ): + if tx_response is None: + continue + if tx_response.block_number is not None: + found[tx_hash] = tx_response + newly_found.append(tx_hash) + logger.info( + f"Tx {tx_response.hash} was included " + f"in block {tx_response.block_number}" + ) + + for tx_hash in newly_found: + del pending[tx_hash] + + if not pending: + break + + if (time.time() - start_time) > self.transaction_wait_timeout: + missing_txs_strings = [ + f"{tx.hash} ({tx.model_dump_json()})" + for tx in transactions + if tx.hash in pending + ] + raise Exception( + f"Transactions " + f"{', '.join(missing_txs_strings)} not " + f"included in a block after " + f"{self.transaction_wait_timeout} seconds" ) - tx_hashes.pop(i) - else: - i += 1 - if not tx_hashes: - return responses - if (time.time() - start_time) > self.transaction_wait_timeout: - break - time.sleep(self.poll_interval) - missing_txs_strings = [ - f"{tx.hash} ({tx.model_dump_json()})" - for tx in transactions - if tx.hash in tx_hashes - ] - raise Exception( - f"Transactions {', '.join(missing_txs_strings)} not included " - f"in a block after {self.transaction_wait_timeout} seconds" - ) + self.pending_transactions_handler() + + return [found[tx.hash] for tx in transactions] def send_wait_transaction(self, transaction: TransactionProtocol) -> Any: """Send transaction and waits until it is included in a block.""" @@ -761,7 +1068,9 @@ class DebugRPC(EthRPC): def trace_call(self, tr: dict[str, str], block_number: str) -> Any | None: """`debug_traceCall`: Returns pre state required for transaction.""" params = [tr, block_number, {"tracer": "prestateTracer"}] - return self.post_request(method="traceCall", params=params) + return self.post_request( + request=RPCCall(method="traceCall", params=params) + ).result_or_raise() class EngineRPC(BaseRPC): @@ -785,19 +1094,10 @@ def __init__( super().__init__(*args, **kwargs) self.jwt_secret = jwt_secret - def post_request( - self, - *, - method: str, - params: Any | None = None, - extra_headers: Dict[str, str] | None = None, - request_id: int | str | None = None, - timeout: int | None = None, - ) -> Any: - """ - Send JSON-RPC POST request to the client RPC server at port defined in - the url. - """ + def _jwt_extra_headers( + self, extra_headers: Dict[str, str] | None = None + ) -> Dict[str, str]: + """Build extra headers with JWT authentication.""" if extra_headers is None: extra_headers = {} jwt_token = encode( @@ -805,16 +1105,40 @@ def post_request( self.jwt_secret, algorithm="HS256", ) - extra_headers = { + return { "Authorization": f"Bearer {jwt_token}", } | extra_headers + def post_request( + self, + *, + request: RPCCall, + extra_headers: Dict[str, str] | None = None, + timeout: int | None = None, + ) -> JSONRPCResponse: + """ + Send JSON-RPC POST request with Engine API JWT authentication. + """ return super().post_request( - method=method, - params=params, - extra_headers=extra_headers, + request=request, + extra_headers=self._jwt_extra_headers(extra_headers), + timeout=timeout, + ) + + def post_batch_request( + self, + *, + calls: Sequence[RPCCall], + extra_headers: Dict[str, str] | None = None, + timeout: int | None = None, + ) -> List[JSONRPCResponse]: + """ + Send JSON-RPC batch POST request with Engine API JWT authentication. + """ + return super().post_batch_request( + calls=calls, + extra_headers=self._jwt_extra_headers(extra_headers), timeout=timeout, - request_id=request_id, ) def new_payload(self, *params: Any, version: int) -> PayloadStatus: @@ -826,7 +1150,9 @@ def new_payload(self, *params: Any, version: int) -> PayloadStatus: params_list = [to_json(param) for param in params] return PayloadStatus.model_validate( - self.post_request(method=method, params=params_list), + self.post_request( + request=RPCCall(method=method, params=params_list) + ).result_or_raise(), context=self.response_validation_context, ) @@ -850,9 +1176,8 @@ def forkchoice_updated( return ForkchoiceUpdateResponse.model_validate( self.post_request( - method=method, - params=params, - ), + request=RPCCall(method=method, params=params), + ).result_or_raise(), context=self.response_validation_context, ) @@ -870,9 +1195,8 @@ def get_payload( return GetPayloadResponse.model_validate( self.post_request( - method=method, - params=[f"{payload_id}"], - ), + request=RPCCall(method=method, params=[f"{payload_id}"]), + ).result_or_raise(), context=self.response_validation_context, ) @@ -889,9 +1213,8 @@ def get_blobs( params = [f"{h}" for h in versioned_hashes] response = self.post_request( - method=method, - params=[params], - ) + request=RPCCall(method=method, params=[params]), + ).result_or_raise() if response is None: # for tests that request non-existing blobs logger.debug("get_blobs response received but it has value: None") return None @@ -984,7 +1307,9 @@ class NetRPC(BaseRPC): def peer_count(self) -> int: """`net_peerCount`: Get the number of peers connected to the client.""" - response = self.post_request(method="peerCount") + response = self.post_request( + request=RPCCall(method="peerCount") + ).result_or_raise() return int(response, 16) # hex -> int def wait_for_peer_connection( @@ -1055,4 +1380,6 @@ class AdminRPC(BaseRPC): def add_peer(self, enode: str) -> bool: """`admin_addPeer`: Add a peer by enode URL.""" - return self.post_request(method="addPeer", params=[enode]) + return self.post_request( + request=RPCCall(method="addPeer", params=[enode]) + ).result_or_raise() diff --git a/packages/testing/src/execution_testing/rpc/rpc_types.py b/packages/testing/src/execution_testing/rpc/rpc_types.py index d543fc56ff..eff0e8f65c 100644 --- a/packages/testing/src/execution_testing/rpc/rpc_types.py +++ b/packages/testing/src/execution_testing/rpc/rpc_types.py @@ -6,7 +6,7 @@ from hashlib import sha256 from typing import Annotated, Any, Dict, List, Protocol, Self -from pydantic import AliasChoices, Field, model_validator +from pydantic import AliasChoices, BaseModel, Field, model_validator from execution_testing.base_types import ( Address, @@ -57,6 +57,61 @@ def __str__(self) -> str: return f"JSONRPCError(code={self.code}, message={self.message})" +class RPCCall(BaseModel): + """Represent a JSON-RPC method call before namespace prefixing.""" + + method: str + params: List[Any] = [] + request_id: int | str | None = None + + +class JSONRPCRequest(BaseModel): + """Represent a JSON-RPC 2.0 request object.""" + + jsonrpc: str = "2.0" + method: str + params: List[Any] = [] + id: int | str + + +class JSONRPCErrorObject(BaseModel): + """Represent the error object in a JSON-RPC 2.0 response.""" + + code: int + message: str + data: str | None = None + + +class JSONRPCResponse(BaseModel): + """Represent a JSON-RPC 2.0 response object.""" + + jsonrpc: str = "2.0" + id: int | str + result: Any = None + error: JSONRPCErrorObject | None = None + + @model_validator(mode="before") + @classmethod + def check_result_or_error(cls, data: Any) -> Any: + """Validate that the response contains 'result' or 'error'.""" + if isinstance(data, dict): + if "result" not in data and "error" not in data: + raise ValueError( + "RPC response must contain 'result' or 'error'" + ) + return data + + def result_or_raise(self) -> Any: + """Return the result or raise JSONRPCError.""" + if self.error is not None: + raise JSONRPCError( + code=self.error.code, + message=self.error.message, + data=self.error.data, + ) + return self.result + + class TransactionByHashResponse(Transaction): """Represents the response of a transaction by hash request.""" diff --git a/tests/benchmark/compute/helpers.py b/tests/benchmark/compute/helpers.py index 6f52cb1582..13be98d788 100644 --- a/tests/benchmark/compute/helpers.py +++ b/tests/benchmark/compute/helpers.py @@ -2,7 +2,7 @@ import math from enum import Enum, auto -from typing import Dict, Generator, Self, Sequence, cast +from typing import Dict, Generator, List, Self, Sequence, cast from execution_testing import ( EOA, @@ -21,6 +21,7 @@ compute_create2_address, compute_deterministic_create2_address, ) +from pydantic import Field from tests.osaka.eip7951_p256verify_precompiles.spec import ( FieldElement, @@ -334,6 +335,12 @@ def address(self) -> Address: return self._cached_address +class ContractDeploymentTransaction(TransactionWithCost): + """Transaction object that can include the expected gas to be consumed.""" + + deployed_contracts: List[Address] = Field(..., exclude=True) + + class CustomSizedContractFactory(IteratingBytecode): """ Factory contract that creates contracts with a custom size. @@ -439,7 +446,7 @@ def transactions_by_total_contract_count( sender: EOA, contract_count: int, contract_start_index: int = 0, - ) -> Generator[TransactionWithCost, None, None]: + ) -> Generator[ContractDeploymentTransaction, None, None]: """ Create a list of transactions calling the factory to create the given number of contracts, each capped tx properly capped by the @@ -485,12 +492,21 @@ def calldata(iteration_count: int, start_iteration: int) -> bytes: start_iteration=start_iteration, calldata=calldata_max, ) - yield TransactionWithCost( + deployed_contracts = [ + self.created_contract_address( + salt=i, + ) + for i in range( + start_iteration, start_iteration + iteration_count + ) + ] + yield ContractDeploymentTransaction( to=to, gas_limit=tx_gas_limit, sender=sender, gas_cost=tx_gas_cost, data=calldata(iteration_count, start_iteration), + deployed_contracts=deployed_contracts, ) start_iteration += iteration_count last_iteration_count = iteration_count diff --git a/tests/benchmark/compute/instruction/test_account_query.py b/tests/benchmark/compute/instruction/test_account_query.py index 60819b5e17..4cf815f7ce 100644 --- a/tests/benchmark/compute/instruction/test_account_query.py +++ b/tests/benchmark/compute/instruction/test_account_query.py @@ -37,7 +37,7 @@ While, ) -from ..helpers import CustomSizedContractFactory +from ..helpers import ContractDeploymentTransaction, CustomSizedContractFactory @pytest.mark.repricing(contract_balance=1) @@ -616,15 +616,22 @@ def access_list_generator( ) # Deploy num_contracts via multiple txs (each capped by tx gas limit). + post = {} with TestPhaseManager.setup(): setup_sender = pre.fund_eoa() - contracts_deployment_txs = list( + contracts_deployment_txs: List[ContractDeploymentTransaction] = [] + for contract_creating_tx in ( custom_sized_contract_factory.transactions_by_total_contract_count( fork=fork, sender=setup_sender, contract_count=num_contracts, ) - ) + ): + contracts_deployment_txs.append(contract_creating_tx) + if custom_sized_contract_factory.contract_size > 0: + post[contract_creating_tx.deployed_contracts[-1]] = Account( + nonce=1 + ) with TestPhaseManager.execution(): attack_sender = pre.fund_eoa() @@ -652,14 +659,6 @@ def access_list_generator( ) total_gas_cost = sum(tx.gas_cost for tx in attack_txs) - post = {} - if custom_sized_contract_factory.contract_size > 0: - for i in range(num_contracts): - deployed_contract_address = ( - custom_sized_contract_factory.created_contract_address(salt=i) - ) - post[deployed_contract_address] = Account(nonce=1) - benchmark_test( pre=pre, post=post, diff --git a/tests/benchmark/compute/scenario/test_unchunkified_bytecode.py b/tests/benchmark/compute/scenario/test_unchunkified_bytecode.py index f226895ec2..8b9420e142 100644 --- a/tests/benchmark/compute/scenario/test_unchunkified_bytecode.py +++ b/tests/benchmark/compute/scenario/test_unchunkified_bytecode.py @@ -3,6 +3,8 @@ This scenario is relevant in forks that have unchunkified bytecode. """ +from typing import List + import pytest from execution_testing import ( Account, @@ -19,7 +21,7 @@ While, ) -from ..helpers import CustomSizedContractFactory +from ..helpers import ContractDeploymentTransaction, CustomSizedContractFactory @pytest.mark.repricing @@ -131,15 +133,22 @@ def calldata(iteration_count: int, start_iteration: int) -> bytes: ) # Deploy num_contracts via multiple txs (each capped by tx gas limit). + post = {} with TestPhaseManager.setup(): setup_sender = pre.fund_eoa() - contracts_deployment_txs = list( + contracts_deployment_txs: List[ContractDeploymentTransaction] = [] + for contract_creating_tx in ( custom_sized_contract_factory.transactions_by_total_contract_count( fork=fork, sender=setup_sender, contract_count=num_contracts, ) - ) + ): + contracts_deployment_txs.append(contract_creating_tx) + if custom_sized_contract_factory.contract_size > 0: + post[contract_creating_tx.deployed_contracts[-1]] = Account( + nonce=1 + ) with TestPhaseManager.execution(): attack_sender = pre.fund_eoa() @@ -166,13 +175,6 @@ def calldata(iteration_count: int, start_iteration: int) -> bytes: ) total_gas_cost = sum(tx.gas_cost for tx in attack_txs) - post = {} - for i in range(num_contracts): - deployed_contract_address = ( - custom_sized_contract_factory.created_contract_address(salt=i) - ) - post[deployed_contract_address] = Account(nonce=1) - benchmark_test( pre=pre, post=post,