Commit 04a88637 authored by Jakub Kicinski's avatar Jakub Kicinski

Merge branch 'tcp-add-cmsg-rx-timestamps-to-rx-zerocopy'

Arjun Roy says:

====================
tcp: add CMSG+rx timestamps to rx. zerocopy

Provide CMSG and receive timestamp support to TCP
receive zerocopy. Patch 1 refactors CMSG pending state for
tcp_recvmsg() to avoid the use of magic numbers; patch 2 implements
receive timestamp via CMSG support for receive zerocopy, and uses the
constants added in patch 1.
====================

Link: https://lore.kernel.org/r/20210121004148.2340206-1-arjunroy.kdev@gmail.comSigned-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parents 5225d5f5 7eeba170
...@@ -354,5 +354,9 @@ struct tcp_zerocopy_receive { ...@@ -354,5 +354,9 @@ struct tcp_zerocopy_receive {
__u64 copybuf_address; /* in: copybuf address (small reads) */ __u64 copybuf_address; /* in: copybuf address (small reads) */
__s32 copybuf_len; /* in/out: copybuf bytes avail/used or error */ __s32 copybuf_len; /* in/out: copybuf bytes avail/used or error */
__u32 flags; /* in: flags */ __u32 flags; /* in: flags */
__u64 msg_control; /* ancillary data */
__u64 msg_controllen;
__u32 msg_flags;
/* __u32 hole; Next we must add >1 u32 otherwise length checks fail. */
}; };
#endif /* _UAPI_LINUX_TCP_H */ #endif /* _UAPI_LINUX_TCP_H */
...@@ -280,6 +280,12 @@ ...@@ -280,6 +280,12 @@
#include <asm/ioctls.h> #include <asm/ioctls.h>
#include <net/busy_poll.h> #include <net/busy_poll.h>
/* Track pending CMSGs. */
enum {
TCP_CMSG_INQ = 1,
TCP_CMSG_TS = 2
};
struct percpu_counter tcp_orphan_count; struct percpu_counter tcp_orphan_count;
EXPORT_SYMBOL_GPL(tcp_orphan_count); EXPORT_SYMBOL_GPL(tcp_orphan_count);
...@@ -1739,6 +1745,20 @@ int tcp_set_rcvlowat(struct sock *sk, int val) ...@@ -1739,6 +1745,20 @@ int tcp_set_rcvlowat(struct sock *sk, int val)
} }
EXPORT_SYMBOL(tcp_set_rcvlowat); EXPORT_SYMBOL(tcp_set_rcvlowat);
static void tcp_update_recv_tstamps(struct sk_buff *skb,
struct scm_timestamping_internal *tss)
{
if (skb->tstamp)
tss->ts[0] = ktime_to_timespec64(skb->tstamp);
else
tss->ts[0] = (struct timespec64) {0};
if (skb_hwtstamps(skb)->hwtstamp)
tss->ts[2] = ktime_to_timespec64(skb_hwtstamps(skb)->hwtstamp);
else
tss->ts[2] = (struct timespec64) {0};
}
#ifdef CONFIG_MMU #ifdef CONFIG_MMU
static const struct vm_operations_struct tcp_vm_ops = { static const struct vm_operations_struct tcp_vm_ops = {
}; };
...@@ -1842,13 +1862,13 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -1842,13 +1862,13 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len,
struct scm_timestamping_internal *tss, struct scm_timestamping_internal *tss,
int *cmsg_flags); int *cmsg_flags);
static int receive_fallback_to_copy(struct sock *sk, static int receive_fallback_to_copy(struct sock *sk,
struct tcp_zerocopy_receive *zc, int inq) struct tcp_zerocopy_receive *zc, int inq,
struct scm_timestamping_internal *tss)
{ {
unsigned long copy_address = (unsigned long)zc->copybuf_address; unsigned long copy_address = (unsigned long)zc->copybuf_address;
struct scm_timestamping_internal tss_unused;
int err, cmsg_flags_unused;
struct msghdr msg = {}; struct msghdr msg = {};
struct iovec iov; struct iovec iov;
int err;
zc->length = 0; zc->length = 0;
zc->recv_skip_hint = 0; zc->recv_skip_hint = 0;
...@@ -1862,7 +1882,7 @@ static int receive_fallback_to_copy(struct sock *sk, ...@@ -1862,7 +1882,7 @@ static int receive_fallback_to_copy(struct sock *sk,
return err; return err;
err = tcp_recvmsg_locked(sk, &msg, inq, /*nonblock=*/1, /*flags=*/0, err = tcp_recvmsg_locked(sk, &msg, inq, /*nonblock=*/1, /*flags=*/0,
&tss_unused, &cmsg_flags_unused); tss, &zc->msg_flags);
if (err < 0) if (err < 0)
return err; return err;
...@@ -1903,21 +1923,27 @@ static int tcp_copy_straggler_data(struct tcp_zerocopy_receive *zc, ...@@ -1903,21 +1923,27 @@ static int tcp_copy_straggler_data(struct tcp_zerocopy_receive *zc,
return (__s32)copylen; return (__s32)copylen;
} }
static int tcp_zerocopy_handle_leftover_data(struct tcp_zerocopy_receive *zc, static int tcp_zc_handle_leftover(struct tcp_zerocopy_receive *zc,
struct sock *sk, struct sock *sk,
struct sk_buff *skb, struct sk_buff *skb,
u32 *seq, u32 *seq,
s32 copybuf_len) s32 copybuf_len,
struct scm_timestamping_internal *tss)
{ {
u32 offset, copylen = min_t(u32, copybuf_len, zc->recv_skip_hint); u32 offset, copylen = min_t(u32, copybuf_len, zc->recv_skip_hint);
if (!copylen) if (!copylen)
return 0; return 0;
/* skb is null if inq < PAGE_SIZE. */ /* skb is null if inq < PAGE_SIZE. */
if (skb) if (skb) {
offset = *seq - TCP_SKB_CB(skb)->seq; offset = *seq - TCP_SKB_CB(skb)->seq;
else } else {
skb = tcp_recv_skb(sk, *seq, &offset); skb = tcp_recv_skb(sk, *seq, &offset);
if (TCP_SKB_CB(skb)->has_rxtstamp) {
tcp_update_recv_tstamps(skb, tss);
zc->msg_flags |= TCP_CMSG_TS;
}
}
zc->copybuf_len = tcp_copy_straggler_data(zc, skb, copylen, &offset, zc->copybuf_len = tcp_copy_straggler_data(zc, skb, copylen, &offset,
seq); seq);
...@@ -2004,9 +2030,37 @@ static int tcp_zerocopy_vm_insert_batch(struct vm_area_struct *vma, ...@@ -2004,9 +2030,37 @@ static int tcp_zerocopy_vm_insert_batch(struct vm_area_struct *vma,
err); err);
} }
static void tcp_recv_timestamp(struct msghdr *msg, const struct sock *sk,
struct scm_timestamping_internal *tss);
static void tcp_zc_finalize_rx_tstamp(struct sock *sk,
struct tcp_zerocopy_receive *zc,
struct scm_timestamping_internal *tss)
{
unsigned long msg_control_addr;
struct msghdr cmsg_dummy;
msg_control_addr = (unsigned long)zc->msg_control;
cmsg_dummy.msg_control = (void *)msg_control_addr;
cmsg_dummy.msg_controllen =
(__kernel_size_t)zc->msg_controllen;
cmsg_dummy.msg_flags = in_compat_syscall()
? MSG_CMSG_COMPAT : 0;
zc->msg_flags = 0;
if (zc->msg_control == msg_control_addr &&
zc->msg_controllen == cmsg_dummy.msg_controllen) {
tcp_recv_timestamp(&cmsg_dummy, sk, tss);
zc->msg_control = (__u64)
((uintptr_t)cmsg_dummy.msg_control);
zc->msg_controllen =
(__u64)cmsg_dummy.msg_controllen;
zc->msg_flags = (__u32)cmsg_dummy.msg_flags;
}
}
#define TCP_ZEROCOPY_PAGE_BATCH_SIZE 32 #define TCP_ZEROCOPY_PAGE_BATCH_SIZE 32
static int tcp_zerocopy_receive(struct sock *sk, static int tcp_zerocopy_receive(struct sock *sk,
struct tcp_zerocopy_receive *zc) struct tcp_zerocopy_receive *zc,
struct scm_timestamping_internal *tss)
{ {
u32 length = 0, offset, vma_len, avail_len, copylen = 0; u32 length = 0, offset, vma_len, avail_len, copylen = 0;
unsigned long address = (unsigned long)zc->address; unsigned long address = (unsigned long)zc->address;
...@@ -2023,6 +2077,7 @@ static int tcp_zerocopy_receive(struct sock *sk, ...@@ -2023,6 +2077,7 @@ static int tcp_zerocopy_receive(struct sock *sk,
int ret; int ret;
zc->copybuf_len = 0; zc->copybuf_len = 0;
zc->msg_flags = 0;
if (address & (PAGE_SIZE - 1) || address != zc->address) if (address & (PAGE_SIZE - 1) || address != zc->address)
return -EINVAL; return -EINVAL;
...@@ -2033,7 +2088,7 @@ static int tcp_zerocopy_receive(struct sock *sk, ...@@ -2033,7 +2088,7 @@ static int tcp_zerocopy_receive(struct sock *sk,
sock_rps_record_flow(sk); sock_rps_record_flow(sk);
if (inq && inq <= copybuf_len) if (inq && inq <= copybuf_len)
return receive_fallback_to_copy(sk, zc, inq); return receive_fallback_to_copy(sk, zc, inq, tss);
if (inq < PAGE_SIZE) { if (inq < PAGE_SIZE) {
zc->length = 0; zc->length = 0;
...@@ -2078,6 +2133,11 @@ static int tcp_zerocopy_receive(struct sock *sk, ...@@ -2078,6 +2133,11 @@ static int tcp_zerocopy_receive(struct sock *sk,
} else { } else {
skb = tcp_recv_skb(sk, seq, &offset); skb = tcp_recv_skb(sk, seq, &offset);
} }
if (TCP_SKB_CB(skb)->has_rxtstamp) {
tcp_update_recv_tstamps(skb, tss);
zc->msg_flags |= TCP_CMSG_TS;
}
zc->recv_skip_hint = skb->len - offset; zc->recv_skip_hint = skb->len - offset;
frags = skb_advance_to_frag(skb, offset, &offset_frag); frags = skb_advance_to_frag(skb, offset, &offset_frag);
if (!frags || offset_frag) if (!frags || offset_frag)
...@@ -2120,8 +2180,7 @@ static int tcp_zerocopy_receive(struct sock *sk, ...@@ -2120,8 +2180,7 @@ static int tcp_zerocopy_receive(struct sock *sk,
mmap_read_unlock(current->mm); mmap_read_unlock(current->mm);
/* Try to copy straggler data. */ /* Try to copy straggler data. */
if (!ret) if (!ret)
copylen = tcp_zerocopy_handle_leftover_data(zc, sk, skb, &seq, copylen = tcp_zc_handle_leftover(zc, sk, skb, &seq, copybuf_len, tss);
copybuf_len);
if (length + copylen) { if (length + copylen) {
WRITE_ONCE(tp->copied_seq, seq); WRITE_ONCE(tp->copied_seq, seq);
...@@ -2142,20 +2201,6 @@ static int tcp_zerocopy_receive(struct sock *sk, ...@@ -2142,20 +2201,6 @@ static int tcp_zerocopy_receive(struct sock *sk,
} }
#endif #endif
static void tcp_update_recv_tstamps(struct sk_buff *skb,
struct scm_timestamping_internal *tss)
{
if (skb->tstamp)
tss->ts[0] = ktime_to_timespec64(skb->tstamp);
else
tss->ts[0] = (struct timespec64) {0};
if (skb_hwtstamps(skb)->hwtstamp)
tss->ts[2] = ktime_to_timespec64(skb_hwtstamps(skb)->hwtstamp);
else
tss->ts[2] = (struct timespec64) {0};
}
/* Similar to __sock_recv_timestamp, but does not require an skb */ /* Similar to __sock_recv_timestamp, but does not require an skb */
static void tcp_recv_timestamp(struct msghdr *msg, const struct sock *sk, static void tcp_recv_timestamp(struct msghdr *msg, const struct sock *sk,
struct scm_timestamping_internal *tss) struct scm_timestamping_internal *tss)
...@@ -2272,7 +2317,7 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -2272,7 +2317,7 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len,
goto out; goto out;
if (tp->recvmsg_inq) if (tp->recvmsg_inq)
*cmsg_flags = 1; *cmsg_flags = TCP_CMSG_INQ;
timeo = sock_rcvtimeo(sk, nonblock); timeo = sock_rcvtimeo(sk, nonblock);
/* Urgent data needs to be handled specially. */ /* Urgent data needs to be handled specially. */
...@@ -2453,7 +2498,7 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -2453,7 +2498,7 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len,
if (TCP_SKB_CB(skb)->has_rxtstamp) { if (TCP_SKB_CB(skb)->has_rxtstamp) {
tcp_update_recv_tstamps(skb, tss); tcp_update_recv_tstamps(skb, tss);
*cmsg_flags |= 2; *cmsg_flags |= TCP_CMSG_TS;
} }
if (used + offset < skb->len) if (used + offset < skb->len)
...@@ -2513,9 +2558,9 @@ int tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int nonblock, ...@@ -2513,9 +2558,9 @@ int tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int nonblock,
release_sock(sk); release_sock(sk);
if (cmsg_flags && ret >= 0) { if (cmsg_flags && ret >= 0) {
if (cmsg_flags & 2) if (cmsg_flags & TCP_CMSG_TS)
tcp_recv_timestamp(msg, sk, &tss); tcp_recv_timestamp(msg, sk, &tss);
if (cmsg_flags & 1) { if (cmsg_flags & TCP_CMSG_INQ) {
inq = tcp_inq_hint(sk); inq = tcp_inq_hint(sk);
put_cmsg(msg, SOL_TCP, TCP_CM_INQ, sizeof(inq), &inq); put_cmsg(msg, SOL_TCP, TCP_CM_INQ, sizeof(inq), &inq);
} }
...@@ -4099,6 +4144,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level, ...@@ -4099,6 +4144,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
} }
#ifdef CONFIG_MMU #ifdef CONFIG_MMU
case TCP_ZEROCOPY_RECEIVE: { case TCP_ZEROCOPY_RECEIVE: {
struct scm_timestamping_internal tss;
struct tcp_zerocopy_receive zc = {}; struct tcp_zerocopy_receive zc = {};
int err; int err;
...@@ -4114,11 +4160,18 @@ static int do_tcp_getsockopt(struct sock *sk, int level, ...@@ -4114,11 +4160,18 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
if (copy_from_user(&zc, optval, len)) if (copy_from_user(&zc, optval, len))
return -EFAULT; return -EFAULT;
lock_sock(sk); lock_sock(sk);
err = tcp_zerocopy_receive(sk, &zc); err = tcp_zerocopy_receive(sk, &zc, &tss);
release_sock(sk); release_sock(sk);
if (len >= offsetofend(struct tcp_zerocopy_receive, err)) if (len >= offsetofend(struct tcp_zerocopy_receive, msg_flags))
goto zerocopy_rcv_sk_err; goto zerocopy_rcv_cmsg;
switch (len) { switch (len) {
case offsetofend(struct tcp_zerocopy_receive, msg_flags):
goto zerocopy_rcv_cmsg;
case offsetofend(struct tcp_zerocopy_receive, msg_controllen):
case offsetofend(struct tcp_zerocopy_receive, msg_control):
case offsetofend(struct tcp_zerocopy_receive, flags):
case offsetofend(struct tcp_zerocopy_receive, copybuf_len):
case offsetofend(struct tcp_zerocopy_receive, copybuf_address):
case offsetofend(struct tcp_zerocopy_receive, err): case offsetofend(struct tcp_zerocopy_receive, err):
goto zerocopy_rcv_sk_err; goto zerocopy_rcv_sk_err;
case offsetofend(struct tcp_zerocopy_receive, inq): case offsetofend(struct tcp_zerocopy_receive, inq):
...@@ -4127,6 +4180,11 @@ static int do_tcp_getsockopt(struct sock *sk, int level, ...@@ -4127,6 +4180,11 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
default: default:
goto zerocopy_rcv_out; goto zerocopy_rcv_out;
} }
zerocopy_rcv_cmsg:
if (zc.msg_flags & TCP_CMSG_TS)
tcp_zc_finalize_rx_tstamp(sk, &zc, &tss);
else
zc.msg_flags = 0;
zerocopy_rcv_sk_err: zerocopy_rcv_sk_err:
if (!err) if (!err)
zc.err = sock_error(sk); zc.err = sock_error(sk);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment