Skip to content
24 changes: 14 additions & 10 deletions src/dtls13.c
Original file line number Diff line number Diff line change
Expand Up @@ -979,31 +979,35 @@ static int Dtls13SendFragmentedInternal(WOLFSSL* ssl)
{
int fragLength, rlHeaderLength;
int remainingSize, maxFragment;
int recordLength;
int recordLength, outputSz;
byte isEncrypted;
byte* output;
int ret;

isEncrypted = Dtls13TypeIsEncrypted(
(enum HandShakeType)ssl->dtls13FragHandshakeType);
rlHeaderLength = Dtls13GetRlHeaderLength(ssl, isEncrypted);
maxFragment = wolfSSL_GetMaxFragSize(ssl, MAX_RECORD_SIZE);
maxFragment = wolfssl_local_GetMaxPlaintextSize(ssl);

remainingSize = ssl->dtls13MessageLength - ssl->dtls13FragOffset;

while (remainingSize > 0) {

fragLength = maxFragment - rlHeaderLength - DTLS_HANDSHAKE_HEADER_SZ;

recordLength = maxFragment;
fragLength = maxFragment - DTLS_HANDSHAKE_HEADER_SZ;

if (fragLength > remainingSize) {
fragLength = remainingSize;
recordLength =
fragLength + rlHeaderLength + DTLS_HANDSHAKE_HEADER_SZ;
}

ret = CheckAvailableSize(ssl, recordLength + MAX_MSG_EXTRA);
recordLength = fragLength + rlHeaderLength + DTLS_HANDSHAKE_HEADER_SZ;
outputSz = wolfssl_local_GetRecordSize(ssl,
fragLength + DTLS_HANDSHAKE_HEADER_SZ, isEncrypted);
if (outputSz < 0) {
Dtls13FreeFragmentsBuffer(ssl);
return recordLength;
}

ret = CheckAvailableSize(ssl, outputSz);
if (ret != 0) {
Dtls13FreeFragmentsBuffer(ssl);
return ret;
Expand All @@ -1025,7 +1029,7 @@ static int Dtls13SendFragmentedInternal(WOLFSSL* ssl)

ret = Dtls13SendOneFragmentRtx(ssl,
(enum HandShakeType)ssl->dtls13FragHandshakeType,
(word16)recordLength + MAX_MSG_EXTRA, output, (word32)recordLength, 0);
(word16)outputSz, output, (word32)recordLength, 0);
if (ret == WC_NO_ERR_TRACE(WANT_WRITE)) {
ssl->dtls13FragOffset += fragLength;
return ret;
Expand Down Expand Up @@ -2018,7 +2022,7 @@ int Dtls13HandshakeSend(WOLFSSL* ssl, byte* message, word16 outputSize,
return ret;
}

maxFrag = wolfSSL_GetMaxFragSize(ssl, MAX_RECORD_SIZE);
maxFrag = wolfssl_local_GetMaxPlaintextSize(ssl);
maxLen = length;

if (handshakeType == key_update)
Expand Down
214 changes: 138 additions & 76 deletions src/internal.c
Original file line number Diff line number Diff line change
Expand Up @@ -10745,7 +10745,13 @@ static int SendHandshakeMsg(WOLFSSL* ssl, byte* input, word32 inputSz,
inputSz += HANDSHAKE_HEADER_SZ;
rHdrSz = RECORD_HEADER_SZ;
}
maxFrag = wolfSSL_GetMaxFragSize(ssl, (int)inputSz);
maxFrag = wolfssl_local_GetMaxPlaintextSize(ssl);
#ifdef WOLFSSL_DTLS
if (ssl->options.dtls) {
/* In DTLS the handshake header is per fragment */
maxFrag -= DTLS_HANDSHAKE_HEADER_SZ;
}
#endif

/* Make sure input is not the ssl output buffer as this
* function doesn't handle that */
Expand Down Expand Up @@ -24793,9 +24799,12 @@ int SendCertificate(WOLFSSL* ssl)
if (ssl->fragOffset != 0)
length -= (ssl->fragOffset + headerSz);

maxFragment = MAX_RECORD_SIZE;

maxFragment = (word32)wolfSSL_GetMaxFragSize(ssl, (int)maxFragment);
maxFragment = (word32)wolfssl_local_GetMaxPlaintextSize(ssl);
if (ssl->options.dtls)
maxFragment -= DTLS_HANDSHAKE_HEADER_SZ;
else
maxFragment -= HANDSHAKE_HEADER_SZ;

while (length > 0 && ret == 0) {
byte* output = NULL;
Expand Down Expand Up @@ -25572,27 +25581,6 @@ int IsSCR(WOLFSSL* ssl)
}


#ifdef WOLFSSL_DTLS
static int ModifyForMTU(WOLFSSL* ssl, int buffSz, int outputSz, int mtuSz)
{
int recordExtra = outputSz - buffSz;

(void)ssl;

if (recordExtra > 0 && outputSz > mtuSz) {
buffSz = mtuSz - recordExtra;
#ifndef WOLFSSL_AEAD_ONLY
/* Subtract a block size to be certain that returned fragment
* size won't get more padding. */
if (ssl->specs.cipher_type == block)
buffSz -= ssl->specs.block_size;
#endif
}

return buffSz;
}
#endif /* WOLFSSL_DTLS */

#if !defined(NO_TLS) && defined(WOLFSSL_TLS13) && \
!defined(WOLFSSL_TLS13_IGNORE_AEAD_LIMITS)
/*
Expand Down Expand Up @@ -25970,31 +25958,33 @@ int SendData(WOLFSSL* ssl, const void* data, size_t sz)
}
#endif /* WOLFSSL_DTLS13 */

buffSz = wolfSSL_GetMaxFragSize(ssl, (word32)sz - sent);

if (sent == (word32)sz) break;

#if defined(WOLFSSL_DTLS) && !defined(WOLFSSL_NO_DTLS_SIZE_CHECK)
if (ssl->options.dtls && ((size_t)buffSz < (word32)sz - sent)) {
error = DTLS_SIZE_ERROR;
ssl->error = error;
WOLFSSL_ERROR(error);
return error;
}
#endif
outputSz = buffSz + COMP_EXTRA + DTLS_RECORD_HEADER_SZ;
if (IsEncryptionOn(ssl, 1) || ssl->options.tls1_3)
outputSz += cipherExtraData(ssl);

#if defined(WOLFSSL_DTLS) && defined(WOLFSSL_DTLS_CID)
buffSz = (word32)sz - sent;
outputSz = wolfssl_local_GetRecordSize(ssl, (word32)buffSz, 1);
#if defined(WOLFSSL_DTLS)
if (ssl->options.dtls) {
byte cidSz = 0;
if ((cidSz = DtlsGetCidTxSize(ssl)) > 0)
outputSz += cidSz + 1; /* +1 for inner content type */
}
#if defined(WOLFSSL_DTLS_MTU)
int mtu = ssl->dtlsMtuSz;
#else
int mtu = MAX_MTU;
#endif
if (outputSz > mtu) {
#if defined(WOLFSSL_NO_DTLS_SIZE_CHECK)
/* split instead of error out */
buffSz = min(buffSz, wolfssl_local_GetMaxPlaintextSize(ssl));
outputSz = wolfssl_local_GetRecordSize(ssl, (word32)buffSz, 1);
#else
error = DTLS_SIZE_ERROR;
ssl->error = error;
WOLFSSL_ERROR(error);
return error;
#endif /* WOLFSSL_NO_DTLS_SIZE_CHECK */
}
}
#endif /* WOLFSSL_DTLS */

/* check for available size */
/* check for available size, it does also DTLS MTU checks */
if ((ret = CheckAvailableSize(ssl, outputSz)) != 0)
return (ssl->error = ret);

Expand Down Expand Up @@ -41811,53 +41801,125 @@ int wolfSSL_AsyncPush(WOLFSSL* ssl, WC_ASYNC_DEV* asyncDev)

#endif /* WOLFSSL_ASYNC_CRYPT */

#if !defined(NO_TLS)
/** Return the record size for sending payloadSz of data
* @param ssl WOLFSSL object
* @param payloadSz Size of data to be sent in record
* @param isEncrypted 1 if encryption is on, 0 if not
* @return Record size for sending payloadSz of data
*/
int wolfssl_local_GetRecordSize(WOLFSSL *ssl, int payloadSz, int isEncrypted)
{
int recordSz;

if (ssl == NULL)
return BAD_FUNC_ARG;

if (isEncrypted) {
recordSz = BuildMessage(ssl, NULL, 0, NULL, payloadSz, application_data,
0, 1, 0, CUR_ORDER);
/* use a safe upper bound in case of error */
if (recordSz < 0) {
recordSz = payloadSz + RECORD_HEADER_SZ
+ cipherExtraData(ssl) + COMP_EXTRA;
if (ssl->options.dtls) {
recordSz += DTLS_RECORD_EXTRA;
}
}
}
else {
recordSz = payloadSz + RECORD_HEADER_SZ;
if (ssl->options.dtls) {
recordSz += DTLS_RECORD_EXTRA;
}
}
return recordSz;
}
#endif

/** Return the maximum plaintext size for the current Max Fragment and MTU.
* @param ssl WOLFSSL object containing ciphersuite information.
* @return Max plaintext size for current MTU
*/
int wolfssl_local_GetMaxPlaintextSize(WOLFSSL *ssl)
{
int maxFrag;

if (ssl == NULL)
return BAD_FUNC_ARG;

maxFrag = wolfSSL_GetMaxFragSize(ssl);

#if defined(WOLFSSL_DTLS)
if (IsDtlsNotSctpMode(ssl)) {
int recordSz;
int mtu;

#if defined(WOLFSSL_DTLS_MTU)
mtu = ssl->dtlsMtuSz;
#else
mtu = MAX_MTU;
#endif

recordSz = wolfssl_local_GetRecordSize(ssl, maxFrag,
IsEncryptionOn(ssl, 1));
/* record size of maxFrag fits in MTU */
if (recordSz <= mtu) {
return maxFrag;
}

/* adjust plaintext size to fit in MTU */
maxFrag -= (recordSz - mtu);
if (maxFrag <= 0) {
WOLFSSL_MSG("MTU too small for any plaintext");
return DTLS_SIZE_ERROR;
}

#ifndef WOLFSSL_AEAD_ONLY
/* For block ciphers, reducing maxFrag may change padding alignment,
* causing the record to still exceed MTU. Iterate to find exact fit.
* Converges in at most 2 iterations due to bounded padding variance. */
if (ssl->specs.cipher_type == block) {
int iter;
for (iter = 0; iter < 2; iter++) {
recordSz = wolfssl_local_GetRecordSize(ssl, maxFrag,
IsEncryptionOn(ssl, 1));
if (recordSz <= mtu)
break;
maxFrag -= (recordSz - mtu);
}
if (recordSz > mtu) {
/* this should never happen */
WOLFSSL_MSG("Failed to fit record in MTU after padding adjust");
return DTLS_SIZE_ERROR;
}
}
#endif
}
#endif /* WOLFSSL_DTLS */

return maxFrag;
}
/**
* Return the max fragment size. This is essentially the maximum
* fragment_length available.
* @param ssl WOLFSSL object containing ciphersuite information.
* @param maxFragment The amount of space we want to check is available. This
* is only the fragment length WITHOUT the (D)TLS headers.
* @return Max fragment size
*/
int wolfSSL_GetMaxFragSize(WOLFSSL* ssl, int maxFragment)
int wolfSSL_GetMaxFragSize(WOLFSSL* ssl)
{
(void) ssl; /* Avoid compiler warnings */
int maxFragment;

if (maxFragment > MAX_RECORD_SIZE) {
maxFragment = MAX_RECORD_SIZE;
}
if (ssl == NULL)
return BAD_FUNC_ARG;

maxFragment = MAX_RECORD_SIZE;

#ifdef HAVE_MAX_FRAGMENT
if ((ssl->max_fragment != 0) && ((word16)maxFragment > ssl->max_fragment)) {
maxFragment = ssl->max_fragment;
}
#endif /* HAVE_MAX_FRAGMENT */
#ifdef WOLFSSL_DTLS
if (IsDtlsNotSctpMode(ssl)) {
int outputSz, mtuSz;

/* Given a input buffer size of maxFragment, how big will the
* encrypted output be? */
if (IsEncryptionOn(ssl, 1)) {
outputSz = BuildMessage(ssl, NULL, 0, NULL,
maxFragment + DTLS_HANDSHAKE_HEADER_SZ,
application_data, 0, 1, 0, CUR_ORDER);
}
else {
outputSz = maxFragment + DTLS_RECORD_HEADER_SZ +
DTLS_HANDSHAKE_HEADER_SZ;
}

/* Readjust maxFragment for MTU size. */
#if defined(WOLFSSL_DTLS_MTU)
mtuSz = ssl->dtlsMtuSz;
#else
mtuSz = MAX_MTU;
#endif
maxFragment = ModifyForMTU(ssl, maxFragment, outputSz, mtuSz);
}
#endif

return maxFragment;
}
Expand Down
5 changes: 2 additions & 3 deletions src/ssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -2905,7 +2905,7 @@ int wolfSSL_GetMaxOutputSize(WOLFSSL* ssl)
return BAD_FUNC_ARG;
}

return wolfSSL_GetMaxFragSize(ssl, OUTPUT_RECORD_SIZE);
return min(OUTPUT_RECORD_SIZE, wolfssl_local_GetMaxPlaintextSize(ssl));
}


