summaryrefslogtreecommitdiffstats
path: root/net/bridge/br_multicast.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/bridge/br_multicast.c')
-rw-r--r--net/bridge/br_multicast.c130
1 files changed, 78 insertions, 52 deletions
diff --git a/net/bridge/br_multicast.c b/net/bridge/br_multicast.c
index 543b3262d00..030a002ff8e 100644
--- a/net/bridge/br_multicast.c
+++ b/net/bridge/br_multicast.c
@@ -33,11 +33,13 @@
#include "br_private.h"
+#define mlock_dereference(X, br) \
+ rcu_dereference_protected(X, lockdep_is_held(&br->multicast_lock))
+
#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
-static inline int ipv6_is_local_multicast(const struct in6_addr *addr)
+static inline int ipv6_is_transient_multicast(const struct in6_addr *addr)
{
- if (ipv6_addr_is_multicast(addr) &&
- IPV6_ADDR_MC_SCOPE(addr) <= IPV6_ADDR_SCOPE_LINKLOCAL)
+ if (ipv6_addr_is_multicast(addr) && IPV6_ADDR_MC_FLAG_TRANSIENT(addr))
return 1;
return 0;
}
@@ -135,7 +137,7 @@ static struct net_bridge_mdb_entry *br_mdb_ip6_get(
struct net_bridge_mdb_entry *br_mdb_get(struct net_bridge *br,
struct sk_buff *skb)
{
- struct net_bridge_mdb_htable *mdb = br->mdb;
+ struct net_bridge_mdb_htable *mdb = rcu_dereference(br->mdb);
struct br_ip ip;
if (br->multicast_disabled)
@@ -229,13 +231,13 @@ static void br_multicast_group_expired(unsigned long data)
if (!netif_running(br->dev) || timer_pending(&mp->timer))
goto out;
- if (!hlist_unhashed(&mp->mglist))
- hlist_del_init(&mp->mglist);
+ mp->mglist = false;
if (mp->ports)
goto out;
- mdb = br->mdb;
+ mdb = mlock_dereference(br->mdb, br);
+
hlist_del_rcu(&mp->hlist[mdb->ver]);
mdb->size--;
@@ -249,16 +251,20 @@ out:
static void br_multicast_del_pg(struct net_bridge *br,
struct net_bridge_port_group *pg)
{
- struct net_bridge_mdb_htable *mdb = br->mdb;
+ struct net_bridge_mdb_htable *mdb;
struct net_bridge_mdb_entry *mp;
struct net_bridge_port_group *p;
- struct net_bridge_port_group **pp;
+ struct net_bridge_port_group __rcu **pp;
+
+ mdb = mlock_dereference(br->mdb, br);
mp = br_mdb_ip_get(mdb, &pg->addr);
if (WARN_ON(!mp))
return;
- for (pp = &mp->ports; (p = *pp); pp = &p->next) {
+ for (pp = &mp->ports;
+ (p = mlock_dereference(*pp, br)) != NULL;
+ pp = &p->next) {
if (p != pg)
continue;
@@ -268,7 +274,7 @@ static void br_multicast_del_pg(struct net_bridge *br,
del_timer(&p->query_timer);
call_rcu_bh(&p->rcu, br_multicast_free_pg);
- if (!mp->ports && hlist_unhashed(&mp->mglist) &&
+ if (!mp->ports && !mp->mglist &&
netif_running(br->dev))
mod_timer(&mp->timer, jiffies);
@@ -294,10 +300,10 @@ out:
spin_unlock(&br->multicast_lock);
}
-static int br_mdb_rehash(struct net_bridge_mdb_htable **mdbp, int max,
+static int br_mdb_rehash(struct net_bridge_mdb_htable __rcu **mdbp, int max,
int elasticity)
{
- struct net_bridge_mdb_htable *old = *mdbp;
+ struct net_bridge_mdb_htable *old = rcu_dereference_protected(*mdbp, 1);
struct net_bridge_mdb_htable *mdb;
int err;
@@ -428,7 +434,6 @@ static struct sk_buff *br_ip6_multicast_alloc_query(struct net_bridge *br,
eth = eth_hdr(skb);
memcpy(eth->h_source, br->dev->dev_addr, 6);
- ipv6_eth_mc_map(group, eth->h_dest);
eth->h_proto = htons(ETH_P_IPV6);
skb_put(skb, sizeof(*eth));
@@ -440,8 +445,10 @@ static struct sk_buff *br_ip6_multicast_alloc_query(struct net_bridge *br,
ip6h->payload_len = htons(8 + sizeof(*mldq));
ip6h->nexthdr = IPPROTO_HOPOPTS;
ip6h->hop_limit = 1;
- ipv6_addr_set(&ip6h->saddr, 0, 0, 0, 0);
+ ipv6_dev_get_saddr(dev_net(br->dev), br->dev, &ip6h->daddr, 0,
+ &ip6h->saddr);
ipv6_addr_set(&ip6h->daddr, htonl(0xff020000), 0, 0, htonl(1));
+ ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
hopopt = (u8 *)(ip6h + 1);
hopopt[0] = IPPROTO_ICMPV6; /* next hdr */
@@ -520,7 +527,7 @@ static void br_multicast_group_query_expired(unsigned long data)
struct net_bridge *br = mp->br;
spin_lock(&br->multicast_lock);
- if (!netif_running(br->dev) || hlist_unhashed(&mp->mglist) ||
+ if (!netif_running(br->dev) || !mp->mglist ||
mp->queries_sent >= br->multicast_last_member_count)
goto out;
@@ -569,7 +576,7 @@ static struct net_bridge_mdb_entry *br_multicast_get_group(
struct net_bridge *br, struct net_bridge_port *port,
struct br_ip *group, int hash)
{
- struct net_bridge_mdb_htable *mdb = br->mdb;
+ struct net_bridge_mdb_htable *mdb;
struct net_bridge_mdb_entry *mp;
struct hlist_node *p;
unsigned count = 0;
@@ -577,6 +584,7 @@ static struct net_bridge_mdb_entry *br_multicast_get_group(
int elasticity;
int err;
+ mdb = rcu_dereference_protected(br->mdb, 1);
hlist_for_each_entry(mp, p, &mdb->mhash[hash], hlist[mdb->ver]) {
count++;
if (unlikely(br_ip_equal(group, &mp->addr)))
@@ -642,13 +650,16 @@ static struct net_bridge_mdb_entry *br_multicast_new_group(
struct net_bridge *br, struct net_bridge_port *port,
struct br_ip *group)
{
- struct net_bridge_mdb_htable *mdb = br->mdb;
+ struct net_bridge_mdb_htable *mdb;
struct net_bridge_mdb_entry *mp;
int hash;
+ int err;
+ mdb = rcu_dereference_protected(br->mdb, 1);
if (!mdb) {
- if (br_mdb_rehash(&br->mdb, BR_HASH_SIZE, 0))
- return NULL;
+ err = br_mdb_rehash(&br->mdb, BR_HASH_SIZE, 0);
+ if (err)
+ return ERR_PTR(err);
goto rehash;
}
@@ -660,7 +671,7 @@ static struct net_bridge_mdb_entry *br_multicast_new_group(
case -EAGAIN:
rehash:
- mdb = br->mdb;
+ mdb = rcu_dereference_protected(br->mdb, 1);
hash = br_ip_hash(mdb, group);
break;
@@ -670,7 +681,7 @@ rehash:
mp = kzalloc(sizeof(*mp), GFP_ATOMIC);
if (unlikely(!mp))
- goto out;
+ return ERR_PTR(-ENOMEM);
mp->br = br;
mp->addr = *group;
@@ -692,7 +703,7 @@ static int br_multicast_add_group(struct net_bridge *br,
{
struct net_bridge_mdb_entry *mp;
struct net_bridge_port_group *p;
- struct net_bridge_port_group **pp;
+ struct net_bridge_port_group __rcu **pp;
unsigned long now = jiffies;
int err;
@@ -703,16 +714,18 @@ static int br_multicast_add_group(struct net_bridge *br,
mp = br_multicast_new_group(br, port, group);
err = PTR_ERR(mp);
- if (unlikely(IS_ERR(mp) || !mp))
+ if (IS_ERR(mp))
goto err;
if (!port) {
- hlist_add_head(&mp->mglist, &br->mglist);
+ mp->mglist = true;
mod_timer(&mp->timer, now + br->multicast_membership_interval);
goto out;
}
- for (pp = &mp->ports; (p = *pp); pp = &p->next) {
+ for (pp = &mp->ports;
+ (p = mlock_dereference(*pp, br)) != NULL;
+ pp = &p->next) {
if (p->port == port)
goto found;
if ((unsigned long)p->port < (unsigned long)port)
@@ -767,11 +780,11 @@ static int br_ip6_multicast_add_group(struct net_bridge *br,
{
struct br_ip br_group;
- if (ipv6_is_local_multicast(group))
+ if (!ipv6_is_transient_multicast(group))
return 0;
ipv6_addr_copy(&br_group.u.ip6, group);
- br_group.proto = htons(ETH_P_IP);
+ br_group.proto = htons(ETH_P_IPV6);
return br_multicast_add_group(br, port, &br_group);
}
@@ -1000,18 +1013,19 @@ static int br_ip6_multicast_mld2_report(struct net_bridge *br,
nsrcs = skb_header_pointer(skb,
len + offsetof(struct mld2_grec,
- grec_mca),
+ grec_nsrcs),
sizeof(_nsrcs), &_nsrcs);
if (!nsrcs)
return -EINVAL;
if (!pskb_may_pull(skb,
len + sizeof(*grec) +
- sizeof(struct in6_addr) * (*nsrcs)))
+ sizeof(struct in6_addr) * ntohs(*nsrcs)))
return -EINVAL;
grec = (struct mld2_grec *)(skb->data + len);
- len += sizeof(*grec) + sizeof(struct in6_addr) * (*nsrcs);
+ len += sizeof(*grec) +
+ sizeof(struct in6_addr) * ntohs(*nsrcs);
/* We treat these as MLDv1 reports for now. */
switch (grec->grec_type) {
@@ -1106,7 +1120,7 @@ static int br_ip4_multicast_query(struct net_bridge *br,
struct net_bridge_mdb_entry *mp;
struct igmpv3_query *ih3;
struct net_bridge_port_group *p;
- struct net_bridge_port_group **pp;
+ struct net_bridge_port_group __rcu **pp;
unsigned long max_delay;
unsigned long now = jiffies;
__be32 group;
@@ -1145,23 +1159,25 @@ static int br_ip4_multicast_query(struct net_bridge *br,
if (!group)
goto out;
- mp = br_mdb_ip4_get(br->mdb, group);
+ mp = br_mdb_ip4_get(mlock_dereference(br->mdb, br), group);
if (!mp)
goto out;
max_delay *= br->multicast_last_member_count;
- if (!hlist_unhashed(&mp->mglist) &&
+ if (mp->mglist &&
(timer_pending(&mp->timer) ?
time_after(mp->timer.expires, now + max_delay) :
try_to_del_timer_sync(&mp->timer) >= 0))
mod_timer(&mp->timer, now + max_delay);
- for (pp = &mp->ports; (p = *pp); pp = &p->next) {
+ for (pp = &mp->ports;
+ (p = mlock_dereference(*pp, br)) != NULL;
+ pp = &p->next) {
if (timer_pending(&p->timer) ?
time_after(p->timer.expires, now + max_delay) :
try_to_del_timer_sync(&p->timer) >= 0)
- mod_timer(&mp->timer, now + max_delay);
+ mod_timer(&p->timer, now + max_delay);
}
out:
@@ -1178,7 +1194,8 @@ static int br_ip6_multicast_query(struct net_bridge *br,
struct mld_msg *mld = (struct mld_msg *) icmp6_hdr(skb);
struct net_bridge_mdb_entry *mp;
struct mld2_query *mld2q;
- struct net_bridge_port_group *p, **pp;
+ struct net_bridge_port_group *p;
+ struct net_bridge_port_group __rcu **pp;
unsigned long max_delay;
unsigned long now = jiffies;
struct in6_addr *group = NULL;
@@ -1214,22 +1231,24 @@ static int br_ip6_multicast_query(struct net_bridge *br,
if (!group)
goto out;
- mp = br_mdb_ip6_get(br->mdb, group);
+ mp = br_mdb_ip6_get(mlock_dereference(br->mdb, br), group);
if (!mp)
goto out;
max_delay *= br->multicast_last_member_count;
- if (!hlist_unhashed(&mp->mglist) &&
+ if (mp->mglist &&
(timer_pending(&mp->timer) ?
time_after(mp->timer.expires, now + max_delay) :
try_to_del_timer_sync(&mp->timer) >= 0))
mod_timer(&mp->timer, now + max_delay);
- for (pp = &mp->ports; (p = *pp); pp = &p->next) {
+ for (pp = &mp->ports;
+ (p = mlock_dereference(*pp, br)) != NULL;
+ pp = &p->next) {
if (timer_pending(&p->timer) ?
time_after(p->timer.expires, now + max_delay) :
try_to_del_timer_sync(&p->timer) >= 0)
- mod_timer(&mp->timer, now + max_delay);
+ mod_timer(&p->timer, now + max_delay);
}
out:
@@ -1254,7 +1273,7 @@ static void br_multicast_leave_group(struct net_bridge *br,
timer_pending(&br->multicast_querier_timer))
goto out;
- mdb = br->mdb;
+ mdb = mlock_dereference(br->mdb, br);
mp = br_mdb_ip_get(mdb, group);
if (!mp)
goto out;
@@ -1264,7 +1283,7 @@ static void br_multicast_leave_group(struct net_bridge *br,
br->multicast_last_member_interval;
if (!port) {
- if (!hlist_unhashed(&mp->mglist) &&
+ if (mp->mglist &&
(timer_pending(&mp->timer) ?
time_after(mp->timer.expires, time) :
try_to_del_timer_sync(&mp->timer) >= 0)) {
@@ -1277,7 +1296,9 @@ static void br_multicast_leave_group(struct net_bridge *br,
goto out;
}
- for (p = mp->ports; p; p = p->next) {
+ for (p = mlock_dereference(mp->ports, br);
+ p != NULL;
+ p = mlock_dereference(p->next, br)) {
if (p->port != port)
continue;
@@ -1320,7 +1341,7 @@ static void br_ip6_multicast_leave_group(struct net_bridge *br,
{
struct br_ip br_group;
- if (ipv6_is_local_multicast(group))
+ if (!ipv6_is_transient_multicast(group))
return;
ipv6_addr_copy(&br_group.u.ip6, group);
@@ -1633,7 +1654,7 @@ void br_multicast_stop(struct net_bridge *br)
del_timer_sync(&br->multicast_query_timer);
spin_lock_bh(&br->multicast_lock);
- mdb = br->mdb;
+ mdb = mlock_dereference(br->mdb, br);
if (!mdb)
goto out;
@@ -1737,6 +1758,7 @@ int br_multicast_toggle(struct net_bridge *br, unsigned long val)
{
struct net_bridge_port *port;
int err = 0;
+ struct net_bridge_mdb_htable *mdb;
spin_lock(&br->multicast_lock);
if (br->multicast_disabled == !val)
@@ -1749,15 +1771,16 @@ int br_multicast_toggle(struct net_bridge *br, unsigned long val)
if (!netif_running(br->dev))
goto unlock;
- if (br->mdb) {
- if (br->mdb->old) {
+ mdb = mlock_dereference(br->mdb, br);
+ if (mdb) {
+ if (mdb->old) {
err = -EEXIST;
rollback:
br->multicast_disabled = !!val;
goto unlock;
}
- err = br_mdb_rehash(&br->mdb, br->mdb->max,
+ err = br_mdb_rehash(&br->mdb, mdb->max,
br->hash_elasticity);
if (err)
goto rollback;
@@ -1782,6 +1805,7 @@ int br_multicast_set_hash_max(struct net_bridge *br, unsigned long val)
{
int err = -ENOENT;
u32 old;
+ struct net_bridge_mdb_htable *mdb;
spin_lock(&br->multicast_lock);
if (!netif_running(br->dev))
@@ -1790,7 +1814,9 @@ int br_multicast_set_hash_max(struct net_bridge *br, unsigned long val)
err = -EINVAL;
if (!is_power_of_2(val))
goto unlock;
- if (br->mdb && val < br->mdb->size)
+
+ mdb = mlock_dereference(br->mdb, br);
+ if (mdb && val < mdb->size)
goto unlock;
err = 0;
@@ -1798,8 +1824,8 @@ int br_multicast_set_hash_max(struct net_bridge *br, unsigned long val)
old = br->hash_max;
br->hash_max = val;
- if (br->mdb) {
- if (br->mdb->old) {
+ if (mdb) {
+ if (mdb->old) {
err = -EEXIST;
rollback:
br->hash_max = old;