diff --git a/src/privsep-root.c b/src/privsep-root.c index 28dfcc4d..5f9b9401 100644 --- a/src/privsep-root.c +++ b/src/privsep-root.c @@ -71,21 +71,20 @@ struct psr_ctx { struct psr_error psr_error; size_t psr_datalen; void *psr_data; - size_t psr_mdatalen; - void *psr_mdata; - bool psr_usemdata; + bool psr_mallocdata; }; static ssize_t -ps_root_readerrorcb(struct psr_ctx *psr_ctx) +ps_root_readerrorcb(struct psr_ctx *pc) { - struct dhcpcd_ctx *ctx = psr_ctx->psr_ctx; + struct dhcpcd_ctx *ctx = pc->psr_ctx; int fd = PS_ROOT_FD(ctx); - struct psr_error *psr_error = &psr_ctx->psr_error; + struct psr_error *psr_error = &pc->psr_error; struct iovec iov[] = { { .iov_base = psr_error, .iov_len = sizeof(*psr_error) }, - { .iov_base = NULL, .iov_len = 0 }, + { .iov_base = pc->psr_data, .iov_len = pc->psr_datalen }, }; + struct msghdr msg = { .msg_iov = iov, .msg_iovlen = __arraycount(iov) }; ssize_t len; #define PSR_ERROR(e) \ @@ -98,81 +97,84 @@ ps_root_readerrorcb(struct psr_ctx *psr_ctx) if (eloop_waitfd(fd) == -1) PSR_ERROR(errno); - len = recv(fd, psr_error, sizeof(*psr_error), MSG_PEEK); + if (!pc->psr_mallocdata) + goto recv; + + /* We peek at the psr_error structure to tell us how much of a buffer + * we need to read the whole packet. */ + msg.msg_iovlen--; + len = recvmsg(fd, &msg, MSG_PEEK | MSG_WAITALL); if (len == -1) PSR_ERROR(errno); - else if ((size_t)len < sizeof(*psr_error)) - PSR_ERROR(EINVAL); - if (psr_error->psr_datalen > SSIZE_MAX) - PSR_ERROR(ENOBUFS); - if (psr_ctx->psr_usemdata && - psr_error->psr_datalen > psr_ctx->psr_mdatalen) - { - void *d = realloc(psr_ctx->psr_mdata, psr_error->psr_datalen); - if (d == NULL) - PSR_ERROR(errno); - psr_ctx->psr_mdata = d; - psr_ctx->psr_mdatalen = psr_error->psr_datalen; + /* After this point, we MUST do another recvmsg even on a failure + * to remove the message after peeking. */ + if ((size_t)len < sizeof(*psr_error)) { + /* We can't use the header to work out buffers, so + * remove the message and bail. */ + (void)recvmsg(fd, &msg, MSG_WAITALL); + PSR_ERROR(EINVAL); } - if (psr_error->psr_datalen != 0) { - if (psr_ctx->psr_usemdata) - iov[1].iov_base = psr_ctx->psr_mdata; - else { - if (psr_error->psr_datalen > psr_ctx->psr_datalen) - PSR_ERROR(ENOBUFS); - iov[1].iov_base = psr_ctx->psr_data; - } + + /* No data to read? Unlikely but ... */ + if (psr_error->psr_datalen == 0) + goto recv; + + pc->psr_data = malloc(psr_error->psr_datalen); + if (pc->psr_data != NULL) { + iov[1].iov_base = pc->psr_data; iov[1].iov_len = psr_error->psr_datalen; + msg.msg_iovlen++; } - len = readv(fd, iov, __arraycount(iov)); +recv: + len = recvmsg(fd, &msg, MSG_WAITALL); if (len == -1) PSR_ERROR(errno); - else if ((size_t)len != sizeof(*psr_error) + psr_error->psr_datalen) + else if ((size_t)len < sizeof(*psr_error)) PSR_ERROR(EINVAL); + else if (msg.msg_flags & MSG_TRUNC) + PSR_ERROR(ENOBUFS); + else if ((size_t)len != sizeof(*psr_error) + psr_error->psr_datalen) { + logerrx("%s: recvmsg returned %zd, expecting %zu", __func__, + len, sizeof(*psr_error) + psr_error->psr_datalen); + PSR_ERROR(EBADMSG); + } return len; } ssize_t ps_root_readerror(struct dhcpcd_ctx *ctx, void *data, size_t len) { - struct psr_ctx *pc = ctx->ps_root->psp_data; + struct psr_ctx pc = { + .psr_ctx = ctx, + .psr_data = data, + .psr_datalen = len, + .psr_mallocdata = false + }; - pc->psr_data = data; - pc->psr_datalen = len; - pc->psr_usemdata = false; - ps_root_readerrorcb(pc); + ps_root_readerrorcb(&pc); - errno = pc->psr_error.psr_errno; - return pc->psr_error.psr_result; + errno = pc.psr_error.psr_errno; + return pc.psr_error.psr_result; } ssize_t ps_root_mreaderror(struct dhcpcd_ctx *ctx, void **data, size_t *len) { - struct psr_ctx *pc = ctx->ps_root->psp_data; - void *d; + struct psr_ctx pc = { + .psr_ctx = ctx, + .psr_data = NULL, + .psr_datalen = 0, + .psr_mallocdata = true + }; - pc->psr_usemdata = true; - ps_root_readerrorcb(pc); + ps_root_readerrorcb(&pc); - if (pc->psr_error.psr_datalen != 0) { - if (pc->psr_error.psr_datalen > pc->psr_mdatalen) { - errno = EINVAL; - return -1; - } - d = malloc(pc->psr_error.psr_datalen); - if (d == NULL) - return -1; - memcpy(d, pc->psr_mdata, pc->psr_error.psr_datalen); - } else - d = NULL; - - errno = pc->psr_error.psr_errno; - *data = d; - *len = pc->psr_error.psr_datalen; - return pc->psr_error.psr_result; + errno = pc.psr_error.psr_errno; + *data = pc.psr_data; + *len = pc.psr_error.psr_datalen; + return pc.psr_error.psr_result; } static ssize_t @@ -196,6 +198,8 @@ ps_root_writeerror(struct dhcpcd_ctx *ctx, ssize_t result, logdebugx("%s: result %zd errno %d", __func__, result, errno); #endif + if (len == 0) + msg.msg_iovlen = 1; err = sendmsg(fd, &msg, MSG_EOR); /* Error sending the message? Try sending the error of sending. */ @@ -204,8 +208,8 @@ ps_root_writeerror(struct dhcpcd_ctx *ctx, ssize_t result, __func__, result, data, len); psr.psr_result = err; psr.psr_errno = errno; - iov[1].iov_base = NULL; - iov[1].iov_len = 0; + psr.psr_datalen = 0; + msg.msg_iovlen = 1; err = sendmsg(fd, &msg, MSG_EOR); } @@ -602,7 +606,7 @@ ps_root_recvmsgcb(void *arg, struct ps_msghdr *psm, struct msghdr *msg) break; } - err = ps_root_writeerror(ctx, err, rlen != 0 ? rdata : 0, rlen); + err = ps_root_writeerror(ctx, err, rdata, rlen); if (free_rdata) free(rdata); return err; @@ -843,17 +847,6 @@ ps_root_log(void *arg, unsigned short events) logerr(__func__); } -static void -ps_root_freepsdata(void *arg) -{ - struct psr_ctx *pc = arg; - - if (pc == NULL) - return; - free(pc->psr_mdata); - free(pc); -} - pid_t ps_root_start(struct dhcpcd_ctx *ctx) { @@ -864,7 +857,6 @@ ps_root_start(struct dhcpcd_ctx *ctx) struct ps_process *psp; int logfd[2] = { -1, -1}, datafd[2] = { -1, -1}; pid_t pid; - struct psr_ctx *pc; if (xsocketpair(AF_UNIX, SOCK_SEQPACKET | SOCK_CXNB, 0, logfd) == -1) return -1; @@ -883,27 +875,15 @@ ps_root_start(struct dhcpcd_ctx *ctx) return -1; #endif - pc = calloc(1, sizeof(*pc)); - if (pc == NULL) - return -1; - pc->psr_ctx = ctx; - psp = ctx->ps_root = ps_newprocess(ctx, &id); if (psp == NULL) - { - free(pc); return -1; - } - psp->psp_freedata = ps_root_freepsdata; + strlcpy(psp->psp_name, "privileged proxy", sizeof(psp->psp_name)); pid = ps_startprocess(psp, ps_root_recvmsg, NULL, ps_root_startcb, PSF_ELOOP); - if (pid == -1) { - free(pc); + if (pid == -1) return -1; - } - - psp->psp_data = pc; if (pid == 0) { ctx->ps_log_fd = logfd[0]; /* Keep open to pass to processes */ diff --git a/src/privsep.c b/src/privsep.c index 7a81c196..1cb2dd77 100644 --- a/src/privsep.c +++ b/src/privsep.c @@ -761,11 +761,6 @@ ps_freeprocess(struct ps_process *psp) TAILQ_REMOVE(&ctx->ps_processes, psp, next); - if (psp->psp_freedata != NULL) - psp->psp_freedata(psp->psp_data); - else - free(psp->psp_data); - if (psp->psp_fd != -1) { eloop_event_delete(ctx->eloop, psp->psp_fd); close(psp->psp_fd); diff --git a/src/privsep.h b/src/privsep.h index 37380d4c..496c9cd5 100644 --- a/src/privsep.h +++ b/src/privsep.h @@ -184,8 +184,6 @@ struct ps_process { char psp_name[PSP_NAMESIZE]; uint16_t psp_proto; const char *psp_protostr; - void *psp_data; - void (*psp_freedata)(void *); bool psp_started; #ifdef INET