Expand All @@ -2925,8 +2925,7 @@ int wolfSSL_GetOutputSize(WOLFSSL* ssl, int inSz)
if (inSz > maxSize)
return INPUT_SIZE_E;

return BuildMessage(ssl, NULL, 0, NULL, inSz, application_data, 0, 1, 0,
CUR_ORDER);
return wolfssl_local_GetRecordSize(ssl, inSz, 1);
}


Expand Down
4 changes: 2 additions & 2 deletions src/tls13.c
Original file line number Diff line number Diff line change
Expand Up @@ -4522,7 +4522,7 @@ int SendTls13ClientHello(WOLFSSL* ssl)

{
#ifdef WOLFSSL_DTLS_CH_FRAG
word16 maxFrag = wolfSSL_GetMaxFragSize(ssl, MAX_RECORD_SIZE);
word16 maxFrag = wolfssl_local_GetMaxPlaintextSize(ssl);
word16 lenWithoutExts = args->length;
#endif

Expand Down Expand Up @@ -8872,7 +8872,7 @@ static int SendTls13Certificate(WOLFSSL* ssl)
if (ssl->fragOffset != 0)
length -= (ssl->fragOffset + headerSz);

maxFragment = (word32)wolfSSL_GetMaxFragSize(ssl, MAX_RECORD_SIZE);
maxFragment = (word32)wolfssl_local_GetMaxPlaintextSize(ssl);

extIdx = 0;

Expand Down
Loading