diff --git a/libc-bottom-half/sources/file_utils.c b/libc-bottom-half/sources/file_utils.c index ef2bd74b7..89408fef9 100644 --- a/libc-bottom-half/sources/file_utils.c +++ b/libc-bottom-half/sources/file_utils.c @@ -351,6 +351,11 @@ ssize_t __wasilibc_write(wasi_write_t *write, const void *buffer, #elif defined(__wasip3__) wasip3_io_state_t *state = write->state; + // If this stream is closed, for example with a TCP shutdown, then it's + // closed and we're at EOF. + if (state->stream == 0) + return write->eof(write->eof_data); + // First resolve any pending I/O, should it exist. if (wasip3_write_resolve_pending(write) < 0) return -1; @@ -550,6 +555,11 @@ ssize_t __wasilibc_read(wasi_read_t *read, void *buffer, size_t length) { wasip3_io_state_t *state = read->state; wasip3_event_t event; + // If this stream is closed, for example with a TCP shutdown, then it's + // closed and we're at EOF. + if (state->stream == 0) + return read->eof(read->eof_data); + // If there's active I/O in progress for this stream then this must wait for // it to complete. if (state->flags & WASIP3_IO_INPROGRESS) { @@ -669,6 +679,13 @@ static void wasip3_poll_write_ready(void *data, poll_state_t *state, static int wasip3_stream_poll(wasip3_io_state_t *iostate, poll_state_t *state, short events, poll_ready_t ready) { + // If the stream is closed then it's immediately ready for reading/writing as + // that'll resolve with an error/0/etc. + if (iostate->stream == 0) { + __wasilibc_poll_ready(state, events); + return 0; + } + // If the I/O stream is finished it'll never block so it's always ready. if (iostate->flags & WASIP3_IO_DONE) { __wasilibc_poll_ready(state, events); diff --git a/libc-bottom-half/sources/tcp.c b/libc-bottom-half/sources/tcp.c index f160a16c5..805533758 100644 --- a/libc-bottom-half/sources/tcp.c +++ b/libc-bottom-half/sources/tcp.c @@ -220,7 +220,8 @@ static int tcp_write_eof(void *data) { if (result.is_err) return __wasilibc_socket_error_to_errno(&result.val.err); } - return 0; + errno = EPIPE; + return -1; } #endif diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index e35c5b18f..b91697946 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -417,6 +417,7 @@ if (NOT (WASI STREQUAL "p1")) add_wasilibc_test(sockets-nonblocking-udp-no-connection.c NETWORK) add_wasilibc_test(sockets-eof-delayed.c NETWORK) add_wasilibc_test(sockets-nonblocking-accept-multiple.c NETWORK) + add_wasilibc_test(sockets-nonblocking-shutdown.c NETWORK) # Define executables for server/client tests, and they're paired together in # various combinations below for various tests. diff --git a/test/src/sockets-nonblocking-shutdown.c b/test/src/sockets-nonblocking-shutdown.c new file mode 100644 index 000000000..493828a00 --- /dev/null +++ b/test/src/sockets-nonblocking-shutdown.c @@ -0,0 +1,63 @@ +#include "test.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define TEST(c) \ + do { \ + errno = 0; \ + if (!(c)) \ + t_error("%s failed (errno = %d)\n", #c, errno); \ + } while (0) + +#define N 10 + +int main() { + int listener_fd = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0); + + // Setup a listener bound to port 0 to have the OS assign us one. + struct sockaddr_in server_address; + socklen_t server_address_len = sizeof(server_address); + server_address.sin_family = AF_INET; + server_address.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + server_address.sin_port = 0; + TEST(bind(listener_fd, (struct sockaddr *)&server_address, + sizeof(server_address)) != -1); + TEST(getsockname(listener_fd, (struct sockaddr *)&server_address, + &server_address_len) != -1); + TEST(listen(listener_fd, N) != -1); + + int client_fd = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0); + TEST(client_fd != -1); + int rc; + TEST((rc = connect(client_fd, (struct sockaddr *)&server_address, + server_address_len)) != -1 || + errno == EINPROGRESS); + if (rc == -1) { + struct pollfd pfd = {.fd = client_fd, .events = POLLWRNORM}; + TEST(poll(&pfd, 1, -1) != -1); + } + + TEST(shutdown(client_fd, SHUT_RDWR) != -1); + + char buf[10]; + TEST(recv(client_fd, buf, sizeof(buf), 0) == 0); + TEST(send(client_fd, buf, sizeof(buf), 0) == -1 && errno == EPIPE); + + struct pollfd pfd = {.fd = client_fd, .events = POLLWRNORM}; + TEST(poll(&pfd, 1, -1) == 1); + TEST(pfd.revents & POLLWRNORM); + pfd.events = POLLRDNORM; + TEST(poll(&pfd, 1, -1) == 1); + TEST(pfd.revents & POLLRDNORM); + + return t_status; +}