diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py b/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py index e7ce646e9987..864fe62e825f 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py @@ -88,14 +88,17 @@ def extract_cert_chain(pem_bytes: bytes) -> bytes: return b"".join(chain.splitlines()) -_Cert = NamedTuple("_Cert", [("pem_bytes", bytes), ("private_key", "Any"), ("fingerprint", bytes)]) +_Cert = NamedTuple( + "_Cert", [("pem_bytes", bytes), ("private_key", "Any"), ("fingerprint", bytes), ("sha256_fingerprint", bytes)] +) def load_pem_certificate(certificate_data: bytes, password: Optional[bytes] = None) -> _Cert: private_key = serialization.load_pem_private_key(certificate_data, password, backend=default_backend()) cert = x509.load_pem_x509_certificate(certificate_data, default_backend()) fingerprint = cert.fingerprint(hashes.SHA1()) # nosec - return _Cert(certificate_data, private_key, fingerprint) + sha256_fingerprint = cert.fingerprint(hashes.SHA256()) + return _Cert(certificate_data, private_key, fingerprint, sha256_fingerprint) def load_pkcs12_certificate(certificate_data: bytes, password: Optional[bytes] = None) -> _Cert: @@ -121,8 +124,9 @@ def load_pkcs12_certificate(certificate_data: bytes, password: Optional[bytes] = pem_bytes = b"".join(pem_sections) fingerprint = cert.fingerprint(hashes.SHA1()) # nosec + sha256_fingerprint = cert.fingerprint(hashes.SHA256()) - return _Cert(pem_bytes, private_key, fingerprint) + return _Cert(pem_bytes, private_key, fingerprint, sha256_fingerprint) def get_client_credential( @@ -166,7 +170,11 @@ def get_client_credential( if not isinstance(cert.private_key, RSAPrivateKey): raise ValueError("The certificate must have an RSA private key because RS256 is used for signing") - client_credential = {"private_key": cert.pem_bytes, "thumbprint": hexlify(cert.fingerprint).decode("utf-8")} + client_credential = { + "private_key": cert.pem_bytes, + "thumbprint": hexlify(cert.fingerprint).decode("utf-8"), + "sha256_thumbprint": hexlify(cert.sha256_fingerprint).decode("utf-8"), + } if password: client_credential["passphrase"] = password diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential.py b/sdk/identity/azure-identity/tests/test_certificate_credential.py index 24015b4da414..e65beb6f4310 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential.py @@ -10,7 +10,7 @@ from azure.identity import CertificateCredential, TokenCachePersistenceOptions from azure.identity._enums import RegionalAuthority from azure.identity._constants import EnvironmentVariables -from azure.identity._credentials.certificate import load_pkcs12_certificate +from azure.identity._credentials.certificate import get_client_credential, load_pkcs12_certificate from azure.identity._internal.user_agent import USER_AGENT from cryptography import x509 from cryptography.hazmat.backends import default_backend @@ -50,6 +50,27 @@ EC_CERT_PATH = os.path.join(os.path.dirname(__file__), "ec-certificate.pem") +@pytest.mark.parametrize("cert_path, cert_password", ALL_CERTS) +def test_get_client_credential_includes_sha256_thumbprint(cert_path, cert_password): + client_credential = get_client_credential(certificate_path=cert_path, password=cert_password) + + assert "sha256_thumbprint" in client_credential + + with open(cert_path, "rb") as f: + cert_bytes = f.read() + + if b"-----BEGIN" in cert_bytes: + cert = x509.load_pem_x509_certificate(cert_bytes, default_backend()) + else: + from cryptography.hazmat.primitives.serialization import pkcs12 + + pw = cert_password.encode("utf-8") if isinstance(cert_password, str) else cert_password + _, cert, _ = pkcs12.load_key_and_certificates(cert_bytes, pw, backend=default_backend()) + + expected_sha256 = cert.fingerprint(hashes.SHA256()).hex() + assert client_credential["sha256_thumbprint"] == expected_sha256 + + def test_non_rsa_key(): """The credential should raise ValueError when given a cert without an RSA private key""" with pytest.raises(ValueError, match=".*RS256.*"):