summaryrefslogtreecommitdiff
path: root/net/mpls/af_mpls.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/mpls/af_mpls.c')
-rw-r--r--net/mpls/af_mpls.c321
1 files changed, 199 insertions, 122 deletions
diff --git a/net/mpls/af_mpls.c b/net/mpls/af_mpls.c
index 25c88cba5c48..580aac112dd2 100644
--- a/net/mpls/af_mpls.c
+++ b/net/mpls/af_mpls.c
@@ -75,16 +75,23 @@ static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
struct nlmsghdr *nlh, struct net *net, u32 portid,
unsigned int nlm_flags);
-static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned index)
+static struct mpls_route *mpls_route_input(struct net *net, unsigned int index)
{
- struct mpls_route *rt = NULL;
+ struct mpls_route __rcu **platform_label;
- if (index < net->mpls.platform_labels) {
- struct mpls_route __rcu **platform_label =
- rcu_dereference_rtnl(net->mpls.platform_label);
- rt = rcu_dereference_rtnl(platform_label[index]);
- }
- return rt;
+ platform_label = mpls_dereference(net, net->mpls.platform_label);
+ return mpls_dereference(net, platform_label[index]);
+}
+
+static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned int index)
+{
+ struct mpls_route __rcu **platform_label;
+
+ if (index >= net->mpls.platform_labels)
+ return NULL;
+
+ platform_label = rcu_dereference(net->mpls.platform_label);
+ return rcu_dereference(platform_label[index]);
}
bool mpls_output_possible(const struct net_device *dev)
@@ -129,25 +136,26 @@ bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu)
}
EXPORT_SYMBOL_GPL(mpls_pkt_too_big);
-void mpls_stats_inc_outucastpkts(struct net_device *dev,
+void mpls_stats_inc_outucastpkts(struct net *net,
+ struct net_device *dev,
const struct sk_buff *skb)
{
struct mpls_dev *mdev;
if (skb->protocol == htons(ETH_P_MPLS_UC)) {
- mdev = mpls_dev_get(dev);
+ mdev = mpls_dev_rcu(dev);
if (mdev)
MPLS_INC_STATS_LEN(mdev, skb->len,
tx_packets,
tx_bytes);
} else if (skb->protocol == htons(ETH_P_IP)) {
- IP_UPD_PO_STATS(dev_net(dev), IPSTATS_MIB_OUT, skb->len);
+ IP_UPD_PO_STATS(net, IPSTATS_MIB_OUT, skb->len);
#if IS_ENABLED(CONFIG_IPV6)
} else if (skb->protocol == htons(ETH_P_IPV6)) {
- struct inet6_dev *in6dev = __in6_dev_get(dev);
+ struct inet6_dev *in6dev = in6_dev_rcu(dev);
if (in6dev)
- IP6_UPD_PO_STATS(dev_net(dev), in6dev,
+ IP6_UPD_PO_STATS(net, in6dev,
IPSTATS_MIB_OUT, skb->len);
#endif
}
@@ -342,7 +350,7 @@ static bool mpls_egress(struct net *net, struct mpls_route *rt,
static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
struct packet_type *pt, struct net_device *orig_dev)
{
- struct net *net = dev_net(dev);
+ struct net *net = dev_net_rcu(dev);
struct mpls_shim_hdr *hdr;
const struct mpls_nh *nh;
struct mpls_route *rt;
@@ -357,7 +365,7 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
/* Careful this entire function runs inside of an rcu critical section */
- mdev = mpls_dev_get(dev);
+ mdev = mpls_dev_rcu(dev);
if (!mdev)
goto drop;
@@ -434,7 +442,7 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
dec.ttl -= 1;
if (unlikely(!new_header_size && dec.bos)) {
/* Penultimate hop popping */
- if (!mpls_egress(dev_net(out_dev), rt, skb, dec))
+ if (!mpls_egress(net, rt, skb, dec))
goto err;
} else {
bool bos;
@@ -451,7 +459,7 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
}
}
- mpls_stats_inc_outucastpkts(out_dev, skb);
+ mpls_stats_inc_outucastpkts(net, out_dev, skb);
/* If via wasn't specified then send out using device address */
if (nh->nh_via_table == MPLS_NEIGH_TABLE_UNSPEC)
@@ -466,7 +474,7 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
return 0;
tx_err:
- out_mdev = out_dev ? mpls_dev_get(out_dev) : NULL;
+ out_mdev = out_dev ? mpls_dev_rcu(out_dev) : NULL;
if (out_mdev)
MPLS_INC_STATS(out_mdev, tx_errors);
goto drop;
@@ -530,10 +538,23 @@ static struct mpls_route *mpls_rt_alloc(u8 num_nh, u8 max_alen, u8 max_labels)
return rt;
}
+static void mpls_rt_free_rcu(struct rcu_head *head)
+{
+ struct mpls_route *rt;
+
+ rt = container_of(head, struct mpls_route, rt_rcu);
+
+ change_nexthops(rt) {
+ netdev_put(nh->nh_dev, &nh->nh_dev_tracker);
+ } endfor_nexthops(rt);
+
+ kfree(rt);
+}
+
static void mpls_rt_free(struct mpls_route *rt)
{
if (rt)
- kfree_rcu(rt, rt_rcu);
+ call_rcu(&rt->rt_rcu, mpls_rt_free_rcu);
}
static void mpls_notify_route(struct net *net, unsigned index,
@@ -557,10 +578,8 @@ static void mpls_route_update(struct net *net, unsigned index,
struct mpls_route __rcu **platform_label;
struct mpls_route *rt;
- ASSERT_RTNL();
-
- platform_label = rtnl_dereference(net->mpls.platform_label);
- rt = rtnl_dereference(platform_label[index]);
+ platform_label = mpls_dereference(net, net->mpls.platform_label);
+ rt = mpls_dereference(net, platform_label[index]);
rcu_assign_pointer(platform_label[index], new);
mpls_notify_route(net, index, rt, new, info);
@@ -569,24 +588,23 @@ static void mpls_route_update(struct net *net, unsigned index,
mpls_rt_free(rt);
}
-static unsigned find_free_label(struct net *net)
+static unsigned int find_free_label(struct net *net)
{
- struct mpls_route __rcu **platform_label;
- size_t platform_labels;
- unsigned index;
+ unsigned int index;
- platform_label = rtnl_dereference(net->mpls.platform_label);
- platform_labels = net->mpls.platform_labels;
- for (index = MPLS_LABEL_FIRST_UNRESERVED; index < platform_labels;
+ for (index = MPLS_LABEL_FIRST_UNRESERVED;
+ index < net->mpls.platform_labels;
index++) {
- if (!rtnl_dereference(platform_label[index]))
+ if (!mpls_route_input(net, index))
return index;
}
+
return LABEL_NOT_SPECIFIED;
}
#if IS_ENABLED(CONFIG_INET)
static struct net_device *inet_fib_lookup_dev(struct net *net,
+ struct mpls_nh *nh,
const void *addr)
{
struct net_device *dev;
@@ -599,14 +617,14 @@ static struct net_device *inet_fib_lookup_dev(struct net *net,
return ERR_CAST(rt);
dev = rt->dst.dev;
- dev_hold(dev);
-
+ netdev_hold(dev, &nh->nh_dev_tracker, GFP_KERNEL);
ip_rt_put(rt);
return dev;
}
#else
static struct net_device *inet_fib_lookup_dev(struct net *net,
+ struct mpls_nh *nh,
const void *addr)
{
return ERR_PTR(-EAFNOSUPPORT);
@@ -615,6 +633,7 @@ static struct net_device *inet_fib_lookup_dev(struct net *net,
#if IS_ENABLED(CONFIG_IPV6)
static struct net_device *inet6_fib_lookup_dev(struct net *net,
+ struct mpls_nh *nh,
const void *addr)
{
struct net_device *dev;
@@ -631,13 +650,14 @@ static struct net_device *inet6_fib_lookup_dev(struct net *net,
return ERR_CAST(dst);
dev = dst->dev;
- dev_hold(dev);
+ netdev_hold(dev, &nh->nh_dev_tracker, GFP_KERNEL);
dst_release(dst);
return dev;
}
#else
static struct net_device *inet6_fib_lookup_dev(struct net *net,
+ struct mpls_nh *nh,
const void *addr)
{
return ERR_PTR(-EAFNOSUPPORT);
@@ -653,16 +673,17 @@ static struct net_device *find_outdev(struct net *net,
if (!oif) {
switch (nh->nh_via_table) {
case NEIGH_ARP_TABLE:
- dev = inet_fib_lookup_dev(net, mpls_nh_via(rt, nh));
+ dev = inet_fib_lookup_dev(net, nh, mpls_nh_via(rt, nh));
break;
case NEIGH_ND_TABLE:
- dev = inet6_fib_lookup_dev(net, mpls_nh_via(rt, nh));
+ dev = inet6_fib_lookup_dev(net, nh, mpls_nh_via(rt, nh));
break;
case NEIGH_LINK_TABLE:
break;
}
} else {
- dev = dev_get_by_index(net, oif);
+ dev = netdev_get_by_index(net, oif,
+ &nh->nh_dev_tracker, GFP_KERNEL);
}
if (!dev)
@@ -671,8 +692,7 @@ static struct net_device *find_outdev(struct net *net,
if (IS_ERR(dev))
return dev;
- /* The caller is holding rtnl anyways, so release the dev reference */
- dev_put(dev);
+ nh->nh_dev = dev;
return dev;
}
@@ -686,20 +706,17 @@ static int mpls_nh_assign_dev(struct net *net, struct mpls_route *rt,
dev = find_outdev(net, rt, nh, oif);
if (IS_ERR(dev)) {
err = PTR_ERR(dev);
- dev = NULL;
goto errout;
}
/* Ensure this is a supported device */
err = -EINVAL;
- if (!mpls_dev_get(dev))
- goto errout;
+ if (!mpls_dev_get(net, dev))
+ goto errout_put;
if ((nh->nh_via_table == NEIGH_LINK_TABLE) &&
(dev->addr_len != nh->nh_via_alen))
- goto errout;
-
- nh->nh_dev = dev;
+ goto errout_put;
if (!(dev->flags & IFF_UP)) {
nh->nh_flags |= RTNH_F_DEAD;
@@ -713,6 +730,9 @@ static int mpls_nh_assign_dev(struct net *net, struct mpls_route *rt,
return 0;
+errout_put:
+ netdev_put(nh->nh_dev, &nh->nh_dev_tracker);
+ nh->nh_dev = NULL;
errout:
return err;
}
@@ -890,7 +910,8 @@ static int mpls_nh_build_multi(struct mpls_route_config *cfg,
struct nlattr *nla_via, *nla_newdst;
int remaining = cfg->rc_mp_len;
int err = 0;
- u8 nhs = 0;
+
+ rt->rt_nhn = 0;
change_nexthops(rt) {
int attrlen;
@@ -926,11 +947,9 @@ static int mpls_nh_build_multi(struct mpls_route_config *cfg,
rt->rt_nhn_alive--;
rtnh = rtnh_next(rtnh, &remaining);
- nhs++;
+ rt->rt_nhn++;
} endfor_nexthops(rt);
- rt->rt_nhn = nhs;
-
return 0;
errout:
@@ -940,30 +959,28 @@ errout:
static bool mpls_label_ok(struct net *net, unsigned int *index,
struct netlink_ext_ack *extack)
{
- bool is_ok = true;
-
/* Reserved labels may not be set */
if (*index < MPLS_LABEL_FIRST_UNRESERVED) {
NL_SET_ERR_MSG(extack,
"Invalid label - must be MPLS_LABEL_FIRST_UNRESERVED or higher");
- is_ok = false;
+ return false;
}
/* The full 20 bit range may not be supported. */
- if (is_ok && *index >= net->mpls.platform_labels) {
+ if (*index >= net->mpls.platform_labels) {
NL_SET_ERR_MSG(extack,
"Label >= configured maximum in platform_labels");
- is_ok = false;
+ return false;
}
*index = array_index_nospec(*index, net->mpls.platform_labels);
- return is_ok;
+
+ return true;
}
static int mpls_route_add(struct mpls_route_config *cfg,
struct netlink_ext_ack *extack)
{
- struct mpls_route __rcu **platform_label;
struct net *net = cfg->rc_nlinfo.nl_net;
struct mpls_route *rt, *old;
int err = -EINVAL;
@@ -991,8 +1008,7 @@ static int mpls_route_add(struct mpls_route_config *cfg,
}
err = -EEXIST;
- platform_label = rtnl_dereference(net->mpls.platform_label);
- old = rtnl_dereference(platform_label[index]);
+ old = mpls_route_input(net, index);
if ((cfg->rc_nlflags & NLM_F_EXCL) && old)
goto errout;
@@ -1103,7 +1119,7 @@ static int mpls_fill_stats_af(struct sk_buff *skb,
struct mpls_dev *mdev;
struct nlattr *nla;
- mdev = mpls_dev_get(dev);
+ mdev = mpls_dev_rcu(dev);
if (!mdev)
return -ENODATA;
@@ -1123,7 +1139,7 @@ static size_t mpls_get_stats_af_size(const struct net_device *dev)
{
struct mpls_dev *mdev;
- mdev = mpls_dev_get(dev);
+ mdev = mpls_dev_rcu(dev);
if (!mdev)
return 0;
@@ -1264,23 +1280,32 @@ static int mpls_netconf_get_devconf(struct sk_buff *in_skb,
if (err < 0)
goto errout;
- err = -EINVAL;
- if (!tb[NETCONFA_IFINDEX])
+ if (!tb[NETCONFA_IFINDEX]) {
+ err = -EINVAL;
goto errout;
+ }
ifindex = nla_get_s32(tb[NETCONFA_IFINDEX]);
- dev = __dev_get_by_index(net, ifindex);
- if (!dev)
- goto errout;
-
- mdev = mpls_dev_get(dev);
- if (!mdev)
- goto errout;
- err = -ENOBUFS;
skb = nlmsg_new(mpls_netconf_msgsize_devconf(NETCONFA_ALL), GFP_KERNEL);
- if (!skb)
+ if (!skb) {
+ err = -ENOBUFS;
goto errout;
+ }
+
+ rcu_read_lock();
+
+ dev = dev_get_by_index_rcu(net, ifindex);
+ if (!dev) {
+ err = -EINVAL;
+ goto errout_unlock;
+ }
+
+ mdev = mpls_dev_rcu(dev);
+ if (!mdev) {
+ err = -EINVAL;
+ goto errout_unlock;
+ }
err = mpls_netconf_fill_devconf(skb, mdev,
NETLINK_CB(in_skb).portid,
@@ -1289,12 +1314,19 @@ static int mpls_netconf_get_devconf(struct sk_buff *in_skb,
if (err < 0) {
/* -EMSGSIZE implies BUG in mpls_netconf_msgsize_devconf() */
WARN_ON(err == -EMSGSIZE);
- kfree_skb(skb);
- goto errout;
+ goto errout_unlock;
}
+
err = rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid);
+
+ rcu_read_unlock();
errout:
return err;
+
+errout_unlock:
+ rcu_read_unlock();
+ kfree_skb(skb);
+ goto errout;
}
static int mpls_netconf_dump_devconf(struct sk_buff *skb,
@@ -1326,7 +1358,7 @@ static int mpls_netconf_dump_devconf(struct sk_buff *skb,
rcu_read_lock();
for_each_netdev_dump(net, dev, ctx->ifindex) {
- mdev = mpls_dev_get(dev);
+ mdev = mpls_dev_rcu(dev);
if (!mdev)
continue;
err = mpls_netconf_fill_devconf(skb, mdev,
@@ -1438,8 +1470,6 @@ static struct mpls_dev *mpls_add_dev(struct net_device *dev)
int err = -ENOMEM;
int i;
- ASSERT_RTNL();
-
mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
if (!mdev)
return ERR_PTR(err);
@@ -1481,16 +1511,15 @@ static void mpls_dev_destroy_rcu(struct rcu_head *head)
static int mpls_ifdown(struct net_device *dev, int event)
{
- struct mpls_route __rcu **platform_label;
struct net *net = dev_net(dev);
- unsigned index;
+ unsigned int index;
- platform_label = rtnl_dereference(net->mpls.platform_label);
for (index = 0; index < net->mpls.platform_labels; index++) {
- struct mpls_route *rt = rtnl_dereference(platform_label[index]);
+ struct mpls_route *rt;
bool nh_del = false;
u8 alive = 0;
+ rt = mpls_route_input(net, index);
if (!rt)
continue;
@@ -1524,8 +1553,12 @@ static int mpls_ifdown(struct net_device *dev, int event)
change_nexthops(rt) {
unsigned int nh_flags = nh->nh_flags;
- if (nh->nh_dev != dev)
+ if (nh->nh_dev != dev) {
+ if (nh_del)
+ netdev_hold(nh->nh_dev, &nh->nh_dev_tracker,
+ GFP_KERNEL);
goto next;
+ }
switch (event) {
case NETDEV_DOWN:
@@ -1557,15 +1590,14 @@ next:
static void mpls_ifup(struct net_device *dev, unsigned int flags)
{
- struct mpls_route __rcu **platform_label;
struct net *net = dev_net(dev);
- unsigned index;
+ unsigned int index;
u8 alive;
- platform_label = rtnl_dereference(net->mpls.platform_label);
for (index = 0; index < net->mpls.platform_labels; index++) {
- struct mpls_route *rt = rtnl_dereference(platform_label[index]);
+ struct mpls_route *rt;
+ rt = mpls_route_input(net, index);
if (!rt)
continue;
@@ -1592,28 +1624,33 @@ static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
void *ptr)
{
struct net_device *dev = netdev_notifier_info_to_dev(ptr);
+ struct net *net = dev_net(dev);
struct mpls_dev *mdev;
unsigned int flags;
int err;
+ mutex_lock(&net->mpls.platform_mutex);
+
if (event == NETDEV_REGISTER) {
mdev = mpls_add_dev(dev);
- if (IS_ERR(mdev))
- return notifier_from_errno(PTR_ERR(mdev));
+ if (IS_ERR(mdev)) {
+ err = PTR_ERR(mdev);
+ goto err;
+ }
- return NOTIFY_OK;
+ goto out;
}
- mdev = mpls_dev_get(dev);
+ mdev = mpls_dev_get(net, dev);
if (!mdev)
- return NOTIFY_OK;
+ goto out;
switch (event) {
case NETDEV_DOWN:
err = mpls_ifdown(dev, event);
if (err)
- return notifier_from_errno(err);
+ goto err;
break;
case NETDEV_UP:
flags = netif_get_flags(dev);
@@ -1629,14 +1666,15 @@ static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
} else {
err = mpls_ifdown(dev, event);
if (err)
- return notifier_from_errno(err);
+ goto err;
}
break;
case NETDEV_UNREGISTER:
err = mpls_ifdown(dev, event);
if (err)
- return notifier_from_errno(err);
- mdev = mpls_dev_get(dev);
+ goto err;
+
+ mdev = mpls_dev_get(net, dev);
if (mdev) {
mpls_dev_sysctl_unregister(dev, mdev);
RCU_INIT_POINTER(dev->mpls_ptr, NULL);
@@ -1644,16 +1682,23 @@ static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
}
break;
case NETDEV_CHANGENAME:
- mdev = mpls_dev_get(dev);
+ mdev = mpls_dev_get(net, dev);
if (mdev) {
mpls_dev_sysctl_unregister(dev, mdev);
err = mpls_dev_sysctl_register(dev, mdev);
if (err)
- return notifier_from_errno(err);
+ goto err;
}
break;
}
+
+out:
+ mutex_unlock(&net->mpls.platform_mutex);
return NOTIFY_OK;
+
+err:
+ mutex_unlock(&net->mpls.platform_mutex);
+ return notifier_from_errno(err);
}
static struct notifier_block mpls_dev_notifier = {
@@ -1928,6 +1973,7 @@ errout:
static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh,
struct netlink_ext_ack *extack)
{
+ struct net *net = sock_net(skb->sk);
struct mpls_route_config *cfg;
int err;
@@ -1939,7 +1985,9 @@ static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh,
if (err < 0)
goto out;
+ mutex_lock(&net->mpls.platform_mutex);
err = mpls_route_del(cfg, extack);
+ mutex_unlock(&net->mpls.platform_mutex);
out:
kfree(cfg);
@@ -1950,6 +1998,7 @@ out:
static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh,
struct netlink_ext_ack *extack)
{
+ struct net *net = sock_net(skb->sk);
struct mpls_route_config *cfg;
int err;
@@ -1961,7 +2010,9 @@ static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh,
if (err < 0)
goto out;
+ mutex_lock(&net->mpls.platform_mutex);
err = mpls_route_add(cfg, extack);
+ mutex_unlock(&net->mpls.platform_mutex);
out:
kfree(cfg);
@@ -2124,7 +2175,7 @@ static int mpls_valid_fib_dump_req(struct net *net, const struct nlmsghdr *nlh,
if (i == RTA_OIF) {
ifindex = nla_get_u32(tb[i]);
- filter->dev = __dev_get_by_index(net, ifindex);
+ filter->dev = dev_get_by_index_rcu(net, ifindex);
if (!filter->dev)
return -ENODEV;
filter->filter_set = 1;
@@ -2162,20 +2213,19 @@ static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
struct net *net = sock_net(skb->sk);
struct mpls_route __rcu **platform_label;
struct fib_dump_filter filter = {
- .rtnl_held = true,
+ .rtnl_held = false,
};
unsigned int flags = NLM_F_MULTI;
size_t platform_labels;
unsigned int index;
+ int err;
- ASSERT_RTNL();
+ rcu_read_lock();
if (cb->strict_check) {
- int err;
-
err = mpls_valid_fib_dump_req(net, nlh, &filter, cb);
if (err < 0)
- return err;
+ goto err;
/* for MPLS, there is only 1 table with fixed type and flags.
* If either are set in the filter then return nothing.
@@ -2183,14 +2233,14 @@ static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
if ((filter.table_id && filter.table_id != RT_TABLE_MAIN) ||
(filter.rt_type && filter.rt_type != RTN_UNICAST) ||
filter.flags)
- return skb->len;
+ goto unlock;
}
index = cb->args[0];
if (index < MPLS_LABEL_FIRST_UNRESERVED)
index = MPLS_LABEL_FIRST_UNRESERVED;
- platform_label = rtnl_dereference(net->mpls.platform_label);
+ platform_label = rcu_dereference(net->mpls.platform_label);
platform_labels = net->mpls.platform_labels;
if (filter.filter_set)
@@ -2199,7 +2249,7 @@ static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
for (; index < platform_labels; index++) {
struct mpls_route *rt;
- rt = rtnl_dereference(platform_label[index]);
+ rt = rcu_dereference(platform_label[index]);
if (!rt)
continue;
@@ -2214,7 +2264,13 @@ static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
}
cb->args[0] = index;
+unlock:
+ rcu_read_unlock();
return skb->len;
+
+err:
+ rcu_read_unlock();
+ return err;
}
static inline size_t lfib_nlmsg_size(struct mpls_route *rt)
@@ -2345,18 +2401,20 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
u32 portid = NETLINK_CB(in_skb).portid;
u32 in_label = LABEL_NOT_SPECIFIED;
struct nlattr *tb[RTA_MAX + 1];
+ struct mpls_route *rt = NULL;
u32 labels[MAX_NEW_LABELS];
struct mpls_shim_hdr *hdr;
unsigned int hdr_size = 0;
const struct mpls_nh *nh;
struct net_device *dev;
- struct mpls_route *rt;
struct rtmsg *rtm, *r;
struct nlmsghdr *nlh;
struct sk_buff *skb;
u8 n_labels;
int err;
+ mutex_lock(&net->mpls.platform_mutex);
+
err = mpls_valid_getroute_req(in_skb, in_nlh, tb, extack);
if (err < 0)
goto errout;
@@ -2378,7 +2436,8 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
}
}
- rt = mpls_route_input_rcu(net, in_label);
+ if (in_label < net->mpls.platform_labels)
+ rt = mpls_route_input(net, in_label);
if (!rt) {
err = -ENETUNREACH;
goto errout;
@@ -2399,7 +2458,8 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
goto errout_free;
}
- return rtnl_unicast(skb, net, portid);
+ err = rtnl_unicast(skb, net, portid);
+ goto errout;
}
if (tb[RTA_NEWDST]) {
@@ -2491,12 +2551,14 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
err = rtnl_unicast(skb, net, portid);
errout:
+ mutex_unlock(&net->mpls.platform_mutex);
return err;
nla_put_failure:
nlmsg_cancel(skb, nlh);
err = -EMSGSIZE;
errout_free:
+ mutex_unlock(&net->mpls.platform_mutex);
kfree_skb(skb);
return err;
}
@@ -2519,10 +2581,13 @@ static int resize_platform_label_table(struct net *net, size_t limit)
/* In case the predefined labels need to be populated */
if (limit > MPLS_LABEL_IPV4NULL) {
struct net_device *lo = net->loopback_dev;
+
rt0 = mpls_rt_alloc(1, lo->addr_len, 0);
if (IS_ERR(rt0))
goto nort0;
+
rt0->rt_nh->nh_dev = lo;
+ netdev_hold(lo, &rt0->rt_nh->nh_dev_tracker, GFP_KERNEL);
rt0->rt_protocol = RTPROT_KERNEL;
rt0->rt_payload_type = MPT_IPV4;
rt0->rt_ttl_propagate = MPLS_TTL_PROP_DEFAULT;
@@ -2533,10 +2598,13 @@ static int resize_platform_label_table(struct net *net, size_t limit)
}
if (limit > MPLS_LABEL_IPV6NULL) {
struct net_device *lo = net->loopback_dev;
+
rt2 = mpls_rt_alloc(1, lo->addr_len, 0);
if (IS_ERR(rt2))
goto nort2;
+
rt2->rt_nh->nh_dev = lo;
+ netdev_hold(lo, &rt2->rt_nh->nh_dev_tracker, GFP_KERNEL);
rt2->rt_protocol = RTPROT_KERNEL;
rt2->rt_payload_type = MPT_IPV6;
rt2->rt_ttl_propagate = MPLS_TTL_PROP_DEFAULT;
@@ -2546,9 +2614,10 @@ static int resize_platform_label_table(struct net *net, size_t limit)
lo->addr_len);
}
- rtnl_lock();
+ mutex_lock(&net->mpls.platform_mutex);
+
/* Remember the original table */
- old = rtnl_dereference(net->mpls.platform_label);
+ old = mpls_dereference(net, net->mpls.platform_label);
old_limit = net->mpls.platform_labels;
/* Free any labels beyond the new table */
@@ -2579,7 +2648,7 @@ static int resize_platform_label_table(struct net *net, size_t limit)
net->mpls.platform_labels = limit;
rcu_assign_pointer(net->mpls.platform_label, labels);
- rtnl_unlock();
+ mutex_unlock(&net->mpls.platform_mutex);
mpls_rt_free(rt2);
mpls_rt_free(rt0);
@@ -2652,12 +2721,13 @@ static const struct ctl_table mpls_table[] = {
},
};
-static int mpls_net_init(struct net *net)
+static __net_init int mpls_net_init(struct net *net)
{
size_t table_size = ARRAY_SIZE(mpls_table);
struct ctl_table *table;
int i;
+ mutex_init(&net->mpls.platform_mutex);
net->mpls.platform_labels = 0;
net->mpls.platform_label = NULL;
net->mpls.ip_ttl_propagate = 1;
@@ -2683,7 +2753,7 @@ static int mpls_net_init(struct net *net)
return 0;
}
-static void mpls_net_exit(struct net *net)
+static __net_exit void mpls_net_exit(struct net *net)
{
struct mpls_route __rcu **platform_label;
size_t platform_labels;
@@ -2703,16 +2773,20 @@ static void mpls_net_exit(struct net *net)
* As such no additional rcu synchronization is necessary when
* freeing the platform_label table.
*/
- rtnl_lock();
- platform_label = rtnl_dereference(net->mpls.platform_label);
+ mutex_lock(&net->mpls.platform_mutex);
+
+ platform_label = mpls_dereference(net, net->mpls.platform_label);
platform_labels = net->mpls.platform_labels;
+
for (index = 0; index < platform_labels; index++) {
- struct mpls_route *rt = rtnl_dereference(platform_label[index]);
- RCU_INIT_POINTER(platform_label[index], NULL);
+ struct mpls_route *rt;
+
+ rt = mpls_dereference(net, platform_label[index]);
mpls_notify_route(net, index, rt, NULL, NULL);
mpls_rt_free(rt);
}
- rtnl_unlock();
+
+ mutex_unlock(&net->mpls.platform_mutex);
kvfree(platform_label);
}
@@ -2729,12 +2803,15 @@ static struct rtnl_af_ops mpls_af_ops __read_mostly = {
};
static const struct rtnl_msg_handler mpls_rtnl_msg_handlers[] __initdata_or_module = {
- {THIS_MODULE, PF_MPLS, RTM_NEWROUTE, mpls_rtm_newroute, NULL, 0},
- {THIS_MODULE, PF_MPLS, RTM_DELROUTE, mpls_rtm_delroute, NULL, 0},
- {THIS_MODULE, PF_MPLS, RTM_GETROUTE, mpls_getroute, mpls_dump_routes, 0},
+ {THIS_MODULE, PF_MPLS, RTM_NEWROUTE, mpls_rtm_newroute, NULL,
+ RTNL_FLAG_DOIT_UNLOCKED},
+ {THIS_MODULE, PF_MPLS, RTM_DELROUTE, mpls_rtm_delroute, NULL,
+ RTNL_FLAG_DOIT_UNLOCKED},
+ {THIS_MODULE, PF_MPLS, RTM_GETROUTE, mpls_getroute, mpls_dump_routes,
+ RTNL_FLAG_DOIT_UNLOCKED | RTNL_FLAG_DUMP_UNLOCKED},
{THIS_MODULE, PF_MPLS, RTM_GETNETCONF,
mpls_netconf_get_devconf, mpls_netconf_dump_devconf,
- RTNL_FLAG_DUMP_UNLOCKED},
+ RTNL_FLAG_DOIT_UNLOCKED | RTNL_FLAG_DUMP_UNLOCKED},
};
static int __init mpls_init(void)