diff options
Diffstat (limited to 'net/mpls')
| -rw-r--r-- | net/mpls/af_mpls.c | 321 | ||||
| -rw-r--r-- | net/mpls/internal.h | 19 | ||||
| -rw-r--r-- | net/mpls/mpls_iptunnel.c | 6 |
3 files changed, 218 insertions, 128 deletions
diff --git a/net/mpls/af_mpls.c b/net/mpls/af_mpls.c index 25c88cba5c48..580aac112dd2 100644 --- a/net/mpls/af_mpls.c +++ b/net/mpls/af_mpls.c @@ -75,16 +75,23 @@ static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt, struct nlmsghdr *nlh, struct net *net, u32 portid, unsigned int nlm_flags); -static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned index) +static struct mpls_route *mpls_route_input(struct net *net, unsigned int index) { - struct mpls_route *rt = NULL; + struct mpls_route __rcu **platform_label; - if (index < net->mpls.platform_labels) { - struct mpls_route __rcu **platform_label = - rcu_dereference_rtnl(net->mpls.platform_label); - rt = rcu_dereference_rtnl(platform_label[index]); - } - return rt; + platform_label = mpls_dereference(net, net->mpls.platform_label); + return mpls_dereference(net, platform_label[index]); +} + +static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned int index) +{ + struct mpls_route __rcu **platform_label; + + if (index >= net->mpls.platform_labels) + return NULL; + + platform_label = rcu_dereference(net->mpls.platform_label); + return rcu_dereference(platform_label[index]); } bool mpls_output_possible(const struct net_device *dev) @@ -129,25 +136,26 @@ bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu) } EXPORT_SYMBOL_GPL(mpls_pkt_too_big); -void mpls_stats_inc_outucastpkts(struct net_device *dev, +void mpls_stats_inc_outucastpkts(struct net *net, + struct net_device *dev, const struct sk_buff *skb) { struct mpls_dev *mdev; if (skb->protocol == htons(ETH_P_MPLS_UC)) { - mdev = mpls_dev_get(dev); + mdev = mpls_dev_rcu(dev); if (mdev) MPLS_INC_STATS_LEN(mdev, skb->len, tx_packets, tx_bytes); } else if (skb->protocol == htons(ETH_P_IP)) { - IP_UPD_PO_STATS(dev_net(dev), IPSTATS_MIB_OUT, skb->len); + IP_UPD_PO_STATS(net, IPSTATS_MIB_OUT, skb->len); #if IS_ENABLED(CONFIG_IPV6) } else if (skb->protocol == htons(ETH_P_IPV6)) { - struct inet6_dev *in6dev = __in6_dev_get(dev); + struct inet6_dev *in6dev = in6_dev_rcu(dev); if (in6dev) - IP6_UPD_PO_STATS(dev_net(dev), in6dev, + IP6_UPD_PO_STATS(net, in6dev, IPSTATS_MIB_OUT, skb->len); #endif } @@ -342,7 +350,7 @@ static bool mpls_egress(struct net *net, struct mpls_route *rt, static int mpls_forward(struct sk_buff *skb, struct net_device *dev, struct packet_type *pt, struct net_device *orig_dev) { - struct net *net = dev_net(dev); + struct net *net = dev_net_rcu(dev); struct mpls_shim_hdr *hdr; const struct mpls_nh *nh; struct mpls_route *rt; @@ -357,7 +365,7 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev, /* Careful this entire function runs inside of an rcu critical section */ - mdev = mpls_dev_get(dev); + mdev = mpls_dev_rcu(dev); if (!mdev) goto drop; @@ -434,7 +442,7 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev, dec.ttl -= 1; if (unlikely(!new_header_size && dec.bos)) { /* Penultimate hop popping */ - if (!mpls_egress(dev_net(out_dev), rt, skb, dec)) + if (!mpls_egress(net, rt, skb, dec)) goto err; } else { bool bos; @@ -451,7 +459,7 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev, } } - mpls_stats_inc_outucastpkts(out_dev, skb); + mpls_stats_inc_outucastpkts(net, out_dev, skb); /* If via wasn't specified then send out using device address */ if (nh->nh_via_table == MPLS_NEIGH_TABLE_UNSPEC) @@ -466,7 +474,7 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev, return 0; tx_err: - out_mdev = out_dev ? mpls_dev_get(out_dev) : NULL; + out_mdev = out_dev ? mpls_dev_rcu(out_dev) : NULL; if (out_mdev) MPLS_INC_STATS(out_mdev, tx_errors); goto drop; @@ -530,10 +538,23 @@ static struct mpls_route *mpls_rt_alloc(u8 num_nh, u8 max_alen, u8 max_labels) return rt; } +static void mpls_rt_free_rcu(struct rcu_head *head) +{ + struct mpls_route *rt; + + rt = container_of(head, struct mpls_route, rt_rcu); + + change_nexthops(rt) { + netdev_put(nh->nh_dev, &nh->nh_dev_tracker); + } endfor_nexthops(rt); + + kfree(rt); +} + static void mpls_rt_free(struct mpls_route *rt) { if (rt) - kfree_rcu(rt, rt_rcu); + call_rcu(&rt->rt_rcu, mpls_rt_free_rcu); } static void mpls_notify_route(struct net *net, unsigned index, @@ -557,10 +578,8 @@ static void mpls_route_update(struct net *net, unsigned index, struct mpls_route __rcu **platform_label; struct mpls_route *rt; - ASSERT_RTNL(); - - platform_label = rtnl_dereference(net->mpls.platform_label); - rt = rtnl_dereference(platform_label[index]); + platform_label = mpls_dereference(net, net->mpls.platform_label); + rt = mpls_dereference(net, platform_label[index]); rcu_assign_pointer(platform_label[index], new); mpls_notify_route(net, index, rt, new, info); @@ -569,24 +588,23 @@ static void mpls_route_update(struct net *net, unsigned index, mpls_rt_free(rt); } -static unsigned find_free_label(struct net *net) +static unsigned int find_free_label(struct net *net) { - struct mpls_route __rcu **platform_label; - size_t platform_labels; - unsigned index; + unsigned int index; - platform_label = rtnl_dereference(net->mpls.platform_label); - platform_labels = net->mpls.platform_labels; - for (index = MPLS_LABEL_FIRST_UNRESERVED; index < platform_labels; + for (index = MPLS_LABEL_FIRST_UNRESERVED; + index < net->mpls.platform_labels; index++) { - if (!rtnl_dereference(platform_label[index])) + if (!mpls_route_input(net, index)) return index; } + return LABEL_NOT_SPECIFIED; } #if IS_ENABLED(CONFIG_INET) static struct net_device *inet_fib_lookup_dev(struct net *net, + struct mpls_nh *nh, const void *addr) { struct net_device *dev; @@ -599,14 +617,14 @@ static struct net_device *inet_fib_lookup_dev(struct net *net, return ERR_CAST(rt); dev = rt->dst.dev; - dev_hold(dev); - + netdev_hold(dev, &nh->nh_dev_tracker, GFP_KERNEL); ip_rt_put(rt); return dev; } #else static struct net_device *inet_fib_lookup_dev(struct net *net, + struct mpls_nh *nh, const void *addr) { return ERR_PTR(-EAFNOSUPPORT); @@ -615,6 +633,7 @@ static struct net_device *inet_fib_lookup_dev(struct net *net, #if IS_ENABLED(CONFIG_IPV6) static struct net_device *inet6_fib_lookup_dev(struct net *net, + struct mpls_nh *nh, const void *addr) { struct net_device *dev; @@ -631,13 +650,14 @@ static struct net_device *inet6_fib_lookup_dev(struct net *net, return ERR_CAST(dst); dev = dst->dev; - dev_hold(dev); + netdev_hold(dev, &nh->nh_dev_tracker, GFP_KERNEL); dst_release(dst); return dev; } #else static struct net_device *inet6_fib_lookup_dev(struct net *net, + struct mpls_nh *nh, const void *addr) { return ERR_PTR(-EAFNOSUPPORT); @@ -653,16 +673,17 @@ static struct net_device *find_outdev(struct net *net, if (!oif) { switch (nh->nh_via_table) { case NEIGH_ARP_TABLE: - dev = inet_fib_lookup_dev(net, mpls_nh_via(rt, nh)); + dev = inet_fib_lookup_dev(net, nh, mpls_nh_via(rt, nh)); break; case NEIGH_ND_TABLE: - dev = inet6_fib_lookup_dev(net, mpls_nh_via(rt, nh)); + dev = inet6_fib_lookup_dev(net, nh, mpls_nh_via(rt, nh)); break; case NEIGH_LINK_TABLE: break; } } else { - dev = dev_get_by_index(net, oif); + dev = netdev_get_by_index(net, oif, + &nh->nh_dev_tracker, GFP_KERNEL); } if (!dev) @@ -671,8 +692,7 @@ static struct net_device *find_outdev(struct net *net, if (IS_ERR(dev)) return dev; - /* The caller is holding rtnl anyways, so release the dev reference */ - dev_put(dev); + nh->nh_dev = dev; return dev; } @@ -686,20 +706,17 @@ static int mpls_nh_assign_dev(struct net *net, struct mpls_route *rt, dev = find_outdev(net, rt, nh, oif); if (IS_ERR(dev)) { err = PTR_ERR(dev); - dev = NULL; goto errout; } /* Ensure this is a supported device */ err = -EINVAL; - if (!mpls_dev_get(dev)) - goto errout; + if (!mpls_dev_get(net, dev)) + goto errout_put; if ((nh->nh_via_table == NEIGH_LINK_TABLE) && (dev->addr_len != nh->nh_via_alen)) - goto errout; - - nh->nh_dev = dev; + goto errout_put; if (!(dev->flags & IFF_UP)) { nh->nh_flags |= RTNH_F_DEAD; @@ -713,6 +730,9 @@ static int mpls_nh_assign_dev(struct net *net, struct mpls_route *rt, return 0; +errout_put: + netdev_put(nh->nh_dev, &nh->nh_dev_tracker); + nh->nh_dev = NULL; errout: return err; } @@ -890,7 +910,8 @@ static int mpls_nh_build_multi(struct mpls_route_config *cfg, struct nlattr *nla_via, *nla_newdst; int remaining = cfg->rc_mp_len; int err = 0; - u8 nhs = 0; + + rt->rt_nhn = 0; change_nexthops(rt) { int attrlen; @@ -926,11 +947,9 @@ static int mpls_nh_build_multi(struct mpls_route_config *cfg, rt->rt_nhn_alive--; rtnh = rtnh_next(rtnh, &remaining); - nhs++; + rt->rt_nhn++; } endfor_nexthops(rt); - rt->rt_nhn = nhs; - return 0; errout: @@ -940,30 +959,28 @@ errout: static bool mpls_label_ok(struct net *net, unsigned int *index, struct netlink_ext_ack *extack) { - bool is_ok = true; - /* Reserved labels may not be set */ if (*index < MPLS_LABEL_FIRST_UNRESERVED) { NL_SET_ERR_MSG(extack, "Invalid label - must be MPLS_LABEL_FIRST_UNRESERVED or higher"); - is_ok = false; + return false; } /* The full 20 bit range may not be supported. */ - if (is_ok && *index >= net->mpls.platform_labels) { + if (*index >= net->mpls.platform_labels) { NL_SET_ERR_MSG(extack, "Label >= configured maximum in platform_labels"); - is_ok = false; + return false; } *index = array_index_nospec(*index, net->mpls.platform_labels); - return is_ok; + + return true; } static int mpls_route_add(struct mpls_route_config *cfg, struct netlink_ext_ack *extack) { - struct mpls_route __rcu **platform_label; struct net *net = cfg->rc_nlinfo.nl_net; struct mpls_route *rt, *old; int err = -EINVAL; @@ -991,8 +1008,7 @@ static int mpls_route_add(struct mpls_route_config *cfg, } err = -EEXIST; - platform_label = rtnl_dereference(net->mpls.platform_label); - old = rtnl_dereference(platform_label[index]); + old = mpls_route_input(net, index); if ((cfg->rc_nlflags & NLM_F_EXCL) && old) goto errout; @@ -1103,7 +1119,7 @@ static int mpls_fill_stats_af(struct sk_buff *skb, struct mpls_dev *mdev; struct nlattr *nla; - mdev = mpls_dev_get(dev); + mdev = mpls_dev_rcu(dev); if (!mdev) return -ENODATA; @@ -1123,7 +1139,7 @@ static size_t mpls_get_stats_af_size(const struct net_device *dev) { struct mpls_dev *mdev; - mdev = mpls_dev_get(dev); + mdev = mpls_dev_rcu(dev); if (!mdev) return 0; @@ -1264,23 +1280,32 @@ static int mpls_netconf_get_devconf(struct sk_buff *in_skb, if (err < 0) goto errout; - err = -EINVAL; - if (!tb[NETCONFA_IFINDEX]) + if (!tb[NETCONFA_IFINDEX]) { + err = -EINVAL; goto errout; + } ifindex = nla_get_s32(tb[NETCONFA_IFINDEX]); - dev = __dev_get_by_index(net, ifindex); - if (!dev) - goto errout; - - mdev = mpls_dev_get(dev); - if (!mdev) - goto errout; - err = -ENOBUFS; skb = nlmsg_new(mpls_netconf_msgsize_devconf(NETCONFA_ALL), GFP_KERNEL); - if (!skb) + if (!skb) { + err = -ENOBUFS; goto errout; + } + + rcu_read_lock(); + + dev = dev_get_by_index_rcu(net, ifindex); + if (!dev) { + err = -EINVAL; + goto errout_unlock; + } + + mdev = mpls_dev_rcu(dev); + if (!mdev) { + err = -EINVAL; + goto errout_unlock; + } err = mpls_netconf_fill_devconf(skb, mdev, NETLINK_CB(in_skb).portid, @@ -1289,12 +1314,19 @@ static int mpls_netconf_get_devconf(struct sk_buff *in_skb, if (err < 0) { /* -EMSGSIZE implies BUG in mpls_netconf_msgsize_devconf() */ WARN_ON(err == -EMSGSIZE); - kfree_skb(skb); - goto errout; + goto errout_unlock; } + err = rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid); + + rcu_read_unlock(); errout: return err; + +errout_unlock: + rcu_read_unlock(); + kfree_skb(skb); + goto errout; } static int mpls_netconf_dump_devconf(struct sk_buff *skb, @@ -1326,7 +1358,7 @@ static int mpls_netconf_dump_devconf(struct sk_buff *skb, rcu_read_lock(); for_each_netdev_dump(net, dev, ctx->ifindex) { - mdev = mpls_dev_get(dev); + mdev = mpls_dev_rcu(dev); if (!mdev) continue; err = mpls_netconf_fill_devconf(skb, mdev, @@ -1438,8 +1470,6 @@ static struct mpls_dev *mpls_add_dev(struct net_device *dev) int err = -ENOMEM; int i; - ASSERT_RTNL(); - mdev = kzalloc(sizeof(*mdev), GFP_KERNEL); if (!mdev) return ERR_PTR(err); @@ -1481,16 +1511,15 @@ static void mpls_dev_destroy_rcu(struct rcu_head *head) static int mpls_ifdown(struct net_device *dev, int event) { - struct mpls_route __rcu **platform_label; struct net *net = dev_net(dev); - unsigned index; + unsigned int index; - platform_label = rtnl_dereference(net->mpls.platform_label); for (index = 0; index < net->mpls.platform_labels; index++) { - struct mpls_route *rt = rtnl_dereference(platform_label[index]); + struct mpls_route *rt; bool nh_del = false; u8 alive = 0; + rt = mpls_route_input(net, index); if (!rt) continue; @@ -1524,8 +1553,12 @@ static int mpls_ifdown(struct net_device *dev, int event) change_nexthops(rt) { unsigned int nh_flags = nh->nh_flags; - if (nh->nh_dev != dev) + if (nh->nh_dev != dev) { + if (nh_del) + netdev_hold(nh->nh_dev, &nh->nh_dev_tracker, + GFP_KERNEL); goto next; + } switch (event) { case NETDEV_DOWN: @@ -1557,15 +1590,14 @@ next: static void mpls_ifup(struct net_device *dev, unsigned int flags) { - struct mpls_route __rcu **platform_label; struct net *net = dev_net(dev); - unsigned index; + unsigned int index; u8 alive; - platform_label = rtnl_dereference(net->mpls.platform_label); for (index = 0; index < net->mpls.platform_labels; index++) { - struct mpls_route *rt = rtnl_dereference(platform_label[index]); + struct mpls_route *rt; + rt = mpls_route_input(net, index); if (!rt) continue; @@ -1592,28 +1624,33 @@ static int mpls_dev_notify(struct notifier_block *this, unsigned long event, void *ptr) { struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct net *net = dev_net(dev); struct mpls_dev *mdev; unsigned int flags; int err; + mutex_lock(&net->mpls.platform_mutex); + if (event == NETDEV_REGISTER) { mdev = mpls_add_dev(dev); - if (IS_ERR(mdev)) - return notifier_from_errno(PTR_ERR(mdev)); + if (IS_ERR(mdev)) { + err = PTR_ERR(mdev); + goto err; + } - return NOTIFY_OK; + goto out; } - mdev = mpls_dev_get(dev); + mdev = mpls_dev_get(net, dev); if (!mdev) - return NOTIFY_OK; + goto out; switch (event) { case NETDEV_DOWN: err = mpls_ifdown(dev, event); if (err) - return notifier_from_errno(err); + goto err; break; case NETDEV_UP: flags = netif_get_flags(dev); @@ -1629,14 +1666,15 @@ static int mpls_dev_notify(struct notifier_block *this, unsigned long event, } else { err = mpls_ifdown(dev, event); if (err) - return notifier_from_errno(err); + goto err; } break; case NETDEV_UNREGISTER: err = mpls_ifdown(dev, event); if (err) - return notifier_from_errno(err); - mdev = mpls_dev_get(dev); + goto err; + + mdev = mpls_dev_get(net, dev); if (mdev) { mpls_dev_sysctl_unregister(dev, mdev); RCU_INIT_POINTER(dev->mpls_ptr, NULL); @@ -1644,16 +1682,23 @@ static int mpls_dev_notify(struct notifier_block *this, unsigned long event, } break; case NETDEV_CHANGENAME: - mdev = mpls_dev_get(dev); + mdev = mpls_dev_get(net, dev); if (mdev) { mpls_dev_sysctl_unregister(dev, mdev); err = mpls_dev_sysctl_register(dev, mdev); if (err) - return notifier_from_errno(err); + goto err; } break; } + +out: + mutex_unlock(&net->mpls.platform_mutex); return NOTIFY_OK; + +err: + mutex_unlock(&net->mpls.platform_mutex); + return notifier_from_errno(err); } static struct notifier_block mpls_dev_notifier = { @@ -1928,6 +1973,7 @@ errout: static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { + struct net *net = sock_net(skb->sk); struct mpls_route_config *cfg; int err; @@ -1939,7 +1985,9 @@ static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh, if (err < 0) goto out; + mutex_lock(&net->mpls.platform_mutex); err = mpls_route_del(cfg, extack); + mutex_unlock(&net->mpls.platform_mutex); out: kfree(cfg); @@ -1950,6 +1998,7 @@ out: static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh, struct netlink_ext_ack *extack) { + struct net *net = sock_net(skb->sk); struct mpls_route_config *cfg; int err; @@ -1961,7 +2010,9 @@ static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh, if (err < 0) goto out; + mutex_lock(&net->mpls.platform_mutex); err = mpls_route_add(cfg, extack); + mutex_unlock(&net->mpls.platform_mutex); out: kfree(cfg); @@ -2124,7 +2175,7 @@ static int mpls_valid_fib_dump_req(struct net *net, const struct nlmsghdr *nlh, if (i == RTA_OIF) { ifindex = nla_get_u32(tb[i]); - filter->dev = __dev_get_by_index(net, ifindex); + filter->dev = dev_get_by_index_rcu(net, ifindex); if (!filter->dev) return -ENODEV; filter->filter_set = 1; @@ -2162,20 +2213,19 @@ static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb) struct net *net = sock_net(skb->sk); struct mpls_route __rcu **platform_label; struct fib_dump_filter filter = { - .rtnl_held = true, + .rtnl_held = false, }; unsigned int flags = NLM_F_MULTI; size_t platform_labels; unsigned int index; + int err; - ASSERT_RTNL(); + rcu_read_lock(); if (cb->strict_check) { - int err; - err = mpls_valid_fib_dump_req(net, nlh, &filter, cb); if (err < 0) - return err; + goto err; /* for MPLS, there is only 1 table with fixed type and flags. * If either are set in the filter then return nothing. @@ -2183,14 +2233,14 @@ static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb) if ((filter.table_id && filter.table_id != RT_TABLE_MAIN) || (filter.rt_type && filter.rt_type != RTN_UNICAST) || filter.flags) - return skb->len; + goto unlock; } index = cb->args[0]; if (index < MPLS_LABEL_FIRST_UNRESERVED) index = MPLS_LABEL_FIRST_UNRESERVED; - platform_label = rtnl_dereference(net->mpls.platform_label); + platform_label = rcu_dereference(net->mpls.platform_label); platform_labels = net->mpls.platform_labels; if (filter.filter_set) @@ -2199,7 +2249,7 @@ static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb) for (; index < platform_labels; index++) { struct mpls_route *rt; - rt = rtnl_dereference(platform_label[index]); + rt = rcu_dereference(platform_label[index]); if (!rt) continue; @@ -2214,7 +2264,13 @@ static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb) } cb->args[0] = index; +unlock: + rcu_read_unlock(); return skb->len; + +err: + rcu_read_unlock(); + return err; } static inline size_t lfib_nlmsg_size(struct mpls_route *rt) @@ -2345,18 +2401,20 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh, u32 portid = NETLINK_CB(in_skb).portid; u32 in_label = LABEL_NOT_SPECIFIED; struct nlattr *tb[RTA_MAX + 1]; + struct mpls_route *rt = NULL; u32 labels[MAX_NEW_LABELS]; struct mpls_shim_hdr *hdr; unsigned int hdr_size = 0; const struct mpls_nh *nh; struct net_device *dev; - struct mpls_route *rt; struct rtmsg *rtm, *r; struct nlmsghdr *nlh; struct sk_buff *skb; u8 n_labels; int err; + mutex_lock(&net->mpls.platform_mutex); + err = mpls_valid_getroute_req(in_skb, in_nlh, tb, extack); if (err < 0) goto errout; @@ -2378,7 +2436,8 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh, } } - rt = mpls_route_input_rcu(net, in_label); + if (in_label < net->mpls.platform_labels) + rt = mpls_route_input(net, in_label); if (!rt) { err = -ENETUNREACH; goto errout; @@ -2399,7 +2458,8 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh, goto errout_free; } - return rtnl_unicast(skb, net, portid); + err = rtnl_unicast(skb, net, portid); + goto errout; } if (tb[RTA_NEWDST]) { @@ -2491,12 +2551,14 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh, err = rtnl_unicast(skb, net, portid); errout: + mutex_unlock(&net->mpls.platform_mutex); return err; nla_put_failure: nlmsg_cancel(skb, nlh); err = -EMSGSIZE; errout_free: + mutex_unlock(&net->mpls.platform_mutex); kfree_skb(skb); return err; } @@ -2519,10 +2581,13 @@ static int resize_platform_label_table(struct net *net, size_t limit) /* In case the predefined labels need to be populated */ if (limit > MPLS_LABEL_IPV4NULL) { struct net_device *lo = net->loopback_dev; + rt0 = mpls_rt_alloc(1, lo->addr_len, 0); if (IS_ERR(rt0)) goto nort0; + rt0->rt_nh->nh_dev = lo; + netdev_hold(lo, &rt0->rt_nh->nh_dev_tracker, GFP_KERNEL); rt0->rt_protocol = RTPROT_KERNEL; rt0->rt_payload_type = MPT_IPV4; rt0->rt_ttl_propagate = MPLS_TTL_PROP_DEFAULT; @@ -2533,10 +2598,13 @@ static int resize_platform_label_table(struct net *net, size_t limit) } if (limit > MPLS_LABEL_IPV6NULL) { struct net_device *lo = net->loopback_dev; + rt2 = mpls_rt_alloc(1, lo->addr_len, 0); if (IS_ERR(rt2)) goto nort2; + rt2->rt_nh->nh_dev = lo; + netdev_hold(lo, &rt2->rt_nh->nh_dev_tracker, GFP_KERNEL); rt2->rt_protocol = RTPROT_KERNEL; rt2->rt_payload_type = MPT_IPV6; rt2->rt_ttl_propagate = MPLS_TTL_PROP_DEFAULT; @@ -2546,9 +2614,10 @@ static int resize_platform_label_table(struct net *net, size_t limit) lo->addr_len); } - rtnl_lock(); + mutex_lock(&net->mpls.platform_mutex); + /* Remember the original table */ - old = rtnl_dereference(net->mpls.platform_label); + old = mpls_dereference(net, net->mpls.platform_label); old_limit = net->mpls.platform_labels; /* Free any labels beyond the new table */ @@ -2579,7 +2648,7 @@ static int resize_platform_label_table(struct net *net, size_t limit) net->mpls.platform_labels = limit; rcu_assign_pointer(net->mpls.platform_label, labels); - rtnl_unlock(); + mutex_unlock(&net->mpls.platform_mutex); mpls_rt_free(rt2); mpls_rt_free(rt0); @@ -2652,12 +2721,13 @@ static const struct ctl_table mpls_table[] = { }, }; -static int mpls_net_init(struct net *net) +static __net_init int mpls_net_init(struct net *net) { size_t table_size = ARRAY_SIZE(mpls_table); struct ctl_table *table; int i; + mutex_init(&net->mpls.platform_mutex); net->mpls.platform_labels = 0; net->mpls.platform_label = NULL; net->mpls.ip_ttl_propagate = 1; @@ -2683,7 +2753,7 @@ static int mpls_net_init(struct net *net) return 0; } -static void mpls_net_exit(struct net *net) +static __net_exit void mpls_net_exit(struct net *net) { struct mpls_route __rcu **platform_label; size_t platform_labels; @@ -2703,16 +2773,20 @@ static void mpls_net_exit(struct net *net) * As such no additional rcu synchronization is necessary when * freeing the platform_label table. */ - rtnl_lock(); - platform_label = rtnl_dereference(net->mpls.platform_label); + mutex_lock(&net->mpls.platform_mutex); + + platform_label = mpls_dereference(net, net->mpls.platform_label); platform_labels = net->mpls.platform_labels; + for (index = 0; index < platform_labels; index++) { - struct mpls_route *rt = rtnl_dereference(platform_label[index]); - RCU_INIT_POINTER(platform_label[index], NULL); + struct mpls_route *rt; + + rt = mpls_dereference(net, platform_label[index]); mpls_notify_route(net, index, rt, NULL, NULL); mpls_rt_free(rt); } - rtnl_unlock(); + + mutex_unlock(&net->mpls.platform_mutex); kvfree(platform_label); } @@ -2729,12 +2803,15 @@ static struct rtnl_af_ops mpls_af_ops __read_mostly = { }; static const struct rtnl_msg_handler mpls_rtnl_msg_handlers[] __initdata_or_module = { - {THIS_MODULE, PF_MPLS, RTM_NEWROUTE, mpls_rtm_newroute, NULL, 0}, - {THIS_MODULE, PF_MPLS, RTM_DELROUTE, mpls_rtm_delroute, NULL, 0}, - {THIS_MODULE, PF_MPLS, RTM_GETROUTE, mpls_getroute, mpls_dump_routes, 0}, + {THIS_MODULE, PF_MPLS, RTM_NEWROUTE, mpls_rtm_newroute, NULL, + RTNL_FLAG_DOIT_UNLOCKED}, + {THIS_MODULE, PF_MPLS, RTM_DELROUTE, mpls_rtm_delroute, NULL, + RTNL_FLAG_DOIT_UNLOCKED}, + {THIS_MODULE, PF_MPLS, RTM_GETROUTE, mpls_getroute, mpls_dump_routes, + RTNL_FLAG_DOIT_UNLOCKED | RTNL_FLAG_DUMP_UNLOCKED}, {THIS_MODULE, PF_MPLS, RTM_GETNETCONF, mpls_netconf_get_devconf, mpls_netconf_dump_devconf, - RTNL_FLAG_DUMP_UNLOCKED}, + RTNL_FLAG_DOIT_UNLOCKED | RTNL_FLAG_DUMP_UNLOCKED}, }; static int __init mpls_init(void) diff --git a/net/mpls/internal.h b/net/mpls/internal.h index 83c629529b57..80cb5bbcd946 100644 --- a/net/mpls/internal.h +++ b/net/mpls/internal.h @@ -88,6 +88,7 @@ enum mpls_payload_type { struct mpls_nh { /* next hop label forwarding entry */ struct net_device *nh_dev; + netdevice_tracker nh_dev_tracker; /* nh_flags is accessed under RCU in the packet path; it is * modified handling netdev events with rtnl lock held @@ -184,9 +185,20 @@ static inline struct mpls_entry_decoded mpls_entry_decode(struct mpls_shim_hdr * return result; } -static inline struct mpls_dev *mpls_dev_get(const struct net_device *dev) +#define mpls_dereference(net, p) \ + rcu_dereference_protected( \ + (p), \ + lockdep_is_held(&(net)->mpls.platform_mutex)) + +static inline struct mpls_dev *mpls_dev_rcu(const struct net_device *dev) +{ + return rcu_dereference(dev->mpls_ptr); +} + +static inline struct mpls_dev *mpls_dev_get(const struct net *net, + const struct net_device *dev) { - return rcu_dereference_rtnl(dev->mpls_ptr); + return mpls_dereference(net, dev->mpls_ptr); } int nla_put_labels(struct sk_buff *skb, int attrtype, u8 labels, @@ -196,7 +208,8 @@ int nla_get_labels(const struct nlattr *nla, u8 max_labels, u8 *labels, bool mpls_output_possible(const struct net_device *dev); unsigned int mpls_dev_mtu(const struct net_device *dev); bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu); -void mpls_stats_inc_outucastpkts(struct net_device *dev, +void mpls_stats_inc_outucastpkts(struct net *net, + struct net_device *dev, const struct sk_buff *skb); #endif /* MPLS_INTERNAL_H */ diff --git a/net/mpls/mpls_iptunnel.c b/net/mpls/mpls_iptunnel.c index 6e73da94af7f..1a1a0eb5b787 100644 --- a/net/mpls/mpls_iptunnel.c +++ b/net/mpls/mpls_iptunnel.c @@ -53,7 +53,7 @@ static int mpls_xmit(struct sk_buff *skb) /* Find the output device */ out_dev = dst->dev; - net = dev_net(out_dev); + net = dev_net_rcu(out_dev); if (!mpls_output_possible(out_dev) || !dst->lwtstate || skb_warn_if_lro(skb)) @@ -128,7 +128,7 @@ static int mpls_xmit(struct sk_buff *skb) bos = false; } - mpls_stats_inc_outucastpkts(out_dev, skb); + mpls_stats_inc_outucastpkts(net, out_dev, skb); if (rt) { if (rt->rt_gw_family == AF_INET6) @@ -153,7 +153,7 @@ static int mpls_xmit(struct sk_buff *skb) return LWTUNNEL_XMIT_DONE; drop: - out_mdev = out_dev ? mpls_dev_get(out_dev) : NULL; + out_mdev = out_dev ? mpls_dev_rcu(out_dev) : NULL; if (out_mdev) MPLS_INC_STATS(out_mdev, tx_errors); kfree_skb(skb); |
