diff options
Diffstat (limited to 'net/ipv4/udp.c')
-rw-r--r-- | net/ipv4/udp.c | 123 |
1 files changed, 52 insertions, 71 deletions
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 57e26fa6618..eacf4cfef14 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -108,9 +108,6 @@ * Snmp MIB for the UDP layer */ -DEFINE_SNMP_STAT(struct udp_mib, udp_stats_in6) __read_mostly; -EXPORT_SYMBOL(udp_stats_in6); - struct hlist_head udp_hash[UDP_HTABLE_SIZE]; DEFINE_RWLOCK(udp_hash_lock); @@ -125,14 +122,23 @@ EXPORT_SYMBOL(sysctl_udp_wmem_min); atomic_t udp_memory_allocated; EXPORT_SYMBOL(udp_memory_allocated); -static inline int __udp_lib_lport_inuse(struct net *net, __u16 num, - const struct hlist_head udptable[]) +static int udp_lib_lport_inuse(struct net *net, __u16 num, + const struct hlist_head udptable[], + struct sock *sk, + int (*saddr_comp)(const struct sock *sk1, + const struct sock *sk2)) { - struct sock *sk; + struct sock *sk2; struct hlist_node *node; - sk_for_each(sk, node, &udptable[udp_hashfn(net, num)]) - if (net_eq(sock_net(sk), net) && sk->sk_hash == num) + sk_for_each(sk2, node, &udptable[udp_hashfn(net, num)]) + if (net_eq(sock_net(sk2), net) && + sk2 != sk && + sk2->sk_hash == num && + (!sk2->sk_reuse || !sk->sk_reuse) && + (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if + || sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && + (*saddr_comp)(sk, sk2)) return 1; return 0; } @@ -149,83 +155,37 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum, const struct sock *sk2 ) ) { struct hlist_head *udptable = sk->sk_prot->h.udp_hash; - struct hlist_node *node; - struct hlist_head *head; - struct sock *sk2; int error = 1; struct net *net = sock_net(sk); write_lock_bh(&udp_hash_lock); if (!snum) { - int i, low, high, remaining; - unsigned rover, best, best_size_so_far; + int low, high, remaining; + unsigned rand; + unsigned short first; inet_get_local_port_range(&low, &high); remaining = (high - low) + 1; - best_size_so_far = UINT_MAX; - best = rover = net_random() % remaining + low; - - /* 1st pass: look for empty (or shortest) hash chain */ - for (i = 0; i < UDP_HTABLE_SIZE; i++) { - int size = 0; - - head = &udptable[udp_hashfn(net, rover)]; - if (hlist_empty(head)) - goto gotit; - - sk_for_each(sk2, node, head) { - if (++size >= best_size_so_far) - goto next; - } - best_size_so_far = size; - best = rover; - next: - /* fold back if end of range */ - if (++rover > high) - rover = low + ((rover - low) - & (UDP_HTABLE_SIZE - 1)); - - - } - - /* 2nd pass: find hole in shortest hash chain */ - rover = best; - for (i = 0; i < (1 << 16) / UDP_HTABLE_SIZE; i++) { - if (! __udp_lib_lport_inuse(net, rover, udptable)) - goto gotit; - rover += UDP_HTABLE_SIZE; - if (rover > high) - rover = low + ((rover - low) - & (UDP_HTABLE_SIZE - 1)); + rand = net_random(); + snum = first = rand % remaining + low; + rand |= 1; + while (udp_lib_lport_inuse(net, snum, udptable, sk, + saddr_comp)) { + do { + snum = snum + rand; + } while (snum < low || snum > high); + if (snum == first) + goto fail; } - - - /* All ports in use! */ + } else if (udp_lib_lport_inuse(net, snum, udptable, sk, saddr_comp)) goto fail; -gotit: - snum = rover; - } else { - head = &udptable[udp_hashfn(net, snum)]; - - sk_for_each(sk2, node, head) - if (sk2->sk_hash == snum && - sk2 != sk && - net_eq(sock_net(sk2), net) && - (!sk2->sk_reuse || !sk->sk_reuse) && - (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if - || sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && - (*saddr_comp)(sk, sk2) ) - goto fail; - } - inet_sk(sk)->num = snum; sk->sk_hash = snum; if (sk_unhashed(sk)) { - head = &udptable[udp_hashfn(net, snum)]; - sk_add_node(sk, head); + sk_add_node(sk, &udptable[udp_hashfn(net, snum)]); sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); } error = 0; @@ -302,6 +262,28 @@ static struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, return result; } +static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb, + __be16 sport, __be16 dport, + struct hlist_head udptable[]) +{ + struct sock *sk; + const struct iphdr *iph = ip_hdr(skb); + + if (unlikely(sk = skb_steal_sock(skb))) + return sk; + else + return __udp4_lib_lookup(dev_net(skb->dst->dev), iph->saddr, sport, + iph->daddr, dport, inet_iif(skb), + udptable); +} + +struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, + __be32 daddr, __be16 dport, int dif) +{ + return __udp4_lib_lookup(net, saddr, sport, daddr, dport, dif, udp_hash); +} +EXPORT_SYMBOL_GPL(udp4_lib_lookup); + static inline struct sock *udp_v4_mcast_next(struct sock *sk, __be16 loc_port, __be32 loc_addr, __be16 rmt_port, __be32 rmt_addr, @@ -1201,8 +1183,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct hlist_head udptable[], return __udp4_lib_mcast_deliver(net, skb, uh, saddr, daddr, udptable); - sk = __udp4_lib_lookup(net, saddr, uh->source, daddr, - uh->dest, inet_iif(skb), udptable); + sk = __udp4_lib_lookup_skb(skb, uh->source, uh->dest, udptable); if (sk != NULL) { int ret = udp_queue_rcv_skb(sk, skb); |