summaryrefslogtreecommitdiff
path: root/lib/crypto/tests/mldsa_kunit.c
blob: 67f8f93e3dc6b6dcb4acfce452b52a84c3f9dc6c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
// SPDX-License-Identifier: GPL-2.0-or-later
/*
 * KUnit tests and benchmark for ML-DSA
 *
 * Copyright 2025 Google LLC
 */
#include <crypto/mldsa.h>
#include <kunit/test.h>
#include <linux/random.h>
#include <linux/unaligned.h>

#define Q 8380417 /* The prime q = 2^23 - 2^13 + 1 */

/* ML-DSA parameters that the tests use */
static const struct {
	int sig_len;
	int pk_len;
	int k;
	int lambda;
	int gamma1;
	int beta;
	int omega;
} params[] = {
	[MLDSA44] = {
		.sig_len = MLDSA44_SIGNATURE_SIZE,
		.pk_len = MLDSA44_PUBLIC_KEY_SIZE,
		.k = 4,
		.lambda = 128,
		.gamma1 = 1 << 17,
		.beta = 78,
		.omega = 80,
	},
	[MLDSA65] = {
		.sig_len = MLDSA65_SIGNATURE_SIZE,
		.pk_len = MLDSA65_PUBLIC_KEY_SIZE,
		.k = 6,
		.lambda = 192,
		.gamma1 = 1 << 19,
		.beta = 196,
		.omega = 55,
	},
	[MLDSA87] = {
		.sig_len = MLDSA87_SIGNATURE_SIZE,
		.pk_len = MLDSA87_PUBLIC_KEY_SIZE,
		.k = 8,
		.lambda = 256,
		.gamma1 = 1 << 19,
		.beta = 120,
		.omega = 75,
	},
};

#include "mldsa-testvecs.h"

static void do_mldsa_and_assert_success(struct kunit *test,
					const struct mldsa_testvector *tv)
{
	int err = mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
			       tv->msg_len, tv->pk, tv->pk_len);
	KUNIT_ASSERT_EQ(test, err, 0);
}

static u8 *kunit_kmemdup_or_fail(struct kunit *test, const u8 *src, size_t len)
{
	u8 *dst = kunit_kmalloc(test, len, GFP_KERNEL);

	KUNIT_ASSERT_NOT_NULL(test, dst);
	return memcpy(dst, src, len);
}

/*
 * Test that changing coefficients in a valid signature's z vector results in
 * the following behavior from mldsa_verify():
 *
 *  * -EBADMSG if a coefficient is changed to have an out-of-range value, i.e.
 *    absolute value >= gamma1 - beta, corresponding to the verifier detecting
 *    the out-of-range coefficient and rejecting the signature as malformed
 *
 *  * -EKEYREJECTED if a coefficient is changed to a different in-range value,
 *    i.e. absolute value < gamma1 - beta, corresponding to the verifier
 *    continuing to the "real" signature check and that check failing
 */
static void test_mldsa_z_range(struct kunit *test,
			       const struct mldsa_testvector *tv)
{
	u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len);
	const int lambda = params[tv->alg].lambda;
	const s32 gamma1 = params[tv->alg].gamma1;
	const int beta = params[tv->alg].beta;
	/*
	 * We just modify the first coefficient.  The coefficient is gamma1
	 * minus either the first 18 or 20 bits of the u32, depending on gamma1.
	 *
	 * The layout of ML-DSA signatures is ctilde || z || h.  ctilde is
	 * lambda / 4 bytes, so z starts at &sig[lambda / 4].
	 */
	u8 *z_ptr = &sig[lambda / 4];
	const u32 z_data = get_unaligned_le32(z_ptr);
	const u32 mask = (gamma1 << 1) - 1;
	/* These are the four boundaries of the out-of-range values. */
	const s32 out_of_range_coeffs[] = {
		-gamma1 + 1,
		-(gamma1 - beta),
		gamma1,
		gamma1 - beta,
	};
	/*
	 * These are the two boundaries of the valid range, along with 0.  We
	 * assume that none of these matches the original coefficient.
	 */
	const s32 in_range_coeffs[] = {
		-(gamma1 - beta - 1),
		0,
		gamma1 - beta - 1,
	};

	/* Initially the signature is valid. */
	do_mldsa_and_assert_success(test, tv);

	/* Test some out-of-range coefficients. */
	for (int i = 0; i < ARRAY_SIZE(out_of_range_coeffs); i++) {
		const s32 c = out_of_range_coeffs[i];

		put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)),
				   z_ptr);
		KUNIT_ASSERT_EQ(test, -EBADMSG,
				mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
					     tv->msg_len, tv->pk, tv->pk_len));
	}

	/* Test some in-range coefficients. */
	for (int i = 0; i < ARRAY_SIZE(in_range_coeffs); i++) {
		const s32 c = in_range_coeffs[i];

		put_unaligned_le32((z_data & ~mask) | (mask & (gamma1 - c)),
				   z_ptr);
		KUNIT_ASSERT_EQ(test, -EKEYREJECTED,
				mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
					     tv->msg_len, tv->pk, tv->pk_len));
	}
}

