Skip to content

Commit 2924e15

Browse files
authored
Add function to calculate checksum address (#1215)
1 parent d35a04a commit 2924e15

File tree

3 files changed

+88
-2
lines changed

3 files changed

+88
-2
lines changed

starknet_py/hash/address.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
from typing import Sequence
22

33
from starknet_py.constants import CONTRACT_ADDRESS_PREFIX, L2_ADDRESS_UPPER_BOUND
4-
from starknet_py.hash.utils import compute_hash_on_elements
4+
from starknet_py.hash.utils import (
5+
HEX_PREFIX,
6+
_starknet_keccak,
7+
compute_hash_on_elements,
8+
encode_uint,
9+
get_bytes_length,
10+
)
511

612

713
def compute_address(
@@ -33,3 +39,29 @@ def compute_address(
3339
)
3440

3541
return raw_address % L2_ADDRESS_UPPER_BOUND
42+
43+
44+
def get_checksum_address(address: str) -> str:
45+
if not address.lower().startswith(HEX_PREFIX):
46+
raise ValueError(f"{address} is not a valid hexadecimal address.")
47+
48+
int_address = int(address, 16)
49+
string_address = address[2:].zfill(64)
50+
51+
address_in_bytes = encode_uint(int_address, get_bytes_length(int_address))
52+
address_hash = _starknet_keccak(address_in_bytes)
53+
54+
result = "".join(
55+
(
56+
char.upper()
57+
if char.isalpha() and (address_hash >> 256 - 4 * i - 1) & 1
58+
else char
59+
)
60+
for i, char in enumerate(string_address)
61+
)
62+
63+
return f"{HEX_PREFIX}{result}"
64+
65+
66+
def is_checksum_address(address: str) -> bool:
67+
return get_checksum_address(address) == address

starknet_py/hash/address_test.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from starknet_py.hash.address import compute_address
1+
import pytest
2+
3+
from starknet_py.hash.address import (
4+
compute_address,
5+
get_checksum_address,
6+
is_checksum_address,
7+
)
28

39

410
def test_compute_address():
@@ -22,3 +28,46 @@ def test_compute_address_with_deployer_address():
2228
)
2329
== 3179899882984850239687045389724311807765146621017486664543269641150383510696
2430
)
31+
32+
33+
@pytest.mark.parametrize(
34+
"address, checksum_address",
35+
[
36+
(
37+
"0x2fd23d9182193775423497fc0c472e156c57c69e4089a1967fb288a2d84e914",
38+
"0x02Fd23d9182193775423497fc0c472E156C57C69E4089A1967fb288A2d84e914",
39+
),
40+
(
41+
"0x00abcdefabcdefabcdefabcdefabcdefabcdefabcdefabcdefabcdefabcdefab",
42+
"0x00AbcDefaBcdefabCDEfAbCDEfAbcdEFAbCDEfabCDefaBCdEFaBcDeFaBcDefAb",
43+
),
44+
(
45+
"0xfedcbafedcbafedcbafedcbafedcbafedcbafedcbafedcbafedcbafedcbafe",
46+
"0x00fEdCBafEdcbafEDCbAFedCBAFeDCbafEdCBAfeDcbaFeDCbAfEDCbAfeDcbAFE",
47+
),
48+
("0xa", "0x000000000000000000000000000000000000000000000000000000000000000A"),
49+
(
50+
"0x0",
51+
"0x0000000000000000000000000000000000000000000000000000000000000000",
52+
),
53+
],
54+
)
55+
def test_get_checksum_address(address, checksum_address):
56+
assert get_checksum_address(address) == checksum_address
57+
58+
59+
@pytest.mark.parametrize("address", ["", "0xx", "0123"])
60+
def test_get_checksum_address_raises_on_invalid_address(address):
61+
with pytest.raises(ValueError):
62+
get_checksum_address(address)
63+
64+
65+
@pytest.mark.parametrize(
66+
"address, is_checksum",
67+
[
68+
("0x02Fd23d9182193775423497fc0c472E156C57C69E4089A1967fb288A2d84e914", True),
69+
("0x000000000000000000000000000000000000000000000000000000000000000a", False),
70+
],
71+
)
72+
def test_is_checksum_address(address, is_checksum):
73+
assert is_checksum_address(address) == is_checksum

starknet_py/hash/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from starknet_py.constants import EC_ORDER
1515

1616
MASK_250 = 2**250 - 1
17+
HEX_PREFIX = "0x"
1718

1819

1920
def _starknet_keccak(data: bytes) -> int:
@@ -84,3 +85,7 @@ def encode_uint(value: int, bytes_length: int = 32) -> bytes:
8485

8586
def encode_uint_list(data: List[int]) -> bytes:
8687
return b"".join(encode_uint(x) for x in data)
88+
89+
90+
def get_bytes_length(value: int) -> int:
91+
return (value.bit_length() + 7) // 8

0 commit comments

Comments
 (0)