diff --git a/.github/workflows/e2e-tests.yml b/.github/workflows/e2e-tests.yml index a28107d..85b7921 100644 --- a/.github/workflows/e2e-tests.yml +++ b/.github/workflows/e2e-tests.yml @@ -124,7 +124,7 @@ jobs: os: - ubuntu-latest test-file: ${{ fromJson(needs.find-tests.outputs.test-files) }} - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: - name: Check-out repository uses: actions/checkout@v4 diff --git a/.github/workflows/unit-and-integration-test.yml b/.github/workflows/unit-and-integration-test.yml index fe2b222..d301d18 100644 --- a/.github/workflows/unit-and-integration-test.yml +++ b/.github/workflows/unit-and-integration-test.yml @@ -15,7 +15,7 @@ jobs: fail-fast: false max-parallel: 5 matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: - name: Checkout repository diff --git a/CHANGELOG.md b/CHANGELOG.md index a1cf15f..c601550 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## 1.6.2 /2025-02-19 + +## What's Changed +* Typing by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/265 +* Cache Improvements by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/267 +* improve (async) query map result by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/266 +* Use threaded bt-decode by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/268 +* Feat: Handle new error message fmt for InBlock source by @ibraheem-abe in https://github.com/opentensor/async-substrate-interface/pull/269 +* Update: Remove python 3.9 support by @ibraheem-abe in https://github.com/opentensor/async-substrate-interface/pull/271 + +**Full Changelog**: https://github.com/opentensor/async-substrate-interface/compare/v1.6.1...v1.6.2 + ## 1.6.1 /2025-02-03 * RuntimeCache updates by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/260 * fix memory leak by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/261 diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index 177e0e2..dcc5ceb 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -31,7 +31,6 @@ from scalecodec.types import ( GenericCall, GenericExtrinsic, - GenericRuntimeCallDefinition, ss58_encode, MultiAccountId, ) @@ -74,12 +73,10 @@ _bt_decode_to_dict_or_list, legacy_scale_decode, convert_account_ids, + decode_query_map_async, ) from async_substrate_interface.utils.storage import StorageKey from async_substrate_interface.type_registry import _TYPE_REGISTRY -from async_substrate_interface.utils.decoding import ( - decode_query_map, -) ResultHandler = Callable[[dict, Any], Awaitable[tuple[dict, bool]]] @@ -526,6 +523,18 @@ async def retrieve_next_page(self, start_key) -> list: self.last_key = result.last_key return result.records + async def retrieve_all_records(self) -> list[Any]: + """ + Retrieves all records from all subsequent pages for the AsyncQueryMapResult, + returning them as a list. + + Side effect: + The self.records list will be populated fully after running this method. + """ + async for _ in self: + pass + return self.records + def __aiter__(self): return self @@ -558,6 +567,7 @@ async def __anext__(self): self.loading_complete = True raise StopAsyncIteration + self.records.extend(next_page) # Update the buffer with the newly fetched records self._buffer = iter(next_page) return next(self._buffer) @@ -1408,7 +1418,9 @@ async def decode_scale( if runtime is None: runtime = await self.init_runtime(block_hash=block_hash) if runtime.metadata_v15 is not None and force_legacy is False: - obj = decode_by_type_string(type_string, runtime.registry, scale_bytes) + obj = await asyncio.to_thread( + decode_by_type_string, type_string, runtime.registry, scale_bytes + ) if self.decode_ss58: try: type_str_int = int(type_string.split("::")[1]) @@ -2762,19 +2774,34 @@ async def rpc_request( logger.error(f"Substrate Request Exception: {result[payload_id]}") raise SubstrateRequestException(result[payload_id][0]) - @cached_fetcher(max_size=SUBSTRATE_CACHE_METHOD_SIZE) - async def get_block_hash(self, block_id: int) -> str: + async def get_block_hash(self, block_id: Optional[int]) -> str: """ - Retrieves the hash of the specified block number + Retrieves the hash of the specified block number, or the chaintip if None Args: block_id: block number Returns: Hash of the block """ + if block_id is None: + return await self.get_chain_head() + else: + if (block_hash := self.runtime_cache.blocks.get(block_id)) is not None: + return block_hash + + block_hash = await self._cached_get_block_hash(block_id) + self.runtime_cache.add_item(block_hash=block_hash, block=block_id) + return block_hash + + @cached_fetcher(max_size=SUBSTRATE_CACHE_METHOD_SIZE) + async def _cached_get_block_hash(self, block_id: int) -> str: + """ + The design of this method is as such, because it allows for an easy drop-in for a different cache, such + as is the case with DiskCachedAsyncSubstrateInterface._cached_get_block_hash + """ return await self._get_block_hash(block_id) - async def _get_block_hash(self, block_id: int) -> str: + async def _get_block_hash(self, block_id: Optional[int]) -> str: return (await self.rpc_request("chain_getBlockHash", [block_id]))["result"] async def get_chain_head(self) -> str: @@ -3852,18 +3879,20 @@ async def query_map( params=[result_keys, block_hash], runtime=runtime, ) + changes = [] for result_group in response["result"]: - result = decode_query_map( - result_group["changes"], - prefix, - runtime, - param_types, - params, - value_type, - key_hashers, - ignore_decoding_errors, - self.decode_ss58, - ) + changes.extend(result_group["changes"]) + result = await decode_query_map_async( + changes, + prefix, + runtime, + param_types, + params, + value_type, + key_hashers, + ignore_decoding_errors, + self.decode_ss58, + ) else: # storage item and value scale type are not included here because this is batch-decoded in rust page_batches = [ @@ -3881,8 +3910,8 @@ async def query_map( results: RequestResults = await self._make_rpc_request( payloads, runtime=runtime ) - for result in results.values(): - res = result[0] + for result_ in results.values(): + res = result_[0] if "error" in res: err_msg = res["error"]["message"] if ( @@ -3900,7 +3929,7 @@ async def query_map( else: for result_group in res["result"]: changes.extend(result_group["changes"]) - result = decode_query_map( + result = await decode_query_map_async( changes, prefix, runtime, @@ -4113,6 +4142,14 @@ async def result_handler(message: dict, subscription_id) -> tuple[dict, bool]: "extrinsic_hash": "0x{}".format(extrinsic.extrinsic_hash.hex()), "finalized": False, }, True + + elif "params" in message and message["params"].get("result") == "invalid": + failure_message = f"Subscription {subscription_id} invalid: {message}" + async with self.ws as ws: + await ws.unsubscribe(subscription_id) + logger.error(failure_message) + raise SubstrateRequestException(failure_message) + return message, False if wait_for_inclusion or wait_for_finalization: @@ -4250,13 +4287,25 @@ async def get_metadata_event( async def get_block_number(self, block_hash: Optional[str] = None) -> int: """Async version of `substrateinterface.base.get_block_number` method.""" - response = await self.rpc_request("chain_getHeader", [block_hash]) + if block_hash is None: + return await self._get_block_number(None) + if (block := self.runtime_cache.blocks_reverse.get(block_hash)) is not None: + return block + block = await self._cached_get_block_number(block_hash) + self.runtime_cache.add_item(block_hash=block_hash, block=block) + return block - if response["result"]: - return int(response["result"]["number"], 16) - raise SubstrateRequestException( - f"Unable to retrieve block number for {block_hash}" - ) + @cached_fetcher(max_size=SUBSTRATE_CACHE_METHOD_SIZE) + async def _cached_get_block_number(self, block_hash: str) -> int: + """ + The design of this method is as such, because it allows for an easy drop-in for a different cache, such + as is the case with DiskCachedAsyncSubstrateInterface._cached_get_block_number + """ + return await self._get_block_number(block_hash=block_hash) + + async def _get_block_number(self, block_hash: Optional[str]) -> int: + response = await self.rpc_request("chain_getHeader", [block_hash]) + return int(response["result"]["number"], 16) async def close(self): """ @@ -4351,9 +4400,13 @@ async def get_block_runtime_version_for(self, block_hash: str): return await self._get_block_runtime_version_for(block_hash) @async_sql_lru_cache(maxsize=SUBSTRATE_CACHE_METHOD_SIZE) - async def get_block_hash(self, block_id: int) -> str: + async def _cached_get_block_hash(self, block_id: int) -> str: return await self._get_block_hash(block_id) + @async_sql_lru_cache(maxsize=SUBSTRATE_CACHE_METHOD_SIZE) + async def _cached_get_block_number(self, block_hash: str) -> int: + return await self._get_block_number(block_hash=block_hash) + async def get_async_substrate_interface( url: str, diff --git a/async_substrate_interface/py.typed b/async_substrate_interface/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/async_substrate_interface/sync_substrate.py b/async_substrate_interface/sync_substrate.py index 5b6db72..f6ff5b6 100644 --- a/async_substrate_interface/sync_substrate.py +++ b/async_substrate_interface/sync_substrate.py @@ -482,6 +482,18 @@ def retrieve_next_page(self, start_key) -> list: self.last_key = result.last_key return result.records + def retrieve_all_records(self) -> list[Any]: + """ + Retrieves all records from all subsequent pages for the QueryMapResult, + returning them as a list. + + Side effect: + The self.records list will be populated fully after running this method. + """ + for _ in self: + pass + return self.records + def __iter__(self): return self @@ -511,6 +523,7 @@ def __next__(self): self.loading_complete = True raise StopIteration + self.records.extend(next_page) # Update the buffer with the newly fetched records self._buffer = iter(next_page) return next(self._buffer) @@ -2052,8 +2065,21 @@ def rpc_request( else: raise SubstrateRequestException(result[payload_id][0]) + def get_block_hash(self, block_id: Optional[int]) -> str: + """ + Retrieves the block hash for a given block number, or the chaintip hash if None + """ + if block_id is None: + return self.get_chain_head() + else: + if (block_hash := self.runtime_cache.blocks.get(block_id)) is not None: + return block_hash + block_hash = self._get_block_hash(block_id) + self.runtime_cache.add_item(block_hash=block_hash, block=block_id) + return block_hash + @functools.lru_cache(maxsize=SUBSTRATE_CACHE_METHOD_SIZE) - def get_block_hash(self, block_id: int) -> str: + def _get_block_hash(self, block_id: int) -> str: return self.rpc_request("chain_getBlockHash", [block_id])["result"] def get_chain_head(self) -> str: @@ -3247,6 +3273,13 @@ def result_handler(message: dict, subscription_id) -> tuple[dict, bool]: "extrinsic_hash": "0x{}".format(extrinsic.extrinsic_hash.hex()), "finalized": False, }, True + + elif "params" in message and message["params"].get("result") == "invalid": + failure_message = f"Subscription {subscription_id} invalid: {message}" + self.rpc_request("author_unwatchExtrinsic", [subscription_id]) + logger.error(failure_message) + raise SubstrateRequestException(failure_message) + return message, False if wait_for_inclusion or wait_for_finalization: @@ -3380,15 +3413,27 @@ def get_metadata_event( return self._get_metadata_event(module_name, event_name, runtime) def get_block_number(self, block_hash: Optional[str] = None) -> int: - """Async version of `substrateinterface.base.get_block_number` method.""" - response = self.rpc_request("chain_getHeader", [block_hash]) - - if response["result"]: - return int(response["result"]["number"], 16) + """ + Retrieves the block number for a given block hash or chaintip. + """ + if block_hash is None: + return self._get_block_number(None) else: - raise SubstrateRequestException( - f"Unable to determine block number for {block_hash}" - ) + if ( + block_number := self.runtime_cache.blocks_reverse.get(block_hash) + ) is not None: + return block_number + block_number = self._cached_get_block_number(block_hash=block_hash) + self.runtime_cache.add_item(block_hash=block_hash, block=block_number) + return block_number + + @functools.lru_cache(maxsize=SUBSTRATE_CACHE_METHOD_SIZE) + def _cached_get_block_number(self, block_hash: Optional[str]) -> int: + return self._get_block_number(block_hash=block_hash) + + def _get_block_number(self, block_hash: Optional[str]) -> int: + response = self.rpc_request("chain_getHeader", [block_hash]) + return int(response["result"]["number"], 16) def close(self): """ @@ -3404,6 +3449,7 @@ def close(self): self.get_block_runtime_info.cache_clear() self.get_block_runtime_version_for.cache_clear() self.supports_rpc_method.cache_clear() - self.get_block_hash.cache_clear() + self._get_block_hash.cache_clear() + self._cached_get_block_number.cache_clear() encode_scale = SubstrateMixin._encode_scale diff --git a/async_substrate_interface/types.py b/async_substrate_interface/types.py index 842e260..7af5e83 100644 --- a/async_substrate_interface/types.py +++ b/async_substrate_interface/types.py @@ -2,7 +2,7 @@ import logging import os from abc import ABC -from collections import defaultdict, deque +from collections import defaultdict, deque, OrderedDict from collections.abc import Iterable from contextlib import suppress from dataclasses import dataclass @@ -48,6 +48,8 @@ class RuntimeCache: def __init__(self, known_versions: Optional[Sequence[tuple[int, int]]] = None): # {block: block_hash, ...} self.blocks: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE) + # {block_hash: block, ...} + self.blocks_reverse: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE) # {block_hash: specVersion, ...} self.block_hashes: LRUCache = LRUCache(max_size=SUBSTRATE_CACHE_METHOD_SIZE) # {specVersion: Runtime, ...} @@ -87,7 +89,7 @@ def add_known_versions(self, known_versions: Sequence[tuple[int, int]]): def add_item( self, - runtime: "Runtime", + runtime: Optional["Runtime"] = None, block: Optional[int] = None, block_hash: Optional[str] = None, runtime_version: Optional[int] = None, @@ -95,13 +97,15 @@ def add_item( """ Adds a Runtime object to the cache mapped to its version, block number, and/or block hash. """ - self.last_used = runtime + if runtime is not None: + self.last_used = runtime + if runtime_version is not None: + self.versions.set(runtime_version, runtime) if block is not None and block_hash is not None: self.blocks.set(block, block_hash) + self.blocks_reverse.set(block_hash, block) if block_hash is not None and runtime_version is not None: self.block_hashes.set(block_hash, runtime_version) - if runtime_version is not None: - self.versions.set(runtime_version, runtime) def retrieve( self, @@ -114,16 +118,24 @@ def retrieve( Retrieval happens in this order. If no Runtime is found mapped to any of your supplied keys, returns `None`. """ # No reason to do this lookup if the runtime version is already supplied in this call - if block is not None and runtime_version is None and self._known_version_blocks: - # _known_version_blocks excludes the last item (see note in `add_known_versions`) - idx = bisect.bisect_right(self._known_version_blocks, block) - 1 - if idx >= 0: - runtime_version = self.known_versions[idx][1] + if runtime_version is None and self._known_version_blocks: + if block is not None: + block_ = block + elif block_hash is not None: + block_ = self.blocks_reverse.get(block_hash) + else: + block_ = None + if block_ is not None: + # _known_version_blocks excludes the last item (see note in `add_known_versions`) + idx = bisect.bisect_right(self._known_version_blocks, block_) - 1 + if idx >= 0: + runtime_version = self.known_versions[idx][1] runtime = None if block is not None: if block_hash is not None: self.blocks.set(block, block_hash) + self.blocks_reverse.set(block_hash, block) if runtime_version is not None: self.block_hashes.set(block_hash, runtime_version) with suppress(AttributeError): @@ -158,6 +170,9 @@ async def load_from_disk(self, chain_endpoint: str): else: logger.debug("Found runtime mappings in disk cache") self.blocks.cache = block_mapping + self.blocks_reverse.cache = OrderedDict( + {v: k for k, v in block_mapping.items()} + ) self.block_hashes.cache = block_hash_mapping for x, y in runtime_version_mapping.items(): self.versions.cache[x] = Runtime.deserialize(y) diff --git a/async_substrate_interface/utils/cache.py b/async_substrate_interface/utils/cache.py index 431a430..8de077b 100644 --- a/async_substrate_interface/utils/cache.py +++ b/async_substrate_interface/utils/cache.py @@ -108,7 +108,9 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any] await self._db.commit() return result - async def load_runtime_cache(self, chain: str) -> tuple[dict, dict, dict]: + async def load_runtime_cache( + self, chain: str + ) -> tuple[OrderedDict[int, str], OrderedDict[str, int], OrderedDict[int, dict]]: async with self._lock: if not self._db: _ensure_dir() @@ -125,7 +127,7 @@ async def load_runtime_cache(self, chain: str) -> tuple[dict, dict, dict]: async with self._lock: local_chain = await self._create_if_not_exists(chain, table) if local_chain: - return {}, {}, {} + return block_mapping, block_hash_mapping, version_mapping for table_name, mapping in tables.items(): try: async with self._lock: @@ -143,7 +145,7 @@ async def load_runtime_cache(self, chain: str) -> tuple[dict, dict, dict]: mapping[key] = runtime except (pickle.PickleError, sqlite3.Error) as e: logger.exception("Cache error", exc_info=e) - return {}, {}, {} + return block_mapping, block_hash_mapping, version_mapping return block_mapping, block_hash_mapping, version_mapping async def dump_runtime_cache( diff --git a/async_substrate_interface/utils/decoding.py b/async_substrate_interface/utils/decoding.py index 8b191b3..adbba88 100644 --- a/async_substrate_interface/utils/decoding.py +++ b/async_substrate_interface/utils/decoding.py @@ -1,3 +1,4 @@ +import asyncio from typing import Union, TYPE_CHECKING, Any from bt_decode import AxonInfo, PrometheusInfo, decode_list @@ -72,16 +73,34 @@ def _decode_scale_list_with_runtime( return obj -def decode_query_map( +async def _async_decode_scale_list_with_runtime( + type_strings: list[str], + scale_bytes_list: list[bytes], + runtime: "Runtime", + return_scale_obj: bool = False, +): + if runtime.metadata_v15 is not None: + obj = await asyncio.to_thread( + decode_list, type_strings, runtime.registry, scale_bytes_list + ) + else: + obj = [ + legacy_scale_decode(x, y, runtime) + for (x, y) in zip(type_strings, scale_bytes_list) + ] + if return_scale_obj: + return [ScaleObj(x) for x in obj] + else: + return obj + + +def _decode_query_map_pre( result_group_changes: list, prefix, - runtime: "Runtime", param_types, params, value_type, key_hashers, - ignore_decoding_errors, - decode_ss58: bool = False, ): def concat_hash_len(key_hasher: str) -> int: """ @@ -98,7 +117,6 @@ def concat_hash_len(key_hasher: str) -> int: hex_to_bytes_ = hex_to_bytes - result = [] # Determine type string key_type_string_ = [] for n in range(len(params), len(param_types)): @@ -116,11 +134,25 @@ def concat_hash_len(key_hasher: str) -> int: pre_decoded_values.append( hex_to_bytes_(item[1]) if item[1] is not None else b"" ) - all_decoded = _decode_scale_list_with_runtime( - pre_decoded_key_types + pre_decoded_value_types, - pre_decoded_keys + pre_decoded_values, - runtime, + return ( + pre_decoded_key_types, + pre_decoded_value_types, + pre_decoded_keys, + pre_decoded_values, ) + + +def _decode_query_map_post( + pre_decoded_key_types, + pre_decoded_value_types, + all_decoded, + runtime: "Runtime", + param_types, + params, + ignore_decoding_errors, + decode_ss58: bool = False, +): + result = [] middl_index = len(all_decoded) // 2 decoded_keys = all_decoded[:middl_index] decoded_values = all_decoded[middl_index:] @@ -167,6 +199,88 @@ def concat_hash_len(key_hasher: str) -> int: return result +async def decode_query_map_async( + result_group_changes: list, + prefix, + runtime: "Runtime", + param_types, + params, + value_type, + key_hashers, + ignore_decoding_errors, + decode_ss58: bool = False, +): + ( + pre_decoded_key_types, + pre_decoded_value_types, + pre_decoded_keys, + pre_decoded_values, + ) = _decode_query_map_pre( + result_group_changes, + prefix, + param_types, + params, + value_type, + key_hashers, + ) + all_decoded = await _async_decode_scale_list_with_runtime( + pre_decoded_key_types + pre_decoded_value_types, + pre_decoded_keys + pre_decoded_values, + runtime, + ) + return _decode_query_map_post( + pre_decoded_key_types, + pre_decoded_value_types, + all_decoded, + runtime, + param_types, + params, + ignore_decoding_errors, + decode_ss58=decode_ss58, + ) + + +def decode_query_map( + result_group_changes: list, + prefix, + runtime: "Runtime", + param_types, + params, + value_type, + key_hashers, + ignore_decoding_errors, + decode_ss58: bool = False, +): + ( + pre_decoded_key_types, + pre_decoded_value_types, + pre_decoded_keys, + pre_decoded_values, + ) = _decode_query_map_pre( + result_group_changes, + prefix, + param_types, + params, + value_type, + key_hashers, + ) + all_decoded = _decode_scale_list_with_runtime( + pre_decoded_key_types + pre_decoded_value_types, + pre_decoded_keys + pre_decoded_values, + runtime, + ) + return _decode_query_map_post( + pre_decoded_key_types, + pre_decoded_value_types, + all_decoded, + runtime, + param_types, + params, + ignore_decoding_errors, + decode_ss58=decode_ss58, + ) + + def legacy_scale_decode( type_string: str, scale_bytes: Union[str, bytes, ScaleBytes], runtime: "Runtime" ): diff --git a/pyproject.toml b/pyproject.toml index 5eb5015..c7e6eef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "async-substrate-interface" -version = "1.6.1" +version = "1.6.2" description = "Asyncio library for interacting with substrate. Mostly API-compatible with py-substrate-interface" readme = "README.md" license = { file = "LICENSE" } @@ -15,7 +15,7 @@ dependencies = [ "xxhash", ] -requires-python = ">=3.9,<3.15" +requires-python = ">=3.10,<3.15" authors = [ { name = "Opentensor Foundation" }, @@ -33,7 +33,6 @@ classifiers = [ "Topic :: Software Development :: Libraries", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -45,6 +44,9 @@ classifiers = [ [project.urls] Repository = "https://github.com/opentensor/async-substrate-interface/" +[tool.setuptools.package-data] +async_substrate_interface = ["py.typed"] + [build-system] requires = ["setuptools~=70.0.0", "wheel"] build-backend = "setuptools.build_meta" diff --git a/tests/benchmarks/benchmark_to_thread_decoding.py b/tests/benchmarks/benchmark_to_thread_decoding.py new file mode 100644 index 0000000..1460770 --- /dev/null +++ b/tests/benchmarks/benchmark_to_thread_decoding.py @@ -0,0 +1,114 @@ +""" +Results: + +93 items + +original (not threading) decoding: +median 3.9731219584937207 +mean 3.810443129093619 +stdev 0.9819147187144933 + +to_thread decoding: +median 2.7345559374953154 +mean 2.7784998625924344 +stdev 0.11112115146834547 + +""" + +import asyncio + +from scalecodec import ss58_encode + +from async_substrate_interface.async_substrate import ( + AsyncSubstrateInterface, + AsyncQueryMapResult, +) +from tests.helpers.settings import LATENT_LITE_ENTRYPOINT + + +async def benchmark_to_thread_decoding(): + async def _query_alpha(hk_: str, sem: asyncio.Semaphore) -> list: + try: + async with sem: + results = [] + qm: AsyncQueryMapResult = await substrate.query_map( + "SubtensorModule", + "Alpha", + params=[hk_], + block_hash=block_hash, + fully_exhaust=False, + page_size=100, + ) + async for result in qm: + results.append(result) + return results + except Exception as e: + raise type(e)(f"[hotkey={hk_}] {e}") from e + + loop = asyncio.get_running_loop() + async with AsyncSubstrateInterface( + LATENT_LITE_ENTRYPOINT, ss58_format=42, chain_name="Bittensor" + ) as substrate: + block_hash = ( + "0xb0f4a6fb95279f035f145600590e6d5508edea986c2e703e16b6bfbe08f29dbd" + ) + start = loop.time() + total_hotkey_alpha_q, total_hotkey_shares_q = await asyncio.gather( + substrate.query_map( + "SubtensorModule", + "TotalHotkeyAlpha", + block_hash=block_hash, + page_size=100, + fully_exhaust=False, + params=[], + ), + substrate.query_map( + "SubtensorModule", + "TotalHotkeyShares", + block_hash=block_hash, + fully_exhaust=False, + page_size=100, + params=[], + ), + ) + hotkeys = set() + tasks: list[asyncio.Task] = [] + sema4 = asyncio.Semaphore(100) + for (hk, netuid), alpha in total_hotkey_alpha_q.records: + hotkey = ss58_encode(bytes(hk[0]), 42) + if alpha.value > 0: + if hotkey not in hotkeys: + hotkeys.add(hotkey) + tasks.append( + loop.create_task(_query_alpha(hotkey, sema4), name=hotkey) + ) + for (hk, netuid), alpha_bits in total_hotkey_shares_q.records: + hotkey = ss58_encode(bytes(hk[0]), 42) + alpha_bits_value = alpha_bits.value["bits"] + if alpha_bits_value > 0: + if hotkey not in hotkeys: + hotkeys.add(hotkey) + tasks.append( + loop.create_task(_query_alpha(hotkey, sema4), name=hotkey) + ) + await asyncio.gather(*tasks) + end = loop.time() + return len(tasks), end - start + + +if __name__ == "__main__": + results = [] + for _ in range(10): + len_tasks, time = asyncio.run(benchmark_to_thread_decoding()) + results.append((len_tasks, time)) + + for len_tasks, time in results: + if len_tasks != 910: + print(len_tasks, time) + time_results = [x[1] for x in results] + import statistics + + median = statistics.median(time_results) + mean = statistics.mean(time_results) + stdev = statistics.stdev(time_results) + print(median, mean, stdev) diff --git a/tests/unit_tests/asyncio_/test_env_vars.py b/tests/unit_tests/asyncio_/test_env_vars.py index 10a0933..3e35565 100644 --- a/tests/unit_tests/asyncio_/test_env_vars.py +++ b/tests/unit_tests/asyncio_/test_env_vars.py @@ -10,7 +10,7 @@ def test_env_vars(monkeypatch): assert asi.get_block_runtime_info._max_size == 9 assert asi.get_parent_block_hash._max_size == 10 assert asi.get_block_runtime_version_for._max_size == 10 - assert asi.get_block_hash._max_size == 10 + assert asi._cached_get_block_hash._max_size == 10 def test_defaults(): @@ -20,4 +20,4 @@ def test_defaults(): assert asi.get_block_runtime_info._max_size == 16 assert asi.get_parent_block_hash._max_size == 512 assert asi.get_block_runtime_version_for._max_size == 512 - assert asi.get_block_hash._max_size == 512 + assert asi._cached_get_block_hash._max_size == 512 diff --git a/tests/unit_tests/asyncio_/test_substrate_interface.py b/tests/unit_tests/asyncio_/test_substrate_interface.py index 721804b..632af81 100644 --- a/tests/unit_tests/asyncio_/test_substrate_interface.py +++ b/tests/unit_tests/asyncio_/test_substrate_interface.py @@ -7,6 +7,7 @@ from websockets.protocol import State from async_substrate_interface.async_substrate import ( + AsyncQueryMapResult, AsyncSubstrateInterface, get_async_substrate_interface, ) @@ -175,3 +176,114 @@ async def test_memory_leak(): f"Loop {i}: diff={total_diff / 1024:.2f} KiB, current={current / 1024:.2f} KiB, " f"peak={peak / 1024:.2f} KiB" ) + + +@pytest.mark.asyncio +async def test_async_query_map_result_retrieve_all_records(): + """Test that retrieve_all_records fetches all pages and returns the full record list.""" + page1 = [("key1", "val1"), ("key2", "val2")] + page2 = [("key3", "val3"), ("key4", "val4")] + page3 = [("key5", "val5")] # partial page signals loading_complete + + mock_substrate = MagicMock() + + qm = AsyncQueryMapResult( + records=list(page1), + page_size=2, + substrate=mock_substrate, + module="TestModule", + storage_function="TestStorage", + last_key="key2", + ) + + # Build mock pages: first call returns page2 (full page), second returns page3 (partial) + page2_result = AsyncQueryMapResult( + records=list(page2), + page_size=2, + substrate=mock_substrate, + last_key="key4", + ) + page3_result = AsyncQueryMapResult( + records=list(page3), + page_size=2, + substrate=mock_substrate, + last_key="key5", + ) + mock_substrate.query_map = AsyncMock(side_effect=[page2_result, page3_result]) + + result = await qm.retrieve_all_records() + + assert result == page1 + page2 + page3 + assert qm.records == page1 + page2 + page3 + assert qm.loading_complete is True + assert mock_substrate.query_map.call_count == 2 + + +class TestGetBlockHash: + @pytest.fixture + def substrate(self): + s = AsyncSubstrateInterface("ws://localhost", _mock=True) + s.runtime_cache = MagicMock() + s._cached_get_block_hash = AsyncMock(return_value="0xCACHED") + s.get_chain_head = AsyncMock(return_value="0xHEAD") + return s + + @pytest.mark.asyncio + async def test_none_block_id_returns_chain_head(self, substrate): + result = await substrate.get_block_hash(None) + assert result == "0xHEAD" + substrate.get_chain_head.assert_awaited_once() + substrate._cached_get_block_hash.assert_not_awaited() + + @pytest.mark.asyncio + async def test_cache_hit_returns_cached_hash(self, substrate): + substrate.runtime_cache.blocks.get.return_value = "0xFROMCACHE" + result = await substrate.get_block_hash(42) + assert result == "0xFROMCACHE" + substrate.runtime_cache.blocks.get.assert_called_once_with(42) + substrate._cached_get_block_hash.assert_not_awaited() + + @pytest.mark.asyncio + async def test_cache_miss_fetches_and_stores(self, substrate): + substrate.runtime_cache.blocks.get.return_value = None + result = await substrate.get_block_hash(42) + assert result == "0xCACHED" + substrate._cached_get_block_hash.assert_awaited_once_with(42) + substrate.runtime_cache.add_item.assert_called_once_with( + block_hash="0xCACHED", block=42 + ) + + +class TestGetBlockNumber: + @pytest.fixture + def substrate(self): + s = AsyncSubstrateInterface("ws://localhost", _mock=True) + s.runtime_cache = MagicMock() + s._cached_get_block_number = AsyncMock(return_value=100) + s._get_block_number = AsyncMock(return_value=99) + return s + + @pytest.mark.asyncio + async def test_none_block_hash_calls_get_block_number_directly(self, substrate): + result = await substrate.get_block_number(None) + assert result == 99 + substrate._get_block_number.assert_awaited_once_with(None) + substrate._cached_get_block_number.assert_not_awaited() + + @pytest.mark.asyncio + async def test_cache_hit_returns_cached_number(self, substrate): + substrate.runtime_cache.blocks_reverse.get.return_value = 42 + result = await substrate.get_block_number("0xABC") + assert result == 42 + substrate.runtime_cache.blocks_reverse.get.assert_called_once_with("0xABC") + substrate._cached_get_block_number.assert_not_awaited() + + @pytest.mark.asyncio + async def test_cache_miss_fetches_and_stores(self, substrate): + substrate.runtime_cache.blocks_reverse.get.return_value = None + result = await substrate.get_block_number("0xABC") + assert result == 100 + substrate._cached_get_block_number.assert_awaited_once_with("0xABC") + substrate.runtime_cache.add_item.assert_called_once_with( + block_hash="0xABC", block=100 + ) diff --git a/tests/unit_tests/sync/test_env_vars.py b/tests/unit_tests/sync/test_env_vars.py index 05d5ded..e53991c 100644 --- a/tests/unit_tests/sync/test_env_vars.py +++ b/tests/unit_tests/sync/test_env_vars.py @@ -10,7 +10,7 @@ def test_env_vars(monkeypatch): assert asi.get_block_runtime_info.cache_parameters()["maxsize"] == 9 assert asi.get_parent_block_hash.cache_parameters()["maxsize"] == 10 assert asi.get_block_runtime_version_for.cache_parameters()["maxsize"] == 10 - assert asi.get_block_hash.cache_parameters()["maxsize"] == 10 + assert asi._get_block_hash.cache_parameters()["maxsize"] == 10 def test_defaults(): @@ -20,4 +20,4 @@ def test_defaults(): assert asi.get_block_runtime_info.cache_parameters()["maxsize"] == 16 assert asi.get_parent_block_hash.cache_parameters()["maxsize"] == 512 assert asi.get_block_runtime_version_for.cache_parameters()["maxsize"] == 512 - assert asi.get_block_hash.cache_parameters()["maxsize"] == 512 + assert asi._get_block_hash.cache_parameters()["maxsize"] == 512 diff --git a/tests/unit_tests/sync/test_substrate_interface.py b/tests/unit_tests/sync/test_substrate_interface.py index 54a5b7d..94a43a0 100644 --- a/tests/unit_tests/sync/test_substrate_interface.py +++ b/tests/unit_tests/sync/test_substrate_interface.py @@ -1,7 +1,7 @@ import tracemalloc from unittest.mock import MagicMock -from async_substrate_interface.sync_substrate import SubstrateInterface +from async_substrate_interface.sync_substrate import SubstrateInterface, QueryMapResult from async_substrate_interface.types import ScaleObj from tests.helpers.settings import ARCHIVE_ENTRYPOINT, LATENT_LITE_ENTRYPOINT @@ -122,3 +122,111 @@ def test_memory_leak(): f"Loop {i}: diff={total_diff / 1024:.2f} KiB, current={current / 1024:.2f} KiB, " f"peak={peak / 1024:.2f} KiB" ) + + +def test_async_query_map_result_retrieve_all_records(): + """Test that retrieve_all_records fetches all pages and returns the full record list.""" + page1 = [("key1", "val1"), ("key2", "val2")] + page2 = [("key3", "val3"), ("key4", "val4")] + page3 = [("key5", "val5")] # partial page signals loading_complete + + mock_substrate = MagicMock() + + qm = QueryMapResult( + records=list(page1), + page_size=2, + substrate=mock_substrate, + module="TestModule", + storage_function="TestStorage", + last_key="key2", + ) + + # Build mock pages: first call returns page2 (full page), second returns page3 (partial) + page2_result = QueryMapResult( + records=list(page2), + page_size=2, + substrate=mock_substrate, + last_key="key4", + ) + page3_result = QueryMapResult( + records=list(page3), + page_size=2, + substrate=mock_substrate, + last_key="key5", + ) + mock_substrate.query_map = MagicMock(side_effect=[page2_result, page3_result]) + + result = qm.retrieve_all_records() + + assert result == page1 + page2 + page3 + assert qm.records == page1 + page2 + page3 + assert qm.loading_complete is True + assert mock_substrate.query_map.call_count == 2 + + +class TestGetBlockHash: + def _make_substrate(self): + s = SubstrateInterface("ws://localhost", _mock=True) + s.runtime_cache = MagicMock() + s._get_block_hash = MagicMock(return_value="0xCACHED") + s.get_chain_head = MagicMock(return_value="0xHEAD") + return s + + def test_none_block_id_returns_chain_head(self): + substrate = self._make_substrate() + result = substrate.get_block_hash(None) + assert result == "0xHEAD" + substrate.get_chain_head.assert_called_once() + substrate._get_block_hash.assert_not_called() + + def test_cache_hit_returns_cached_hash(self): + substrate = self._make_substrate() + substrate.runtime_cache.blocks.get.return_value = "0xFROMCACHE" + result = substrate.get_block_hash(42) + assert result == "0xFROMCACHE" + substrate.runtime_cache.blocks.get.assert_called_once_with(42) + substrate._get_block_hash.assert_not_called() + + def test_cache_miss_fetches_and_stores(self): + substrate = self._make_substrate() + substrate.runtime_cache.blocks.get.return_value = None + result = substrate.get_block_hash(42) + assert result == "0xCACHED" + substrate._get_block_hash.assert_called_once_with(42) + substrate.runtime_cache.add_item.assert_called_once_with( + block_hash="0xCACHED", block=42 + ) + + +class TestGetBlockNumber: + def _make_substrate(self): + s = SubstrateInterface("ws://localhost", _mock=True) + s.runtime_cache = MagicMock() + s._cached_get_block_number = MagicMock(return_value=100) + s._get_block_number = MagicMock(return_value=99) + return s + + def test_none_block_hash_calls_get_block_number_directly(self): + substrate = self._make_substrate() + result = substrate.get_block_number(None) + assert result == 99 + substrate._get_block_number.assert_called_once_with(None) + substrate._cached_get_block_number.assert_not_called() + + def test_cache_hit_returns_cached_number(self): + substrate = self._make_substrate() + substrate.runtime_cache.blocks_reverse.get.return_value = 42 + result = substrate.get_block_number("0xABC") + assert result == 42 + substrate.runtime_cache.blocks_reverse.get.assert_called_once_with("0xABC") + substrate._cached_get_block_number.assert_not_called() + + def test_cache_miss_fetches_and_stores(self): + substrate = self._make_substrate() + substrate.runtime_cache.blocks_reverse.get.return_value = None + result = substrate.get_block_number("0xABC") + assert result == 100 + substrate._cached_get_block_number.assert_called_once_with(block_hash="0xABC") + substrate.runtime_cache.add_item.assert_called_once_with( + block_hash="0xABC", block=100 + )