Skip to content

Commit 5692cf1

Browse files
authored
don't mark just-connected socket ready-to-read in poll(2) (#732)
This fixes a bug in the `poll(2)` implementation for `wasm32-wasip2` such that it spuriously set the `POLLRDNORM` bit on the `revents` field for a socket which just finished connecting even though there's nothing immediately available to read. The fix is to set at most the `POLLWRNORM` bit in that situation.
1 parent d02bdc2 commit 5692cf1

4 files changed

Lines changed: 133 additions & 4 deletions

File tree

libc-bottom-half/cloudlibc/src/libc/poll/poll.c

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,12 @@ int __wasilibc_poll_add(poll_state_t *state, short events,
163163
}
164164

165165
void __wasilibc_poll_ready(poll_state_t *state, short events) {
166-
if (state->pollfd->revents == 0) {
167-
++state->event_count;
166+
if (events != 0) {
167+
if (state->pollfd->revents == 0) {
168+
++state->event_count;
169+
}
170+
state->pollfd->revents |= events;
168171
}
169-
state->pollfd->revents |= events;
170172
}
171173

172174
static int poll_impl(struct pollfd *fds, size_t nfds, int timeout) {

libc-bottom-half/sources/wasip2_tcp.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1211,7 +1211,9 @@ static int tcp_poll_finish(void *data, poll_state_t *state, short events) {
12111211
memset(&socket->state.connected, 0, sizeof(socket->state.connected));
12121212
socket->state.connected.input = tuple.f0;
12131213
socket->state.connected.output = tuple.f1;
1214-
__wasilibc_poll_ready(state, events);
1214+
// Now that it's connected, it's immediately writable but not necessarily
1215+
// immediately readable:
1216+
__wasilibc_poll_ready(state, events & POLLWRNORM);
12151217
} else if (error == NETWORK_ERROR_CODE_WOULD_BLOCK) {
12161218
// No events yet -- application will need to poll again
12171219
} else {

test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ endif()
306306
# ========= sockets-related tests ===============================
307307

308308
if (NOT (WASI STREQUAL "p1"))
309+
add_wasilibc_test(poll-connect.c NETWORK FAILP3)
309310
add_wasilibc_test(poll-nonblocking-socket.c NETWORK FAILP3)
310311
add_wasilibc_test(setsockopt.c NETWORK FAILP3)
311312
add_wasilibc_test(sockets-nonblocking-udp.c NETWORK FAILP3)

test/src/poll-connect.c

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#include "test.h"
2+
#include <errno.h>
3+
#include <netinet/in.h>
4+
#include <poll.h>
5+
#include <sys/socket.h>
6+
7+
#define TEST(c) \
8+
do { \
9+
errno = 0; \
10+
if (!(c)) \
11+
t_error("%s failed (errno = %d)\n", #c, errno); \
12+
} while (0)
13+
14+
#define TEST2(c) \
15+
do { \
16+
if (!(c)) \
17+
t_error("%s failed (errno = %d)\n", #c, errno); \
18+
} while (0)
19+
20+
#define ASSERT(c) \
21+
do { \
22+
if (!(c)) \
23+
t_error("%s failed\n", #c); \
24+
} while (0)
25+
26+
int main() {
27+
errno = 0;
28+
int server = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0);
29+
TEST2(server != -1);
30+
31+
// Bind to any available port on the loopback address.
32+
struct sockaddr_in server_address;
33+
server_address.sin_family = AF_INET;
34+
server_address.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
35+
server_address.sin_port = htons(0);
36+
TEST(bind(server, (struct sockaddr *)&server_address,
37+
sizeof(server_address)) != -1);
38+
39+
// Get the port we bound to.
40+
socklen_t server_address_len = sizeof(server_address);
41+
TEST(getsockname(server, (struct sockaddr *)&server_address,
42+
&server_address_len) != -1);
43+
44+
TEST(listen(server, 1) != -1);
45+
46+
errno = 0;
47+
int client = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0);
48+
TEST2(client != -1);
49+
50+
// Connect a client socket. This may or may not complete immediately; if not,
51+
// we'll poll and finish connecting later.
52+
TEST(connect(client, (const struct sockaddr *)&server_address,
53+
sizeof(server_address)) != -1 ||
54+
errno == EAGAIN || errno == EWOULDBLOCK || errno == EINPROGRESS);
55+
56+
// Start accepting connections. This may or may not complete immediately; if
57+
// not, we'll poll and accept again later.
58+
socklen_t client_len = sizeof(struct sockaddr_in);
59+
struct sockaddr_in client_address;
60+
errno = 0;
61+
int server_client =
62+
accept(server, (struct sockaddr *)&client_address, &client_len);
63+
TEST2(server_client != -1 || errno == EAGAIN || errno == EWOULDBLOCK ||
64+
errno == EINPROGRESS);
65+
66+
struct pollfd client_fd = {
67+
.fd = client, .events = POLLWRNORM | POLLRDNORM, .revents = 0};
68+
69+
struct pollfd fds[2];
70+
fds[0] = client_fd;
71+
72+
// Poll until the client socket finished connecting and the server socket
73+
// finishes accepting.
74+
int loop_count = 0;
75+
while ((server_client == -1 || fds[0].events) && loop_count < 10) {
76+
++loop_count;
77+
78+
int fd_count = 1;
79+
if (server_client == -1) {
80+
struct pollfd server_fd = {
81+
.fd = server, .events = POLLWRNORM | POLLRDNORM, .revents = 0};
82+
fds[1] = server_fd;
83+
fd_count = 2;
84+
}
85+
86+
TEST(poll(fds, fd_count, 100) != -1);
87+
if (fds[0].revents) {
88+
// The newly-connected client socket should be writable but not yet
89+
// readable:
90+
ASSERT((fds[0].revents & POLLRDNORM) == 0);
91+
ASSERT((fds[0].revents & POLLWRNORM) != 0);
92+
fds[0].events = 0;
93+
}
94+
95+
if (server_client == -1 && fds[1].revents) {
96+
errno = 0;
97+
server_client =
98+
accept(server, (struct sockaddr *)&client_address, &client_len);
99+
TEST2(server_client != -1);
100+
}
101+
}
102+
ASSERT(loop_count < 10);
103+
104+
// Send some data from the server to the client.
105+
uint8_t data[5] = {1, 2, 3, 4, 5};
106+
TEST(send(server_client, data, sizeof(data), 0) == sizeof(data));
107+
108+
fds[0].events = POLLRDNORM;
109+
errno = 0;
110+
int count = poll(fds, 1, 100);
111+
TEST2(count == 1);
112+
// Now the server has sent something, so the client socket should be readable.
113+
ASSERT((fds[0].revents & POLLRDNORM) != 0);
114+
115+
uint8_t received[sizeof(data)];
116+
TEST(recv(client, received, sizeof(data), 0) == sizeof(data));
117+
118+
// Assert that what was received matches what was sent.
119+
for (int i = 0; i < sizeof(data); ++i) {
120+
ASSERT(received[i] == data[i]);
121+
}
122+
123+
return t_status;
124+
}

0 commit comments

Comments
 (0)