Skip to content

Commit edec759

Browse files
tkumor3ddoktorski
andauthored
Simplify using a custom network (#1327)
* Simplifies using a custom network * update test * update calculate_hash params type * for chain id in account map string to hash and then to integate * add test and getting chain_id from network in default * fix test * fixes after remove goerli * feedback * Improve logic in `Account` constructor * Remove `chain` parameter from `Account` methods * Add `RECOGNIZED_CHAIN_IDS` constant * Update migration guide --------- Co-authored-by: Dariusz Doktorski <dariusz.doktorski@swmansion.com>
1 parent 4825ff5 commit edec759

File tree

14 files changed

+115
-74
lines changed

14 files changed

+115
-74
lines changed

docs/migration_guide.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,17 @@ Version 0.22.0 of **starknet.py**
1111
-----------------------
1212

1313
1. Support for Goerli has been removed
14-
2. ``StarknetChainId.SEPOLIA_TESTNET`` has been renamed to ``StarknetChainId.SEPOLIA``
14+
15+
.. currentmodule:: starknet_py.net.models
16+
17+
2. ``StarknetChainId.SEPOLIA_TESTNET`` has been renamed to :class:`StarknetChainId.SEPOLIA`
18+
19+
.. currentmodule:: starknet_py.net.account.account
20+
21+
3. Parameter ``chain`` has been removed from the methods :meth:`Account.deploy_account_v1` and :meth:`Account.deploy_account_v3`
22+
4. Parameter ``chain_id`` has been removed from the method :meth:`~Account.get_balance`
23+
24+
1525

1626
******************************
1727
0.21.0 Migration guide

starknet_py/net/account/account.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
Tag,
2424
)
2525
from starknet_py.net.full_node_client import FullNodeClient
26-
from starknet_py.net.models import AddressRepresentation, StarknetChainId, parse_address
26+
from starknet_py.net.models import AddressRepresentation, parse_address
27+
from starknet_py.net.models.chains import RECOGNIZED_CHAIN_IDS, Chain, parse_chain
2728
from starknet_py.net.models.transaction import (
2829
AccountTransaction,
2930
DeclareV1,
@@ -72,7 +73,7 @@ def __init__(
7273
client: Client,
7374
signer: Optional[BaseSigner] = None,
7475
key_pair: Optional[KeyPair] = None,
75-
chain: Optional[StarknetChainId] = None,
76+
chain: Optional[Chain] = None,
7677
):
7778
"""
7879
:param address: Address of the account contract.
@@ -81,11 +82,18 @@ def __init__(
8182
If none is provided, default
8283
:py:class:`starknet_py.net.signer.stark_curve_signer.StarkCurveSigner` is used.
8384
:param key_pair: Key pair that will be used to create a default `Signer`.
84-
:param chain: ChainId of the chain used to create the default signer.
85+
:param chain: Chain ID associated with the account.
86+
This can be supplied in multiple formats:
87+
88+
- an enum :py:class:`starknet_py.net.models.StarknetChainId`
89+
- a string name (e.g. 'SN_SEPOLIA')
90+
- a hexadecimal value (e.g. '0x1')
91+
- an integer (e.g. 1)
8592
"""
8693
self._address = parse_address(address)
8794
self._client = client
8895
self._cairo_version = None
96+
self._chain_id = None if chain is None else parse_chain(chain)
8997

9098
if signer is not None and key_pair is not None:
9199
raise ValueError("Arguments signer and key_pair are mutually exclusive.")
@@ -95,14 +103,13 @@ def __init__(
95103
raise ValueError(
96104
"Either a signer or a key_pair must be provided in Account constructor."
97105
)
98-
if chain is None:
106+
if self._chain_id is None:
99107
raise ValueError("One of chain or signer must be provided.")
100108

101109
signer = StarkCurveSigner(
102-
account_address=self.address, key_pair=key_pair, chain_id=chain
110+
account_address=self.address, key_pair=key_pair, chain_id=self._chain_id
103111
)
104112
self.signer: BaseSigner = signer
105-
self._chain_id = chain
106113

107114
@property
108115
def address(self) -> int:
@@ -293,13 +300,18 @@ async def get_nonce(
293300
async def get_balance(
294301
self,
295302
token_address: Optional[AddressRepresentation] = None,
296-
chain_id: Optional[StarknetChainId] = None,
297303
*,
298304
block_hash: Optional[Union[Hash, Tag]] = None,
299305
block_number: Optional[Union[int, Tag]] = None,
300306
) -> int:
301307
if token_address is None:
302-
token_address = self._default_token_address_for_chain(chain_id)
308+
chain_id = await self._get_chain_id()
309+
if chain_id in RECOGNIZED_CHAIN_IDS:
310+
token_address = FEE_CONTRACT_ADDRESS
311+
else:
312+
raise ValueError(
313+
"Argument token_address must be specified when using a custom network."
314+
)
303315

304316
low, high = await self._client.call_contract(
305317
Call(
@@ -596,7 +608,6 @@ async def deploy_account_v1(
596608
salt: int,
597609
key_pair: KeyPair,
598610
client: Client,
599-
chain: StarknetChainId,
600611
constructor_calldata: Optional[List[int]] = None,
601612
nonce: int = 0,
602613
max_fee: Optional[int] = None,
@@ -618,7 +629,6 @@ async def deploy_account_v1(
618629
:param salt: Salt used to calculate the address.
619630
:param key_pair: KeyPair used to calculate address and sign deploy account transaction.
620631
:param client: Client instance used for deployment.
621-
:param chain: Id of the Starknet chain used.
622632
:param constructor_calldata: Optional calldata to account contract constructor. If ``None`` is passed,
623633
``[key_pair.public_key]`` will be used as calldata.
624634
:param nonce: Nonce of the transaction.
@@ -631,6 +641,8 @@ async def deploy_account_v1(
631641
else [key_pair.public_key]
632642
)
633643

644+
chain = await client.get_chain_id()
645+
634646
account = _prepare_account_to_deploy(
635647
address=address,
636648
class_hash=class_hash,
@@ -650,7 +662,7 @@ async def deploy_account_v1(
650662
auto_estimate=auto_estimate,
651663
)
652664

653-
if chain in StarknetChainId:
665+
if parse_chain(chain) in RECOGNIZED_CHAIN_IDS:
654666
balance = await account.get_balance()
655667
if balance < deploy_account_tx.max_fee:
656668
raise ValueError(
@@ -671,7 +683,6 @@ async def deploy_account_v3(
671683
salt: int,
672684
key_pair: KeyPair,
673685
client: Client,
674-
chain: StarknetChainId,
675686
constructor_calldata: Optional[List[int]] = None,
676687
nonce: int = 0,
677688
l1_resource_bounds: Optional[ResourceBounds] = None,
@@ -690,7 +701,6 @@ async def deploy_account_v3(
690701
:param salt: Salt used to calculate the address.
691702
:param key_pair: KeyPair used to calculate address and sign deploy account transaction.
692703
:param client: Client instance used for deployment.
693-
:param chain: Id of the Starknet chain used.
694704
:param constructor_calldata: Optional calldata to account contract constructor. If ``None`` is passed,
695705
``[key_pair.public_key]`` will be used as calldata.
696706
:param nonce: Nonce of the transaction.
@@ -704,6 +714,8 @@ async def deploy_account_v3(
704714
else [key_pair.public_key]
705715
)
706716

717+
chain = await client.get_chain_id()
718+
707719
account = _prepare_account_to_deploy(
708720
address=address,
709721
class_hash=class_hash,
@@ -729,19 +741,12 @@ async def deploy_account_v3(
729741
hash=result.transaction_hash, account=account, _client=account.client
730742
)
731743

732-
def _default_token_address_for_chain(
733-
self, chain_id: Optional[StarknetChainId] = None
734-
) -> str:
735-
if (chain_id or self._chain_id) not in [
736-
StarknetChainId.SEPOLIA,
737-
StarknetChainId.SEPOLIA_INTEGRATION,
738-
StarknetChainId.MAINNET,
739-
]:
740-
raise ValueError(
741-
"Argument token_address must be specified when using a custom network."
742-
)
744+
async def _get_chain_id(self) -> int:
745+
if self._chain_id is None:
746+
chain = await self._client.get_chain_id()
747+
self._chain_id = parse_chain(chain)
743748

744-
return FEE_CONTRACT_ADDRESS
749+
return self._chain_id
745750

746751

747752
def _prepare_account_to_deploy(
@@ -750,7 +755,7 @@ def _prepare_account_to_deploy(
750755
salt: int,
751756
key_pair: KeyPair,
752757
client: Client,
753-
chain: StarknetChainId,
758+
chain: Chain,
754759
calldata: List[int],
755760
) -> Account:
756761
# pylint: disable=too-many-arguments

starknet_py/net/account/account_test.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,27 +88,42 @@ def test_create_account():
8888
assert account.signer.public_key == key_pair.public_key
8989

9090

91-
def test_create_account_from_signer():
91+
@pytest.mark.parametrize(
92+
"chain",
93+
[
94+
StarknetChainId.SEPOLIA,
95+
"SN_SEPOLIA",
96+
"0x534e5f5345504f4c4941",
97+
393402133025997798000961,
98+
],
99+
)
100+
def test_create_account_parses_chain(chain):
101+
key_pair = KeyPair.from_private_key(0x111)
102+
account = Account(
103+
address=0x1,
104+
client=FullNodeClient(node_url=""),
105+
key_pair=key_pair,
106+
chain=chain,
107+
)
108+
109+
assert account.address == 0x1
110+
assert account.signer.public_key == key_pair.public_key
111+
assert isinstance(account.signer, StarkCurveSigner)
112+
assert account.signer.chain_id == 0x534E5F5345504F4C4941
113+
114+
115+
def test_create_account_from_signer(client):
92116
signer = StarkCurveSigner(
93117
account_address=0x1,
94118
key_pair=KeyPair.from_private_key(0x111),
95119
chain_id=StarknetChainId.MAINNET,
96120
)
97-
account = Account(address=0x1, client=FullNodeClient(node_url=""), signer=signer)
121+
account = Account(address=0x1, client=client, signer=signer)
98122

99123
assert account.address == 0x1
100124
assert account.signer == signer
101125

102126

103-
def test_create_account_raises_on_no_chain_and_signer():
104-
with pytest.raises(ValueError, match="One of chain or signer must be provided"):
105-
Account(
106-
address=0x1,
107-
client=FullNodeClient(node_url=""),
108-
key_pair=KeyPair.from_private_key(0x111),
109-
)
110-
111-
112127
def test_create_account_raises_on_no_keypair_and_signer():
113128
with pytest.raises(
114129
ValueError,

starknet_py/net/account/base_account.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
SentTransactionResponse,
1111
Tag,
1212
)
13-
from starknet_py.net.models import AddressRepresentation, StarknetChainId
13+
from starknet_py.net.models import AddressRepresentation
1414
from starknet_py.net.models.transaction import (
1515
AccountTransaction,
1616
DeclareV1,
@@ -93,18 +93,15 @@ async def get_nonce(
9393
async def get_balance(
9494
self,
9595
token_address: Optional[AddressRepresentation] = None,
96-
chain_id: Optional[StarknetChainId] = None,
9796
*,
9897
block_hash: Optional[Union[Hash, Tag]] = None,
9998
block_number: Optional[Union[int, Tag]] = None,
10099
) -> int:
101100
"""
102-
Checks account's balance of specified token.
101+
Checks account's balance of the specified token.
102+
By default, it uses the L2 ETH address for mainnet and sepolia networks.
103103
104104
:param token_address: Address of the ERC20 contract.
105-
:param chain_id: Identifier of the Starknet chain used.
106-
If token_address is not specified it will be used to determine network's payment token address.
107-
If token_address is provided, chain_id will be ignored.
108105
:param block_hash: Block's hash or literals `"pending"` or `"latest"`
109106
:param block_number: Block's number or literals `"pending"` or `"latest"`
110107
:return: Token balance.

starknet_py/net/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,3 +301,7 @@ async def get_contract_nonce(
301301
:param block_number: Block's number or literals `"pending"` or `"latest"`
302302
:return: The last nonce used for the given contract
303303
"""
304+
305+
@abstractmethod
306+
async def get_chain_id(self) -> str:
307+
"""Return the currently configured Starknet chain id"""

starknet_py/net/full_node_client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,6 @@ async def get_block_hash_and_number(self) -> BlockHashAndNumber:
455455
return cast(BlockHashAndNumber, BlockHashAndNumberSchema().load(res))
456456

457457
async def get_chain_id(self) -> str:
458-
"""Return the currently configured Starknet chain id"""
459458
return await self._client.call(method_name="chainId", params={})
460459

461460
async def get_syncing_status(self) -> Union[bool, SyncStatus]:

starknet_py/net/models/chains.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import IntEnum
2-
from typing import Optional
2+
from typing import Optional, Union
33

44
from starknet_py.common import int_from_bytes
55
from starknet_py.net.networks import MAINNET, SEPOLIA, SEPOLIA_INTEGRATION, Network
@@ -15,6 +15,13 @@ class StarknetChainId(IntEnum):
1515
SEPOLIA_INTEGRATION = int_from_bytes(b"SN_INTEGRATION_SEPOLIA")
1616

1717

18+
RECOGNIZED_CHAIN_IDS = [
19+
StarknetChainId.MAINNET,
20+
StarknetChainId.SEPOLIA,
21+
StarknetChainId.SEPOLIA_INTEGRATION,
22+
]
23+
24+
1825
def chain_from_network(
1926
net: Network, chain: Optional[StarknetChainId] = None
2027
) -> StarknetChainId:
@@ -31,3 +38,17 @@ def chain_from_network(
3138
raise ValueError("Chain is required when not using predefined networks.")
3239

3340
return chain
41+
42+
43+
ChainId = Union[StarknetChainId, int]
44+
Chain = Union[str, ChainId]
45+
46+
47+
def parse_chain(chain: Chain) -> ChainId:
48+
if isinstance(chain, str):
49+
try:
50+
return int(chain, 16)
51+
except ValueError:
52+
return int_from_bytes(chain.encode())
53+
else:
54+
return chain

0 commit comments

Comments
 (0)