summaryrefslogtreecommitdiff
path: root/lib/crypto/mldsa.c
blob: c96fddc4e7dcf94417008f28e0c6dbf87b3746e6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
// SPDX-License-Identifier: GPL-2.0-or-later
/*
 * Support for verifying ML-DSA signatures
 *
 * Copyright 2025 Google LLC
 */

#include <crypto/mldsa.h>
#include <crypto/sha3.h>
#include <kunit/visibility.h>
#include <linux/export.h>
#include <linux/module.h>
#include <linux/slab.h>
#include <linux/string.h>
#include <linux/unaligned.h>
#include "fips-mldsa.h"

#define Q 8380417 /* The prime q = 2^23 - 2^13 + 1 */
#define QINV_MOD_2_32 58728449 /* Multiplicative inverse of q mod 2^32 */
#define N 256 /* Number of components per ring element */
#define D 13 /* Number of bits dropped from the public key vector t */
#define RHO_LEN 32 /* Length of the public random seed in bytes */
#define MAX_W1_ENCODED_LEN 192 /* Max encoded length of one element of w'_1 */

/*
 * The zetas array in Montgomery form, i.e. with extra factor of 2^32.
 * Reference: FIPS 204 Section 7.5 "NTT and NTT^-1"
 * Generated by the following Python code:
 * q=8380417; [a%q - q*(a%q > q//2) for a in [1753**(int(f'{i:08b}'[::-1], 2)) << 32 for i in range(256)]]
 */
static const s32 zetas_times_2_32[N] = {
	-4186625, 25847,    -2608894, -518909,	237124,	  -777960,  -876248,
	466468,	  1826347,  2353451,  -359251,	-2091905, 3119733,  -2884855,
	3111497,  2680103,  2725464,  1024112,	-1079900, 3585928,  -549488,
	-1119584, 2619752,  -2108549, -2118186, -3859737, -1399561, -3277672,
	1757237,  -19422,   4010497,  280005,	2706023,  95776,    3077325,
	3530437,  -1661693, -3592148, -2537516, 3915439,  -3861115, -3043716,
	3574422,  -2867647, 3539968,  -300467,	2348700,  -539299,  -1699267,
	-1643818, 3505694,  -3821735, 3507263,	-2140649, -1600420, 3699596,
	811944,	  531354,   954230,   3881043,	3900724,  -2556880, 2071892,
	-2797779, -3930395, -1528703, -3677745, -3041255, -1452451, 3475950,
	2176455,  -1585221, -1257611, 1939314,	-4083598, -1000202, -3190144,
	-3157330, -3632928, 126922,   3412210,	-983419,  2147896,  2715295,
	-2967645, -3693493, -411027,  -2477047, -671102,  -1228525, -22981,
	-1308169, -381987,  1349076,  1852771,	-1430430, -3343383, 264944,
	508951,	  3097992,  44288,    -1100098, 904516,	  3958618,  -3724342,
	-8578,	  1653064,  -3249728, 2389356,	-210977,  759969,   -1316856,
	189548,	  -3553272, 3159746,  -1851402, -2409325, -177440,  1315589,
	1341330,  1285669,  -1584928, -812732,	-1439742, -3019102, -3881060,
	-3628969, 3839961,  2091667,  3407706,	2316500,  3817976,  -3342478,
	2244091,  -2446433, -3562462, 266997,	2434439,  -1235728, 3513181,
	-3520352, -3759364, -1197226, -3193378, 900702,	  1859098,  909542,
	819034,	  495491,   -1613174, -43260,	-522500,  -655327,  -3122442,
	2031748,  3207046,  -3556995, -525098,	-768622,  -3595838, 342297,
	286988,	  -2437823, 4108315,  3437287,	-3342277, 1735879,  203044,
	2842341,  2691481,  -2590150, 1265009,	4055324,  1247620,  2486353,
	1595974,  -3767016, 1250494,  2635921,	-3548272, -2994039, 1869119,
	1903435,  -1050970, -1333058, 1237275,	-3318210, -1430225, -451100,
	1312455,  3306115,  -1962642, -1279661, 1917081,  -2546312, -1374803,
	1500165,  777191,   2235880,  3406031,	-542412,  -2831860, -1671176,
	-1846953, -2584293, -3724270, 594136,	-3776993, -2013608, 2432395,
	2454455,  -164721,  1957272,  3369112,	185531,	  -1207385, -3183426,
	162844,	  1616392,  3014001,  810149,	1652634,  -3694233, -1799107,
	-3038916, 3523897,  3866901,  269760,	2213111,  -975884,  1717735,
	472078,	  -426683,  1723600,  -1803090, 1910376,  -1667432, -1104333,
	-260646,  -3833893, -2939036, -2235985, -420899,  -2286327, 183443,
	-976891,  1612842,  -3545687, -554416,	3919660,  -48306,   -1362209,
	3937738,  1400424,  -846154,  1976782
};

