Commit 99126abe authored by Martin KaFai Lau's avatar Martin KaFai Lau Committed by Daniel Borkmann

bpf: selftests: A few improvements to network_helpers.c

This patch makes a few changes to the network_helpers.c

1) Enforce SO_RCVTIMEO and SO_SNDTIMEO
   This patch enforces timeout to the network fds through setsockopt
   SO_RCVTIMEO and SO_SNDTIMEO.

   It will remove the need for SOCK_NONBLOCK that requires a more demanding
   timeout logic with epoll/select, e.g. epoll_create, epoll_ctrl, and
   then epoll_wait for timeout.

   That removes the need for connect_wait() from the
   cgroup_skb_sk_lookup.c. The needed change is made in
   cgroup_skb_sk_lookup.c.

2) start_server():
   Add optional addr_str and port to start_server().
   That removes the need of the start_server_with_port().  The caller
   can pass addr_str==NULL and/or port==0.

   I have a future tcp-hdr-opt test that will pass a non-NULL addr_str
   and it is in general useful for other future tests.

   "int timeout_ms" is also added to control the timeout
   on the "accept(listen_fd)".

3) connect_to_fd(): Fully use the server_fd.
   The server sock address has already been obtained from
   getsockname(server_fd).  The sockaddr includes the family,
   so the "int family" arg is redundant.

   Since the server address is obtained from server_fd,  there
   is little reason not to get the server's socket type from the
   server_fd also.  getsockopt(server_fd) can be used to do that,
   so "int type" arg is also removed.

   "int timeout_ms" is added.

4) connect_fd_to_fd():
   "int timeout_ms" is added.
   Some code is also refactored to connect_fd_to_addr() which is
   shared with connect_to_fd().

5) Preserve errno:
   Some callers need to check errno, e.g. cgroup_skb_sk_lookup.c.
   Make changes to do it more consistently in save_errno_close()
   and log_err().
