diff --git a/wolfcrypt/src/wc_mlkem.c b/wolfcrypt/src/wc_mlkem.c index 67f34013f5e..6cd325888a7 100644 --- a/wolfcrypt/src/wc_mlkem.c +++ b/wolfcrypt/src/wc_mlkem.c @@ -1793,7 +1793,9 @@ int wc_MlKemKey_DecodePublicKey(MlKemKey* key, const unsigned char* in, if (ret == 0) { mlkemkey_decode_public(key->pub, key->pubSeed, p, k); - + ret = mlkem_check_public(key->pub, k); + } + if (ret == 0) { /* Calculate public hash. */ ret = MLKEM_HASH_H(&key->hash, in, len, key->h); } diff --git a/wolfcrypt/src/wc_mlkem_poly.c b/wolfcrypt/src/wc_mlkem_poly.c index b13d9305ac1..6e8ce95f751 100644 --- a/wolfcrypt/src/wc_mlkem_poly.c +++ b/wolfcrypt/src/wc_mlkem_poly.c @@ -6074,4 +6074,27 @@ void mlkem_to_bytes(byte* b, sword16* p, int k) } } +/** + * Check the public key values are smaller than the modulus. + * + * @param [in] pub Public key - vector. + * @param [in] k Number of polynomials in vector. + * @return 0 when all values are in range. + * @return PUBLIC_KEY_E when at least one value is out of range. + */ +int mlkem_check_public(sword16* pub, int k) +{ + int ret = 0; + int i; + + for (i = 0; i < k * MLKEM_N; i++) { + if (pub[i] >= MLKEM_Q) { + ret = PUBLIC_KEY_E; + break; + } + } + + return ret; +} + #endif /* WOLFSSL_WC_MLKEM */ diff --git a/wolfssl/wolfcrypt/wc_mlkem.h b/wolfssl/wolfcrypt/wc_mlkem.h index 91e015f368a..460b13ee344 100644 --- a/wolfssl/wolfcrypt/wc_mlkem.h +++ b/wolfssl/wolfcrypt/wc_mlkem.h @@ -238,6 +238,8 @@ WOLFSSL_LOCAL void mlkem_from_bytes(sword16* p, const byte* b, int k); WOLFSSL_LOCAL void mlkem_to_bytes(byte* b, sword16* p, int k); +WOLFSSL_LOCAL +int mlkem_check_public(sword16* p, int k); #ifdef USE_INTEL_SPEEDUP WOLFSSL_LOCAL