/* Reference: FIPS 204 Section 4 "Parameter Sets" */
static const struct mldsa_parameter_set {
	u8 k; /* num rows in the matrix A */
	u8 l; /* num columns in the matrix A */
	u8 ctilde_len; /* length of commitment hash ctilde in bytes; lambda/4 */
	u8 omega; /* max num of 1's in the hint vector h */
	u8 tau; /* num of +-1's in challenge c */
	u8 beta; /* tau times eta */
	u16 pk_len; /* length of public keys in bytes */
	u16 sig_len; /* length of signatures in bytes */
	s32 gamma1; /* coefficient range of y */
} mldsa_parameter_sets[] = {
	[MLDSA44] = {
		.k = 4,
		.l = 4,
		.ctilde_len = 32,
		.omega = 80,
		.tau = 39,
		.beta = 78,
		.pk_len = MLDSA44_PUBLIC_KEY_SIZE,
		.sig_len = MLDSA44_SIGNATURE_SIZE,
		.gamma1 = 1 << 17,
	},
	[MLDSA65] = {
		.k = 6,
		.l = 5,
		.ctilde_len = 48,
		.omega = 55,
		.tau = 49,
		.beta = 196,
		.pk_len = MLDSA65_PUBLIC_KEY_SIZE,
		.sig_len = MLDSA65_SIGNATURE_SIZE,
		.gamma1 = 1 << 19,
	},
	[MLDSA87] = {
		.k = 8,
		.l = 7,
		.ctilde_len = 64,
		.omega = 75,
		.tau = 60,
		.beta = 120,
		.pk_len = MLDSA87_PUBLIC_KEY_SIZE,
		.sig_len = MLDSA87_SIGNATURE_SIZE,
		.gamma1 = 1 << 19,
	},
};

/*
 * An element of the ring R_q (normal form) or the ring T_q (NTT form).  It
 * consists of N integers mod q: either the polynomial coefficients of the R_q
 * element or the components of the T_q element.  In either case, whether they
 * are fully reduced to [0, q - 1] varies in the different parts of the code.
 */
struct mldsa_ring_elem {
	s32 x[N];
};

struct mldsa_verification_workspace {
	/* SHAKE context for computing c, mu, and ctildeprime */
	struct shake_ctx shake;
	/* The fields in this union are used in their order of declaration. */
	union {
		/* The hash of the public key */
		u8 tr[64];
		/* The message representative mu */
		u8 mu[64];
		/* Temporary space for rej_ntt_poly() */
		u8 block[SHAKE128_BLOCK_SIZE + 1];
		/* Encoded element of w'_1 */
		u8 w1_encoded[MAX_W1_ENCODED_LEN];
		/* The commitment hash.  Real length is params->ctilde_len */
		u8 ctildeprime[64];
	};
	/* SHAKE context for generating elements of the matrix A */
	struct shake_ctx a_shake;
	/*
	 * An element of the matrix A generated from the public seed, or an
	 * element of the vector t_1 decoded from the public key and pre-scaled
	 * by 2^d.  Both are in NTT form.  To reduce memory usage, we generate
	 * or decode these elements only as needed.
	 */
	union {
		struct mldsa_ring_elem a;
		struct mldsa_ring_elem t1_scaled;
	};
	/* The challenge c, generated from ctilde */
	struct mldsa_ring_elem c;
	/* A temporary element used during calculations */
	struct mldsa_ring_elem tmp;

	/* The following fields are variable-length: */

	/* The signer's response vector */
	struct mldsa_ring_elem z[/* l */];

	/* The signer's hint vector */
	/* u8 h[k * N]; */
};

