seatd/common/connection.c

310 lines
8 KiB
C
Raw Normal View History

#include <errno.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <unistd.h>
#include "connection.h"
static inline uint32_t connection_buffer_mask(const uint32_t idx) {
return idx & (CONNECTION_BUFFER_SIZE - 1);
}
static inline uint32_t connection_buffer_size(const struct connection_buffer *b) {
return b->head - b->tail;
}
static inline void connection_buffer_consume(struct connection_buffer *b, const size_t size) {
b->tail += size;
}
static inline void connection_buffer_restore(struct connection_buffer *b, const size_t size) {
b->tail -= size;
}
/*
* connection_buffer_get_iov prepares I/O vectors pointing to our ring buffer.
* Two may be used if the buffer has wrapped around.
*/
static void connection_buffer_get_iov(struct connection_buffer *b, struct iovec *iov, int *count) {
uint32_t head = connection_buffer_mask(b->head);
uint32_t tail = connection_buffer_mask(b->tail);
if (tail < head) {
iov[0].iov_base = b->data + tail;
iov[0].iov_len = head - tail;
*count = 1;
} else if (head == 0) {
iov[0].iov_base = b->data + tail;
iov[0].iov_len = sizeof b->data - tail;
*count = 1;
} else {
iov[0].iov_base = b->data + tail;
iov[0].iov_len = sizeof b->data - tail;
iov[1].iov_base = b->data;
iov[1].iov_len = head;
*count = 2;
}
}
/*
* connection_buffer_put_iov prepares I/O vectors pointing to our ring buffer.
* Two may be used if the buffer has wrapped around.
*/
static void connection_buffer_put_iov(struct connection_buffer *b, struct iovec *iov, int *count) {
uint32_t head = connection_buffer_mask(b->head);
uint32_t tail = connection_buffer_mask(b->tail);
if (head < tail) {
iov[0].iov_base = b->data + head;
iov[0].iov_len = tail - head;
*count = 1;
} else if (tail == 0) {
iov[0].iov_base = b->data + head;
iov[0].iov_len = sizeof b->data - head;
*count = 1;
} else {
iov[0].iov_base = b->data + head;
iov[0].iov_len = sizeof b->data - head;
iov[1].iov_base = b->data;
iov[1].iov_len = tail;
*count = 2;
}
}
/*
* connection_buffer_copy copies from our ring buffer into a linear buffer.
*/
static void connection_buffer_copy(const struct connection_buffer *b, void *data, const size_t count) {
uint32_t tail = connection_buffer_mask(b->tail);
if (tail + count <= sizeof b->data) {
memcpy(data, b->data + tail, count);
return;
}
uint32_t size = sizeof b->data - tail;
memcpy(data, b->data + tail, size);
memcpy((char *)data + size, b->data, count - size);
}
/*
* connection_buffer_copy copies from a linear buffer into our ring buffer.
*/
static int connection_buffer_put(struct connection_buffer *b, const void *data, const size_t count) {
if (count > sizeof(b->data)) {
errno = EOVERFLOW;
return -1;
}
uint32_t head = connection_buffer_mask(b->head);
if (head + count <= sizeof b->data) {
memcpy(b->data + head, data, count);
} else {
uint32_t size = sizeof b->data - head;
memcpy(b->data + head, data, size);
memcpy(b->data, (const char *)data + size, count - size);
}
b->head += count;
return 0;
}
/*
* close_fds closes all fds within a connection_buffer
*/
static void connection_buffer_close_fds(struct connection_buffer *buffer) {
size_t size = connection_buffer_size(buffer);
if (size == 0) {
return;
}
int fds[sizeof(buffer->data) / sizeof(int)];
connection_buffer_copy(buffer, fds, size);
int count = size / sizeof fds[0];
size = count * sizeof fds[0];
for (int idx = 0; idx < count; idx++) {
close(fds[idx]);
}
connection_buffer_consume(buffer, size);
}
/*
* build_cmsg prepares a cmsg from a buffer full of fds
*/
static void build_cmsg(struct connection_buffer *buffer, char *data, int *clen) {
size_t size = connection_buffer_size(buffer);
if (size > MAX_FDS * sizeof(int)) {
size = MAX_FDS * sizeof(int);
}
if (size <= 0) {
*clen = 0;
return;
}
struct cmsghdr *cmsg = (struct cmsghdr *)data;
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
cmsg->cmsg_len = CMSG_LEN(size);
connection_buffer_copy(buffer, CMSG_DATA(cmsg), size);
*clen = cmsg->cmsg_len;
}
static int decode_cmsg(struct connection_buffer *buffer, struct msghdr *msg) {
bool overflow = false;
struct cmsghdr *cmsg;
for (cmsg = CMSG_FIRSTHDR(msg); cmsg != NULL; cmsg = CMSG_NXTHDR(msg, cmsg)) {
if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS) {
continue;
}
size_t size = cmsg->cmsg_len - CMSG_LEN(0);
size_t max = sizeof(buffer->data) - connection_buffer_size(buffer);
if (size > max || overflow) {
overflow = true;
size /= sizeof(int);
for (size_t idx = 0; idx < size; idx++) {
close(((int *)CMSG_DATA(cmsg))[idx]);
}
} else if (connection_buffer_put(buffer, CMSG_DATA(cmsg), size) < 0) {
return -1;
}
}
if (overflow) {
errno = EOVERFLOW;
return -1;
}
return 0;
}
int connection_read(struct connection *connection) {
if (connection_buffer_size(&connection->in) >= sizeof(connection->in.data)) {
errno = EOVERFLOW;
return -1;
}
int count;
struct iovec iov[2];
connection_buffer_put_iov(&connection->in, iov, &count);
char cmsg[CMSG_LEN(CONNECTION_BUFFER_SIZE)];
struct msghdr msg = {
.msg_name = NULL,
.msg_namelen = 0,
.msg_iov = iov,
.msg_iovlen = count,
.msg_control = cmsg,
.msg_controllen = sizeof cmsg,
.msg_flags = 0,
};
int len;
do {
len = recvmsg(connection->fd, &msg, MSG_DONTWAIT | MSG_CMSG_CLOEXEC);
if (len == -1 && errno != EINTR)
return -1;
} while (len == -1);
if (decode_cmsg(&connection->fds_in, &msg) != 0) {
return -1;
}
connection->in.head += len;
return connection_buffer_size(&connection->in);
}
int connection_flush(struct connection *connection) {
if (!connection->want_flush) {
return 0;
}
uint32_t tail = connection->out.tail;
while (connection->out.head - connection->out.tail > 0) {
int count;
struct iovec iov[2];
connection_buffer_get_iov(&connection->out, iov, &count);
int clen;
char cmsg[CMSG_LEN(CONNECTION_BUFFER_SIZE)];
build_cmsg(&connection->fds_out, cmsg, &clen);
struct msghdr msg = {
.msg_name = NULL,
.msg_namelen = 0,
.msg_iov = iov,
.msg_iovlen = count,
.msg_control = (clen > 0) ? cmsg : NULL,
.msg_controllen = clen,
.msg_flags = 0,
};
int len;
do {
len = sendmsg(connection->fd, &msg, MSG_NOSIGNAL | MSG_DONTWAIT);
if (len == -1 && errno != EINTR)
return -1;
} while (len == -1);
connection_buffer_close_fds(&connection->fds_out);
connection->out.tail += len;
}
connection->want_flush = 0;
return connection->out.head - tail;
}
int connection_put(struct connection *connection, const void *data, size_t count) {
if (connection_buffer_size(&connection->out) + count > CONNECTION_BUFFER_SIZE) {
connection->want_flush = 1;
if (connection_flush(connection) == -1) {
return -1;
}
}
if (connection_buffer_put(&connection->out, data, count) == -1) {
return -1;
}
connection->want_flush = 1;
return 0;
}
int connection_put_fd(struct connection *connection, int fd) {
if (connection_buffer_size(&connection->fds_out) == MAX_FDS * sizeof fd) {
errno = EOVERFLOW;
return -1;
}
return connection_buffer_put(&connection->fds_out, &fd, sizeof fd);
}
int connection_get(struct connection *connection, void *dst, size_t count) {
if (count > connection_buffer_size(&connection->in)) {
errno = EAGAIN;
return -1;
}
connection_buffer_copy(&connection->in, dst, count);
connection_buffer_consume(&connection->in, count);
return count;
}
int connection_get_fd(struct connection *connection) {
int fd;
if (sizeof fd > connection_buffer_size(&connection->fds_in)) {
errno = EAGAIN;
return -1;
}
connection_buffer_copy(&connection->fds_in, &fd, sizeof fd);
connection_buffer_consume(&connection->fds_in, sizeof fd);
return fd;
}
void connection_close_fds(struct connection *connection) {
connection_buffer_close_fds(&connection->fds_in);
connection_buffer_close_fds(&connection->fds_out);
}
size_t connection_pending(struct connection *connection) {
return connection_buffer_size(&connection->in);
}
void connection_restore(struct connection *connection, size_t count) {
connection_buffer_restore(&connection->in, count);
}