diff options
| -rw-r--r-- | include/linux/net.h | 9 | ||||
| -rw-r--r-- | include/net/sock.h | 20 | ||||
| -rw-r--r-- | net/core/sock.c | 11 | ||||
| -rw-r--r-- | net/socket.c | 80 |
4 files changed, 61 insertions, 59 deletions
diff --git a/include/linux/net.h b/include/linux/net.h index 8b012430f49d..04bd681473db 100644 --- a/include/linux/net.h +++ b/include/linux/net.h @@ -89,9 +89,11 @@ struct page; struct kiocb; struct sockaddr; struct msghdr; +struct module; struct proto_ops { int family; + struct module *owner; int (*release) (struct socket *sock); int (*bind) (struct socket *sock, struct sockaddr *umyaddr, @@ -127,8 +129,6 @@ struct proto_ops { int offset, size_t size, int flags); }; -struct module; - struct net_proto_family { int family; int (*create)(struct socket *sock, int protocol); @@ -140,9 +140,6 @@ struct net_proto_family { struct module *owner; }; -extern int net_family_get(int family); -extern void net_family_put(int family); - struct iovec; extern int sock_wake_async(struct socket *sk, int how, int band); @@ -227,7 +224,7 @@ SOCKCALL_WRAP(name, mmap, (struct file *file, struct socket *sock, struct vm_are \ static struct proto_ops name##_ops = { \ .family = fam, \ - \ + .owner = THIS_MODULE, \ .release = __lock_##name##_release, \ .bind = __lock_##name##_bind, \ .connect = __lock_##name##_connect, \ diff --git a/include/net/sock.h b/include/net/sock.h index 4f5b79f30880..53639300bc3c 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -43,7 +43,7 @@ #include <linux/config.h> #include <linux/timer.h> #include <linux/cache.h> - +#include <linux/module.h> #include <linux/netdevice.h> #include <linux/skbuff.h> /* struct sk_buff */ #include <linux/security.h> @@ -197,6 +197,7 @@ struct sock { void *user_data; /* Callbacks */ + struct module *owner; void (*state_change)(struct sock *sk); void (*data_ready)(struct sock *sk,int bytes); void (*write_space)(struct sock *sk); @@ -270,6 +271,23 @@ struct proto { } stats[NR_CPUS]; }; +static __inline__ void sk_set_owner(struct sock *sk, struct module *owner) +{ + /* + * One should use sk_set_owner just once, after struct sock creation, + * be it shortly after sk_alloc or after a function that returns a new + * struct sock (and that down the call chain called sk_alloc), e.g. the + * IPv4 and IPv6 modules share tcp_create_openreq_child, so if + * tcp_create_openreq_child called sk_set_owner IPv6 would have to + * change the ownership of this struct sock, with one not needed + * transient sk_set_owner call. + */ + if (unlikely(sk->owner != NULL)) + BUG(); + sk->owner = owner; + __module_get(owner); +} + /* Called with local bh disabled */ static __inline__ void sock_prot_inc_use(struct proto *prot) { diff --git a/net/core/sock.c b/net/core/sock.c index 29fdc583abe3..82e8942061db 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -591,8 +591,6 @@ struct sock *sk_alloc(int family, int priority, int zero_it, kmem_cache_t *slab) { struct sock *sk = NULL; - if (!net_family_get(family)) - goto out; if (!slab) slab = sk_cachep; sk = kmem_cache_alloc(slab, priority); @@ -604,16 +602,14 @@ struct sock *sk_alloc(int family, int priority, int zero_it, kmem_cache_t *slab) sock_lock_init(sk); } sk->slab = slab; - } else - net_family_put(family); -out: + } return sk; } void sk_free(struct sock *sk) { struct sk_filter *filter; - const int family = sk->family; + struct module *owner = sk->owner; if (sk->destruct) sk->destruct(sk); @@ -628,7 +624,7 @@ void sk_free(struct sock *sk) printk(KERN_DEBUG "sk_free: optmem leakage (%d bytes) detected.\n", atomic_read(&sk->omem_alloc)); kmem_cache_free(sk->slab, sk); - net_family_put(family); + module_put(owner); } void __init sk_init(void) @@ -1112,6 +1108,7 @@ void sock_init_data(struct socket *sock, struct sock *sk) sk->rcvlowat = 1; sk->rcvtimeo = MAX_SCHEDULE_TIMEOUT; sk->sndtimeo = MAX_SCHEDULE_TIMEOUT; + sk->owner = NULL; atomic_set(&sk->refcnt, 1); } diff --git a/net/socket.c b/net/socket.c index d95971cd3333..9e49a7bd82f1 100644 --- a/net/socket.c +++ b/net/socket.c @@ -141,36 +141,6 @@ static struct file_operations socket_file_ops = { static struct net_proto_family *net_families[NPROTO]; -static __inline__ void net_family_bug(int family) -{ - printk(KERN_ERR "%d is not yet sock_registered!\n", family); - BUG(); -} - -int net_family_get(int family) -{ - struct net_proto_family *prot = net_families[family]; - int rc = 1; - - barrier(); - if (likely(prot != NULL)) - rc = try_module_get(prot->owner); - else - net_family_bug(family); - return rc; -} - -void net_family_put(int family) -{ - struct net_proto_family *prot = net_families[family]; - - barrier(); - if (likely(prot != NULL)) - module_put(prot->owner); - else - net_family_bug(family); -} - #if defined(CONFIG_SMP) || defined(CONFIG_PREEMPT) static atomic_t net_family_lockct = ATOMIC_INIT(0); static spinlock_t net_family_lock = SPIN_LOCK_UNLOCKED; @@ -535,11 +505,11 @@ struct file_operations bad_sock_fops = { void sock_release(struct socket *sock) { if (sock->ops) { - const int family = sock->ops->family; + struct module *owner = sock->ops->owner; sock->ops->release(sock); sock->ops = NULL; - net_family_put(family); + module_put(owner); } if (sock->fasync_list) @@ -1091,19 +1061,37 @@ int sock_create(int family, int type, int protocol, struct socket **res) sock->type = type; + /* + * We will call the ->create function, that possibly is in a loadable + * module, so we have to bump that loadable module refcnt first. + */ i = -EAFNOSUPPORT; - if (!net_family_get(family)) - goto out_release; - - if ((i = net_families[family]->create(sock, protocol)) < 0) + if (!try_module_get(net_families[family]->owner)) goto out_release; + if ((i = net_families[family]->create(sock, protocol)) < 0) + goto out_module_put; + /* + * Now to bump the refcnt of the [loadable] module that owns this + * socket at sock_release time we decrement its refcnt. + */ + if (!try_module_get(sock->ops->owner)) { + sock->ops = NULL; + goto out_module_put; + } + /* + * Now that we're done with the ->create function, the [loadable] + * module can have its refcnt decremented + */ + module_put(net_families[family]->owner); *res = sock; security_socket_post_create(sock, family, type, protocol); out: net_family_read_unlock(); return i; +out_module_put: + module_put(net_families[family]->owner); out_release: sock_release(sock); goto out; @@ -1288,28 +1276,30 @@ asmlinkage long sys_accept(int fd, struct sockaddr *upeer_sockaddr, int *upeer_a if (err) goto out_release; - err = -EAFNOSUPPORT; - if (!net_family_get(sock->ops->family)) - goto out_release; + /* + * We don't need try_module_get here, as the listening socket (sock) + * has the protocol module (sock->ops->owner) held. + */ + __module_get(sock->ops->owner); err = sock->ops->accept(sock, newsock, sock->file->f_flags); if (err < 0) - goto out_family_put; + goto out_module_put; if (upeer_sockaddr) { if(newsock->ops->getname(newsock, (struct sockaddr *)address, &len, 2)<0) { err = -ECONNABORTED; - goto out_family_put; + goto out_module_put; } err = move_addr_to_user(address, len, upeer_sockaddr, upeer_addrlen); if (err < 0) - goto out_family_put; + goto out_module_put; } /* File flags are not inherited via accept() unlike another OSes. */ if ((err = sock_map_fd(newsock)) < 0) - goto out_family_put; + goto out_module_put; security_socket_post_accept(sock, newsock); @@ -1317,8 +1307,8 @@ out_put: sockfd_put(sock); out: return err; -out_family_put: - net_family_put(sock->ops->family); +out_module_put: + module_put(sock->ops->owner); out_release: sock_release(newsock); goto out_put; |
