310 lines
8.1 KiB
C
310 lines
8.1 KiB
C
#include <errno.h>
|
|
#include <stddef.h>
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <string.h>
|
|
#include <sys/socket.h>
|
|
#include <unistd.h>
|
|
|
|
#include "compiler.h"
|
|
#include "connection.h"
|
|
|
|
ALWAYS_INLINE static uint32_t connection_buffer_mask(const uint32_t idx) {
|
|
return idx & (CONNECTION_BUFFER_SIZE - 1);
|
|
}
|
|
|
|
ALWAYS_INLINE static uint32_t connection_buffer_size(const struct connection_buffer *b) {
|
|
return b->head - b->tail;
|
|
}
|
|
|
|
ALWAYS_INLINE static void connection_buffer_consume(struct connection_buffer *b, const size_t size) {
|
|
b->tail += size;
|
|
}
|
|
|
|
ALWAYS_INLINE static void connection_buffer_restore(struct connection_buffer *b, const size_t size) {
|
|
b->tail -= size;
|
|
}
|
|
|
|
/*
|
|
* connection_buffer_get_iov prepares I/O vectors pointing to our ring buffer.
|
|
* Two may be used if the buffer has wrapped around.
|
|
*/
|
|
static void connection_buffer_get_iov(struct connection_buffer *b, struct iovec *iov, int *count) {
|
|
uint32_t head = connection_buffer_mask(b->head);
|
|
uint32_t tail = connection_buffer_mask(b->tail);
|
|
if (tail < head) {
|
|
iov[0].iov_base = b->data + tail;
|
|
iov[0].iov_len = head - tail;
|
|
*count = 1;
|
|
} else if (head == 0) {
|
|
iov[0].iov_base = b->data + tail;
|
|
iov[0].iov_len = sizeof b->data - tail;
|
|
*count = 1;
|
|
} else {
|
|
iov[0].iov_base = b->data + tail;
|
|
iov[0].iov_len = sizeof b->data - tail;
|
|
iov[1].iov_base = b->data;
|
|
iov[1].iov_len = head;
|
|
*count = 2;
|
|
}
|
|
}
|
|
|
|
/*
|
|
* connection_buffer_put_iov prepares I/O vectors pointing to our ring buffer.
|
|
* Two may be used if the buffer has wrapped around.
|
|
*/
|
|
static void connection_buffer_put_iov(struct connection_buffer *b, struct iovec *iov, int *count) {
|
|
uint32_t head = connection_buffer_mask(b->head);
|
|
uint32_t tail = connection_buffer_mask(b->tail);
|
|
if (head < tail) {
|
|
iov[0].iov_base = b->data + head;
|
|
iov[0].iov_len = tail - head;
|
|
*count = 1;
|
|
} else if (tail == 0) {
|
|
iov[0].iov_base = b->data + head;
|
|
iov[0].iov_len = sizeof b->data - head;
|
|
*count = 1;
|
|
} else {
|
|
iov[0].iov_base = b->data + head;
|
|
iov[0].iov_len = sizeof b->data - head;
|
|
iov[1].iov_base = b->data;
|
|
iov[1].iov_len = tail;
|
|
*count = 2;
|
|
}
|
|
}
|
|
|
|
/*
|
|
* connection_buffer_copy copies from our ring buffer into a linear buffer.
|
|
*/
|
|
static void connection_buffer_copy(const struct connection_buffer *b, void *data, const size_t count) {
|
|
uint32_t tail = connection_buffer_mask(b->tail);
|
|
if (tail + count <= sizeof b->data) {
|
|
memcpy(data, b->data + tail, count);
|
|
return;
|
|
}
|
|
|
|
uint32_t size = sizeof b->data - tail;
|
|
memcpy(data, b->data + tail, size);
|
|
memcpy((char *)data + size, b->data, count - size);
|
|
}
|
|
/*
|
|
* connection_buffer_copy copies from a linear buffer into our ring buffer.
|
|
*/
|
|
static int connection_buffer_put(struct connection_buffer *b, const void *data, const size_t count) {
|
|
if (count > sizeof(b->data)) {
|
|
errno = EOVERFLOW;
|
|
return -1;
|
|
}
|
|
|
|
uint32_t head = connection_buffer_mask(b->head);
|
|
if (head + count <= sizeof b->data) {
|
|
memcpy(b->data + head, data, count);
|
|
} else {
|
|
uint32_t size = sizeof b->data - head;
|
|
memcpy(b->data + head, data, size);
|
|
memcpy(b->data, (const char *)data + size, count - size);
|
|
}
|
|
|
|
b->head += count;
|
|
return 0;
|
|
}
|
|
|
|
/*
|
|
* close_fds closes all fds within a connection_buffer
|
|
*/
|
|
static void connection_buffer_close_fds(struct connection_buffer *buffer) {
|
|
size_t size = connection_buffer_size(buffer);
|
|
if (size == 0) {
|
|
return;
|
|
}
|
|
int fds[sizeof(buffer->data) / sizeof(int)];
|
|
connection_buffer_copy(buffer, fds, size);
|
|
int count = size / sizeof fds[0];
|
|
size = count * sizeof fds[0];
|
|
for (int idx = 0; idx < count; idx++) {
|
|
close(fds[idx]);
|
|
}
|
|
connection_buffer_consume(buffer, size);
|
|
}
|
|
|
|
/*
|
|
* build_cmsg prepares a cmsg from a buffer full of fds
|
|
*/
|
|
static void build_cmsg(struct connection_buffer *buffer, char *data, int *clen) {
|
|
size_t size = connection_buffer_size(buffer);
|
|
if (size > MAX_FDS * sizeof(int)) {
|
|
size = MAX_FDS * sizeof(int);
|
|
}
|
|
|
|
if (size <= 0) {
|
|
*clen = 0;
|
|
return;
|
|
}
|
|
|
|
struct cmsghdr *cmsg = (struct cmsghdr *)data;
|
|
cmsg->cmsg_level = SOL_SOCKET;
|
|
cmsg->cmsg_type = SCM_RIGHTS;
|
|
cmsg->cmsg_len = CMSG_LEN(size);
|
|
connection_buffer_copy(buffer, CMSG_DATA(cmsg), size);
|
|
*clen = cmsg->cmsg_len;
|
|
}
|
|
|
|
static int decode_cmsg(struct connection_buffer *buffer, struct msghdr *msg) {
|
|
bool overflow = false;
|
|
struct cmsghdr *cmsg;
|
|
for (cmsg = CMSG_FIRSTHDR(msg); cmsg != NULL; cmsg = CMSG_NXTHDR(msg, cmsg)) {
|
|
if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS) {
|
|
continue;
|
|
}
|
|
|
|
size_t size = cmsg->cmsg_len - CMSG_LEN(0);
|
|
size_t max = sizeof(buffer->data) - connection_buffer_size(buffer);
|
|
if (size > max || overflow) {
|
|
overflow = true;
|
|
size /= sizeof(int);
|
|
for (size_t idx = 0; idx < size; idx++) {
|
|
close(((int *)CMSG_DATA(cmsg))[idx]);
|
|
}
|
|
} else if (connection_buffer_put(buffer, CMSG_DATA(cmsg), size) < 0) {
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
if (overflow) {
|
|
errno = EOVERFLOW;
|
|
return -1;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
int connection_read(struct connection *connection) {
|
|
if (connection_buffer_size(&connection->in) >= sizeof(connection->in.data)) {
|
|
errno = EOVERFLOW;
|
|
return -1;
|
|
}
|
|
|
|
int count;
|
|
struct iovec iov[2];
|
|
connection_buffer_put_iov(&connection->in, iov, &count);
|
|
|
|
char cmsg[CMSG_LEN(CONNECTION_BUFFER_SIZE)];
|
|
struct msghdr msg = {
|
|
.msg_name = NULL,
|
|
.msg_namelen = 0,
|
|
.msg_iov = iov,
|
|
.msg_iovlen = count,
|
|
.msg_control = cmsg,
|
|
.msg_controllen = sizeof cmsg,
|
|
.msg_flags = 0,
|
|
};
|
|
|
|
int len;
|
|
do {
|
|
len = recvmsg(connection->fd, &msg, MSG_DONTWAIT | MSG_CMSG_CLOEXEC);
|
|
if (len == -1 && errno != EINTR)
|
|
return -1;
|
|
} while (len == -1);
|
|
|
|
if (decode_cmsg(&connection->fds_in, &msg) != 0) {
|
|
return -1;
|
|
}
|
|
connection->in.head += len;
|
|
|
|
return connection_buffer_size(&connection->in);
|
|
}
|
|
|
|
int connection_flush(struct connection *connection) {
|
|
if (!connection->want_flush) {
|
|
return 0;
|
|
}
|
|
|
|
uint32_t tail = connection->out.tail;
|
|
while (connection->out.head - connection->out.tail > 0) {
|
|
int count;
|
|
struct iovec iov[2];
|
|
connection_buffer_get_iov(&connection->out, iov, &count);
|
|
|
|
int clen;
|
|
char cmsg[CMSG_LEN(CONNECTION_BUFFER_SIZE)];
|
|
build_cmsg(&connection->fds_out, cmsg, &clen);
|
|
struct msghdr msg = {
|
|
.msg_name = NULL,
|
|
.msg_namelen = 0,
|
|
.msg_iov = iov,
|
|
.msg_iovlen = count,
|
|
.msg_control = (clen > 0) ? cmsg : NULL,
|
|
.msg_controllen = clen,
|
|
.msg_flags = 0,
|
|
};
|
|
|
|
int len;
|
|
do {
|
|
len = sendmsg(connection->fd, &msg, MSG_NOSIGNAL | MSG_DONTWAIT);
|
|
if (len == -1 && errno != EINTR)
|
|
return -1;
|
|
} while (len == -1);
|
|
connection_buffer_close_fds(&connection->fds_out);
|
|
connection->out.tail += len;
|
|
}
|
|
connection->want_flush = 0;
|
|
return connection->out.head - tail;
|
|
}
|
|
|
|
int connection_put(struct connection *connection, const void *data, size_t count) {
|
|
if (connection_buffer_size(&connection->out) + count > CONNECTION_BUFFER_SIZE) {
|
|
connection->want_flush = 1;
|
|
if (connection_flush(connection) == -1) {
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
if (connection_buffer_put(&connection->out, data, count) == -1) {
|
|
return -1;
|
|
}
|
|
|
|
connection->want_flush = 1;
|
|
return 0;
|
|
}
|
|
|
|
int connection_put_fd(struct connection *connection, int fd) {
|
|
if (connection_buffer_size(&connection->fds_out) == MAX_FDS * sizeof fd) {
|
|
errno = EOVERFLOW;
|
|
return -1;
|
|
}
|
|
|
|
return connection_buffer_put(&connection->fds_out, &fd, sizeof fd);
|
|
}
|
|
|
|
int connection_get(struct connection *connection, void *dst, size_t count) {
|
|
if (count > connection_buffer_size(&connection->in)) {
|
|
errno = EAGAIN;
|
|
return -1;
|
|
}
|
|
connection_buffer_copy(&connection->in, dst, count);
|
|
connection_buffer_consume(&connection->in, count);
|
|
return count;
|
|
}
|
|
|
|
int connection_get_fd(struct connection *connection) {
|
|
int fd;
|
|
if (sizeof fd > connection_buffer_size(&connection->fds_in)) {
|
|
errno = EAGAIN;
|
|
return -1;
|
|
}
|
|
connection_buffer_copy(&connection->fds_in, &fd, sizeof fd);
|
|
connection_buffer_consume(&connection->fds_in, sizeof fd);
|
|
return fd;
|
|
}
|
|
|
|
void connection_close_fds(struct connection *connection) {
|
|
connection_buffer_close_fds(&connection->fds_in);
|
|
connection_buffer_close_fds(&connection->fds_out);
|
|
}
|
|
|
|
size_t connection_pending(struct connection *connection) {
|
|
return connection_buffer_size(&connection->in);
|
|
}
|
|
|
|
void connection_restore(struct connection *connection, size_t count) {
|
|
connection_buffer_restore(&connection->in, count);
|
|
}
|