mshv_regions.c 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. // SPDX-License-Identifier: GPL-2.0-only
  2. /*
  3. * Copyright (c) 2025, Microsoft Corporation.
  4. *
  5. * Memory region management for mshv_root module.
  6. *
  7. * Authors: Microsoft Linux virtualization team
  8. */
  9. #include <linux/hmm.h>
  10. #include <linux/hyperv.h>
  11. #include <linux/kref.h>
  12. #include <linux/mm.h>
  13. #include <linux/vmalloc.h>
  14. #include <asm/mshyperv.h>
  15. #include "mshv_root.h"
  16. #define MSHV_MAP_FAULT_IN_PAGES PTRS_PER_PMD
  17. /**
  18. * mshv_chunk_stride - Compute stride for mapping guest memory
  19. * @page : The page to check for huge page backing
  20. * @gfn : Guest frame number for the mapping
  21. * @page_count: Total number of pages in the mapping
  22. *
  23. * Determines the appropriate stride (in pages) for mapping guest memory.
  24. * Uses huge page stride if the backing page is huge and the guest mapping
  25. * is properly aligned; otherwise falls back to single page stride.
  26. *
  27. * Return: Stride in pages, or -EINVAL if page order is unsupported.
  28. */
  29. static int mshv_chunk_stride(struct page *page,
  30. u64 gfn, u64 page_count)
  31. {
  32. unsigned int page_order;
  33. /*
  34. * Use single page stride by default. For huge page stride, the
  35. * page must be compound and point to the head of the compound
  36. * page, and both gfn and page_count must be huge-page aligned.
  37. */
  38. if (!PageCompound(page) || !PageHead(page) ||
  39. !IS_ALIGNED(gfn, PTRS_PER_PMD) ||
  40. !IS_ALIGNED(page_count, PTRS_PER_PMD))
  41. return 1;
  42. page_order = folio_order(page_folio(page));
  43. /* The hypervisor only supports 2M huge page */
  44. if (page_order != PMD_ORDER)
  45. return -EINVAL;
  46. return 1 << page_order;
  47. }
  48. /**
  49. * mshv_region_process_chunk - Processes a contiguous chunk of memory pages
  50. * in a region.
  51. * @region : Pointer to the memory region structure.
  52. * @flags : Flags to pass to the handler.
  53. * @page_offset: Offset into the region's pages array to start processing.
  54. * @page_count : Number of pages to process.
  55. * @handler : Callback function to handle the chunk.
  56. *
  57. * This function scans the region's pages starting from @page_offset,
  58. * checking for contiguous present pages of the same size (normal or huge).
  59. * It invokes @handler for the chunk of contiguous pages found. Returns the
  60. * number of pages handled, or a negative error code if the first page is
  61. * not present or the handler fails.
  62. *
  63. * Note: The @handler callback must be able to handle both normal and huge
  64. * pages.
  65. *
  66. * Return: Number of pages handled, or negative error code.
  67. */
  68. static long mshv_region_process_chunk(struct mshv_mem_region *region,
  69. u32 flags,
  70. u64 page_offset, u64 page_count,
  71. int (*handler)(struct mshv_mem_region *region,
  72. u32 flags,
  73. u64 page_offset,
  74. u64 page_count,
  75. bool huge_page))
  76. {
  77. u64 gfn = region->start_gfn + page_offset;
  78. u64 count;
  79. struct page *page;
  80. int stride, ret;
  81. page = region->mreg_pages[page_offset];
  82. if (!page)
  83. return -EINVAL;
  84. stride = mshv_chunk_stride(page, gfn, page_count);
  85. if (stride < 0)
  86. return stride;
  87. /* Start at stride since the first stride is validated */
  88. for (count = stride; count < page_count; count += stride) {
  89. page = region->mreg_pages[page_offset + count];
  90. /* Break if current page is not present */
  91. if (!page)
  92. break;
  93. /* Break if stride size changes */
  94. if (stride != mshv_chunk_stride(page, gfn + count,
  95. page_count - count))
  96. break;
  97. }
  98. ret = handler(region, flags, page_offset, count, stride > 1);
  99. if (ret)
  100. return ret;
  101. return count;
  102. }
  103. /**
  104. * mshv_region_process_range - Processes a range of memory pages in a
  105. * region.
  106. * @region : Pointer to the memory region structure.
  107. * @flags : Flags to pass to the handler.
  108. * @page_offset: Offset into the region's pages array to start processing.
  109. * @page_count : Number of pages to process.
  110. * @handler : Callback function to handle each chunk of contiguous
  111. * pages.
  112. *
  113. * Iterates over the specified range of pages in @region, skipping
  114. * non-present pages. For each contiguous chunk of present pages, invokes
  115. * @handler via mshv_region_process_chunk.
  116. *
  117. * Note: The @handler callback must be able to handle both normal and huge
  118. * pages.
  119. *
  120. * Returns 0 on success, or a negative error code on failure.
  121. */
  122. static int mshv_region_process_range(struct mshv_mem_region *region,
  123. u32 flags,
  124. u64 page_offset, u64 page_count,
  125. int (*handler)(struct mshv_mem_region *region,
  126. u32 flags,
  127. u64 page_offset,
  128. u64 page_count,
  129. bool huge_page))
  130. {
  131. long ret;
  132. if (page_offset + page_count > region->nr_pages)
  133. return -EINVAL;
  134. while (page_count) {
  135. /* Skip non-present pages */
  136. if (!region->mreg_pages[page_offset]) {
  137. page_offset++;
  138. page_count--;
  139. continue;
  140. }
  141. ret = mshv_region_process_chunk(region, flags,
  142. page_offset,
  143. page_count,
  144. handler);
  145. if (ret < 0)
  146. return ret;
  147. page_offset += ret;
  148. page_count -= ret;
  149. }
  150. return 0;
  151. }
  152. struct mshv_mem_region *mshv_region_create(u64 guest_pfn, u64 nr_pages,
  153. u64 uaddr, u32 flags)
  154. {
  155. struct mshv_mem_region *region;
  156. region = vzalloc(sizeof(*region) + sizeof(struct page *) * nr_pages);
  157. if (!region)
  158. return ERR_PTR(-ENOMEM);
  159. region->nr_pages = nr_pages;
  160. region->start_gfn = guest_pfn;
  161. region->start_uaddr = uaddr;
  162. region->hv_map_flags = HV_MAP_GPA_READABLE | HV_MAP_GPA_ADJUSTABLE;
  163. if (flags & BIT(MSHV_SET_MEM_BIT_WRITABLE))
  164. region->hv_map_flags |= HV_MAP_GPA_WRITABLE;
  165. if (flags & BIT(MSHV_SET_MEM_BIT_EXECUTABLE))
  166. region->hv_map_flags |= HV_MAP_GPA_EXECUTABLE;
  167. kref_init(&region->mreg_refcount);
  168. return region;
  169. }
  170. static int mshv_region_chunk_share(struct mshv_mem_region *region,
  171. u32 flags,
  172. u64 page_offset, u64 page_count,
  173. bool huge_page)
  174. {
  175. if (huge_page)
  176. flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
  177. return hv_call_modify_spa_host_access(region->partition->pt_id,
  178. region->mreg_pages + page_offset,
  179. page_count,
  180. HV_MAP_GPA_READABLE |
  181. HV_MAP_GPA_WRITABLE,
  182. flags, true);
  183. }
  184. int mshv_region_share(struct mshv_mem_region *region)
  185. {
  186. u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_SHARED;
  187. return mshv_region_process_range(region, flags,
  188. 0, region->nr_pages,
  189. mshv_region_chunk_share);
  190. }
  191. static int mshv_region_chunk_unshare(struct mshv_mem_region *region,
  192. u32 flags,
  193. u64 page_offset, u64 page_count,
  194. bool huge_page)
  195. {
  196. if (huge_page)
  197. flags |= HV_MODIFY_SPA_PAGE_HOST_ACCESS_LARGE_PAGE;
  198. return hv_call_modify_spa_host_access(region->partition->pt_id,
  199. region->mreg_pages + page_offset,
  200. page_count, 0,
  201. flags, false);
  202. }
  203. int mshv_region_unshare(struct mshv_mem_region *region)
  204. {
  205. u32 flags = HV_MODIFY_SPA_PAGE_HOST_ACCESS_MAKE_EXCLUSIVE;
  206. return mshv_region_process_range(region, flags,
  207. 0, region->nr_pages,
  208. mshv_region_chunk_unshare);
  209. }
  210. static int mshv_region_chunk_remap(struct mshv_mem_region *region,
  211. u32 flags,
  212. u64 page_offset, u64 page_count,
  213. bool huge_page)
  214. {
  215. if (huge_page)
  216. flags |= HV_MAP_GPA_LARGE_PAGE;
  217. return hv_call_map_gpa_pages(region->partition->pt_id,
  218. region->start_gfn + page_offset,
  219. page_count, flags,
  220. region->mreg_pages + page_offset);
  221. }
  222. static int mshv_region_remap_pages(struct mshv_mem_region *region,
  223. u32 map_flags,
  224. u64 page_offset, u64 page_count)
  225. {
  226. return mshv_region_process_range(region, map_flags,
  227. page_offset, page_count,
  228. mshv_region_chunk_remap);
  229. }
  230. int mshv_region_map(struct mshv_mem_region *region)
  231. {
  232. u32 map_flags = region->hv_map_flags;
  233. return mshv_region_remap_pages(region, map_flags,
  234. 0, region->nr_pages);
  235. }
  236. static void mshv_region_invalidate_pages(struct mshv_mem_region *region,
  237. u64 page_offset, u64 page_count)
  238. {
  239. if (region->mreg_type == MSHV_REGION_TYPE_MEM_PINNED)
  240. unpin_user_pages(region->mreg_pages + page_offset, page_count);
  241. memset(region->mreg_pages + page_offset, 0,
  242. page_count * sizeof(struct page *));
  243. }
  244. void mshv_region_invalidate(struct mshv_mem_region *region)
  245. {
  246. mshv_region_invalidate_pages(region, 0, region->nr_pages);
  247. }
  248. int mshv_region_pin(struct mshv_mem_region *region)
  249. {
  250. u64 done_count, nr_pages;
  251. struct page **pages;
  252. __u64 userspace_addr;
  253. int ret;
  254. for (done_count = 0; done_count < region->nr_pages; done_count += ret) {
  255. pages = region->mreg_pages + done_count;
  256. userspace_addr = region->start_uaddr +
  257. done_count * HV_HYP_PAGE_SIZE;
  258. nr_pages = min(region->nr_pages - done_count,
  259. MSHV_PIN_PAGES_BATCH_SIZE);
  260. /*
  261. * Pinning assuming 4k pages works for large pages too.
  262. * All page structs within the large page are returned.
  263. *
  264. * Pin requests are batched because pin_user_pages_fast
  265. * with the FOLL_LONGTERM flag does a large temporary
  266. * allocation of contiguous memory.
  267. */
  268. ret = pin_user_pages_fast(userspace_addr, nr_pages,
  269. FOLL_WRITE | FOLL_LONGTERM,
  270. pages);
  271. if (ret != nr_pages)
  272. goto release_pages;
  273. }
  274. return 0;
  275. release_pages:
  276. if (ret > 0)
  277. done_count += ret;
  278. mshv_region_invalidate_pages(region, 0, done_count);
  279. return ret < 0 ? ret : -ENOMEM;
  280. }
  281. static int mshv_region_chunk_unmap(struct mshv_mem_region *region,
  282. u32 flags,
  283. u64 page_offset, u64 page_count,
  284. bool huge_page)
  285. {
  286. if (huge_page)
  287. flags |= HV_UNMAP_GPA_LARGE_PAGE;
  288. return hv_call_unmap_gpa_pages(region->partition->pt_id,
  289. region->start_gfn + page_offset,
  290. page_count, flags);
  291. }
  292. static int mshv_region_unmap(struct mshv_mem_region *region)
  293. {
  294. return mshv_region_process_range(region, 0,
  295. 0, region->nr_pages,
  296. mshv_region_chunk_unmap);
  297. }
  298. static void mshv_region_destroy(struct kref *ref)
  299. {
  300. struct mshv_mem_region *region =
  301. container_of(ref, struct mshv_mem_region, mreg_refcount);
  302. struct mshv_partition *partition = region->partition;
  303. int ret;
  304. if (region->mreg_type == MSHV_REGION_TYPE_MEM_MOVABLE)
  305. mshv_region_movable_fini(region);
  306. if (mshv_partition_encrypted(partition)) {
  307. ret = mshv_region_share(region);
  308. if (ret) {
  309. pt_err(partition,
  310. "Failed to regain access to memory, unpinning user pages will fail and crash the host error: %d\n",
  311. ret);
  312. return;
  313. }
  314. }
  315. mshv_region_unmap(region);
  316. mshv_region_invalidate(region);
  317. vfree(region);
  318. }
  319. void mshv_region_put(struct mshv_mem_region *region)
  320. {
  321. kref_put(&region->mreg_refcount, mshv_region_destroy);
  322. }
  323. int mshv_region_get(struct mshv_mem_region *region)
  324. {
  325. return kref_get_unless_zero(&region->mreg_refcount);
  326. }
  327. /**
  328. * mshv_region_hmm_fault_and_lock - Handle HMM faults and lock the memory region
  329. * @region: Pointer to the memory region structure
  330. * @range: Pointer to the HMM range structure
  331. *
  332. * This function performs the following steps:
  333. * 1. Reads the notifier sequence for the HMM range.
  334. * 2. Acquires a read lock on the memory map.
  335. * 3. Handles HMM faults for the specified range.
  336. * 4. Releases the read lock on the memory map.
  337. * 5. If successful, locks the memory region mutex.
  338. * 6. Verifies if the notifier sequence has changed during the operation.
  339. * If it has, releases the mutex and returns -EBUSY to match with
  340. * hmm_range_fault() return code for repeating.
  341. *
  342. * Return: 0 on success, a negative error code otherwise.
  343. */
  344. static int mshv_region_hmm_fault_and_lock(struct mshv_mem_region *region,
  345. struct hmm_range *range)
  346. {
  347. int ret;
  348. range->notifier_seq = mmu_interval_read_begin(range->notifier);
  349. mmap_read_lock(region->mreg_mni.mm);
  350. ret = hmm_range_fault(range);
  351. mmap_read_unlock(region->mreg_mni.mm);
  352. if (ret)
  353. return ret;
  354. mutex_lock(&region->mreg_mutex);
  355. if (mmu_interval_read_retry(range->notifier, range->notifier_seq)) {
  356. mutex_unlock(&region->mreg_mutex);
  357. cond_resched();
  358. return -EBUSY;
  359. }
  360. return 0;
  361. }
  362. /**
  363. * mshv_region_range_fault - Handle memory range faults for a given region.
  364. * @region: Pointer to the memory region structure.
  365. * @page_offset: Offset of the page within the region.
  366. * @page_count: Number of pages to handle.
  367. *
  368. * This function resolves memory faults for a specified range of pages
  369. * within a memory region. It uses HMM (Heterogeneous Memory Management)
  370. * to fault in the required pages and updates the region's page array.
  371. *
  372. * Return: 0 on success, negative error code on failure.
  373. */
  374. static int mshv_region_range_fault(struct mshv_mem_region *region,
  375. u64 page_offset, u64 page_count)
  376. {
  377. struct hmm_range range = {
  378. .notifier = &region->mreg_mni,
  379. .default_flags = HMM_PFN_REQ_FAULT | HMM_PFN_REQ_WRITE,
  380. };
  381. unsigned long *pfns;
  382. int ret;
  383. u64 i;
  384. pfns = kmalloc_array(page_count, sizeof(*pfns), GFP_KERNEL);
  385. if (!pfns)
  386. return -ENOMEM;
  387. range.hmm_pfns = pfns;
  388. range.start = region->start_uaddr + page_offset * HV_HYP_PAGE_SIZE;
  389. range.end = range.start + page_count * HV_HYP_PAGE_SIZE;
  390. do {
  391. ret = mshv_region_hmm_fault_and_lock(region, &range);
  392. } while (ret == -EBUSY);
  393. if (ret)
  394. goto out;
  395. for (i = 0; i < page_count; i++)
  396. region->mreg_pages[page_offset + i] = hmm_pfn_to_page(pfns[i]);
  397. ret = mshv_region_remap_pages(region, region->hv_map_flags,
  398. page_offset, page_count);
  399. mutex_unlock(&region->mreg_mutex);
  400. out:
  401. kfree(pfns);
  402. return ret;
  403. }
  404. bool mshv_region_handle_gfn_fault(struct mshv_mem_region *region, u64 gfn)
  405. {
  406. u64 page_offset, page_count;
  407. int ret;
  408. /* Align the page offset to the nearest MSHV_MAP_FAULT_IN_PAGES. */
  409. page_offset = ALIGN_DOWN(gfn - region->start_gfn,
  410. MSHV_MAP_FAULT_IN_PAGES);
  411. /* Map more pages than requested to reduce the number of faults. */
  412. page_count = min(region->nr_pages - page_offset,
  413. MSHV_MAP_FAULT_IN_PAGES);
  414. ret = mshv_region_range_fault(region, page_offset, page_count);
  415. WARN_ONCE(ret,
  416. "p%llu: GPA intercept failed: region %#llx-%#llx, gfn %#llx, page_offset %llu, page_count %llu\n",
  417. region->partition->pt_id, region->start_uaddr,
  418. region->start_uaddr + (region->nr_pages << HV_HYP_PAGE_SHIFT),
  419. gfn, page_offset, page_count);
  420. return !ret;
  421. }
  422. /**
  423. * mshv_region_interval_invalidate - Invalidate a range of memory region
  424. * @mni: Pointer to the mmu_interval_notifier structure
  425. * @range: Pointer to the mmu_notifier_range structure
  426. * @cur_seq: Current sequence number for the interval notifier
  427. *
  428. * This function invalidates a memory region by remapping its pages with
  429. * no access permissions. It locks the region's mutex to ensure thread safety
  430. * and updates the sequence number for the interval notifier. If the range
  431. * is blockable, it uses a blocking lock; otherwise, it attempts a non-blocking
  432. * lock and returns false if unsuccessful.
  433. *
  434. * NOTE: Failure to invalidate a region is a serious error, as the pages will
  435. * be considered freed while they are still mapped by the hypervisor.
  436. * Any attempt to access such pages will likely crash the system.
  437. *
  438. * Return: true if the region was successfully invalidated, false otherwise.
  439. */
  440. static bool mshv_region_interval_invalidate(struct mmu_interval_notifier *mni,
  441. const struct mmu_notifier_range *range,
  442. unsigned long cur_seq)
  443. {
  444. struct mshv_mem_region *region = container_of(mni,
  445. struct mshv_mem_region,
  446. mreg_mni);
  447. u64 page_offset, page_count;
  448. unsigned long mstart, mend;
  449. int ret = -EPERM;
  450. mstart = max(range->start, region->start_uaddr);
  451. mend = min(range->end, region->start_uaddr +
  452. (region->nr_pages << HV_HYP_PAGE_SHIFT));
  453. page_offset = HVPFN_DOWN(mstart - region->start_uaddr);
  454. page_count = HVPFN_DOWN(mend - mstart);
  455. if (mmu_notifier_range_blockable(range))
  456. mutex_lock(&region->mreg_mutex);
  457. else if (!mutex_trylock(&region->mreg_mutex))
  458. goto out_fail;
  459. mmu_interval_set_seq(mni, cur_seq);
  460. ret = mshv_region_remap_pages(region, HV_MAP_GPA_NO_ACCESS,
  461. page_offset, page_count);
  462. if (ret)
  463. goto out_unlock;
  464. mshv_region_invalidate_pages(region, page_offset, page_count);
  465. mutex_unlock(&region->mreg_mutex);
  466. return true;
  467. out_unlock:
  468. mutex_unlock(&region->mreg_mutex);
  469. out_fail:
  470. WARN_ONCE(ret,
  471. "Failed to invalidate region %#llx-%#llx (range %#lx-%#lx, event: %u, pages %#llx-%#llx, mm: %#llx): %d\n",
  472. region->start_uaddr,
  473. region->start_uaddr + (region->nr_pages << HV_HYP_PAGE_SHIFT),
  474. range->start, range->end, range->event,
  475. page_offset, page_offset + page_count - 1, (u64)range->mm, ret);
  476. return false;
  477. }
  478. static const struct mmu_interval_notifier_ops mshv_region_mni_ops = {
  479. .invalidate = mshv_region_interval_invalidate,
  480. };
  481. void mshv_region_movable_fini(struct mshv_mem_region *region)
  482. {
  483. mmu_interval_notifier_remove(&region->mreg_mni);
  484. }
  485. bool mshv_region_movable_init(struct mshv_mem_region *region)
  486. {
  487. int ret;
  488. ret = mmu_interval_notifier_insert(&region->mreg_mni, current->mm,
  489. region->start_uaddr,
  490. region->nr_pages << HV_HYP_PAGE_SHIFT,
  491. &mshv_region_mni_ops);
  492. if (ret)
  493. return false;
  494. mutex_init(&region->mreg_mutex);
  495. return true;
  496. }