/*
 * Compute a * b * 2^-32 mod q.  a * b must be in the range [-2^31 * q, 2^31 * q
 * - 1] before reduction.  The return value is in the range [-q + 1, q - 1].
 *
 * To reduce mod q efficiently, this uses Montgomery reduction with R=2^32.
 * That's where the factor of 2^-32 comes from.  The caller must include a
 * factor of 2^32 at some point to compensate for that.
 *
 * To keep the input and output ranges very close to symmetric, this
 * specifically does a "signed" Montgomery reduction.  That is, when computing
 * d = c * q^-1 mod 2^32, this chooses a representative in [S32_MIN, S32_MAX]
 * rather than [0, U32_MAX], i.e. s32 rather than u32.  This matters in the
 * wider multiplication d * Q when d keeps its value via sign extension.
 *
 * Reference: FIPS 204 Appendix A "Montgomery Multiplication".  But, it doesn't
 * explain it properly: it has an off-by-one error in the upper end of the input
 * range, it doesn't clarify that the signed version should be used, and it
 * gives an unnecessarily large output range.  A better citation is perhaps the
 * Dilithium reference code, which functionally matches the below code and
 * merely has the (benign) off-by-one error in its documentation.
 */
static inline s32 Zq_mult(s32 a, s32 b)
{
	/* Compute the unreduced product c. */
	s64 c = (s64)a * b;

	/*
	 * Compute d = c * q^-1 mod 2^32.  Generate a signed result, as
	 * explained above, but do the actual multiplication using an unsigned
	 * type to avoid signed integer overflow which is undefined behavior.
	 */
	s32 d = (u32)c * QINV_MOD_2_32;

	/*
	 * Compute e = c - d * q.  This makes the low 32 bits zero, since
	 *   c - (c * q^-1) * q mod 2^32
	 * = c - c * (q^-1 * q) mod 2^32
	 * = c - c * 1 mod 2^32
	 * = c - c mod 2^32
	 * = 0 mod 2^32
	 */
	s64 e = c - (s64)d * Q;

	/* Finally, return e * 2^-32. */
	return e >> 32;
}

/*
 * Convert @w to its number-theoretically-transformed representation in-place.
 * Reference: FIPS 204 Algorithm 41, NTT
 *
 * To prevent intermediate overflows, all input coefficients must have absolute
 * value < q.  All output components have absolute value < 9*q.
 */
static void ntt(struct mldsa_ring_elem *w)
{
	int m = 0; /* index in zetas_times_2_32 */

	for (int len = 128; len >= 1; len /= 2) {
		for (int start = 0; start < 256; start += 2 * len) {
			const s32 z = zetas_times_2_32[++m];

			for (int j = start; j < start + len; j++) {
				s32 t = Zq_mult(z, w->x[j + len]);

				w->x[j + len] = w->x[j] - t;
				w->x[j] += t;
			}
		}
	}
}

/*
 * Convert @w from its number-theoretically-transformed representation in-place.
 * Reference: FIPS 204 Algorithm 42, NTT^-1
 *
 * This also multiplies the coefficients by 2^32, undoing an extra factor of
 * 2^-32 introduced earlier, and reduces the coefficients to [0, q - 1].
 */
static void invntt_and_mul_2_32(struct mldsa_ring_elem *w)
{
	int m = 256; /* index in zetas_times_2_32 */

	/* Prevent intermediate overflows. */
	for (int j = 0; j < 256; j++)
		w->x[j] %= Q;

	for (int len = 1; len < 256; len *= 2) {
		for (int start = 0; start < 256; start += 2 * len) {
			const s32 z = -zetas_times_2_32[--m];

			for (int j = start; j < start + len; j++) {
				s32 t = w->x[j];

				w->x[j] = t + w->x[j + len];
				w->x[j + len] = Zq_mult(z, t - w->x[j + len]);
			}
		}
	}
	/*
	 * Multiply by 2^32 * 256^-1.  2^32 cancels the factor of 2^-32 from
	 * earlier Montgomery multiplications.  256^-1 is for NTT^-1.  This
	 * itself uses Montgomery multiplication, so *another* 2^32 is needed.
	 * Thus the actual multiplicand is 2^32 * 2^32 * 256^-1 mod q = 41978.
	 *
	 * Finally, also reduce from [-q + 1, q - 1] to [0, q - 1].
	 */
	for (int j = 0; j < 256; j++) {
		w->x[j] = Zq_mult(w->x[j], 41978);
		w->x[j] += (w->x[j] >> 31) & Q;
	}
}