Signed-off-by: default avatarMartin KaFai Lau <kafai@fb.com>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Acked-by: default avatarYonghong Song <yhs@fb.com>
Link: https://lore.kernel.org/bpf/20200702004852.2103003-1-kafai@fb.com
parent 91f77560
...@@ -7,8 +7,6 @@ ...@@ -7,8 +7,6 @@
#include <arpa/inet.h> #include <arpa/inet.h>
#include <sys/epoll.h>
#include <linux/err.h> #include <linux/err.h>
#include <linux/in.h> #include <linux/in.h>
#include <linux/in6.h> #include <linux/in6.h>
...@@ -17,8 +15,13 @@ ...@@ -17,8 +15,13 @@
#include "network_helpers.h" #include "network_helpers.h"
#define clean_errno() (errno == 0 ? "None" : strerror(errno)) #define clean_errno() (errno == 0 ? "None" : strerror(errno))
#define log_err(MSG, ...) fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", \ #define log_err(MSG, ...) ({ \
__FILE__, __LINE__, clean_errno(), ##__VA_ARGS__) int __save = errno; \
fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", \
__FILE__, __LINE__, clean_errno(), \
##__VA_ARGS__); \
errno = __save; \
})
struct ipv4_packet pkt_v4 = { struct ipv4_packet pkt_v4 = {
.eth.h_proto = __bpf_constant_htons(ETH_P_IP), .eth.h_proto = __bpf_constant_htons(ETH_P_IP),
...@@ -37,7 +40,34 @@ struct ipv6_packet pkt_v6 = { ...@@ -37,7 +40,34 @@ struct ipv6_packet pkt_v6 = {
.tcp.doff = 5, .tcp.doff = 5,
}; };
int start_server_with_port(int family, int type, __u16 port) static int settimeo(int fd, int timeout_ms)
{
struct timeval timeout = { .tv_sec = 3 };
if (timeout_ms > 0) {
timeout.tv_sec = timeout_ms / 1000;
timeout.tv_usec = (timeout_ms % 1000) * 1000;
}
if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout,
sizeof(timeout))) {
log_err("Failed to set SO_RCVTIMEO");
return -1;
}
if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout,
sizeof(timeout))) {
log_err("Failed to set SO_SNDTIMEO");
return -1;
}
return 0;
}
#define save_errno_close(fd) ({ int __save = errno; close(fd); errno = __save; })
int start_server(int family, int type, const char *addr_str, __u16 port,
int timeout_ms)
{ {
struct sockaddr_storage addr = {}; struct sockaddr_storage addr = {};
socklen_t len; socklen_t len;
...@@ -48,120 +78,119 @@ int start_server_with_port(int family, int type, __u16 port) ...@@ -48,120 +78,119 @@ int start_server_with_port(int family, int type, __u16 port)
sin->sin_family = AF_INET; sin->sin_family = AF_INET;
sin->sin_port = htons(port); sin->sin_port = htons(port);
if (addr_str &&
inet_pton(AF_INET, addr_str, &sin->sin_addr) != 1) {
log_err("inet_pton(AF_INET, %s)", addr_str);
return -1;
}
len = sizeof(*sin); len = sizeof(*sin);
} else { } else {
struct sockaddr_in6 *sin6 = (void *)&addr; struct sockaddr_in6 *sin6 = (void *)&addr;
sin6->sin6_family = AF_INET6; sin6->sin6_family = AF_INET6;
sin6->sin6_port = htons(port); sin6->sin6_port = htons(port);
if (addr_str &&
inet_pton(AF_INET6, addr_str, &sin6->sin6_addr) != 1) {
log_err("inet_pton(AF_INET6, %s)", addr_str);
return -1;
}
len = sizeof(*sin6); len = sizeof(*sin6);
} }
fd = socket(family, type | SOCK_NONBLOCK, 0); fd = socket(family, type, 0);
if (fd < 0) { if (fd < 0) {
log_err("Failed to create server socket"); log_err("Failed to create server socket");
return -1; return -1;
} }
if (settimeo(fd, timeout_ms))
goto error_close;
if (bind(fd, (const struct sockaddr *)&addr, len) < 0) { if (bind(fd, (const struct sockaddr *)&addr, len) < 0) {
log_err("Failed to bind socket"); log_err("Failed to bind socket");
close(fd); goto error_close;
return -1;
} }
if (type == SOCK_STREAM) { if (type == SOCK_STREAM) {
if (listen(fd, 1) < 0) { if (listen(fd, 1) < 0) {
log_err("Failed to listed on socket"); log_err("Failed to listed on socket");
close(fd); goto error_close;
return -1;
} }
} }
return fd; return fd;
}
int start_server(int family, int type) error_close:
{ save_errno_close(fd);
return start_server_with_port(family, type, 0); return -1;
} }
static const struct timeval timeo_sec = { .tv_sec = 3 }; static int connect_fd_to_addr(int fd,
static const size_t timeo_optlen = sizeof(timeo_sec); const struct sockaddr_storage *addr,
socklen_t addrlen)
int connect_to_fd(int family, int type, int server_fd)
{ {
int fd, save_errno; if (connect(fd, (const struct sockaddr *)addr, addrlen)) {
log_err("Failed to connect to server");
fd = socket(family, type, 0);
if (fd < 0) {
log_err("Failed to create client socket");
return -1; return -1;
} }
if (connect_fd_to_fd(fd, server_fd) < 0 && errno != EINPROGRESS) { return 0;
save_errno = errno;
close(fd);
errno = save_errno;
return -1;
}
return fd;
} }
int connect_fd_to_fd(int client_fd, int server_fd) int connect_to_fd(int server_fd, int timeout_ms)
{ {
struct sockaddr_storage addr; struct sockaddr_storage addr;
socklen_t len = sizeof(addr); struct sockaddr_in *addr_in;
int save_errno; socklen_t addrlen, optlen;
int fd, type;
if (setsockopt(client_fd, SOL_SOCKET, SO_RCVTIMEO, &timeo_sec, optlen = sizeof(type);
timeo_optlen)) { if (getsockopt(server_fd, SOL_SOCKET, SO_TYPE, &type, &optlen)) {
log_err("Failed to set SO_RCVTIMEO"); log_err("getsockopt(SOL_TYPE)");
return -1; return -1;
} }
if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) { addrlen = sizeof(addr);
if (getsockname(server_fd, (struct sockaddr *)&addr, &addrlen)) {
log_err("Failed to get server addr"); log_err("Failed to get server addr");
return -1; return -1;
} }
if (connect(client_fd, (const struct sockaddr *)&addr, len) < 0) { addr_in = (struct sockaddr_in *)&addr;
if (errno != EINPROGRESS) { fd = socket(addr_in->sin_family, type, 0);
save_errno = errno; if (fd < 0) {
log_err("Failed to connect to server"); log_err("Failed to create client socket");
errno = save_errno;
}
return -1; return -1;
} }
return 0; if (settimeo(fd, timeout_ms))
goto error_close;
if (connect_fd_to_addr(fd, &addr, addrlen))
goto error_close;
return fd;
error_close:
save_errno_close(fd);
return -1;
} }
int connect_wait(int fd) int connect_fd_to_fd(int client_fd, int server_fd, int timeout_ms)
{ {
struct epoll_event ev = {}, events[2]; struct sockaddr_storage addr;
int timeout_ms = 1000; socklen_t len = sizeof(addr);
int efd, nfd;
efd = epoll_create1(EPOLL_CLOEXEC); if (settimeo(client_fd, timeout_ms))
if (efd < 0) {
log_err("Failed to open epoll fd");
return -1; return -1;
}
ev.events = EPOLLRDHUP | EPOLLOUT;
ev.data.fd = fd;
if (epoll_ctl(efd, EPOLL_CTL_ADD, fd, &ev) < 0) { if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
log_err("Failed to register fd=%d on epoll fd=%d", fd, efd); log_err("Failed to get server addr");
close(efd);
return -1; return -1;
} }
nfd = epoll_wait(efd, events, ARRAY_SIZE(events), timeout_ms); if (connect_fd_to_addr(client_fd, &addr, len))
if (nfd < 0) return -1;
log_err("Failed to wait for I/O event on epoll fd=%d", efd);
close(efd); return 0;
return nfd;
} }
...@@ -33,10 +33,9 @@ struct ipv6_packet { ...@@ -33,10 +33,9 @@ struct ipv6_packet {
} __packed; } __packed;
extern struct ipv6_packet pkt_v6; extern struct ipv6_packet pkt_v6;
int start_server(int family, int type); int start_server(int family, int type, const char *addr, __u16 port,
int start_server_with_port(int family, int type, __u16 port); int timeout_ms);
int connect_to_fd(int family, int type, int server_fd); int connect_to_fd(int server_fd, int timeout_ms);
int connect_fd_to_fd(int client_fd, int server_fd); int connect_fd_to_fd(int client_fd, int server_fd, int timeout_ms);
int connect_wait(int client_fd);
#endif #endif
...@@ -13,7 +13,7 @@ static void run_lookup_test(__u16 *g_serv_port, int out_sk) ...@@ -13,7 +13,7 @@ static void run_lookup_test(__u16 *g_serv_port, int out_sk)
socklen_t addr_len = sizeof(addr); socklen_t addr_len = sizeof(addr);
__u32 duration = 0; __u32 duration = 0;
serv_sk = start_server(AF_INET6, SOCK_STREAM); serv_sk = start_server(AF_INET6, SOCK_STREAM, NULL, 0, 0);
if (CHECK(serv_sk < 0, "start_server", "failed to start server\n")) if (CHECK(serv_sk < 0, "start_server", "failed to start server\n"))
return; return;
...@@ -24,17 +24,13 @@ static void run_lookup_test(__u16 *g_serv_port, int out_sk) ...@@ -24,17 +24,13 @@ static void run_lookup_test(__u16 *g_serv_port, int out_sk)
*g_serv_port = addr.sin6_port; *g_serv_port = addr.sin6_port;
/* Client outside of test cgroup should fail to connect by timeout. */ /* Client outside of test cgroup should fail to connect by timeout. */
err = connect_fd_to_fd(out_sk, serv_sk); err = connect_fd_to_fd(out_sk, serv_sk, 1000);
if (CHECK(!err || errno != EINPROGRESS, "connect_fd_to_fd", if (CHECK(!err || errno != EINPROGRESS, "connect_fd_to_fd",
"unexpected result err %d errno %d\n", err, errno)) "unexpected result err %d errno %d\n", err, errno))
goto cleanup; goto cleanup;
err = connect_wait(out_sk);
if (CHECK(err, "connect_wait", "unexpected result %d\n", err))
goto cleanup;
/* Client inside test cgroup should connect just fine. */ /* Client inside test cgroup should connect just fine. */
in_sk = connect_to_fd(AF_INET6, SOCK_STREAM, serv_sk); in_sk = connect_to_fd(serv_sk, 0);
if (CHECK(in_sk < 0, "connect_to_fd", "errno %d\n", errno)) if (CHECK(in_sk < 0, "connect_to_fd", "errno %d\n", errno))
goto cleanup; goto cleanup;
...@@ -85,7 +81,7 @@ void test_cgroup_skb_sk_lookup(void) ...@@ -85,7 +81,7 @@ void test_cgroup_skb_sk_lookup(void)
* differs from that of testing cgroup. Moving selftests process to * differs from that of testing cgroup. Moving selftests process to
* testing cgroup won't change cgroup id of an already created socket. * testing cgroup won't change cgroup id of an already created socket.
*/ */
out_sk = socket(AF_INET6, SOCK_STREAM | SOCK_NONBLOCK, 0); out_sk = socket(AF_INET6, SOCK_STREAM, 0);
if (CHECK_FAIL(out_sk < 0)) if (CHECK_FAIL(out_sk < 0))
return; return;
......
...@@ -114,7 +114,7 @@ static int run_test(int cgroup_fd, int server_fd, int family, int type) ...@@ -114,7 +114,7 @@ static int run_test(int cgroup_fd, int server_fd, int family, int type)
goto close_bpf_object; goto close_bpf_object;
} }
fd = connect_to_fd(family, type, server_fd); fd = connect_to_fd(server_fd, 0);
if (fd < 0) { if (fd < 0) {
err = -1; err = -1;
goto close_bpf_object; goto close_bpf_object;
...@@ -137,25 +137,25 @@ void test_connect_force_port(void) ...@@ -137,25 +137,25 @@ void test_connect_force_port(void)
if (CHECK_FAIL(cgroup_fd < 0)) if (CHECK_FAIL(cgroup_fd < 0))
return; return;
server_fd = start_server_with_port(AF_INET, SOCK_STREAM, 60123); server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 60123, 0);
if (CHECK_FAIL(server_fd < 0)) if (CHECK_FAIL(server_fd < 0))
goto close_cgroup_fd; goto close_cgroup_fd;
CHECK_FAIL(run_test(cgroup_fd, server_fd, AF_INET, SOCK_STREAM)); CHECK_FAIL(run_test(cgroup_fd, server_fd, AF_INET, SOCK_STREAM));
close(server_fd); close(server_fd);
server_fd = start_server_with_port(AF_INET6, SOCK_STREAM, 60124); server_fd = start_server(AF_INET6, SOCK_STREAM, NULL, 60124, 0);
if (CHECK_FAIL(server_fd < 0)) if (CHECK_FAIL(server_fd < 0))
goto close_cgroup_fd; goto close_cgroup_fd;
CHECK_FAIL(run_test(cgroup_fd, server_fd, AF_INET6, SOCK_STREAM)); CHECK_FAIL(run_test(cgroup_fd, server_fd, AF_INET6, SOCK_STREAM));
close(server_fd); close(server_fd);
server_fd = start_server_with_port(AF_INET, SOCK_DGRAM, 60123); server_fd = start_server(AF_INET, SOCK_DGRAM, NULL, 60123, 0);
if (CHECK_FAIL(server_fd < 0)) if (CHECK_FAIL(server_fd < 0))
goto close_cgroup_fd; goto close_cgroup_fd;
CHECK_FAIL(run_test(cgroup_fd, server_fd, AF_INET, SOCK_DGRAM)); CHECK_FAIL(run_test(cgroup_fd, server_fd, AF_INET, SOCK_DGRAM));
close(server_fd); close(server_fd);
server_fd = start_server_with_port(AF_INET6, SOCK_DGRAM, 60124); server_fd = start_server(AF_INET6, SOCK_DGRAM, NULL, 60124, 0);
if (CHECK_FAIL(server_fd < 0)) if (CHECK_FAIL(server_fd < 0))
goto close_cgroup_fd; goto close_cgroup_fd;
CHECK_FAIL(run_test(cgroup_fd, server_fd, AF_INET6, SOCK_DGRAM)); CHECK_FAIL(run_test(cgroup_fd, server_fd, AF_INET6, SOCK_DGRAM));
......
...@@ -23,7 +23,7 @@ void test_load_bytes_relative(void) ...@@ -23,7 +23,7 @@ void test_load_bytes_relative(void)
if (CHECK_FAIL(cgroup_fd < 0)) if (CHECK_FAIL(cgroup_fd < 0))
return; return;
server_fd = start_server(AF_INET, SOCK_STREAM); server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
if (CHECK_FAIL(server_fd < 0)) if (CHECK_FAIL(server_fd < 0))
goto close_cgroup_fd; goto close_cgroup_fd;
...@@ -49,7 +49,7 @@ void test_load_bytes_relative(void) ...@@ -49,7 +49,7 @@ void test_load_bytes_relative(void)
if (CHECK_FAIL(err)) if (CHECK_FAIL(err))
goto close_bpf_object; goto close_bpf_object;
client_fd = connect_to_fd(AF_INET, SOCK_STREAM, server_fd); client_fd = connect_to_fd(server_fd, 0);
if (CHECK_FAIL(client_fd < 0)) if (CHECK_FAIL(client_fd < 0))
goto close_bpf_object; goto close_bpf_object;
close(client_fd); close(client_fd);
......
...@@ -118,7 +118,7 @@ static int run_test(int cgroup_fd, int server_fd) ...@@ -118,7 +118,7 @@ static int run_test(int cgroup_fd, int server_fd)
goto close_bpf_object; goto close_bpf_object;
} }
client_fd = connect_to_fd(AF_INET, SOCK_STREAM, server_fd); client_fd = connect_to_fd(server_fd, 0);
if (client_fd < 0) { if (client_fd < 0) {
err = -1; err = -1;
goto close_bpf_object; goto close_bpf_object;
...@@ -161,7 +161,7 @@ void test_tcp_rtt(void) ...@@ -161,7 +161,7 @@ void test_tcp_rtt(void)
if (CHECK_FAIL(cgroup_fd < 0)) if (CHECK_FAIL(cgroup_fd < 0))
return; return;
server_fd = start_server(AF_INET, SOCK_STREAM); server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
if (CHECK_FAIL(server_fd < 0)) if (CHECK_FAIL(server_fd < 0))
goto close_cgroup_fd; goto close_cgroup_fd;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment