diff options
Diffstat (limited to 'net/ipv4')
-rw-r--r-- | net/ipv4/cipso_ipv4.c | 1 | ||||
-rw-r--r-- | net/ipv4/devinet.c | 9 | ||||
-rw-r--r-- | net/ipv4/icmp.c | 3 | ||||
-rw-r--r-- | net/ipv4/ip_forward.c | 2 | ||||
-rw-r--r-- | net/ipv4/route.c | 135 | ||||
-rw-r--r-- | net/ipv4/sysctl_net_ipv4.c | 12 | ||||
-rw-r--r-- | net/ipv4/tcp_input.c | 4 | ||||
-rw-r--r-- | net/ipv4/tcp_timer.c | 4 | ||||
-rw-r--r-- | net/ipv4/udp.c | 241 | ||||
-rw-r--r-- | net/ipv4/udp_impl.h | 4 | ||||
-rw-r--r-- | net/ipv4/udplite.c | 14 |
11 files changed, 312 insertions, 117 deletions
diff --git a/net/ipv4/cipso_ipv4.c b/net/ipv4/cipso_ipv4.c index 2e78f6bd977..e52799047a5 100644 --- a/net/ipv4/cipso_ipv4.c +++ b/net/ipv4/cipso_ipv4.c @@ -490,7 +490,6 @@ int cipso_v4_doi_add(struct cipso_v4_doi *doi_def) } atomic_set(&doi_def->refcount, 1); - INIT_RCU_HEAD(&doi_def->rcu); spin_lock(&cipso_v4_doi_list_lock); if (cipso_v4_doi_search(doi_def->doi) != NULL) diff --git a/net/ipv4/devinet.c b/net/ipv4/devinet.c index 56fce3ab6c5..0bff576d291 100644 --- a/net/ipv4/devinet.c +++ b/net/ipv4/devinet.c @@ -112,13 +112,7 @@ static inline void devinet_sysctl_unregister(struct in_device *idev) static struct in_ifaddr *inet_alloc_ifa(void) { - struct in_ifaddr *ifa = kzalloc(sizeof(*ifa), GFP_KERNEL); - - if (ifa) { - INIT_RCU_HEAD(&ifa->rcu_head); - } - - return ifa; + return kzalloc(sizeof(struct in_ifaddr), GFP_KERNEL); } static void inet_rcu_free_ifa(struct rcu_head *head) @@ -161,7 +155,6 @@ static struct in_device *inetdev_init(struct net_device *dev) in_dev = kzalloc(sizeof(*in_dev), GFP_KERNEL); if (!in_dev) goto out; - INIT_RCU_HEAD(&in_dev->rcu_head); memcpy(&in_dev->cnf, dev_net(dev)->ipv4.devconf_dflt, sizeof(in_dev->cnf)); in_dev->cnf.sysctl = NULL; diff --git a/net/ipv4/icmp.c b/net/ipv4/icmp.c index 72b2de76f1c..e9d6ea0b49c 100644 --- a/net/ipv4/icmp.c +++ b/net/ipv4/icmp.c @@ -976,9 +976,10 @@ int icmp_rcv(struct sk_buff *skb) struct net *net = dev_net(rt->u.dst.dev); if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) { + struct sec_path *sp = skb_sec_path(skb); int nh; - if (!(skb->sp && skb->sp->xvec[skb->sp->len - 1]->props.flags & + if (!(sp && sp->xvec[sp->len - 1]->props.flags & XFRM_STATE_ICMP)) goto drop; diff --git a/net/ipv4/ip_forward.c b/net/ipv4/ip_forward.c index 450016b89a1..df3fe50bbf0 100644 --- a/net/ipv4/ip_forward.c +++ b/net/ipv4/ip_forward.c @@ -106,7 +106,7 @@ int ip_forward(struct sk_buff *skb) * We now generate an ICMP HOST REDIRECT giving the route * we calculated. */ - if (rt->rt_flags&RTCF_DOREDIRECT && !opt->srr && !skb->sp) + if (rt->rt_flags&RTCF_DOREDIRECT && !opt->srr && !skb_sec_path(skb)) ip_rt_send_redirect(skb); skb->priority = rt_tos2priority(iph->tos); diff --git a/net/ipv4/route.c b/net/ipv4/route.c index 2ea6dcc3e2c..e59b4dcf677 100644 --- a/net/ipv4/route.c +++ b/net/ipv4/route.c @@ -129,6 +129,7 @@ static int ip_rt_mtu_expires __read_mostly = 10 * 60 * HZ; static int ip_rt_min_pmtu __read_mostly = 512 + 20 + 20; static int ip_rt_min_advmss __read_mostly = 256; static int ip_rt_secret_interval __read_mostly = 10 * 60 * HZ; +static int rt_chain_length_max __read_mostly = 20; static void rt_worker_func(struct work_struct *work); static DECLARE_DELAYED_WORK(expires_work, rt_worker_func); @@ -145,6 +146,7 @@ static struct dst_entry *ipv4_negative_advice(struct dst_entry *dst); static void ipv4_link_failure(struct sk_buff *skb); static void ip_rt_update_pmtu(struct dst_entry *dst, u32 mtu); static int rt_garbage_collect(struct dst_ops *ops); +static void rt_emergency_hash_rebuild(struct net *net); static struct dst_ops ipv4_dst_ops = { @@ -201,6 +203,7 @@ const __u8 ip_tos2prio[16] = { struct rt_hash_bucket { struct rtable *chain; }; + #if defined(CONFIG_SMP) || defined(CONFIG_DEBUG_SPINLOCK) || \ defined(CONFIG_PROVE_LOCKING) /* @@ -674,6 +677,20 @@ static inline u32 rt_score(struct rtable *rt) return score; } +static inline bool rt_caching(const struct net *net) +{ + return net->ipv4.current_rt_cache_rebuild_count <= + net->ipv4.sysctl_rt_cache_rebuild_count; +} + +static inline bool compare_hash_inputs(const struct flowi *fl1, + const struct flowi *fl2) +{ + return (__force u32)(((fl1->nl_u.ip4_u.daddr ^ fl2->nl_u.ip4_u.daddr) | + (fl1->nl_u.ip4_u.saddr ^ fl2->nl_u.ip4_u.saddr) | + (fl1->iif ^ fl2->iif)) == 0); +} + static inline int compare_keys(struct flowi *fl1, struct flowi *fl2) { return ((__force u32)((fl1->nl_u.ip4_u.daddr ^ fl2->nl_u.ip4_u.daddr) | @@ -753,11 +770,24 @@ static void rt_do_flush(int process_context) } } +/* + * While freeing expired entries, we compute average chain length + * and standard deviation, using fixed-point arithmetic. + * This to have an estimation of rt_chain_length_max + * rt_chain_length_max = max(elasticity, AVG + 4*SD) + * We use 3 bits for frational part, and 29 (or 61) for magnitude. + */ + +#define FRACT_BITS 3 +#define ONE (1UL << FRACT_BITS) + static void rt_check_expire(void) { static unsigned int rover; unsigned int i = rover, goal; struct rtable *rth, **rthp; + unsigned long length = 0, samples = 0; + unsigned long sum = 0, sum2 = 0; u64 mult; mult = ((u64)ip_rt_gc_interval) << rt_hash_log; @@ -766,6 +796,7 @@ static void rt_check_expire(void) goal = (unsigned int)mult; if (goal > rt_hash_mask) goal = rt_hash_mask + 1; + length = 0; for (; goal > 0; goal--) { unsigned long tmo = ip_rt_gc_timeout; @@ -775,6 +806,8 @@ static void rt_check_expire(void) if (need_resched()) cond_resched(); + samples++; + if (*rthp == NULL) continue; spin_lock_bh(rt_hash_lock_addr(i)); @@ -789,11 +822,29 @@ static void rt_check_expire(void) if (time_before_eq(jiffies, rth->u.dst.expires)) { tmo >>= 1; rthp = &rth->u.dst.rt_next; + /* + * Only bump our length if the hash + * inputs on entries n and n+1 are not + * the same, we only count entries on + * a chain with equal hash inputs once + * so that entries for different QOS + * levels, and other non-hash input + * attributes don't unfairly skew + * the length computation + */ + if ((*rthp == NULL) || + !compare_hash_inputs(&(*rthp)->fl, + &rth->fl)) + length += ONE; continue; } } else if (!rt_may_expire(rth, tmo, ip_rt_gc_timeout)) { tmo >>= 1; rthp = &rth->u.dst.rt_next; + if ((*rthp == NULL) || + !compare_hash_inputs(&(*rthp)->fl, + &rth->fl)) + length += ONE; continue; } @@ -802,6 +853,15 @@ static void rt_check_expire(void) rt_free(rth); } spin_unlock_bh(rt_hash_lock_addr(i)); + sum += length; + sum2 += length*length; + } + if (samples) { + unsigned long avg = sum / samples; + unsigned long sd = int_sqrt(sum2 / samples - avg*avg); + rt_chain_length_max = max_t(unsigned long, + ip_rt_gc_elasticity, + (avg + 4*sd) >> FRACT_BITS); } rover = i; } @@ -851,6 +911,26 @@ static void rt_secret_rebuild(unsigned long __net) mod_timer(&net->ipv4.rt_secret_timer, jiffies + ip_rt_secret_interval); } +static void rt_secret_rebuild_oneshot(struct net *net) +{ + del_timer_sync(&net->ipv4.rt_secret_timer); + rt_cache_invalidate(net); + if (ip_rt_secret_interval) { + net->ipv4.rt_secret_timer.expires += ip_rt_secret_interval; + add_timer(&net->ipv4.rt_secret_timer); + } +} + +static void rt_emergency_hash_rebuild(struct net *net) +{ + if (net_ratelimit()) { + printk(KERN_WARNING "Route hash chain too long!\n"); + printk(KERN_WARNING "Adjust your secret_interval!\n"); + } + + rt_secret_rebuild_oneshot(net); +} + /* Short description of GC goals. @@ -989,6 +1069,7 @@ out: return 0; static int rt_intern_hash(unsigned hash, struct rtable *rt, struct rtable **rp) { struct rtable *rth, **rthp; + struct rtable *rthi; unsigned long now; struct rtable *cand, **candp; u32 min_score; @@ -1002,7 +1083,13 @@ restart: candp = NULL; now = jiffies; + if (!rt_caching(dev_net(rt->u.dst.dev))) { + rt_drop(rt); + return 0; + } + rthp = &rt_hash_table[hash].chain; + rthi = NULL; spin_lock_bh(rt_hash_lock_addr(hash)); while ((rth = *rthp) != NULL) { @@ -1048,6 +1135,17 @@ restart: chain_length++; rthp = &rth->u.dst.rt_next; + + /* + * check to see if the next entry in the chain + * contains the same hash input values as rt. If it does + * This is where we will insert into the list, instead of + * at the head. This groups entries that differ by aspects not + * relvant to the hash function together, which we use to adjust + * our chain length + */ + if (*rthp && compare_hash_inputs(&(*rthp)->fl, &rt->fl)) + rthi = rth; } if (cand) { @@ -1061,6 +1159,16 @@ restart: *candp = cand->u.dst.rt_next; rt_free(cand); } + } else { + if (chain_length > rt_chain_length_max) { + struct net *net = dev_net(rt->u.dst.dev); + int num = ++net->ipv4.current_rt_cache_rebuild_count; + if (!rt_caching(dev_net(rt->u.dst.dev))) { + printk(KERN_WARNING "%s: %d rebuilds is over limit, route caching disabled\n", + rt->u.dst.dev->name, num); + } + rt_emergency_hash_rebuild(dev_net(rt->u.dst.dev)); + } } /* Try to bind route to arp only if it is output @@ -1098,7 +1206,11 @@ restart: } } - rt->u.dst.rt_next = rt_hash_table[hash].chain; + if (rthi) + rt->u.dst.rt_next = rthi->u.dst.rt_next; + else + rt->u.dst.rt_next = rt_hash_table[hash].chain; + #if RT_CACHE_DEBUG >= 2 if (rt->u.dst.rt_next) { struct rtable *trt; @@ -1114,7 +1226,11 @@ restart: * previous writes to rt are comitted to memory * before making rt visible to other CPUS. */ - rcu_assign_pointer(rt_hash_table[hash].chain, rt); + if (rthi) + rcu_assign_pointer(rthi->u.dst.rt_next, rt); + else + rcu_assign_pointer(rt_hash_table[hash].chain, rt); + spin_unlock_bh(rt_hash_lock_addr(hash)); *rp = rt; return 0; @@ -1217,6 +1333,9 @@ void ip_rt_redirect(__be32 old_gw, __be32 daddr, __be32 new_gw, || ipv4_is_zeronet(new_gw)) goto reject_redirect; + if (!rt_caching(net)) + goto reject_redirect; + if (!IN_DEV_SHARED_MEDIA(in_dev)) { if (!inet_addr_onlink(in_dev, new_gw, old_gw)) goto reject_redirect; @@ -1267,7 +1386,6 @@ void ip_rt_redirect(__be32 old_gw, __be32 daddr, __be32 new_gw, /* Copy all the information. */ *rt = *rth; - INIT_RCU_HEAD(&rt->u.dst.rcu_head); rt->u.dst.__use = 1; atomic_set(&rt->u.dst.__refcnt, 1); rt->u.dst.child = NULL; @@ -1280,7 +1398,9 @@ void ip_rt_redirect(__be32 old_gw, __be32 daddr, __be32 new_gw, rt->u.dst.path = &rt->u.dst; rt->u.dst.neighbour = NULL; rt->u.dst.hh = NULL; +#ifdef CONFIG_XFRM rt->u.dst.xfrm = NULL; +#endif rt->rt_genid = rt_genid(net); rt->rt_flags |= RTCF_REDIRECTED; @@ -2130,6 +2250,10 @@ int ip_route_input(struct sk_buff *skb, __be32 daddr, __be32 saddr, struct net *net; net = dev_net(dev); + + if (!rt_caching(net)) + goto skip_cache; + tos &= IPTOS_RT_MASK; hash = rt_hash(daddr, saddr, iif, rt_genid(net)); @@ -2154,6 +2278,7 @@ int ip_route_input(struct sk_buff *skb, __be32 daddr, __be32 saddr, } rcu_read_unlock(); +skip_cache: /* Multicast recognition logic is moved from route cache to here. The problem was that too many Ethernet cards have broken/missing hardware multicast filters :-( As result the host on multicasting @@ -2539,6 +2664,9 @@ int __ip_route_output_key(struct net *net, struct rtable **rp, unsigned hash; struct rtable *rth; + if (!rt_caching(net)) + goto slow_output; + hash = rt_hash(flp->fl4_dst, flp->fl4_src, flp->oif, rt_genid(net)); rcu_read_lock_bh(); @@ -2563,6 +2691,7 @@ int __ip_route_output_key(struct net *net, struct rtable **rp, } rcu_read_unlock_bh(); +slow_output: return ip_route_output_slow(net, rp, flp); } diff --git a/net/ipv4/sysctl_net_ipv4.c b/net/ipv4/sysctl_net_ipv4.c index 1bb10df8ce7..0cc8d31f9ac 100644 --- a/net/ipv4/sysctl_net_ipv4.c +++ b/net/ipv4/sysctl_net_ipv4.c @@ -795,6 +795,14 @@ static struct ctl_table ipv4_net_table[] = { .mode = 0644, .proc_handler = &proc_dointvec }, + { + .ctl_name = CTL_UNNUMBERED, + .procname = "rt_cache_rebuild_count", + .data = &init_net.ipv4.sysctl_rt_cache_rebuild_count, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = &proc_dointvec + }, { } }; @@ -827,8 +835,12 @@ static __net_init int ipv4_sysctl_init_net(struct net *net) &net->ipv4.sysctl_icmp_ratelimit; table[5].data = &net->ipv4.sysctl_icmp_ratemask; + table[6].data = + &net->ipv4.sysctl_rt_cache_rebuild_count; } + net->ipv4.sysctl_rt_cache_rebuild_count = 4; + net->ipv4.ipv4_hdr = register_net_sysctl_table(net, net_ipv4_ctl_path, table); if (net->ipv4.ipv4_hdr == NULL) diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c index d77c0d29e23..04909e4b3c4 100644 --- a/net/ipv4/tcp_input.c +++ b/net/ipv4/tcp_input.c @@ -2346,9 +2346,9 @@ static void DBGUNDO(struct sock *sk, const char *msg) #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) else if (sk->sk_family == AF_INET6) { struct ipv6_pinfo *np = inet6_sk(sk); - printk(KERN_DEBUG "Undo %s " NIP6_FMT "/%u c%u l%u ss%u/%u p%u\n", + printk(KERN_DEBUG "Undo %s %pI6/%u c%u l%u ss%u/%u p%u\n", msg, - NIP6(np->daddr), ntohs(inet->dport), + &np->daddr, ntohs(inet->dport), tp->snd_cwnd, tcp_left_out(tp), tp->snd_ssthresh, tp->prior_ssthresh, tp->packets_out); diff --git a/net/ipv4/tcp_timer.c b/net/ipv4/tcp_timer.c index 6b6dff1164b..979c9d604eb 100644 --- a/net/ipv4/tcp_timer.c +++ b/net/ipv4/tcp_timer.c @@ -306,8 +306,8 @@ static void tcp_retransmit_timer(struct sock *sk) #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) else if (sk->sk_family == AF_INET6) { struct ipv6_pinfo *np = inet6_sk(sk); - LIMIT_NETDEBUG(KERN_DEBUG "TCP: Treason uncloaked! Peer " NIP6_FMT ":%u/%u shrinks window %u:%u. Repaired.\n", - NIP6(np->daddr), ntohs(inet->dport), + LIMIT_NETDEBUG(KERN_DEBUG "TCP: Treason uncloaked! Peer %pI6:%u/%u shrinks window %u:%u. Repaired.\n", + &np->daddr, ntohs(inet->dport), inet->num, tp->snd_una, tp->snd_nxt); } #endif diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 2095abc3cab..f760b86e753 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -81,6 +81,8 @@ #include <asm/uaccess.h> #include <asm/ioctls.h> #include <linux/bootmem.h> +#include <linux/highmem.h> +#include <linux/swap.h> #include <linux/types.h> #include <linux/fcntl.h> #include <linux/module.h> @@ -104,12 +106,8 @@ #include <net/xfrm.h> #include "udp_impl.h" -/* - * Snmp MIB for the UDP layer - */ - -struct hlist_head udp_hash[UDP_HTABLE_SIZE]; -DEFINE_RWLOCK(udp_hash_lock); +struct udp_table udp_table; +EXPORT_SYMBOL(udp_table); int sysctl_udp_mem[3] __read_mostly; int sysctl_udp_rmem_min __read_mostly; @@ -123,7 +121,7 @@ atomic_t udp_memory_allocated; EXPORT_SYMBOL(udp_memory_allocated); static int udp_lib_lport_inuse(struct net *net, __u16 num, - const struct hlist_head udptable[], + const struct udp_hslot *hslot, struct sock *sk, int (*saddr_comp)(const struct sock *sk1, const struct sock *sk2)) @@ -131,7 +129,7 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num, struct sock *sk2; struct hlist_node *node; - sk_for_each(sk2, node, &udptable[udp_hashfn(net, num)]) + sk_for_each(sk2, node, &hslot->head) if (net_eq(sock_net(sk2), net) && sk2 != sk && sk2->sk_hash == num && @@ -154,12 +152,11 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum, int (*saddr_comp)(const struct sock *sk1, const struct sock *sk2 ) ) { - struct hlist_head *udptable = sk->sk_prot->h.udp_hash; + struct udp_hslot *hslot; + struct udp_table *udptable = sk->sk_prot->h.udp_table; int error = 1; struct net *net = sock_net(sk); - write_lock_bh(&udp_hash_lock); - if (!snum) { int low, high, remaining; unsigned rand; @@ -171,26 +168,34 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum, rand = net_random(); snum = first = rand % remaining + low; rand |= 1; - while (udp_lib_lport_inuse(net, snum, udptable, sk, - saddr_comp)) { + for (;;) { + hslot = &udptable->hash[udp_hashfn(net, snum)]; + spin_lock_bh(&hslot->lock); + if (!udp_lib_lport_inuse(net, snum, hslot, sk, saddr_comp)) + break; + spin_unlock_bh(&hslot->lock); do { snum = snum + rand; } while (snum < low || snum > high); if (snum == first) goto fail; } - } else if (udp_lib_lport_inuse(net, snum, udptable, sk, saddr_comp)) - goto fail; - + } else { + hslot = &udptable->hash[udp_hashfn(net, snum)]; + spin_lock_bh(&hslot->lock); + if (udp_lib_lport_inuse(net, snum, hslot, sk, saddr_comp)) + goto fail_unlock; + } inet_sk(sk)->num = snum; sk->sk_hash = snum; if (sk_unhashed(sk)) { - sk_add_node(sk, &udptable[udp_hashfn(net, snum)]); + sk_add_node_rcu(sk, &hslot->head); sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); } error = 0; +fail_unlock: + spin_unlock_bh(&hslot->lock); fail: - write_unlock_bh(&udp_hash_lock); return error; } @@ -208,63 +213,89 @@ int udp_v4_get_port(struct sock *sk, unsigned short snum) return udp_lib_get_port(sk, snum, ipv4_rcv_saddr_equal); } +static inline int compute_score(struct sock *sk, struct net *net, __be32 saddr, + unsigned short hnum, + __be16 sport, __be32 daddr, __be16 dport, int dif) +{ + int score = -1; + + if (net_eq(sock_net(sk), net) && sk->sk_hash == hnum && + !ipv6_only_sock(sk)) { + struct inet_sock *inet = inet_sk(sk); + + score = (sk->sk_family == PF_INET ? 1 : 0); + if (inet->rcv_saddr) { + if (inet->rcv_saddr != daddr) + return -1; + score += 2; + } + if (inet->daddr) { + if (inet->daddr != saddr) + return -1; + score += 2; + } + if (inet->dport) { + if (inet->dport != sport) + return -1; + score += 2; + } + if (sk->sk_bound_dev_if) { + if (sk->sk_bound_dev_if != dif) + return -1; + score += 2; + } + } + return score; +} + /* UDP is nearly always wildcards out the wazoo, it makes no sense to try * harder than this. -DaveM */ static struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, __be32 daddr, __be16 dport, - int dif, struct hlist_head udptable[]) + int dif, struct udp_table *udptable) { - struct sock *sk, *result = NULL; - struct hlist_node *node; + struct sock *sk, *result; + struct hlist_node *node, *next; unsigned short hnum = ntohs(dport); - int badness = -1; - - read_lock(&udp_hash_lock); - sk_for_each(sk, node, &udptable[udp_hashfn(net, hnum)]) { - struct inet_sock *inet = inet_sk(sk); - - if (net_eq(sock_net(sk), net) && sk->sk_hash == hnum && - !ipv6_only_sock(sk)) { - int score = (sk->sk_family == PF_INET ? 1 : 0); - if (inet->rcv_saddr) { - if (inet->rcv_saddr != daddr) - continue; - score+=2; - } - if (inet->daddr) { - if (inet->daddr != saddr) - continue; - score+=2; - } - if (inet->dport) { - if (inet->dport != sport) - continue; - score+=2; - } - if (sk->sk_bound_dev_if) { - if (sk->sk_bound_dev_if != dif) - continue; - score+=2; - } - if (score == 9) { - result = sk; - break; - } else if (score > badness) { - result = sk; - badness = score; - } + unsigned int hash = udp_hashfn(net, hnum); + struct udp_hslot *hslot = &udptable->hash[hash]; + int score, badness; + + rcu_read_lock(); +begin: + result = NULL; + badness = -1; + sk_for_each_rcu_safenext(sk, node, &hslot->head, next) { + /* + * lockless reader, and SLAB_DESTROY_BY_RCU items: + * We must check this item was not moved to another chain + */ + if (udp_hashfn(net, sk->sk_hash) != hash) + goto begin; + score = compute_score(sk, net, saddr, hnum, sport, + daddr, dport, dif); + if (score > badness) { + result = sk; + badness = score; } } - if (result) - sock_hold(result); - read_unlock(&udp_hash_lock); + if (result) { + if (unlikely(!atomic_inc_not_zero(&result->sk_refcnt))) + result = NULL; + else if (unlikely(compute_score(result, net, saddr, hnum, sport, + daddr, dport, dif) < badness)) { + sock_put(result); + goto begin; + } + } + rcu_read_unlock(); return result; } static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb, __be16 sport, __be16 dport, - struct hlist_head udptable[]) + struct udp_table *udptable) { struct sock *sk; const struct iphdr *iph = ip_hdr(skb); @@ -280,7 +311,7 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb, struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, __be32 daddr, __be16 dport, int dif) { - return __udp4_lib_lookup(net, saddr, sport, daddr, dport, dif, udp_hash); + return __udp4_lib_lookup(net, saddr, sport, daddr, dport, dif, &udp_table); } EXPORT_SYMBOL_GPL(udp4_lib_lookup); @@ -323,7 +354,7 @@ found: * to find the appropriate port. */ -void __udp4_lib_err(struct sk_buff *skb, u32 info, struct hlist_head udptable[]) +void __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable) { struct inet_sock *inet; struct iphdr *iph = (struct iphdr*)skb->data; @@ -392,7 +423,7 @@ out: void udp_err(struct sk_buff *skb, u32 info) { - __udp4_lib_err(skb, info, udp_hash); + __udp4_lib_err(skb, info, &udp_table); } /* @@ -933,6 +964,21 @@ int udp_disconnect(struct sock *sk, int flags) return 0; } +void udp_lib_unhash(struct sock *sk) +{ + struct udp_table *udptable = sk->sk_prot->h.udp_table; + unsigned int hash = udp_hashfn(sock_net(sk), sk->sk_hash); + struct udp_hslot *hslot = &udptable->hash[hash]; + + spin_lock_bh(&hslot->lock); + if (sk_del_node_init_rcu(sk)) { + inet_sk(sk)->num = 0; + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); + } + spin_unlock_bh(&hslot->lock); +} +EXPORT_SYMBOL(udp_lib_unhash); + static int __udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) { int is_udplite = IS_UDPLITE(sk); @@ -1071,13 +1117,14 @@ drop: static int __udp4_lib_mcast_deliver(struct net *net, struct sk_buff *skb, struct udphdr *uh, __be32 saddr, __be32 daddr, - struct hlist_head udptable[]) + struct udp_table *udptable) { struct sock *sk; + struct udp_hslot *hslot = &udptable->hash[udp_hashfn(net, ntohs(uh->dest))]; int dif; - read_lock(&udp_hash_lock); - sk = sk_head(&udptable[udp_hashfn(net, ntohs(uh->dest))]); + spin_lock(&hslot->lock); + sk = sk_head(&hslot->head); dif = skb->dev->ifindex; sk = udp_v4_mcast_next(sk, uh->dest, daddr, uh->source, saddr, dif); if (sk) { @@ -1102,7 +1149,7 @@ static int __udp4_lib_mcast_deliver(struct net *net, struct sk_buff *skb, } while (sknext); } else kfree_skb(skb); - read_unlock(&udp_hash_lock); + spin_unlock(&hslot->lock); return 0; } @@ -1148,7 +1195,7 @@ static inline int udp4_csum_init(struct sk_buff *skb, struct udphdr *uh, * All we need to do is get the socket, and then do a checksum. */ -int __udp4_lib_rcv(struct sk_buff *skb, struct hlist_head udptable[], +int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, int proto) { struct sock *sk; @@ -1246,7 +1293,7 @@ drop: int udp_rcv(struct sk_buff *skb) { - return __udp4_lib_rcv(skb, udp_hash, IPPROTO_UDP); + return __udp4_lib_rcv(skb, &udp_table, IPPROTO_UDP); } void udp_destroy_sock(struct sock *sk) @@ -1488,7 +1535,8 @@ struct proto udp_prot = { .sysctl_wmem = &sysctl_udp_wmem_min, .sysctl_rmem = &sysctl_udp_rmem_min, .obj_size = sizeof(struct udp_sock), - .h.udp_hash = udp_hash, + .slab_flags = SLAB_DESTROY_BY_RCU, + .h.udp_table = &udp_table, #ifdef CONFIG_COMPAT .compat_setsockopt = compat_udp_setsockopt, .compat_getsockopt = compat_udp_getsockopt, @@ -1498,20 +1546,23 @@ struct proto udp_prot = { /* ------------------------------------------------------------------------ */ #ifdef CONFIG_PROC_FS -static struct sock *udp_get_first(struct seq_file *seq) +static struct sock *udp_get_first(struct seq_file *seq, int start) { struct sock *sk; struct udp_iter_state *state = seq->private; struct net *net = seq_file_net(seq); - for (state->bucket = 0; state->bucket < UDP_HTABLE_SIZE; ++state->bucket) { + for (state->bucket = start; state->bucket < UDP_HTABLE_SIZE; ++state->bucket) { struct hlist_node *node; - sk_for_each(sk, node, state->hashtable + state->bucket) { + struct udp_hslot *hslot = &state->udp_table->hash[state->bucket]; + spin_lock_bh(&hslot->lock); + sk_for_each(sk, node, &hslot->head) { if (!net_eq(sock_net(sk), net)) continue; if (sk->sk_family == state->family) goto found; } + spin_unlock_bh(&hslot->lock); } sk = NULL; found: @@ -1525,20 +1576,18 @@ static struct sock *udp_get_next(struct seq_file *seq, struct sock *sk) do { sk = sk_next(sk); -try_again: - ; } while (sk && (!net_eq(sock_net(sk), net) || sk->sk_family != state->family)); - if (!sk && ++state->bucket < UDP_HTABLE_SIZE) { - sk = sk_head(state->hashtable + state->bucket); - goto try_again; + if (!sk) { + spin_unlock_bh(&state->udp_table->hash[state->bucket].lock); + return udp_get_first(seq, state->bucket + 1); } return sk; } static struct sock *udp_get_idx(struct seq_file *seq, loff_t pos) { - struct sock *sk = udp_get_first(seq); + struct sock *sk = udp_get_first(seq, 0); if (sk) while (pos && (sk = udp_get_next(seq, sk)) != NULL) @@ -1547,9 +1596,7 @@ static struct sock *udp_get_idx(struct seq_file *seq, loff_t pos) } static void *udp_seq_start(struct seq_file *seq, loff_t *pos) - __acquires(udp_hash_lock) { - read_lock(&udp_hash_lock); return *pos ? udp_get_idx(seq, *pos-1) : SEQ_START_TOKEN; } @@ -1567,9 +1614,11 @@ static void *udp_seq_next(struct seq_file *seq, void *v, loff_t *pos) } static void udp_seq_stop(struct seq_file *seq, void *v) - __releases(udp_hash_lock) { - read_unlock(&udp_hash_lock); + struct udp_iter_state *state = seq->private; + + if (state->bucket < UDP_HTABLE_SIZE) + spin_unlock_bh(&state->udp_table->hash[state->bucket].lock); } static int udp_seq_open(struct inode *inode, struct file *file) @@ -1585,7 +1634,7 @@ static int udp_seq_open(struct inode *inode, struct file *file) s = ((struct seq_file *)file->private_data)->private; s->family = afinfo->family; - s->hashtable = afinfo->hashtable; + s->udp_table = afinfo->udp_table; return err; } @@ -1657,7 +1706,7 @@ int udp4_seq_show(struct seq_file *seq, void *v) static struct udp_seq_afinfo udp4_seq_afinfo = { .name = "udp", .family = AF_INET, - .hashtable = udp_hash, + .udp_table = &udp_table, .seq_fops = { .owner = THIS_MODULE, }, @@ -1692,16 +1741,28 @@ void udp4_proc_exit(void) } #endif /* CONFIG_PROC_FS */ +void __init udp_table_init(struct udp_table *table) +{ + int i; + + for (i = 0; i < UDP_HTABLE_SIZE; i++) { + INIT_HLIST_HEAD(&table->hash[i].head); + spin_lock_init(&table->hash[i].lock); + } +} + void __init udp_init(void) { - unsigned long limit; + unsigned long nr_pages, limit; + udp_table_init(&udp_table); /* Set the pressure threshold up by the same strategy of TCP. It is a * fraction of global memory that is up to 1/2 at 256 MB, decreasing * toward zero with the amount of memory, with a floor of 128 pages. */ - limit = min(nr_all_pages, 1UL<<(28-PAGE_SHIFT)) >> (20-PAGE_SHIFT); - limit = (limit * (nr_all_pages >> (20-PAGE_SHIFT))) >> (PAGE_SHIFT-11); + nr_pages = totalram_pages - totalhigh_pages; + limit = min(nr_pages, 1UL<<(28-PAGE_SHIFT)) >> (20-PAGE_SHIFT); + limit = (limit * (nr_pages >> (20-PAGE_SHIFT))) >> (PAGE_SHIFT-11); limit = max(limit, 128UL); sysctl_udp_mem[0] = limit / 4 * 3; sysctl_udp_mem[1] = limit; @@ -1712,8 +1773,6 @@ void __init udp_init(void) } EXPORT_SYMBOL(udp_disconnect); -EXPORT_SYMBOL(udp_hash); -EXPORT_SYMBOL(udp_hash_lock); EXPORT_SYMBOL(udp_ioctl); EXPORT_SYMBOL(udp_prot); EXPORT_SYMBOL(udp_sendmsg); diff --git a/net/ipv4/udp_impl.h b/net/ipv4/udp_impl.h index 2e9bad2fa1b..9f4a6165f72 100644 --- a/net/ipv4/udp_impl.h +++ b/net/ipv4/udp_impl.h @@ -5,8 +5,8 @@ #include <net/protocol.h> #include <net/inet_common.h> -extern int __udp4_lib_rcv(struct sk_buff *, struct hlist_head [], int ); -extern void __udp4_lib_err(struct sk_buff *, u32, struct hlist_head []); +extern int __udp4_lib_rcv(struct sk_buff *, struct udp_table *, int ); +extern void __udp4_lib_err(struct sk_buff *, u32, struct udp_table *); extern int udp_v4_get_port(struct sock *sk, unsigned short snum); diff --git a/net/ipv4/udplite.c b/net/ipv4/udplite.c index 3c807964da9..c784891cb7e 100644 --- a/net/ipv4/udplite.c +++ b/net/ipv4/udplite.c @@ -12,16 +12,17 @@ */ #include "udp_impl.h" -struct hlist_head udplite_hash[UDP_HTABLE_SIZE]; +struct udp_table udplite_table; +EXPORT_SYMBOL(udplite_table); static int udplite_rcv(struct sk_buff *skb) { - return __udp4_lib_rcv(skb, udplite_hash, IPPROTO_UDPLITE); + return __udp4_lib_rcv(skb, &udplite_table, IPPROTO_UDPLITE); } static void udplite_err(struct sk_buff *skb, u32 info) { - __udp4_lib_err(skb, info, udplite_hash); + __udp4_lib_err(skb, info, &udplite_table); } static struct net_protocol udplite_protocol = { @@ -50,7 +51,8 @@ struct proto udplite_prot = { .unhash = udp_lib_unhash, .get_port = udp_v4_get_port, .obj_size = sizeof(struct udp_sock), - .h.udp_hash = udplite_hash, + .slab_flags = SLAB_DESTROY_BY_RCU, + .h.udp_table = &udplite_table, #ifdef CONFIG_COMPAT .compat_setsockopt = compat_udp_setsockopt, .compat_getsockopt = compat_udp_getsockopt, @@ -71,7 +73,7 @@ static struct inet_protosw udplite4_protosw = { static struct udp_seq_afinfo udplite4_seq_afinfo = { .name = "udplite", .family = AF_INET, - .hashtable = udplite_hash, + .udp_table = &udplite_table, .seq_fops = { .owner = THIS_MODULE, }, @@ -108,6 +110,7 @@ static inline int udplite4_proc_init(void) void __init udplite4_register(void) { + udp_table_init(&udplite_table); if (proto_register(&udplite_prot, 1)) goto out_register_err; @@ -126,5 +129,4 @@ out_register_err: printk(KERN_CRIT "%s: Cannot add UDP-Lite protocol.\n", __func__); } -EXPORT_SYMBOL(udplite_hash); EXPORT_SYMBOL(udplite_prot); |