summaryrefslogtreecommitdiff
path: root/net/psp/psp_nl.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/psp/psp_nl.c')
-rw-r--r--net/psp/psp_nl.c505
1 files changed, 505 insertions, 0 deletions
diff --git a/net/psp/psp_nl.c b/net/psp/psp_nl.c
new file mode 100644
index 000000000000..8aaca62744c3
--- /dev/null
+++ b/net/psp/psp_nl.c
@@ -0,0 +1,505 @@
+// SPDX-License-Identifier: GPL-2.0-only
+
+#include <linux/skbuff.h>
+#include <linux/xarray.h>
+#include <net/genetlink.h>
+#include <net/psp.h>
+#include <net/sock.h>
+
+#include "psp-nl-gen.h"
+#include "psp.h"
+
+/* Netlink helpers */
+
+static struct sk_buff *psp_nl_reply_new(struct genl_info *info)
+{
+ struct sk_buff *rsp;
+ void *hdr;
+
+ rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
+ if (!rsp)
+ return NULL;
+
+ hdr = genlmsg_iput(rsp, info);
+ if (!hdr) {
+ nlmsg_free(rsp);
+ return NULL;
+ }
+
+ return rsp;
+}
+
+static int psp_nl_reply_send(struct sk_buff *rsp, struct genl_info *info)
+{
+ /* Note that this *only* works with a single message per skb! */
+ nlmsg_end(rsp, (struct nlmsghdr *)rsp->data);
+
+ return genlmsg_reply(rsp, info);
+}
+
+/* Device stuff */
+
+static struct psp_dev *
+psp_device_get_and_lock(struct net *net, struct nlattr *dev_id)
+{
+ struct psp_dev *psd;
+ int err;
+
+ mutex_lock(&psp_devs_lock);
+ psd = xa_load(&psp_devs, nla_get_u32(dev_id));
+ if (!psd) {
+ mutex_unlock(&psp_devs_lock);
+ return ERR_PTR(-ENODEV);
+ }
+
+ mutex_lock(&psd->lock);
+ mutex_unlock(&psp_devs_lock);
+
+ err = psp_dev_check_access(psd, net);
+ if (err) {
+ mutex_unlock(&psd->lock);
+ return ERR_PTR(err);
+ }
+
+ return psd;
+}
+
+int psp_device_get_locked(const struct genl_split_ops *ops,
+ struct sk_buff *skb, struct genl_info *info)
+{
+ if (GENL_REQ_ATTR_CHECK(info, PSP_A_DEV_ID))
+ return -EINVAL;
+
+ info->user_ptr[0] = psp_device_get_and_lock(genl_info_net(info),
+ info->attrs[PSP_A_DEV_ID]);
+ return PTR_ERR_OR_ZERO(info->user_ptr[0]);
+}
+
+void
+psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb,
+ struct genl_info *info)
+{
+ struct socket *socket = info->user_ptr[1];
+ struct psp_dev *psd = info->user_ptr[0];
+
+ mutex_unlock(&psd->lock);
+ if (socket)
+ sockfd_put(socket);
+}
+
+static int
+psp_nl_dev_fill(struct psp_dev *psd, struct sk_buff *rsp,
+ const struct genl_info *info)
+{
+ void *hdr;
+
+ hdr = genlmsg_iput(rsp, info);
+ if (!hdr)
+ return -EMSGSIZE;
+
+ if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
+ nla_put_u32(rsp, PSP_A_DEV_IFINDEX, psd->main_netdev->ifindex) ||
+ nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_CAP, psd->caps->versions) ||
+ nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_ENA, psd->config.versions))
+ goto err_cancel_msg;
+
+ genlmsg_end(rsp, hdr);
+ return 0;
+
+err_cancel_msg:
+ genlmsg_cancel(rsp, hdr);
+ return -EMSGSIZE;
+}
+
+void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd)
+{
+ struct genl_info info;
+ struct sk_buff *ntf;
+
+ if (!genl_has_listeners(&psp_nl_family, dev_net(psd->main_netdev),
+ PSP_NLGRP_MGMT))
+ return;
+
+ ntf = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
+ if (!ntf)
+ return;
+
+ genl_info_init_ntf(&info, &psp_nl_family, cmd);
+ if (psp_nl_dev_fill(psd, ntf, &info)) {
+ nlmsg_free(ntf);
+ return;
+ }
+
+ genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
+ 0, PSP_NLGRP_MGMT, GFP_KERNEL);
+}
+
+int psp_nl_dev_get_doit(struct sk_buff *req, struct genl_info *info)
+{
+ struct psp_dev *psd = info->user_ptr[0];
+ struct sk_buff *rsp;
+ int err;
+
+ rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
+ if (!rsp)
+ return -ENOMEM;
+
+ err = psp_nl_dev_fill(psd, rsp, info);
+ if (err)
+ goto err_free_msg;
+
+ return genlmsg_reply(rsp, info);
+
+err_free_msg:
+ nlmsg_free(rsp);
+ return err;
+}
+
+static int
+psp_nl_dev_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb,
+ struct psp_dev *psd)
+{
+ if (psp_dev_check_access(psd, sock_net(rsp->sk)))
+ return 0;
+
+ return psp_nl_dev_fill(psd, rsp, genl_info_dump(cb));
+}
+
+int psp_nl_dev_get_dumpit(struct sk_buff *rsp, struct netlink_callback *cb)
+{
+ struct psp_dev *psd;
+ int err = 0;
+
+ mutex_lock(&psp_devs_lock);
+ xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) {
+ mutex_lock(&psd->lock);
+ err = psp_nl_dev_get_dumpit_one(rsp, cb, psd);
+ mutex_unlock(&psd->lock);
+ if (err)
+ break;
+ }
+ mutex_unlock(&psp_devs_lock);
+
+ return err;
+}
+
+int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info)
+{
+ struct psp_dev *psd = info->user_ptr[0];
+ struct psp_dev_config new_config;
+ struct sk_buff *rsp;
+ int err;
+
+ memcpy(&new_config, &psd->config, sizeof(new_config));
+
+ if (info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]) {
+ new_config.versions =
+ nla_get_u32(info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]);
+ if (new_config.versions & ~psd->caps->versions) {
+ NL_SET_ERR_MSG(info->extack, "Requested PSP versions not supported by the device");
+ return -EINVAL;
+ }
+ } else {
+ NL_SET_ERR_MSG(info->extack, "No settings present");
+ return -EINVAL;
+ }
+
+ rsp = psp_nl_reply_new(info);
+ if (!rsp)
+ return -ENOMEM;
+
+ if (memcmp(&new_config, &psd->config, sizeof(new_config))) {
+ err = psd->ops->set_config(psd, &new_config, info->extack);
+ if (err)
+ goto err_free_rsp;
+
+ memcpy(&psd->config, &new_config, sizeof(new_config));
+ }
+
+ psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF);
+
+ return psp_nl_reply_send(rsp, info);
+
+err_free_rsp:
+ nlmsg_free(rsp);
+ return err;
+}
+
+int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info)
+{
+ struct psp_dev *psd = info->user_ptr[0];
+ struct genl_info ntf_info;
+ struct sk_buff *ntf, *rsp;
+ u8 prev_gen;
+ int err;
+
+ rsp = psp_nl_reply_new(info);
+ if (!rsp)
+ return -ENOMEM;
+
+ genl_info_init_ntf(&ntf_info, &psp_nl_family, PSP_CMD_KEY_ROTATE_NTF);
+ ntf = psp_nl_reply_new(&ntf_info);
+ if (!ntf) {
+ err = -ENOMEM;
+ goto err_free_rsp;
+ }
+
+ if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
+ nla_put_u32(ntf, PSP_A_DEV_ID, psd->id)) {
+ err = -EMSGSIZE;
+ goto err_free_ntf;
+ }
+
+ /* suggest the next gen number, driver can override */
+ prev_gen = psd->generation;
+ psd->generation = (prev_gen + 1) & PSP_GEN_VALID_MASK;
+
+ err = psd->ops->key_rotate(psd, info->extack);
+ if (err)
+ goto err_free_ntf;
+
+ WARN_ON_ONCE((psd->generation && psd->generation == prev_gen) ||
+ psd->generation & ~PSP_GEN_VALID_MASK);
+
+ psp_assocs_key_rotated(psd);
+
+ nlmsg_end(ntf, (struct nlmsghdr *)ntf->data);
+ genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
+ 0, PSP_NLGRP_USE, GFP_KERNEL);
+ return psp_nl_reply_send(rsp, info);
+
+err_free_ntf:
+ nlmsg_free(ntf);
+err_free_rsp:
+ nlmsg_free(rsp);
+ return err;
+}
+
+/* Key etc. */
+
+int psp_assoc_device_get_locked(const struct genl_split_ops *ops,
+ struct sk_buff *skb, struct genl_info *info)
+{
+ struct socket *socket;
+ struct psp_dev *psd;
+ struct nlattr *id;
+ int fd, err;
+
+ if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_SOCK_FD))
+ return -EINVAL;
+
+ fd = nla_get_u32(info->attrs[PSP_A_ASSOC_SOCK_FD]);
+ socket = sockfd_lookup(fd, &err);
+ if (!socket)
+ return err;
+
+ if (!sk_is_tcp(socket->sk)) {
+ NL_SET_ERR_MSG_ATTR(info->extack,
+ info->attrs[PSP_A_ASSOC_SOCK_FD],
+ "Unsupported socket family and type");
+ err = -EOPNOTSUPP;
+ goto err_sock_put;
+ }
+
+ psd = psp_dev_get_for_sock(socket->sk);
+ if (psd) {
+ err = psp_dev_check_access(psd, genl_info_net(info));
+ if (err) {
+ psp_dev_put(psd);
+ psd = NULL;
+ }
+ }
+
+ if (!psd && GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_DEV_ID)) {
+ err = -EINVAL;
+ goto err_sock_put;
+ }
+
+ id = info->attrs[PSP_A_ASSOC_DEV_ID];
+ if (psd) {
+ mutex_lock(&psd->lock);
+ if (id && psd->id != nla_get_u32(id)) {
+ mutex_unlock(&psd->lock);
+ NL_SET_ERR_MSG_ATTR(info->extack, id,
+ "Device id vs socket mismatch");
+ err = -EINVAL;
+ goto err_psd_put;
+ }
+
+ psp_dev_put(psd);
+ } else {
+ psd = psp_device_get_and_lock(genl_info_net(info), id);
+ if (IS_ERR(psd)) {
+ err = PTR_ERR(psd);
+ goto err_sock_put;
+ }
+ }
+
+ info->user_ptr[0] = psd;
+ info->user_ptr[1] = socket;
+
+ return 0;
+
+err_psd_put:
+ psp_dev_put(psd);
+err_sock_put:
+ sockfd_put(socket);
+ return err;
+}
+
+static int
+psp_nl_parse_key(struct genl_info *info, u32 attr, struct psp_key_parsed *key,
+ unsigned int key_sz)
+{
+ struct nlattr *nest = info->attrs[attr];
+ struct nlattr *tb[PSP_A_KEYS_SPI + 1];
+ u32 spi;
+ int err;
+
+ err = nla_parse_nested(tb, ARRAY_SIZE(tb) - 1, nest,
+ psp_keys_nl_policy, info->extack);
+ if (err)
+ return err;
+
+ if (NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_KEY) ||
+ NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_SPI))
+ return -EINVAL;
+
+ if (nla_len(tb[PSP_A_KEYS_KEY]) != key_sz) {
+ NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
+ "incorrect key length");
+ return -EINVAL;
+ }
+
+ spi = nla_get_u32(tb[PSP_A_KEYS_SPI]);
+ if (!(spi & PSP_SPI_KEY_ID)) {
+ NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
+ "invalid SPI: lower 31b must be non-zero");
+ return -EINVAL;
+ }
+
+ key->spi = cpu_to_be32(spi);
+ memcpy(key->key, nla_data(tb[PSP_A_KEYS_KEY]), key_sz);
+
+ return 0;
+}
+
+static int
+psp_nl_put_key(struct sk_buff *skb, u32 attr, u32 version,
+ struct psp_key_parsed *key)
+{
+ int key_sz = psp_key_size(version);
+ void *nest;
+
+ nest = nla_nest_start(skb, attr);
+
+ if (nla_put_u32(skb, PSP_A_KEYS_SPI, be32_to_cpu(key->spi)) ||
+ nla_put(skb, PSP_A_KEYS_KEY, key_sz, key->key)) {
+ nla_nest_cancel(skb, nest);
+ return -EMSGSIZE;
+ }
+
+ nla_nest_end(skb, nest);
+
+ return 0;
+}
+
+int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
+{
+ struct socket *socket = info->user_ptr[1];
+ struct psp_dev *psd = info->user_ptr[0];
+ struct psp_key_parsed key;
+ struct psp_assoc *pas;
+ struct sk_buff *rsp;
+ u32 version;
+ int err;
+
+ if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION))
+ return -EINVAL;
+
+ version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
+ if (!(psd->caps->versions & (1 << version))) {
+ NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
+ return -EOPNOTSUPP;
+ }
+
+ rsp = psp_nl_reply_new(info);
+ if (!rsp)
+ return -ENOMEM;
+
+ pas = psp_assoc_create(psd);
+ if (!pas) {
+ err = -ENOMEM;
+ goto err_free_rsp;
+ }
+ pas->version = version;
+
+ err = psd->ops->rx_spi_alloc(psd, version, &key, info->extack);
+ if (err)
+ goto err_free_pas;
+
+ if (nla_put_u32(rsp, PSP_A_ASSOC_DEV_ID, psd->id) ||
+ psp_nl_put_key(rsp, PSP_A_ASSOC_RX_KEY, version, &key)) {
+ err = -EMSGSIZE;
+ goto err_free_pas;
+ }
+
+ err = psp_sock_assoc_set_rx(socket->sk, pas, &key, info->extack);
+ if (err) {
+ NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_SOCK_FD]);
+ goto err_free_pas;
+ }
+ psp_assoc_put(pas);
+
+ return psp_nl_reply_send(rsp, info);
+
+err_free_pas:
+ psp_assoc_put(pas);
+err_free_rsp:
+ nlmsg_free(rsp);
+ return err;
+}
+
+int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
+{
+ struct socket *socket = info->user_ptr[1];
+ struct psp_dev *psd = info->user_ptr[0];
+ struct psp_key_parsed key;
+ struct sk_buff *rsp;
+ unsigned int key_sz;
+ u32 version;
+ int err;
+
+ if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION) ||
+ GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_TX_KEY))
+ return -EINVAL;
+
+ version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
+ if (!(psd->caps->versions & (1 << version))) {
+ NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
+ return -EOPNOTSUPP;
+ }
+
+ key_sz = psp_key_size(version);
+ if (!key_sz)
+ return -EINVAL;
+
+ err = psp_nl_parse_key(info, PSP_A_ASSOC_TX_KEY, &key, key_sz);
+ if (err < 0)
+ return err;
+
+ rsp = psp_nl_reply_new(info);
+ if (!rsp)
+ return -ENOMEM;
+
+ err = psp_sock_assoc_set_tx(socket->sk, psd, version, &key,
+ info->extack);
+ if (err)
+ goto err_free_msg;
+
+ return psp_nl_reply_send(rsp, info);
+
+err_free_msg:
+ nlmsg_free(rsp);
+ return err;
+}