diff options
Diffstat (limited to 'net/vmw_vsock')
| -rw-r--r-- | net/vmw_vsock/af_vsock.c | 335 | ||||
| -rw-r--r-- | net/vmw_vsock/hyperv_transport.c | 7 | ||||
| -rw-r--r-- | net/vmw_vsock/virtio_transport.c | 22 | ||||
| -rw-r--r-- | net/vmw_vsock/virtio_transport_common.c | 62 | ||||
| -rw-r--r-- | net/vmw_vsock/vmci_transport.c | 28 | ||||
| -rw-r--r-- | net/vmw_vsock/vsock_loopback.c | 22 |
6 files changed, 406 insertions, 70 deletions
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index a3505a4dcee0..20ad2b2dc17b 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -83,6 +83,50 @@ * TCP_ESTABLISHED - connected * TCP_CLOSING - disconnecting * TCP_LISTEN - listening + * + * - Namespaces in vsock support two different modes: "local" and "global". + * Each mode defines how the namespace interacts with CIDs. + * Each namespace exposes two sysctl files: + * + * - /proc/sys/net/vsock/ns_mode (read-only) reports the current namespace's + * mode, which is set at namespace creation and immutable thereafter. + * - /proc/sys/net/vsock/child_ns_mode (writable) controls what mode future + * child namespaces will inherit when created. The default is "global". + * + * Changing child_ns_mode only affects newly created namespaces, not the + * current namespace or existing children. At namespace creation, ns_mode + * is inherited from the parent's child_ns_mode. + * + * The init_net mode is "global" and cannot be modified. + * + * The modes affect the allocation and accessibility of CIDs as follows: + * + * - global - access and allocation are all system-wide + * - all CID allocation from global namespaces draw from the same + * system-wide pool. + * - if one global namespace has already allocated some CID, another + * global namespace will not be able to allocate the same CID. + * - global mode AF_VSOCK sockets can reach any VM or socket in any global + * namespace, they are not contained to only their own namespace. + * - AF_VSOCK sockets in a global mode namespace cannot reach VMs or + * sockets in any local mode namespace. + * - local - access and allocation are contained within the namespace + * - CID allocation draws only from a private pool local only to the + * namespace, and does not affect the CIDs available for allocation in any + * other namespace (global or local). + * - VMs in a local namespace do not collide with CIDs in any other local + * namespace or any global namespace. For example, if a VM in a local mode + * namespace is given CID 10, then CID 10 is still available for + * allocation in any other namespace, but not in the same namespace. + * - AF_VSOCK sockets in a local mode namespace can connect only to VMs or + * other sockets within their own namespace. + * - sockets bound to VMADDR_CID_ANY in local namespaces will never resolve + * to any transport that is not compatible with local mode. There is no + * error that propagates to the user (as there is for connection attempts) + * because it is possible for some packet to reach this socket from + * a different transport that *does* support local mode. For + * example, virtio-vsock may not support local mode, but the socket + * may still accept a connection from vhost-vsock which does. */ #include <linux/compat.h> @@ -100,20 +144,31 @@ #include <linux/module.h> #include <linux/mutex.h> #include <linux/net.h> +#include <linux/proc_fs.h> #include <linux/poll.h> #include <linux/random.h> #include <linux/skbuff.h> #include <linux/smp.h> #include <linux/socket.h> #include <linux/stddef.h> +#include <linux/sysctl.h> #include <linux/unistd.h> #include <linux/wait.h> #include <linux/workqueue.h> #include <net/sock.h> #include <net/af_vsock.h> +#include <net/netns/vsock.h> #include <uapi/linux/vm_sockets.h> #include <uapi/asm-generic/ioctls.h> +#define VSOCK_NET_MODE_STR_GLOBAL "global" +#define VSOCK_NET_MODE_STR_LOCAL "local" + +/* 6 chars for "global", 1 for null-terminator, and 1 more for '\n'. + * The newline is added by proc_dostring() for read operations. + */ +#define VSOCK_NET_MODE_STR_MAX 8 + static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr); static void vsock_sk_destruct(struct sock *sk); static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb); @@ -235,33 +290,42 @@ static void __vsock_remove_connected(struct vsock_sock *vsk) sock_put(&vsk->sk); } -static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr) +static struct sock *__vsock_find_bound_socket_net(struct sockaddr_vm *addr, + struct net *net) { struct vsock_sock *vsk; list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) { - if (vsock_addr_equals_addr(addr, &vsk->local_addr)) - return sk_vsock(vsk); + struct sock *sk = sk_vsock(vsk); + + if (vsock_addr_equals_addr(addr, &vsk->local_addr) && + vsock_net_check_mode(sock_net(sk), net)) + return sk; if (addr->svm_port == vsk->local_addr.svm_port && (vsk->local_addr.svm_cid == VMADDR_CID_ANY || - addr->svm_cid == VMADDR_CID_ANY)) - return sk_vsock(vsk); + addr->svm_cid == VMADDR_CID_ANY) && + vsock_net_check_mode(sock_net(sk), net)) + return sk; } return NULL; } -static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src, - struct sockaddr_vm *dst) +static struct sock * +__vsock_find_connected_socket_net(struct sockaddr_vm *src, + struct sockaddr_vm *dst, struct net *net) { struct vsock_sock *vsk; list_for_each_entry(vsk, vsock_connected_sockets(src, dst), connected_table) { + struct sock *sk = sk_vsock(vsk); + if (vsock_addr_equals_addr(src, &vsk->remote_addr) && - dst->svm_port == vsk->local_addr.svm_port) { - return sk_vsock(vsk); + dst->svm_port == vsk->local_addr.svm_port && + vsock_net_check_mode(sock_net(sk), net)) { + return sk; } } @@ -304,12 +368,18 @@ void vsock_remove_connected(struct vsock_sock *vsk) } EXPORT_SYMBOL_GPL(vsock_remove_connected); -struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr) +/* Find a bound socket, filtering by namespace and namespace mode. + * + * Use this in transports that are namespace-aware and can provide the + * network namespace context. + */ +struct sock *vsock_find_bound_socket_net(struct sockaddr_vm *addr, + struct net *net) { struct sock *sk; spin_lock_bh(&vsock_table_lock); - sk = __vsock_find_bound_socket(addr); + sk = __vsock_find_bound_socket_net(addr, net); if (sk) sock_hold(sk); @@ -317,15 +387,32 @@ struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr) return sk; } +EXPORT_SYMBOL_GPL(vsock_find_bound_socket_net); + +/* Find a bound socket without namespace filtering. + * + * Use this in transports that lack namespace context. All sockets are + * treated as if in global mode. + */ +struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr) +{ + return vsock_find_bound_socket_net(addr, NULL); +} EXPORT_SYMBOL_GPL(vsock_find_bound_socket); -struct sock *vsock_find_connected_socket(struct sockaddr_vm *src, - struct sockaddr_vm *dst) +/* Find a connected socket, filtering by namespace and namespace mode. + * + * Use this in transports that are namespace-aware and can provide the + * network namespace context. + */ +struct sock *vsock_find_connected_socket_net(struct sockaddr_vm *src, + struct sockaddr_vm *dst, + struct net *net) { struct sock *sk; spin_lock_bh(&vsock_table_lock); - sk = __vsock_find_connected_socket(src, dst); + sk = __vsock_find_connected_socket_net(src, dst, net); if (sk) sock_hold(sk); @@ -333,6 +420,18 @@ struct sock *vsock_find_connected_socket(struct sockaddr_vm *src, return sk; } +EXPORT_SYMBOL_GPL(vsock_find_connected_socket_net); + +/* Find a connected socket without namespace filtering. + * + * Use this in transports that lack namespace context. All sockets are + * treated as if in global mode. + */ +struct sock *vsock_find_connected_socket(struct sockaddr_vm *src, + struct sockaddr_vm *dst) +{ + return vsock_find_connected_socket_net(src, dst, NULL); +} EXPORT_SYMBOL_GPL(vsock_find_connected_socket); void vsock_remove_sock(struct vsock_sock *vsk) @@ -528,7 +627,7 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) if (sk->sk_type == SOCK_SEQPACKET) { if (!new_transport->seqpacket_allow || - !new_transport->seqpacket_allow(remote_cid)) { + !new_transport->seqpacket_allow(vsk, remote_cid)) { module_put(new_transport->module); return -ESOCKTNOSUPPORT; } @@ -676,11 +775,11 @@ out: static int __vsock_bind_connectible(struct vsock_sock *vsk, struct sockaddr_vm *addr) { - static u32 port; + struct net *net = sock_net(sk_vsock(vsk)); struct sockaddr_vm new_addr; - if (!port) - port = get_random_u32_above(LAST_RESERVED_PORT); + if (!net->vsock.port) + net->vsock.port = get_random_u32_above(LAST_RESERVED_PORT); vsock_addr_init(&new_addr, addr->svm_cid, addr->svm_port); @@ -689,13 +788,13 @@ static int __vsock_bind_connectible(struct vsock_sock *vsk, unsigned int i; for (i = 0; i < MAX_PORT_RETRIES; i++) { - if (port == VMADDR_PORT_ANY || - port <= LAST_RESERVED_PORT) - port = LAST_RESERVED_PORT + 1; + if (net->vsock.port == VMADDR_PORT_ANY || + net->vsock.port <= LAST_RESERVED_PORT) + net->vsock.port = LAST_RESERVED_PORT + 1; - new_addr.svm_port = port++; + new_addr.svm_port = net->vsock.port++; - if (!__vsock_find_bound_socket(&new_addr)) { + if (!__vsock_find_bound_socket_net(&new_addr, net)) { found = true; break; } @@ -712,7 +811,7 @@ static int __vsock_bind_connectible(struct vsock_sock *vsk, return -EACCES; } - if (__vsock_find_bound_socket(&new_addr)) + if (__vsock_find_bound_socket_net(&new_addr, net)) return -EADDRINUSE; } @@ -1314,7 +1413,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg, goto out; } - if (!transport->dgram_allow(remote_addr->svm_cid, + if (!transport->dgram_allow(vsk, remote_addr->svm_cid, remote_addr->svm_port)) { err = -EINVAL; goto out; @@ -1355,7 +1454,7 @@ static int vsock_dgram_connect(struct socket *sock, if (err) goto out; - if (!vsk->transport->dgram_allow(remote_addr->svm_cid, + if (!vsk->transport->dgram_allow(vsk, remote_addr->svm_cid, remote_addr->svm_port)) { err = -EINVAL; goto out; @@ -1585,7 +1684,7 @@ static int vsock_connect(struct socket *sock, struct sockaddr_unsized *addr, * endpoints. */ if (!transport || - !transport->stream_allow(remote_addr->svm_cid, + !transport->stream_allow(vsk, remote_addr->svm_cid, remote_addr->svm_port)) { err = -ENETUNREACH; goto out; @@ -2662,6 +2761,180 @@ static struct miscdevice vsock_device = { .fops = &vsock_device_ops, }; +static int __vsock_net_mode_string(const struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos, + enum vsock_net_mode mode, + enum vsock_net_mode *new_mode) +{ + char data[VSOCK_NET_MODE_STR_MAX] = {0}; + struct ctl_table tmp; + int ret; + + if (!table->data || !table->maxlen || !*lenp) { + *lenp = 0; + return 0; + } + + tmp = *table; + tmp.data = data; + + if (!write) { + const char *p; + + switch (mode) { + case VSOCK_NET_MODE_GLOBAL: + p = VSOCK_NET_MODE_STR_GLOBAL; + break; + case VSOCK_NET_MODE_LOCAL: + p = VSOCK_NET_MODE_STR_LOCAL; + break; + default: + WARN_ONCE(true, "netns has invalid vsock mode"); + *lenp = 0; + return 0; + } + + strscpy(data, p, sizeof(data)); + tmp.maxlen = strlen(p); + } + + ret = proc_dostring(&tmp, write, buffer, lenp, ppos); + if (ret || !write) + return ret; + + if (*lenp >= sizeof(data)) + return -EINVAL; + + if (!strncmp(data, VSOCK_NET_MODE_STR_GLOBAL, sizeof(data))) + *new_mode = VSOCK_NET_MODE_GLOBAL; + else if (!strncmp(data, VSOCK_NET_MODE_STR_LOCAL, sizeof(data))) + *new_mode = VSOCK_NET_MODE_LOCAL; + else + return -EINVAL; + + return 0; +} + +static int vsock_net_mode_string(const struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net; + + if (write) + return -EPERM; + + net = current->nsproxy->net_ns; + + return __vsock_net_mode_string(table, write, buffer, lenp, ppos, + vsock_net_mode(net), NULL); +} + +static int vsock_net_child_mode_string(const struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + enum vsock_net_mode new_mode; + struct net *net; + int ret; + + net = current->nsproxy->net_ns; + + ret = __vsock_net_mode_string(table, write, buffer, lenp, ppos, + vsock_net_child_mode(net), &new_mode); + if (ret) + return ret; + + if (write) + vsock_net_set_child_mode(net, new_mode); + + return 0; +} + +static struct ctl_table vsock_table[] = { + { + .procname = "ns_mode", + .data = &init_net.vsock.mode, + .maxlen = VSOCK_NET_MODE_STR_MAX, + .mode = 0444, + .proc_handler = vsock_net_mode_string + }, + { + .procname = "child_ns_mode", + .data = &init_net.vsock.child_ns_mode, + .maxlen = VSOCK_NET_MODE_STR_MAX, + .mode = 0644, + .proc_handler = vsock_net_child_mode_string + }, +}; + +static int __net_init vsock_sysctl_register(struct net *net) +{ + struct ctl_table *table; + + if (net_eq(net, &init_net)) { + table = vsock_table; + } else { + table = kmemdup(vsock_table, sizeof(vsock_table), GFP_KERNEL); + if (!table) + goto err_alloc; + + table[0].data = &net->vsock.mode; + table[1].data = &net->vsock.child_ns_mode; + } + + net->vsock.sysctl_hdr = register_net_sysctl_sz(net, "net/vsock", table, + ARRAY_SIZE(vsock_table)); + if (!net->vsock.sysctl_hdr) + goto err_reg; + + return 0; + +err_reg: + if (!net_eq(net, &init_net)) + kfree(table); +err_alloc: + return -ENOMEM; +} + +static void vsock_sysctl_unregister(struct net *net) +{ + const struct ctl_table *table; + + table = net->vsock.sysctl_hdr->ctl_table_arg; + unregister_net_sysctl_table(net->vsock.sysctl_hdr); + if (!net_eq(net, &init_net)) + kfree(table); +} + +static void vsock_net_init(struct net *net) +{ + if (net_eq(net, &init_net)) + net->vsock.mode = VSOCK_NET_MODE_GLOBAL; + else + net->vsock.mode = vsock_net_child_mode(current->nsproxy->net_ns); + + net->vsock.child_ns_mode = VSOCK_NET_MODE_GLOBAL; +} + +static __net_init int vsock_sysctl_init_net(struct net *net) +{ + vsock_net_init(net); + + if (vsock_sysctl_register(net)) + return -ENOMEM; + + return 0; +} + +static __net_exit void vsock_sysctl_exit_net(struct net *net) +{ + vsock_sysctl_unregister(net); +} + +static struct pernet_operations vsock_sysctl_ops = { + .init = vsock_sysctl_init_net, + .exit = vsock_sysctl_exit_net, +}; + static int __init vsock_init(void) { int err = 0; @@ -2689,10 +2962,17 @@ static int __init vsock_init(void) goto err_unregister_proto; } + if (register_pernet_subsys(&vsock_sysctl_ops)) { + err = -ENOMEM; + goto err_unregister_sock; + } + vsock_bpf_build_proto(); return 0; +err_unregister_sock: + sock_unregister(AF_VSOCK); err_unregister_proto: proto_unregister(&vsock_proto); err_deregister_misc: @@ -2706,6 +2986,7 @@ static void __exit vsock_exit(void) misc_deregister(&vsock_device); sock_unregister(AF_VSOCK); proto_unregister(&vsock_proto); + unregister_pernet_subsys(&vsock_sysctl_ops); } const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk) diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c index 432fcbbd14d4..c3010c874308 100644 --- a/net/vmw_vsock/hyperv_transport.c +++ b/net/vmw_vsock/hyperv_transport.c @@ -570,7 +570,7 @@ static int hvs_dgram_enqueue(struct vsock_sock *vsk, return -EOPNOTSUPP; } -static bool hvs_dgram_allow(u32 cid, u32 port) +static bool hvs_dgram_allow(struct vsock_sock *vsk, u32 cid, u32 port) { return false; } @@ -745,8 +745,11 @@ static bool hvs_stream_is_active(struct vsock_sock *vsk) return hvs->chan != NULL; } -static bool hvs_stream_allow(u32 cid, u32 port) +static bool hvs_stream_allow(struct vsock_sock *vsk, u32 cid, u32 port) { + if (!vsock_net_mode_global(vsk)) + return false; + if (cid == VMADDR_CID_HOST) return true; diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c index 8c867023a2e5..3f7ea2db9bd7 100644 --- a/net/vmw_vsock/virtio_transport.c +++ b/net/vmw_vsock/virtio_transport.c @@ -231,7 +231,7 @@ static int virtio_transport_send_skb_fast_path(struct virtio_vsock *vsock, struc } static int -virtio_transport_send_pkt(struct sk_buff *skb) +virtio_transport_send_pkt(struct sk_buff *skb, struct net *net) { struct virtio_vsock_hdr *hdr; struct virtio_vsock *vsock; @@ -536,7 +536,13 @@ static bool virtio_transport_msgzerocopy_allow(void) return true; } -static bool virtio_transport_seqpacket_allow(u32 remote_cid); +bool virtio_transport_stream_allow(struct vsock_sock *vsk, u32 cid, u32 port) +{ + return vsock_net_mode_global(vsk); +} + +static bool virtio_transport_seqpacket_allow(struct vsock_sock *vsk, + u32 remote_cid); static struct virtio_transport virtio_transport = { .transport = { @@ -593,11 +599,15 @@ static struct virtio_transport virtio_transport = { .can_msgzerocopy = virtio_transport_can_msgzerocopy, }; -static bool virtio_transport_seqpacket_allow(u32 remote_cid) +static bool +virtio_transport_seqpacket_allow(struct vsock_sock *vsk, u32 remote_cid) { struct virtio_vsock *vsock; bool seqpacket_allow; + if (!vsock_net_mode_global(vsk)) + return false; + seqpacket_allow = false; rcu_read_lock(); vsock = rcu_dereference(the_virtio_vsock); @@ -660,7 +670,11 @@ static void virtio_transport_rx_work(struct work_struct *work) virtio_vsock_skb_put(skb, payload_len); virtio_transport_deliver_tap_pkt(skb); - virtio_transport_recv_pkt(&virtio_transport, skb); + + /* Force virtio-transport into global mode since it + * does not yet support local-mode namespacing. + */ + virtio_transport_recv_pkt(&virtio_transport, skb, NULL); } } while (!virtqueue_enable_cb(vq)); diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c index d3e26025ef58..d017ab318a7e 100644 --- a/net/vmw_vsock/virtio_transport_common.c +++ b/net/vmw_vsock/virtio_transport_common.c @@ -414,7 +414,7 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, virtio_transport_inc_tx_pkt(vvs, skb); - ret = t_ops->send_pkt(skb); + ret = t_ops->send_pkt(skb, info->net); if (ret < 0) break; @@ -526,6 +526,7 @@ static int virtio_transport_send_credit_update(struct vsock_sock *vsk) struct virtio_vsock_pkt_info info = { .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE, .vsk = vsk, + .net = sock_net(sk_vsock(vsk)), }; return virtio_transport_send_pkt_info(vsk, &info); @@ -1055,12 +1056,6 @@ bool virtio_transport_stream_is_active(struct vsock_sock *vsk) } EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active); -bool virtio_transport_stream_allow(u32 cid, u32 port) -{ - return true; -} -EXPORT_SYMBOL_GPL(virtio_transport_stream_allow); - int virtio_transport_dgram_bind(struct vsock_sock *vsk, struct sockaddr_vm *addr) { @@ -1068,7 +1063,7 @@ int virtio_transport_dgram_bind(struct vsock_sock *vsk, } EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind); -bool virtio_transport_dgram_allow(u32 cid, u32 port) +bool virtio_transport_dgram_allow(struct vsock_sock *vsk, u32 cid, u32 port) { return false; } @@ -1079,6 +1074,7 @@ int virtio_transport_connect(struct vsock_sock *vsk) struct virtio_vsock_pkt_info info = { .op = VIRTIO_VSOCK_OP_REQUEST, .vsk = vsk, + .net = sock_net(sk_vsock(vsk)), }; return virtio_transport_send_pkt_info(vsk, &info); @@ -1094,6 +1090,7 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int mode) (mode & SEND_SHUTDOWN ? VIRTIO_VSOCK_SHUTDOWN_SEND : 0), .vsk = vsk, + .net = sock_net(sk_vsock(vsk)), }; return virtio_transport_send_pkt_info(vsk, &info); @@ -1120,6 +1117,7 @@ virtio_transport_stream_enqueue(struct vsock_sock *vsk, .msg = msg, .pkt_len = len, .vsk = vsk, + .net = sock_net(sk_vsock(vsk)), }; return virtio_transport_send_pkt_info(vsk, &info); @@ -1157,6 +1155,7 @@ static int virtio_transport_reset(struct vsock_sock *vsk, .op = VIRTIO_VSOCK_OP_RST, .reply = !!skb, .vsk = vsk, + .net = sock_net(sk_vsock(vsk)), }; /* Send RST only if the original pkt is not a RST pkt */ @@ -1168,15 +1167,31 @@ static int virtio_transport_reset(struct vsock_sock *vsk, /* Normally packets are associated with a socket. There may be no socket if an * attempt was made to connect to a socket that does not exist. + * + * net refers to the namespace of whoever sent the invalid message. For + * loopback, this is the namespace of the socket. For vhost, this is the + * namespace of the VM (i.e., vhost_vsock). */ static int virtio_transport_reset_no_sock(const struct virtio_transport *t, - struct sk_buff *skb) + struct sk_buff *skb, struct net *net) { struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); struct virtio_vsock_pkt_info info = { .op = VIRTIO_VSOCK_OP_RST, .type = le16_to_cpu(hdr->type), .reply = true, + + /* Set sk owner to socket we are replying to (may be NULL for + * non-loopback). This keeps a reference to the sock and + * sock_net(sk) until the reply skb is freed. + */ + .vsk = vsock_sk(skb->sk), + + /* net is not defined here because we pass it directly to + * t->send_pkt(), instead of relying on + * virtio_transport_send_pkt_info() to pass it. It is not needed + * by virtio_transport_alloc_skb(). + */ }; struct sk_buff *reply; @@ -1195,7 +1210,7 @@ static int virtio_transport_reset_no_sock(const struct virtio_transport *t, if (!reply) return -ENOMEM; - return t->send_pkt(reply); + return t->send_pkt(reply, net); } /* This function should be called with sk_lock held and SOCK_DONE set */ @@ -1479,6 +1494,7 @@ virtio_transport_send_response(struct vsock_sock *vsk, .remote_port = le32_to_cpu(hdr->src_port), .reply = true, .vsk = vsk, + .net = sock_net(sk_vsock(vsk)), }; return virtio_transport_send_pkt_info(vsk, &info); @@ -1521,12 +1537,12 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb, int ret; if (le16_to_cpu(hdr->op) != VIRTIO_VSOCK_OP_REQUEST) { - virtio_transport_reset_no_sock(t, skb); + virtio_transport_reset_no_sock(t, skb, sock_net(sk)); return -EINVAL; } if (sk_acceptq_is_full(sk)) { - virtio_transport_reset_no_sock(t, skb); + virtio_transport_reset_no_sock(t, skb, sock_net(sk)); return -ENOMEM; } @@ -1534,13 +1550,13 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb, * Subsequent enqueues would lead to a memory leak. */ if (sk->sk_shutdown == SHUTDOWN_MASK) { - virtio_transport_reset_no_sock(t, skb); + virtio_transport_reset_no_sock(t, skb, sock_net(sk)); return -ESHUTDOWN; } child = vsock_create_connected(sk); if (!child) { - virtio_transport_reset_no_sock(t, skb); + virtio_transport_reset_no_sock(t, skb, sock_net(sk)); return -ENOMEM; } @@ -1562,7 +1578,7 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb, */ if (ret || vchild->transport != &t->transport) { release_sock(child); - virtio_transport_reset_no_sock(t, skb); + virtio_transport_reset_no_sock(t, skb, sock_net(sk)); sock_put(child); return ret; } @@ -1590,7 +1606,7 @@ static bool virtio_transport_valid_type(u16 type) * lock. */ void virtio_transport_recv_pkt(struct virtio_transport *t, - struct sk_buff *skb) + struct sk_buff *skb, struct net *net) { struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); struct sockaddr_vm src, dst; @@ -1613,24 +1629,24 @@ void virtio_transport_recv_pkt(struct virtio_transport *t, le32_to_cpu(hdr->fwd_cnt)); if (!virtio_transport_valid_type(le16_to_cpu(hdr->type))) { - (void)virtio_transport_reset_no_sock(t, skb); + (void)virtio_transport_reset_no_sock(t, skb, net); goto free_pkt; } /* The socket must be in connected or bound table * otherwise send reset back */ - sk = vsock_find_connected_socket(&src, &dst); + sk = vsock_find_connected_socket_net(&src, &dst, net); if (!sk) { - sk = vsock_find_bound_socket(&dst); + sk = vsock_find_bound_socket_net(&dst, net); if (!sk) { - (void)virtio_transport_reset_no_sock(t, skb); + (void)virtio_transport_reset_no_sock(t, skb, net); goto free_pkt; } } if (virtio_transport_get_type(sk) != le16_to_cpu(hdr->type)) { - (void)virtio_transport_reset_no_sock(t, skb); + (void)virtio_transport_reset_no_sock(t, skb, net); sock_put(sk); goto free_pkt; } @@ -1649,7 +1665,7 @@ void virtio_transport_recv_pkt(struct virtio_transport *t, */ if (sock_flag(sk, SOCK_DONE) || (sk->sk_state != TCP_LISTEN && vsk->transport != &t->transport)) { - (void)virtio_transport_reset_no_sock(t, skb); + (void)virtio_transport_reset_no_sock(t, skb, net); release_sock(sk); sock_put(sk); goto free_pkt; @@ -1681,7 +1697,7 @@ void virtio_transport_recv_pkt(struct virtio_transport *t, kfree_skb(skb); break; default: - (void)virtio_transport_reset_no_sock(t, skb); + (void)virtio_transport_reset_no_sock(t, skb, net); kfree_skb(skb); break; } diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c index 7eccd6708d66..a64522be1bad 100644 --- a/net/vmw_vsock/vmci_transport.c +++ b/net/vmw_vsock/vmci_transport.c @@ -161,7 +161,7 @@ vmci_transport_packet_init(struct vmci_transport_packet *pkt, case VMCI_TRANSPORT_PACKET_TYPE_WAITING_READ: case VMCI_TRANSPORT_PACKET_TYPE_WAITING_WRITE: - memcpy(&pkt->u.wait, wait, sizeof(pkt->u.wait)); + pkt->u.wait = *wait; break; case VMCI_TRANSPORT_PACKET_TYPE_REQUEST2: @@ -646,13 +646,17 @@ static int vmci_transport_recv_dgram_cb(void *data, struct vmci_datagram *dg) return VMCI_SUCCESS; } -static bool vmci_transport_stream_allow(u32 cid, u32 port) +static bool vmci_transport_stream_allow(struct vsock_sock *vsk, u32 cid, + u32 port) { static const u32 non_socket_contexts[] = { VMADDR_CID_LOCAL, }; int i; + if (!vsock_net_mode_global(vsk)) + return false; + BUILD_BUG_ON(sizeof(cid) != sizeof(*non_socket_contexts)); for (i = 0; i < ARRAY_SIZE(non_socket_contexts); i++) { @@ -682,12 +686,10 @@ static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg) err = VMCI_SUCCESS; bh_process_pkt = false; - /* Ignore incoming packets from contexts without sockets, or resources - * that aren't vsock implementations. + /* Ignore incoming packets from resources that aren't vsock + * implementations. */ - - if (!vmci_transport_stream_allow(dg->src.context, -1) - || vmci_transport_peer_rid(dg->src.context) != dg->src.resource) + if (vmci_transport_peer_rid(dg->src.context) != dg->src.resource) return VMCI_ERROR_NO_ACCESS; if (VMCI_DG_SIZE(dg) < sizeof(*pkt)) @@ -749,6 +751,12 @@ static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg) goto out; } + /* Ignore incoming packets from contexts without sockets. */ + if (!vmci_transport_stream_allow(vsk, dg->src.context, -1)) { + err = VMCI_ERROR_NO_ACCESS; + goto out; + } + /* We do most everything in a work queue, but let's fast path the * notification of reads and writes to help data transfer performance. * We can only do this if there is no process context code executing @@ -1784,8 +1792,12 @@ out: return err; } -static bool vmci_transport_dgram_allow(u32 cid, u32 port) +static bool vmci_transport_dgram_allow(struct vsock_sock *vsk, u32 cid, + u32 port) { + if (!vsock_net_mode_global(vsk)) + return false; + if (cid == VMADDR_CID_HYPERVISOR) { /* Registrations of PBRPC Servers do not modify VMX/Hypervisor * state and are allowed. diff --git a/net/vmw_vsock/vsock_loopback.c b/net/vmw_vsock/vsock_loopback.c index bc2ff918b315..8068d1b6e851 100644 --- a/net/vmw_vsock/vsock_loopback.c +++ b/net/vmw_vsock/vsock_loopback.c @@ -26,7 +26,7 @@ static u32 vsock_loopback_get_local_cid(void) return VMADDR_CID_LOCAL; } -static int vsock_loopback_send_pkt(struct sk_buff *skb) +static int vsock_loopback_send_pkt(struct sk_buff *skb, struct net *net) { struct vsock_loopback *vsock = &the_vsock_loopback; int len = skb->len; @@ -46,7 +46,15 @@ static int vsock_loopback_cancel_pkt(struct vsock_sock *vsk) return 0; } -static bool vsock_loopback_seqpacket_allow(u32 remote_cid); +static bool vsock_loopback_seqpacket_allow(struct vsock_sock *vsk, + u32 remote_cid); + +static bool vsock_loopback_stream_allow(struct vsock_sock *vsk, u32 cid, + u32 port) +{ + return true; +} + static bool vsock_loopback_msgzerocopy_allow(void) { return true; @@ -76,7 +84,7 @@ static struct virtio_transport loopback_transport = { .stream_has_space = virtio_transport_stream_has_space, .stream_rcvhiwat = virtio_transport_stream_rcvhiwat, .stream_is_active = virtio_transport_stream_is_active, - .stream_allow = virtio_transport_stream_allow, + .stream_allow = vsock_loopback_stream_allow, .seqpacket_dequeue = virtio_transport_seqpacket_dequeue, .seqpacket_enqueue = virtio_transport_seqpacket_enqueue, @@ -106,9 +114,10 @@ static struct virtio_transport loopback_transport = { .send_pkt = vsock_loopback_send_pkt, }; -static bool vsock_loopback_seqpacket_allow(u32 remote_cid) +static bool +vsock_loopback_seqpacket_allow(struct vsock_sock *vsk, u32 remote_cid) { - return true; + return vsock_net_mode_global(vsk); } static void vsock_loopback_work(struct work_struct *work) @@ -130,7 +139,8 @@ static void vsock_loopback_work(struct work_struct *work) */ virtio_transport_consume_skb_sent(skb, false); virtio_transport_deliver_tap_pkt(skb); - virtio_transport_recv_pkt(&loopback_transport, skb); + virtio_transport_recv_pkt(&loopback_transport, skb, + sock_net(skb->sk)); } } |
