diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 2c9c2ca..0000000 Binary files a/.DS_Store and /dev/null differ diff --git a/.gitignore b/.gitignore index 9312504..fa4024a 100644 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,4 @@ sample.py vault.py parser.py *.pem +.DS_Store diff --git a/VERSION b/VERSION index 6da28dd..0ea3a94 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.1 \ No newline at end of file +0.2.0 diff --git a/requirements.txt b/requirements.txt index bf5c16a..9523fee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ -cffi==1.17.1 -cryptography==44.0.0 -ecdsa==0.19.0 -pycparser==2.22 -pycryptodome==3.21.0 -six==1.16.0 -sqlcipher3==0.5.4 +cffi==2.0.0 +cryptography==46.0.3 +pycparser==2.23 +pycryptodome==3.23.0 +six==1.17.0 +sqlcipher3==0.6.0 diff --git a/smswithoutborders_libsig/keypairs.py b/smswithoutborders_libsig/keypairs.py index b1ce595..0c94deb 100755 --- a/smswithoutborders_libsig/keypairs.py +++ b/smswithoutborders_libsig/keypairs.py @@ -1,22 +1,21 @@ #!/usr/bin/env python3 -from abc import ABC, abstractmethod - -# X25519 -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey -from cryptography.hazmat.primitives import serialization -import binascii import base64 +import binascii +import secrets +import struct +import uuid +from typing import Self + +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric.x25519 import ( + X25519PrivateKey, + X25519PublicKey, +) from cryptography.hazmat.primitives.kdf.hkdf import HKDF from smswithoutborders_libsig.keystore import Keystore -import base64 -import secrets -import uuid -import struct -from typing import Self # Available in Python 3.11+ class x25519: def __init__(self, keystore_path=None, pnt_keystore=None, secret_key=None): @@ -40,56 +39,56 @@ def init(self): if not self.keystore_path: self.keystore_path = f"db_keys/{self.pnt_keystore}.db" - self.secret_key = self.store(pk, _pk, self.keystore_path, - self.pnt_keystore, secret_key=self.secret_key) + self.secret_key = self.store( + pk, _pk, self.keystore_path, self.pnt_keystore, secret_key=self.secret_key + ) return pk def serialize(self) -> bytes: - """ - """ - if not hasattr(self, 'pnt_keystore') or self.pnt_keystore == None or \ - not hasattr(self, 'keystore_path') or self.keystore_path == None or \ - not hasattr(self, 'secret_key') or self.secret_key == None: + """ """ + if ( + not hasattr(self, "pnt_keystore") + or self.pnt_keystore == None + or not hasattr(self, "keystore_path") + or self.keystore_path == None + or not hasattr(self, "secret_key") + or self.secret_key == None + ): raise Exception("keypair not initialized -- init()") keystore_path_len = len(self.keystore_path) pnt_keystore_len = len(self.pnt_keystore) - return struct.pack(" Self: - """ - """ + """ """ x = x25519() keystore_path_len, pnt_keystore_len = struct.unpack(" bytes: def store(self, pk, _pk, keystore_path, pnt_keystore, secret_key=None) -> bytes: if not secret_key: - secret_key = secrets.token_bytes(self.size) # store this + secret_key = secrets.token_bytes(self.size) # store this keystore = Keystore(keystore_path, secret_key) keystore.store(keypair=(pk, _pk), pnt=pnt_keystore) @@ -110,32 +109,15 @@ def store(self, pk, _pk, keystore_path, pnt_keystore, secret_key=None) -> bytes: def fetch(self, pnt_keystore, secret_key, keystore_path=None): keystore = Keystore(keystore_path, secret_key) return keystore.fetch(pnt_keystore) - def __agree__(self, secret_key, info=b"x25591_key_exchange", salt=None): - return HKDF(algorithm=hashes.SHA256(), - length=self.size, salt=salt, info=info,).derive(secret_key) + return HKDF( + algorithm=hashes.SHA256(), + length=self.size, + salt=salt, + info=info, + ).derive(secret_key) def migrate(self, old_key: str) -> str: dec_key = base64.b64decode(old_key) return binascii.hexlify(dec_key).decode("ascii") - - -if __name__ == "__main__": - client1 = x25519() - client1_public_key = client1.init() - - client2 = x25519() - client2_public_key = client2.init() - - dk = client1.agree(client2_public_key) - dk1 = client2.agree(client1_public_key) - - assert(dk != None) - assert(dk1 != None) - assert(dk == dk1) - - s_c1 = client1.serialize() - d_c1 = client1.deserialize(s_c1) - - assert(d_c1 == client1) diff --git a/smswithoutborders_libsig/protocols.py b/smswithoutborders_libsig/protocols.py index f8a9874..a227c9e 100755 --- a/smswithoutborders_libsig/protocols.py +++ b/smswithoutborders_libsig/protocols.py @@ -1,18 +1,19 @@ #!/usr/bin/env python3 -from smswithoutborders_libsig.keypairs import x25519 +import base64 +import json +import pickle +import struct +import warnings -from Crypto.Protocol.KDF import HKDF -from Crypto.Random import get_random_bytes -from Crypto.Hash import SHA512, SHA256, HMAC from Crypto.Cipher import AES +from Crypto.Hash import HMAC, SHA256, SHA512 +from Crypto.Protocol.KDF import HKDF from Crypto.Util.Padding import pad, unpad -import logging -import struct import smswithoutborders_libsig.helpers as helpers -import pickle -import base64 +from smswithoutborders_libsig.keypairs import x25519 + class States: DHs: x25519 = None @@ -30,72 +31,165 @@ class States: MKSKIPPED = {} def serialize(self) -> bytes: - if not hasattr(self, 'DHs') or self.DHs == None or \ - not hasattr(self, 'RK') or self.RK == None: - raise Exception("State cannot be serialized: reason DHs == None or RK == None") + warnings.warn( + "serialize() is deprecated due to pickle usage. Use serialize_json() instead.", + DeprecationWarning, + stacklevel=2, + ) + if ( + not hasattr(self, "DHs") + or self.DHs is None + or not hasattr(self, "RK") + or self.RK is None + ): + raise Exception( + "State cannot be serialized: reason DHs == None or RK == None" + ) s_keypairs = self.DHs.serialize() s_keypairs_len = len(s_keypairs) - dhr_len = len(self.DHr) if not self.DHr is None else 0 - rk_len = len(self.RK) if not self.RK is None else 0 - ck_len = len(self.CKs) if not self.CKs is None else 0 - cr_len = len(self.CKr) if not self.CKr is None else 0 - - len_start = struct.pack(f"<{'i'*5}", s_keypairs_len, rk_len, dhr_len, ck_len, cr_len) + dhr_len = len(self.DHr) if self.DHr is not None else 0 + rk_len = len(self.RK) if self.RK is not None else 0 + ck_len = len(self.CKs) if self.CKs is not None else 0 + cr_len = len(self.CKr) if self.CKr is not None else 0 + + len_start = struct.pack( + f"<{'i' * 5}", s_keypairs_len, rk_len, dhr_len, ck_len, cr_len + ) _serialized = len_start + s_keypairs + self.RK for i in [self.DHr, self.CKs, self.CKr]: - if i: + if i: _serialized = _serialized + i - _serialized = _serialized + struct.pack(" bytes: + """ + Serialize state to JSON format + Returns bytes containing JSON-encoded state. + """ + if ( + not hasattr(self, "DHs") + or self.DHs is None + or not hasattr(self, "RK") + or self.RK is None + ): + raise Exception( + "State cannot be serialized: reason DHs == None or RK == None" + ) + + mkskipped_encoded = {} + for (dh_key, n), mk_value in self.MKSKIPPED.items(): + key_str = f"{base64.b64encode(dh_key).decode('ascii')}:{n}" + mkskipped_encoded[key_str] = base64.b64encode(mk_value).decode("ascii") + + state_dict = { + "version": 1, + "DHs": base64.b64encode(self.DHs.serialize()).decode("ascii"), + "DHr": base64.b64encode(self.DHr).decode("ascii") if self.DHr else None, + "RK": base64.b64encode(self.RK).decode("ascii"), + "CKs": base64.b64encode(self.CKs).decode("ascii") if self.CKs else None, + "CKr": base64.b64encode(self.CKr).decode("ascii") if self.CKr else None, + "Ns": self.Ns, + "Nr": self.Nr, + "PN": self.PN, + "MKSKIPPED": mkskipped_encoded, + } + + return json.dumps(state_dict).encode("utf-8") + + @staticmethod + def deserialize_json(data: bytes): + """ + Deserialize state from JSON format. + """ + state = States() + state_dict = json.loads(data.decode("utf-8")) + + if state_dict.get("version") != 1: + raise ValueError(f"Unsupported state version: {state_dict.get('version')}") + + state.DHs = x25519().deserialize(base64.b64decode(state_dict["DHs"])) + state.RK = base64.b64decode(state_dict["RK"]) + state.DHr = base64.b64decode(state_dict["DHr"]) if state_dict["DHr"] else None + state.CKs = base64.b64decode(state_dict["CKs"]) if state_dict["CKs"] else None + state.CKr = base64.b64decode(state_dict["CKr"]) if state_dict["CKr"] else None + state.Ns = state_dict["Ns"] + state.Nr = state_dict["Nr"] + state.PN = state_dict["PN"] + + state.MKSKIPPED = {} + for key_str, mk_value_encoded in state_dict["MKSKIPPED"].items(): + dh_b64, n_str = key_str.rsplit(":", 1) + dh_key = base64.b64decode(dh_b64) + n = int(n_str) + mk_value = base64.b64decode(mk_value_encoded) + state.MKSKIPPED[(dh_key, n)] = mk_value + + return state - return (self.DHs == other.DHs and self.DHr == other.DHr - and self.RK == other.RK and self.CKs == other.CKs - and self.CKr == other.CKr and self.Ns == other.Ns - and self.Nr == other.Nr and self.PN == other.PN - and self.MKSKIPPED == other.MKSKIPPED) class HEADERS: - dh: bytes # public key bytes + dh: bytes # public key bytes pn = None n = None - + LEN = None - - def __init__(self, dh_pair: bytes=None, pn=None, n=None): + + def __init__(self, dh_pair: bytes = None, pn=None, n=None): if dh_pair: self.dh = dh_pair.get_public_key() self.pn = pn @@ -118,12 +212,8 @@ def deserialize(data): return headers - def __eq__(self, other): - return (self.dh == other.dh and - self.pn == other.pn and - self.n == other.n) -def DHRatchet(state: States, header: HEADERS): +def DHRatchet(state: States, header: HEADERS): state.PN = state.Ns state.Ns = 0 state.Nr = 0 @@ -136,41 +226,46 @@ def DHRatchet(state: States, header: HEADERS): state.RK, state.CKs = KDF_RK(state.RK, shared_secret) -def GENERATE_DH(keystore_path: str=None, secret_key = None) -> bytes: +def GENERATE_DH(keystore_path: str = None, secret_key=None) -> bytes: x = x25519(keystore_path=keystore_path, secret_key=secret_key) x.init() return x + def DH(dh_pair: x25519, dh_pub: bytes) -> bytes: return dh_pair.agree(dh_pub) -def KDF_RK(rk, dh_out): - length=32 - num_keys=2 + +def KDF_RK(rk, dh_out): + length = 32 + num_keys = 2 # TODO: make meaninful information - information=b'KDF_RK' + information = b"KDF_RK" - return HKDF(master=dh_out, - key_len=length, - salt=rk, - hashmod=SHA512, - num_keys=num_keys, context=information) + return HKDF( + master=dh_out, + key_len=length, + salt=rk, + hashmod=SHA512, + num_keys=num_keys, + context=information, + ) def KDF_CK(ck): d_ck = HMAC.new(ck, digestmod=SHA256) - _ck = d_ck.update(b'\x01').digest() + _ck = d_ck.update(b"\x01").digest() d_ck = HMAC.new(ck, digestmod=SHA256) - mk = d_ck.update(b'\x02').digest() + mk = d_ck.update(b"\x02").digest() return _ck, mk def ENCRYPT(mk, plaintext, associated_data) -> bytes: key, auth_key, iv = helpers.get_mac_parameters(mk) cipher = AES.new(key, AES.MODE_CBC, iv) - cipher_text = cipher.encrypt(pad(plaintext, AES.block_size)) + cipher_text = cipher.encrypt(pad(plaintext, AES.block_size)) hmac = helpers.build_verification_hash(auth_key, associated_data, cipher_text) return cipher_text + hmac.digest() diff --git a/tests/test_helpers.py b/tests/test_helpers.py new file mode 100644 index 0000000..de3e367 --- /dev/null +++ b/tests/test_helpers.py @@ -0,0 +1,66 @@ +"""Test utilities.""" + +from cryptography.hazmat.primitives import constant_time + +from smswithoutborders_libsig.keypairs import x25519 +from smswithoutborders_libsig.protocols import HEADERS, States + + +def states_equal(state1, state2) -> bool: + """Compare two States objects.""" + + if not isinstance(state1, States) or not isinstance(state2, States): + return False + + dhr_equal = constant_time.bytes_eq( + state1.DHr if state1.DHr else b"", state2.DHr if state2.DHr else b"" + ) + rk_equal = constant_time.bytes_eq( + state1.RK if state1.RK else b"", state2.RK if state2.RK else b"" + ) + cks_equal = constant_time.bytes_eq( + state1.CKs if state1.CKs else b"", state2.CKs if state2.CKs else b"" + ) + ckr_equal = constant_time.bytes_eq( + state1.CKr if state1.CKr else b"", state2.CKr if state2.CKr else b"" + ) + + return ( + keypairs_equal(state1.DHs, state2.DHs) + and dhr_equal + and rk_equal + and cks_equal + and ckr_equal + and state1.Ns == state2.Ns + and state1.Nr == state2.Nr + and state1.PN == state2.PN + and state1.MKSKIPPED == state2.MKSKIPPED + ) + + +def headers_equal(header1, header2) -> bool: + """Compare two HEADERS objects.""" + + if not isinstance(header1, HEADERS) or not isinstance(header2, HEADERS): + return False + + return ( + constant_time.bytes_eq(header1.dh, header2.dh) + and header1.pn == header2.pn + and header1.n == header2.n + ) + + +def keypairs_equal(keypair1, keypair2) -> bool: + """Compare two Keypairs objects.""" + + if not isinstance(keypair1, x25519) or not isinstance(keypair2, x25519): + return False + + return ( + keypair1.keystore_path == keypair2.keystore_path + and keypair1.pnt_keystore == keypair2.pnt_keystore + and constant_time.bytes_eq( + keypair1.secret_key.encode(), keypair2.secret_key.encode() + ) + ) diff --git a/tests/test_json_serialization.py b/tests/test_json_serialization.py new file mode 100644 index 0000000..e21890e --- /dev/null +++ b/tests/test_json_serialization.py @@ -0,0 +1,429 @@ +import os +import secrets + +import pytest + +from smswithoutborders_libsig.keypairs import x25519 +from smswithoutborders_libsig.protocols import States +from tests.test_helpers import states_equal + + +class TestJSONSerialization: + """Test JSON-based serialization.""" + + def setup_method(self): + """Setup test fixtures""" + self.keystore_path = "db_keys/test_json_states.db" + self.dh_keypair = x25519(keystore_path=self.keystore_path) + self.dh_keypair.init() + + def teardown_method(self): + """Cleanup test files""" + if os.path.exists(self.keystore_path): + os.remove(self.keystore_path) + + def test_json_serialization_basic(self): + """Test basic JSON serialization""" + state = States() + state.DHs = self.dh_keypair + state.RK = secrets.token_bytes(32) + state.DHr = secrets.token_bytes(32) + state.CKs = secrets.token_bytes(32) + state.CKr = secrets.token_bytes(32) + state.Ns = 5 + state.Nr = 3 + state.PN = 2 + + serialized = state.serialize_json() + assert isinstance(serialized, bytes) + assert len(serialized) > 0 + + def test_json_deserialization_basic(self): + """Test basic JSON deserialization""" + state = States() + state.DHs = self.dh_keypair + state.RK = secrets.token_bytes(32) + state.DHr = secrets.token_bytes(32) + state.CKs = secrets.token_bytes(32) + state.CKr = secrets.token_bytes(32) + state.Ns = 5 + state.Nr = 3 + state.PN = 2 + + serialized = state.serialize_json() + deserialized = States.deserialize_json(serialized) + + assert states_equal(deserialized, state) + assert deserialized.Ns == 5 + assert deserialized.Nr == 3 + assert deserialized.PN == 2 + + def test_json_serialization_with_mkskipped(self): + """Test JSON serialization with MKSKIPPED data""" + state = States() + state.DHs = self.dh_keypair + state.RK = secrets.token_bytes(32) + state.DHr = secrets.token_bytes(32) + state.CKs = secrets.token_bytes(32) + state.CKr = secrets.token_bytes(32) + state.Ns = 10 + state.Nr = 7 + state.PN = 5 + state.MKSKIPPED = { + (b"dh_key_1" * 4, 1): b"message_key_1" * 2, + (b"dh_key_2" * 4, 5): b"message_key_2" * 2, + (b"dh_key_3" * 4, 10): b"message_key_3" * 2, + } + + serialized = state.serialize_json() + deserialized = States.deserialize_json(serialized) + + assert states_equal(deserialized, state) + assert deserialized.MKSKIPPED == state.MKSKIPPED + assert len(deserialized.MKSKIPPED) == 3 + + def test_json_serialization_with_none_values(self): + """Test JSON serialization handles None values correctly""" + state = States() + state.DHs = self.dh_keypair + state.RK = secrets.token_bytes(32) + state.DHr = secrets.token_bytes(32) + state.CKs = None + state.CKr = None + + serialized = state.serialize_json() + deserialized = States.deserialize_json(serialized) + + assert deserialized.DHr == state.DHr + assert deserialized.CKs is None + assert deserialized.CKr is None + assert states_equal(deserialized, state) + + def test_json_serialization_empty_mkskipped(self): + """Test JSON serialization with empty MKSKIPPED""" + state = States() + state.DHs = self.dh_keypair + state.RK = secrets.token_bytes(32) + state.DHr = secrets.token_bytes(32) + state.MKSKIPPED = {} + + serialized = state.serialize_json() + deserialized = States.deserialize_json(serialized) + + assert deserialized.MKSKIPPED == {} + assert states_equal(deserialized, state) + + def test_json_serialization_without_required_fields(self): + """Test JSON serialization fails without DHs or RK""" + state = States() + with pytest.raises(Exception, match="State cannot be serialized"): + state.serialize_json() + + state.DHs = self.dh_keypair + with pytest.raises(Exception, match="State cannot be serialized"): + state.serialize_json() + + def test_json_serialization_deterministic(self): + """Test JSON serialization is deterministic""" + state = States() + state.DHs = self.dh_keypair + state.RK = secrets.token_bytes(32) + state.DHr = secrets.token_bytes(32) + state.Ns = 5 + + serialized1 = state.serialize_json() + serialized2 = state.serialize_json() + + assert serialized1 == serialized2 + + def test_json_output_is_valid_json(self): + """Test that serialized output is valid JSON""" + import json + + state = States() + state.DHs = self.dh_keypair + state.RK = secrets.token_bytes(32) + state.DHr = secrets.token_bytes(32) + + serialized = state.serialize_json() + + # Should not raise exception + parsed = json.loads(serialized.decode("utf-8")) + assert isinstance(parsed, dict) + assert "version" in parsed + assert parsed["version"] == 1 + + def test_json_no_binary_in_output(self): + """Test that JSON output contains no binary data (only base64 strings)""" + import json + + state = States() + state.DHs = self.dh_keypair + state.RK = secrets.token_bytes(32) + state.DHr = secrets.token_bytes(32) + state.MKSKIPPED = {(b"test_key" * 4, 1): b"test_value" * 2} + + serialized = state.serialize_json() + parsed = json.loads(serialized.decode("utf-8")) + + # All values should be strings, integers, or dicts (no bytes) + assert isinstance(parsed["DHs"], str) + assert isinstance(parsed["RK"], str) + assert isinstance(parsed["DHr"], str) + assert isinstance(parsed["Ns"], int) + assert isinstance(parsed["MKSKIPPED"], dict) + for key, value in parsed["MKSKIPPED"].items(): + assert isinstance(key, str) + assert isinstance(value, str) + + +class TestPickleToJSONMigration: + """Test migration from pickle serialization to JSON serialization""" + + def setup_method(self): + """Setup test fixtures""" + self.keystore_path = "db_keys/test_migration.db" + self.dh_keypair = x25519(keystore_path=self.keystore_path) + self.dh_keypair.init() + + def teardown_method(self): + """Cleanup test files""" + if os.path.exists(self.keystore_path): + os.remove(self.keystore_path) + + def test_migrate_pickle_to_json_basic(self): + """Test migrating basic state from pickle to JSON""" + # Create state and serialize with pickle + state_original = States() + state_original.DHs = self.dh_keypair + state_original.RK = secrets.token_bytes(32) + state_original.DHr = secrets.token_bytes(32) + state_original.CKs = secrets.token_bytes(32) + state_original.CKr = secrets.token_bytes(32) + state_original.Ns = 5 + state_original.Nr = 3 + state_original.PN = 2 + + pickle_serialized = state_original.serialize() + + # Deserialize with pickle + state_from_pickle = States.deserialize(pickle_serialized) + + # Re-serialize with JSON + json_serialized = state_from_pickle.serialize_json() + + # Deserialize with JSON + state_from_json = States.deserialize_json(json_serialized) + + # Verify all data is preserved + assert states_equal(state_from_json, state_original) + assert state_from_json.Ns == 5 + assert state_from_json.Nr == 3 + assert state_from_json.PN == 2 + + def test_migrate_pickle_to_json_with_mkskipped(self): + """Test migrating state with MKSKIPPED from pickle to JSON""" + # Create state with MKSKIPPED + state_original = States() + state_original.DHs = self.dh_keypair + state_original.RK = secrets.token_bytes(32) + state_original.DHr = secrets.token_bytes(32) + state_original.CKs = secrets.token_bytes(32) + state_original.CKr = secrets.token_bytes(32) + state_original.Ns = 10 + state_original.Nr = 7 + state_original.PN = 5 + state_original.MKSKIPPED = { + (b"dh_key_1" * 4, 1): b"message_key_1" * 2, + (b"dh_key_2" * 4, 5): b"message_key_2" * 2, + (b"dh_key_3" * 4, 10): b"message_key_3" * 2, + } + + # Pickle serialize + pickle_serialized = state_original.serialize() + + # Deserialize with pickle + state_from_pickle = States.deserialize(pickle_serialized) + + # Re-serialize with JSON + json_serialized = state_from_pickle.serialize_json() + + # Deserialize with JSON + state_from_json = States.deserialize_json(json_serialized) + + # Verify all data is preserved including MKSKIPPED + assert states_equal(state_from_json, state_original) + assert state_from_json.MKSKIPPED == state_original.MKSKIPPED + assert len(state_from_json.MKSKIPPED) == 3 + + # Verify each MKSKIPPED entry + for key, value in state_original.MKSKIPPED.items(): + assert key in state_from_json.MKSKIPPED + assert state_from_json.MKSKIPPED[key] == value + + def test_migrate_pickle_to_json_with_none_values(self): + """Test migrating state with None values from pickle to JSON""" + state_original = States() + state_original.DHs = self.dh_keypair + state_original.RK = secrets.token_bytes(32) + state_original.DHr = secrets.token_bytes(32) + state_original.CKs = None + state_original.CKr = None + + pickle_serialized = state_original.serialize() + state_from_pickle = States.deserialize(pickle_serialized) + json_serialized = state_from_pickle.serialize_json() + state_from_json = States.deserialize_json(json_serialized) + + assert states_equal(state_from_json, state_original) + assert state_from_json.CKs is None + assert state_from_json.CKr is None + + def test_migrate_pickle_to_json_preserves_crypto_keys(self): + """Test that cryptographic keys are preserved during migration""" + state_original = States() + state_original.DHs = self.dh_keypair + state_original.RK = secrets.token_bytes(32) + state_original.DHr = secrets.token_bytes(32) + state_original.CKs = secrets.token_bytes(32) + state_original.CKr = secrets.token_bytes(32) + + pickle_serialized = state_original.serialize() + state_from_pickle = States.deserialize(pickle_serialized) + json_serialized = state_from_pickle.serialize_json() + state_from_json = States.deserialize_json(json_serialized) + + # Verify cryptographic keys are byte-for-byte identical + assert state_from_json.RK == state_original.RK + assert state_from_json.DHr == state_original.DHr + assert state_from_json.CKs == state_original.CKs + assert state_from_json.CKr == state_original.CKr + assert ( + state_from_json.DHs.get_public_key() == state_original.DHs.get_public_key() + ) + + def test_migrate_multiple_states(self): + """Test migrating multiple different states from pickle to JSON""" + states = [] + + for i in range(5): + keystore_path = f"db_keys/test_migration_{i}.db" + if os.path.exists(keystore_path): + os.remove(keystore_path) + + dh = x25519(keystore_path=keystore_path) + dh.init() + + state = States() + state.DHs = dh + state.RK = secrets.token_bytes(32) + state.DHr = secrets.token_bytes(32) + state.CKs = secrets.token_bytes(32) + state.CKr = secrets.token_bytes(32) + state.Ns = i * 10 + state.Nr = i * 5 + state.PN = i * 3 + state.MKSKIPPED = { + (secrets.token_bytes(32), j): secrets.token_bytes(32) for j in range(i) + } + + states.append((state, keystore_path)) + + for original_state, keystore_path in states: + pickle_serialized = original_state.serialize() + state_from_pickle = States.deserialize(pickle_serialized) + json_serialized = state_from_pickle.serialize_json() + state_from_json = States.deserialize_json(json_serialized) + + assert states_equal(state_from_json, original_state) + + if os.path.exists(keystore_path): + os.remove(keystore_path) + + def test_json_deserialization_invalid_version(self): + """Test that invalid version raises an error""" + import json + + invalid_data = json.dumps({"version": 99}).encode("utf-8") + + with pytest.raises(ValueError, match="Unsupported state version"): + States.deserialize_json(invalid_data) + + def test_roundtrip_consistency_pickle_vs_json(self): + """Test that pickle and JSON produce equivalent results after roundtrip""" + state_original = States() + state_original.DHs = self.dh_keypair + state_original.RK = secrets.token_bytes(32) + state_original.DHr = secrets.token_bytes(32) + state_original.CKs = secrets.token_bytes(32) + state_original.CKr = secrets.token_bytes(32) + state_original.Ns = 42 + state_original.Nr = 24 + state_original.PN = 12 + state_original.MKSKIPPED = {(b"key" * 8, 7): b"value" * 8} + + # Roundtrip through pickle + pickle_roundtrip = States.deserialize(state_original.serialize()) + + # Roundtrip through JSON + json_roundtrip = States.deserialize_json(state_original.serialize_json()) + + # Both should equal the original + assert states_equal(pickle_roundtrip, state_original) + assert states_equal(json_roundtrip, state_original) + + # And should equal each other + assert states_equal(pickle_roundtrip, json_roundtrip) + + +class TestJSONSecurityProperties: + """Test security properties of JSON serialization""" + + def setup_method(self): + """Setup test fixtures""" + self.keystore_path = "db_keys/test_security.db" + self.dh_keypair = x25519(keystore_path=self.keystore_path) + self.dh_keypair.init() + + def teardown_method(self): + """Cleanup test files""" + if os.path.exists(self.keystore_path): + os.remove(self.keystore_path) + + def test_json_no_code_execution_risk(self): + """Test that JSON deserialization does not execute code""" + # JSON should not allow code execution unlike pickle + state = States() + state.DHs = self.dh_keypair + state.RK = secrets.token_bytes(32) + state.DHr = secrets.token_bytes(32) + + serialized = state.serialize_json() + + # This should safely deserialize without any code execution + deserialized = States.deserialize_json(serialized) + assert states_equal(deserialized, state) + + def test_json_malformed_input_handling(self): + """Test that malformed JSON input raises appropriate errors""" + with pytest.raises((ValueError, Exception)): + States.deserialize_json(b"not valid json") + + with pytest.raises((ValueError, Exception)): + States.deserialize_json(b"{incomplete") + + def test_json_tampering_detection(self): + """Test that tampering with JSON is detectable through data validation""" + state = States() + state.DHs = self.dh_keypair + state.RK = secrets.token_bytes(32) + state.DHr = secrets.token_bytes(32) + + serialized = state.serialize_json() + + # Tamper with the data + tampered = serialized.replace(b'"version": 1', b'"version": "hacked"') + + # Should fail validation + with pytest.raises((ValueError, TypeError, Exception)): + States.deserialize_json(tampered) diff --git a/tests/test_protocols.py b/tests/test_protocols.py new file mode 100644 index 0000000..d3ed0ea --- /dev/null +++ b/tests/test_protocols.py @@ -0,0 +1,297 @@ +"""Tests for protocol components.""" + +import os +import secrets + +import pytest + +from smswithoutborders_libsig.keypairs import x25519 +from smswithoutborders_libsig.protocols import ( + CONCAT, + DECRYPT, + DH, + ENCRYPT, + GENERATE_DH, + HEADERS, + KDF_CK, + KDF_RK, + DHRatchet, + States, +) +from tests.test_helpers import headers_equal, states_equal + + +class TestStates: + """Test States serialization and deserialization.""" + + def setup_method(self): + self.keystore_path = "db_keys/test_states.db" + self.dh_keypair = x25519(keystore_path=self.keystore_path) + self.dh_keypair.init() + + def teardown_method(self): + if os.path.exists(self.keystore_path): + os.remove(self.keystore_path) + + def test_state_initialization(self): + """Test States initializes with default values.""" + state = States() + assert state.DHs is None + assert state.RK is None + assert state.Ns == 0 + assert state.MKSKIPPED == {} + + def test_state_serialization_roundtrip(self): + """Test States serialization and deserialization.""" + state = States() + state.DHs = self.dh_keypair + state.RK = secrets.token_bytes(32) + state.DHr = secrets.token_bytes(32) + state.CKs = secrets.token_bytes(32) + state.CKr = secrets.token_bytes(32) + state.Ns = 5 + state.Nr = 3 + state.PN = 2 + state.MKSKIPPED = {(b"key1", 1): b"value1"} + + serialized = state.serialize() + deserialized = States.deserialize(serialized) + + assert states_equal(deserialized, state) + assert deserialized.Ns == 5 + assert deserialized.MKSKIPPED == {(b"key1", 1): b"value1"} + + def test_state_serialization_requires_fields(self): + """Test serialization requires DHs and RK.""" + state = States() + with pytest.raises(Exception, match="State cannot be serialized"): + state.serialize() + + def test_state_serialization_with_none_values(self): + """Test States handles None values.""" + state = States() + state.DHs = self.dh_keypair + state.RK = secrets.token_bytes(32) + state.DHr = secrets.token_bytes(32) + state.CKs = None + state.CKr = None + + serialized = state.serialize() + deserialized = States.deserialize(serialized) + + assert deserialized.CKs is None + assert deserialized.CKr is None + + def test_states_equality(self): + """Test states comparison.""" + state1 = States() + state1.DHs = self.dh_keypair + state1.RK = secrets.token_bytes(32) + state1.Ns = 1 + + state2 = States() + state2.DHs = self.dh_keypair + state2.RK = secrets.token_bytes(32) + state2.Ns = 2 + + assert not states_equal(state1, state2) + + +class TestHEADERS: + """Test HEADERS serialization.""" + + def setup_method(self): + self.keystore_path = "db_keys/test_headers.db" + self.dh_keypair = x25519(keystore_path=self.keystore_path) + self.dh_keypair.init() + + def teardown_method(self): + if os.path.exists(self.keystore_path): + os.remove(self.keystore_path) + + def test_header_serialization_roundtrip(self): + """Test HEADERS serialization and deserialization.""" + header1 = HEADERS(dh_pair=self.dh_keypair, pn=5, n=10) + serialized = header1.serialize() + header2 = HEADERS.deserialize(serialized) + + assert headers_equal(header1, header2) + assert header2.pn == 5 + assert header2.n == 10 + + def test_headers_equality(self): + """Test headers comparison.""" + header1 = HEADERS(dh_pair=self.dh_keypair, pn=5, n=10) + header2 = HEADERS(dh_pair=self.dh_keypair, pn=6, n=10) + + assert not headers_equal(header1, header2) + + +class TestDHRatchet: + """Test DH operations.""" + + def setup_method(self): + self.keystore_path1 = "db_keys/test_dh1.db" + self.keystore_path2 = "db_keys/test_dh2.db" + + def teardown_method(self): + for path in [self.keystore_path1, self.keystore_path2]: + if os.path.exists(path): + os.remove(path) + + def test_dh_agreement(self): + """Test DH key agreement.""" + dh1 = GENERATE_DH(keystore_path=self.keystore_path1) + dh2 = GENERATE_DH(keystore_path=self.keystore_path2) + + shared1 = DH(dh1, dh2.get_public_key()) + shared2 = DH(dh2, dh1.get_public_key()) + + assert shared1 == shared2 + assert len(shared1) == 32 + + def test_dh_ratchet_updates_state(self): + """Test DHRatchet updates state.""" + state = States() + state.DHs = GENERATE_DH(keystore_path=self.keystore_path1) + state.RK = secrets.token_bytes(32) + state.Ns = 5 + + dh_remote = GENERATE_DH(keystore_path=self.keystore_path2) + header = HEADERS(dh_pair=dh_remote, pn=0, n=0) + + DHRatchet(state, header) + + assert state.PN == 5 + assert state.Ns == 0 + assert state.CKr is not None + assert state.CKs is not None + + +class TestKDF: + """Test key derivation functions.""" + + def test_kdf_rk(self): + """Test KDF_RK generates deterministic keys.""" + rk = secrets.token_bytes(32) + dh_out = secrets.token_bytes(32) + + new_rk1, ck1 = KDF_RK(rk, dh_out) + new_rk2, ck2 = KDF_RK(rk, dh_out) + + assert new_rk1 == new_rk2 + assert ck1 == ck2 + assert len(new_rk1) == 32 + + def test_kdf_ck(self): + """Test KDF_CK generates deterministic keys.""" + ck = secrets.token_bytes(32) + + new_ck1, mk1 = KDF_CK(ck) + new_ck2, mk2 = KDF_CK(ck) + + assert new_ck1 == new_ck2 + assert mk1 == mk2 + assert len(new_ck1) == 32 + + +class TestEncryption: + """Test encryption and decryption.""" + + def test_encrypt_decrypt_roundtrip(self): + """Test encryption and decryption.""" + mk = secrets.token_bytes(32) + plaintext = b"Hello, World!" + associated_data = b"metadata" + + ciphertext = ENCRYPT(mk, plaintext, associated_data) + decrypted = DECRYPT(mk, ciphertext, associated_data) + + assert decrypted == plaintext + assert ciphertext != plaintext + + def test_decrypt_wrong_key_fails(self): + """Test decryption with wrong key fails.""" + mk1 = secrets.token_bytes(32) + mk2 = secrets.token_bytes(32) + plaintext = b"Secret" + associated_data = b"metadata" + + ciphertext = ENCRYPT(mk1, plaintext, associated_data) + + with pytest.raises(ValueError): + DECRYPT(mk2, ciphertext, associated_data) + + def test_decrypt_wrong_ad_fails(self): + """Test decryption with wrong associated data fails.""" + mk = secrets.token_bytes(32) + plaintext = b"Secret" + + ciphertext = ENCRYPT(mk, plaintext, b"ad1") + + with pytest.raises(ValueError): + DECRYPT(mk, ciphertext, b"ad2") + + def test_decrypt_tampered_ciphertext_fails(self): + """Test decryption with tampered ciphertext fails.""" + mk = secrets.token_bytes(32) + plaintext = b"Secret" + associated_data = b"metadata" + + ciphertext = ENCRYPT(mk, plaintext, associated_data) + tampered = bytearray(ciphertext) + tampered[0] ^= 0xFF + + with pytest.raises(ValueError): + DECRYPT(mk, bytes(tampered), associated_data) + + +class TestCONCAT: + """Test CONCAT function.""" + + def setup_method(self): + self.keystore_path = "db_keys/test_concat.db" + self.dh_keypair = x25519(keystore_path=self.keystore_path) + self.dh_keypair.init() + + def teardown_method(self): + if os.path.exists(self.keystore_path): + os.remove(self.keystore_path) + + def test_concat(self): + """Test CONCAT combines data and header.""" + ad = b"associated_data" + header = HEADERS(dh_pair=self.dh_keypair, pn=5, n=10) + + result = CONCAT(ad, header) + + assert result.startswith(ad) + assert len(result) == len(ad) + len(header.serialize()) + + +class TestIntegration: + """Integration tests.""" + + def setup_method(self): + self.keystore_path = "db_keys/test_int.db" + + def teardown_method(self): + if os.path.exists(self.keystore_path): + os.remove(self.keystore_path) + + def test_state_persistence(self): + """Test state can be persisted and restored.""" + state1 = States() + state1.DHs = GENERATE_DH(keystore_path=self.keystore_path) + state1.RK = secrets.token_bytes(32) + state1.DHr = secrets.token_bytes(32) + state1.CKs = secrets.token_bytes(32) + state1.Ns = 10 + state1.MKSKIPPED = {(b"key", 1): b"value"} + + serialized = state1.serialize() + state2 = States.deserialize(serialized) + + assert state2.Ns == 10 + assert state2.MKSKIPPED == {(b"key", 1): b"value"} + assert states_equal(state2, state1) diff --git a/tests/test_x25519_keypairs.py b/tests/test_x25519_keypairs.py index a30542d..62016ad 100644 --- a/tests/test_x25519_keypairs.py +++ b/tests/test_x25519_keypairs.py @@ -1,9 +1,9 @@ -""" -Tests for the x25519 key exchange mechanism. -""" +"""Tests for x25519 keypair operations.""" import os + import pytest + from smswithoutborders_libsig.keypairs import x25519 @@ -19,11 +19,7 @@ def keypair_paths(tmp_path): def test_keypair_initialization(keypair_paths): - """Test the initialization of Alice's and Bob's public keys. - - Ensures that the public keys are not None, are of type bytes, - and have the correct length. - """ + """Test keypair initialization generates valid public keys.""" alice_db_path, bob_db_path = keypair_paths alice = x25519(alice_db_path) @@ -41,11 +37,7 @@ def test_keypair_initialization(keypair_paths): def test_key_agreement_protocol(keypair_paths): - """Test the key agreement protocol between Alice and Bob. - - Verifies that the shared keys are correctly generated, are equal, - and have the correct length. - """ + """Test key agreement produces matching shared secrets.""" alice_db_path, bob_db_path = keypair_paths alice = x25519(alice_db_path) @@ -67,31 +59,18 @@ def test_key_agreement_protocol(keypair_paths): def test_invalid_key_agreement(keypair_paths): - """Test the key agreement with invalid inputs. - - Ensures that appropriate exceptions are raised for invalid inputs. - """ - alice_db_path, bob_db_path = keypair_paths + """Test key agreement rejects invalid public keys.""" + alice_db_path, _ = keypair_paths alice = x25519(alice_db_path) - bob = x25519(bob_db_path) - - alice_public_key = alice.init() - bob_public_key = bob.init() + alice.init() with pytest.raises(ValueError): alice.agree(b"invalid_key") - with pytest.raises(ValueError): - bob.agree(b"invalid_key") - -def test_keypair_reinitialization(keypair_paths): - """Test the reinitialization of the x25519 object with existing keys. - - Ensures that x25519 objects can be reinitialized using the same database - paths and secret keys, and that the key agreement process remains functional. - """ +def test_keypair_serialization(keypair_paths): + """Test keypair serialization and deserialization.""" alice_db_path, bob_db_path = keypair_paths alice = x25519(alice_db_path) @@ -100,32 +79,16 @@ def test_keypair_reinitialization(keypair_paths): alice_public_key = alice.init() bob_public_key = bob.init() - alice_pnt_keystore = alice.pnt_keystore - alice_secret_key = alice.secret_key - bob_pnt_keystore = bob.pnt_keystore - bob_secret_key = bob.secret_key + alice_pnt = alice.pnt_keystore + alice_secret = alice.secret_key del alice - del bob - alice = x25519( - pnt_keystore=alice_pnt_keystore, - keystore_path=alice_db_path, - secret_key=alice_secret_key, - ) - bob = x25519( - pnt_keystore=bob_pnt_keystore, - keystore_path=bob_db_path, - secret_key=bob_secret_key, + alice_restored = x25519( + pnt_keystore=alice_pnt, keystore_path=alice_db_path, secret_key=alice_secret ) - alice_shared_key = alice.agree(bob_public_key) + alice_shared_key = alice_restored.agree(bob_public_key) bob_shared_key = bob.agree(alice_public_key) - assert alice_shared_key is not None - assert bob_shared_key is not None - assert isinstance(alice_shared_key, bytes) - assert isinstance(bob_shared_key, bytes) - assert len(alice_shared_key) == 32 - assert len(bob_shared_key) == 32 assert alice_shared_key == bob_shared_key