diff options
Diffstat (limited to 'net/netlink/af_netlink.c')
-rw-r--r-- | net/netlink/af_netlink.c | 61 |
1 files changed, 53 insertions, 8 deletions
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index 59dc7d14060..2a233ffcf61 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -106,6 +106,7 @@ struct nl_pid_hash { struct netlink_table { struct nl_pid_hash hash; struct hlist_head mc_list; + unsigned long *listeners; unsigned int nl_nonroot; unsigned int groups; struct module *module; @@ -122,7 +123,7 @@ static void netlink_destroy_callback(struct netlink_callback *cb); static DEFINE_RWLOCK(nl_table_lock); static atomic_t nl_table_users = ATOMIC_INIT(0); -static struct notifier_block *netlink_chain; +static ATOMIC_NOTIFIER_HEAD(netlink_chain); static u32 netlink_group_mask(u32 group) { @@ -296,6 +297,24 @@ static inline int nl_pid_hash_dilute(struct nl_pid_hash *hash, int len) static const struct proto_ops netlink_ops; +static void +netlink_update_listeners(struct sock *sk) +{ + struct netlink_table *tbl = &nl_table[sk->sk_protocol]; + struct hlist_node *node; + unsigned long mask; + unsigned int i; + + for (i = 0; i < NLGRPSZ(tbl->groups)/sizeof(unsigned long); i++) { + mask = 0; + sk_for_each_bound(sk, node, &tbl->mc_list) + mask |= nlk_sk(sk)->groups[i]; + tbl->listeners[i] = mask; + } + /* this function is only called with the netlink table "grabbed", which + * makes sure updates are visible before bind or setsockopt return. */ +} + static int netlink_insert(struct sock *sk, u32 pid) { struct nl_pid_hash *hash = &nl_table[sk->sk_protocol].hash; @@ -450,18 +469,21 @@ static int netlink_release(struct socket *sock) .protocol = sk->sk_protocol, .pid = nlk->pid, }; - notifier_call_chain(&netlink_chain, NETLINK_URELEASE, &n); + atomic_notifier_call_chain(&netlink_chain, + NETLINK_URELEASE, &n); } if (nlk->module) module_put(nlk->module); + netlink_table_grab(); if (nlk->flags & NETLINK_KERNEL_SOCKET) { - netlink_table_grab(); + kfree(nl_table[sk->sk_protocol].listeners); nl_table[sk->sk_protocol].module = NULL; nl_table[sk->sk_protocol].registered = 0; - netlink_table_ungrab(); - } + } else if (nlk->subscriptions) + netlink_update_listeners(sk); + netlink_table_ungrab(); kfree(nlk->groups); nlk->groups = NULL; @@ -589,6 +611,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len hweight32(nladdr->nl_groups) - hweight32(nlk->groups[0])); nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | nladdr->nl_groups; + netlink_update_listeners(sk); netlink_table_ungrab(); return 0; @@ -807,6 +830,17 @@ retry: return netlink_sendskb(sk, skb, ssk->sk_protocol); } +int netlink_has_listeners(struct sock *sk, unsigned int group) +{ + int res = 0; + + BUG_ON(!(nlk_sk(sk)->flags & NETLINK_KERNEL_SOCKET)); + if (group - 1 < nl_table[sk->sk_protocol].groups) + res = test_bit(group - 1, nl_table[sk->sk_protocol].listeners); + return res; +} +EXPORT_SYMBOL_GPL(netlink_has_listeners); + static __inline__ int netlink_broadcast_deliver(struct sock *sk, struct sk_buff *skb) { struct netlink_sock *nlk = nlk_sk(sk); @@ -1011,6 +1045,7 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname, else __clear_bit(val - 1, nlk->groups); netlink_update_subscriptions(sk, subscriptions); + netlink_update_listeners(sk); netlink_table_ungrab(); err = 0; break; @@ -1237,6 +1272,7 @@ netlink_kernel_create(int unit, unsigned int groups, struct socket *sock; struct sock *sk; struct netlink_sock *nlk; + unsigned long *listeners = NULL; if (!nl_table) return NULL; @@ -1250,6 +1286,13 @@ netlink_kernel_create(int unit, unsigned int groups, if (__netlink_create(sock, unit) < 0) goto out_sock_release; + if (groups < 32) + groups = 32; + + listeners = kzalloc(NLGRPSZ(groups), GFP_KERNEL); + if (!listeners) + goto out_sock_release; + sk = sock->sk; sk->sk_data_ready = netlink_data_ready; if (input) @@ -1262,7 +1305,8 @@ netlink_kernel_create(int unit, unsigned int groups, nlk->flags |= NETLINK_KERNEL_SOCKET; netlink_table_grab(); - nl_table[unit].groups = groups < 32 ? 32 : groups; + nl_table[unit].groups = groups; + nl_table[unit].listeners = listeners; nl_table[unit].module = module; nl_table[unit].registered = 1; netlink_table_ungrab(); @@ -1270,6 +1314,7 @@ netlink_kernel_create(int unit, unsigned int groups, return sk; out_sock_release: + kfree(listeners); sock_release(sock); return NULL; } @@ -1651,12 +1696,12 @@ static struct file_operations netlink_seq_fops = { int netlink_register_notifier(struct notifier_block *nb) { - return notifier_chain_register(&netlink_chain, nb); + return atomic_notifier_chain_register(&netlink_chain, nb); } int netlink_unregister_notifier(struct notifier_block *nb) { - return notifier_chain_unregister(&netlink_chain, nb); + return atomic_notifier_chain_unregister(&netlink_chain, nb); } static const struct proto_ops netlink_ops = { |