summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/linux/net.h9
-rw-r--r--include/net/sock.h20
-rw-r--r--net/core/sock.c11
-rw-r--r--net/socket.c80
4 files changed, 61 insertions, 59 deletions
diff --git a/include/linux/net.h b/include/linux/net.h
index 8b012430f49d..04bd681473db 100644
--- a/include/linux/net.h
+++ b/include/linux/net.h
@@ -89,9 +89,11 @@ struct page;
struct kiocb;
struct sockaddr;
struct msghdr;
+struct module;
struct proto_ops {
int family;
+ struct module *owner;
int (*release) (struct socket *sock);
int (*bind) (struct socket *sock,
struct sockaddr *umyaddr,
@@ -127,8 +129,6 @@ struct proto_ops {
int offset, size_t size, int flags);
};
-struct module;
-
struct net_proto_family {
int family;
int (*create)(struct socket *sock, int protocol);
@@ -140,9 +140,6 @@ struct net_proto_family {
struct module *owner;
};
-extern int net_family_get(int family);
-extern void net_family_put(int family);
-
struct iovec;
extern int sock_wake_async(struct socket *sk, int how, int band);
@@ -227,7 +224,7 @@ SOCKCALL_WRAP(name, mmap, (struct file *file, struct socket *sock, struct vm_are
\
static struct proto_ops name##_ops = { \
.family = fam, \
- \
+ .owner = THIS_MODULE, \
.release = __lock_##name##_release, \
.bind = __lock_##name##_bind, \
.connect = __lock_##name##_connect, \
diff --git a/include/net/sock.h b/include/net/sock.h
index 4f5b79f30880..53639300bc3c 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -43,7 +43,7 @@
#include <linux/config.h>
#include <linux/timer.h>
#include <linux/cache.h>
-
+#include <linux/module.h>
#include <linux/netdevice.h>
#include <linux/skbuff.h> /* struct sk_buff */
#include <linux/security.h>
@@ -197,6 +197,7 @@ struct sock {
void *user_data;
/* Callbacks */
+ struct module *owner;
void (*state_change)(struct sock *sk);
void (*data_ready)(struct sock *sk,int bytes);
void (*write_space)(struct sock *sk);
@@ -270,6 +271,23 @@ struct proto {
} stats[NR_CPUS];
};
+static __inline__ void sk_set_owner(struct sock *sk, struct module *owner)
+{
+ /*
+ * One should use sk_set_owner just once, after struct sock creation,
+ * be it shortly after sk_alloc or after a function that returns a new
+ * struct sock (and that down the call chain called sk_alloc), e.g. the
+ * IPv4 and IPv6 modules share tcp_create_openreq_child, so if
+ * tcp_create_openreq_child called sk_set_owner IPv6 would have to
+ * change the ownership of this struct sock, with one not needed
+ * transient sk_set_owner call.
+ */
+ if (unlikely(sk->owner != NULL))
+ BUG();
+ sk->owner = owner;
+ __module_get(owner);
+}
+
/* Called with local bh disabled */
static __inline__ void sock_prot_inc_use(struct proto *prot)
{
diff --git a/net/core/sock.c b/net/core/sock.c
index 29fdc583abe3..82e8942061db 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -591,8 +591,6 @@ struct sock *sk_alloc(int family, int priority, int zero_it, kmem_cache_t *slab)
{
struct sock *sk = NULL;
- if (!net_family_get(family))
- goto out;
if (!slab)
slab = sk_cachep;
sk = kmem_cache_alloc(slab, priority);
@@ -604,16 +602,14 @@ struct sock *sk_alloc(int family, int priority, int zero_it, kmem_cache_t *slab)
sock_lock_init(sk);
}
sk->slab = slab;
- } else
- net_family_put(family);
-out:
+ }
return sk;
}
void sk_free(struct sock *sk)
{
struct sk_filter *filter;
- const int family = sk->family;
+ struct module *owner = sk->owner;
if (sk->destruct)
sk->destruct(sk);
@@ -628,7 +624,7 @@ void sk_free(struct sock *sk)
printk(KERN_DEBUG "sk_free: optmem leakage (%d bytes) detected.\n", atomic_read(&sk->omem_alloc));
kmem_cache_free(sk->slab, sk);
- net_family_put(family);
+ module_put(owner);
}
void __init sk_init(void)
@@ -1112,6 +1108,7 @@ void sock_init_data(struct socket *sock, struct sock *sk)
sk->rcvlowat = 1;
sk->rcvtimeo = MAX_SCHEDULE_TIMEOUT;
sk->sndtimeo = MAX_SCHEDULE_TIMEOUT;
+ sk->owner = NULL;
atomic_set(&sk->refcnt, 1);
}
diff --git a/net/socket.c b/net/socket.c
index d95971cd3333..9e49a7bd82f1 100644
--- a/net/socket.c
+++ b/net/socket.c
@@ -141,36 +141,6 @@ static struct file_operations socket_file_ops = {
static struct net_proto_family *net_families[NPROTO];
-static __inline__ void net_family_bug(int family)
-{
- printk(KERN_ERR "%d is not yet sock_registered!\n", family);
- BUG();
-}
-
-int net_family_get(int family)
-{
- struct net_proto_family *prot = net_families[family];
- int rc = 1;
-
- barrier();
- if (likely(prot != NULL))
- rc = try_module_get(prot->owner);
- else
- net_family_bug(family);
- return rc;
-}
-
-void net_family_put(int family)
-{
- struct net_proto_family *prot = net_families[family];
-
- barrier();
- if (likely(prot != NULL))
- module_put(prot->owner);
- else
- net_family_bug(family);
-}
-
#if defined(CONFIG_SMP) || defined(CONFIG_PREEMPT)
static atomic_t net_family_lockct = ATOMIC_INIT(0);
static spinlock_t net_family_lock = SPIN_LOCK_UNLOCKED;
@@ -535,11 +505,11 @@ struct file_operations bad_sock_fops = {
void sock_release(struct socket *sock)
{
if (sock->ops) {
- const int family = sock->ops->family;
+ struct module *owner = sock->ops->owner;
sock->ops->release(sock);
sock->ops = NULL;
- net_family_put(family);
+ module_put(owner);
}
if (sock->fasync_list)
@@ -1091,19 +1061,37 @@ int sock_create(int family, int type, int protocol, struct socket **res)
sock->type = type;
+ /*
+ * We will call the ->create function, that possibly is in a loadable
+ * module, so we have to bump that loadable module refcnt first.
+ */
i = -EAFNOSUPPORT;
- if (!net_family_get(family))
- goto out_release;
-
- if ((i = net_families[family]->create(sock, protocol)) < 0)
+ if (!try_module_get(net_families[family]->owner))
goto out_release;
+ if ((i = net_families[family]->create(sock, protocol)) < 0)
+ goto out_module_put;
+ /*
+ * Now to bump the refcnt of the [loadable] module that owns this
+ * socket at sock_release time we decrement its refcnt.
+ */
+ if (!try_module_get(sock->ops->owner)) {
+ sock->ops = NULL;
+ goto out_module_put;
+ }
+ /*
+ * Now that we're done with the ->create function, the [loadable]
+ * module can have its refcnt decremented
+ */
+ module_put(net_families[family]->owner);
*res = sock;
security_socket_post_create(sock, family, type, protocol);
out:
net_family_read_unlock();
return i;
+out_module_put:
+ module_put(net_families[family]->owner);
out_release:
sock_release(sock);
goto out;
@@ -1288,28 +1276,30 @@ asmlinkage long sys_accept(int fd, struct sockaddr *upeer_sockaddr, int *upeer_a
if (err)
goto out_release;
- err = -EAFNOSUPPORT;
- if (!net_family_get(sock->ops->family))
- goto out_release;
+ /*
+ * We don't need try_module_get here, as the listening socket (sock)
+ * has the protocol module (sock->ops->owner) held.
+ */
+ __module_get(sock->ops->owner);
err = sock->ops->accept(sock, newsock, sock->file->f_flags);
if (err < 0)
- goto out_family_put;
+ goto out_module_put;
if (upeer_sockaddr) {
if(newsock->ops->getname(newsock, (struct sockaddr *)address, &len, 2)<0) {
err = -ECONNABORTED;
- goto out_family_put;
+ goto out_module_put;
}
err = move_addr_to_user(address, len, upeer_sockaddr, upeer_addrlen);
if (err < 0)
- goto out_family_put;
+ goto out_module_put;
}
/* File flags are not inherited via accept() unlike another OSes. */
if ((err = sock_map_fd(newsock)) < 0)
- goto out_family_put;
+ goto out_module_put;
security_socket_post_accept(sock, newsock);
@@ -1317,8 +1307,8 @@ out_put:
sockfd_put(sock);
out:
return err;
-out_family_put:
- net_family_put(sock->ops->family);
+out_module_put:
+ module_put(sock->ops->owner);
out_release:
sock_release(newsock);
goto out_put;