| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 |
- // SPDX-License-Identifier: GPL-2.0-only
- #include <linux/file.h>
- #include <linux/net.h>
- #include <linux/rcupdate.h>
- #include <linux/tcp.h>
- #include <net/ip.h>
- #include <net/psp.h>
- #include "psp.h"
- struct psp_dev *psp_dev_get_for_sock(struct sock *sk)
- {
- struct psp_dev *psd = NULL;
- struct dst_entry *dst;
- rcu_read_lock();
- dst = __sk_dst_get(sk);
- if (dst) {
- psd = rcu_dereference(dst_dev_rcu(dst)->psp_dev);
- if (psd && !psp_dev_tryget(psd))
- psd = NULL;
- }
- rcu_read_unlock();
- return psd;
- }
- static struct sk_buff *
- psp_validate_xmit(struct sock *sk, struct net_device *dev, struct sk_buff *skb)
- {
- struct psp_assoc *pas;
- bool good;
- rcu_read_lock();
- pas = psp_skb_get_assoc_rcu(skb);
- good = !pas || rcu_access_pointer(dev->psp_dev) == pas->psd;
- rcu_read_unlock();
- if (!good) {
- sk_skb_reason_drop(sk, skb, SKB_DROP_REASON_PSP_OUTPUT);
- return NULL;
- }
- return skb;
- }
- struct psp_assoc *psp_assoc_create(struct psp_dev *psd)
- {
- struct psp_assoc *pas;
- lockdep_assert_held(&psd->lock);
- pas = kzalloc_flex(*pas, drv_data, psd->caps->assoc_drv_spc,
- GFP_KERNEL_ACCOUNT);
- if (!pas)
- return NULL;
- pas->psd = psd;
- pas->dev_id = psd->id;
- pas->generation = psd->generation;
- psp_dev_get(psd);
- refcount_set(&pas->refcnt, 1);
- list_add_tail(&pas->assocs_list, &psd->active_assocs);
- return pas;
- }
- static struct psp_assoc *psp_assoc_dummy(struct psp_assoc *pas)
- {
- struct psp_dev *psd = pas->psd;
- size_t sz;
- lockdep_assert_held(&psd->lock);
- sz = struct_size(pas, drv_data, psd->caps->assoc_drv_spc);
- return kmemdup(pas, sz, GFP_KERNEL);
- }
- static int psp_dev_tx_key_add(struct psp_dev *psd, struct psp_assoc *pas,
- struct netlink_ext_ack *extack)
- {
- return psd->ops->tx_key_add(psd, pas, extack);
- }
- void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas)
- {
- if (pas->tx.spi)
- psd->ops->tx_key_del(psd, pas);
- list_del(&pas->assocs_list);
- }
- static void psp_assoc_free(struct work_struct *work)
- {
- struct psp_assoc *pas = container_of(work, struct psp_assoc, work);
- struct psp_dev *psd = pas->psd;
- mutex_lock(&psd->lock);
- if (psd->ops)
- psp_dev_tx_key_del(psd, pas);
- mutex_unlock(&psd->lock);
- psp_dev_put(psd);
- kfree(pas);
- }
- static void psp_assoc_free_queue(struct rcu_head *head)
- {
- struct psp_assoc *pas = container_of(head, struct psp_assoc, rcu);
- INIT_WORK(&pas->work, psp_assoc_free);
- schedule_work(&pas->work);
- }
- /**
- * psp_assoc_put() - release a reference on a PSP association
- * @pas: association to release
- */
- void psp_assoc_put(struct psp_assoc *pas)
- {
- if (pas && refcount_dec_and_test(&pas->refcnt))
- call_rcu(&pas->rcu, psp_assoc_free_queue);
- }
- void psp_sk_assoc_free(struct sock *sk)
- {
- struct psp_assoc *pas = rcu_dereference_protected(sk->psp_assoc, 1);
- rcu_assign_pointer(sk->psp_assoc, NULL);
- psp_assoc_put(pas);
- }
- int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas,
- struct psp_key_parsed *key,
- struct netlink_ext_ack *extack)
- {
- int err;
- memcpy(&pas->rx, key, sizeof(*key));
- lock_sock(sk);
- if (psp_sk_assoc(sk)) {
- NL_SET_ERR_MSG(extack, "Socket already has PSP state");
- err = -EBUSY;
- goto exit_unlock;
- }
- refcount_inc(&pas->refcnt);
- rcu_assign_pointer(sk->psp_assoc, pas);
- err = 0;
- exit_unlock:
- release_sock(sk);
- return err;
- }
- static int psp_sock_recv_queue_check(struct sock *sk, struct psp_assoc *pas)
- {
- struct psp_skb_ext *pse;
- struct sk_buff *skb;
- skb_rbtree_walk(skb, &tcp_sk(sk)->out_of_order_queue) {
- pse = skb_ext_find(skb, SKB_EXT_PSP);
- if (!psp_pse_matches_pas(pse, pas))
- return -EBUSY;
- }
- skb_queue_walk(&sk->sk_receive_queue, skb) {
- pse = skb_ext_find(skb, SKB_EXT_PSP);
- if (!psp_pse_matches_pas(pse, pas))
- return -EBUSY;
- }
- return 0;
- }
- int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd,
- u32 version, struct psp_key_parsed *key,
- struct netlink_ext_ack *extack)
- {
- struct inet_connection_sock *icsk;
- struct psp_assoc *pas, *dummy;
- int err;
- lock_sock(sk);
- pas = psp_sk_assoc(sk);
- if (!pas) {
- NL_SET_ERR_MSG(extack, "Socket has no Rx key");
- err = -EINVAL;
- goto exit_unlock;
- }
- if (pas->psd != psd) {
- NL_SET_ERR_MSG(extack, "Rx key from different device");
- err = -EINVAL;
- goto exit_unlock;
- }
- if (pas->version != version) {
- NL_SET_ERR_MSG(extack,
- "PSP version mismatch with existing state");
- err = -EINVAL;
- goto exit_unlock;
- }
- if (pas->tx.spi) {
- NL_SET_ERR_MSG(extack, "Tx key already set");
- err = -EBUSY;
- goto exit_unlock;
- }
- err = psp_sock_recv_queue_check(sk, pas);
- if (err) {
- NL_SET_ERR_MSG(extack, "Socket has incompatible segments already in the recv queue");
- goto exit_unlock;
- }
- /* Pass a fake association to drivers to make sure they don't
- * try to store pointers to it. For re-keying we'll need to
- * re-allocate the assoc structures.
- */
- dummy = psp_assoc_dummy(pas);
- if (!dummy) {
- err = -ENOMEM;
- goto exit_unlock;
- }
- memcpy(&dummy->tx, key, sizeof(*key));
- err = psp_dev_tx_key_add(psd, dummy, extack);
- if (err)
- goto exit_free_dummy;
- memcpy(pas->drv_data, dummy->drv_data, psd->caps->assoc_drv_spc);
- memcpy(&pas->tx, key, sizeof(*key));
- WRITE_ONCE(sk->sk_validate_xmit_skb, psp_validate_xmit);
- tcp_write_collapse_fence(sk);
- pas->upgrade_seq = tcp_sk(sk)->rcv_nxt;
- icsk = inet_csk(sk);
- icsk->icsk_ext_hdr_len += psp_sk_overhead(sk);
- icsk->icsk_sync_mss(sk, icsk->icsk_pmtu_cookie);
- exit_free_dummy:
- kfree(dummy);
- exit_unlock:
- release_sock(sk);
- return err;
- }
- void psp_assocs_key_rotated(struct psp_dev *psd)
- {
- struct psp_assoc *pas, *next;
- /* Mark the stale associations as invalid, they will no longer
- * be able to Rx any traffic.
- */
- list_for_each_entry_safe(pas, next, &psd->prev_assocs, assocs_list) {
- pas->generation |= ~PSP_GEN_VALID_MASK;
- psd->stats.stales++;
- }
- list_splice_init(&psd->prev_assocs, &psd->stale_assocs);
- list_splice_init(&psd->active_assocs, &psd->prev_assocs);
- /* TODO: we should inform the sockets that got shut down */
- }
- void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk)
- {
- struct psp_assoc *pas = psp_sk_assoc(sk);
- if (pas)
- refcount_inc(&pas->refcnt);
- rcu_assign_pointer(tw->psp_assoc, pas);
- tw->tw_validate_xmit_skb = psp_validate_xmit;
- }
- void psp_twsk_assoc_free(struct inet_timewait_sock *tw)
- {
- struct psp_assoc *pas = rcu_dereference_protected(tw->psp_assoc, 1);
- rcu_assign_pointer(tw->psp_assoc, NULL);
- psp_assoc_put(pas);
- }
- void psp_reply_set_decrypted(const struct sock *sk, struct sk_buff *skb)
- {
- struct psp_assoc *pas;
- rcu_read_lock();
- pas = psp_sk_get_assoc_rcu(sk);
- if (pas && pas->tx.spi)
- skb->decrypted = 1;
- rcu_read_unlock();
- }
- EXPORT_IPV6_MOD_GPL(psp_reply_set_decrypted);
|