/*
 * Decode an element of t_1, i.e. the high d bits of t = A*s_1 + s_2.
 * Reference: FIPS 204 Algorithm 23, pkDecode.
 * Also multiply it by 2^d and convert it to NTT form.
 */
static const u8 *decode_t1_elem(struct mldsa_ring_elem *out,
				const u8 *t1_encoded)
{
	for (int j = 0; j < N; j += 4, t1_encoded += 5) {
		u32 v = get_unaligned_le32(t1_encoded);

		out->x[j + 0] = ((v >> 0) & 0x3ff) << D;
		out->x[j + 1] = ((v >> 10) & 0x3ff) << D;
		out->x[j + 2] = ((v >> 20) & 0x3ff) << D;
		out->x[j + 3] = ((v >> 30) | (t1_encoded[4] << 2)) << D;
		static_assert(0x3ff << D < Q); /* All coefficients < q. */
	}
	ntt(out);
	return t1_encoded; /* Return updated pointer. */
}

/*
 * Decode the signer's response vector 'z' from the signature.
 * Reference: FIPS 204 Algorithm 27, sigDecode.
 *
 * This also validates that the coefficients of z are in range, corresponding
 * the infinity norm check at the end of Algorithm 8, ML-DSA.Verify_internal.
 *
 * Finally, this also converts z to NTT form.
 */
static bool decode_z(struct mldsa_ring_elem z[/* l */], int l, s32 gamma1,
		     int beta, const u8 **sig_ptr)
{
	const u8 *sig = *sig_ptr;

	for (int i = 0; i < l; i++) {
		if (l == 4) { /* ML-DSA-44? */
			/* 18-bit coefficients: decode 4 from 9 bytes. */
			for (int j = 0; j < N; j += 4, sig += 9) {
				u64 v = get_unaligned_le64(sig);

				z[i].x[j + 0] = (v >> 0) & 0x3ffff;
				z[i].x[j + 1] = (v >> 18) & 0x3ffff;
				z[i].x[j + 2] = (v >> 36) & 0x3ffff;
				z[i].x[j + 3] = (v >> 54) | (sig[8] << 10);
			}
		} else {
			/* 20-bit coefficients: decode 4 from 10 bytes. */
			for (int j = 0; j < N; j += 4, sig += 10) {
				u64 v = get_unaligned_le64(sig);

				z[i].x[j + 0] = (v >> 0) & 0xfffff;
				z[i].x[j + 1] = (v >> 20) & 0xfffff;
				z[i].x[j + 2] = (v >> 40) & 0xfffff;
				z[i].x[j + 3] =
					(v >> 60) |
					(get_unaligned_le16(&sig[8]) << 4);
			}
		}
		for (int j = 0; j < N; j++) {
			z[i].x[j] = gamma1 - z[i].x[j];
			if (z[i].x[j] <= -(gamma1 - beta) ||
			    z[i].x[j] >= gamma1 - beta)
				return false;
		}
		ntt(&z[i]);
	}
	*sig_ptr = sig; /* Return updated pointer. */
	return true;
}

/*
 * Decode the signer's hint vector 'h' from the signature.
 * Reference: FIPS 204 Algorithm 21, HintBitUnpack
 *
 * Note that there are several ways in which the hint vector can be malformed.
 */
static bool decode_hint_vector(u8 h[/* k * N */], int k, int omega, const u8 *y)
{
	int index = 0;

	memset(h, 0, k * N);
	for (int i = 0; i < k; i++) {
		int count = y[omega + i]; /* num 1's in elems 0 through i */
		int prev = -1;

		/* Cumulative count mustn't decrease or exceed omega. */
		if (count < index || count > omega)
			return false;
		for (; index < count; index++) {
			if (prev >= y[index]) /* Coefficients out of order? */
				return false;
			prev = y[index];
			h[i * N + y[index]] = 1;
		}
	}
	return mem_is_zero(&y[index], omega - index);
}

/*
 * Expand @seed into an element of R_q @c with coefficients in {-1, 0, 1},
 * exactly @tau of them nonzero.  Reference: FIPS 204 Algorithm 29, SampleInBall
 */
