| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438 |
- // SPDX-License-Identifier: GPL-2.0-or-later
- /*
- * KUnit tests and benchmark for ML-DSA
- *
- * Copyright 2025 Google LLC
- */
- #include <crypto/mldsa.h>
- #include <kunit/test.h>
- #include <linux/random.h>
- #include <linux/unaligned.h>
- #define Q 8380417 /* The prime q = 2^23 - 2^13 + 1 */
- /* ML-DSA parameters that the tests use */
- static const struct {
- int sig_len;
- int pk_len;
- int k;
- int lambda;
- int gamma1;
- int beta;
- int omega;
- } params[] = {
- [MLDSA44] = {
- .sig_len = MLDSA44_SIGNATURE_SIZE,
- .pk_len = MLDSA44_PUBLIC_KEY_SIZE,
- .k = 4,
- .lambda = 128,
- .gamma1 = 1 << 17,
- .beta = 78,
- .omega = 80,
- },
- [MLDSA65] = {
- .sig_len = MLDSA65_SIGNATURE_SIZE,
- .pk_len = MLDSA65_PUBLIC_KEY_SIZE,
- .k = 6,
- .lambda = 192,
- .gamma1 = 1 << 19,
- .beta = 196,
- .omega = 55,
- },
- [MLDSA87] = {
- .sig_len = MLDSA87_SIGNATURE_SIZE,
- .pk_len = MLDSA87_PUBLIC_KEY_SIZE,
- .k = 8,
- .lambda = 256,
- .gamma1 = 1 << 19,
- .beta = 120,
- .omega = 75,
- },
- };
- #include "mldsa-testvecs.h"
- static void do_mldsa_and_assert_success(struct kunit *test,
- const struct mldsa_testvector *tv)
- {
- int err = mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
- tv->msg_len, tv->pk, tv->pk_len);
- KUNIT_ASSERT_EQ(test, err, 0);
- }
- static u8 *kunit_kmemdup_or_fail(struct kunit *test, const u8 *src, size_t len)
- {
- u8 *dst = kunit_kmalloc(test, len, GFP_KERNEL);
- KUNIT_ASSERT_NOT_NULL(test, dst);
- return memcpy(dst, src, len);
- }
- /*
- * Test that changing coefficients in a valid signature's z vector results in
- * the following behavior from mldsa_verify():
- *
- * * -EBADMSG if a coefficient is changed to have an out-of-range value, i.e.
- * absolute value >= gamma1 - beta, corresponding to the verifier detecting
- * the out-of-range coefficient and rejecting the signature as malformed
- *
- * * -EKEYREJECTED if a coefficient is changed to a different in-range value,
- * i.e. absolute value < gamma1 - beta, corresponding to the verifier
- * continuing to the "real" signature check and that check failing
- */
- static void test_mldsa_z_range(struct kunit *test,
- const struct mldsa_testvector *tv)
- {
- u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len);
- const int lambda = params[tv->alg].lambda;
- const s32 gamma1 = params[tv->alg].gamma1;
- const int beta = params[tv->alg].beta;
- /*
- * We just modify the first coefficient. The coefficient is gamma1
- * minus either the first 18 or 20 bits of the u32, depending on gamma1.
- *
- * The layout of ML-DSA signatures is ctilde || z || h. ctilde is
- * lambda / 4 bytes, so z starts at &sig[lambda / 4].
- */
- u8 *z_ptr = &sig[lambda / 4];
- const u32 z_data = get_unaligned_le32(z_ptr);
- const u32 mask = (gamma1 << 1) - 1;
- /* These are the four boundaries of the out-of-range values. */
- const s32 out_of_range_coeffs[] = {
- -gamma1 + 1,
- -(gamma1 - beta),
- gamma1,
- gamma1 - beta,
- };
- /*
- * These are the two boundaries of the valid range, along with 0. We
- * assume that none of these matches the original coefficient.
- */
- const s32 in_range_coeffs[] = {
- -(gamma1 - beta - 1),
- 0,
- gamma1 - beta - 1,
- };
- /* Initially the signature is valid. */
- do_mldsa_and_assert_success(test, tv);
- /* Test some out-of-range coefficients. */
- for (int i = 0; i < ARRAY_SIZE(out_of_range_coeffs); i++) {
- const s32 c = out_of_range_coeffs[i];
- put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)),
- z_ptr);
- KUNIT_ASSERT_EQ(test, -EBADMSG,
- mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
- tv->msg_len, tv->pk, tv->pk_len));
- }
- /* Test some in-range coefficients. */
- for (int i = 0; i < ARRAY_SIZE(in_range_coeffs); i++) {
- const s32 c = in_range_coeffs[i];
- put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)),
- z_ptr);
- KUNIT_ASSERT_EQ(test, -EKEYREJECTED,
- mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
- tv->msg_len, tv->pk, tv->pk_len));
- }
- }
- /* Test that mldsa_verify() rejects malformed hint vectors with -EBADMSG. */
- static void test_mldsa_bad_hints(struct kunit *test,
- const struct mldsa_testvector *tv)
- {
- const int omega = params[tv->alg].omega;
- const int k = params[tv->alg].k;
- u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len);
- /* Pointer to the encoded hint vector in the signature */
- u8 *hintvec = &sig[tv->sig_len - omega - k];
- u8 h;
- /* Initially the signature is valid. */
- do_mldsa_and_assert_success(test, tv);
- /* Cumulative hint count exceeds omega */
- memcpy(sig, tv->sig, tv->sig_len);
- hintvec[omega + k - 1] = omega + 1;
- KUNIT_ASSERT_EQ(test, -EBADMSG,
- mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
- tv->msg_len, tv->pk, tv->pk_len));
- /* Cumulative hint count decreases */
- memcpy(sig, tv->sig, tv->sig_len);
- KUNIT_ASSERT_GE(test, hintvec[omega + k - 2], 1);
- hintvec[omega + k - 1] = hintvec[omega + k - 2] - 1;
- KUNIT_ASSERT_EQ(test, -EBADMSG,
- mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
- tv->msg_len, tv->pk, tv->pk_len));
- /*
- * Hint indices out of order. To test this, swap hintvec[0] and
- * hintvec[1]. This assumes that the original valid signature had at
- * least two nonzero hints in the first element (asserted below).
- */
- memcpy(sig, tv->sig, tv->sig_len);
- KUNIT_ASSERT_GE(test, hintvec[omega], 2);
- h = hintvec[0];
- hintvec[0] = hintvec[1];
- hintvec[1] = h;
- KUNIT_ASSERT_EQ(test, -EBADMSG,
- mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
- tv->msg_len, tv->pk, tv->pk_len));
- /*
- * Extra hint indices given. For this test to work, the original valid
- * signature must have fewer than omega nonzero hints (asserted below).
- */
- memcpy(sig, tv->sig, tv->sig_len);
- KUNIT_ASSERT_LT(test, hintvec[omega + k - 1], omega);
- hintvec[omega - 1] = 0xff;
- KUNIT_ASSERT_EQ(test, -EBADMSG,
- mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
- tv->msg_len, tv->pk, tv->pk_len));
- }
- static void test_mldsa_mutation(struct kunit *test,
- const struct mldsa_testvector *tv)
- {
- const int sig_len = tv->sig_len;
- const int msg_len = tv->msg_len;
- const int pk_len = tv->pk_len;
- const int num_iter = 200;
- u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, sig_len);
- u8 *msg = kunit_kmemdup_or_fail(test, tv->msg, msg_len);
- u8 *pk = kunit_kmemdup_or_fail(test, tv->pk, pk_len);
- /* Initially the signature is valid. */
- do_mldsa_and_assert_success(test, tv);
- /* Changing any bit in the signature should invalidate the signature */
- for (int i = 0; i < num_iter; i++) {
- size_t pos = get_random_u32_below(sig_len);
- u8 b = 1 << get_random_u32_below(8);
- sig[pos] ^= b;
- KUNIT_ASSERT_NE(test, 0,
- mldsa_verify(tv->alg, sig, sig_len, msg,
- msg_len, pk, pk_len));
- sig[pos] ^= b;
- }
- /* Changing any bit in the message should invalidate the signature */
- for (int i = 0; i < num_iter; i++) {
- size_t pos = get_random_u32_below(msg_len);
- u8 b = 1 << get_random_u32_below(8);
- msg[pos] ^= b;
- KUNIT_ASSERT_NE(test, 0,
- mldsa_verify(tv->alg, sig, sig_len, msg,
- msg_len, pk, pk_len));
- msg[pos] ^= b;
- }
- /* Changing any bit in the public key should invalidate the signature */
- for (int i = 0; i < num_iter; i++) {
- size_t pos = get_random_u32_below(pk_len);
- u8 b = 1 << get_random_u32_below(8);
- pk[pos] ^= b;
- KUNIT_ASSERT_NE(test, 0,
- mldsa_verify(tv->alg, sig, sig_len, msg,
- msg_len, pk, pk_len));
- pk[pos] ^= b;
- }
- /* All changes should have been undone. */
- KUNIT_ASSERT_EQ(test, 0,
- mldsa_verify(tv->alg, sig, sig_len, msg, msg_len, pk,
- pk_len));
- }
- static void test_mldsa(struct kunit *test, const struct mldsa_testvector *tv)
- {
- /* Valid signature */
- KUNIT_ASSERT_EQ(test, tv->sig_len, params[tv->alg].sig_len);
- KUNIT_ASSERT_EQ(test, tv->pk_len, params[tv->alg].pk_len);
- do_mldsa_and_assert_success(test, tv);
- /* Signature too short */
- KUNIT_ASSERT_EQ(test, -EBADMSG,
- mldsa_verify(tv->alg, tv->sig, tv->sig_len - 1, tv->msg,
- tv->msg_len, tv->pk, tv->pk_len));
- /* Signature too long */
- KUNIT_ASSERT_EQ(test, -EBADMSG,
- mldsa_verify(tv->alg, tv->sig, tv->sig_len + 1, tv->msg,
- tv->msg_len, tv->pk, tv->pk_len));
- /* Public key too short */
- KUNIT_ASSERT_EQ(test, -EBADMSG,
- mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
- tv->msg_len, tv->pk, tv->pk_len - 1));
- /* Public key too long */
- KUNIT_ASSERT_EQ(test, -EBADMSG,
- mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
- tv->msg_len, tv->pk, tv->pk_len + 1));
- /*
- * Message too short. Error is EKEYREJECTED because it gets rejected by
- * the "real" signature check rather than the well-formedness checks.
- */
- KUNIT_ASSERT_EQ(test, -EKEYREJECTED,
- mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
- tv->msg_len - 1, tv->pk, tv->pk_len));
- /*
- * Can't simply try (tv->msg, tv->msg_len + 1) too, as tv->msg would be
- * accessed out of bounds. However, ML-DSA just hashes the message and
- * doesn't handle different message lengths differently anyway.
- */
- /* Test the validity checks on the z vector. */
- test_mldsa_z_range(test, tv);
- /* Test the validity checks on the hint vector. */
- test_mldsa_bad_hints(test, tv);
- /* Test randomly mutating the inputs. */
- test_mldsa_mutation(test, tv);
- }
- static void test_mldsa44(struct kunit *test)
- {
- test_mldsa(test, &mldsa44_testvector);
- }
- static void test_mldsa65(struct kunit *test)
- {
- test_mldsa(test, &mldsa65_testvector);
- }
- static void test_mldsa87(struct kunit *test)
- {
- test_mldsa(test, &mldsa87_testvector);
- }
- static s32 mod(s32 a, s32 m)
- {
- a %= m;
- if (a < 0)
- a += m;
- return a;
- }
- static s32 symmetric_mod(s32 a, s32 m)
- {
- a = mod(a, m);
- if (a > m / 2)
- a -= m;
- return a;
- }
- /* Mechanical, inefficient translation of FIPS 204 Algorithm 36, Decompose */
- static void decompose_ref(s32 r, s32 gamma2, s32 *r0, s32 *r1)
- {
- s32 rplus = mod(r, Q);
- *r0 = symmetric_mod(rplus, 2 * gamma2);
- if (rplus - *r0 == Q - 1) {
- *r1 = 0;
- *r0 = *r0 - 1;
- } else {
- *r1 = (rplus - *r0) / (2 * gamma2);
- }
- }
- /* Mechanical, inefficient translation of FIPS 204 Algorithm 40, UseHint */
- static s32 use_hint_ref(u8 h, s32 r, s32 gamma2)
- {
- s32 m = (Q - 1) / (2 * gamma2);
- s32 r0, r1;
- decompose_ref(r, gamma2, &r0, &r1);
- if (h == 1 && r0 > 0)
- return mod(r1 + 1, m);
- if (h == 1 && r0 <= 0)
- return mod(r1 - 1, m);
- return r1;
- }
- /*
- * Test that for all possible inputs, mldsa_use_hint() gives the same output as
- * a mechanical translation of the pseudocode from FIPS 204.
- */
- static void test_mldsa_use_hint(struct kunit *test)
- {
- for (int i = 0; i < 2; i++) {
- const s32 gamma2 = (Q - 1) / (i == 0 ? 88 : 32);
- for (u8 h = 0; h < 2; h++) {
- for (s32 r = 0; r < Q; r++) {
- KUNIT_ASSERT_EQ(test,
- mldsa_use_hint(h, r, gamma2),
- use_hint_ref(h, r, gamma2));
- }
- }
- }
- }
- static void benchmark_mldsa(struct kunit *test,
- const struct mldsa_testvector *tv)
- {
- const int warmup_niter = 200;
- const int benchmark_niter = 200;
- u64 t0, t1;
- if (!IS_ENABLED(CONFIG_CRYPTO_LIB_BENCHMARK))
- kunit_skip(test, "not enabled");
- for (int i = 0; i < warmup_niter; i++)
- do_mldsa_and_assert_success(test, tv);
- t0 = ktime_get_ns();
- for (int i = 0; i < benchmark_niter; i++)
- do_mldsa_and_assert_success(test, tv);
- t1 = ktime_get_ns();
- kunit_info(test, "%llu ops/s",
- div64_u64((u64)benchmark_niter * NSEC_PER_SEC,
- t1 - t0 ?: 1));
- }
- static void benchmark_mldsa44(struct kunit *test)
- {
- benchmark_mldsa(test, &mldsa44_testvector);
- }
- static void benchmark_mldsa65(struct kunit *test)
- {
- benchmark_mldsa(test, &mldsa65_testvector);
- }
- static void benchmark_mldsa87(struct kunit *test)
- {
- benchmark_mldsa(test, &mldsa87_testvector);
- }
- static struct kunit_case mldsa_kunit_cases[] = {
- KUNIT_CASE(test_mldsa44),
- KUNIT_CASE(test_mldsa65),
- KUNIT_CASE(test_mldsa87),
- KUNIT_CASE(test_mldsa_use_hint),
- KUNIT_CASE(benchmark_mldsa44),
- KUNIT_CASE(benchmark_mldsa65),
- KUNIT_CASE(benchmark_mldsa87),
- {},
- };
- static struct kunit_suite mldsa_kunit_suite = {
- .name = "mldsa",
- .test_cases = mldsa_kunit_cases,
- };
- kunit_test_suite(mldsa_kunit_suite);
- MODULE_DESCRIPTION("KUnit tests and benchmark for ML-DSA");
- MODULE_IMPORT_NS("EXPORTED_FOR_KUNIT_TESTING");
- MODULE_LICENSE("GPL");
|