functions.h 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. /* SPDX-License-Identifier: GPL-2.0-only */
  2. #ifndef __NET_PSP_HELPERS_H
  3. #define __NET_PSP_HELPERS_H
  4. #include <linux/skbuff.h>
  5. #include <linux/rcupdate.h>
  6. #include <linux/udp.h>
  7. #include <net/sock.h>
  8. #include <net/tcp.h>
  9. #include <net/psp/types.h>
  10. struct inet_timewait_sock;
  11. /* Driver-facing API */
  12. struct psp_dev *
  13. psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops,
  14. struct psp_dev_caps *psd_caps, void *priv_ptr);
  15. void psp_dev_unregister(struct psp_dev *psd);
  16. bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
  17. u8 ver, __be16 sport);
  18. int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv);
  19. /* Kernel-facing API */
  20. void psp_assoc_put(struct psp_assoc *pas);
  21. static inline void *psp_assoc_drv_data(struct psp_assoc *pas)
  22. {
  23. return pas->drv_data;
  24. }
  25. #if IS_ENABLED(CONFIG_INET_PSP)
  26. unsigned int psp_key_size(u32 version);
  27. void psp_sk_assoc_free(struct sock *sk);
  28. void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk);
  29. void psp_twsk_assoc_free(struct inet_timewait_sock *tw);
  30. void psp_reply_set_decrypted(const struct sock *sk, struct sk_buff *skb);
  31. static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
  32. {
  33. return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk));
  34. }
  35. static inline void
  36. psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb)
  37. {
  38. struct psp_assoc *pas;
  39. pas = psp_sk_assoc(sk);
  40. if (pas && pas->tx.spi)
  41. skb->decrypted = 1;
  42. }
  43. static inline unsigned long
  44. __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
  45. unsigned long diffs)
  46. {
  47. struct psp_skb_ext *a, *b;
  48. a = skb_ext_find(one, SKB_EXT_PSP);
  49. b = skb_ext_find(two, SKB_EXT_PSP);
  50. diffs |= (!!a) ^ (!!b);
  51. if (!diffs && unlikely(a))
  52. diffs |= memcmp(a, b, sizeof(*a));
  53. return diffs;
  54. }
  55. static inline bool
  56. psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas)
  57. {
  58. bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN);
  59. u32 end_seq = TCP_SKB_CB(skb)->end_seq;
  60. u32 seq = TCP_SKB_CB(skb)->seq;
  61. bool pure_fin;
  62. pure_fin = fin && end_seq - seq == 1;
  63. return seq == end_seq || (pure_fin && seq == pas->upgrade_seq);
  64. }
  65. static inline bool
  66. psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas)
  67. {
  68. return pse && pas->rx.spi == pse->spi &&
  69. pas->generation == pse->generation &&
  70. pas->version == pse->version &&
  71. pas->dev_id == pse->dev_id;
  72. }
  73. static inline enum skb_drop_reason
  74. __psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas)
  75. {
  76. struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP);
  77. if (!pas)
  78. return pse ? SKB_DROP_REASON_PSP_INPUT : 0;
  79. if (likely(psp_pse_matches_pas(pse, pas))) {
  80. if (unlikely(!pas->peer_tx))
  81. pas->peer_tx = 1;
  82. return 0;
  83. }
  84. if (!pse) {
  85. if (!pas->tx.spi ||
  86. (!pas->peer_tx && psp_is_allowed_nondata(skb, pas)))
  87. return 0;
  88. }
  89. return SKB_DROP_REASON_PSP_INPUT;
  90. }
  91. static inline enum skb_drop_reason
  92. psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
  93. {
  94. return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk));
  95. }
  96. static inline enum skb_drop_reason
  97. psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
  98. {
  99. return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc));
  100. }
  101. static inline struct psp_assoc *psp_sk_get_assoc_rcu(const struct sock *sk)
  102. {
  103. struct psp_assoc *pas;
  104. int state;
  105. state = READ_ONCE(sk->sk_state);
  106. if (!sk_is_inet(sk) || state == TCP_NEW_SYN_RECV)
  107. return NULL;
  108. pas = state == TCP_TIME_WAIT ?
  109. rcu_dereference(inet_twsk(sk)->psp_assoc) :
  110. rcu_dereference(sk->psp_assoc);
  111. return pas;
  112. }
  113. static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
  114. {
  115. if (!skb->decrypted || !skb->sk)
  116. return NULL;
  117. return psp_sk_get_assoc_rcu(skb->sk);
  118. }
  119. static inline unsigned int psp_sk_overhead(const struct sock *sk)
  120. {
  121. int psp_encap = sizeof(struct udphdr) + PSP_HDR_SIZE + PSP_TRL_SIZE;
  122. bool has_psp = rcu_access_pointer(sk->psp_assoc);
  123. return has_psp ? psp_encap : 0;
  124. }
  125. #else
  126. static inline void psp_sk_assoc_free(struct sock *sk) { }
  127. static inline void
  128. psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { }
  129. static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { }
  130. static inline void
  131. psp_reply_set_decrypted(const struct sock *sk, struct sk_buff *skb) { }
  132. static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
  133. {
  134. return NULL;
  135. }
  136. static inline void
  137. psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { }
  138. static inline unsigned long
  139. __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
  140. unsigned long diffs)
  141. {
  142. return diffs;
  143. }
  144. static inline enum skb_drop_reason
  145. psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
  146. {
  147. return 0;
  148. }
  149. static inline enum skb_drop_reason
  150. psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
  151. {
  152. return 0;
  153. }
  154. static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
  155. {
  156. return NULL;
  157. }
  158. static inline unsigned int psp_sk_overhead(const struct sock *sk)
  159. {
  160. return 0;
  161. }
  162. #endif
  163. static inline unsigned long
  164. psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two)
  165. {
  166. return __psp_skb_coalesce_diff(one, two, 0);
  167. }
  168. #endif /* __NET_PSP_HELPERS_H */