static void sample_in_ball(struct mldsa_ring_elem *c, const u8 *seed,
			   size_t seed_len, int tau, struct shake_ctx *shake)
{
	u64 signs;
	u8 j;

	shake256_init(shake);
	shake_update(shake, seed, seed_len);
	shake_squeeze(shake, (u8 *)&signs, sizeof(signs));
	le64_to_cpus(&signs);
	*c = (struct mldsa_ring_elem){};
	for (int i = N - tau; i < N; i++, signs >>= 1) {
		do {
			shake_squeeze(shake, &j, 1);
		} while (j > i);
		c->x[i] = c->x[j];
		c->x[j] = 1 - 2 * (s32)(signs & 1);
	}
}

/*
 * Expand the public seed @rho and @row_and_column into an element of T_q @out.
 * Reference: FIPS 204 Algorithm 30, RejNTTPoly
 *
 * @shake and @block are temporary space used by the expansion.  @block has
 * space for one SHAKE128 block, plus an extra byte to allow reading a u32 from
 * the final 3-byte group without reading out-of-bounds.
 */
static void rej_ntt_poly(struct mldsa_ring_elem *out, const u8 rho[RHO_LEN],
			 __le16 row_and_column, struct shake_ctx *shake,
			 u8 block[SHAKE128_BLOCK_SIZE + 1])
{
	shake128_init(shake);
	shake_update(shake, rho, RHO_LEN);
	shake_update(shake, (u8 *)&row_and_column, sizeof(row_and_column));
	for (int i = 0; i < N;) {
		shake_squeeze(shake, block, SHAKE128_BLOCK_SIZE);
		block[SHAKE128_BLOCK_SIZE] = 0; /* for KMSAN */
		static_assert(SHAKE128_BLOCK_SIZE % 3 == 0);
		for (int j = 0; j < SHAKE128_BLOCK_SIZE && i < N; j += 3) {
			u32 x = get_unaligned_le32(&block[j]) & 0x7fffff;

			if (x < Q) /* Ignore values >= q. */
				out->x[i++] = x;
		}
	}
}

/*
 * Return the HighBits of r adjusted according to hint h
 * Reference: FIPS 204 Algorithm 40, UseHint
 *
 * This is needed because of the public key compression in ML-DSA.
 *
 * h is either 0 or 1, r is in [0, q - 1], and gamma2 is either (q - 1) / 88 or
 * (q - 1) / 32.  Except when invoked via the unit test interface, gamma2 is a
 * compile-time constant, so compilers will optimize the code accordingly.
 */
static __always_inline s32 use_hint(u8 h, s32 r, const s32 gamma2)
{
	const s32 m = (Q - 1) / (2 * gamma2); /* 44 or 16, compile-time const */
	s32 r1;

	/*
	 * Handle the special case where r - (r mod+- (2 * gamma2)) == q - 1,
	 * i.e. r >= q - gamma2.  This is also exactly where the computation of
	 * r1 below would produce 'm' and would need a correction.
	 */
	if (r >= Q - gamma2)
		return h == 0 ? 0 : m - 1;

	/*
	 * Compute the (non-hint-adjusted) HighBits r1 as:
	 *
	 *  r1 = (r - (r mod+- (2 * gamma2))) / (2 * gamma2)
	 *     = floor((r + gamma2 - 1) / (2 * gamma2))
	 *
	 * Note that when '2 * gamma2' is a compile-time constant, compilers
	 * optimize the division to a reciprocal multiplication and shift.
	 */
	r1 = (u32)(r + gamma2 - 1) / (2 * gamma2);

	/*
	 * Return the HighBits r1:
	 *	+ 0 if the hint is 0;
	 *	+ 1 (mod m) if the hint is 1 and the LowBits are positive;
	 *	- 1 (mod m) if the hint is 1 and the LowBits are negative or 0.
	 *
	 * r1 is in (and remains in) [0, m - 1].  Note that when 'm' is a
	 * compile-time constant, compilers optimize the '% m' accordingly.
	 */
	if (h == 0)
		return r1;
	if (r > r1 * (2 * gamma2))
		return (u32)(r1 + 1) % m;
	return (u32)(r1 + m - 1) % m;
}

static __always_inline void use_hint_elem(struct mldsa_ring_elem *w,
					  const u8 h[N], const s32 gamma2)
{
	for (int j = 0; j < N; j++)
		w->x[j] = use_hint(h[j], w->x[j], gamma2);
}

