diff --git a/modules/proto_ipsec/ipsec.c b/modules/proto_ipsec/ipsec.c index eae353fa11..507b00f3f5 100644 --- a/modules/proto_ipsec/ipsec.c +++ b/modules/proto_ipsec/ipsec.c @@ -779,14 +779,34 @@ void ipsec_ctx_release(struct ipsec_ctx *ctx) { int free = 0; - if (!ctx) + LM_DBG("Releasing IPSec ctx %p (state %d), ref %d\n", ctx, ctx?ctx->state:0, ctx?ctx->ref:0); + + if (!ctx || ctx->ref <= 0) { + LM_DBG("ctx %p is NULL or has invalid ref %d\n", ctx, ctx?ctx->ref:0); + return; + } + + if (!VALID_IPSEC_STATE(ctx->state)) { + LM_DBG("ctx %p is not in a valid state %d\n", ctx, ctx->state); return; + } lock_get(&ctx->lock); free = ipsec_ctx_release_unsafe(ctx); + if (free) ctx->state = IPSEC_STATE_INVALID; /* mark as invalid */ lock_release(&ctx->lock); - if (free) + if (free) { + LM_DBG("IPSec ctx %p released\n", ctx); + if (ctx->user) { + ipsec_ctx_release_user(ctx); + ctx->user = NULL; /* avoid double release */ + } + ipsec_ctx_remove_free_tmp(ctx, 0); ipsec_ctx_free(ctx); + } else { + LM_DBG("IPSec ctx %p not released, ref=%d\n", ctx, ctx->ref); + } + } struct ipsec_ctx_tmp { @@ -828,6 +848,24 @@ void ipsec_ctx_push_user(struct ipsec_user *user, struct ipsec_ctx *ctx, enum ip } } +void ipsec_ctx_add_tmp(struct ipsec_ctx *ctx) +{ + struct ipsec_ctx_tmp *tmp = shm_malloc(sizeof *tmp); + if (!tmp) { + LM_ERR("could not push ctx in ue - dropping it!\n"); + return; + } + memset(tmp, 0, sizeof *tmp); + INIT_LIST_HEAD(&tmp->list); + tmp->expire = get_ticks() + ipsec_tmp_timeout; + tmp->ctx = ctx; + ctx->state = IPSEC_STATE_TMP; + + lock_get(ipsec_tmp_contexts_lock); + list_add_tail(&tmp->list, ipsec_tmp_contexts); + lock_release(ipsec_tmp_contexts_lock); +} + void ipsec_ctx_release_tmp_user(struct ipsec_user *user) { struct list_head *it, *safe; @@ -846,12 +884,46 @@ void ipsec_ctx_release_user(struct ipsec_ctx *ctx) { int release = 0; struct ipsec_user *user = ctx->user; + struct list_head *it, *safe, *prev = NULL; + struct list_head new; + struct ipsec_ctx *tmp_ctx; + + INIT_LIST_HEAD(&new); lock_get(&user->lock); + LM_DBG("User %.*s has %d contexts, list %d\n", + user->impi.len, user->impi.s, list_size(&user->sas), list_size(&user->list)); + + list_for_each_safe(it, safe, &user->sas) { + tmp_ctx = list_entry(it, struct ipsec_ctx, list); + LM_DBG("User: Context %p (state %d)\n", + tmp_ctx, tmp_ctx->state); + if (tmp_ctx == ctx) { + LM_DBG("User: Found context %p (state %d)\n", + tmp_ctx, tmp_ctx->state); + prev = it; + break; /* found */ + } + } + if (prev) { + LM_DBG("Found Context in user %.*s\n", + user->impi.len, user->impi.s); + list_cut_position(&new, &user->sas, prev); + if (list_size(&user->sas) > 0) { + LM_DBG("User %.*s has %d contexts left\n", + user->impi.len, user->impi.s, list_size(&user->sas)); + } else { + LM_DBG("User %.*s has no contexts left, releasing\n", + user->impi.len, user->impi.s); + release = 1; + } + } + if (list_is_valid(&ctx->list)) { list_del(&ctx->list); - release = 1; } + ctx->user = NULL; /* avoid double release */ + lock_release(&user->lock); if (release) ipsec_release_user(user); @@ -869,36 +941,61 @@ void ipsec_ctx_timer(unsigned int ticks, void* param) lock_get(ipsec_tmp_contexts_lock); list_for_each_safe(it, safe, ipsec_tmp_contexts) { + tmp = list_entry(it, struct ipsec_ctx_tmp, list); + LM_DBG("Context %p (state %d) expire at %u, current ticks %u\n", + tmp->ctx, tmp->ctx->state, (unsigned int)tmp->expire, ticks); if (ticks < tmp->expire) break; /* finished */ - prev = it; - IPSEC_CTX_UNREF(tmp->ctx); LM_DBG("IPSec ctx %p removing\n", tmp->ctx); + prev = it; + } + if (!prev) { + LM_DBG("No expired contexts found\n"); + lock_release(ipsec_tmp_contexts_lock); + return; /* nothing to do */ } /* unlink from the shared list */ if (prev) list_cut_position(&new, ipsec_tmp_contexts, prev); + LM_DBG("Unlinked %d expired contexts\n", list_size(&new)); lock_release(ipsec_tmp_contexts_lock); list_for_each_safe(it, safe, &new) { tmp = list_entry(it, struct ipsec_ctx_tmp, list); - lock_get(&tmp->ctx->lock); - if (tmp->ctx->state == IPSEC_STATE_TMP) { - tmp->ctx->state = IPSEC_STATE_INVALID; - LM_DBG("IPSec ctx %p expired\n", tmp->ctx); + LM_DBG("Context %p (state %d) Refcount %d\n", + tmp->ctx, tmp->ctx->state, tmp->ctx->ref); + if (VALID_IPSEC_STATE(tmp->ctx->state)) { + lock_get(&tmp->ctx->lock); + LM_DBG("Got lock for context %p (state %d)\n", tmp->ctx, tmp->ctx->state); + if (tmp->ctx->state == IPSEC_STATE_TMP) { + tmp->ctx->state = IPSEC_STATE_INVALID; + LM_ERR("IPSec ctx %p expired\n", tmp->ctx); + } + list_del(&tmp->list); + ctx = tmp->ctx; + free = IPSEC_CTX_UNREF_UNSAFE(tmp->ctx); + lock_release(&tmp->ctx->lock); + LM_DBG("Released lock for context %p (state %d), free=%d\n", ctx, ctx->state, free); + shm_free(tmp); + if (free) + ipsec_ctx_free(ctx); + LM_DBG("IPSec ctx %p deleted\n", ctx); + } else { + LM_DBG("IPSec ctx %p already deleted\n", tmp->ctx); + list_del(&tmp->list); + shm_free(tmp); } - list_del(&tmp->list); - ctx = tmp->ctx; - free = IPSEC_CTX_UNREF_UNSAFE(tmp->ctx); - lock_release(&tmp->ctx->lock); - shm_free(tmp); - if (free) - ipsec_ctx_free(ctx); } + LM_DBG("Finished removing expired contexts\n"); } void ipsec_ctx_remove_tmp(struct ipsec_ctx *ctx) +{ + return ipsec_ctx_remove_free_tmp(ctx, 1); +} + +void ipsec_ctx_remove_free_tmp(struct ipsec_ctx *ctx, int _free) { struct list_head *it, *safe; struct ipsec_ctx_tmp *tmp; @@ -911,7 +1008,8 @@ void ipsec_ctx_remove_tmp(struct ipsec_ctx *ctx) if (tmp->ctx != ctx) continue; list_del(&tmp->list); - free = IPSEC_CTX_UNREF_UNSAFE(tmp->ctx); + if (_free) + free = IPSEC_CTX_UNREF_UNSAFE(tmp->ctx); shm_free(tmp); break; } diff --git a/modules/proto_ipsec/ipsec.h b/modules/proto_ipsec/ipsec.h index 95561c06e4..4786b26982 100644 --- a/modules/proto_ipsec/ipsec.h +++ b/modules/proto_ipsec/ipsec.h @@ -42,6 +42,10 @@ enum ipsec_state { IPSEC_STATE_INVALID, }; +#define VALID_IPSEC_STATE(_s) \ + ((_s) == IPSEC_STATE_TMP || \ + (_s) == IPSEC_STATE_OK) + #define ipsec_socket mnl_socket #include "../../str.h" @@ -122,11 +126,13 @@ void ipsec_ctx_push(struct ipsec_ctx *ctx); struct ipsec_ctx *ipsec_ctx_get(void); void ipsec_ctx_push_user(struct ipsec_user *user, struct ipsec_ctx *ctx, enum ipsec_state state); void ipsec_ctx_push_tmp_user(struct ipsec_user *user, struct ipsec_ctx *ctx); +void ipsec_ctx_add_tmp(struct ipsec_ctx *ctx); void ipsec_ctx_release_tmp_user(struct ipsec_user *user); void ipsec_ctx_release_user(struct ipsec_ctx *ctx); void ipsec_ctx_release(struct ipsec_ctx *ctx); int ipsec_ctx_release_unsafe(struct ipsec_ctx *ctx); void ipsec_ctx_remove_tmp(struct ipsec_ctx *ctx); +void ipsec_ctx_remove_free_tmp(struct ipsec_ctx *ctx, int _free); void ipsec_ctx_extend_tmp(struct ipsec_ctx *ctx); #endif /* _IPSEC_H_ */ diff --git a/modules/proto_ipsec/ipsec_user.c b/modules/proto_ipsec/ipsec_user.c index cbd66c4e6e..13cff4fdee 100644 --- a/modules/proto_ipsec/ipsec_user.c +++ b/modules/proto_ipsec/ipsec_user.c @@ -434,8 +434,12 @@ struct ipsec_ctx *ipsec_get_ctx_user(struct ipsec_user *user, struct receive_inf lock_get(&user->lock); list_for_each(it, &user->sas) { ctx = list_entry(it, struct ipsec_ctx, list); - if (ctx->ue.port_c == ri->src_port && ctx->me.port_s == ri->dst_port) - break; + LM_DBG("checking ctx %p (state %d) for src_port %u, dst_port %u\n", ctx, ctx->state, ri->src_port, ri->dst_port); + if (VALID_IPSEC_STATE(ctx->state)) { + LM_DBG("ctx %p is valid, state %d\n", ctx, ctx->state); + if (ctx->ue.port_c == ri->src_port && ctx->me.port_s == ri->dst_port) + break; + } ctx = NULL; } lock_release(&user->lock); @@ -446,14 +450,25 @@ struct ipsec_ctx *ipsec_get_ctx_user_port(struct ipsec_user *user, unsigned shor { struct list_head *it; struct ipsec_ctx *ctx = NULL; + struct ipsec_ctx *tmp_ctx = NULL; lock_get(&user->lock); list_for_each(it, &user->sas) { ctx = list_entry(it, struct ipsec_ctx, list); - if (ctx->ue.port_s == port || ctx->ue.port_c) - break; + LM_DBG("checking ctx %p for port %u\n", ctx, port); + if (ctx->ue.port_s == port || ctx->ue.port_c) { + if (ctx->state == IPSEC_STATE_TMP) { + tmp_ctx = ctx; + } + if (ctx->state == IPSEC_STATE_OK) + break; + } ctx = NULL; } lock_release(&user->lock); + if (!ctx && tmp_ctx) { + /* Just found a temporary context, not an "OK" context, return temporary context */ + return tmp_ctx; + } return ctx; } diff --git a/modules/proto_ipsec/proto_ipsec.c b/modules/proto_ipsec/proto_ipsec.c index 0364a3e5fd..083f824c12 100644 --- a/modules/proto_ipsec/proto_ipsec.c +++ b/modules/proto_ipsec/proto_ipsec.c @@ -428,8 +428,31 @@ static int proto_ipsec_send(const struct socket_info* source, t = tm_ipsec.t_gett(); if (t && t != T_UNDEFINED) ctx = IPSEC_CTX_TM_GET(t); - if (!ctx) + if (ctx) { + LM_DBG("Got context %p (state = %d) from transaction\n", + ctx, ctx->state); + if (ctx->state == IPSEC_STATE_INVALID) { + LM_DBG("invalid state (%d) for %s:%hu\n", ctx->state, + ip_addr2a(&ip), port); + IPSEC_CTX_UNREF(ctx); + ctx = NULL; + } else { + IPSEC_CTX_REF(ctx); /* ref it, so we can use it */ + } + } + if (!ctx) { + LM_DBG("no valid IPSec context for %s:%hu\n", ip_addr2a(&ip), port); ctx = ipsec_get_ctx_ip_port(&ip, port); + if (ctx) { + LM_DBG("Got context %p (state = %d) for %s:%hu\n", + ctx, ctx->state, ip_addr2a(&ip), port); + if (ctx->state == IPSEC_STATE_INVALID) { + LM_DBG("invalid state (%d) for %s:%hu\n", ctx->state, + ip_addr2a(&ip), port); + ctx = NULL; + } + } + } if (ctx) { ipsec_si = ctx->client; if (source->proto == PROTO_UDP) @@ -443,8 +466,9 @@ static int proto_ipsec_send(const struct socket_info* source, to = &ue_to; } IPSEC_CTX_UNREF(ctx); + ctx = NULL; } else { - LM_WARN("could not find ctx for %s:%hu\n", ip_addr2a(&ip), port); + LM_WARN("could not find valid ctx for %s:%hu\n", ip_addr2a(&ip), port); } } else { /* this should be a TCP reply - preserve the socket */ @@ -730,7 +754,7 @@ static int w_ipsec_create(struct sip_msg *msg, int *_port_ps, int *_port_pc, user = ipsec_get_user(&req->rcv.src_ip, impi, impu); if (!user) { - LM_ERR("could not get a new IPSec user\n"); + LM_WARN("could not get a new IPSec user\n"); return -1; } @@ -752,10 +776,17 @@ static int w_ipsec_create(struct sip_msg *msg, int *_port_ps, int *_port_pc, * existing SA/ctx for this USER - try to locate it */ ctx = IPSEC_CTX_TM_GET(t); - if (ctx) + LM_DBG("got ctx %p for t=%p, state %d\n", ctx?ctx:0, t, ctx?ctx->state:-1); + if (ctx && ctx->state == IPSEC_STATE_OK && ctx->me.port_c != port_pc) prev_port_pc = ctx->me.port_c; else prev_port_pc = 0; + if (ctx && ctx->state == IPSEC_STATE_OK) { + LM_DBG("found existing IPSec context %p for user %.*s, marked as temporary, to be deleted\n", + ctx, impi->len, impi->s); + /* add the context as temporarily, so the "old" context gets removed on timer */ + ipsec_ctx_add_tmp(ctx); + } } else { prev_port_pc = 0; /* message was received unprotected - remove all temporary SAs */ @@ -770,6 +801,7 @@ static int w_ipsec_create(struct sip_msg *msg, int *_port_ps, int *_port_pc, ip_addr2a(&req->rcv.dst_ip), port_ps); goto release_user; } + /* locate the client IP */ sc = find_ipsec_socket_info(&req->rcv.dst_ip, port_pc, ss->port_no, prev_port_pc); if (!sc) { LM_INFO("could not find a client listener on %s:%d!\n", @@ -1038,6 +1070,7 @@ static int ipsec_handle_register(struct sip_msg *msg, struct socket_info *si) LM_ERR("could not find any IPSec context!\n"); goto drop_user; } + LM_DBG("got ctx %p for user %.*s, state %d\n", ctx, impi->len, impi->s, ctx->state); lock_get(&ctx->lock); switch (ctx->state) { case IPSEC_STATE_TMP: