diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c index 3b34e549fb79..5e36f45e906b 100644 --- a/net/l2tp/l2tp_core.c +++ b/net/l2tp/l2tp_core.c @@ -262,6 +262,28 @@ struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth) } EXPORT_SYMBOL_GPL(l2tp_tunnel_get_nth); +struct l2tp_tunnel *l2tp_tunnel_get_next(const struct net *net, unsigned long *key) +{ + struct l2tp_net *pn = l2tp_pernet(net); + struct l2tp_tunnel *tunnel = NULL; + + rcu_read_lock_bh(); +again: + tunnel = idr_get_next_ul(&pn->l2tp_tunnel_idr, key); + if (tunnel) { + if (refcount_inc_not_zero(&tunnel->ref_count)) { + rcu_read_unlock_bh(); + return tunnel; + } + (*key)++; + goto again; + } + rcu_read_unlock_bh(); + + return NULL; +} +EXPORT_SYMBOL_GPL(l2tp_tunnel_get_next); + struct l2tp_session *l2tp_v3_session_get(const struct net *net, struct sock *sk, u32 session_id) { const struct l2tp_net *pn = l2tp_pernet(net); @@ -352,6 +374,110 @@ struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth) } EXPORT_SYMBOL_GPL(l2tp_session_get_nth); +static struct l2tp_session *l2tp_v2_session_get_next(const struct net *net, + u16 tid, + unsigned long *key) +{ + struct l2tp_net *pn = l2tp_pernet(net); + struct l2tp_session *session = NULL; + + /* Start searching within the range of the tid */ + if (*key == 0) + *key = l2tp_v2_session_key(tid, 0); + + rcu_read_lock_bh(); +again: + session = idr_get_next_ul(&pn->l2tp_v2_session_idr, key); + if (session) { + struct l2tp_tunnel *tunnel = READ_ONCE(session->tunnel); + + /* ignore sessions with id 0 as they are internal for pppol2tp */ + if (session->session_id == 0) { + (*key)++; + goto again; + } + + if (tunnel && tunnel->tunnel_id == tid && + refcount_inc_not_zero(&session->ref_count)) { + rcu_read_unlock_bh(); + return session; + } + + (*key)++; + if (tunnel->tunnel_id == tid) + goto again; + } + rcu_read_unlock_bh(); + + return NULL; +} + +static struct l2tp_session *l2tp_v3_session_get_next(const struct net *net, + u32 tid, struct sock *sk, + unsigned long *key) +{ + struct l2tp_net *pn = l2tp_pernet(net); + struct l2tp_session *session = NULL; + + rcu_read_lock_bh(); +again: + session = idr_get_next_ul(&pn->l2tp_v3_session_idr, key); + if (session && !hash_hashed(&session->hlist)) { + struct l2tp_tunnel *tunnel = READ_ONCE(session->tunnel); + + if (tunnel && tunnel->tunnel_id == tid && + refcount_inc_not_zero(&session->ref_count)) { + rcu_read_unlock_bh(); + return session; + } + + (*key)++; + goto again; + } + + /* If we get here and session is non-NULL, the IDR entry may be one + * where the session_id collides with one in another tunnel. Check + * session_htable for a match. There can only be one session of a given + * ID per tunnel so we can return as soon as a match is found. + */ + if (session && hash_hashed(&session->hlist)) { + unsigned long hkey = l2tp_v3_session_hashkey(sk, session->session_id); + u32 sid = session->session_id; + + hash_for_each_possible_rcu(pn->l2tp_v3_session_htable, session, + hlist, hkey) { + struct l2tp_tunnel *tunnel = READ_ONCE(session->tunnel); + + if (session->session_id == sid && + tunnel && tunnel->tunnel_id == tid && + refcount_inc_not_zero(&session->ref_count)) { + rcu_read_unlock_bh(); + return session; + } + } + + /* If no match found, the colliding session ID isn't in our + * tunnel so try the next session ID. + */ + (*key)++; + goto again; + } + + rcu_read_unlock_bh(); + + return NULL; +} + +struct l2tp_session *l2tp_session_get_next(const struct net *net, struct sock *sk, int pver, + u32 tunnel_id, unsigned long *key) +{ + if (pver == L2TP_HDR_VER_2) + return l2tp_v2_session_get_next(net, tunnel_id, key); + else + return l2tp_v3_session_get_next(net, tunnel_id, sk, key); +} +EXPORT_SYMBOL_GPL(l2tp_session_get_next); + /* Lookup a session by interface name. * This is very inefficient but is only used by management interfaces. */ diff --git a/net/l2tp/l2tp_core.h b/net/l2tp/l2tp_core.h index c907687705b9..cc464982a7d9 100644 --- a/net/l2tp/l2tp_core.h +++ b/net/l2tp/l2tp_core.h @@ -220,12 +220,15 @@ void l2tp_session_dec_refcount(struct l2tp_session *session); */ struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id); struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth); +struct l2tp_tunnel *l2tp_tunnel_get_next(const struct net *net, unsigned long *key); struct l2tp_session *l2tp_v3_session_get(const struct net *net, struct sock *sk, u32 session_id); struct l2tp_session *l2tp_v2_session_get(const struct net *net, u16 tunnel_id, u16 session_id); struct l2tp_session *l2tp_session_get(const struct net *net, struct sock *sk, int pver, u32 tunnel_id, u32 session_id); struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth); +struct l2tp_session *l2tp_session_get_next(const struct net *net, struct sock *sk, int pver, + u32 tunnel_id, unsigned long *key); struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net, const char *ifname);