nsm.c 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. // SPDX-License-Identifier: GPL-2.0
  2. /*
  3. * Amazon Nitro Secure Module driver.
  4. *
  5. * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
  6. *
  7. * The Nitro Secure Module implements commands via CBOR over virtio.
  8. * This driver exposes a raw message ioctls on /dev/nsm that user
  9. * space can use to issue these commands.
  10. */
  11. #include <linux/file.h>
  12. #include <linux/fs.h>
  13. #include <linux/interrupt.h>
  14. #include <linux/hw_random.h>
  15. #include <linux/miscdevice.h>
  16. #include <linux/module.h>
  17. #include <linux/mutex.h>
  18. #include <linux/slab.h>
  19. #include <linux/string.h>
  20. #include <linux/uaccess.h>
  21. #include <linux/uio.h>
  22. #include <linux/virtio_config.h>
  23. #include <linux/virtio_ids.h>
  24. #include <linux/virtio.h>
  25. #include <linux/wait.h>
  26. #include <uapi/linux/nsm.h>
  27. /* Timeout for NSM virtqueue respose in milliseconds. */
  28. #define NSM_DEFAULT_TIMEOUT_MSECS (120000) /* 2 minutes */
  29. /* Maximum length input data */
  30. struct nsm_data_req {
  31. u32 len;
  32. u8 data[NSM_REQUEST_MAX_SIZE];
  33. };
  34. /* Maximum length output data */
  35. struct nsm_data_resp {
  36. u32 len;
  37. u8 data[NSM_RESPONSE_MAX_SIZE];
  38. };
  39. /* Full NSM request/response message */
  40. struct nsm_msg {
  41. struct nsm_data_req req;
  42. struct nsm_data_resp resp;
  43. };
  44. struct nsm {
  45. struct virtio_device *vdev;
  46. struct virtqueue *vq;
  47. struct mutex lock;
  48. struct completion cmd_done;
  49. struct miscdevice misc;
  50. struct hwrng hwrng;
  51. struct work_struct misc_init;
  52. struct nsm_msg msg;
  53. };
  54. /* NSM device ID */
  55. static const struct virtio_device_id id_table[] = {
  56. { VIRTIO_ID_NITRO_SEC_MOD, VIRTIO_DEV_ANY_ID },
  57. { 0 },
  58. };
  59. static struct nsm *file_to_nsm(struct file *file)
  60. {
  61. return container_of(file->private_data, struct nsm, misc);
  62. }
  63. static struct nsm *hwrng_to_nsm(struct hwrng *rng)
  64. {
  65. return container_of(rng, struct nsm, hwrng);
  66. }
  67. #define CBOR_TYPE_MASK 0xE0
  68. #define CBOR_TYPE_MAP 0xA0
  69. #define CBOR_TYPE_TEXT 0x60
  70. #define CBOR_TYPE_ARRAY 0x40
  71. #define CBOR_HEADER_SIZE_SHORT 1
  72. #define CBOR_SHORT_SIZE_MAX_VALUE 23
  73. #define CBOR_LONG_SIZE_U8 24
  74. #define CBOR_LONG_SIZE_U16 25
  75. #define CBOR_LONG_SIZE_U32 26
  76. #define CBOR_LONG_SIZE_U64 27
  77. static bool cbor_object_is_array(const u8 *cbor_object, size_t cbor_object_size)
  78. {
  79. if (cbor_object_size == 0 || cbor_object == NULL)
  80. return false;
  81. return (cbor_object[0] & CBOR_TYPE_MASK) == CBOR_TYPE_ARRAY;
  82. }
  83. static int cbor_object_get_array(u8 *cbor_object, size_t cbor_object_size, u8 **cbor_array)
  84. {
  85. u8 cbor_short_size;
  86. void *array_len_p;
  87. u64 array_len;
  88. u64 array_offset;
  89. if (!cbor_object_is_array(cbor_object, cbor_object_size))
  90. return -EFAULT;
  91. cbor_short_size = (cbor_object[0] & 0x1F);
  92. /* Decoding byte array length */
  93. array_offset = CBOR_HEADER_SIZE_SHORT;
  94. if (cbor_short_size >= CBOR_LONG_SIZE_U8)
  95. array_offset += BIT(cbor_short_size - CBOR_LONG_SIZE_U8);
  96. if (cbor_object_size < array_offset)
  97. return -EFAULT;
  98. array_len_p = &cbor_object[1];
  99. switch (cbor_short_size) {
  100. case CBOR_SHORT_SIZE_MAX_VALUE: /* short encoding */
  101. array_len = cbor_short_size;
  102. break;
  103. case CBOR_LONG_SIZE_U8:
  104. array_len = *(u8 *)array_len_p;
  105. break;
  106. case CBOR_LONG_SIZE_U16:
  107. array_len = be16_to_cpup((__be16 *)array_len_p);
  108. break;
  109. case CBOR_LONG_SIZE_U32:
  110. array_len = be32_to_cpup((__be32 *)array_len_p);
  111. break;
  112. case CBOR_LONG_SIZE_U64:
  113. array_len = be64_to_cpup((__be64 *)array_len_p);
  114. break;
  115. }
  116. if (cbor_object_size < array_offset)
  117. return -EFAULT;
  118. if (cbor_object_size - array_offset < array_len)
  119. return -EFAULT;
  120. if (array_len > INT_MAX)
  121. return -EFAULT;
  122. *cbor_array = cbor_object + array_offset;
  123. return array_len;
  124. }
  125. /* Copy the request of a raw message to kernel space */
  126. static int fill_req_raw(struct nsm *nsm, struct nsm_data_req *req,
  127. struct nsm_raw *raw)
  128. {
  129. /* Verify the user input size. */
  130. if (raw->request.len > sizeof(req->data))
  131. return -EMSGSIZE;
  132. /* Copy the request payload */
  133. if (copy_from_user(req->data, u64_to_user_ptr(raw->request.addr),
  134. raw->request.len))
  135. return -EFAULT;
  136. req->len = raw->request.len;
  137. return 0;
  138. }
  139. /* Copy the response of a raw message back to user-space */
  140. static int parse_resp_raw(struct nsm *nsm, struct nsm_data_resp *resp,
  141. struct nsm_raw *raw)
  142. {
  143. /* Truncate any message that does not fit. */
  144. raw->response.len = min_t(u64, raw->response.len, resp->len);
  145. /* Copy the response content to user space */
  146. if (copy_to_user(u64_to_user_ptr(raw->response.addr),
  147. resp->data, raw->response.len))
  148. return -EFAULT;
  149. return 0;
  150. }
  151. /* Virtqueue interrupt handler */
  152. static void nsm_vq_callback(struct virtqueue *vq)
  153. {
  154. struct nsm *nsm = vq->vdev->priv;
  155. complete(&nsm->cmd_done);
  156. }
  157. /* Forward a message to the NSM device and wait for the response from it */
  158. static int nsm_sendrecv_msg_locked(struct nsm *nsm)
  159. {
  160. struct device *dev = &nsm->vdev->dev;
  161. struct scatterlist sg_in, sg_out;
  162. struct nsm_msg *msg = &nsm->msg;
  163. struct virtqueue *vq = nsm->vq;
  164. unsigned int len;
  165. void *queue_buf;
  166. bool kicked;
  167. int rc;
  168. /* Initialize scatter-gather lists with request and response buffers. */
  169. sg_init_one(&sg_out, msg->req.data, msg->req.len);
  170. sg_init_one(&sg_in, msg->resp.data, sizeof(msg->resp.data));
  171. init_completion(&nsm->cmd_done);
  172. /* Add the request buffer (read by the device). */
  173. rc = virtqueue_add_outbuf(vq, &sg_out, 1, msg->req.data, GFP_KERNEL);
  174. if (rc)
  175. return rc;
  176. /* Add the response buffer (written by the device). */
  177. rc = virtqueue_add_inbuf(vq, &sg_in, 1, msg->resp.data, GFP_KERNEL);
  178. if (rc)
  179. goto cleanup;
  180. kicked = virtqueue_kick(vq);
  181. if (!kicked) {
  182. /* Cannot kick the virtqueue. */
  183. rc = -EIO;
  184. goto cleanup;
  185. }
  186. /* If the kick succeeded, wait for the device's response. */
  187. if (!wait_for_completion_io_timeout(&nsm->cmd_done,
  188. msecs_to_jiffies(NSM_DEFAULT_TIMEOUT_MSECS))) {
  189. rc = -ETIMEDOUT;
  190. goto cleanup;
  191. }
  192. queue_buf = virtqueue_get_buf(vq, &len);
  193. if (!queue_buf || (queue_buf != msg->req.data)) {
  194. dev_err(dev, "wrong request buffer.");
  195. rc = -ENODATA;
  196. goto cleanup;
  197. }
  198. queue_buf = virtqueue_get_buf(vq, &len);
  199. if (!queue_buf || (queue_buf != msg->resp.data)) {
  200. dev_err(dev, "wrong response buffer.");
  201. rc = -ENODATA;
  202. goto cleanup;
  203. }
  204. msg->resp.len = len;
  205. rc = 0;
  206. cleanup:
  207. if (rc) {
  208. /* Clean the virtqueue. */
  209. while (virtqueue_get_buf(vq, &len) != NULL)
  210. ;
  211. }
  212. return rc;
  213. }
  214. static int fill_req_get_random(struct nsm *nsm, struct nsm_data_req *req)
  215. {
  216. /*
  217. * 69 # text(9)
  218. * 47657452616E646F6D # "GetRandom"
  219. */
  220. const u8 request[] = { CBOR_TYPE_TEXT + strlen("GetRandom"),
  221. 'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm' };
  222. memcpy(req->data, request, sizeof(request));
  223. req->len = sizeof(request);
  224. return 0;
  225. }
  226. static int parse_resp_get_random(struct nsm *nsm, struct nsm_data_resp *resp,
  227. void *out, size_t max)
  228. {
  229. /*
  230. * A1 # map(1)
  231. * 69 # text(9) - Name of field
  232. * 47657452616E646F6D # "GetRandom"
  233. * A1 # map(1) - The field itself
  234. * 66 # text(6)
  235. * 72616E646F6D # "random"
  236. * # The rest of the response is random data
  237. */
  238. const u8 response[] = { CBOR_TYPE_MAP + 1,
  239. CBOR_TYPE_TEXT + strlen("GetRandom"),
  240. 'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm',
  241. CBOR_TYPE_MAP + 1,
  242. CBOR_TYPE_TEXT + strlen("random"),
  243. 'r', 'a', 'n', 'd', 'o', 'm' };
  244. struct device *dev = &nsm->vdev->dev;
  245. u8 *rand_data = NULL;
  246. u8 *resp_ptr = resp->data;
  247. u64 resp_len = resp->len;
  248. int rc;
  249. if ((resp->len < sizeof(response) + 1) ||
  250. (memcmp(resp_ptr, response, sizeof(response)) != 0)) {
  251. dev_err(dev, "Invalid response for GetRandom");
  252. return -EFAULT;
  253. }
  254. resp_ptr += sizeof(response);
  255. resp_len -= sizeof(response);
  256. rc = cbor_object_get_array(resp_ptr, resp_len, &rand_data);
  257. if (rc < 0) {
  258. dev_err(dev, "GetRandom: Invalid CBOR encoding\n");
  259. return rc;
  260. }
  261. rc = min_t(size_t, rc, max);
  262. memcpy(out, rand_data, rc);
  263. return rc;
  264. }
  265. /*
  266. * HwRNG implementation
  267. */
  268. static int nsm_rng_read(struct hwrng *rng, void *data, size_t max, bool wait)
  269. {
  270. struct nsm *nsm = hwrng_to_nsm(rng);
  271. struct device *dev = &nsm->vdev->dev;
  272. int rc = 0;
  273. /* NSM always needs to wait for a response */
  274. if (!wait)
  275. return 0;
  276. mutex_lock(&nsm->lock);
  277. rc = fill_req_get_random(nsm, &nsm->msg.req);
  278. if (rc != 0)
  279. goto out;
  280. rc = nsm_sendrecv_msg_locked(nsm);
  281. if (rc != 0)
  282. goto out;
  283. rc = parse_resp_get_random(nsm, &nsm->msg.resp, data, max);
  284. if (rc < 0)
  285. goto out;
  286. dev_dbg(dev, "RNG: returning rand bytes = %d", rc);
  287. out:
  288. mutex_unlock(&nsm->lock);
  289. return rc;
  290. }
  291. static long nsm_dev_ioctl(struct file *file, unsigned int cmd,
  292. unsigned long arg)
  293. {
  294. void __user *argp = u64_to_user_ptr((u64)arg);
  295. struct nsm *nsm = file_to_nsm(file);
  296. struct nsm_raw raw;
  297. int r = 0;
  298. if (cmd != NSM_IOCTL_RAW)
  299. return -EINVAL;
  300. if (_IOC_SIZE(cmd) != sizeof(raw))
  301. return -EINVAL;
  302. /* Copy user argument struct to kernel argument struct */
  303. r = -EFAULT;
  304. if (copy_from_user(&raw, argp, _IOC_SIZE(cmd)))
  305. goto out;
  306. mutex_lock(&nsm->lock);
  307. /* Convert kernel argument struct to device request */
  308. r = fill_req_raw(nsm, &nsm->msg.req, &raw);
  309. if (r)
  310. goto out;
  311. /* Send message to NSM and read reply */
  312. r = nsm_sendrecv_msg_locked(nsm);
  313. if (r)
  314. goto out;
  315. /* Parse device response into kernel argument struct */
  316. r = parse_resp_raw(nsm, &nsm->msg.resp, &raw);
  317. if (r)
  318. goto out;
  319. /* Copy kernel argument struct back to user argument struct */
  320. r = -EFAULT;
  321. if (copy_to_user(argp, &raw, sizeof(raw)))
  322. goto out;
  323. r = 0;
  324. out:
  325. mutex_unlock(&nsm->lock);
  326. return r;
  327. }
  328. static int nsm_device_init_vq(struct virtio_device *vdev)
  329. {
  330. struct virtqueue *vq = virtio_find_single_vq(vdev,
  331. nsm_vq_callback, "nsm.vq.0");
  332. struct nsm *nsm = vdev->priv;
  333. if (IS_ERR(vq))
  334. return PTR_ERR(vq);
  335. nsm->vq = vq;
  336. return 0;
  337. }
  338. static const struct file_operations nsm_dev_fops = {
  339. .unlocked_ioctl = nsm_dev_ioctl,
  340. .compat_ioctl = compat_ptr_ioctl,
  341. };
  342. /* Handler for probing the NSM device */
  343. static int nsm_device_probe(struct virtio_device *vdev)
  344. {
  345. struct device *dev = &vdev->dev;
  346. struct nsm *nsm;
  347. int rc;
  348. nsm = devm_kzalloc(&vdev->dev, sizeof(*nsm), GFP_KERNEL);
  349. if (!nsm)
  350. return -ENOMEM;
  351. vdev->priv = nsm;
  352. nsm->vdev = vdev;
  353. rc = nsm_device_init_vq(vdev);
  354. if (rc) {
  355. dev_err(dev, "queue failed to initialize: %d.\n", rc);
  356. goto err_init_vq;
  357. }
  358. mutex_init(&nsm->lock);
  359. /* Register as hwrng provider */
  360. nsm->hwrng = (struct hwrng) {
  361. .read = nsm_rng_read,
  362. .name = "nsm-hwrng",
  363. .quality = 1000,
  364. };
  365. rc = hwrng_register(&nsm->hwrng);
  366. if (rc) {
  367. dev_err(dev, "RNG initialization error: %d.\n", rc);
  368. goto err_hwrng;
  369. }
  370. /* Register /dev/nsm device node */
  371. nsm->misc = (struct miscdevice) {
  372. .minor = MISC_DYNAMIC_MINOR,
  373. .name = "nsm",
  374. .fops = &nsm_dev_fops,
  375. .mode = 0666,
  376. };
  377. rc = misc_register(&nsm->misc);
  378. if (rc) {
  379. dev_err(dev, "misc device registration error: %d.\n", rc);
  380. goto err_misc;
  381. }
  382. return 0;
  383. err_misc:
  384. hwrng_unregister(&nsm->hwrng);
  385. err_hwrng:
  386. vdev->config->del_vqs(vdev);
  387. err_init_vq:
  388. return rc;
  389. }
  390. /* Handler for removing the NSM device */
  391. static void nsm_device_remove(struct virtio_device *vdev)
  392. {
  393. struct nsm *nsm = vdev->priv;
  394. hwrng_unregister(&nsm->hwrng);
  395. vdev->config->del_vqs(vdev);
  396. misc_deregister(&nsm->misc);
  397. }
  398. /* NSM device configuration structure */
  399. static struct virtio_driver virtio_nsm_driver = {
  400. .feature_table = 0,
  401. .feature_table_size = 0,
  402. .feature_table_legacy = 0,
  403. .feature_table_size_legacy = 0,
  404. .driver.name = KBUILD_MODNAME,
  405. .id_table = id_table,
  406. .probe = nsm_device_probe,
  407. .remove = nsm_device_remove,
  408. };
  409. module_virtio_driver(virtio_nsm_driver);
  410. MODULE_DEVICE_TABLE(virtio, id_table);
  411. MODULE_DESCRIPTION("Virtio NSM driver");
  412. MODULE_LICENSE("GPL");