diff --git a/src/crypto/openssl.c b/src/crypto/openssl.c index 6f9b0cf8..3df96e05 100644 --- a/src/crypto/openssl.c +++ b/src/crypto/openssl.c @@ -431,6 +431,78 @@ void free_ssl_context(SSL_CTX *ssl_context) { SSL_CTX_free(ssl_context); } +int number_of_digits(int number) { + if (number < 0) number = (number == INT_MIN) ? INT_MAX : -number; + if (number < 10) return 1; + if (number < 100) return 2; + if (number < 1000) return 3; + if (number < 10000) return 4; + if (number < 100000) return 5; + if (number < 1000000) return 6; + if (number < 10000000) return 7; + if (number < 100000000) return 8; + if (number < 1000000000) return 9; + return 10; +} + +int add_alpn(const char *alpn, char* out) +{ + unsigned int alpn_len, all_len; + + alpn_len = strlen(alpn); + if (alpn_len > 255 || alpn_len == 0) + return -1; + + all_len = strlen(out); + out[all_len] = alpn_len; + memcpy(&out[all_len + 1], alpn, alpn_len); + out[all_len + 1 + alpn_len] = '\0'; + return 0; +} + +char* build_alpn(const char **alpn, const unsigned int alpn_count) +{ + unsigned int length_letters = 0; + unsigned int total_length = 0; + for (unsigned int i = 0; i < alpn_count; i++) { + int cur_length = strlen(alpn[i]); + total_length += cur_length + number_of_digits(cur_length); + } + + char * alpn_wire_string = (char *) malloc(total_length + 1); + for (unsigned int i = 0; i < alpn_count; i++) { + int res = add_alpn(alpn[i], alpn_wire_string); + if (res == -1) { + free(alpn_wire_string); + return NULL; + } + } + + return alpn_wire_string; +} + +void free_alpn(void *parent, void *ptr, CRYPTO_EX_DATA *ad, int idx, long argl, void *argp) +{ + if (ptr != NULL) { + char* alpn_wire_string = (char*)ptr; + free(alpn_wire_string); + } +} + +int alpn_cb(SSL *ssl, const unsigned char **out, unsigned char *outlen, const unsigned char *in, unsigned int inlen, void *arg) { + int r; + char* alpn_wire_string = (char*)arg; + if (arg == NULL || strlen(alpn_wire_string) == 0) { + return SSL_TLSEXT_ERR_NOACK; + } + r = SSL_select_next_proto((unsigned char **) out, outlen, in, inlen, (unsigned char *)alpn_wire_string, strlen(alpn_wire_string)); + if (r == OPENSSL_NPN_NO_OVERLAP) { + return SSL_TLSEXT_ERR_ALERT_FATAL; + } else { + return SSL_TLSEXT_ERR_OK; + } +} + /* This function should take any options and return SSL_CTX - which has to be free'd with * our destructor function - free_ssl_context() */ SSL_CTX *create_ssl_context_from_options(struct us_socket_context_options_t options) { @@ -530,6 +602,20 @@ SSL_CTX *create_ssl_context_from_options(struct us_socket_context_options_t opti } } + if (options.alpn_protocols) { + static int alpn_index = -1; + if (alpn_index == -1) { + alpn_index = SSL_CTX_get_ex_new_index(0, NULL, NULL, NULL, free_alpn); + } + char *alpn_wire_string = build_alpn(options.alpn_protocols, options.alpn_protocols_length); + if (alpn_wire_string == NULL) { + free_ssl_context(ssl_context); + return NULL; + } + SSL_CTX_set_ex_data(ssl_context, alpn_index, alpn_wire_string); + SSL_CTX_set_alpn_select_cb(ssl_context, alpn_cb, alpn_wire_string); + } + /* This must be free'd with free_ssl_context, not SSL_CTX_free */ return ssl_context; } diff --git a/src/libusockets.h b/src/libusockets.h index dde4bb67..3a5fc367 100644 --- a/src/libusockets.h +++ b/src/libusockets.h @@ -131,6 +131,8 @@ struct us_socket_context_options_t { const char *dh_params_file_name; const char *ca_file_name; const char *ssl_ciphers; + const char **alpn_protocols; + unsigned int alpn_protocols_length; int ssl_prefer_low_memory_usage; /* Todo: rename to prefer_low_memory_usage and apply for TCP as well */ };