#if IS_ENABLED(CONFIG_CRYPTO_LIB_MLDSA_KUNIT_TEST)
/* Allow the __always_inline function use_hint() to be unit-tested. */
s32 mldsa_use_hint(u8 h, s32 r, s32 gamma2)
{
	return use_hint(h, r, gamma2);
}
EXPORT_SYMBOL_IF_KUNIT(mldsa_use_hint);
#endif

/*
 * Encode one element of the commitment vector w'_1 into a byte string.
 * Reference: FIPS 204 Algorithm 28, w1Encode.
 * Return the number of bytes used: 192 for ML-DSA-44 and 128 for the others.
 */
static size_t encode_w1(u8 out[MAX_W1_ENCODED_LEN],
			const struct mldsa_ring_elem *w1, int k)
{
	size_t pos = 0;

	static_assert(N * 6 / 8 == MAX_W1_ENCODED_LEN);
	if (k == 4) { /* ML-DSA-44? */
		/* 6 bits per coefficient.  Pack 4 at a time. */
		for (int j = 0; j < N; j += 4) {
			u32 v = (w1->x[j + 0] << 0) | (w1->x[j + 1] << 6) |
				(w1->x[j + 2] << 12) | (w1->x[j + 3] << 18);
			out[pos++] = v >> 0;
			out[pos++] = v >> 8;
			out[pos++] = v >> 16;
		}
	} else {
		/* 4 bits per coefficient.  Pack 2 at a time. */
		for (int j = 0; j < N; j += 2)
			out[pos++] = w1->x[j] | (w1->x[j + 1] << 4);
	}
	return pos;
}