/* Test that mldsa_verify() rejects malformed hint vectors with -EBADMSG. */
static void test_mldsa_bad_hints(struct kunit *test,
				 const struct mldsa_testvector *tv)
{
	const int omega = params[tv->alg].omega;
	const int k = params[tv->alg].k;
	u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, tv->sig_len);
	/* Pointer to the encoded hint vector in the signature */
	u8 *hintvec = &sig[tv->sig_len - omega - k];
	u8 h;

	/* Initially the signature is valid. */
	do_mldsa_and_assert_success(test, tv);

	/* Cumulative hint count exceeds omega */
	memcpy(sig, tv->sig, tv->sig_len);
	hintvec[omega + k - 1] = omega + 1;
	KUNIT_ASSERT_EQ(test, -EBADMSG,
			mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
				     tv->msg_len, tv->pk, tv->pk_len));

	/* Cumulative hint count decreases */
	memcpy(sig, tv->sig, tv->sig_len);
	KUNIT_ASSERT_GE(test, hintvec[omega + k - 2], 1);
	hintvec[omega + k - 1] = hintvec[omega + k - 2] - 1;
	KUNIT_ASSERT_EQ(test, -EBADMSG,
			mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
				     tv->msg_len, tv->pk, tv->pk_len));

	/*
	 * Hint indices out of order.  To test this, swap hintvec[0] and
	 * hintvec[1].  This assumes that the original valid signature had at
	 * least two nonzero hints in the first element (asserted below).
	 */
	memcpy(sig, tv->sig, tv->sig_len);
	KUNIT_ASSERT_GE(test, hintvec[omega], 2);
	h = hintvec[0];
	hintvec[0] = hintvec[1];
	hintvec[1] = h;
	KUNIT_ASSERT_EQ(test, -EBADMSG,
			mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
				     tv->msg_len, tv->pk, tv->pk_len));

	/*
	 * Extra hint indices given.  For this test to work, the original valid
	 * signature must have fewer than omega nonzero hints (asserted below).
	 */
	memcpy(sig, tv->sig, tv->sig_len);
	KUNIT_ASSERT_LT(test, hintvec[omega + k - 1], omega);
	hintvec[omega - 1] = 0xff;
	KUNIT_ASSERT_EQ(test, -EBADMSG,
			mldsa_verify(tv->alg, sig, tv->sig_len, tv->msg,
				     tv->msg_len, tv->pk, tv->pk_len));
}

