diff options
Diffstat (limited to 'net/sunrpc/svcsock.c')
-rw-r--r-- | net/sunrpc/svcsock.c | 407 |
1 files changed, 279 insertions, 128 deletions
diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c index ff1f8bf680a..63ae94771b8 100644 --- a/net/sunrpc/svcsock.c +++ b/net/sunrpc/svcsock.c @@ -36,11 +36,13 @@ #include <net/sock.h> #include <net/checksum.h> #include <net/ip.h> +#include <net/ipv6.h> #include <net/tcp_states.h> #include <asm/uaccess.h> #include <asm/ioctls.h> #include <linux/sunrpc/types.h> +#include <linux/sunrpc/clnt.h> #include <linux/sunrpc/xdr.h> #include <linux/sunrpc/svcsock.h> #include <linux/sunrpc/stats.h> @@ -58,10 +60,16 @@ * providing that certain rules are followed: * * SK_CONN, SK_DATA, can be set or cleared at any time. - * after a set, svc_sock_enqueue must be called. + * after a set, svc_sock_enqueue must be called. * after a clear, the socket must be read/accepted * if this succeeds, it must be set again. * SK_CLOSE can set at any time. It is never cleared. + * sk_inuse contains a bias of '1' until SK_DEAD is set. + * so when sk_inuse hits zero, we know the socket is dead + * and no-one is using it. + * SK_DEAD can only be set while SK_BUSY is held which ensures + * no other thread will be using the socket or will try to + * set SK_DEAD. * */ @@ -69,7 +77,8 @@ static struct svc_sock *svc_setup_socket(struct svc_serv *, struct socket *, - int *errp, int pmap_reg); + int *errp, int flags); +static void svc_delete_socket(struct svc_sock *svsk); static void svc_udp_data_ready(struct sock *, int); static int svc_udp_recvfrom(struct svc_rqst *); static int svc_udp_sendto(struct svc_rqst *); @@ -114,6 +123,41 @@ static inline void svc_reclassify_socket(struct socket *sock) } #endif +static char *__svc_print_addr(struct sockaddr *addr, char *buf, size_t len) +{ + switch (addr->sa_family) { + case AF_INET: + snprintf(buf, len, "%u.%u.%u.%u, port=%u", + NIPQUAD(((struct sockaddr_in *) addr)->sin_addr), + htons(((struct sockaddr_in *) addr)->sin_port)); + break; +#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) + case AF_INET6: + snprintf(buf, len, "%x:%x:%x:%x:%x:%x:%x:%x, port=%u", + NIP6(((struct sockaddr_in6 *) addr)->sin6_addr), + htons(((struct sockaddr_in6 *) addr)->sin6_port)); + break; +#endif + default: + snprintf(buf, len, "unknown address type: %d", addr->sa_family); + break; + } + return buf; +} + +/** + * svc_print_addr - Format rq_addr field for printing + * @rqstp: svc_rqst struct containing address to print + * @buf: target buffer for formatted address + * @len: length of target buffer + * + */ +char *svc_print_addr(struct svc_rqst *rqstp, char *buf, size_t len) +{ + return __svc_print_addr(svc_addr(rqstp), buf, len); +} +EXPORT_SYMBOL_GPL(svc_print_addr); + /* * Queue up an idle server thread. Must have pool->sp_lock held. * Note: this is really a stack rather than a queue, so that we only @@ -245,7 +289,7 @@ svc_sock_enqueue(struct svc_sock *svsk) svsk->sk_sk, rqstp); svc_thread_dequeue(pool, rqstp); if (rqstp->rq_sock) - printk(KERN_ERR + printk(KERN_ERR "svc_sock_enqueue: server %p, rq_sock=%p!\n", rqstp, rqstp->rq_sock); rqstp->rq_sock = svsk; @@ -329,8 +373,9 @@ void svc_reserve(struct svc_rqst *rqstp, int space) static inline void svc_sock_put(struct svc_sock *svsk) { - if (atomic_dec_and_test(&svsk->sk_inuse) && - test_bit(SK_DEAD, &svsk->sk_flags)) { + if (atomic_dec_and_test(&svsk->sk_inuse)) { + BUG_ON(! test_bit(SK_DEAD, &svsk->sk_flags)); + dprintk("svc: releasing dead socket\n"); if (svsk->sk_sock->file) sockfd_put(svsk->sk_sock); @@ -402,6 +447,43 @@ svc_wake_up(struct svc_serv *serv) } } +union svc_pktinfo_u { + struct in_pktinfo pkti; +#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) + struct in6_pktinfo pkti6; +#endif +}; + +static void svc_set_cmsg_data(struct svc_rqst *rqstp, struct cmsghdr *cmh) +{ + switch (rqstp->rq_sock->sk_sk->sk_family) { + case AF_INET: { + struct in_pktinfo *pki = CMSG_DATA(cmh); + + cmh->cmsg_level = SOL_IP; + cmh->cmsg_type = IP_PKTINFO; + pki->ipi_ifindex = 0; + pki->ipi_spec_dst.s_addr = rqstp->rq_daddr.addr.s_addr; + cmh->cmsg_len = CMSG_LEN(sizeof(*pki)); + } + break; +#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) + case AF_INET6: { + struct in6_pktinfo *pki = CMSG_DATA(cmh); + + cmh->cmsg_level = SOL_IPV6; + cmh->cmsg_type = IPV6_PKTINFO; + pki->ipi6_ifindex = 0; + ipv6_addr_copy(&pki->ipi6_addr, + &rqstp->rq_daddr.addr6); + cmh->cmsg_len = CMSG_LEN(sizeof(*pki)); + } + break; +#endif + } + return; +} + /* * Generic sendto routine */ @@ -411,9 +493,8 @@ svc_sendto(struct svc_rqst *rqstp, struct xdr_buf *xdr) struct svc_sock *svsk = rqstp->rq_sock; struct socket *sock = svsk->sk_sock; int slen; - char buffer[CMSG_SPACE(sizeof(struct in_pktinfo))]; + char buffer[CMSG_SPACE(sizeof(union svc_pktinfo_u))]; struct cmsghdr *cmh = (struct cmsghdr *)buffer; - struct in_pktinfo *pki = (struct in_pktinfo *)CMSG_DATA(cmh); int len = 0; int result; int size; @@ -421,25 +502,20 @@ svc_sendto(struct svc_rqst *rqstp, struct xdr_buf *xdr) size_t base = xdr->page_base; unsigned int pglen = xdr->page_len; unsigned int flags = MSG_MORE; + char buf[RPC_MAX_ADDRBUFLEN]; slen = xdr->len; if (rqstp->rq_prot == IPPROTO_UDP) { - /* set the source and destination */ - struct msghdr msg; - msg.msg_name = &rqstp->rq_addr; - msg.msg_namelen = sizeof(rqstp->rq_addr); - msg.msg_iov = NULL; - msg.msg_iovlen = 0; - msg.msg_flags = MSG_MORE; - - msg.msg_control = cmh; - msg.msg_controllen = sizeof(buffer); - cmh->cmsg_len = CMSG_LEN(sizeof(*pki)); - cmh->cmsg_level = SOL_IP; - cmh->cmsg_type = IP_PKTINFO; - pki->ipi_ifindex = 0; - pki->ipi_spec_dst.s_addr = rqstp->rq_daddr; + struct msghdr msg = { + .msg_name = &rqstp->rq_addr, + .msg_namelen = rqstp->rq_addrlen, + .msg_control = cmh, + .msg_controllen = sizeof(buffer), + .msg_flags = MSG_MORE, + }; + + svc_set_cmsg_data(rqstp, cmh); if (sock_sendmsg(sock, &msg, 0) < 0) goto out; @@ -476,16 +552,16 @@ svc_sendto(struct svc_rqst *rqstp, struct xdr_buf *xdr) if (xdr->tail[0].iov_len) { result = kernel_sendpage(sock, rqstp->rq_respages[0], ((unsigned long)xdr->tail[0].iov_base) - & (PAGE_SIZE-1), + & (PAGE_SIZE-1), xdr->tail[0].iov_len, 0); if (result > 0) len += result; } out: - dprintk("svc: socket %p sendto([%p %Zu... ], %d) = %d (addr %x)\n", - rqstp->rq_sock, xdr->head[0].iov_base, xdr->head[0].iov_len, xdr->len, len, - rqstp->rq_addr.sin_addr.s_addr); + dprintk("svc: socket %p sendto([%p %Zu... ], %d) = %d (addr %s)\n", + rqstp->rq_sock, xdr->head[0].iov_base, xdr->head[0].iov_len, + xdr->len, len, svc_print_addr(rqstp, buf, sizeof(buf))); return len; } @@ -520,7 +596,7 @@ svc_sock_names(char *buf, struct svc_serv *serv, char *toclose) if (!serv) return 0; - spin_lock(&serv->sv_lock); + spin_lock_bh(&serv->sv_lock); list_for_each_entry(svsk, &serv->sv_permsocks, sk_list) { int onelen = one_sock_name(buf+len, svsk); if (toclose && strcmp(toclose, buf+len) == 0) @@ -528,12 +604,12 @@ svc_sock_names(char *buf, struct svc_serv *serv, char *toclose) else len += onelen; } - spin_unlock(&serv->sv_lock); + spin_unlock_bh(&serv->sv_lock); if (closesk) /* Should unregister with portmap, but you cannot * unregister just one protocol... */ - svc_delete_socket(closesk); + svc_close_socket(closesk); else if (toclose) return -ENOENT; return len; @@ -560,31 +636,22 @@ svc_recv_available(struct svc_sock *svsk) static int svc_recvfrom(struct svc_rqst *rqstp, struct kvec *iov, int nr, int buflen) { - struct msghdr msg; - struct socket *sock; - int len, alen; - - rqstp->rq_addrlen = sizeof(rqstp->rq_addr); - sock = rqstp->rq_sock->sk_sock; - - msg.msg_name = &rqstp->rq_addr; - msg.msg_namelen = sizeof(rqstp->rq_addr); - msg.msg_control = NULL; - msg.msg_controllen = 0; - - msg.msg_flags = MSG_DONTWAIT; + struct svc_sock *svsk = rqstp->rq_sock; + struct msghdr msg = { + .msg_flags = MSG_DONTWAIT, + }; + int len; - len = kernel_recvmsg(sock, &msg, iov, nr, buflen, MSG_DONTWAIT); + len = kernel_recvmsg(svsk->sk_sock, &msg, iov, nr, buflen, + msg.msg_flags); /* sock_recvmsg doesn't fill in the name/namelen, so we must.. - * possibly we should cache this in the svc_sock structure - * at accept time. FIXME */ - alen = sizeof(rqstp->rq_addr); - kernel_getpeername(sock, (struct sockaddr *)&rqstp->rq_addr, &alen); + memcpy(&rqstp->rq_addr, &svsk->sk_remote, svsk->sk_remotelen); + rqstp->rq_addrlen = svsk->sk_remotelen; dprintk("svc: socket %p recvfrom(%p, %Zu) = %d\n", - rqstp->rq_sock, iov[0].iov_base, iov[0].iov_len, len); + svsk, iov[0].iov_base, iov[0].iov_len, len); return len; } @@ -654,6 +721,47 @@ svc_write_space(struct sock *sk) } } +static void svc_udp_get_sender_address(struct svc_rqst *rqstp, + struct sk_buff *skb) +{ + switch (rqstp->rq_sock->sk_sk->sk_family) { + case AF_INET: { + /* this seems to come from net/ipv4/udp.c:udp_recvmsg */ + struct sockaddr_in *sin = svc_addr_in(rqstp); + + sin->sin_family = AF_INET; + sin->sin_port = skb->h.uh->source; + sin->sin_addr.s_addr = skb->nh.iph->saddr; + rqstp->rq_addrlen = sizeof(struct sockaddr_in); + /* Remember which interface received this request */ + rqstp->rq_daddr.addr.s_addr = skb->nh.iph->daddr; + } + break; +#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) + case AF_INET6: { + /* this is derived from net/ipv6/udp.c:udpv6_recvmesg */ + struct sockaddr_in6 *sin6 = svc_addr_in6(rqstp); + + sin6->sin6_family = AF_INET6; + sin6->sin6_port = skb->h.uh->source; + sin6->sin6_flowinfo = 0; + sin6->sin6_scope_id = 0; + if (ipv6_addr_type(&sin6->sin6_addr) & + IPV6_ADDR_LINKLOCAL) + sin6->sin6_scope_id = IP6CB(skb)->iif; + ipv6_addr_copy(&sin6->sin6_addr, + &skb->nh.ipv6h->saddr); + rqstp->rq_addrlen = sizeof(struct sockaddr_in); + /* Remember which interface received this request */ + ipv6_addr_copy(&rqstp->rq_daddr.addr6, + &skb->nh.ipv6h->saddr); + } + break; +#endif + } + return; +} + /* * Receive a datagram from a UDP socket. */ @@ -683,6 +791,11 @@ svc_udp_recvfrom(struct svc_rqst *rqstp) return svc_deferred_recv(rqstp); } + if (test_bit(SK_CLOSE, &svsk->sk_flags)) { + svc_delete_socket(svsk); + return 0; + } + clear_bit(SK_DATA, &svsk->sk_flags); while ((skb = skb_recv_datagram(svsk->sk_sk, 0, 1, &err)) == NULL) { if (err == -EAGAIN) { @@ -698,7 +811,7 @@ svc_udp_recvfrom(struct svc_rqst *rqstp) tv.tv_sec = xtime.tv_sec; tv.tv_usec = xtime.tv_nsec / NSEC_PER_USEC; skb_set_timestamp(skb, &tv); - /* Don't enable netstamp, sunrpc doesn't + /* Don't enable netstamp, sunrpc doesn't need that much accuracy */ } skb_get_timestamp(skb, &svsk->sk_sk->sk_stamp); @@ -712,13 +825,9 @@ svc_udp_recvfrom(struct svc_rqst *rqstp) len = skb->len - sizeof(struct udphdr); rqstp->rq_arg.len = len; - rqstp->rq_prot = IPPROTO_UDP; + rqstp->rq_prot = IPPROTO_UDP; - /* Get sender address */ - rqstp->rq_addr.sin_family = AF_INET; - rqstp->rq_addr.sin_port = skb->h.uh->source; - rqstp->rq_addr.sin_addr.s_addr = skb->nh.iph->saddr; - rqstp->rq_daddr = skb->nh.iph->daddr; + svc_udp_get_sender_address(rqstp, skb); if (skb_is_nonlinear(skb)) { /* we have to copy */ @@ -730,7 +839,7 @@ svc_udp_recvfrom(struct svc_rqst *rqstp) return 0; } local_bh_enable(); - skb_free_datagram(svsk->sk_sk, skb); + skb_free_datagram(svsk->sk_sk, skb); } else { /* we can use it in-place */ rqstp->rq_arg.head[0].iov_base = skb->data + sizeof(struct udphdr); @@ -781,7 +890,7 @@ svc_udp_init(struct svc_sock *svsk) svsk->sk_sendto = svc_udp_sendto; /* initialise setting must have enough space to - * receive and respond to one request. + * receive and respond to one request. * svc_udp_recvfrom will re-adjust if necessary */ svc_sock_setbufsize(svsk->sk_sock, @@ -862,18 +971,36 @@ svc_tcp_data_ready(struct sock *sk, int count) wake_up_interruptible(sk->sk_sleep); } +static inline int svc_port_is_privileged(struct sockaddr *sin) +{ + switch (sin->sa_family) { + case AF_INET: + return ntohs(((struct sockaddr_in *)sin)->sin_port) + < PROT_SOCK; +#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) + case AF_INET6: + return ntohs(((struct sockaddr_in6 *)sin)->sin6_port) + < PROT_SOCK; +#endif + default: + return 0; + } +} + /* * Accept a TCP connection */ static void svc_tcp_accept(struct svc_sock *svsk) { - struct sockaddr_in sin; + struct sockaddr_storage addr; + struct sockaddr *sin = (struct sockaddr *) &addr; struct svc_serv *serv = svsk->sk_server; struct socket *sock = svsk->sk_sock; struct socket *newsock; struct svc_sock *newsvsk; int err, slen; + char buf[RPC_MAX_ADDRBUFLEN]; dprintk("svc: tcp_accept %p sock %p\n", svsk, sock); if (!sock) @@ -894,8 +1021,7 @@ svc_tcp_accept(struct svc_sock *svsk) set_bit(SK_CONN, &svsk->sk_flags); svc_sock_enqueue(svsk); - slen = sizeof(sin); - err = kernel_getpeername(newsock, (struct sockaddr *) &sin, &slen); + err = kernel_getpeername(newsock, sin, &slen); if (err < 0) { if (net_ratelimit()) printk(KERN_WARNING "%s: peername failed (err %d)!\n", @@ -904,27 +1030,30 @@ svc_tcp_accept(struct svc_sock *svsk) } /* Ideally, we would want to reject connections from unauthorized - * hosts here, but when we get encription, the IP of the host won't - * tell us anything. For now just warn about unpriv connections. + * hosts here, but when we get encryption, the IP of the host won't + * tell us anything. For now just warn about unpriv connections. */ - if (ntohs(sin.sin_port) >= 1024) { + if (!svc_port_is_privileged(sin)) { dprintk(KERN_WARNING - "%s: connect from unprivileged port: %u.%u.%u.%u:%d\n", - serv->sv_name, - NIPQUAD(sin.sin_addr.s_addr), ntohs(sin.sin_port)); + "%s: connect from unprivileged port: %s\n", + serv->sv_name, + __svc_print_addr(sin, buf, sizeof(buf))); } - - dprintk("%s: connect from %u.%u.%u.%u:%04x\n", serv->sv_name, - NIPQUAD(sin.sin_addr.s_addr), ntohs(sin.sin_port)); + dprintk("%s: connect from %s\n", serv->sv_name, + __svc_print_addr(sin, buf, sizeof(buf))); /* make sure that a write doesn't block forever when * low on memory */ newsock->sk->sk_sndtimeo = HZ*30; - if (!(newsvsk = svc_setup_socket(serv, newsock, &err, 0))) + if (!(newsvsk = svc_setup_socket(serv, newsock, &err, + (SVC_SOCK_ANONYMOUS | SVC_SOCK_TEMPORARY)))) goto failed; + memcpy(&newsvsk->sk_remote, sin, slen); + newsvsk->sk_remotelen = slen; + svc_sock_received(newsvsk); /* make sure that we don't have too many active connections. * If we have, something must be dropped. @@ -947,11 +1076,9 @@ svc_tcp_accept(struct svc_sock *svsk) "sockets, consider increasing the " "number of nfsd threads\n", serv->sv_name); - printk(KERN_NOTICE "%s: last TCP connect from " - "%u.%u.%u.%u:%d\n", - serv->sv_name, - NIPQUAD(sin.sin_addr.s_addr), - ntohs(sin.sin_port)); + printk(KERN_NOTICE + "%s: last TCP connect from %s\n", + serv->sv_name, buf); } /* * Always select the oldest socket. It's not fair, @@ -1025,7 +1152,7 @@ svc_tcp_recvfrom(struct svc_rqst *rqstp) * on the number of threads which will access the socket. * * rcvbuf just needs to be able to hold a few requests. - * Normally they will be removed from the queue + * Normally they will be removed from the queue * as soon a a complete request arrives. */ svc_sock_setbufsize(svsk->sk_sock, @@ -1050,7 +1177,7 @@ svc_tcp_recvfrom(struct svc_rqst *rqstp) if (len < want) { dprintk("svc: short recvfrom while reading record length (%d of %lu)\n", - len, want); + len, want); svc_sock_received(svsk); return -EAGAIN; /* record header not complete */ } @@ -1176,7 +1303,8 @@ svc_tcp_sendto(struct svc_rqst *rqstp) rqstp->rq_sock->sk_server->sv_name, (sent<0)?"got error":"sent only", sent, xbufp->len); - svc_delete_socket(rqstp->rq_sock); + set_bit(SK_CLOSE, &rqstp->rq_sock->sk_flags); + svc_sock_enqueue(rqstp->rq_sock); sent = -EAGAIN; } return sent; @@ -1207,7 +1335,7 @@ svc_tcp_init(struct svc_sock *svsk) tp->nonagle = 1; /* disable Nagle's algorithm */ /* initialise setting must have enough space to - * receive and respond to one request. + * receive and respond to one request. * svc_tcp_recvfrom will re-adjust if necessary */ svc_sock_setbufsize(svsk->sk_sock, @@ -1216,7 +1344,7 @@ svc_tcp_init(struct svc_sock *svsk) set_bit(SK_CHNGBUF, &svsk->sk_flags); set_bit(SK_DATA, &svsk->sk_flags); - if (sk->sk_state != TCP_ESTABLISHED) + if (sk->sk_state != TCP_ESTABLISHED) set_bit(SK_CLOSE, &svsk->sk_flags); } } @@ -1232,7 +1360,7 @@ svc_sock_update_bufs(struct svc_serv *serv) spin_lock_bh(&serv->sv_lock); list_for_each(le, &serv->sv_permsocks) { - struct svc_sock *svsk = + struct svc_sock *svsk = list_entry(le, struct svc_sock, sk_list); set_bit(SK_CHNGBUF, &svsk->sk_flags); } @@ -1252,7 +1380,7 @@ svc_sock_update_bufs(struct svc_serv *serv) int svc_recv(struct svc_rqst *rqstp, long timeout) { - struct svc_sock *svsk =NULL; + struct svc_sock *svsk = NULL; struct svc_serv *serv = rqstp->rq_server; struct svc_pool *pool = rqstp->rq_pool; int len, i; @@ -1264,11 +1392,11 @@ svc_recv(struct svc_rqst *rqstp, long timeout) rqstp, timeout); if (rqstp->rq_sock) - printk(KERN_ERR + printk(KERN_ERR "svc_recv: service %p, socket not NULL!\n", rqstp); if (waitqueue_active(&rqstp->rq_wait)) - printk(KERN_ERR + printk(KERN_ERR "svc_recv: service %p, wait queue active!\n", rqstp); @@ -1349,7 +1477,7 @@ svc_recv(struct svc_rqst *rqstp, long timeout) svsk->sk_lastrecv = get_seconds(); clear_bit(SK_OLD, &svsk->sk_flags); - rqstp->rq_secure = ntohs(rqstp->rq_addr.sin_port) < 1024; + rqstp->rq_secure = svc_port_is_privileged(svc_addr(rqstp)); rqstp->rq_chandle.defer = svc_defer; if (serv->sv_stats) @@ -1357,7 +1485,7 @@ svc_recv(struct svc_rqst *rqstp, long timeout) return len; } -/* +/* * Drop request */ void @@ -1462,12 +1590,14 @@ svc_age_temp_sockets(unsigned long closure) * Initialize socket for RPC use and create svc_sock struct * XXX: May want to setsockopt SO_SNDBUF and SO_RCVBUF. */ -static struct svc_sock * -svc_setup_socket(struct svc_serv *serv, struct socket *sock, - int *errp, int pmap_register) +static struct svc_sock *svc_setup_socket(struct svc_serv *serv, + struct socket *sock, + int *errp, int flags) { struct svc_sock *svsk; struct sock *inet; + int pmap_register = !(flags & SVC_SOCK_ANONYMOUS); + int is_temporary = flags & SVC_SOCK_TEMPORARY; dprintk("svc: svc_setup_socket %p\n", sock); if (!(svsk = kzalloc(sizeof(*svsk), GFP_KERNEL))) { @@ -1495,7 +1625,7 @@ svc_setup_socket(struct svc_serv *serv, struct socket *sock, svsk->sk_odata = inet->sk_data_ready; svsk->sk_owspace = inet->sk_write_space; svsk->sk_server = serv; - atomic_set(&svsk->sk_inuse, 0); + atomic_set(&svsk->sk_inuse, 1); svsk->sk_lastrecv = get_seconds(); spin_lock_init(&svsk->sk_defer_lock); INIT_LIST_HEAD(&svsk->sk_deferred); @@ -1509,7 +1639,7 @@ svc_setup_socket(struct svc_serv *serv, struct socket *sock, svc_tcp_init(svsk); spin_lock_bh(&serv->sv_lock); - if (!pmap_register) { + if (is_temporary) { set_bit(SK_TEMP, &svsk->sk_flags); list_add(&svsk->sk_list, &serv->sv_tempsocks); serv->sv_tmpcnt++; @@ -1529,8 +1659,6 @@ svc_setup_socket(struct svc_serv *serv, struct socket *sock, dprintk("svc: svc_setup_socket created %p (inet %p)\n", svsk, svsk->sk_sk); - clear_bit(SK_BUSY, &svsk->sk_flags); - svc_sock_enqueue(svsk); return svsk; } @@ -1553,9 +1681,11 @@ int svc_addsock(struct svc_serv *serv, else if (so->state > SS_UNCONNECTED) err = -EISCONN; else { - svsk = svc_setup_socket(serv, so, &err, 1); - if (svsk) + svsk = svc_setup_socket(serv, so, &err, SVC_SOCK_DEFAULTS); + if (svsk) { + svc_sock_received(svsk); err = 0; + } } if (err) { sockfd_put(so); @@ -1569,18 +1699,18 @@ EXPORT_SYMBOL_GPL(svc_addsock); /* * Create socket for RPC service. */ -static int -svc_create_socket(struct svc_serv *serv, int protocol, struct sockaddr_in *sin) +static int svc_create_socket(struct svc_serv *serv, int protocol, + struct sockaddr *sin, int len, int flags) { struct svc_sock *svsk; struct socket *sock; int error; int type; + char buf[RPC_MAX_ADDRBUFLEN]; - dprintk("svc: svc_create_socket(%s, %d, %u.%u.%u.%u:%d)\n", - serv->sv_program->pg_name, protocol, - NIPQUAD(sin->sin_addr.s_addr), - ntohs(sin->sin_port)); + dprintk("svc: svc_create_socket(%s, %d, %s)\n", + serv->sv_program->pg_name, protocol, + __svc_print_addr(sin, buf, sizeof(buf))); if (protocol != IPPROTO_UDP && protocol != IPPROTO_TCP) { printk(KERN_WARNING "svc: only UDP and TCP " @@ -1589,15 +1719,15 @@ svc_create_socket(struct svc_serv *serv, int protocol, struct sockaddr_in *sin) } type = (protocol == IPPROTO_UDP)? SOCK_DGRAM : SOCK_STREAM; - if ((error = sock_create_kern(PF_INET, type, protocol, &sock)) < 0) + error = sock_create_kern(sin->sa_family, type, protocol, &sock); + if (error < 0) return error; svc_reclassify_socket(sock); if (type == SOCK_STREAM) - sock->sk->sk_reuse = 1; /* allow address reuse */ - error = kernel_bind(sock, (struct sockaddr *) sin, - sizeof(*sin)); + sock->sk->sk_reuse = 1; /* allow address reuse */ + error = kernel_bind(sock, sin, len); if (error < 0) goto bummer; @@ -1606,8 +1736,10 @@ svc_create_socket(struct svc_serv *serv, int protocol, struct sockaddr_in *sin) goto bummer; } - if ((svsk = svc_setup_socket(serv, sock, &error, 1)) != NULL) - return 0; + if ((svsk = svc_setup_socket(serv, sock, &error, flags)) != NULL) { + svc_sock_received(svsk); + return ntohs(inet_sk(svsk->sk_sk)->sport); + } bummer: dprintk("svc: svc_create_socket error = %d\n", -error); @@ -1618,7 +1750,7 @@ bummer: /* * Remove a dead socket */ -void +static void svc_delete_socket(struct svc_sock *svsk) { struct svc_serv *serv; @@ -1637,43 +1769,60 @@ svc_delete_socket(struct svc_sock *svsk) if (!test_and_set_bit(SK_DETACHED, &svsk->sk_flags)) list_del_init(&svsk->sk_list); - /* + /* * We used to delete the svc_sock from whichever list * it's sk_ready node was on, but we don't actually * need to. This is because the only time we're called * while still attached to a queue, the queue itself * is about to be destroyed (in svc_destroy). */ - if (!test_and_set_bit(SK_DEAD, &svsk->sk_flags)) + if (!test_and_set_bit(SK_DEAD, &svsk->sk_flags)) { + BUG_ON(atomic_read(&svsk->sk_inuse)<2); + atomic_dec(&svsk->sk_inuse); if (test_bit(SK_TEMP, &svsk->sk_flags)) serv->sv_tmpcnt--; + } - /* This atomic_inc should be needed - svc_delete_socket - * should have the semantic of dropping a reference. - * But it doesn't yet.... - */ - atomic_inc(&svsk->sk_inuse); spin_unlock_bh(&serv->sv_lock); +} + +void svc_close_socket(struct svc_sock *svsk) +{ + set_bit(SK_CLOSE, &svsk->sk_flags); + if (test_and_set_bit(SK_BUSY, &svsk->sk_flags)) + /* someone else will have to effect the close */ + return; + + atomic_inc(&svsk->sk_inuse); + svc_delete_socket(svsk); + clear_bit(SK_BUSY, &svsk->sk_flags); svc_sock_put(svsk); } -/* - * Make a socket for nfsd and lockd +/** + * svc_makesock - Make a socket for nfsd and lockd + * @serv: RPC server structure + * @protocol: transport protocol to use + * @port: port to use + * @flags: requested socket characteristics + * */ -int -svc_makesock(struct svc_serv *serv, int protocol, unsigned short port) +int svc_makesock(struct svc_serv *serv, int protocol, unsigned short port, + int flags) { - struct sockaddr_in sin; + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_addr.s_addr = INADDR_ANY, + .sin_port = htons(port), + }; dprintk("svc: creating socket proto = %d\n", protocol); - sin.sin_family = AF_INET; - sin.sin_addr.s_addr = INADDR_ANY; - sin.sin_port = htons(port); - return svc_create_socket(serv, protocol, &sin); + return svc_create_socket(serv, protocol, (struct sockaddr *) &sin, + sizeof(sin), flags); } /* - * Handle defer and revisit of requests + * Handle defer and revisit of requests */ static void svc_revisit(struct cache_deferred_req *dreq, int too_many) @@ -1718,7 +1867,8 @@ svc_defer(struct cache_req *req) dr->handle.owner = rqstp->rq_server; dr->prot = rqstp->rq_prot; - dr->addr = rqstp->rq_addr; + memcpy(&dr->addr, &rqstp->rq_addr, rqstp->rq_addrlen); + dr->addrlen = rqstp->rq_addrlen; dr->daddr = rqstp->rq_daddr; dr->argslen = rqstp->rq_arg.len >> 2; memcpy(dr->args, rqstp->rq_arg.head[0].iov_base-skip, dr->argslen<<2); @@ -1742,7 +1892,8 @@ static int svc_deferred_recv(struct svc_rqst *rqstp) rqstp->rq_arg.page_len = 0; rqstp->rq_arg.len = dr->argslen<<2; rqstp->rq_prot = dr->prot; - rqstp->rq_addr = dr->addr; + memcpy(&rqstp->rq_addr, &dr->addr, dr->addrlen); + rqstp->rq_addrlen = dr->addrlen; rqstp->rq_daddr = dr->daddr; rqstp->rq_respages = rqstp->rq_pages; return dr->argslen<<2; @@ -1752,7 +1903,7 @@ static int svc_deferred_recv(struct svc_rqst *rqstp) static struct svc_deferred_req *svc_deferred_dequeue(struct svc_sock *svsk) { struct svc_deferred_req *dr = NULL; - + if (!test_bit(SK_DEFERRED, &svsk->sk_flags)) return NULL; spin_lock_bh(&svsk->sk_defer_lock); |