mldsa.c 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682
  1. // SPDX-License-Identifier: GPL-2.0-or-later
  2. /*
  3. * Support for verifying ML-DSA signatures
  4. *
  5. * Copyright 2025 Google LLC
  6. */
  7. #include <crypto/mldsa.h>
  8. #include <crypto/sha3.h>
  9. #include <kunit/visibility.h>
  10. #include <linux/export.h>
  11. #include <linux/module.h>
  12. #include <linux/slab.h>
  13. #include <linux/string.h>
  14. #include <linux/unaligned.h>
  15. #include "fips-mldsa.h"
  16. #define Q 8380417 /* The prime q = 2^23 - 2^13 + 1 */
  17. #define QINV_MOD_2_32 58728449 /* Multiplicative inverse of q mod 2^32 */
  18. #define N 256 /* Number of components per ring element */
  19. #define D 13 /* Number of bits dropped from the public key vector t */
  20. #define RHO_LEN 32 /* Length of the public random seed in bytes */
  21. #define MAX_W1_ENCODED_LEN 192 /* Max encoded length of one element of w'_1 */
  22. /*
  23. * The zetas array in Montgomery form, i.e. with extra factor of 2^32.
  24. * Reference: FIPS 204 Section 7.5 "NTT and NTT^-1"
  25. * Generated by the following Python code:
  26. * q=8380417; [a%q - q*(a%q > q//2) for a in [1753**(int(f'{i:08b}'[::-1], 2)) << 32 for i in range(256)]]
  27. */
  28. static const s32 zetas_times_2_32[N] = {
  29. -4186625, 25847, -2608894, -518909, 237124, -777960, -876248,
  30. 466468, 1826347, 2353451, -359251, -2091905, 3119733, -2884855,
  31. 3111497, 2680103, 2725464, 1024112, -1079900, 3585928, -549488,
  32. -1119584, 2619752, -2108549, -2118186, -3859737, -1399561, -3277672,
  33. 1757237, -19422, 4010497, 280005, 2706023, 95776, 3077325,
  34. 3530437, -1661693, -3592148, -2537516, 3915439, -3861115, -3043716,
  35. 3574422, -2867647, 3539968, -300467, 2348700, -539299, -1699267,
  36. -1643818, 3505694, -3821735, 3507263, -2140649, -1600420, 3699596,
  37. 811944, 531354, 954230, 3881043, 3900724, -2556880, 2071892,
  38. -2797779, -3930395, -1528703, -3677745, -3041255, -1452451, 3475950,
  39. 2176455, -1585221, -1257611, 1939314, -4083598, -1000202, -3190144,
  40. -3157330, -3632928, 126922, 3412210, -983419, 2147896, 2715295,
  41. -2967645, -3693493, -411027, -2477047, -671102, -1228525, -22981,
  42. -1308169, -381987, 1349076, 1852771, -1430430, -3343383, 264944,
  43. 508951, 3097992, 44288, -1100098, 904516, 3958618, -3724342,
  44. -8578, 1653064, -3249728, 2389356, -210977, 759969, -1316856,
  45. 189548, -3553272, 3159746, -1851402, -2409325, -177440, 1315589,
  46. 1341330, 1285669, -1584928, -812732, -1439742, -3019102, -3881060,
  47. -3628969, 3839961, 2091667, 3407706, 2316500, 3817976, -3342478,
  48. 2244091, -2446433, -3562462, 266997, 2434439, -1235728, 3513181,
  49. -3520352, -3759364, -1197226, -3193378, 900702, 1859098, 909542,
  50. 819034, 495491, -1613174, -43260, -522500, -655327, -3122442,
  51. 2031748, 3207046, -3556995, -525098, -768622, -3595838, 342297,
  52. 286988, -2437823, 4108315, 3437287, -3342277, 1735879, 203044,
  53. 2842341, 2691481, -2590150, 1265009, 4055324, 1247620, 2486353,
  54. 1595974, -3767016, 1250494, 2635921, -3548272, -2994039, 1869119,
  55. 1903435, -1050970, -1333058, 1237275, -3318210, -1430225, -451100,
  56. 1312455, 3306115, -1962642, -1279661, 1917081, -2546312, -1374803,
  57. 1500165, 777191, 2235880, 3406031, -542412, -2831860, -1671176,
  58. -1846953, -2584293, -3724270, 594136, -3776993, -2013608, 2432395,
  59. 2454455, -164721, 1957272, 3369112, 185531, -1207385, -3183426,
  60. 162844, 1616392, 3014001, 810149, 1652634, -3694233, -1799107,
  61. -3038916, 3523897, 3866901, 269760, 2213111, -975884, 1717735,
  62. 472078, -426683, 1723600, -1803090, 1910376, -1667432, -1104333,
  63. -260646, -3833893, -2939036, -2235985, -420899, -2286327, 183443,
  64. -976891, 1612842, -3545687, -554416, 3919660, -48306, -1362209,
  65. 3937738, 1400424, -846154, 1976782
  66. };
  67. /* Reference: FIPS 204 Section 4 "Parameter Sets" */
  68. static const struct mldsa_parameter_set {
  69. u8 k; /* num rows in the matrix A */
  70. u8 l; /* num columns in the matrix A */
  71. u8 ctilde_len; /* length of commitment hash ctilde in bytes; lambda/4 */
  72. u8 omega; /* max num of 1's in the hint vector h */
  73. u8 tau; /* num of +-1's in challenge c */
  74. u8 beta; /* tau times eta */
  75. u16 pk_len; /* length of public keys in bytes */
  76. u16 sig_len; /* length of signatures in bytes */
  77. s32 gamma1; /* coefficient range of y */
  78. } mldsa_parameter_sets[] = {
  79. [MLDSA44] = {
  80. .k = 4,
  81. .l = 4,
  82. .ctilde_len = 32,
  83. .omega = 80,
  84. .tau = 39,
  85. .beta = 78,
  86. .pk_len = MLDSA44_PUBLIC_KEY_SIZE,
  87. .sig_len = MLDSA44_SIGNATURE_SIZE,
  88. .gamma1 = 1 << 17,
  89. },
  90. [MLDSA65] = {
  91. .k = 6,
  92. .l = 5,
  93. .ctilde_len = 48,
  94. .omega = 55,
  95. .tau = 49,
  96. .beta = 196,
  97. .pk_len = MLDSA65_PUBLIC_KEY_SIZE,
  98. .sig_len = MLDSA65_SIGNATURE_SIZE,
  99. .gamma1 = 1 << 19,
  100. },
  101. [MLDSA87] = {
  102. .k = 8,
  103. .l = 7,
  104. .ctilde_len = 64,
  105. .omega = 75,
  106. .tau = 60,
  107. .beta = 120,
  108. .pk_len = MLDSA87_PUBLIC_KEY_SIZE,
  109. .sig_len = MLDSA87_SIGNATURE_SIZE,
  110. .gamma1 = 1 << 19,
  111. },
  112. };
  113. /*
  114. * An element of the ring R_q (normal form) or the ring T_q (NTT form). It
  115. * consists of N integers mod q: either the polynomial coefficients of the R_q
  116. * element or the components of the T_q element. In either case, whether they
  117. * are fully reduced to [0, q - 1] varies in the different parts of the code.
  118. */
  119. struct mldsa_ring_elem {
  120. s32 x[N];
  121. };
  122. struct mldsa_verification_workspace {
  123. /* SHAKE context for computing c, mu, and ctildeprime */
  124. struct shake_ctx shake;
  125. /* The fields in this union are used in their order of declaration. */
  126. union {
  127. /* The hash of the public key */
  128. u8 tr[64];
  129. /* The message representative mu */
  130. u8 mu[64];
  131. /* Temporary space for rej_ntt_poly() */
  132. u8 block[SHAKE128_BLOCK_SIZE + 1];
  133. /* Encoded element of w'_1 */
  134. u8 w1_encoded[MAX_W1_ENCODED_LEN];
  135. /* The commitment hash. Real length is params->ctilde_len */
  136. u8 ctildeprime[64];
  137. };
  138. /* SHAKE context for generating elements of the matrix A */
  139. struct shake_ctx a_shake;
  140. /*
  141. * An element of the matrix A generated from the public seed, or an
  142. * element of the vector t_1 decoded from the public key and pre-scaled
  143. * by 2^d. Both are in NTT form. To reduce memory usage, we generate
  144. * or decode these elements only as needed.
  145. */
  146. union {
  147. struct mldsa_ring_elem a;
  148. struct mldsa_ring_elem t1_scaled;
  149. };
  150. /* The challenge c, generated from ctilde */
  151. struct mldsa_ring_elem c;
  152. /* A temporary element used during calculations */
  153. struct mldsa_ring_elem tmp;
  154. /* The following fields are variable-length: */
  155. /* The signer's response vector */
  156. struct mldsa_ring_elem z[/* l */];
  157. /* The signer's hint vector */
  158. /* u8 h[k * N]; */
  159. };
  160. /*
  161. * Compute a * b * 2^-32 mod q. a * b must be in the range [-2^31 * q, 2^31 * q
  162. * - 1] before reduction. The return value is in the range [-q + 1, q - 1].
  163. *
  164. * To reduce mod q efficiently, this uses Montgomery reduction with R=2^32.
  165. * That's where the factor of 2^-32 comes from. The caller must include a
  166. * factor of 2^32 at some point to compensate for that.
  167. *
  168. * To keep the input and output ranges very close to symmetric, this
  169. * specifically does a "signed" Montgomery reduction. That is, when computing
  170. * d = c * q^-1 mod 2^32, this chooses a representative in [S32_MIN, S32_MAX]
  171. * rather than [0, U32_MAX], i.e. s32 rather than u32. This matters in the
  172. * wider multiplication d * Q when d keeps its value via sign extension.
  173. *
  174. * Reference: FIPS 204 Appendix A "Montgomery Multiplication". But, it doesn't
  175. * explain it properly: it has an off-by-one error in the upper end of the input
  176. * range, it doesn't clarify that the signed version should be used, and it
  177. * gives an unnecessarily large output range. A better citation is perhaps the
  178. * Dilithium reference code, which functionally matches the below code and
  179. * merely has the (benign) off-by-one error in its documentation.
  180. */
  181. static inline s32 Zq_mult(s32 a, s32 b)
  182. {
  183. /* Compute the unreduced product c. */
  184. s64 c = (s64)a * b;
  185. /*
  186. * Compute d = c * q^-1 mod 2^32. Generate a signed result, as
  187. * explained above, but do the actual multiplication using an unsigned
  188. * type to avoid signed integer overflow which is undefined behavior.
  189. */
  190. s32 d = (u32)c * QINV_MOD_2_32;
  191. /*
  192. * Compute e = c - d * q. This makes the low 32 bits zero, since
  193. * c - (c * q^-1) * q mod 2^32
  194. * = c - c * (q^-1 * q) mod 2^32
  195. * = c - c * 1 mod 2^32
  196. * = c - c mod 2^32
  197. * = 0 mod 2^32
  198. */
  199. s64 e = c - (s64)d * Q;
  200. /* Finally, return e * 2^-32. */
  201. return e >> 32;
  202. }
  203. /*
  204. * Convert @w to its number-theoretically-transformed representation in-place.
  205. * Reference: FIPS 204 Algorithm 41, NTT
  206. *
  207. * To prevent intermediate overflows, all input coefficients must have absolute
  208. * value < q. All output components have absolute value < 9*q.
  209. */
  210. static void ntt(struct mldsa_ring_elem *w)
  211. {
  212. int m = 0; /* index in zetas_times_2_32 */
  213. for (int len = 128; len >= 1; len /= 2) {
  214. for (int start = 0; start < 256; start += 2 * len) {
  215. const s32 z = zetas_times_2_32[++m];
  216. for (int j = start; j < start + len; j++) {
  217. s32 t = Zq_mult(z, w->x[j + len]);
  218. w->x[j + len] = w->x[j] - t;
  219. w->x[j] += t;
  220. }
  221. }
  222. }
  223. }
  224. /*
  225. * Convert @w from its number-theoretically-transformed representation in-place.
  226. * Reference: FIPS 204 Algorithm 42, NTT^-1
  227. *
  228. * This also multiplies the coefficients by 2^32, undoing an extra factor of
  229. * 2^-32 introduced earlier, and reduces the coefficients to [0, q - 1].
  230. */
  231. static void invntt_and_mul_2_32(struct mldsa_ring_elem *w)
  232. {
  233. int m = 256; /* index in zetas_times_2_32 */
  234. /* Prevent intermediate overflows. */
  235. for (int j = 0; j < 256; j++)
  236. w->x[j] %= Q;
  237. for (int len = 1; len < 256; len *= 2) {
  238. for (int start = 0; start < 256; start += 2 * len) {
  239. const s32 z = -zetas_times_2_32[--m];
  240. for (int j = start; j < start + len; j++) {
  241. s32 t = w->x[j];
  242. w->x[j] = t + w->x[j + len];
  243. w->x[j + len] = Zq_mult(z, t - w->x[j + len]);
  244. }
  245. }
  246. }
  247. /*
  248. * Multiply by 2^32 * 256^-1. 2^32 cancels the factor of 2^-32 from
  249. * earlier Montgomery multiplications. 256^-1 is for NTT^-1. This
  250. * itself uses Montgomery multiplication, so *another* 2^32 is needed.
  251. * Thus the actual multiplicand is 2^32 * 2^32 * 256^-1 mod q = 41978.
  252. *
  253. * Finally, also reduce from [-q + 1, q - 1] to [0, q - 1].
  254. */
  255. for (int j = 0; j < 256; j++) {
  256. w->x[j] = Zq_mult(w->x[j], 41978);
  257. w->x[j] += (w->x[j] >> 31) & Q;
  258. }
  259. }
  260. /*
  261. * Decode an element of t_1, i.e. the high d bits of t = A*s_1 + s_2.
  262. * Reference: FIPS 204 Algorithm 23, pkDecode.
  263. * Also multiply it by 2^d and convert it to NTT form.
  264. */
  265. static const u8 *decode_t1_elem(struct mldsa_ring_elem *out,
  266. const u8 *t1_encoded)
  267. {
  268. for (int j = 0; j < N; j += 4, t1_encoded += 5) {
  269. u32 v = get_unaligned_le32(t1_encoded);
  270. out->x[j + 0] = ((v >> 0) & 0x3ff) << D;
  271. out->x[j + 1] = ((v >> 10) & 0x3ff) << D;
  272. out->x[j + 2] = ((v >> 20) & 0x3ff) << D;
  273. out->x[j + 3] = ((v >> 30) | (t1_encoded[4] << 2)) << D;
  274. static_assert(0x3ff << D < Q); /* All coefficients < q. */
  275. }
  276. ntt(out);
  277. return t1_encoded; /* Return updated pointer. */
  278. }
  279. /*
  280. * Decode the signer's response vector 'z' from the signature.
  281. * Reference: FIPS 204 Algorithm 27, sigDecode.
  282. *
  283. * This also validates that the coefficients of z are in range, corresponding
  284. * the infinity norm check at the end of Algorithm 8, ML-DSA.Verify_internal.
  285. *
  286. * Finally, this also converts z to NTT form.
  287. */
  288. static bool decode_z(struct mldsa_ring_elem z[/* l */], int l, s32 gamma1,
  289. int beta, const u8 **sig_ptr)
  290. {
  291. const u8 *sig = *sig_ptr;
  292. for (int i = 0; i < l; i++) {
  293. if (l == 4) { /* ML-DSA-44? */
  294. /* 18-bit coefficients: decode 4 from 9 bytes. */
  295. for (int j = 0; j < N; j += 4, sig += 9) {
  296. u64 v = get_unaligned_le64(sig);
  297. z[i].x[j + 0] = (v >> 0) & 0x3ffff;
  298. z[i].x[j + 1] = (v >> 18) & 0x3ffff;
  299. z[i].x[j + 2] = (v >> 36) & 0x3ffff;
  300. z[i].x[j + 3] = (v >> 54) | (sig[8] << 10);
  301. }
  302. } else {
  303. /* 20-bit coefficients: decode 4 from 10 bytes. */
  304. for (int j = 0; j < N; j += 4, sig += 10) {
  305. u64 v = get_unaligned_le64(sig);
  306. z[i].x[j + 0] = (v >> 0) & 0xfffff;
  307. z[i].x[j + 1] = (v >> 20) & 0xfffff;
  308. z[i].x[j + 2] = (v >> 40) & 0xfffff;
  309. z[i].x[j + 3] =
  310. (v >> 60) |
  311. (get_unaligned_le16(&sig[8]) << 4);
  312. }
  313. }
  314. for (int j = 0; j < N; j++) {
  315. z[i].x[j] = gamma1 - z[i].x[j];
  316. if (z[i].x[j] <= -(gamma1 - beta) ||
  317. z[i].x[j] >= gamma1 - beta)
  318. return false;
  319. }
  320. ntt(&z[i]);
  321. }
  322. *sig_ptr = sig; /* Return updated pointer. */
  323. return true;
  324. }
  325. /*
  326. * Decode the signer's hint vector 'h' from the signature.
  327. * Reference: FIPS 204 Algorithm 21, HintBitUnpack
  328. *
  329. * Note that there are several ways in which the hint vector can be malformed.
  330. */
  331. static bool decode_hint_vector(u8 h[/* k * N */], int k, int omega, const u8 *y)
  332. {
  333. int index = 0;
  334. memset(h, 0, k * N);
  335. for (int i = 0; i < k; i++) {
  336. int count = y[omega + i]; /* num 1's in elems 0 through i */
  337. int prev = -1;
  338. /* Cumulative count mustn't decrease or exceed omega. */
  339. if (count < index || count > omega)
  340. return false;
  341. for (; index < count; index++) {
  342. if (prev >= y[index]) /* Coefficients out of order? */
  343. return false;
  344. prev = y[index];
  345. h[i * N + y[index]] = 1;
  346. }
  347. }
  348. return mem_is_zero(&y[index], omega - index);
  349. }
  350. /*
  351. * Expand @seed into an element of R_q @c with coefficients in {-1, 0, 1},
  352. * exactly @tau of them nonzero. Reference: FIPS 204 Algorithm 29, SampleInBall
  353. */
  354. static void sample_in_ball(struct mldsa_ring_elem *c, const u8 *seed,
  355. size_t seed_len, int tau, struct shake_ctx *shake)
  356. {
  357. u64 signs;
  358. u8 j;
  359. shake256_init(shake);
  360. shake_update(shake, seed, seed_len);
  361. shake_squeeze(shake, (u8 *)&signs, sizeof(signs));
  362. le64_to_cpus(&signs);
  363. *c = (struct mldsa_ring_elem){};
  364. for (int i = N - tau; i < N; i++, signs >>= 1) {
  365. do {
  366. shake_squeeze(shake, &j, 1);
  367. } while (j > i);
  368. c->x[i] = c->x[j];
  369. c->x[j] = 1 - 2 * (s32)(signs & 1);
  370. }
  371. }
  372. /*
  373. * Expand the public seed @rho and @row_and_column into an element of T_q @out.
  374. * Reference: FIPS 204 Algorithm 30, RejNTTPoly
  375. *
  376. * @shake and @block are temporary space used by the expansion. @block has
  377. * space for one SHAKE128 block, plus an extra byte to allow reading a u32 from
  378. * the final 3-byte group without reading out-of-bounds.
  379. */
  380. static void rej_ntt_poly(struct mldsa_ring_elem *out, const u8 rho[RHO_LEN],
  381. __le16 row_and_column, struct shake_ctx *shake,
  382. u8 block[SHAKE128_BLOCK_SIZE + 1])
  383. {
  384. shake128_init(shake);
  385. shake_update(shake, rho, RHO_LEN);
  386. shake_update(shake, (u8 *)&row_and_column, sizeof(row_and_column));
  387. for (int i = 0; i < N;) {
  388. shake_squeeze(shake, block, SHAKE128_BLOCK_SIZE);
  389. block[SHAKE128_BLOCK_SIZE] = 0; /* for KMSAN */
  390. static_assert(SHAKE128_BLOCK_SIZE % 3 == 0);
  391. for (int j = 0; j < SHAKE128_BLOCK_SIZE && i < N; j += 3) {
  392. u32 x = get_unaligned_le32(&block[j]) & 0x7fffff;
  393. if (x < Q) /* Ignore values >= q. */
  394. out->x[i++] = x;
  395. }
  396. }
  397. }
  398. /*
  399. * Return the HighBits of r adjusted according to hint h
  400. * Reference: FIPS 204 Algorithm 40, UseHint
  401. *
  402. * This is needed because of the public key compression in ML-DSA.
  403. *
  404. * h is either 0 or 1, r is in [0, q - 1], and gamma2 is either (q - 1) / 88 or
  405. * (q - 1) / 32. Except when invoked via the unit test interface, gamma2 is a
  406. * compile-time constant, so compilers will optimize the code accordingly.
  407. */
  408. static __always_inline s32 use_hint(u8 h, s32 r, const s32 gamma2)
  409. {
  410. const s32 m = (Q - 1) / (2 * gamma2); /* 44 or 16, compile-time const */
  411. s32 r1;
  412. /*
  413. * Handle the special case where r - (r mod+- (2 * gamma2)) == q - 1,
  414. * i.e. r >= q - gamma2. This is also exactly where the computation of
  415. * r1 below would produce 'm' and would need a correction.
  416. */
  417. if (r >= Q - gamma2)
  418. return h == 0 ? 0 : m - 1;
  419. /*
  420. * Compute the (non-hint-adjusted) HighBits r1 as:
  421. *
  422. * r1 = (r - (r mod+- (2 * gamma2))) / (2 * gamma2)
  423. * = floor((r + gamma2 - 1) / (2 * gamma2))
  424. *
  425. * Note that when '2 * gamma2' is a compile-time constant, compilers
  426. * optimize the division to a reciprocal multiplication and shift.
  427. */
  428. r1 = (u32)(r + gamma2 - 1) / (2 * gamma2);
  429. /*
  430. * Return the HighBits r1:
  431. * + 0 if the hint is 0;
  432. * + 1 (mod m) if the hint is 1 and the LowBits are positive;
  433. * - 1 (mod m) if the hint is 1 and the LowBits are negative or 0.
  434. *
  435. * r1 is in (and remains in) [0, m - 1]. Note that when 'm' is a
  436. * compile-time constant, compilers optimize the '% m' accordingly.
  437. */
  438. if (h == 0)
  439. return r1;
  440. if (r > r1 * (2 * gamma2))
  441. return (u32)(r1 + 1) % m;
  442. return (u32)(r1 + m - 1) % m;
  443. }
  444. static __always_inline void use_hint_elem(struct mldsa_ring_elem *w,
  445. const u8 h[N], const s32 gamma2)
  446. {
  447. for (int j = 0; j < N; j++)
  448. w->x[j] = use_hint(h[j], w->x[j], gamma2);
  449. }
  450. #if IS_ENABLED(CONFIG_CRYPTO_LIB_MLDSA_KUNIT_TEST)
  451. /* Allow the __always_inline function use_hint() to be unit-tested. */
  452. s32 mldsa_use_hint(u8 h, s32 r, s32 gamma2)
  453. {
  454. return use_hint(h, r, gamma2);
  455. }
  456. EXPORT_SYMBOL_IF_KUNIT(mldsa_use_hint);
  457. #endif
  458. /*
  459. * Encode one element of the commitment vector w'_1 into a byte string.
  460. * Reference: FIPS 204 Algorithm 28, w1Encode.
  461. * Return the number of bytes used: 192 for ML-DSA-44 and 128 for the others.
  462. */
  463. static size_t encode_w1(u8 out[MAX_W1_ENCODED_LEN],
  464. const struct mldsa_ring_elem *w1, int k)
  465. {
  466. size_t pos = 0;
  467. static_assert(N * 6 / 8 == MAX_W1_ENCODED_LEN);
  468. if (k == 4) { /* ML-DSA-44? */
  469. /* 6 bits per coefficient. Pack 4 at a time. */
  470. for (int j = 0; j < N; j += 4) {
  471. u32 v = (w1->x[j + 0] << 0) | (w1->x[j + 1] << 6) |
  472. (w1->x[j + 2] << 12) | (w1->x[j + 3] << 18);
  473. out[pos++] = v >> 0;
  474. out[pos++] = v >> 8;
  475. out[pos++] = v >> 16;
  476. }
  477. } else {
  478. /* 4 bits per coefficient. Pack 2 at a time. */
  479. for (int j = 0; j < N; j += 2)
  480. out[pos++] = w1->x[j] | (w1->x[j + 1] << 4);
  481. }
  482. return pos;
  483. }
  484. int mldsa_verify(enum mldsa_alg alg, const u8 *sig, size_t sig_len,
  485. const u8 *msg, size_t msg_len, const u8 *pk, size_t pk_len)
  486. {
  487. const struct mldsa_parameter_set *params = &mldsa_parameter_sets[alg];
  488. const int k = params->k, l = params->l;
  489. /* For now this just does pure ML-DSA with an empty context string. */
  490. static const u8 msg_prefix[2] = { /* dom_sep= */ 0, /* ctx_len= */ 0 };
  491. const u8 *ctilde; /* The signer's commitment hash */
  492. const u8 *t1_encoded = &pk[RHO_LEN]; /* Next encoded element of t_1 */
  493. u8 *h; /* The signer's hint vector, length k * N */
  494. size_t w1_enc_len;
  495. /* Validate the public key and signature lengths. */
  496. if (pk_len != params->pk_len || sig_len != params->sig_len)
  497. return -EBADMSG;
  498. /*
  499. * Allocate the workspace, including variable-length fields. Its size
  500. * depends only on the ML-DSA parameter set, not the other inputs.
  501. *
  502. * For freeing it, use kfree_sensitive() rather than kfree(). This is
  503. * mainly to comply with FIPS 204 Section 3.6.3 "Intermediate Values".
  504. * In reality it's a bit gratuitous, as this is a public key operation.
  505. */
  506. struct mldsa_verification_workspace *ws __free(kfree_sensitive) =
  507. kmalloc(sizeof(*ws) + (l * sizeof(ws->z[0])) + (k * N),
  508. GFP_KERNEL);
  509. if (!ws)
  510. return -ENOMEM;
  511. h = (u8 *)&ws->z[l];
  512. /* Decode the signature. Reference: FIPS 204 Algorithm 27, sigDecode */
  513. ctilde = sig;
  514. sig += params->ctilde_len;
  515. if (!decode_z(ws->z, l, params->gamma1, params->beta, &sig))
  516. return -EBADMSG;
  517. if (!decode_hint_vector(h, k, params->omega, sig))
  518. return -EBADMSG;
  519. /* Recreate the challenge c from the signer's commitment hash. */
  520. sample_in_ball(&ws->c, ctilde, params->ctilde_len, params->tau,
  521. &ws->shake);
  522. ntt(&ws->c);
  523. /* Compute the message representative mu. */
  524. shake256(pk, pk_len, ws->tr, sizeof(ws->tr));
  525. shake256_init(&ws->shake);
  526. shake_update(&ws->shake, ws->tr, sizeof(ws->tr));
  527. shake_update(&ws->shake, msg_prefix, sizeof(msg_prefix));
  528. shake_update(&ws->shake, msg, msg_len);
  529. shake_squeeze(&ws->shake, ws->mu, sizeof(ws->mu));
  530. /* Start computing ctildeprime = H(mu || w1Encode(w'_1)). */
  531. shake256_init(&ws->shake);
  532. shake_update(&ws->shake, ws->mu, sizeof(ws->mu));
  533. /*
  534. * Compute the commitment w'_1 from A, z, c, t_1, and h.
  535. *
  536. * The computation is the same for each of the k rows. Just do each row
  537. * before moving on to the next, resulting in only one loop over k.
  538. */
  539. for (int i = 0; i < k; i++) {
  540. /*
  541. * tmp = NTT(A) * NTT(z) * 2^-32
  542. * To reduce memory use, generate each element of NTT(A)
  543. * on-demand. Note that each element is used only once.
  544. */
  545. ws->tmp = (struct mldsa_ring_elem){};
  546. for (int j = 0; j < l; j++) {
  547. rej_ntt_poly(&ws->a, pk /* rho is first field of pk */,
  548. cpu_to_le16((i << 8) | j), &ws->a_shake,
  549. ws->block);
  550. for (int n = 0; n < N; n++)
  551. ws->tmp.x[n] +=
  552. Zq_mult(ws->a.x[n], ws->z[j].x[n]);
  553. }
  554. /* All components of tmp now have abs value < l*q. */
  555. /* Decode the next element of t_1. */
  556. t1_encoded = decode_t1_elem(&ws->t1_scaled, t1_encoded);
  557. /*
  558. * tmp -= NTT(c) * NTT(t_1 * 2^d) * 2^-32
  559. *
  560. * Taking a conservative bound for the output of ntt(), the
  561. * multiplicands can have absolute value up to 9*q. That
  562. * corresponds to a product with absolute value 81*q^2. That is
  563. * within the limits of Zq_mult() which needs < ~256*q^2.
  564. */
  565. for (int j = 0; j < N; j++)
  566. ws->tmp.x[j] -= Zq_mult(ws->c.x[j], ws->t1_scaled.x[j]);
  567. /* All components of tmp now have abs value < (l+1)*q. */
  568. /* tmp = w'_Approx = NTT^-1(tmp) * 2^32 */
  569. invntt_and_mul_2_32(&ws->tmp);
  570. /* All coefficients of tmp are now in [0, q - 1]. */
  571. /*
  572. * tmp = w'_1 = UseHint(h, w'_Approx)
  573. * For efficiency, set gamma2 to a compile-time constant.
  574. */
  575. if (k == 4)
  576. use_hint_elem(&ws->tmp, &h[i * N], (Q - 1) / 88);
  577. else
  578. use_hint_elem(&ws->tmp, &h[i * N], (Q - 1) / 32);
  579. /* Encode and hash the next element of w'_1. */
  580. w1_enc_len = encode_w1(ws->w1_encoded, &ws->tmp, k);
  581. shake_update(&ws->shake, ws->w1_encoded, w1_enc_len);
  582. }
  583. /* Finish computing ctildeprime. */
  584. shake_squeeze(&ws->shake, ws->ctildeprime, params->ctilde_len);
  585. /* Verify that ctilde == ctildeprime. */
  586. if (memcmp(ws->ctildeprime, ctilde, params->ctilde_len) != 0)
  587. return -EKEYREJECTED;
  588. /* ||z||_infinity < gamma1 - beta was already checked in decode_z(). */
  589. return 0;
  590. }
  591. EXPORT_SYMBOL_GPL(mldsa_verify);
  592. #ifdef CONFIG_CRYPTO_FIPS
  593. static int __init mldsa_mod_init(void)
  594. {
  595. if (fips_enabled) {
  596. /*
  597. * FIPS cryptographic algorithm self-test. As per the FIPS
  598. * Implementation Guidance, testing any ML-DSA parameter set
  599. * satisfies the test requirement for all of them, and only a
  600. * positive test is required.
  601. */
  602. int err = mldsa_verify(MLDSA65, fips_test_mldsa65_signature,
  603. sizeof(fips_test_mldsa65_signature),
  604. fips_test_mldsa65_message,
  605. sizeof(fips_test_mldsa65_message),
  606. fips_test_mldsa65_public_key,
  607. sizeof(fips_test_mldsa65_public_key));
  608. if (err)
  609. panic("mldsa: FIPS self-test failed; err=%pe\n",
  610. ERR_PTR(err));
  611. }
  612. return 0;
  613. }
  614. subsys_initcall(mldsa_mod_init);
  615. static void __exit mldsa_mod_exit(void)
  616. {
  617. }
  618. module_exit(mldsa_mod_exit);
  619. #endif /* CONFIG_CRYPTO_FIPS */
  620. MODULE_DESCRIPTION("ML-DSA signature verification");
  621. MODULE_LICENSE("GPL");