mldsa_kunit.c 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. // SPDX-License-Identifier: GPL-2.0-or-later
  2. /*
  3. * KUnit tests and benchmark for ML-DSA
  4. *
  5. * Copyright 2025 Google LLC
  6. */
  7. #include <crypto/mldsa.h>
  8. #include <kunit/test.h>
  9. #include <linux/random.h>
  10. #include <linux/unaligned.h>
  11. #define Q 8380417 /* The prime q = 2^23 - 2^13 + 1 */
  12. /* ML-DSA parameters that the tests use */
  13. static const struct {
  14. int sig_len;
  15. int pk_len;
  16. int k;
  17. int lambda;
  18. int gamma1;
  19. int beta;
  20. int omega;
  21. } params[] = {
  22. [MLDSA44] = {
  23. .sig_len = MLDSA44_SIGNATURE_SIZE,
  24. .pk_len = MLDSA44_PUBLIC_KEY_SIZE,
  25. .k = 4,
  26. .lambda = 128,
  27. .gamma1 = 1 << 17,
  28. .beta = 78,
  29. .omega = 80,
  30. },
  31. [MLDSA65] = {
  32. .sig_len = MLDSA65_SIGNATURE_SIZE,
  33. .pk_len = MLDSA65_PUBLIC_KEY_SIZE,
  34. .k = 6,
  35. .lambda = 192,
  36. .gamma1 = 1 << 19,
  37. .beta = 196,
  38. .omega = 55,
  39. },
  40. [MLDSA87] = {
  41. .sig_len = MLDSA87_SIGNATURE_SIZE,
  42. .pk_len = MLDSA87_PUBLIC_KEY_SIZE,
  43. .k = 8,
  44. .lambda = 256,
  45. .gamma1 = 1 << 19,
  46. .beta = 120,
  47. .omega = 75,
  48. },
  49. };
  50. #include "mldsa-testvecs.h"
  51. static void do_mldsa_and_assert_success(struct kunit *test,
  52. const struct mldsa_testvector *tv)
  53. {
  54. int err = mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
  55. tv->msg_len, tv->pk, tv->pk_len);
  56. KUNIT_ASSERT_EQ(test, err, 0);
  57. }
  58. static u8 *kunit_kmemdup_or_fail(struct kunit *test, const u8 *src, size_t len)
  59. {
  60. u8 *dst = kunit_kmalloc(test, len, GFP_KERNEL);
  61. KUNIT_ASSERT_NOT_NULL(test, dst);
  62. return memcpy(dst, src, len);
  63. }
  64. /*
  65. * Test that changing coefficients in a valid signature's z vector results in
  66. * the following behavior from mldsa_verify():
  67. *
  68. * * -EBADMSG if a coefficient is changed to have an out-of-range value, i.e.
  69. * absolute value >= gamma1 - beta, corresponding to the verifier detecting
  70. * the out-of-range coefficient and rejecting the signature as malformed
  71. *
  72. * * -EKEYREJECTED if a coefficient is changed to a different in-range value,
  73. * i.e. absolute value < gamma1 - beta, corresponding to the verifier
  74. * continuing to the "real" signature check and that check failing
  75. */
  76. static void test_mldsa_z_range(struct kunit *test,
  77. const struct mldsa_testvector *tv)
  78. {
  79. u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len);
  80. const int lambda = params[tv->alg].lambda;
  81. const s32 gamma1 = params[tv->alg].gamma1;
  82. const int beta = params[tv->alg].beta;
  83. /*
  84. * We just modify the first coefficient. The coefficient is gamma1
  85. * minus either the first 18 or 20 bits of the u32, depending on gamma1.
  86. *
  87. * The layout of ML-DSA signatures is ctilde || z || h. ctilde is
  88. * lambda / 4 bytes, so z starts at &sig[lambda / 4].
  89. */
  90. u8 *z_ptr = &sig[lambda / 4];
  91. const u32 z_data = get_unaligned_le32(z_ptr);
  92. const u32 mask = (gamma1 << 1) - 1;
  93. /* These are the four boundaries of the out-of-range values. */
  94. const s32 out_of_range_coeffs[] = {
  95. -gamma1 + 1,
  96. -(gamma1 - beta),
  97. gamma1,
  98. gamma1 - beta,
  99. };
  100. /*
  101. * These are the two boundaries of the valid range, along with 0. We
  102. * assume that none of these matches the original coefficient.
  103. */
  104. const s32 in_range_coeffs[] = {
  105. -(gamma1 - beta - 1),
  106. 0,
  107. gamma1 - beta - 1,
  108. };
  109. /* Initially the signature is valid. */
  110. do_mldsa_and_assert_success(test, tv);
  111. /* Test some out-of-range coefficients. */
  112. for (int i = 0; i < ARRAY_SIZE(out_of_range_coeffs); i++) {
  113. const s32 c = out_of_range_coeffs[i];
  114. put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)),
  115. z_ptr);
  116. KUNIT_ASSERT_EQ(test, -EBADMSG,
  117. mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
  118. tv->msg_len, tv->pk, tv->pk_len));
  119. }
  120. /* Test some in-range coefficients. */
  121. for (int i = 0; i < ARRAY_SIZE(in_range_coeffs); i++) {
  122. const s32 c = in_range_coeffs[i];
  123. put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)),
  124. z_ptr);
  125. KUNIT_ASSERT_EQ(test, -EKEYREJECTED,
  126. mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
  127. tv->msg_len, tv->pk, tv->pk_len));
  128. }
  129. }
  130. /* Test that mldsa_verify() rejects malformed hint vectors with -EBADMSG. */
  131. static void test_mldsa_bad_hints(struct kunit *test,
  132. const struct mldsa_testvector *tv)
  133. {
  134. const int omega = params[tv->alg].omega;
  135. const int k = params[tv->alg].k;
  136. u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len);
  137. /* Pointer to the encoded hint vector in the signature */
  138. u8 *hintvec = &sig[tv->sig_len - omega - k];
  139. u8 h;
  140. /* Initially the signature is valid. */
  141. do_mldsa_and_assert_success(test, tv);
  142. /* Cumulative hint count exceeds omega */
  143. memcpy(sig, tv->sig, tv->sig_len);
  144. hintvec[omega + k - 1] = omega + 1;
  145. KUNIT_ASSERT_EQ(test, -EBADMSG,
  146. mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
  147. tv->msg_len, tv->pk, tv->pk_len));
  148. /* Cumulative hint count decreases */
  149. memcpy(sig, tv->sig, tv->sig_len);
  150. KUNIT_ASSERT_GE(test, hintvec[omega + k - 2], 1);
  151. hintvec[omega + k - 1] = hintvec[omega + k - 2] - 1;
  152. KUNIT_ASSERT_EQ(test, -EBADMSG,
  153. mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
  154. tv->msg_len, tv->pk, tv->pk_len));
  155. /*
  156. * Hint indices out of order. To test this, swap hintvec[0] and
  157. * hintvec[1]. This assumes that the original valid signature had at
  158. * least two nonzero hints in the first element (asserted below).
  159. */
  160. memcpy(sig, tv->sig, tv->sig_len);
  161. KUNIT_ASSERT_GE(test, hintvec[omega], 2);
  162. h = hintvec[0];
  163. hintvec[0] = hintvec[1];
  164. hintvec[1] = h;
  165. KUNIT_ASSERT_EQ(test, -EBADMSG,
  166. mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
  167. tv->msg_len, tv->pk, tv->pk_len));
  168. /*
  169. * Extra hint indices given. For this test to work, the original valid
  170. * signature must have fewer than omega nonzero hints (asserted below).
  171. */
  172. memcpy(sig, tv->sig, tv->sig_len);
  173. KUNIT_ASSERT_LT(test, hintvec[omega + k - 1], omega);
  174. hintvec[omega - 1] = 0xff;
  175. KUNIT_ASSERT_EQ(test, -EBADMSG,
  176. mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
  177. tv->msg_len, tv->pk, tv->pk_len));
  178. }
  179. static void test_mldsa_mutation(struct kunit *test,
  180. const struct mldsa_testvector *tv)
  181. {
  182. const int sig_len = tv->sig_len;
  183. const int msg_len = tv->msg_len;
  184. const int pk_len = tv->pk_len;
  185. const int num_iter = 200;
  186. u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, sig_len);
  187. u8 *msg = kunit_kmemdup_or_fail(test, tv->msg, msg_len);
  188. u8 *pk = kunit_kmemdup_or_fail(test, tv->pk, pk_len);
  189. /* Initially the signature is valid. */
  190. do_mldsa_and_assert_success(test, tv);
  191. /* Changing any bit in the signature should invalidate the signature */
  192. for (int i = 0; i < num_iter; i++) {
  193. size_t pos = get_random_u32_below(sig_len);
  194. u8 b = 1 << get_random_u32_below(8);
  195. sig[pos] ^= b;
  196. KUNIT_ASSERT_NE(test, 0,
  197. mldsa_verify(tv->alg, sig, sig_len, msg,
  198. msg_len, pk, pk_len));
  199. sig[pos] ^= b;
  200. }
  201. /* Changing any bit in the message should invalidate the signature */
  202. for (int i = 0; i < num_iter; i++) {
  203. size_t pos = get_random_u32_below(msg_len);
  204. u8 b = 1 << get_random_u32_below(8);
  205. msg[pos] ^= b;
  206. KUNIT_ASSERT_NE(test, 0,
  207. mldsa_verify(tv->alg, sig, sig_len, msg,
  208. msg_len, pk, pk_len));
  209. msg[pos] ^= b;
  210. }
  211. /* Changing any bit in the public key should invalidate the signature */
  212. for (int i = 0; i < num_iter; i++) {
  213. size_t pos = get_random_u32_below(pk_len);
  214. u8 b = 1 << get_random_u32_below(8);
  215. pk[pos] ^= b;
  216. KUNIT_ASSERT_NE(test, 0,
  217. mldsa_verify(tv->alg, sig, sig_len, msg,
  218. msg_len, pk, pk_len));
  219. pk[pos] ^= b;
  220. }
  221. /* All changes should have been undone. */
  222. KUNIT_ASSERT_EQ(test, 0,
  223. mldsa_verify(tv->alg, sig, sig_len, msg, msg_len, pk,
  224. pk_len));
  225. }
  226. static void test_mldsa(struct kunit *test, const struct mldsa_testvector *tv)
  227. {
  228. /* Valid signature */
  229. KUNIT_ASSERT_EQ(test, tv->sig_len, params[tv->alg].sig_len);
  230. KUNIT_ASSERT_EQ(test, tv->pk_len, params[tv->alg].pk_len);
  231. do_mldsa_and_assert_success(test, tv);
  232. /* Signature too short */
  233. KUNIT_ASSERT_EQ(test, -EBADMSG,
  234. mldsa_verify(tv->alg, tv->sig, tv->sig_len - 1, tv->msg,
  235. tv->msg_len, tv->pk, tv->pk_len));
  236. /* Signature too long */
  237. KUNIT_ASSERT_EQ(test, -EBADMSG,
  238. mldsa_verify(tv->alg, tv->sig, tv->sig_len + 1, tv->msg,
  239. tv->msg_len, tv->pk, tv->pk_len));
  240. /* Public key too short */
  241. KUNIT_ASSERT_EQ(test, -EBADMSG,
  242. mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
  243. tv->msg_len, tv->pk, tv->pk_len - 1));
  244. /* Public key too long */
  245. KUNIT_ASSERT_EQ(test, -EBADMSG,
  246. mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
  247. tv->msg_len, tv->pk, tv->pk_len + 1));
  248. /*
  249. * Message too short. Error is EKEYREJECTED because it gets rejected by
  250. * the "real" signature check rather than the well-formedness checks.
  251. */
  252. KUNIT_ASSERT_EQ(test, -EKEYREJECTED,
  253. mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
  254. tv->msg_len - 1, tv->pk, tv->pk_len));
  255. /*
  256. * Can't simply try (tv->msg, tv->msg_len + 1) too, as tv->msg would be
  257. * accessed out of bounds. However, ML-DSA just hashes the message and
  258. * doesn't handle different message lengths differently anyway.
  259. */
  260. /* Test the validity checks on the z vector. */
  261. test_mldsa_z_range(test, tv);
  262. /* Test the validity checks on the hint vector. */
  263. test_mldsa_bad_hints(test, tv);
  264. /* Test randomly mutating the inputs. */
  265. test_mldsa_mutation(test, tv);
  266. }
  267. static void test_mldsa44(struct kunit *test)
  268. {
  269. test_mldsa(test, &mldsa44_testvector);
  270. }
  271. static void test_mldsa65(struct kunit *test)
  272. {
  273. test_mldsa(test, &mldsa65_testvector);
  274. }
  275. static void test_mldsa87(struct kunit *test)
  276. {
  277. test_mldsa(test, &mldsa87_testvector);
  278. }
  279. static s32 mod(s32 a, s32 m)
  280. {
  281. a %= m;
  282. if (a < 0)
  283. a += m;
  284. return a;
  285. }
  286. static s32 symmetric_mod(s32 a, s32 m)
  287. {
  288. a = mod(a, m);
  289. if (a > m / 2)
  290. a -= m;
  291. return a;
  292. }
  293. /* Mechanical, inefficient translation of FIPS 204 Algorithm 36, Decompose */
  294. static void decompose_ref(s32 r, s32 gamma2, s32 *r0, s32 *r1)
  295. {
  296. s32 rplus = mod(r, Q);
  297. *r0 = symmetric_mod(rplus, 2 * gamma2);
  298. if (rplus - *r0 == Q - 1) {
  299. *r1 = 0;
  300. *r0 = *r0 - 1;
  301. } else {
  302. *r1 = (rplus - *r0) / (2 * gamma2);
  303. }
  304. }
  305. /* Mechanical, inefficient translation of FIPS 204 Algorithm 40, UseHint */
  306. static s32 use_hint_ref(u8 h, s32 r, s32 gamma2)
  307. {
  308. s32 m = (Q - 1) / (2 * gamma2);
  309. s32 r0, r1;
  310. decompose_ref(r, gamma2, &r0, &r1);
  311. if (h == 1 && r0 > 0)
  312. return mod(r1 + 1, m);
  313. if (h == 1 && r0 <= 0)
  314. return mod(r1 - 1, m);
  315. return r1;
  316. }
  317. /*
  318. * Test that for all possible inputs, mldsa_use_hint() gives the same output as
  319. * a mechanical translation of the pseudocode from FIPS 204.
  320. */
  321. static void test_mldsa_use_hint(struct kunit *test)
  322. {
  323. for (int i = 0; i < 2; i++) {
  324. const s32 gamma2 = (Q - 1) / (i == 0 ? 88 : 32);
  325. for (u8 h = 0; h < 2; h++) {
  326. for (s32 r = 0; r < Q; r++) {
  327. KUNIT_ASSERT_EQ(test,
  328. mldsa_use_hint(h, r, gamma2),
  329. use_hint_ref(h, r, gamma2));
  330. }
  331. }
  332. }
  333. }
  334. static void benchmark_mldsa(struct kunit *test,
  335. const struct mldsa_testvector *tv)
  336. {
  337. const int warmup_niter = 200;
  338. const int benchmark_niter = 200;
  339. u64 t0, t1;
  340. if (!IS_ENABLED(CONFIG_CRYPTO_LIB_BENCHMARK))
  341. kunit_skip(test, "not enabled");
  342. for (int i = 0; i < warmup_niter; i++)
  343. do_mldsa_and_assert_success(test, tv);
  344. t0 = ktime_get_ns();
  345. for (int i = 0; i < benchmark_niter; i++)
  346. do_mldsa_and_assert_success(test, tv);
  347. t1 = ktime_get_ns();
  348. kunit_info(test, "%llu ops/s",
  349. div64_u64((u64)benchmark_niter * NSEC_PER_SEC,
  350. t1 - t0 ?: 1));
  351. }
  352. static void benchmark_mldsa44(struct kunit *test)
  353. {
  354. benchmark_mldsa(test, &mldsa44_testvector);
  355. }
  356. static void benchmark_mldsa65(struct kunit *test)
  357. {
  358. benchmark_mldsa(test, &mldsa65_testvector);
  359. }
  360. static void benchmark_mldsa87(struct kunit *test)
  361. {
  362. benchmark_mldsa(test, &mldsa87_testvector);
  363. }
  364. static struct kunit_case mldsa_kunit_cases[] = {
  365. KUNIT_CASE(test_mldsa44),
  366. KUNIT_CASE(test_mldsa65),
  367. KUNIT_CASE(test_mldsa87),
  368. KUNIT_CASE(test_mldsa_use_hint),
  369. KUNIT_CASE(benchmark_mldsa44),
  370. KUNIT_CASE(benchmark_mldsa65),
  371. KUNIT_CASE(benchmark_mldsa87),
  372. {},
  373. };
  374. static struct kunit_suite mldsa_kunit_suite = {
  375. .name = "mldsa",
  376. .test_cases = mldsa_kunit_cases,
  377. };
  378. kunit_test_suite(mldsa_kunit_suite);
  379. MODULE_DESCRIPTION("KUnit tests and benchmark for ML-DSA");
  380. MODULE_IMPORT_NS("EXPORTED_FOR_KUNIT_TESTING");
  381. MODULE_LICENSE("GPL");