static void test_mldsa_mutation(struct kunit *test,
				const struct mldsa_testvector *tv)
{
	const int sig_len = tv->sig_len;
	const int msg_len = tv->msg_len;
	const int pk_len = tv->pk_len;
	const int num_iter = 200;
	u8 *sig = kunit_kmemdup_or_fail(test, tv->sig, sig_len);
	u8 *msg = kunit_kmemdup_or_fail(test, tv->msg, msg_len);
	u8 *pk = kunit_kmemdup_or_fail(test, tv->pk, pk_len);

	/* Initially the signature is valid. */
	do_mldsa_and_assert_success(test, tv);

	/* Changing any bit in the signature should invalidate the signature */
	for (int i = 0; i < num_iter; i++) {
		size_t pos = get_random_u32_below(sig_len);
		u8 b = 1 << get_random_u32_below(8);

		sig[pos] ^= b;
		KUNIT_ASSERT_NE(test, 0,
				mldsa_verify(tv->alg, sig, sig_len, msg,
					     msg_len, pk, pk_len));
		sig[pos] ^= b;
	}

	/* Changing any bit in the message should invalidate the signature */
	for (int i = 0; i < num_iter; i++) {
		size_t pos = get_random_u32_below(msg_len);
		u8 b = 1 << get_random_u32_below(8);

		msg[pos] ^= b;
		KUNIT_ASSERT_NE(test, 0,
				mldsa_verify(tv->alg, sig, sig_len, msg,
					     msg_len, pk, pk_len));
		msg[pos] ^= b;
	}

	/* Changing any bit in the public key should invalidate the signature */
	for (int i = 0; i < num_iter; i++) {
		size_t pos = get_random_u32_below(pk_len);
		u8 b = 1 << get_random_u32_below(8);

		pk[pos] ^= b;
		KUNIT_ASSERT_NE(test, 0,
				mldsa_verify(tv->alg, sig, sig_len, msg,
					     msg_len, pk, pk_len));
		pk[pos] ^= b;
	}

	/* All changes should have been undone. */
	KUNIT_ASSERT_EQ(test, 0,
			mldsa_verify(tv->alg, sig, sig_len, msg, msg_len, pk,
				     pk_len));
}

static void test_mldsa(struct kunit *test, const struct mldsa_testvector *tv)
{
	/* Valid signature */
	KUNIT_ASSERT_EQ(test, tv->sig_len, params[tv->alg].sig_len);
	KUNIT_ASSERT_EQ(test, tv->pk_len, params[tv->alg].pk_len);
	do_mldsa_and_assert_success(test, tv);

	/* Signature too short */
	KUNIT_ASSERT_EQ(test, -EBADMSG,
			mldsa_verify(tv->alg, tv->sig, tv->sig_len - 1, tv->msg,
				     tv->msg_len, tv->pk, tv->pk_len));

	/* Signature too long */
	KUNIT_ASSERT_EQ(test, -EBADMSG,
			mldsa_verify(tv->alg, tv->sig, tv->sig_len + 1, tv->msg,
				     tv->msg_len, tv->pk, tv->pk_len));

	/* Public key too short */
	KUNIT_ASSERT_EQ(test, -EBADMSG,
			mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
				     tv->msg_len, tv->pk, tv->pk_len - 1));

	/* Public key too long */
	KUNIT_ASSERT_EQ(test, -EBADMSG,
			mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
				     tv->msg_len, tv->pk, tv->pk_len + 1));

	/*
	 * Message too short.  Error is EKEYREJECTED because it gets rejected by
	 * the "real" signature check rather than the well-formedness checks.
	 */
	KUNIT_ASSERT_EQ(test, -EKEYREJECTED,
			mldsa_verify(tv->alg, tv->sig, tv->sig_len, tv->msg,
				     tv->msg_len - 1, tv->pk, tv->pk_len));
	/*
	 * Can't simply try (tv->msg, tv->msg_len + 1) too, as tv->msg would be
	 * accessed out of bounds.  However, ML-DSA just hashes the message and
	 * doesn't handle different message lengths differently anyway.
	 */

	/* Test the validity checks on the z vector. */
	test_mldsa_z_range(test, tv);

	/* Test the validity checks on the hint vector. */
	test_mldsa_bad_hints(test, tv);

	/* Test randomly mutating the inputs. */
	test_mldsa_mutation(test, tv);
}

static void test_mldsa44(struct kunit *test)
{
	test_mldsa(test, &mldsa44_testvector);
}

static void test_mldsa65(struct kunit *test)
{
	test_mldsa(test, &mldsa65_testvector);
}

