// polymult.cpp #include "ntt.h" // n^{-1} mod q, for n = 256, q = 8380417 // Computed so that (256 * DILITHIUM_INV_N) % DILITHIUM_Q == 1. static const coeff_t DILITHIUM_INV_N = (coeff_t)8347681; // --------------------------------------------------------------------- // Modular arithmetic // --------------------------------------------------------------------- coeff_t mod_q(wide_t x) { #pragma HLS inline // Reduce into [0, q-1] via 64-bit modulo. // This is not the most hardware-optimal, but it is simple and correct. wide_t r = x % (wide_t)DILITHIUM_Q; if (r < 0) r += DILITHIUM_Q; return (coeff_t)r; } coeff_t mod_add(coeff_t a, coeff_t b) { #pragma HLS inline wide_t s = (wide_t)a + (wide_t)b; if (s >= DILITHIUM_Q) s -= DILITHIUM_Q; return (coeff_t)s; } coeff_t mod_sub(coeff_t a, coeff_t b) { #pragma HLS inline wide_t d = (wide_t)a - (wide_t)b; if (d < 0) d += DILITHIUM_Q; return (coeff_t)d; } // One step of the butterfly: t = zeta * b, then (a+t, a-t) static void butterfly(coeff_t &a, coeff_t &b, coeff_t zeta) { #pragma HLS inline wide_t prod = (wide_t)zeta * (wide_t)b; coeff_t t = mod_q(prod); coeff_t a0 = a; a = mod_add(a0, t); b = mod_sub(a0, t); } // --------------------------------------------------------------------- // Forward NTT in-place on a[0..255] for R_q = Z_q[x]/(x^256 - 1) // Standard Cooley–Tukey style iteration over layers. // The zetas[] table and its traversal pattern must match the Dilithium spec. // --------------------------------------------------------------------- void ntt(coeff_t a[DILITHIUM_N]) { #pragma HLS inline off #pragma HLS ARRAY_PARTITION variable=a cyclic factor=4 dim=1 unsigned k = 0; // len goes 128,64,32,...,1 for (unsigned len = DILITHIUM_N >> 1; len >= 1; len >>= 1) { #pragma HLS loop_tripcount min=1 max=8 for (unsigned start = 0; start < DILITHIUM_N; start += 2 * len) { #pragma HLS loop_tripcount min=1 max=256 coeff_t zeta = dilithium_zetas[k++]; for (unsigned j = start; j < start + len; j++) { #pragma HLS PIPELINE II=1 coeff_t &aj = a[j]; coeff_t &ajl = a[j + len]; butterfly(aj, ajl, zeta); } } if (len == 1) break; // avoid unsigned wrap } } // --------------------------------------------------------------------- // Inverse NTT (Gentleman–Sande style) on a[0..255], followed by scaling // by n^{-1}. Pattern mirrors standard Dilithium invNTT. // --------------------------------------------------------------------- void inv_ntt(coeff_t a[DILITHIUM_N]) { #pragma HLS inline off #pragma HLS ARRAY_PARTITION variable=a cyclic factor=4 dim=1 unsigned k = 0; // len goes 1,2,4,...,128 for (unsigned len = 1; len < DILITHIUM_N; len <<= 1) { #pragma HLS loop_tripcount min=1 max=8 for (unsigned start = 0; start < DILITHIUM_N; start += 2 * len) { #pragma HLS loop_tripcount min=1 max=256 coeff_t zeta = dilithium_zetas_inv[k++]; for (unsigned j = start; j < start + len; j++) { #pragma HLS PIPELINE II=1 coeff_t u = a[j]; coeff_t v = a[j + len]; coeff_t a_sum = mod_add(u, v); coeff_t a_diff = mod_sub(u, v); // Multiply the "difference" branch by zeta wide_t prod = (wide_t)zeta * (wide_t)a_diff; coeff_t t = mod_q(prod); a[j] = a_sum; a[j+len] = t; } } } // Final scaling by n^{-1} mod q for (unsigned i = 0; i < DILITHIUM_N; i++) { #pragma HLS PIPELINE II=1 wide_t prod = (wide_t)a[i] * (wide_t)DILITHIUM_INV_N; a[i] = mod_q(prod); } } // --------------------------------------------------------------------- // Polynomial multiplication in R_q = Z_q[x]/(x^256 - 1): // c = a * b (using NTT, pointwise multiply, INTT) // --------------------------------------------------------------------- static void poly_mult_internal(coeff_t a[DILITHIUM_N], coeff_t b[DILITHIUM_N], coeff_t c[DILITHIUM_N]) { #pragma HLS inline off #pragma HLS DATAFLOW // In-place NTT ntt(a); ntt(b); // Pointwise multiply for (unsigned i = 0; i < DILITHIUM_N; i++) { #pragma HLS PIPELINE II=1 wide_t prod = (wide_t)a[i] * (wide_t)b[i]; c[i] = mod_q(prod); } // In-place inverse NTT on c inv_ntt(c); } // --------------------------------------------------------------------- // AXI4-Stream <-> internal coefficient buffer // Protocol: // // Input stream (axis_in): // - First 256 beats : a[0..255], last = 0 // - Next 256 beats : b[0..255], last = 1 on the final beat // // Output stream (axis_out): // - 256 beats with c[0..255], last = 1 on the final beat // --------------------------------------------------------------------- static void axis_read_polys(hls::stream &axis_in, coeff_t a[DILITHIUM_N], coeff_t b[DILITHIUM_N]) { #pragma HLS inline off coeff_axis_t w; // Read a for (unsigned i = 0; i < DILITHIUM_N; i++) { #pragma HLS PIPELINE II=1 w = axis_in.read(); a[i] = (coeff_t)w.data; } // Read b for (unsigned i = 0; i < DILITHIUM_N; i++) { #pragma HLS PIPELINE II=1 w = axis_in.read(); b[i] = (coeff_t)w.data; } } static void axis_write_poly(hls::stream &axis_out, coeff_t c[DILITHIUM_N]) { #pragma HLS inline off coeff_axis_t w; w.keep = -1; w.strb = -1; for (unsigned i = 0; i < DILITHIUM_N; i++) { #pragma HLS PIPELINE II=1 w.data = (ap_int<32>)c[i]; w.last = (i == DILITHIUM_N - 1) ? 1 : 0; axis_out.write(w); } } // --------------------------------------------------------------------- // Top-level HLS function: poly_mult_dilithium // // This is your second accelerator, to be instantiated alongside the // Kyber NTT IP in the same overlay. It has its own AXI4-Stream ports, // but you can reuse the DMA driver pattern from Kyber by adjusting // the packing (no 2x16-bit packing here, just one 32-bit coeff per beat). // --------------------------------------------------------------------- int poly_mult(hls::stream &axis_in, hls::stream &axis_out) { #pragma HLS INTERFACE axis register port=axis_in #pragma HLS INTERFACE axis register port=axis_out #pragma HLS INTERFACE s_axilite port=return bundle=CTRL_BUS #pragma HLS DATAFLOW coeff_t a[DILITHIUM_N]; coeff_t b[DILITHIUM_N]; coeff_t c[DILITHIUM_N]; #pragma HLS ARRAY_PARTITION variable=a cyclic factor=4 dim=1 #pragma HLS ARRAY_PARTITION variable=b cyclic factor=4 dim=1 #pragma HLS ARRAY_PARTITION variable=c cyclic factor=4 dim=1 axis_read_polys(axis_in, a, b); poly_mult_internal(a, b, c); axis_write_poly(axis_out, c); return 0; }