Commit 66f87790 authored by David S. Miller's avatar David S. Miller

Merge branch 'udp-peek'

Willem de Bruijn says:

====================
udp: support SO_PEEK_OFF

Support peeking at a non-zero offset for UDP sockets. Match the
existing behavior on Unix datagram sockets.

1/3 makes the sk_peek_offset functions safe to use outside locks
2/3 removes udp headers before enqueue, to simplify offset arithmetic
3/3 introduces SO_PEEK_OFFSET support, with Unix socket peek semantics.

Changes
  v1->v2
    - squash patches 3 and 4
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents e43d15c8 627d2d6b
...@@ -2949,7 +2949,12 @@ int skb_copy_datagram_from_iter(struct sk_buff *skb, int offset, ...@@ -2949,7 +2949,12 @@ int skb_copy_datagram_from_iter(struct sk_buff *skb, int offset,
struct iov_iter *from, int len); struct iov_iter *from, int len);
int zerocopy_sg_from_iter(struct sk_buff *skb, struct iov_iter *frm); int zerocopy_sg_from_iter(struct sk_buff *skb, struct iov_iter *frm);
void skb_free_datagram(struct sock *sk, struct sk_buff *skb); void skb_free_datagram(struct sock *sk, struct sk_buff *skb);
void skb_free_datagram_locked(struct sock *sk, struct sk_buff *skb); void __skb_free_datagram_locked(struct sock *sk, struct sk_buff *skb, int len);
static inline void skb_free_datagram_locked(struct sock *sk,
struct sk_buff *skb)
{
__skb_free_datagram_locked(sk, skb, 0);
}
int skb_kill_datagram(struct sock *sk, struct sk_buff *skb, unsigned int flags); int skb_kill_datagram(struct sock *sk, struct sk_buff *skb, unsigned int flags);
int skb_copy_bits(const struct sk_buff *skb, int offset, void *to, int len); int skb_copy_bits(const struct sk_buff *skb, int offset, void *to, int len);
int skb_store_bits(struct sk_buff *skb, int offset, const void *from, int len); int skb_store_bits(struct sk_buff *skb, int offset, const void *from, int len);
......
...@@ -457,28 +457,32 @@ struct sock { ...@@ -457,28 +457,32 @@ struct sock {
#define SK_CAN_REUSE 1 #define SK_CAN_REUSE 1
#define SK_FORCE_REUSE 2 #define SK_FORCE_REUSE 2
int sk_set_peek_off(struct sock *sk, int val);
static inline int sk_peek_offset(struct sock *sk, int flags) static inline int sk_peek_offset(struct sock *sk, int flags)
{ {
if ((flags & MSG_PEEK) && (sk->sk_peek_off >= 0)) if (unlikely(flags & MSG_PEEK)) {
return sk->sk_peek_off; s32 off = READ_ONCE(sk->sk_peek_off);
else if (off >= 0)
return 0; return off;
}
return 0;
} }
static inline void sk_peek_offset_bwd(struct sock *sk, int val) static inline void sk_peek_offset_bwd(struct sock *sk, int val)
{ {
if (sk->sk_peek_off >= 0) { s32 off = READ_ONCE(sk->sk_peek_off);
if (sk->sk_peek_off >= val)
sk->sk_peek_off -= val; if (unlikely(off >= 0)) {
else off = max_t(s32, off - val, 0);
sk->sk_peek_off = 0; WRITE_ONCE(sk->sk_peek_off, off);
} }
} }
static inline void sk_peek_offset_fwd(struct sock *sk, int val) static inline void sk_peek_offset_fwd(struct sock *sk, int val)
{ {
if (sk->sk_peek_off >= 0) sk_peek_offset_bwd(sk, -val);
sk->sk_peek_off += val;
} }
/* /*
...@@ -1862,6 +1866,7 @@ void sk_reset_timer(struct sock *sk, struct timer_list *timer, ...@@ -1862,6 +1866,7 @@ void sk_reset_timer(struct sock *sk, struct timer_list *timer,
void sk_stop_timer(struct sock *sk, struct timer_list *timer); void sk_stop_timer(struct sock *sk, struct timer_list *timer);
int __sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb);
int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb); int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb);
int sock_queue_err_skb(struct sock *sk, struct sk_buff *skb); int sock_queue_err_skb(struct sock *sk, struct sk_buff *skb);
......
...@@ -158,6 +158,15 @@ static inline __sum16 udp_v4_check(int len, __be32 saddr, ...@@ -158,6 +158,15 @@ static inline __sum16 udp_v4_check(int len, __be32 saddr,
void udp_set_csum(bool nocheck, struct sk_buff *skb, void udp_set_csum(bool nocheck, struct sk_buff *skb,
__be32 saddr, __be32 daddr, int len); __be32 saddr, __be32 daddr, int len);
static inline void udp_csum_pull_header(struct sk_buff *skb)
{
if (skb->ip_summed == CHECKSUM_NONE)
skb->csum = csum_partial(udp_hdr(skb), sizeof(struct udphdr),
skb->csum);
skb_pull_rcsum(skb, sizeof(struct udphdr));
UDP_SKB_CB(skb)->cscov -= sizeof(struct udphdr);
}
struct sk_buff **udp_gro_receive(struct sk_buff **head, struct sk_buff *skb, struct sk_buff **udp_gro_receive(struct sk_buff **head, struct sk_buff *skb,
struct udphdr *uh); struct udphdr *uh);
int udp_gro_complete(struct sk_buff *skb, int nhoff); int udp_gro_complete(struct sk_buff *skb, int nhoff);
......
...@@ -301,16 +301,19 @@ void skb_free_datagram(struct sock *sk, struct sk_buff *skb) ...@@ -301,16 +301,19 @@ void skb_free_datagram(struct sock *sk, struct sk_buff *skb)
} }
EXPORT_SYMBOL(skb_free_datagram); EXPORT_SYMBOL(skb_free_datagram);
void skb_free_datagram_locked(struct sock *sk, struct sk_buff *skb) void __skb_free_datagram_locked(struct sock *sk, struct sk_buff *skb, int len)
{ {
bool slow; bool slow;
if (likely(atomic_read(&skb->users) == 1)) if (likely(atomic_read(&skb->users) == 1))
smp_rmb(); smp_rmb();
else if (likely(!atomic_dec_and_test(&skb->users))) else if (likely(!atomic_dec_and_test(&skb->users))) {
sk_peek_offset_bwd(sk, len);
return; return;
}
slow = lock_sock_fast(sk); slow = lock_sock_fast(sk);
sk_peek_offset_bwd(sk, len);
skb_orphan(skb); skb_orphan(skb);
sk_mem_reclaim_partial(sk); sk_mem_reclaim_partial(sk);
unlock_sock_fast(sk, slow); unlock_sock_fast(sk, slow);
...@@ -318,7 +321,7 @@ void skb_free_datagram_locked(struct sock *sk, struct sk_buff *skb) ...@@ -318,7 +321,7 @@ void skb_free_datagram_locked(struct sock *sk, struct sk_buff *skb)
/* skb is now orphaned, can be freed outside of locked section */ /* skb is now orphaned, can be freed outside of locked section */
__kfree_skb(skb); __kfree_skb(skb);
} }
EXPORT_SYMBOL(skb_free_datagram_locked); EXPORT_SYMBOL(__skb_free_datagram_locked);
/** /**
* skb_kill_datagram - Free a datagram skbuff forcibly * skb_kill_datagram - Free a datagram skbuff forcibly
......
...@@ -402,9 +402,8 @@ static void sock_disable_timestamp(struct sock *sk, unsigned long flags) ...@@ -402,9 +402,8 @@ static void sock_disable_timestamp(struct sock *sk, unsigned long flags)
} }
int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) int __sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
{ {
int err;
unsigned long flags; unsigned long flags;
struct sk_buff_head *list = &sk->sk_receive_queue; struct sk_buff_head *list = &sk->sk_receive_queue;
...@@ -414,10 +413,6 @@ int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) ...@@ -414,10 +413,6 @@ int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
return -ENOMEM; return -ENOMEM;
} }
err = sk_filter(sk, skb);
if (err)
return err;
if (!sk_rmem_schedule(sk, skb, skb->truesize)) { if (!sk_rmem_schedule(sk, skb, skb->truesize)) {
atomic_inc(&sk->sk_drops); atomic_inc(&sk->sk_drops);
return -ENOBUFS; return -ENOBUFS;
...@@ -440,6 +435,18 @@ int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) ...@@ -440,6 +435,18 @@ int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
sk->sk_data_ready(sk); sk->sk_data_ready(sk);
return 0; return 0;
} }
EXPORT_SYMBOL(__sock_queue_rcv_skb);
int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
{
int err;
err = sk_filter(sk, skb);
if (err)
return err;
return __sock_queue_rcv_skb(sk, skb);
}
EXPORT_SYMBOL(sock_queue_rcv_skb); EXPORT_SYMBOL(sock_queue_rcv_skb);
int sk_receive_skb(struct sock *sk, struct sk_buff *skb, const int nested) int sk_receive_skb(struct sock *sk, struct sk_buff *skb, const int nested)
...@@ -2180,6 +2187,15 @@ void __sk_mem_reclaim(struct sock *sk, int amount) ...@@ -2180,6 +2187,15 @@ void __sk_mem_reclaim(struct sock *sk, int amount)
} }
EXPORT_SYMBOL(__sk_mem_reclaim); EXPORT_SYMBOL(__sk_mem_reclaim);
int sk_set_peek_off(struct sock *sk, int val)
{
if (val < 0)
return -EINVAL;
sk->sk_peek_off = val;
return 0;
}
EXPORT_SYMBOL_GPL(sk_set_peek_off);
/* /*
* Set of default routines for initialising struct proto_ops when * Set of default routines for initialising struct proto_ops when
......
...@@ -948,6 +948,7 @@ const struct proto_ops inet_dgram_ops = { ...@@ -948,6 +948,7 @@ const struct proto_ops inet_dgram_ops = {
.recvmsg = inet_recvmsg, .recvmsg = inet_recvmsg,
.mmap = sock_no_mmap, .mmap = sock_no_mmap,
.sendpage = inet_sendpage, .sendpage = inet_sendpage,
.set_peek_off = sk_set_peek_off,
#ifdef CONFIG_COMPAT #ifdef CONFIG_COMPAT
.compat_setsockopt = compat_sock_common_setsockopt, .compat_setsockopt = compat_sock_common_setsockopt,
.compat_getsockopt = compat_sock_common_getsockopt, .compat_getsockopt = compat_sock_common_getsockopt,
......
...@@ -1294,7 +1294,7 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, ...@@ -1294,7 +1294,7 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock,
DECLARE_SOCKADDR(struct sockaddr_in *, sin, msg->msg_name); DECLARE_SOCKADDR(struct sockaddr_in *, sin, msg->msg_name);
struct sk_buff *skb; struct sk_buff *skb;
unsigned int ulen, copied; unsigned int ulen, copied;
int peeked, off = 0; int peeked, peeking, off;
int err; int err;
int is_udplite = IS_UDPLITE(sk); int is_udplite = IS_UDPLITE(sk);
bool checksum_valid = false; bool checksum_valid = false;
...@@ -1304,15 +1304,16 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, ...@@ -1304,15 +1304,16 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock,
return ip_recv_error(sk, msg, len, addr_len); return ip_recv_error(sk, msg, len, addr_len);
try_again: try_again:
peeking = off = sk_peek_offset(sk, flags);
skb = __skb_recv_datagram(sk, flags | (noblock ? MSG_DONTWAIT : 0), skb = __skb_recv_datagram(sk, flags | (noblock ? MSG_DONTWAIT : 0),
&peeked, &off, &err); &peeked, &off, &err);
if (!skb) if (!skb)
goto out; return err;
ulen = skb->len - sizeof(struct udphdr); ulen = skb->len;
copied = len; copied = len;
if (copied > ulen) if (copied > ulen - off)
copied = ulen; copied = ulen - off;
else if (copied < ulen) else if (copied < ulen)
msg->msg_flags |= MSG_TRUNC; msg->msg_flags |= MSG_TRUNC;
...@@ -1322,18 +1323,16 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, ...@@ -1322,18 +1323,16 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock,
* coverage checksum (UDP-Lite), do it before the copy. * coverage checksum (UDP-Lite), do it before the copy.
*/ */
if (copied < ulen || UDP_SKB_CB(skb)->partial_cov) { if (copied < ulen || UDP_SKB_CB(skb)->partial_cov || peeking) {
checksum_valid = !udp_lib_checksum_complete(skb); checksum_valid = !udp_lib_checksum_complete(skb);
if (!checksum_valid) if (!checksum_valid)
goto csum_copy_err; goto csum_copy_err;
} }
if (checksum_valid || skb_csum_unnecessary(skb)) if (checksum_valid || skb_csum_unnecessary(skb))
err = skb_copy_datagram_msg(skb, sizeof(struct udphdr), err = skb_copy_datagram_msg(skb, off, msg, copied);
msg, copied);
else { else {
err = skb_copy_and_csum_datagram_msg(skb, sizeof(struct udphdr), err = skb_copy_and_csum_datagram_msg(skb, off, msg);
msg);
if (err == -EINVAL) if (err == -EINVAL)
goto csum_copy_err; goto csum_copy_err;
...@@ -1346,7 +1345,8 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, ...@@ -1346,7 +1345,8 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock,
UDP_INC_STATS_USER(sock_net(sk), UDP_INC_STATS_USER(sock_net(sk),
UDP_MIB_INERRORS, is_udplite); UDP_MIB_INERRORS, is_udplite);
} }
goto out_free; skb_free_datagram_locked(sk, skb);
return err;
} }
if (!peeked) if (!peeked)
...@@ -1370,9 +1370,7 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock, ...@@ -1370,9 +1370,7 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock,
if (flags & MSG_TRUNC) if (flags & MSG_TRUNC)
err = ulen; err = ulen;
out_free: __skb_free_datagram_locked(sk, skb, peeking ? -err : err);
skb_free_datagram_locked(sk, skb);
out:
return err; return err;
csum_copy_err: csum_copy_err:
...@@ -1500,7 +1498,7 @@ static int __udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) ...@@ -1500,7 +1498,7 @@ static int __udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
sk_incoming_cpu_update(sk); sk_incoming_cpu_update(sk);
} }
rc = sock_queue_rcv_skb(sk, skb); rc = __sock_queue_rcv_skb(sk, skb);
if (rc < 0) { if (rc < 0) {
int is_udplite = IS_UDPLITE(sk); int is_udplite = IS_UDPLITE(sk);
...@@ -1616,10 +1614,14 @@ int udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) ...@@ -1616,10 +1614,14 @@ int udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
} }
} }
if (rcu_access_pointer(sk->sk_filter) && if (rcu_access_pointer(sk->sk_filter)) {
udp_lib_checksum_complete(skb)) if (udp_lib_checksum_complete(skb))
goto csum_error; goto csum_error;
if (sk_filter(sk, skb))
goto drop;
}
udp_csum_pull_header(skb);
if (sk_rcvqueues_full(sk, sk->sk_rcvbuf)) { if (sk_rcvqueues_full(sk, sk->sk_rcvbuf)) {
UDP_INC_STATS_BH(sock_net(sk), UDP_MIB_RCVBUFERRORS, UDP_INC_STATS_BH(sock_net(sk), UDP_MIB_RCVBUFERRORS,
is_udplite); is_udplite);
......
...@@ -561,6 +561,7 @@ const struct proto_ops inet6_dgram_ops = { ...@@ -561,6 +561,7 @@ const struct proto_ops inet6_dgram_ops = {
.recvmsg = inet_recvmsg, /* ok */ .recvmsg = inet_recvmsg, /* ok */
.mmap = sock_no_mmap, .mmap = sock_no_mmap,
.sendpage = sock_no_sendpage, .sendpage = sock_no_sendpage,
.set_peek_off = sk_set_peek_off,
#ifdef CONFIG_COMPAT #ifdef CONFIG_COMPAT
.compat_setsockopt = compat_sock_common_setsockopt, .compat_setsockopt = compat_sock_common_setsockopt,
.compat_getsockopt = compat_sock_common_getsockopt, .compat_getsockopt = compat_sock_common_getsockopt,
......
...@@ -357,7 +357,7 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -357,7 +357,7 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
struct inet_sock *inet = inet_sk(sk); struct inet_sock *inet = inet_sk(sk);
struct sk_buff *skb; struct sk_buff *skb;
unsigned int ulen, copied; unsigned int ulen, copied;
int peeked, off = 0; int peeked, peeking, off;
int err; int err;
int is_udplite = IS_UDPLITE(sk); int is_udplite = IS_UDPLITE(sk);
bool checksum_valid = false; bool checksum_valid = false;
...@@ -371,15 +371,16 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -371,15 +371,16 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
return ipv6_recv_rxpmtu(sk, msg, len, addr_len); return ipv6_recv_rxpmtu(sk, msg, len, addr_len);
try_again: try_again:
peeking = off = sk_peek_offset(sk, flags);
skb = __skb_recv_datagram(sk, flags | (noblock ? MSG_DONTWAIT : 0), skb = __skb_recv_datagram(sk, flags | (noblock ? MSG_DONTWAIT : 0),
&peeked, &off, &err); &peeked, &off, &err);
if (!skb) if (!skb)
goto out; return err;
ulen = skb->len - sizeof(struct udphdr); ulen = skb->len;
copied = len; copied = len;
if (copied > ulen) if (copied > ulen - off)
copied = ulen; copied = ulen - off;
else if (copied < ulen) else if (copied < ulen)
msg->msg_flags |= MSG_TRUNC; msg->msg_flags |= MSG_TRUNC;
...@@ -391,17 +392,16 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -391,17 +392,16 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
* coverage checksum (UDP-Lite), do it before the copy. * coverage checksum (UDP-Lite), do it before the copy.
*/ */
if (copied < ulen || UDP_SKB_CB(skb)->partial_cov) { if (copied < ulen || UDP_SKB_CB(skb)->partial_cov || peeking) {
checksum_valid = !udp_lib_checksum_complete(skb); checksum_valid = !udp_lib_checksum_complete(skb);
if (!checksum_valid) if (!checksum_valid)
goto csum_copy_err; goto csum_copy_err;
} }
if (checksum_valid || skb_csum_unnecessary(skb)) if (checksum_valid || skb_csum_unnecessary(skb))
err = skb_copy_datagram_msg(skb, sizeof(struct udphdr), err = skb_copy_datagram_msg(skb, off, msg, copied);
msg, copied);
else { else {
err = skb_copy_and_csum_datagram_msg(skb, sizeof(struct udphdr), msg); err = skb_copy_and_csum_datagram_msg(skb, off, msg);
if (err == -EINVAL) if (err == -EINVAL)
goto csum_copy_err; goto csum_copy_err;
} }
...@@ -418,7 +418,8 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -418,7 +418,8 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
UDP_MIB_INERRORS, UDP_MIB_INERRORS,
is_udplite); is_udplite);
} }
goto out_free; skb_free_datagram_locked(sk, skb);
return err;
} }
if (!peeked) { if (!peeked) {
if (is_udp4) if (is_udp4)
...@@ -466,9 +467,7 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -466,9 +467,7 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
if (flags & MSG_TRUNC) if (flags & MSG_TRUNC)
err = ulen; err = ulen;
out_free: __skb_free_datagram_locked(sk, skb, peeking ? -err : err);
skb_free_datagram_locked(sk, skb);
out:
return err; return err;
csum_copy_err: csum_copy_err:
...@@ -554,7 +553,7 @@ static int __udpv6_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) ...@@ -554,7 +553,7 @@ static int __udpv6_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
sk_incoming_cpu_update(sk); sk_incoming_cpu_update(sk);
} }
rc = sock_queue_rcv_skb(sk, skb); rc = __sock_queue_rcv_skb(sk, skb);
if (rc < 0) { if (rc < 0) {
int is_udplite = IS_UDPLITE(sk); int is_udplite = IS_UDPLITE(sk);
...@@ -648,8 +647,11 @@ int udpv6_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) ...@@ -648,8 +647,11 @@ int udpv6_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
if (rcu_access_pointer(sk->sk_filter)) { if (rcu_access_pointer(sk->sk_filter)) {
if (udp_lib_checksum_complete(skb)) if (udp_lib_checksum_complete(skb))
goto csum_error; goto csum_error;
if (sk_filter(sk, skb))
goto drop;
} }
udp_csum_pull_header(skb);
if (sk_rcvqueues_full(sk, sk->sk_rcvbuf)) { if (sk_rcvqueues_full(sk, sk->sk_rcvbuf)) {
UDP6_INC_STATS_BH(sock_net(sk), UDP6_INC_STATS_BH(sock_net(sk),
UDP_MIB_RCVBUFERRORS, is_udplite); UDP_MIB_RCVBUFERRORS, is_udplite);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment