diff --git a/libc-bottom-half/cloudlibc/src/libc/poll/poll.c b/libc-bottom-half/cloudlibc/src/libc/poll/poll.c index de4c2c842..6d77f3381 100644 --- a/libc-bottom-half/cloudlibc/src/libc/poll/poll.c +++ b/libc-bottom-half/cloudlibc/src/libc/poll/poll.c @@ -317,6 +317,7 @@ int __wasilibc_poll_add(poll_state_t *state, uint32_t waitable, } void __wasilibc_poll_ready(poll_state_t *state, short events) { + events = events & state->pollfd->events; if (events != 0) { if (state->pollfd->revents == 0) { ++state->event_count; diff --git a/test/src/sockets-nonblocking.c b/test/src/sockets-nonblocking.c index 9a6664eec..31f9cc7c3 100644 --- a/test/src/sockets-nonblocking.c +++ b/test/src/sockets-nonblocking.c @@ -156,8 +156,64 @@ void test_tcp_client() { TEST(close(server_socket_fd) == 0); } +static void test_poll_events() { + int listener_fd = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0); + + struct sockaddr_in server_address; + socklen_t server_address_len = sizeof(server_address); + server_address.sin_family = AF_INET; + server_address.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + server_address.sin_port = 0; + TEST(bind(listener_fd, (struct sockaddr *)&server_address, + sizeof(server_address)) != -1); + TEST(getsockname(listener_fd, (struct sockaddr *)&server_address, + &server_address_len) != -1); + TEST(listen(listener_fd, 1) != -1); + + int client_fd; + TEST((client_fd = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0)) != -1); + TEST(connect(client_fd, (struct sockaddr *)&server_address, + server_address_len) != -1 || + errno == EINPROGRESS); + + struct pollfd pfds[3]; + pfds[0].fd = listener_fd; + pfds[0].events = POLLRDNORM; + pfds[0].revents = 0; + pfds[1].fd = client_fd; + pfds[1].events = POLLRDNORM; + pfds[1].revents = 0; + + while (!(pfds[0].revents & POLLRDNORM)) { + TEST(poll(pfds, 2, -1) != -1); + TEST(pfds[1].revents == 0); + } + TEST(pfds[0].revents == POLLRDNORM); + + int server_fd; + struct pollfd pfd; + TEST((server_fd = accept4(listener_fd, NULL, NULL, SOCK_NONBLOCK)) != -1); + pfd.fd = server_fd; + pfd.events = POLLWRNORM; + pfd.revents = 0; + while (!(pfd.revents & POLLWRNORM)) + TEST(poll(&pfd, 1, -1) != -1); + TEST(write(server_fd, "x", 1) == 1); + TEST(close(server_fd) == 0); + + while (!(pfds[1].revents & POLLRDNORM)) { + TEST(poll(pfds, 2, -1) != -1); + TEST(pfds[0].revents == 0); + } + TEST(pfds[1].revents == POLLRDNORM); + + TEST(close(client_fd) == 0); + TEST(close(listener_fd) == 0); +} + int main(void) { test_tcp_client(); + test_poll_events(); return t_status; }