psp_nl.c 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. // SPDX-License-Identifier: GPL-2.0-only
  2. #include <linux/ethtool.h>
  3. #include <linux/skbuff.h>
  4. #include <linux/xarray.h>
  5. #include <net/genetlink.h>
  6. #include <net/psp.h>
  7. #include <net/sock.h>
  8. #include "psp-nl-gen.h"
  9. #include "psp.h"
  10. /* Netlink helpers */
  11. static struct sk_buff *psp_nl_reply_new(struct genl_info *info)
  12. {
  13. struct sk_buff *rsp;
  14. void *hdr;
  15. rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
  16. if (!rsp)
  17. return NULL;
  18. hdr = genlmsg_iput(rsp, info);
  19. if (!hdr) {
  20. nlmsg_free(rsp);
  21. return NULL;
  22. }
  23. return rsp;
  24. }
  25. static int psp_nl_reply_send(struct sk_buff *rsp, struct genl_info *info)
  26. {
  27. /* Note that this *only* works with a single message per skb! */
  28. nlmsg_end(rsp, (struct nlmsghdr *)rsp->data);
  29. return genlmsg_reply(rsp, info);
  30. }
  31. /* Device stuff */
  32. static struct psp_dev *
  33. psp_device_get_and_lock(struct net *net, struct nlattr *dev_id)
  34. {
  35. struct psp_dev *psd;
  36. int err;
  37. mutex_lock(&psp_devs_lock);
  38. psd = xa_load(&psp_devs, nla_get_u32(dev_id));
  39. if (!psd) {
  40. mutex_unlock(&psp_devs_lock);
  41. return ERR_PTR(-ENODEV);
  42. }
  43. mutex_lock(&psd->lock);
  44. mutex_unlock(&psp_devs_lock);
  45. err = psp_dev_check_access(psd, net);
  46. if (err) {
  47. mutex_unlock(&psd->lock);
  48. return ERR_PTR(err);
  49. }
  50. return psd;
  51. }
  52. int psp_device_get_locked(const struct genl_split_ops *ops,
  53. struct sk_buff *skb, struct genl_info *info)
  54. {
  55. if (GENL_REQ_ATTR_CHECK(info, PSP_A_DEV_ID))
  56. return -EINVAL;
  57. info->user_ptr[0] = psp_device_get_and_lock(genl_info_net(info),
  58. info->attrs[PSP_A_DEV_ID]);
  59. return PTR_ERR_OR_ZERO(info->user_ptr[0]);
  60. }
  61. void
  62. psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb,
  63. struct genl_info *info)
  64. {
  65. struct socket *socket = info->user_ptr[1];
  66. struct psp_dev *psd = info->user_ptr[0];
  67. mutex_unlock(&psd->lock);
  68. if (socket)
  69. sockfd_put(socket);
  70. }
  71. static int
  72. psp_nl_dev_fill(struct psp_dev *psd, struct sk_buff *rsp,
  73. const struct genl_info *info)
  74. {
  75. void *hdr;
  76. hdr = genlmsg_iput(rsp, info);
  77. if (!hdr)
  78. return -EMSGSIZE;
  79. if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
  80. nla_put_u32(rsp, PSP_A_DEV_IFINDEX, psd->main_netdev->ifindex) ||
  81. nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_CAP, psd->caps->versions) ||
  82. nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_ENA, psd->config.versions))
  83. goto err_cancel_msg;
  84. genlmsg_end(rsp, hdr);
  85. return 0;
  86. err_cancel_msg:
  87. genlmsg_cancel(rsp, hdr);
  88. return -EMSGSIZE;
  89. }
  90. void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd)
  91. {
  92. struct genl_info info;
  93. struct sk_buff *ntf;
  94. if (!genl_has_listeners(&psp_nl_family, dev_net(psd->main_netdev),
  95. PSP_NLGRP_MGMT))
  96. return;
  97. ntf = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
  98. if (!ntf)
  99. return;
  100. genl_info_init_ntf(&info, &psp_nl_family, cmd);
  101. if (psp_nl_dev_fill(psd, ntf, &info)) {
  102. nlmsg_free(ntf);
  103. return;
  104. }
  105. genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
  106. 0, PSP_NLGRP_MGMT, GFP_KERNEL);
  107. }
  108. int psp_nl_dev_get_doit(struct sk_buff *req, struct genl_info *info)
  109. {
  110. struct psp_dev *psd = info->user_ptr[0];
  111. struct sk_buff *rsp;
  112. int err;
  113. rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
  114. if (!rsp)
  115. return -ENOMEM;
  116. err = psp_nl_dev_fill(psd, rsp, info);
  117. if (err)
  118. goto err_free_msg;
  119. return genlmsg_reply(rsp, info);
  120. err_free_msg:
  121. nlmsg_free(rsp);
  122. return err;
  123. }
  124. static int
  125. psp_nl_dev_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb,
  126. struct psp_dev *psd)
  127. {
  128. if (psp_dev_check_access(psd, sock_net(rsp->sk)))
  129. return 0;
  130. return psp_nl_dev_fill(psd, rsp, genl_info_dump(cb));
  131. }
  132. int psp_nl_dev_get_dumpit(struct sk_buff *rsp, struct netlink_callback *cb)
  133. {
  134. struct psp_dev *psd;
  135. int err = 0;
  136. mutex_lock(&psp_devs_lock);
  137. xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) {
  138. mutex_lock(&psd->lock);
  139. err = psp_nl_dev_get_dumpit_one(rsp, cb, psd);
  140. mutex_unlock(&psd->lock);
  141. if (err)
  142. break;
  143. }
  144. mutex_unlock(&psp_devs_lock);
  145. return err;
  146. }
  147. int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info)
  148. {
  149. struct psp_dev *psd = info->user_ptr[0];
  150. struct psp_dev_config new_config;
  151. struct sk_buff *rsp;
  152. int err;
  153. memcpy(&new_config, &psd->config, sizeof(new_config));
  154. if (info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]) {
  155. new_config.versions =
  156. nla_get_u32(info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]);
  157. if (new_config.versions & ~psd->caps->versions) {
  158. NL_SET_ERR_MSG(info->extack, "Requested PSP versions not supported by the device");
  159. return -EINVAL;
  160. }
  161. } else {
  162. NL_SET_ERR_MSG(info->extack, "No settings present");
  163. return -EINVAL;
  164. }
  165. rsp = psp_nl_reply_new(info);
  166. if (!rsp)
  167. return -ENOMEM;
  168. if (memcmp(&new_config, &psd->config, sizeof(new_config))) {
  169. err = psd->ops->set_config(psd, &new_config, info->extack);
  170. if (err)
  171. goto err_free_rsp;
  172. memcpy(&psd->config, &new_config, sizeof(new_config));
  173. }
  174. psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF);
  175. return psp_nl_reply_send(rsp, info);
  176. err_free_rsp:
  177. nlmsg_free(rsp);
  178. return err;
  179. }
  180. int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info)
  181. {
  182. struct psp_dev *psd = info->user_ptr[0];
  183. struct genl_info ntf_info;
  184. struct sk_buff *ntf, *rsp;
  185. u8 prev_gen;
  186. int err;
  187. rsp = psp_nl_reply_new(info);
  188. if (!rsp)
  189. return -ENOMEM;
  190. genl_info_init_ntf(&ntf_info, &psp_nl_family, PSP_CMD_KEY_ROTATE_NTF);
  191. ntf = psp_nl_reply_new(&ntf_info);
  192. if (!ntf) {
  193. err = -ENOMEM;
  194. goto err_free_rsp;
  195. }
  196. if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
  197. nla_put_u32(ntf, PSP_A_DEV_ID, psd->id)) {
  198. err = -EMSGSIZE;
  199. goto err_free_ntf;
  200. }
  201. /* suggest the next gen number, driver can override */
  202. prev_gen = psd->generation;
  203. psd->generation = (prev_gen + 1) & PSP_GEN_VALID_MASK;
  204. err = psd->ops->key_rotate(psd, info->extack);
  205. if (err)
  206. goto err_free_ntf;
  207. WARN_ON_ONCE((psd->generation && psd->generation == prev_gen) ||
  208. psd->generation & ~PSP_GEN_VALID_MASK);
  209. psp_assocs_key_rotated(psd);
  210. psd->stats.rotations++;
  211. nlmsg_end(ntf, (struct nlmsghdr *)ntf->data);
  212. genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
  213. 0, PSP_NLGRP_USE, GFP_KERNEL);
  214. return psp_nl_reply_send(rsp, info);
  215. err_free_ntf:
  216. nlmsg_free(ntf);
  217. err_free_rsp:
  218. nlmsg_free(rsp);
  219. return err;
  220. }
  221. /* Key etc. */
  222. int psp_assoc_device_get_locked(const struct genl_split_ops *ops,
  223. struct sk_buff *skb, struct genl_info *info)
  224. {
  225. struct socket *socket;
  226. struct psp_dev *psd;
  227. struct nlattr *id;
  228. int fd, err;
  229. if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_SOCK_FD))
  230. return -EINVAL;
  231. fd = nla_get_u32(info->attrs[PSP_A_ASSOC_SOCK_FD]);
  232. socket = sockfd_lookup(fd, &err);
  233. if (!socket)
  234. return err;
  235. if (!sk_is_tcp(socket->sk)) {
  236. NL_SET_ERR_MSG_ATTR(info->extack,
  237. info->attrs[PSP_A_ASSOC_SOCK_FD],
  238. "Unsupported socket family and type");
  239. err = -EOPNOTSUPP;
  240. goto err_sock_put;
  241. }
  242. psd = psp_dev_get_for_sock(socket->sk);
  243. if (psd) {
  244. err = psp_dev_check_access(psd, genl_info_net(info));
  245. if (err) {
  246. psp_dev_put(psd);
  247. psd = NULL;
  248. }
  249. }
  250. if (!psd && GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_DEV_ID)) {
  251. err = -EINVAL;
  252. goto err_sock_put;
  253. }
  254. id = info->attrs[PSP_A_ASSOC_DEV_ID];
  255. if (psd) {
  256. mutex_lock(&psd->lock);
  257. if (id && psd->id != nla_get_u32(id)) {
  258. mutex_unlock(&psd->lock);
  259. NL_SET_ERR_MSG_ATTR(info->extack, id,
  260. "Device id vs socket mismatch");
  261. err = -EINVAL;
  262. goto err_psd_put;
  263. }
  264. psp_dev_put(psd);
  265. } else {
  266. psd = psp_device_get_and_lock(genl_info_net(info), id);
  267. if (IS_ERR(psd)) {
  268. err = PTR_ERR(psd);
  269. goto err_sock_put;
  270. }
  271. }
  272. info->user_ptr[0] = psd;
  273. info->user_ptr[1] = socket;
  274. return 0;
  275. err_psd_put:
  276. psp_dev_put(psd);
  277. err_sock_put:
  278. sockfd_put(socket);
  279. return err;
  280. }
  281. static int
  282. psp_nl_parse_key(struct genl_info *info, u32 attr, struct psp_key_parsed *key,
  283. unsigned int key_sz)
  284. {
  285. struct nlattr *nest = info->attrs[attr];
  286. struct nlattr *tb[PSP_A_KEYS_SPI + 1];
  287. u32 spi;
  288. int err;
  289. err = nla_parse_nested(tb, ARRAY_SIZE(tb) - 1, nest,
  290. psp_keys_nl_policy, info->extack);
  291. if (err)
  292. return err;
  293. if (NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_KEY) ||
  294. NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_SPI))
  295. return -EINVAL;
  296. if (nla_len(tb[PSP_A_KEYS_KEY]) != key_sz) {
  297. NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
  298. "incorrect key length");
  299. return -EINVAL;
  300. }
  301. spi = nla_get_u32(tb[PSP_A_KEYS_SPI]);
  302. if (!(spi & PSP_SPI_KEY_ID)) {
  303. NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
  304. "invalid SPI: lower 31b must be non-zero");
  305. return -EINVAL;
  306. }
  307. key->spi = cpu_to_be32(spi);
  308. memcpy(key->key, nla_data(tb[PSP_A_KEYS_KEY]), key_sz);
  309. return 0;
  310. }
  311. static int
  312. psp_nl_put_key(struct sk_buff *skb, u32 attr, u32 version,
  313. struct psp_key_parsed *key)
  314. {
  315. int key_sz = psp_key_size(version);
  316. void *nest;
  317. nest = nla_nest_start(skb, attr);
  318. if (nla_put_u32(skb, PSP_A_KEYS_SPI, be32_to_cpu(key->spi)) ||
  319. nla_put(skb, PSP_A_KEYS_KEY, key_sz, key->key)) {
  320. nla_nest_cancel(skb, nest);
  321. return -EMSGSIZE;
  322. }
  323. nla_nest_end(skb, nest);
  324. return 0;
  325. }
  326. int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
  327. {
  328. struct socket *socket = info->user_ptr[1];
  329. struct psp_dev *psd = info->user_ptr[0];
  330. struct psp_key_parsed key;
  331. struct psp_assoc *pas;
  332. struct sk_buff *rsp;
  333. u32 version;
  334. int err;
  335. if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION))
  336. return -EINVAL;
  337. version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
  338. if (!(psd->caps->versions & (1 << version))) {
  339. NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
  340. return -EOPNOTSUPP;
  341. }
  342. rsp = psp_nl_reply_new(info);
  343. if (!rsp)
  344. return -ENOMEM;
  345. pas = psp_assoc_create(psd);
  346. if (!pas) {
  347. err = -ENOMEM;
  348. goto err_free_rsp;
  349. }
  350. pas->version = version;
  351. err = psd->ops->rx_spi_alloc(psd, version, &key, info->extack);
  352. if (err)
  353. goto err_free_pas;
  354. if (nla_put_u32(rsp, PSP_A_ASSOC_DEV_ID, psd->id) ||
  355. psp_nl_put_key(rsp, PSP_A_ASSOC_RX_KEY, version, &key)) {
  356. err = -EMSGSIZE;
  357. goto err_free_pas;
  358. }
  359. err = psp_sock_assoc_set_rx(socket->sk, pas, &key, info->extack);
  360. if (err) {
  361. NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_SOCK_FD]);
  362. goto err_free_pas;
  363. }
  364. psp_assoc_put(pas);
  365. return psp_nl_reply_send(rsp, info);
  366. err_free_pas:
  367. psp_assoc_put(pas);
  368. err_free_rsp:
  369. nlmsg_free(rsp);
  370. return err;
  371. }
  372. int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
  373. {
  374. struct socket *socket = info->user_ptr[1];
  375. struct psp_dev *psd = info->user_ptr[0];
  376. struct psp_key_parsed key;
  377. struct sk_buff *rsp;
  378. unsigned int key_sz;
  379. u32 version;
  380. int err;
  381. if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION) ||
  382. GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_TX_KEY))
  383. return -EINVAL;
  384. version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
  385. if (!(psd->caps->versions & (1 << version))) {
  386. NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
  387. return -EOPNOTSUPP;
  388. }
  389. key_sz = psp_key_size(version);
  390. if (!key_sz)
  391. return -EINVAL;
  392. err = psp_nl_parse_key(info, PSP_A_ASSOC_TX_KEY, &key, key_sz);
  393. if (err < 0)
  394. return err;
  395. rsp = psp_nl_reply_new(info);
  396. if (!rsp)
  397. return -ENOMEM;
  398. err = psp_sock_assoc_set_tx(socket->sk, psd, version, &key,
  399. info->extack);
  400. if (err)
  401. goto err_free_msg;
  402. return psp_nl_reply_send(rsp, info);
  403. err_free_msg:
  404. nlmsg_free(rsp);
  405. return err;
  406. }
  407. static int
  408. psp_nl_stats_fill(struct psp_dev *psd, struct sk_buff *rsp,
  409. const struct genl_info *info)
  410. {
  411. unsigned int required_cnt = sizeof(struct psp_dev_stats) / sizeof(u64);
  412. struct psp_dev_stats stats;
  413. void *hdr;
  414. int i;
  415. memset(&stats, 0xff, sizeof(stats));
  416. psd->ops->get_stats(psd, &stats);
  417. for (i = 0; i < required_cnt; i++)
  418. if (WARN_ON_ONCE(stats.required[i] == ETHTOOL_STAT_NOT_SET))
  419. return -EOPNOTSUPP;
  420. hdr = genlmsg_iput(rsp, info);
  421. if (!hdr)
  422. return -EMSGSIZE;
  423. if (nla_put_u32(rsp, PSP_A_STATS_DEV_ID, psd->id) ||
  424. nla_put_uint(rsp, PSP_A_STATS_KEY_ROTATIONS,
  425. psd->stats.rotations) ||
  426. nla_put_uint(rsp, PSP_A_STATS_STALE_EVENTS, psd->stats.stales) ||
  427. nla_put_uint(rsp, PSP_A_STATS_RX_PACKETS, stats.rx_packets) ||
  428. nla_put_uint(rsp, PSP_A_STATS_RX_BYTES, stats.rx_bytes) ||
  429. nla_put_uint(rsp, PSP_A_STATS_RX_AUTH_FAIL, stats.rx_auth_fail) ||
  430. nla_put_uint(rsp, PSP_A_STATS_RX_ERROR, stats.rx_error) ||
  431. nla_put_uint(rsp, PSP_A_STATS_RX_BAD, stats.rx_bad) ||
  432. nla_put_uint(rsp, PSP_A_STATS_TX_PACKETS, stats.tx_packets) ||
  433. nla_put_uint(rsp, PSP_A_STATS_TX_BYTES, stats.tx_bytes) ||
  434. nla_put_uint(rsp, PSP_A_STATS_TX_ERROR, stats.tx_error))
  435. goto err_cancel_msg;
  436. genlmsg_end(rsp, hdr);
  437. return 0;
  438. err_cancel_msg:
  439. genlmsg_cancel(rsp, hdr);
  440. return -EMSGSIZE;
  441. }
  442. int psp_nl_get_stats_doit(struct sk_buff *skb, struct genl_info *info)
  443. {
  444. struct psp_dev *psd = info->user_ptr[0];
  445. struct sk_buff *rsp;
  446. int err;
  447. rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
  448. if (!rsp)
  449. return -ENOMEM;
  450. err = psp_nl_stats_fill(psd, rsp, info);
  451. if (err)
  452. goto err_free_msg;
  453. return genlmsg_reply(rsp, info);
  454. err_free_msg:
  455. nlmsg_free(rsp);
  456. return err;
  457. }
  458. static int
  459. psp_nl_stats_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb,
  460. struct psp_dev *psd)
  461. {
  462. if (psp_dev_check_access(psd, sock_net(rsp->sk)))
  463. return 0;
  464. return psp_nl_stats_fill(psd, rsp, genl_info_dump(cb));
  465. }
  466. int psp_nl_get_stats_dumpit(struct sk_buff *rsp, struct netlink_callback *cb)
  467. {
  468. struct psp_dev *psd;
  469. int err = 0;
  470. mutex_lock(&psp_devs_lock);
  471. xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) {
  472. mutex_lock(&psd->lock);
  473. err = psp_nl_stats_get_dumpit_one(rsp, cb, psd);
  474. mutex_unlock(&psd->lock);
  475. if (err)
  476. break;
  477. }
  478. mutex_unlock(&psp_devs_lock);
  479. return err;
  480. }