static void test_mldsa87(struct kunit *test)
{
	test_mldsa(test, &mldsa87_testvector);
}

static s32 mod(s32 a, s32 m)
{
	a %= m;
	if (a < 0)
		a += m;
	return a;
}

static s32 symmetric_mod(s32 a, s32 m)
{
	a = mod(a, m);
	if (a > m / 2)
		a -= m;
	return a;
}

/* Mechanical, inefficient translation of FIPS 204 Algorithm 36, Decompose */
static void decompose_ref(s32 r, s32 gamma2, s32 *r0, s32 *r1)
{
	s32 rplus = mod(r, Q);

	*r0 = symmetric_mod(rplus, 2 * gamma2);
	if (rplus - *r0 == Q - 1) {
		*r1 = 0;
		*r0 = *r0 - 1;
	} else {
		*r1 = (rplus - *r0) / (2 * gamma2);
	}
}

/* Mechanical, inefficient translation of FIPS 204 Algorithm 40, UseHint */
static s32 use_hint_ref(u8 h, s32 r, s32 gamma2)
{
	s32 m = (Q - 1) / (2 * gamma2);
	s32 r0, r1;

	decompose_ref(r, gamma2, &r0, &r1);
	if (h == 1 && r0 > 0)
		return mod(r1 + 1, m);
	if (h == 1 && r0 <= 0)
		return mod(r1 - 1, m);
	return r1;
}

/*
 * Test that for all possible inputs, mldsa_use_hint() gives the same output as
 * a mechanical translation of the pseudocode from FIPS 204.
 */
static void test_mldsa_use_hint(struct kunit *test)
{
	for (int i = 0; i < 2; i++) {
		const s32 gamma2 = (Q - 1) / (i == 0 ? 88 : 32);

		for (u8 h = 0; h < 2; h++) {
			for (s32 r = 0; r < Q; r++) {
				KUNIT_ASSERT_EQ(test,
						mldsa_use_hint(h, r, gamma2),
						use_hint_ref(h, r, gamma2));
			}
		}
	}
}

static void benchmark_mldsa(struct kunit *test,
			    const struct mldsa_testvector *tv)
{
	const int warmup_niter = 200;
	const int benchmark_niter = 200;
	u64 t0, t1;

	if (!IS_ENABLED(CONFIG_CRYPTO_LIB_BENCHMARK))
		kunit_skip(test, "not enabled");

	for (int i = 0; i < warmup_niter; i++)
		do_mldsa_and_assert_success(test, tv);

	t0 = ktime_get_ns();
	for (int i = 0; i < benchmark_niter; i++)
		do_mldsa_and_assert_success(test, tv);
	t1 = ktime_get_ns();
	kunit_info(test, "%llu ops/s",
		   div64_u64((u64)benchmark_niter * NSEC_PER_SEC,
			     t1 - t0 ?: 1));
}

static void benchmark_mldsa44(struct kunit *test)
{
	benchmark_mldsa(test, &mldsa44_testvector);
}

static void benchmark_mldsa65(struct kunit *test)
{
	benchmark_mldsa(test, &mldsa65_testvector);
}

static void benchmark_mldsa87(struct kunit *test)
{
	benchmark_mldsa(test, &mldsa87_testvector);
}

static struct kunit_case mldsa_kunit_cases[] = {
	KUNIT_CASE(test_mldsa44),
	KUNIT_CASE(test_mldsa65),
	KUNIT_CASE(test_mldsa87),
	KUNIT_CASE(test_mldsa_use_hint),
	KUNIT_CASE(benchmark_mldsa44),
	KUNIT_CASE(benchmark_mldsa65),
	KUNIT_CASE(benchmark_mldsa87),
	{},
};

static struct kunit_suite mldsa_kunit_suite = {
	.name = "mldsa",
	.test_cases = mldsa_kunit_cases,
};
kunit_test_suite(mldsa_kunit_suite);

MODULE_DESCRIPTION("KUnit tests and benchmark for ML-DSA");
MODULE_IMPORT_NS("EXPORTED_FOR_KUNIT_TESTING");
MODULE_LICENSE("GPL");