Commit 2bc793e3 authored by Cong Wang's avatar Cong Wang Committed by Alexei Starovoitov

skmsg: Extract __tcp_bpf_recvmsg() and tcp_bpf_wait_data()

Although these two functions are only used by TCP, they are not
specific to TCP at all, both operate on skmsg and ingress_msg,
so fit in net/core/skmsg.c very well.

And we will need them for non-TCP, so rename and move them to
skmsg.c and export them to modules.
Signed-off-by: default avatarCong Wang <cong.wang@bytedance.com>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Link: https://lore.kernel.org/bpf/20210331023237.41094-13-xiyou.wangcong@gmail.com
parent d7f57118
...@@ -125,6 +125,10 @@ int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from, ...@@ -125,6 +125,10 @@ int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
struct sk_msg *msg, u32 bytes); struct sk_msg *msg, u32 bytes);
int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from, int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
struct sk_msg *msg, u32 bytes); struct sk_msg *msg, u32 bytes);
int sk_msg_wait_data(struct sock *sk, struct sk_psock *psock, int flags,
long timeo, int *err);
int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
int len, int flags);
static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes) static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
{ {
......
...@@ -2209,8 +2209,6 @@ void tcp_bpf_clone(const struct sock *sk, struct sock *newsk); ...@@ -2209,8 +2209,6 @@ void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes, int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
int flags); int flags);
int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
struct msghdr *msg, int len, int flags);
#endif /* CONFIG_NET_SOCK_MSG */ #endif /* CONFIG_NET_SOCK_MSG */
#if !defined(CONFIG_BPF_SYSCALL) || !defined(CONFIG_NET_SOCK_MSG) #if !defined(CONFIG_BPF_SYSCALL) || !defined(CONFIG_NET_SOCK_MSG)
......
...@@ -399,6 +399,104 @@ int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from, ...@@ -399,6 +399,104 @@ int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
} }
EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter); EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
int sk_msg_wait_data(struct sock *sk, struct sk_psock *psock, int flags,
long timeo, int *err)
{
DEFINE_WAIT_FUNC(wait, woken_wake_function);
int ret = 0;
if (sk->sk_shutdown & RCV_SHUTDOWN)
return 1;
if (!timeo)
return ret;
add_wait_queue(sk_sleep(sk), &wait);
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
ret = sk_wait_event(sk, &timeo,
!list_empty(&psock->ingress_msg) ||
!skb_queue_empty(&sk->sk_receive_queue), &wait);
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
remove_wait_queue(sk_sleep(sk), &wait);
return ret;
}
EXPORT_SYMBOL_GPL(sk_msg_wait_data);
/* Receive sk_msg from psock->ingress_msg to @msg. */
int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
int len, int flags)
{
struct iov_iter *iter = &msg->msg_iter;
int peek = flags & MSG_PEEK;
struct sk_msg *msg_rx;
int i, copied = 0;
msg_rx = sk_psock_peek_msg(psock);
while (copied != len) {
struct scatterlist *sge;
if (unlikely(!msg_rx))
break;
i = msg_rx->sg.start;
do {
struct page *page;
int copy;
sge = sk_msg_elem(msg_rx, i);
copy = sge->length;
page = sg_page(sge);
if (copied + copy > len)
copy = len - copied;
copy = copy_page_to_iter(page, sge->offset, copy, iter);
if (!copy)
return copied ? copied : -EFAULT;
copied += copy;
if (likely(!peek)) {
sge->offset += copy;
sge->length -= copy;
if (!msg_rx->skb)
sk_mem_uncharge(sk, copy);
msg_rx->sg.size -= copy;
if (!sge->length) {
sk_msg_iter_var_next(i);
if (!msg_rx->skb)
put_page(page);
}
} else {
/* Lets not optimize peek case if copy_page_to_iter
* didn't copy the entire length lets just break.
*/
if (copy != sge->length)
return copied;
sk_msg_iter_var_next(i);
}
if (copied == len)
break;
} while (i != msg_rx->sg.end);
if (unlikely(peek)) {
msg_rx = sk_psock_next_msg(psock, msg_rx);
if (!msg_rx)
break;
continue;
}
msg_rx->sg.start = i;
if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
msg_rx = sk_psock_dequeue_msg(psock);
kfree_sk_msg(msg_rx);
}
msg_rx = sk_psock_peek_msg(psock);
}
return copied;
}
EXPORT_SYMBOL_GPL(sk_msg_recvmsg);
static struct sk_msg *sk_psock_create_ingress_msg(struct sock *sk, static struct sk_msg *sk_psock_create_ingress_msg(struct sock *sk,
struct sk_buff *skb) struct sk_buff *skb)
{ {
......
...@@ -10,80 +10,6 @@ ...@@ -10,80 +10,6 @@
#include <net/inet_common.h> #include <net/inet_common.h>
#include <net/tls.h> #include <net/tls.h>
int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
struct msghdr *msg, int len, int flags)
{
struct iov_iter *iter = &msg->msg_iter;
int peek = flags & MSG_PEEK;
struct sk_msg *msg_rx;
int i, copied = 0;
msg_rx = sk_psock_peek_msg(psock);
while (copied != len) {
struct scatterlist *sge;
if (unlikely(!msg_rx))
break;
i = msg_rx->sg.start;
do {
struct page *page;
int copy;
sge = sk_msg_elem(msg_rx, i);
copy = sge->length;
page = sg_page(sge);
if (copied + copy > len)
copy = len - copied;
copy = copy_page_to_iter(page, sge->offset, copy, iter);
if (!copy)
return copied ? copied : -EFAULT;
copied += copy;
if (likely(!peek)) {
sge->offset += copy;
sge->length -= copy;
if (!msg_rx->skb)
sk_mem_uncharge(sk, copy);
msg_rx->sg.size -= copy;
if (!sge->length) {
sk_msg_iter_var_next(i);
if (!msg_rx->skb)
put_page(page);
}
} else {
/* Lets not optimize peek case if copy_page_to_iter
* didn't copy the entire length lets just break.
*/
if (copy != sge->length)
return copied;
sk_msg_iter_var_next(i);
}
if (copied == len)
break;
} while (i != msg_rx->sg.end);
if (unlikely(peek)) {
msg_rx = sk_psock_next_msg(psock, msg_rx);
if (!msg_rx)
break;
continue;
}
msg_rx->sg.start = i;
if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
msg_rx = sk_psock_dequeue_msg(psock);
kfree_sk_msg(msg_rx);
}
msg_rx = sk_psock_peek_msg(psock);
}
return copied;
}
EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg);
static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
struct sk_msg *msg, u32 apply_bytes, int flags) struct sk_msg *msg, u32 apply_bytes, int flags)
{ {
...@@ -237,28 +163,6 @@ static bool tcp_bpf_stream_read(const struct sock *sk) ...@@ -237,28 +163,6 @@ static bool tcp_bpf_stream_read(const struct sock *sk)
return !empty; return !empty;
} }
static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
int flags, long timeo, int *err)
{
DEFINE_WAIT_FUNC(wait, woken_wake_function);
int ret = 0;
if (sk->sk_shutdown & RCV_SHUTDOWN)
return 1;
if (!timeo)
return ret;
add_wait_queue(sk_sleep(sk), &wait);
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
ret = sk_wait_event(sk, &timeo,
!list_empty(&psock->ingress_msg) ||
!skb_queue_empty(&sk->sk_receive_queue), &wait);
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
remove_wait_queue(sk_sleep(sk), &wait);
return ret;
}
static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len) int nonblock, int flags, int *addr_len)
{ {
...@@ -278,13 +182,13 @@ static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -278,13 +182,13 @@ static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
} }
lock_sock(sk); lock_sock(sk);
msg_bytes_ready: msg_bytes_ready:
copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags); copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
if (!copied) { if (!copied) {
int data, err = 0; int data, err = 0;
long timeo; long timeo;
timeo = sock_rcvtimeo(sk, nonblock); timeo = sock_rcvtimeo(sk, nonblock);
data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err); data = sk_msg_wait_data(sk, psock, flags, timeo, &err);
if (data) { if (data) {
if (!sk_psock_queue_empty(psock)) if (!sk_psock_queue_empty(psock))
goto msg_bytes_ready; goto msg_bytes_ready;
......
...@@ -1789,8 +1789,8 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1789,8 +1789,8 @@ int tls_sw_recvmsg(struct sock *sk,
skb = tls_wait_data(sk, psock, flags, timeo, &err); skb = tls_wait_data(sk, psock, flags, timeo, &err);
if (!skb) { if (!skb) {
if (psock) { if (psock) {
int ret = __tcp_bpf_recvmsg(sk, psock, int ret = sk_msg_recvmsg(sk, psock, msg, len,
msg, len, flags); flags);
if (ret > 0) { if (ret > 0) {
decrypted += ret; decrypted += ret;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment