psp_main.c 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. // SPDX-License-Identifier: GPL-2.0-only
  2. #include <linux/bitfield.h>
  3. #include <linux/list.h>
  4. #include <linux/netdevice.h>
  5. #include <linux/xarray.h>
  6. #include <net/net_namespace.h>
  7. #include <net/psp.h>
  8. #include <net/udp.h>
  9. #include "psp.h"
  10. #include "psp-nl-gen.h"
  11. DEFINE_XARRAY_ALLOC1(psp_devs);
  12. struct mutex psp_devs_lock;
  13. /**
  14. * DOC: PSP locking
  15. *
  16. * psp_devs_lock protects the psp_devs xarray.
  17. * Ordering is take the psp_devs_lock and then the instance lock.
  18. * Each instance is protected by RCU, and has a refcount.
  19. * When driver unregisters the instance gets flushed, but struct sticks around.
  20. */
  21. /**
  22. * psp_dev_check_access() - check if user in a given net ns can access PSP dev
  23. * @psd: PSP device structure user is trying to access
  24. * @net: net namespace user is in
  25. *
  26. * Return: 0 if PSP device should be visible in @net, errno otherwise.
  27. */
  28. int psp_dev_check_access(struct psp_dev *psd, struct net *net)
  29. {
  30. if (dev_net(psd->main_netdev) == net)
  31. return 0;
  32. return -ENOENT;
  33. }
  34. /**
  35. * psp_dev_create() - create and register PSP device
  36. * @netdev: main netdevice
  37. * @psd_ops: driver callbacks
  38. * @psd_caps: device capabilities
  39. * @priv_ptr: back-pointer to driver private data
  40. *
  41. * Return: pointer to allocated PSP device, or ERR_PTR.
  42. */
  43. struct psp_dev *
  44. psp_dev_create(struct net_device *netdev,
  45. struct psp_dev_ops *psd_ops, struct psp_dev_caps *psd_caps,
  46. void *priv_ptr)
  47. {
  48. struct psp_dev *psd;
  49. static u32 last_id;
  50. int err;
  51. if (WARN_ON(!psd_caps->versions ||
  52. !psd_ops->set_config ||
  53. !psd_ops->key_rotate ||
  54. !psd_ops->rx_spi_alloc ||
  55. !psd_ops->tx_key_add ||
  56. !psd_ops->tx_key_del ||
  57. !psd_ops->get_stats))
  58. return ERR_PTR(-EINVAL);
  59. psd = kzalloc_obj(*psd);
  60. if (!psd)
  61. return ERR_PTR(-ENOMEM);
  62. psd->main_netdev = netdev;
  63. psd->ops = psd_ops;
  64. psd->caps = psd_caps;
  65. psd->drv_priv = priv_ptr;
  66. mutex_init(&psd->lock);
  67. INIT_LIST_HEAD(&psd->active_assocs);
  68. INIT_LIST_HEAD(&psd->prev_assocs);
  69. INIT_LIST_HEAD(&psd->stale_assocs);
  70. refcount_set(&psd->refcnt, 1);
  71. mutex_lock(&psp_devs_lock);
  72. err = xa_alloc_cyclic(&psp_devs, &psd->id, psd, xa_limit_16b,
  73. &last_id, GFP_KERNEL);
  74. if (err) {
  75. mutex_unlock(&psp_devs_lock);
  76. kfree(psd);
  77. return ERR_PTR(err);
  78. }
  79. mutex_lock(&psd->lock);
  80. mutex_unlock(&psp_devs_lock);
  81. psp_nl_notify_dev(psd, PSP_CMD_DEV_ADD_NTF);
  82. rcu_assign_pointer(netdev->psp_dev, psd);
  83. mutex_unlock(&psd->lock);
  84. return psd;
  85. }
  86. EXPORT_SYMBOL(psp_dev_create);
  87. void psp_dev_free(struct psp_dev *psd)
  88. {
  89. mutex_lock(&psp_devs_lock);
  90. xa_erase(&psp_devs, psd->id);
  91. mutex_unlock(&psp_devs_lock);
  92. mutex_destroy(&psd->lock);
  93. kfree_rcu(psd, rcu);
  94. }
  95. /**
  96. * psp_dev_unregister() - unregister PSP device
  97. * @psd: PSP device structure
  98. */
  99. void psp_dev_unregister(struct psp_dev *psd)
  100. {
  101. struct psp_assoc *pas, *next;
  102. mutex_lock(&psp_devs_lock);
  103. mutex_lock(&psd->lock);
  104. psp_nl_notify_dev(psd, PSP_CMD_DEV_DEL_NTF);
  105. /* Wait until psp_dev_free() to call xa_erase() to prevent a
  106. * different psd from being added to the xarray with this id, while
  107. * there are still references to this psd being held.
  108. */
  109. xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL);
  110. mutex_unlock(&psp_devs_lock);
  111. list_splice_init(&psd->active_assocs, &psd->prev_assocs);
  112. list_splice_init(&psd->prev_assocs, &psd->stale_assocs);
  113. list_for_each_entry_safe(pas, next, &psd->stale_assocs, assocs_list)
  114. psp_dev_tx_key_del(psd, pas);
  115. rcu_assign_pointer(psd->main_netdev->psp_dev, NULL);
  116. psd->ops = NULL;
  117. psd->drv_priv = NULL;
  118. mutex_unlock(&psd->lock);
  119. psp_dev_put(psd);
  120. }
  121. EXPORT_SYMBOL(psp_dev_unregister);
  122. unsigned int psp_key_size(u32 version)
  123. {
  124. switch (version) {
  125. case PSP_VERSION_HDR0_AES_GCM_128:
  126. case PSP_VERSION_HDR0_AES_GMAC_128:
  127. return 16;
  128. case PSP_VERSION_HDR0_AES_GCM_256:
  129. case PSP_VERSION_HDR0_AES_GMAC_256:
  130. return 32;
  131. default:
  132. return 0;
  133. }
  134. }
  135. EXPORT_SYMBOL(psp_key_size);
  136. static void psp_write_headers(struct net *net, struct sk_buff *skb, __be32 spi,
  137. u8 ver, unsigned int udp_len, __be16 sport)
  138. {
  139. struct udphdr *uh = udp_hdr(skb);
  140. struct psphdr *psph = (struct psphdr *)(uh + 1);
  141. const struct sock *sk = skb->sk;
  142. uh->dest = htons(PSP_DEFAULT_UDP_PORT);
  143. /* A bit of theory: Selection of the source port.
  144. *
  145. * We need some entropy, so that multiple flows use different
  146. * source ports for better RSS spreading at the receiver.
  147. *
  148. * We also need that all packets belonging to one TCP flow
  149. * use the same source port through their duration,
  150. * so that all these packets land in the same receive queue.
  151. *
  152. * udp_flow_src_port() is using sk_txhash, inherited from
  153. * skb_set_hash_from_sk() call in __tcp_transmit_skb().
  154. * This field is subject to reshuffling, thanks to
  155. * sk_rethink_txhash() calls in various TCP functions.
  156. *
  157. * Instead, use sk->sk_hash which is constant through
  158. * the whole flow duration.
  159. */
  160. if (likely(sk)) {
  161. u32 hash = sk->sk_hash;
  162. int min, max;
  163. /* These operations are cheap, no need to cache the result
  164. * in another socket field.
  165. */
  166. inet_get_local_port_range(net, &min, &max);
  167. /* Since this is being sent on the wire obfuscate hash a bit
  168. * to minimize possibility that any useful information to an
  169. * attacker is leaked. Only upper 16 bits are relevant in the
  170. * computation for 16 bit port value because we use a
  171. * reciprocal divide.
  172. */
  173. hash ^= hash << 16;
  174. uh->source = htons((((u64)hash * (max - min)) >> 32) + min);
  175. } else {
  176. uh->source = udp_flow_src_port(net, skb, 0, 0, false);
  177. }
  178. uh->check = 0;
  179. uh->len = htons(udp_len);
  180. psph->nexthdr = IPPROTO_TCP;
  181. psph->hdrlen = PSP_HDRLEN_NOOPT;
  182. psph->crypt_offset = 0;
  183. psph->verfl = FIELD_PREP(PSPHDR_VERFL_VERSION, ver) |
  184. FIELD_PREP(PSPHDR_VERFL_ONE, 1);
  185. psph->spi = spi;
  186. memset(&psph->iv, 0, sizeof(psph->iv));
  187. }
  188. /* Encapsulate a TCP packet with PSP by adding the UDP+PSP headers and filling
  189. * them in.
  190. */
  191. bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
  192. u8 ver, __be16 sport)
  193. {
  194. u32 network_len = skb_network_header_len(skb);
  195. u32 ethr_len = skb_mac_header_len(skb);
  196. u32 bufflen = ethr_len + network_len;
  197. if (skb_cow_head(skb, PSP_ENCAP_HLEN))
  198. return false;
  199. skb_push(skb, PSP_ENCAP_HLEN);
  200. skb->mac_header -= PSP_ENCAP_HLEN;
  201. skb->network_header -= PSP_ENCAP_HLEN;
  202. skb->transport_header -= PSP_ENCAP_HLEN;
  203. memmove(skb->data, skb->data + PSP_ENCAP_HLEN, bufflen);
  204. if (skb->protocol == htons(ETH_P_IP)) {
  205. ip_hdr(skb)->protocol = IPPROTO_UDP;
  206. be16_add_cpu(&ip_hdr(skb)->tot_len, PSP_ENCAP_HLEN);
  207. ip_hdr(skb)->check = 0;
  208. ip_hdr(skb)->check =
  209. ip_fast_csum((u8 *)ip_hdr(skb), ip_hdr(skb)->ihl);
  210. } else if (skb->protocol == htons(ETH_P_IPV6)) {
  211. ipv6_hdr(skb)->nexthdr = IPPROTO_UDP;
  212. be16_add_cpu(&ipv6_hdr(skb)->payload_len, PSP_ENCAP_HLEN);
  213. } else {
  214. return false;
  215. }
  216. skb_set_inner_ipproto(skb, IPPROTO_TCP);
  217. skb_set_inner_transport_header(skb, skb_transport_offset(skb) +
  218. PSP_ENCAP_HLEN);
  219. skb->encapsulation = 1;
  220. psp_write_headers(net, skb, spi, ver,
  221. skb->len - skb_transport_offset(skb), sport);
  222. return true;
  223. }
  224. EXPORT_SYMBOL(psp_dev_encapsulate);
  225. /* Receive handler for PSP packets.
  226. *
  227. * Presently it accepts only already-authenticated packets and does not
  228. * support optional fields, such as virtualization cookies. The caller should
  229. * ensure that skb->data is pointing to the mac header, and that skb->mac_len
  230. * is set. This function does not currently adjust skb->csum (CHECKSUM_COMPLETE
  231. * is not supported).
  232. */
  233. int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv)
  234. {
  235. int l2_hlen = 0, l3_hlen, encap;
  236. struct psp_skb_ext *pse;
  237. struct psphdr *psph;
  238. struct ethhdr *eth;
  239. struct udphdr *uh;
  240. __be16 proto;
  241. bool is_udp;
  242. eth = (struct ethhdr *)skb->data;
  243. proto = __vlan_get_protocol(skb, eth->h_proto, &l2_hlen);
  244. if (proto == htons(ETH_P_IP))
  245. l3_hlen = sizeof(struct iphdr);
  246. else if (proto == htons(ETH_P_IPV6))
  247. l3_hlen = sizeof(struct ipv6hdr);
  248. else
  249. return -EINVAL;
  250. if (unlikely(!pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN)))
  251. return -EINVAL;
  252. if (proto == htons(ETH_P_IP)) {
  253. struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);
  254. is_udp = iph->protocol == IPPROTO_UDP;
  255. l3_hlen = iph->ihl * 4;
  256. if (l3_hlen != sizeof(struct iphdr) &&
  257. !pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN))
  258. return -EINVAL;
  259. } else {
  260. struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);
  261. is_udp = ipv6h->nexthdr == IPPROTO_UDP;
  262. }
  263. if (unlikely(!is_udp))
  264. return -EINVAL;
  265. uh = (struct udphdr *)(skb->data + l2_hlen + l3_hlen);
  266. if (unlikely(uh->dest != htons(PSP_DEFAULT_UDP_PORT)))
  267. return -EINVAL;
  268. pse = skb_ext_add(skb, SKB_EXT_PSP);
  269. if (!pse)
  270. return -EINVAL;
  271. psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen +
  272. sizeof(struct udphdr));
  273. pse->spi = psph->spi;
  274. pse->dev_id = dev_id;
  275. pse->generation = generation;
  276. pse->version = FIELD_GET(PSPHDR_VERFL_VERSION, psph->verfl);
  277. encap = PSP_ENCAP_HLEN;
  278. encap += strip_icv ? PSP_TRL_SIZE : 0;
  279. if (proto == htons(ETH_P_IP)) {
  280. struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);
  281. iph->protocol = psph->nexthdr;
  282. iph->tot_len = htons(ntohs(iph->tot_len) - encap);
  283. iph->check = 0;
  284. iph->check = ip_fast_csum((u8 *)iph, iph->ihl);
  285. } else {
  286. struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);
  287. ipv6h->nexthdr = psph->nexthdr;
  288. ipv6h->payload_len = htons(ntohs(ipv6h->payload_len) - encap);
  289. }
  290. memmove(skb->data + PSP_ENCAP_HLEN, skb->data, l2_hlen + l3_hlen);
  291. skb_pull(skb, PSP_ENCAP_HLEN);
  292. if (strip_icv)
  293. pskb_trim(skb, skb->len - PSP_TRL_SIZE);
  294. return 0;
  295. }
  296. EXPORT_SYMBOL(psp_dev_rcv);
  297. static int __init psp_init(void)
  298. {
  299. mutex_init(&psp_devs_lock);
  300. return genl_register_family(&psp_nl_family);
  301. }
  302. subsys_initcall(psp_init);