diff --git a/src/ipc.c b/src/ipc.c index ec1abd0..612a85e 100644 --- a/src/ipc.c +++ b/src/ipc.c @@ -43,32 +43,73 @@ extern char *ident; static struct sockaddr_un sun; static int ipc_socket = -1; +/* max word count in one command: + * smcroutectl add in1 source group out1 out2 .. out32 + */ +#define CMD_MAX_WORDS (MAXVIFS + 3) + /* Receive command from the smcroutectl */ static void ipc_read(int sd) { + /* since command len must be limited by the max number of oifs + preallocate ipc_msg only once in advance */ + char msg_buf[sizeof(struct ipc_msg) + CMD_MAX_WORDS * sizeof(char *)]; char buf[MX_CMDPKT_SZ]; - struct ipc_msg *msg; + const char* buf_ptr; + ssize_t rx = 0, rx_curr; + int first_call = 1; memset(buf, 0, sizeof(buf)); - msg = (struct ipc_msg *)ipc_receive(sd, buf, sizeof(buf)); - if (!msg) { - /* Skip logging client disconnects */ - if (errno != ECONNRESET) - smclog(LOG_WARNING, "Failed receiving IPC message from client: %s", strerror(errno)); - return; - } - if (msg_do(sd, msg)) { - if (EINVAL == errno) - smclog(LOG_WARNING, "Unknown or malformed IPC message '%c' from client.", msg->cmd); - errno = 0; - ipc_send(sd, log_message, strlen(log_message) + 1); - } else { - ipc_send(sd, "", 1); - } + /* since client message would be big enough and couldn't fit into buffer + we have to make multiple iterations to receive all data */ + while(1) { + rx_curr = ipc_receive(sd, buf + rx, sizeof(buf) - rx, first_call); + first_call = 0; + if(rx_curr <= 0) { + if (errno == EAGAIN) + return; /* no more data from client */ + /* Skip logging client disconnects */ + else if (errno != ECONNRESET) + smclog(LOG_WARNING, "Failed receiving IPC message from client: %s", strerror(errno)); + return; + } + + rx += rx_curr; - free(msg); + /* Make sure to always have at least one NUL, for strlen() */ + buf[rx] = 0; + + buf_ptr = buf; + while(rx > 0) { + struct ipc_msg* msg = (struct ipc_msg*)msg_buf; + /* extract one command at a time */ + if (ipc_parse(buf_ptr, rx, msg)) { + if (EAGAIN == errno) { + /* need more data from client? + move last unused bytes (if any) to the begging of the buffer + and lets try to receive more data */ + memmove(buf, buf_ptr, rx); + break; + } + smclog(LOG_WARNING, "Failed to parse IPC message from client: %s", strerror(errno)); + return; + } + + if (msg_do(sd, msg)) { + if (EINVAL == errno) + smclog(LOG_WARNING, "Unknown or malformed IPC message '%c' from client.", msg->cmd); + errno = 0; + ipc_send(sd, log_message, strlen(log_message) + 1); + } else { + ipc_send(sd, "", 1); + } + /* shift to the next command if any and reduce remaining bytes in buffer */ + buf_ptr += msg->len; + rx -= msg->len; + } + } } static void ipc_accept(int sd, void *arg) @@ -149,7 +190,7 @@ void ipc_exit(void) * Returns: * Number of bytes successfully sent, or -1 with @errno on failure. */ -int ipc_send(int sd, char *buf, size_t len) +int ipc_send(int sd, const char *buf, size_t len) { if (write(sd, buf, len) != (ssize_t)len) return -1; @@ -162,68 +203,79 @@ int ipc_send(int sd, char *buf, size_t len) * @sd: Client socket from ipc_accept() * @buf: Buffer for message * @len: Size of @buf in bytes + * @first_call: non-zero set on first read after accept, 0 - subsequent calls * - * Reads a message from the IPC socket and stores in @buf, respecting + * Reads a message(s) from the IPC socket and stores in @buf, respecting * the size @len. Connects and resets connection as necessary. * * Returns: - * Pointer to a successfuly read command packet in @buf, or %NULL on error. + * Size of a successfuly read command packet in @buf, or 0 on error. */ -void *ipc_receive(int sd, char *buf, size_t len) +ssize_t ipc_receive(int sd, char *buf, size_t len, int first_call) { ssize_t sz; + /* since we can call this multiple times during receive of multipart + command lets pass `don't wait` flag to not block forever + when client finish transmission */ + int flags = first_call ? 0 : MSG_DONTWAIT; - sz = recv(sd, buf, len - 1, 0); - if (sz <= 0) { - if (!sz) - errno = ECONNRESET; - return NULL; - } + sz = recv(sd, buf, len - 1, flags); + if (!sz) + errno = ECONNRESET; - /* successful read */ - if ((size_t)sz >= sizeof(struct ipc_msg)) { - struct ipc_msg *msg = (struct ipc_msg *)buf; + return sz; +} - /* Make sure to always have at least one NUL, for strlen() */ - buf[sz] = 0; +/** + * ipc_server_parse - Parse IPC message(s) from client + * @buf: Buffer of message(s) + * @sz: Size of @buf in bytes + * @msg_buf: Preallocated ipc_msg + * + * Parse message(s) from the IPC socket, respecting + * the size @sz. + * + * Returns: + * POSIX OK(0) on a successfuly read command in @buf, or non-zero on error. + */ +int ipc_parse(const char *buf, size_t sz, void* msg_buf) +{ + struct ipc_msg* msg; - if ((size_t)sz == msg->len) { + /* successful read */ + if (sz >= sizeof(struct ipc_msg)) { + memcpy(msg_buf, buf, sizeof(struct ipc_msg)); + msg = (struct ipc_msg*)msg_buf; + /* enough bytes to extract just one message? */ + if (sz >= msg->len) { size_t i, count; - char *ptr; + /* We are not going to modify anything here */ + const char *ptr; - /* Upper bound: smcroutectl add in1 source group out1 out2 .. out32 */ count = msg->count; - if (count > (MAXVIFS + 3)) { + if (count > CMD_MAX_WORDS) { errno = EINVAL; - return NULL; + return 1; } - msg = malloc(sizeof(struct ipc_msg) + msg->count * sizeof(char *)); - if (!msg) - return NULL; - - memcpy(msg, buf, sizeof(struct ipc_msg)); - ptr = buf + offsetof(struct ipc_msg, argv); for (i = 0; i < count; i++) { /* Verify ptr, attacker may set too large msg->count */ - if (ptr >= (buf + len)) { - free(msg); + if (ptr >= (buf + msg->len)) { errno = EBADMSG; - return NULL; + return 1; } - msg->argv[i] = ptr; + msg->argv[i] = (char*)ptr; ptr += strlen(ptr) + 1; } - msg->count = count; - return msg; + return 0; } } - + /* we've parsed all commands or not enough bytes to parse next */ errno = EAGAIN; - return NULL; + return 1; } /** diff --git a/src/ipc.h b/src/ipc.h index 791e00d..e27bc12 100644 --- a/src/ipc.h +++ b/src/ipc.h @@ -7,8 +7,9 @@ int ipc_init (char *path); void ipc_exit (void); -int ipc_send (int sd, char *buf, size_t len); -void *ipc_receive (int sd, char *buf, size_t len); +int ipc_send (int sd, const char *buf, size_t len); +ssize_t ipc_receive(int sd, char *buf, size_t len, int first_call); +int ipc_parse (const char *buf, size_t sz, void *msg_buf); #endif /* SMCROUTE_IPC_H_ */