diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c index 7f53e022e27e..07449fd00b40 100644 --- a/net/mptcp/pm_netlink.c +++ b/net/mptcp/pm_netlink.c @@ -598,8 +598,10 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) continue; spin_unlock_bh(&msk->pm.lock); - for (i = 0; i < nr; i++) - __mptcp_subflow_connect(sk, &local->addr, &addrs[i]); + for (i = 0; i < nr; i++) { + if (refcount_inc_not_zero(&local->refcnt)) + __mptcp_subflow_connect(sk, &local->addr, &addrs[i]); + } spin_lock_bh(&msk->pm.lock); } mptcp_pm_nl_check_work_pending(msk); @@ -639,7 +641,8 @@ static unsigned int fill_local_addresses_vec(struct mptcp_sock *msk, if (!mptcp_pm_addr_families_match(sk, &entry->addr, remote)) continue; - if (msk->pm.subflows < subflows_max) { + if (msk->pm.subflows < subflows_max && + refcount_inc_not_zero(&entry->refcnt)) { msk->pm.subflows++; addrs[i++] = entry->addr; } @@ -1088,6 +1091,7 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct mptcp_addr_info *skc entry->ifindex = 0; entry->flags = MPTCP_PM_ADDR_FLAG_IMPLICIT; entry->lsk = NULL; + refcount_set(&entry->refcnt, 1); ret = mptcp_pm_nl_append_new_local_addr(pernet, entry, true); if (ret < 0) kfree(entry); @@ -1327,6 +1331,7 @@ int mptcp_pm_nl_add_addr_doit(struct sk_buff *skb, struct genl_info *info) } *entry = addr; + refcount_set(&entry->refcnt, 1); if (entry->addr.port) { ret = mptcp_pm_nl_create_listen_socket(skb->sk, entry); if (ret) { @@ -1341,7 +1346,8 @@ int mptcp_pm_nl_add_addr_doit(struct sk_buff *skb, struct genl_info *info) goto out_free; } - mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk)); + if (refcount_inc_not_zero(&entry->refcnt)) + mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk)); return 0; out_free: @@ -1480,6 +1486,7 @@ int mptcp_pm_nl_del_addr_doit(struct sk_buff *skb, struct genl_info *info) struct nlattr *attr = info->attrs[MPTCP_PM_ENDPOINT_ADDR]; struct pm_nl_pernet *pernet = genl_info_pm_nl(info); struct mptcp_pm_addr_entry addr, *entry; + bool release_entry = false; unsigned int addr_max; int ret; @@ -1511,14 +1518,21 @@ int mptcp_pm_nl_del_addr_doit(struct sk_buff *skb, struct genl_info *info) WRITE_ONCE(pernet->local_addr_max, addr_max - 1); } - pernet->addrs--; - list_del_rcu(&entry->list); - __clear_bit(entry->addr.id, pernet->id_bitmap); + if (refcount_dec_not_one(&entry->refcnt) && + refcount_read(&entry->refcnt) == 1) { + pernet->addrs--; + list_del_rcu(&entry->list); + __clear_bit(entry->addr.id, pernet->id_bitmap); + release_entry = true; + } spin_unlock_bh(&pernet->lock); mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), entry); synchronize_rcu(); - __mptcp_pm_release_addr_entry(entry); + if (release_entry) + __mptcp_pm_release_addr_entry(entry); + else + entry->flags |= MPTCP_PM_ADDR_FLAG_IMPLICIT; return ret; }