kunit.rs 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. // SPDX-License-Identifier: GPL-2.0
  2. //! Procedural macro to run KUnit tests using a user-space like syntax.
  3. //!
  4. //! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com>
  5. use std::ffi::CString;
  6. use proc_macro2::TokenStream;
  7. use quote::{
  8. format_ident,
  9. quote,
  10. ToTokens, //
  11. };
  12. use syn::{
  13. parse_quote,
  14. Error,
  15. Ident,
  16. Item,
  17. ItemMod,
  18. LitCStr,
  19. Result, //
  20. };
  21. pub(crate) fn kunit_tests(test_suite: Ident, mut module: ItemMod) -> Result<TokenStream> {
  22. if test_suite.to_string().len() > 255 {
  23. return Err(Error::new_spanned(
  24. test_suite,
  25. "test suite names cannot exceed the maximum length of 255 bytes",
  26. ));
  27. }
  28. // We cannot handle modules that defer to another file (e.g. `mod foo;`).
  29. let Some((module_brace, module_items)) = module.content.take() else {
  30. Err(Error::new_spanned(
  31. module,
  32. "`#[kunit_tests(test_name)]` attribute should only be applied to inline modules",
  33. ))?
  34. };
  35. // Make the entire module gated behind `CONFIG_KUNIT`.
  36. module
  37. .attrs
  38. .insert(0, parse_quote!(#[cfg(CONFIG_KUNIT="y")]));
  39. let mut processed_items = Vec::new();
  40. let mut test_cases = Vec::new();
  41. // Generate the test KUnit test suite and a test case for each `#[test]`.
  42. //
  43. // The code generated for the following test module:
  44. //
  45. // ```
  46. // #[kunit_tests(kunit_test_suit_name)]
  47. // mod tests {
  48. // #[test]
  49. // fn foo() {
  50. // assert_eq!(1, 1);
  51. // }
  52. //
  53. // #[test]
  54. // fn bar() {
  55. // assert_eq!(2, 2);
  56. // }
  57. // }
  58. // ```
  59. //
  60. // Looks like:
  61. //
  62. // ```
  63. // unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut ::kernel::bindings::kunit) { foo(); }
  64. // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut ::kernel::bindings::kunit) { bar(); }
  65. //
  66. // static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [
  67. // ::kernel::kunit::kunit_case(c"foo", kunit_rust_wrapper_foo),
  68. // ::kernel::kunit::kunit_case(c"bar", kunit_rust_wrapper_bar),
  69. // ::pin_init::zeroed(),
  70. // ];
  71. //
  72. // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES);
  73. // ```
  74. //
  75. // Non-function items (e.g. imports) are preserved.
  76. for item in module_items {
  77. let Item::Fn(mut f) = item else {
  78. processed_items.push(item);
  79. continue;
  80. };
  81. // TODO: Replace below with `extract_if` when MSRV is bumped above 1.85.
  82. let before_len = f.attrs.len();
  83. f.attrs.retain(|attr| !attr.path().is_ident("test"));
  84. if f.attrs.len() == before_len {
  85. processed_items.push(Item::Fn(f));
  86. continue;
  87. }
  88. let test = f.sig.ident.clone();
  89. // Retrieve `#[cfg]` applied on the function which needs to be present on derived items too.
  90. let cfg_attrs: Vec<_> = f
  91. .attrs
  92. .iter()
  93. .filter(|attr| attr.path().is_ident("cfg"))
  94. .cloned()
  95. .collect();
  96. // Before the test, override usual `assert!` and `assert_eq!` macros with ones that call
  97. // KUnit instead.
  98. let test_str = test.to_string();
  99. let path = CString::new(crate::helpers::file()).expect("file path cannot contain NUL");
  100. processed_items.push(parse_quote! {
  101. #[allow(unused)]
  102. macro_rules! assert {
  103. ($cond:expr $(,)?) => {{
  104. kernel::kunit_assert!(#test_str, #path, 0, $cond);
  105. }}
  106. }
  107. });
  108. processed_items.push(parse_quote! {
  109. #[allow(unused)]
  110. macro_rules! assert_eq {
  111. ($left:expr, $right:expr $(,)?) => {{
  112. kernel::kunit_assert_eq!(#test_str, #path, 0, $left, $right);
  113. }}
  114. }
  115. });
  116. // Add back the test item.
  117. processed_items.push(Item::Fn(f));
  118. let kunit_wrapper_fn_name = format_ident!("kunit_rust_wrapper_{test}");
  119. let test_cstr = LitCStr::new(
  120. &CString::new(test_str.as_str()).expect("identifier cannot contain NUL"),
  121. test.span(),
  122. );
  123. processed_items.push(parse_quote! {
  124. unsafe extern "C" fn #kunit_wrapper_fn_name(_test: *mut ::kernel::bindings::kunit) {
  125. (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
  126. // Append any `cfg` attributes the user might have written on their tests so we
  127. // don't attempt to call them when they are `cfg`'d out. An extra `use` is used
  128. // here to reduce the length of the assert message.
  129. #(#cfg_attrs)*
  130. {
  131. (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;
  132. use ::kernel::kunit::is_test_result_ok;
  133. assert!(is_test_result_ok(#test()));
  134. }
  135. }
  136. });
  137. test_cases.push(quote!(
  138. ::kernel::kunit::kunit_case(#test_cstr, #kunit_wrapper_fn_name)
  139. ));
  140. }
  141. let num_tests_plus_1 = test_cases.len() + 1;
  142. processed_items.push(parse_quote! {
  143. static mut TEST_CASES: [::kernel::bindings::kunit_case; #num_tests_plus_1] = [
  144. #(#test_cases,)*
  145. ::pin_init::zeroed(),
  146. ];
  147. });
  148. processed_items.push(parse_quote! {
  149. ::kernel::kunit_unsafe_test_suite!(#test_suite, TEST_CASES);
  150. });
  151. module.content = Some((module_brace, processed_items));
  152. Ok(module.to_token_stream())
  153. }