int mldsa_verify(enum mldsa_alg alg, const u8 *sig, size_t sig_len,
		 const u8 *msg, size_t msg_len, const u8 *pk, size_t pk_len)
{
	const struct mldsa_parameter_set *params = &mldsa_parameter_sets[alg];
	const int k = params->k, l = params->l;
	/* For now this just does pure ML-DSA with an empty context string. */
	static const u8 msg_prefix[2] = { /* dom_sep= */ 0, /* ctx_len= */ 0 };
	const u8 *ctilde; /* The signer's commitment hash */
	const u8 *t1_encoded = &pk[RHO_LEN]; /* Next encoded element of t_1 */
	u8 *h; /* The signer's hint vector, length k * N */
	size_t w1_enc_len;

	/* Validate the public key and signature lengths. */
	if (pk_len != params->pk_len || sig_len != params->sig_len)
		return -EBADMSG;

	/*
	 * Allocate the workspace, including variable-length fields.  Its size
	 * depends only on the ML-DSA parameter set, not the other inputs.
	 *
	 * For freeing it, use kfree_sensitive() rather than kfree().  This is
	 * mainly to comply with FIPS 204 Section 3.6.3 "Intermediate Values".
	 * In reality it's a bit gratuitous, as this is a public key operation.
	 */
	struct mldsa_verification_workspace *ws __free(kfree_sensitive) =
		kmalloc(sizeof(*ws) + (l * sizeof(ws->z[0])) + (k * N),
			GFP_KERNEL);
	if (!ws)
		return -ENOMEM;
	h = (u8 *)&ws->z[l];

	/* Decode the signature.  Reference: FIPS 204 Algorithm 27, sigDecode */
	ctilde = sig;
	sig += params->ctilde_len;
	if (!decode_z(ws->z, l, params->gamma1, params->beta, &sig))
		return -EBADMSG;
	if (!decode_hint_vector(h, k, params->omega, sig))
		return -EBADMSG;

	/* Recreate the challenge c from the signer's commitment hash. */
	sample_in_ball(&ws->c, ctilde, params->ctilde_len, params->tau,
		       &ws->shake);
	ntt(&ws->c);

	/* Compute the message representative mu. */
	shake256(pk, pk_len, ws->tr, sizeof(ws->tr));
	shake256_init(&ws->shake);
	shake_update(&ws->shake, ws->tr, sizeof(ws->tr));
	shake_update(&ws->shake, msg_prefix, sizeof(msg_prefix));
	shake_update(&ws->shake, msg, msg_len);
	shake_squeeze(&ws->shake, ws->mu, sizeof(ws->mu));

	/* Start computing ctildeprime = H(mu || w1Encode(w'_1)). */
	shake256_init(&ws->shake);
	shake_update(&ws->shake, ws->mu, sizeof(ws->mu));

	/*
	 * Compute the commitment w'_1 from A, z, c, t_1, and h.
	 *
	 * The computation is the same for each of the k rows.  Just do each row
	 * before moving on to the next, resulting in only one loop over k.
	 */
	for (int i = 0; i < k; i++) {
		/*
		 * tmp = NTT(A) * NTT(z) * 2^-32
		 * To reduce memory use, generate each element of NTT(A)
		 * on-demand.  Note that each element is used only once.
		 */
		ws->tmp = (struct mldsa_ring_elem){};
		for (int j = 0; j < l; j++) {
			rej_ntt_poly(&ws->a, pk /* rho is first field of pk */,
				     cpu_to_le16((i << 8) | j), &ws->a_shake,
				     ws->block);
			for (int n = 0; n < N; n++)
				ws->tmp.x[n] +=
					Zq_mult(ws->a.x[n], ws->z[j].x[n]);
		}
		/* All components of tmp now have abs value < l*q. */

		/* Decode the next element of t_1. */
		t1_encoded = decode_t1_elem(&ws->t1_scaled, t1_encoded);

		/*
		 * tmp -= NTT(c) * NTT(t_1 * 2^d) * 2^-32
		 *
		 * Taking a conservative bound for the output of ntt(), the
		 * multiplicands can have absolute value up to 9*q.  That
		 * corresponds to a product with absolute value 81*q^2.  That is
		 * within the limits of Zq_mult() which needs < ~256*q^2.
		 */
		for (int j = 0; j < N; j++)
			ws->tmp.x[j] -= Zq_mult(ws->c.x[j], ws->t1_scaled.x[j]);
		/* All components of tmp now have abs value < (l+1)*q. */

		/* tmp = w'_Approx = NTT^-1(tmp) * 2^32 */
		invntt_and_mul_2_32(&ws->tmp);
		/* All coefficients of tmp are now in [0, q - 1]. */

		/*
		 * tmp = w'_1 = UseHint(h, w'_Approx)
		 * For efficiency, set gamma2 to a compile-time constant.
		 */
		if (k == 4)
			use_hint_elem(&ws->tmp, &h[i * N], (Q - 1) / 88);
		else
			use_hint_elem(&ws->tmp, &h[i * N], (Q - 1) / 32);

		/* Encode and hash the next element of w'_1. */
		w1_enc_len = encode_w1(ws->w1_encoded, &ws->tmp, k);
		shake_update(&ws->shake, ws->w1_encoded, w1_enc_len);
	}

	/* Finish computing ctildeprime. */
	shake_squeeze(&ws->shake, ws->ctildeprime, params->ctilde_len);

	/* Verify that ctilde == ctildeprime. */
	if (memcmp(ws->ctildeprime, ctilde, params->ctilde_len) != 0)
		return -EKEYREJECTED;
	/* ||z||_infinity < gamma1 - beta was already checked in decode_z(). */
	return 0;
}
EXPORT_SYMBOL_GPL(mldsa_verify);

#ifdef CONFIG_CRYPTO_FIPS
static int __init mldsa_mod_init(void)
{
	if (fips_enabled) {
		/*
		 * FIPS cryptographic algorithm self-test.  As per the FIPS
		 * Implementation Guidance, testing any ML-DSA parameter set
		 * satisfies the test requirement for all of them, and only a
		 * positive test is required.
		 */
		int err = mldsa_verify(MLDSA65, fips_test_mldsa65_signature,
				       sizeof(fips_test_mldsa65_signature),
				       fips_test_mldsa65_message,
				       sizeof(fips_test_mldsa65_message),
				       fips_test_mldsa65_public_key,
				       sizeof(fips_test_mldsa65_public_key));
		if (err)
			panic("mldsa: FIPS self-test failed; err=%pe\n",
			      ERR_PTR(err));
	}
	return 0;
}
subsys_initcall(mldsa_mod_init);

static void __exit mldsa_mod_exit(void)
{
}
module_exit(mldsa_mod_exit);
#endif /* CONFIG_CRYPTO_FIPS */

MODULE_DESCRIPTION("ML-DSA signature verification");
MODULE_LICENSE("GPL");