Skip to content

Commit f32c6ff

Browse files
authored
Refactor code for creating broadcasted transaction (#1288)
* Move _create_broadcasted_txn to client utils * add broadcasted schemas * refactor * clean up * clean up * fix linter * rename * clean up code * feedback * revert transaction types * adds tests * feedback * fix * fix tests * Remove obsolate code broadcasted txn (#1297)
1 parent f992bc8 commit f32c6ff

File tree

5 files changed

+218
-161
lines changed

5 files changed

+218
-161
lines changed

starknet_py/net/client_test.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,30 @@
1+
import dataclasses
2+
13
import pytest
24

35
from starknet_py.constants import ADDR_BOUND
6+
from starknet_py.hash.selector import get_selector_from_name
47
from starknet_py.net.client_models import (
8+
Call,
59
DAMode,
610
ResourceBoundsMapping,
711
Transaction,
12+
TransactionType,
813
TransactionV3,
914
)
15+
from starknet_py.net.client_utils import _create_broadcasted_txn
1016
from starknet_py.net.full_node_client import _to_storage_key
1117
from starknet_py.net.http_client import RpcHttpClient, ServerError
18+
from starknet_py.net.models.transaction import (
19+
DeclareV1,
20+
DeclareV2,
21+
DeclareV3,
22+
DeployAccountV1,
23+
DeployAccountV3,
24+
InvokeV1,
25+
InvokeV3,
26+
)
27+
from starknet_py.tests.e2e.fixtures.constants import MAX_FEE, MAX_RESOURCE_BOUNDS_L1
1228

1329

1430
@pytest.mark.asyncio
@@ -74,3 +90,118 @@ def test_get_rpc_storage_key(key, expected):
7490
def test_get_rpc_storage_key_raises_on_non_representable_key(key):
7591
with pytest.raises(ValueError, match="cannot be represented"):
7692
_to_storage_key(key)
93+
94+
95+
@pytest.mark.asyncio
96+
async def test_broadcasted_txn_declare_v3(
97+
account, abi_types_compiled_contract_and_class_hash
98+
):
99+
declare_v3 = await account.sign_declare_v3(
100+
compiled_contract=abi_types_compiled_contract_and_class_hash[0],
101+
compiled_class_hash=abi_types_compiled_contract_and_class_hash[1],
102+
l1_resource_bounds=MAX_RESOURCE_BOUNDS_L1,
103+
)
104+
105+
brodcasted_txn = _create_broadcasted_txn(declare_v3)
106+
assert brodcasted_txn["type"] == TransactionType.DECLARE.name
107+
108+
expected_keys = dataclasses.fields(DeclareV3)
109+
assert all(key.name in brodcasted_txn for key in expected_keys)
110+
111+
112+
@pytest.mark.asyncio
113+
async def test_broadcasted_txn_declare_v2(
114+
account, abi_types_compiled_contract_and_class_hash
115+
):
116+
declare_v2 = await account.sign_declare_v2(
117+
compiled_contract=abi_types_compiled_contract_and_class_hash[0],
118+
compiled_class_hash=abi_types_compiled_contract_and_class_hash[1],
119+
max_fee=MAX_FEE,
120+
)
121+
122+
brodcasted_txn = _create_broadcasted_txn(declare_v2)
123+
124+
assert brodcasted_txn["type"] == TransactionType.DECLARE.name
125+
126+
expected_keys = dataclasses.fields(DeclareV2)
127+
assert all(key.name in brodcasted_txn for key in expected_keys)
128+
129+
130+
@pytest.mark.asyncio
131+
async def test_broadcasted_txn_declare_v1(account, map_compiled_contract):
132+
declare_v1 = await account.sign_declare_v1(
133+
compiled_contract=map_compiled_contract,
134+
max_fee=MAX_FEE,
135+
)
136+
137+
brodcasted_txn = _create_broadcasted_txn(declare_v1)
138+
139+
assert brodcasted_txn["type"] == TransactionType.DECLARE.name
140+
141+
expected_keys = dataclasses.fields(DeclareV1)
142+
assert all(key.name in brodcasted_txn for key in expected_keys)
143+
144+
145+
@pytest.mark.asyncio
146+
async def test_broadcasted_txn_invoke_v3(account, map_contract):
147+
invoke_tx = await account.sign_invoke_v3(
148+
calls=Call(map_contract.address, get_selector_from_name("put"), [3, 4]),
149+
l1_resource_bounds=MAX_RESOURCE_BOUNDS_L1,
150+
)
151+
152+
brodcasted_txn = _create_broadcasted_txn(invoke_tx)
153+
154+
assert brodcasted_txn["type"] == TransactionType.INVOKE.name
155+
156+
expected_keys = dataclasses.fields(InvokeV3)
157+
assert all(key.name in brodcasted_txn for key in expected_keys)
158+
159+
160+
@pytest.mark.asyncio
161+
async def test_broadcasted_txn_invoke_v1(account, map_contract):
162+
invoke_tx = await account.sign_invoke_v1(
163+
calls=Call(map_contract.address, get_selector_from_name("put"), [3, 4]),
164+
max_fee=int(1e16),
165+
)
166+
167+
brodcasted_txn = _create_broadcasted_txn(invoke_tx)
168+
169+
assert brodcasted_txn["type"] == TransactionType.INVOKE.name
170+
171+
expected_keys = dataclasses.fields(InvokeV1)
172+
assert all(key.name in brodcasted_txn for key in expected_keys)
173+
174+
175+
@pytest.mark.asyncio
176+
async def test_broadcasted_txn_deploy_account_v3(account):
177+
class_hash = 0x1234
178+
salt = 0x123
179+
calldata = [1, 2, 3]
180+
signed_tx = await account.sign_deploy_account_v3(
181+
class_hash,
182+
salt,
183+
l1_resource_bounds=MAX_RESOURCE_BOUNDS_L1,
184+
constructor_calldata=calldata,
185+
)
186+
brodcasted_txn = _create_broadcasted_txn(signed_tx)
187+
assert brodcasted_txn["type"] == TransactionType.DEPLOY_ACCOUNT.name
188+
189+
expected_keys = dataclasses.fields(DeployAccountV3)
190+
assert all(key.name in brodcasted_txn for key in expected_keys)
191+
192+
193+
@pytest.mark.asyncio
194+
async def test_broadcasted_txn_deploy_account_v1(account):
195+
class_hash = 0x1234
196+
salt = 0x123
197+
calldata = [1, 2, 3]
198+
signed_tx = await account.sign_deploy_account_v1(
199+
class_hash, salt, calldata, max_fee=MAX_FEE
200+
)
201+
202+
brodcasted_txn = _create_broadcasted_txn(signed_tx)
203+
204+
assert brodcasted_txn["type"] == TransactionType.DEPLOY_ACCOUNT.name
205+
206+
expected_keys = dataclasses.fields(DeployAccountV1)
207+
assert all(key.name in brodcasted_txn for key in expected_keys)

starknet_py/net/client_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import re
2-
from typing import Union
2+
from typing import Dict, Union, cast
33

44
from typing_extensions import get_args
55

66
from starknet_py.hash.utils import encode_uint, encode_uint_list
77
from starknet_py.net.client_models import Hash, L1HandlerTransaction, Tag
8+
from starknet_py.net.models.transaction import AccountTransaction
9+
from starknet_py.net.schemas.broadcasted_txn import BroadcastedTransactionSchema
810

911

1012
def hash_to_felt(value: Hash) -> str:
@@ -77,3 +79,10 @@ def _is_valid_eth_address(address: str) -> bool:
7779
A function checking if an address matches Ethereum address regex. Note that it doesn't validate any checksums etc.
7880
"""
7981
return bool(re.fullmatch("^0x[a-fA-F0-9]{40}$", address))
82+
83+
84+
def _create_broadcasted_txn(transaction: AccountTransaction) -> dict:
85+
return cast(
86+
Dict,
87+
BroadcastedTransactionSchema().dump(obj=transaction),
88+
)

starknet_py/net/full_node_client.py

Lines changed: 2 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Optional, Tuple, Union, cast
1+
from typing import List, Optional, Tuple, Union, cast
22

33
import aiohttp
44
from marshmallow import EXCLUDE
@@ -34,9 +34,9 @@
3434
TransactionReceipt,
3535
TransactionStatusResponse,
3636
TransactionTrace,
37-
TransactionType,
3837
)
3938
from starknet_py.net.client_utils import (
39+
_create_broadcasted_txn,
4040
_is_valid_eth_address,
4141
_to_rpc_felt,
4242
_to_storage_key,
@@ -46,16 +46,9 @@
4646
from starknet_py.net.models.transaction import (
4747
AccountTransaction,
4848
Declare,
49-
DeclareV1Schema,
50-
DeclareV2,
51-
DeclareV2Schema,
52-
DeclareV3,
5349
DeployAccount,
54-
DeployAccountV3,
5550
Invoke,
56-
InvokeV3,
5751
)
58-
from starknet_py.net.schemas.gateway import SierraCompiledContractSchema
5952
from starknet_py.net.schemas.rpc import (
6053
BlockHashAndNumberSchema,
6154
BlockStateUpdateSchema,
@@ -77,10 +70,8 @@
7770
TransactionReceiptSchema,
7871
TransactionStatusResponseSchema,
7972
TransactionTraceSchema,
80-
TransactionV3Schema,
8173
TypesOfTransactionsSchema,
8274
)
83-
from starknet_py.net.schemas.utils import _extract_tx_version
8475
from starknet_py.transaction_errors import TransactionNotReceivedError
8576
from starknet_py.utils.sync import add_sync_methods
8677

@@ -810,152 +801,3 @@ def _get_raw_block_identifier(
810801
return {"block_number": block_number}
811802

812803
return "pending"
813-
814-
815-
def _create_broadcasted_txn(transaction: AccountTransaction) -> dict:
816-
txn_map = {
817-
TransactionType.DECLARE: _create_broadcasted_declare_properties,
818-
TransactionType.INVOKE: _create_broadcasted_invoke_properties,
819-
TransactionType.DEPLOY_ACCOUNT: _create_broadcasted_deploy_account_properties,
820-
}
821-
822-
common_properties = _create_broadcasted_txn_common_properties(transaction)
823-
transaction_specific_properties = txn_map[transaction.type](transaction)
824-
825-
return {
826-
**common_properties,
827-
**transaction_specific_properties,
828-
}
829-
830-
831-
def _create_broadcasted_declare_properties(
832-
transaction: Union[Declare, DeclareV2, DeclareV3]
833-
) -> dict:
834-
if isinstance(transaction, DeclareV2):
835-
return _create_broadcasted_declare_v2_properties(transaction)
836-
if isinstance(transaction, DeclareV3):
837-
return _create_broadcasted_declare_v3_properties(transaction)
838-
839-
contract_class = cast(Dict, DeclareV1Schema().dump(obj=transaction))[
840-
"contract_class"
841-
]
842-
declare_properties = {
843-
"contract_class": {
844-
"entry_points_by_type": contract_class["entry_points_by_type"],
845-
"program": contract_class["program"],
846-
},
847-
"sender_address": _to_rpc_felt(transaction.sender_address),
848-
}
849-
if contract_class["abi"] is not None:
850-
declare_properties["contract_class"]["abi"] = contract_class["abi"]
851-
852-
return declare_properties
853-
854-
855-
def _create_broadcasted_declare_v2_properties(transaction: DeclareV2) -> dict:
856-
contract_class = cast(Dict, DeclareV2Schema().dump(obj=transaction))[
857-
"contract_class"
858-
]
859-
declare_v2_properties = {
860-
"contract_class": {
861-
"entry_points_by_type": contract_class["entry_points_by_type"],
862-
"sierra_program": contract_class["sierra_program"],
863-
"contract_class_version": contract_class["contract_class_version"],
864-
},
865-
"sender_address": _to_rpc_felt(transaction.sender_address),
866-
"compiled_class_hash": _to_rpc_felt(transaction.compiled_class_hash),
867-
}
868-
if contract_class["abi"] is not None:
869-
declare_v2_properties["contract_class"]["abi"] = contract_class["abi"]
870-
871-
return declare_v2_properties
872-
873-
874-
def _create_broadcasted_declare_v3_properties(transaction: DeclareV3) -> dict:
875-
contract_class = cast(
876-
Dict, SierraCompiledContractSchema().dump(obj=transaction.contract_class)
877-
)
878-
879-
declare_v3_properties = {
880-
"contract_class": {
881-
"entry_points_by_type": contract_class["entry_points_by_type"],
882-
"sierra_program": contract_class["sierra_program"],
883-
"contract_class_version": contract_class["contract_class_version"],
884-
},
885-
"sender_address": _to_rpc_felt(transaction.sender_address),
886-
"compiled_class_hash": _to_rpc_felt(transaction.compiled_class_hash),
887-
"account_deployment_data": [
888-
_to_rpc_felt(data) for data in transaction.account_deployment_data
889-
],
890-
}
891-
892-
if contract_class["abi"] is not None:
893-
declare_v3_properties["contract_class"]["abi"] = contract_class["abi"]
894-
895-
return {
896-
**_create_broadcasted_txn_v3_common_properties(transaction),
897-
**declare_v3_properties,
898-
}
899-
900-
901-
def _create_broadcasted_invoke_properties(transaction: Union[Invoke, InvokeV3]) -> dict:
902-
invoke_properties = {
903-
"sender_address": _to_rpc_felt(transaction.sender_address),
904-
"calldata": [_to_rpc_felt(data) for data in transaction.calldata],
905-
}
906-
907-
if isinstance(transaction, InvokeV3):
908-
return {
909-
**_create_broadcasted_txn_v3_common_properties(transaction),
910-
**invoke_properties,
911-
"account_deployment_data": [
912-
_to_rpc_felt(data) for data in transaction.account_deployment_data
913-
],
914-
}
915-
916-
return invoke_properties
917-
918-
919-
def _create_broadcasted_deploy_account_properties(
920-
transaction: Union[DeployAccount, DeployAccountV3]
921-
) -> dict:
922-
deploy_account_txn_properties = {
923-
"contract_address_salt": _to_rpc_felt(transaction.contract_address_salt),
924-
"constructor_calldata": [
925-
_to_rpc_felt(data) for data in transaction.constructor_calldata
926-
],
927-
"class_hash": _to_rpc_felt(transaction.class_hash),
928-
}
929-
930-
if isinstance(transaction, DeployAccountV3):
931-
return {
932-
**_create_broadcasted_txn_v3_common_properties(transaction),
933-
**deploy_account_txn_properties,
934-
}
935-
936-
return deploy_account_txn_properties
937-
938-
939-
def _create_broadcasted_txn_common_properties(transaction: AccountTransaction) -> dict:
940-
broadcasted_txn_common_properties = {
941-
"type": transaction.type.name,
942-
"version": _to_rpc_felt(transaction.version),
943-
"signature": [_to_rpc_felt(sig) for sig in transaction.signature],
944-
"nonce": _to_rpc_felt(transaction.nonce),
945-
}
946-
947-
if _extract_tx_version(transaction.version) < 3 and hasattr(transaction, "max_fee"):
948-
broadcasted_txn_common_properties["max_fee"] = _to_rpc_felt(
949-
transaction.max_fee # pyright: ignore
950-
)
951-
952-
return broadcasted_txn_common_properties
953-
954-
955-
def _create_broadcasted_txn_v3_common_properties(
956-
transaction: Union[DeclareV3, InvokeV3, DeployAccountV3]
957-
) -> dict:
958-
return cast(
959-
Dict,
960-
TransactionV3Schema(exclude=["version", "signature"]).dump(obj=transaction),
961-
)

0 commit comments

Comments
 (0)