Commit d4253c62 authored by David S. Miller's avatar David S. Miller

Merge branch 'ip_cmsg_csum'

Tom Herbert says:

====================
ip: Support checksum returned in csmg

This patch set allows the packet checksum for a datagram socket
to be returned in csum data in recvmsg. This allows userspace
to implement its own checksum over the data, for instance if an
IP tunnel was be implemented in user space, the inner checksum
could be validated.

Changes in this patch set:
  - Move checksum conversion to inet_sock from udp_sock. This
    generalizes checksum conversion for use with other protocols.
  - Move IP cmsg constants to a header file and make processing
    of the flags more efficient in ip_cmsg_recv
  - Return checksum value in cmsg. This is specifically the unfolded
    32 bit checksum of the full packet starting from the first byte
    returned in recvmsg

Tested: Wrote a little server to get checksums in cmsg for UDP and
        verfied correct checksum is returned.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 149118d8 ad6f939a
...@@ -49,11 +49,7 @@ struct udp_sock { ...@@ -49,11 +49,7 @@ struct udp_sock {
unsigned int corkflag; /* Cork is required */ unsigned int corkflag; /* Cork is required */
__u8 encap_type; /* Is this an Encapsulation socket? */ __u8 encap_type; /* Is this an Encapsulation socket? */
unsigned char no_check6_tx:1,/* Send zero UDP6 checksums on TX? */ unsigned char no_check6_tx:1,/* Send zero UDP6 checksums on TX? */
no_check6_rx:1,/* Allow zero UDP6 checksums on RX? */ no_check6_rx:1;/* Allow zero UDP6 checksums on RX? */
convert_csum:1;/* On receive, convert checksum
* unnecessary to checksum complete
* if possible.
*/
/* /*
* Following member retains the information to create a UDP header * Following member retains the information to create a UDP header
* when the socket is uncorked. * when the socket is uncorked.
...@@ -102,16 +98,6 @@ static inline bool udp_get_no_check6_rx(struct sock *sk) ...@@ -102,16 +98,6 @@ static inline bool udp_get_no_check6_rx(struct sock *sk)
return udp_sk(sk)->no_check6_rx; return udp_sk(sk)->no_check6_rx;
} }
static inline void udp_set_convert_csum(struct sock *sk, bool val)
{
udp_sk(sk)->convert_csum = val;
}
static inline bool udp_get_convert_csum(struct sock *sk)
{
return udp_sk(sk)->convert_csum;
}
#define udp_portaddr_for_each_entry(__sk, node, list) \ #define udp_portaddr_for_each_entry(__sk, node, list) \
hlist_nulls_for_each_entry(__sk, node, list, __sk_common.skc_portaddr_node) hlist_nulls_for_each_entry(__sk, node, list, __sk_common.skc_portaddr_node)
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#ifndef _INET_SOCK_H #ifndef _INET_SOCK_H
#define _INET_SOCK_H #define _INET_SOCK_H
#include <linux/bitops.h>
#include <linux/kmemcheck.h> #include <linux/kmemcheck.h>
#include <linux/string.h> #include <linux/string.h>
#include <linux/types.h> #include <linux/types.h>
...@@ -184,6 +184,7 @@ struct inet_sock { ...@@ -184,6 +184,7 @@ struct inet_sock {
mc_all:1, mc_all:1,
nodefrag:1; nodefrag:1;
__u8 rcv_tos; __u8 rcv_tos;
__u8 convert_csum;
int uc_index; int uc_index;
int mc_index; int mc_index;
__be32 mc_addr; __be32 mc_addr;
...@@ -194,6 +195,16 @@ struct inet_sock { ...@@ -194,6 +195,16 @@ struct inet_sock {
#define IPCORK_OPT 1 /* ip-options has been held in ipcork.opt */ #define IPCORK_OPT 1 /* ip-options has been held in ipcork.opt */
#define IPCORK_ALLFRAG 2 /* always fragment (for ipv6 for now) */ #define IPCORK_ALLFRAG 2 /* always fragment (for ipv6 for now) */
/* cmsg flags for inet */
#define IP_CMSG_PKTINFO BIT(0)
#define IP_CMSG_TTL BIT(1)
#define IP_CMSG_TOS BIT(2)
#define IP_CMSG_RECVOPTS BIT(3)
#define IP_CMSG_RETOPTS BIT(4)
#define IP_CMSG_PASSSEC BIT(5)
#define IP_CMSG_ORIGDSTADDR BIT(6)
#define IP_CMSG_CHECKSUM BIT(7)
static inline struct inet_sock *inet_sk(const struct sock *sk) static inline struct inet_sock *inet_sk(const struct sock *sk)
{ {
return (struct inet_sock *)sk; return (struct inet_sock *)sk;
...@@ -250,4 +261,20 @@ static inline __u8 inet_sk_flowi_flags(const struct sock *sk) ...@@ -250,4 +261,20 @@ static inline __u8 inet_sk_flowi_flags(const struct sock *sk)
return flags; return flags;
} }
static inline void inet_inc_convert_csum(struct sock *sk)
{
inet_sk(sk)->convert_csum++;
}
static inline void inet_dec_convert_csum(struct sock *sk)
{
if (inet_sk(sk)->convert_csum > 0)
inet_sk(sk)->convert_csum--;
}
static inline bool inet_get_convert_csum(struct sock *sk)
{
return !!inet_sk(sk)->convert_csum;
}
#endif /* _INET_SOCK_H */ #endif /* _INET_SOCK_H */
...@@ -537,7 +537,7 @@ int ip_options_rcv_srr(struct sk_buff *skb); ...@@ -537,7 +537,7 @@ int ip_options_rcv_srr(struct sk_buff *skb);
*/ */
void ipv4_pktinfo_prepare(const struct sock *sk, struct sk_buff *skb); void ipv4_pktinfo_prepare(const struct sock *sk, struct sk_buff *skb);
void ip_cmsg_recv(struct msghdr *msg, struct sk_buff *skb); void ip_cmsg_recv_offset(struct msghdr *msg, struct sk_buff *skb, int offset);
int ip_cmsg_send(struct net *net, struct msghdr *msg, int ip_cmsg_send(struct net *net, struct msghdr *msg,
struct ipcm_cookie *ipc, bool allow_ipv6); struct ipcm_cookie *ipc, bool allow_ipv6);
int ip_setsockopt(struct sock *sk, int level, int optname, char __user *optval, int ip_setsockopt(struct sock *sk, int level, int optname, char __user *optval,
...@@ -557,6 +557,11 @@ void ip_icmp_error(struct sock *sk, struct sk_buff *skb, int err, __be16 port, ...@@ -557,6 +557,11 @@ void ip_icmp_error(struct sock *sk, struct sk_buff *skb, int err, __be16 port,
void ip_local_error(struct sock *sk, int err, __be32 daddr, __be16 dport, void ip_local_error(struct sock *sk, int err, __be32 daddr, __be16 dport,
u32 info); u32 info);
static inline void ip_cmsg_recv(struct msghdr *msg, struct sk_buff *skb)
{
ip_cmsg_recv_offset(msg, skb, 0);
}
bool icmp_global_allow(void); bool icmp_global_allow(void);
extern int sysctl_icmp_msgs_per_sec; extern int sysctl_icmp_msgs_per_sec;
extern int sysctl_icmp_msgs_burst; extern int sysctl_icmp_msgs_burst;
......
...@@ -109,6 +109,7 @@ struct in_addr { ...@@ -109,6 +109,7 @@ struct in_addr {
#define IP_MINTTL 21 #define IP_MINTTL 21
#define IP_NODEFRAG 22 #define IP_NODEFRAG 22
#define IP_CHECKSUM 23
/* IP_MTU_DISCOVER values */ /* IP_MTU_DISCOVER values */
#define IP_PMTUDISC_DONT 0 /* Never send DF frames */ #define IP_PMTUDISC_DONT 0 /* Never send DF frames */
......
...@@ -490,7 +490,7 @@ static int fou_create(struct net *net, struct fou_cfg *cfg, ...@@ -490,7 +490,7 @@ static int fou_create(struct net *net, struct fou_cfg *cfg,
sk->sk_user_data = fou; sk->sk_user_data = fou;
fou->sock = sock; fou->sock = sock;
udp_set_convert_csum(sk, true); inet_inc_convert_csum(sk);
sk->sk_allocation = GFP_ATOMIC; sk->sk_allocation = GFP_ATOMIC;
......
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include <net/route.h> #include <net/route.h>
#include <net/xfrm.h> #include <net/xfrm.h>
#include <net/compat.h> #include <net/compat.h>
#include <net/checksum.h>
#if IS_ENABLED(CONFIG_IPV6) #if IS_ENABLED(CONFIG_IPV6)
#include <net/transp_v6.h> #include <net/transp_v6.h>
#endif #endif
...@@ -45,14 +46,6 @@ ...@@ -45,14 +46,6 @@
#include <linux/errqueue.h> #include <linux/errqueue.h>
#include <asm/uaccess.h> #include <asm/uaccess.h>
#define IP_CMSG_PKTINFO 1
#define IP_CMSG_TTL 2
#define IP_CMSG_TOS 4
#define IP_CMSG_RECVOPTS 8
#define IP_CMSG_RETOPTS 16
#define IP_CMSG_PASSSEC 32
#define IP_CMSG_ORIGDSTADDR 64
/* /*
* SOL_IP control messages. * SOL_IP control messages.
*/ */
...@@ -104,6 +97,20 @@ static void ip_cmsg_recv_retopts(struct msghdr *msg, struct sk_buff *skb) ...@@ -104,6 +97,20 @@ static void ip_cmsg_recv_retopts(struct msghdr *msg, struct sk_buff *skb)
put_cmsg(msg, SOL_IP, IP_RETOPTS, opt->optlen, opt->__data); put_cmsg(msg, SOL_IP, IP_RETOPTS, opt->optlen, opt->__data);
} }
static void ip_cmsg_recv_checksum(struct msghdr *msg, struct sk_buff *skb,
int offset)
{
__wsum csum = skb->csum;
if (skb->ip_summed != CHECKSUM_COMPLETE)
return;
if (offset != 0)
csum = csum_sub(csum, csum_partial(skb->data, offset, 0));
put_cmsg(msg, SOL_IP, IP_CHECKSUM, sizeof(__wsum), &csum);
}
static void ip_cmsg_recv_security(struct msghdr *msg, struct sk_buff *skb) static void ip_cmsg_recv_security(struct msghdr *msg, struct sk_buff *skb)
{ {
char *secdata; char *secdata;
...@@ -144,47 +151,73 @@ static void ip_cmsg_recv_dstaddr(struct msghdr *msg, struct sk_buff *skb) ...@@ -144,47 +151,73 @@ static void ip_cmsg_recv_dstaddr(struct msghdr *msg, struct sk_buff *skb)
put_cmsg(msg, SOL_IP, IP_ORIGDSTADDR, sizeof(sin), &sin); put_cmsg(msg, SOL_IP, IP_ORIGDSTADDR, sizeof(sin), &sin);
} }
void ip_cmsg_recv(struct msghdr *msg, struct sk_buff *skb) void ip_cmsg_recv_offset(struct msghdr *msg, struct sk_buff *skb,
int offset)
{ {
struct inet_sock *inet = inet_sk(skb->sk); struct inet_sock *inet = inet_sk(skb->sk);
unsigned int flags = inet->cmsg_flags; unsigned int flags = inet->cmsg_flags;
/* Ordered by supposed usage frequency */ /* Ordered by supposed usage frequency */
if (flags & 1) if (flags & IP_CMSG_PKTINFO) {
ip_cmsg_recv_pktinfo(msg, skb); ip_cmsg_recv_pktinfo(msg, skb);
if ((flags >>= 1) == 0)
return;
if (flags & 1) flags &= ~IP_CMSG_PKTINFO;
if (!flags)
return;
}
if (flags & IP_CMSG_TTL) {
ip_cmsg_recv_ttl(msg, skb); ip_cmsg_recv_ttl(msg, skb);
if ((flags >>= 1) == 0)
return;
if (flags & 1) flags &= ~IP_CMSG_TTL;
if (!flags)
return;
}
if (flags & IP_CMSG_TOS) {
ip_cmsg_recv_tos(msg, skb); ip_cmsg_recv_tos(msg, skb);
if ((flags >>= 1) == 0)
return;
if (flags & 1) flags &= ~IP_CMSG_TOS;
if (!flags)
return;
}
if (flags & IP_CMSG_RECVOPTS) {
ip_cmsg_recv_opts(msg, skb); ip_cmsg_recv_opts(msg, skb);
if ((flags >>= 1) == 0)
return;
if (flags & 1) flags &= ~IP_CMSG_RECVOPTS;
if (!flags)
return;
}
if (flags & IP_CMSG_RETOPTS) {
ip_cmsg_recv_retopts(msg, skb); ip_cmsg_recv_retopts(msg, skb);
if ((flags >>= 1) == 0)
return;
if (flags & 1) flags &= ~IP_CMSG_RETOPTS;
if (!flags)
return;
}
if (flags & IP_CMSG_PASSSEC) {
ip_cmsg_recv_security(msg, skb); ip_cmsg_recv_security(msg, skb);
if ((flags >>= 1) == 0) flags &= ~IP_CMSG_PASSSEC;
return; if (!flags)
if (flags & 1) return;
}
if (flags & IP_CMSG_ORIGDSTADDR) {
ip_cmsg_recv_dstaddr(msg, skb); ip_cmsg_recv_dstaddr(msg, skb);
flags &= ~IP_CMSG_ORIGDSTADDR;
if (!flags)
return;
}
if (flags & IP_CMSG_CHECKSUM)
ip_cmsg_recv_checksum(msg, skb, offset);
} }
EXPORT_SYMBOL(ip_cmsg_recv); EXPORT_SYMBOL(ip_cmsg_recv_offset);
int ip_cmsg_send(struct net *net, struct msghdr *msg, struct ipcm_cookie *ipc, int ip_cmsg_send(struct net *net, struct msghdr *msg, struct ipcm_cookie *ipc,
bool allow_ipv6) bool allow_ipv6)
...@@ -522,6 +555,7 @@ static int do_ip_setsockopt(struct sock *sk, int level, ...@@ -522,6 +555,7 @@ static int do_ip_setsockopt(struct sock *sk, int level,
case IP_MULTICAST_ALL: case IP_MULTICAST_ALL:
case IP_MULTICAST_LOOP: case IP_MULTICAST_LOOP:
case IP_RECVORIGDSTADDR: case IP_RECVORIGDSTADDR:
case IP_CHECKSUM:
if (optlen >= sizeof(int)) { if (optlen >= sizeof(int)) {
if (get_user(val, (int __user *) optval)) if (get_user(val, (int __user *) optval))
return -EFAULT; return -EFAULT;
...@@ -619,6 +653,19 @@ static int do_ip_setsockopt(struct sock *sk, int level, ...@@ -619,6 +653,19 @@ static int do_ip_setsockopt(struct sock *sk, int level,
else else
inet->cmsg_flags &= ~IP_CMSG_ORIGDSTADDR; inet->cmsg_flags &= ~IP_CMSG_ORIGDSTADDR;
break; break;
case IP_CHECKSUM:
if (val) {
if (!(inet->cmsg_flags & IP_CMSG_CHECKSUM)) {
inet_inc_convert_csum(sk);
inet->cmsg_flags |= IP_CMSG_CHECKSUM;
}
} else {
if (inet->cmsg_flags & IP_CMSG_CHECKSUM) {
inet_dec_convert_csum(sk);
inet->cmsg_flags &= ~IP_CMSG_CHECKSUM;
}
}
break;
case IP_TOS: /* This sets both TOS and Precedence */ case IP_TOS: /* This sets both TOS and Precedence */
if (sk->sk_type == SOCK_STREAM) { if (sk->sk_type == SOCK_STREAM) {
val &= ~INET_ECN_MASK; val &= ~INET_ECN_MASK;
...@@ -1222,6 +1269,9 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname, ...@@ -1222,6 +1269,9 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname,
case IP_RECVORIGDSTADDR: case IP_RECVORIGDSTADDR:
val = (inet->cmsg_flags & IP_CMSG_ORIGDSTADDR) != 0; val = (inet->cmsg_flags & IP_CMSG_ORIGDSTADDR) != 0;
break; break;
case IP_CHECKSUM:
val = (inet->cmsg_flags & IP_CMSG_CHECKSUM) != 0;
break;
case IP_TOS: case IP_TOS:
val = inet->tos; val = inet->tos;
break; break;
......
...@@ -1329,7 +1329,7 @@ int udp_recvmsg(struct kiocb *iocb, struct sock *sk, struct msghdr *msg, ...@@ -1329,7 +1329,7 @@ int udp_recvmsg(struct kiocb *iocb, struct sock *sk, struct msghdr *msg,
*addr_len = sizeof(*sin); *addr_len = sizeof(*sin);
} }
if (inet->cmsg_flags) if (inet->cmsg_flags)
ip_cmsg_recv(msg, skb); ip_cmsg_recv_offset(msg, skb, sizeof(struct udphdr));
err = copied; err = copied;
if (flags & MSG_TRUNC) if (flags & MSG_TRUNC)
...@@ -1806,7 +1806,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, ...@@ -1806,7 +1806,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
if (sk != NULL) { if (sk != NULL) {
int ret; int ret;
if (udp_sk(sk)->convert_csum && uh->check && !IS_UDPLITE(sk)) if (inet_get_convert_csum(sk) && uh->check && !IS_UDPLITE(sk))
skb_checksum_try_convert(skb, IPPROTO_UDP, uh->check, skb_checksum_try_convert(skb, IPPROTO_UDP, uh->check,
inet_compute_pseudo); inet_compute_pseudo);
......
...@@ -63,7 +63,7 @@ void setup_udp_tunnel_sock(struct net *net, struct socket *sock, ...@@ -63,7 +63,7 @@ void setup_udp_tunnel_sock(struct net *net, struct socket *sock,
inet_sk(sk)->mc_loop = 0; inet_sk(sk)->mc_loop = 0;
/* Enable CHECKSUM_UNNECESSARY to CHECKSUM_COMPLETE conversion */ /* Enable CHECKSUM_UNNECESSARY to CHECKSUM_COMPLETE conversion */
udp_set_convert_csum(sk, true); inet_inc_convert_csum(sk);
rcu_assign_sk_user_data(sk, cfg->sk_user_data); rcu_assign_sk_user_data(sk, cfg->sk_user_data);
......
...@@ -909,7 +909,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, ...@@ -909,7 +909,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
goto csum_error; goto csum_error;
} }
if (udp_sk(sk)->convert_csum && uh->check && !IS_UDPLITE(sk)) if (inet_get_convert_csum(sk) && uh->check && !IS_UDPLITE(sk))
skb_checksum_try_convert(skb, IPPROTO_UDP, uh->check, skb_checksum_try_convert(skb, IPPROTO_UDP, uh->check,
ip6_compute_pseudo); ip6_compute_pseudo);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment