Commit ba5ea614 authored by Linus Lüssing's avatar Linus Lüssing Committed by David S. Miller

bridge: simplify ip_mc_check_igmp() and ipv6_mc_check_mld() calls

This patch refactors ip_mc_check_igmp(), ipv6_mc_check_mld() and
their callers (more precisely, the Linux bridge) to not rely on
the skb_trimmed parameter anymore.

An skb with its tail trimmed to the IP packet length was initially
introduced for the following three reasons:

1) To be able to verify the ICMPv6 checksum.
2) To be able to distinguish the version of an IGMP or MLD query.
   They are distinguishable only by their size.
3) To avoid parsing data for an IGMPv3 or MLDv2 report that is
   beyond the IP packet but still within the skb.

The first case still uses a cloned and potentially trimmed skb to
verfiy. However, there is no need to propagate it to the caller.
For the second and third case explicit IP packet length checks were
added.

This hopefully makes ip_mc_check_igmp() and ipv6_mc_check_mld() easier
to read and verfiy, as well as easier to use.
Signed-off-by: default avatarLinus Lüssing <linus.luessing@c0d3.blue>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 6679cf09
......@@ -18,6 +18,7 @@
#include <linux/skbuff.h>
#include <linux/timer.h>
#include <linux/in.h>
#include <linux/ip.h>
#include <linux/refcount.h>
#include <uapi/linux/igmp.h>
......@@ -106,6 +107,14 @@ struct ip_mc_list {
#define IGMPV3_QQIC(value) IGMPV3_EXP(0x80, 4, 3, value)
#define IGMPV3_MRC(value) IGMPV3_EXP(0x80, 4, 3, value)
static inline int ip_mc_may_pull(struct sk_buff *skb, unsigned int len)
{
if (skb_transport_offset(skb) + ip_transport_len(skb) < len)
return -EINVAL;
return pskb_may_pull(skb, len);
}
extern int ip_check_mc_rcu(struct in_device *dev, __be32 mc_addr, __be32 src_addr, u8 proto);
extern int igmp_rcv(struct sk_buff *);
extern int ip_mc_join_group(struct sock *sk, struct ip_mreqn *imr);
......@@ -130,6 +139,6 @@ extern void ip_mc_unmap(struct in_device *);
extern void ip_mc_remap(struct in_device *);
extern void ip_mc_dec_group(struct in_device *in_dev, __be32 addr);
extern void ip_mc_inc_group(struct in_device *in_dev, __be32 addr);
int ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed);
int ip_mc_check_igmp(struct sk_buff *skb);
#endif
......@@ -34,4 +34,9 @@ static inline struct iphdr *ipip_hdr(const struct sk_buff *skb)
{
return (struct iphdr *)skb_transport_header(skb);
}
static inline unsigned int ip_transport_len(const struct sk_buff *skb)
{
return ntohs(ip_hdr(skb)->tot_len) - skb_network_header_len(skb);
}
#endif /* _LINUX_IP_H */
......@@ -104,6 +104,12 @@ static inline struct ipv6hdr *ipipv6_hdr(const struct sk_buff *skb)
return (struct ipv6hdr *)skb_transport_header(skb);
}
static inline unsigned int ipv6_transport_len(const struct sk_buff *skb)
{
return ntohs(ipv6_hdr(skb)->payload_len) + sizeof(struct ipv6hdr) -
skb_network_header_len(skb);
}
/*
This structure contains results of exthdrs parsing
as offsets from skb->nh.
......
......@@ -49,6 +49,7 @@ struct prefix_info {
struct in6_addr prefix;
};
#include <linux/ipv6.h>
#include <linux/netdevice.h>
#include <net/if_inet6.h>
#include <net/ipv6.h>
......@@ -201,6 +202,15 @@ u32 ipv6_addr_label(struct net *net, const struct in6_addr *addr,
/*
* multicast prototypes (mcast.c)
*/
static inline int ipv6_mc_may_pull(struct sk_buff *skb,
unsigned int len)
{
if (skb_transport_offset(skb) + ipv6_transport_len(skb) < len)
return -EINVAL;
return pskb_may_pull(skb, len);
}
int ipv6_sock_mc_join(struct sock *sk, int ifindex,
const struct in6_addr *addr);
int ipv6_sock_mc_drop(struct sock *sk, int ifindex,
......@@ -219,7 +229,7 @@ void ipv6_mc_unmap(struct inet6_dev *idev);
void ipv6_mc_remap(struct inet6_dev *idev);
void ipv6_mc_init_dev(struct inet6_dev *idev);
void ipv6_mc_destroy_dev(struct inet6_dev *idev);
int ipv6_mc_check_mld(struct sk_buff *skb, struct sk_buff **skb_trimmed);
int ipv6_mc_check_mld(struct sk_buff *skb);
void addrconf_dad_failure(struct sk_buff *skb, struct inet6_ifaddr *ifp);
bool ipv6_chk_mcast_addr(struct net_device *dev, const struct in6_addr *group,
......
......@@ -674,7 +674,7 @@ static void batadv_mcast_mla_update(struct work_struct *work)
*/
static bool batadv_mcast_is_report_ipv4(struct sk_buff *skb)
{
if (ip_mc_check_igmp(skb, NULL) < 0)
if (ip_mc_check_igmp(skb) < 0)
return false;
switch (igmp_hdr(skb)->type) {
......@@ -741,7 +741,7 @@ static int batadv_mcast_forw_mode_check_ipv4(struct batadv_priv *bat_priv,
*/
static bool batadv_mcast_is_report_ipv6(struct sk_buff *skb)
{
if (ipv6_mc_check_mld(skb, NULL) < 0)
if (ipv6_mc_check_mld(skb) < 0)
return false;
switch (icmp6_hdr(skb)->icmp6_type) {
......
......@@ -938,7 +938,7 @@ static int br_ip4_multicast_igmp3_report(struct net_bridge *br,
for (i = 0; i < num; i++) {
len += sizeof(*grec);
if (!pskb_may_pull(skb, len))
if (!ip_mc_may_pull(skb, len))
return -EINVAL;
grec = (void *)(skb->data + len - sizeof(*grec));
......@@ -946,7 +946,7 @@ static int br_ip4_multicast_igmp3_report(struct net_bridge *br,
type = grec->grec_type;
len += ntohs(grec->grec_nsrcs) * 4;
if (!pskb_may_pull(skb, len))
if (!ip_mc_may_pull(skb, len))
return -EINVAL;
/* We treat this as an IGMPv2 report for now. */
......@@ -985,15 +985,17 @@ static int br_ip6_multicast_mld2_report(struct net_bridge *br,
struct sk_buff *skb,
u16 vid)
{
unsigned int nsrcs_offset;
const unsigned char *src;
struct icmp6hdr *icmp6h;
struct mld2_grec *grec;
unsigned int grec_len;
int i;
int len;
int num;
int err = 0;
if (!pskb_may_pull(skb, sizeof(*icmp6h)))
if (!ipv6_mc_may_pull(skb, sizeof(*icmp6h)))
return -EINVAL;
icmp6h = icmp6_hdr(skb);
......@@ -1003,21 +1005,25 @@ static int br_ip6_multicast_mld2_report(struct net_bridge *br,
for (i = 0; i < num; i++) {
__be16 *nsrcs, _nsrcs;
nsrcs = skb_header_pointer(skb,
len + offsetof(struct mld2_grec,
grec_nsrcs),
nsrcs_offset = len + offsetof(struct mld2_grec, grec_nsrcs);
if (skb_transport_offset(skb) + ipv6_transport_len(skb) <
nsrcs_offset + sizeof(_nsrcs))
return -EINVAL;
nsrcs = skb_header_pointer(skb, nsrcs_offset,
sizeof(_nsrcs), &_nsrcs);
if (!nsrcs)
return -EINVAL;
if (!pskb_may_pull(skb,
len + sizeof(*grec) +
sizeof(struct in6_addr) * ntohs(*nsrcs)))
grec_len = sizeof(*grec) +
sizeof(struct in6_addr) * ntohs(*nsrcs);
if (!ipv6_mc_may_pull(skb, len + grec_len))
return -EINVAL;
grec = (struct mld2_grec *)(skb->data + len);
len += sizeof(*grec) +
sizeof(struct in6_addr) * ntohs(*nsrcs);
len += grec_len;
/* We treat these as MLDv1 reports for now. */
switch (grec->grec_type) {
......@@ -1219,6 +1225,7 @@ static void br_ip4_multicast_query(struct net_bridge *br,
struct sk_buff *skb,
u16 vid)
{
unsigned int transport_len = ip_transport_len(skb);
const struct iphdr *iph = ip_hdr(skb);
struct igmphdr *ih = igmp_hdr(skb);
struct net_bridge_mdb_entry *mp;
......@@ -1228,7 +1235,6 @@ static void br_ip4_multicast_query(struct net_bridge *br,
struct br_ip saddr;
unsigned long max_delay;
unsigned long now = jiffies;
unsigned int offset = skb_transport_offset(skb);
__be32 group;
spin_lock(&br->multicast_lock);
......@@ -1238,14 +1244,14 @@ static void br_ip4_multicast_query(struct net_bridge *br,
group = ih->group;
if (skb->len == offset + sizeof(*ih)) {
if (transport_len == sizeof(*ih)) {
max_delay = ih->code * (HZ / IGMP_TIMER_SCALE);
if (!max_delay) {
max_delay = 10 * HZ;
group = 0;
}
} else if (skb->len >= offset + sizeof(*ih3)) {
} else if (transport_len >= sizeof(*ih3)) {
ih3 = igmpv3_query_hdr(skb);
if (ih3->nsrcs)
goto out;
......@@ -1296,6 +1302,7 @@ static int br_ip6_multicast_query(struct net_bridge *br,
struct sk_buff *skb,
u16 vid)
{
unsigned int transport_len = ipv6_transport_len(skb);
const struct ipv6hdr *ip6h = ipv6_hdr(skb);
struct mld_msg *mld;
struct net_bridge_mdb_entry *mp;
......@@ -1315,7 +1322,7 @@ static int br_ip6_multicast_query(struct net_bridge *br,
(port && port->state == BR_STATE_DISABLED))
goto out;
if (skb->len == offset + sizeof(*mld)) {
if (transport_len == sizeof(*mld)) {
if (!pskb_may_pull(skb, offset + sizeof(*mld))) {
err = -EINVAL;
goto out;
......@@ -1581,12 +1588,11 @@ static int br_multicast_ipv4_rcv(struct net_bridge *br,
struct sk_buff *skb,
u16 vid)
{
struct sk_buff *skb_trimmed = NULL;
const unsigned char *src;
struct igmphdr *ih;
int err;
err = ip_mc_check_igmp(skb, &skb_trimmed);
err = ip_mc_check_igmp(skb);
if (err == -ENOMSG) {
if (!ipv4_is_local_multicast(ip_hdr(skb)->daddr)) {
......@@ -1612,19 +1618,16 @@ static int br_multicast_ipv4_rcv(struct net_bridge *br,
err = br_ip4_multicast_add_group(br, port, ih->group, vid, src);
break;
case IGMPV3_HOST_MEMBERSHIP_REPORT:
err = br_ip4_multicast_igmp3_report(br, port, skb_trimmed, vid);
err = br_ip4_multicast_igmp3_report(br, port, skb, vid);
break;
case IGMP_HOST_MEMBERSHIP_QUERY:
br_ip4_multicast_query(br, port, skb_trimmed, vid);
br_ip4_multicast_query(br, port, skb, vid);
break;
case IGMP_HOST_LEAVE_MESSAGE:
br_ip4_multicast_leave_group(br, port, ih->group, vid, src);
break;
}
if (skb_trimmed && skb_trimmed != skb)
kfree_skb(skb_trimmed);
br_multicast_count(br, port, skb, BR_INPUT_SKB_CB(skb)->igmp,
BR_MCAST_DIR_RX);
......@@ -1637,12 +1640,11 @@ static int br_multicast_ipv6_rcv(struct net_bridge *br,
struct sk_buff *skb,
u16 vid)
{
struct sk_buff *skb_trimmed = NULL;
const unsigned char *src;
struct mld_msg *mld;
int err;
err = ipv6_mc_check_mld(skb, &skb_trimmed);
err = ipv6_mc_check_mld(skb);
if (err == -ENOMSG) {
if (!ipv6_addr_is_ll_all_nodes(&ipv6_hdr(skb)->daddr))
......@@ -1664,10 +1666,10 @@ static int br_multicast_ipv6_rcv(struct net_bridge *br,
src);
break;
case ICMPV6_MLD2_REPORT:
err = br_ip6_multicast_mld2_report(br, port, skb_trimmed, vid);
err = br_ip6_multicast_mld2_report(br, port, skb, vid);
break;
case ICMPV6_MGM_QUERY:
err = br_ip6_multicast_query(br, port, skb_trimmed, vid);
err = br_ip6_multicast_query(br, port, skb, vid);
break;
case ICMPV6_MGM_REDUCTION:
src = eth_hdr(skb)->h_source;
......@@ -1675,9 +1677,6 @@ static int br_multicast_ipv6_rcv(struct net_bridge *br,
break;
}
if (skb_trimmed && skb_trimmed != skb)
kfree_skb(skb_trimmed);
br_multicast_count(br, port, skb, BR_INPUT_SKB_CB(skb)->igmp,
BR_MCAST_DIR_RX);
......
......@@ -1544,7 +1544,7 @@ static inline __sum16 ip_mc_validate_checksum(struct sk_buff *skb)
return skb_checksum_simple_validate(skb);
}
static int __ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed)
static int __ip_mc_check_igmp(struct sk_buff *skb)
{
struct sk_buff *skb_chk;
......@@ -1566,16 +1566,10 @@ static int __ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed)
if (ret)
goto err;
if (skb_trimmed)
*skb_trimmed = skb_chk;
/* free now unneeded clone */
else if (skb_chk != skb)
kfree_skb(skb_chk);
ret = 0;
err:
if (ret && skb_chk && skb_chk != skb)
if (skb_chk && skb_chk != skb)
kfree_skb(skb_chk);
return ret;
......@@ -1584,7 +1578,6 @@ static int __ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed)
/**
* ip_mc_check_igmp - checks whether this is a sane IGMP packet
* @skb: the skb to validate
* @skb_trimmed: to store an skb pointer trimmed to IPv4 packet tail (optional)
*
* Checks whether an IPv4 packet is a valid IGMP packet. If so sets
* skb transport header accordingly and returns zero.
......@@ -1594,18 +1587,10 @@ static int __ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed)
* -ENOMSG: IP header validation succeeded but it is not an IGMP packet.
* -ENOMEM: A memory allocation failure happened.
*
* Optionally, an skb pointer might be provided via skb_trimmed (or set it
* to NULL): After parsing an IGMP packet successfully it will point to
* an skb which has its tail aligned to the IP packet end. This might
* either be the originally provided skb or a trimmed, cloned version if
* the skb frame had data beyond the IP packet. A cloned skb allows us
* to leave the original skb and its full frame unchanged (which might be
* desirable for layer 2 frame jugglers).
*
* Caller needs to set the skb network header and free any returned skb if it
* differs from the provided skb.
*/
int ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed)
int ip_mc_check_igmp(struct sk_buff *skb)
{
int ret = ip_mc_check_iphdr(skb);
......@@ -1615,7 +1600,7 @@ int ip_mc_check_igmp(struct sk_buff *skb, struct sk_buff **skb_trimmed)
if (ip_hdr(skb)->protocol != IPPROTO_IGMP)
return -ENOMSG;
return __ip_mc_check_igmp(skb, skb_trimmed);
return __ip_mc_check_igmp(skb);
}
EXPORT_SYMBOL(ip_mc_check_igmp);
......
......@@ -136,8 +136,7 @@ static inline __sum16 ipv6_mc_validate_checksum(struct sk_buff *skb)
return skb_checksum_validate(skb, IPPROTO_ICMPV6, ip6_compute_pseudo);
}
static int __ipv6_mc_check_mld(struct sk_buff *skb,
struct sk_buff **skb_trimmed)
static int __ipv6_mc_check_mld(struct sk_buff *skb)
{
struct sk_buff *skb_chk = NULL;
......@@ -160,16 +159,10 @@ static int __ipv6_mc_check_mld(struct sk_buff *skb,
if (ret)
goto err;
if (skb_trimmed)
*skb_trimmed = skb_chk;
/* free now unneeded clone */
else if (skb_chk != skb)
kfree_skb(skb_chk);
ret = 0;
err:
if (ret && skb_chk && skb_chk != skb)
if (skb_chk && skb_chk != skb)
kfree_skb(skb_chk);
return ret;
......@@ -178,7 +171,6 @@ static int __ipv6_mc_check_mld(struct sk_buff *skb,
/**
* ipv6_mc_check_mld - checks whether this is a sane MLD packet
* @skb: the skb to validate
* @skb_trimmed: to store an skb pointer trimmed to IPv6 packet tail (optional)
*
* Checks whether an IPv6 packet is a valid MLD packet. If so sets
* skb transport header accordingly and returns zero.
......@@ -188,18 +180,10 @@ static int __ipv6_mc_check_mld(struct sk_buff *skb,
* -ENOMSG: IP header validation succeeded but it is not an MLD packet.
* -ENOMEM: A memory allocation failure happened.
*
* Optionally, an skb pointer might be provided via skb_trimmed (or set it
* to NULL): After parsing an MLD packet successfully it will point to
* an skb which has its tail aligned to the IP packet end. This might
* either be the originally provided skb or a trimmed, cloned version if
* the skb frame had data beyond the IP packet. A cloned skb allows us
* to leave the original skb and its full frame unchanged (which might be
* desirable for layer 2 frame jugglers).
*
* Caller needs to set the skb network header and free any returned skb if it
* differs from the provided skb.
*/
int ipv6_mc_check_mld(struct sk_buff *skb, struct sk_buff **skb_trimmed)
int ipv6_mc_check_mld(struct sk_buff *skb)
{
int ret;
......@@ -211,6 +195,6 @@ int ipv6_mc_check_mld(struct sk_buff *skb, struct sk_buff **skb_trimmed)
if (ret < 0)
return ret;
return __ipv6_mc_check_mld(skb, skb_trimmed);
return __ipv6_mc_check_mld(skb);
}
EXPORT_SYMBOL(ipv6_mc_check_mld);
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment