Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions tests/test_trustme.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,22 @@ def test_ca_from_pem(tmp_path: Path) -> None:
assert ca1.private_key_pem.bytes() == ca2.private_key_pem.bytes()


def safe_close(sock: Optional[Union[socket.socket, SslSocket]]) -> None:
if sock is None:
return
try:
if isinstance(sock, (socket.socket, ssl.SSLSocket)):
sock.shutdown(socket.SHUT_RDWR)
elif isinstance(sock, OpenSSL.SSL.Connection):
sock.shutdown()
except Exception:
pass
try:
sock.close()
except Exception:
pass


def check_connection_end_to_end(
wrap_client: Callable[[CA, socket.socket, str], SslSocket],
wrap_server: Callable[[LeafCert, socket.socket], SslSocket],
Expand All @@ -296,12 +312,12 @@ def fake_ssl_client(ca: CA, raw_client_sock: socket.socket, hostname: str) -> No
# Send and receive some data to prove the connection is good
wrapped_client_sock.send(b"x")
assert wrapped_client_sock.recv(1) == b"y"
wrapped_client_sock.close()
safe_close(wrapped_client_sock)
except: # pragma: no cover
sys.excepthook(*sys.exc_info())
raise
finally:
raw_client_sock.close()
safe_close(raw_client_sock)

# Server side
def fake_ssl_server(server_cert: LeafCert, raw_server_sock: socket.socket) -> None:
Expand All @@ -310,23 +326,25 @@ def fake_ssl_server(server_cert: LeafCert, raw_server_sock: socket.socket) -> No
# Prove that we're connected
assert wrapped_server_sock.recv(1) == b"x"
wrapped_server_sock.send(b"y")
wrapped_server_sock.close()
safe_close(wrapped_server_sock)
except: # pragma: no cover
sys.excepthook(*sys.exc_info())
raise
finally:
raw_server_sock.close()
safe_close(raw_server_sock)

def doit(ca: CA, hostname: str, server_cert: LeafCert) -> None:
# socketpair and ssl don't work together on py2, because... reasons.
# So we need to do this the hard way.
listener = socket.socket()
listener.bind(("127.0.0.1", 0))
listener.listen(1)
raw_client_sock = socket.socket()
raw_client_sock.connect(listener.getsockname())
raw_server_sock, _ = listener.accept()
listener.close()
try:
listener.bind(("127.0.0.1", 0))
listener.listen(1)
raw_client_sock = socket.socket()
raw_client_sock.connect(listener.getsockname())
raw_server_sock, _ = listener.accept()
finally:
safe_close(listener)
with ThreadPoolExecutor(2) as tpe:
f1 = tpe.submit(fake_ssl_client, ca, raw_client_sock, hostname)
f2 = tpe.submit(fake_ssl_server, server_cert, raw_server_sock)
Expand Down