mirror of
https://github.com/saymrwulf/pqc-accelerate.git
synced 2026-05-14 20:48:07 +00:00
228 lines
7.1 KiB
C++
228 lines
7.1 KiB
C++
// polymult.cpp
|
||
#include "ntt.h"
|
||
|
||
|
||
// n^{-1} mod q, for n = 256, q = 8380417
|
||
// Computed so that (256 * DILITHIUM_INV_N) % DILITHIUM_Q == 1.
|
||
static const coeff_t DILITHIUM_INV_N = (coeff_t)8347681;
|
||
|
||
// ---------------------------------------------------------------------
|
||
// Modular arithmetic
|
||
// ---------------------------------------------------------------------
|
||
|
||
coeff_t mod_q(wide_t x)
|
||
{
|
||
#pragma HLS inline
|
||
// Reduce into [0, q-1] via 64-bit modulo.
|
||
// This is not the most hardware-optimal, but it is simple and correct.
|
||
wide_t r = x % (wide_t)DILITHIUM_Q;
|
||
if (r < 0)
|
||
r += DILITHIUM_Q;
|
||
return (coeff_t)r;
|
||
}
|
||
|
||
coeff_t mod_add(coeff_t a, coeff_t b)
|
||
{
|
||
#pragma HLS inline
|
||
wide_t s = (wide_t)a + (wide_t)b;
|
||
if (s >= DILITHIUM_Q)
|
||
s -= DILITHIUM_Q;
|
||
return (coeff_t)s;
|
||
}
|
||
|
||
coeff_t mod_sub(coeff_t a, coeff_t b)
|
||
{
|
||
#pragma HLS inline
|
||
wide_t d = (wide_t)a - (wide_t)b;
|
||
if (d < 0)
|
||
d += DILITHIUM_Q;
|
||
return (coeff_t)d;
|
||
}
|
||
|
||
// One step of the butterfly: t = zeta * b, then (a+t, a-t)
|
||
static void butterfly(coeff_t &a, coeff_t &b, coeff_t zeta)
|
||
{
|
||
#pragma HLS inline
|
||
wide_t prod = (wide_t)zeta * (wide_t)b;
|
||
coeff_t t = mod_q(prod);
|
||
coeff_t a0 = a;
|
||
a = mod_add(a0, t);
|
||
b = mod_sub(a0, t);
|
||
}
|
||
|
||
// ---------------------------------------------------------------------
|
||
// Forward NTT in-place on a[0..255] for R_q = Z_q[x]/(x^256 - 1)
|
||
// Standard Cooley–Tukey style iteration over layers.
|
||
// The zetas[] table and its traversal pattern must match the Dilithium spec.
|
||
// ---------------------------------------------------------------------
|
||
void ntt(coeff_t a[DILITHIUM_N])
|
||
{
|
||
#pragma HLS inline off
|
||
#pragma HLS ARRAY_PARTITION variable=a cyclic factor=4 dim=1
|
||
|
||
unsigned k = 0;
|
||
// len goes 128,64,32,...,1
|
||
for (unsigned len = DILITHIUM_N >> 1; len >= 1; len >>= 1) {
|
||
#pragma HLS loop_tripcount min=1 max=8
|
||
for (unsigned start = 0; start < DILITHIUM_N; start += 2 * len) {
|
||
#pragma HLS loop_tripcount min=1 max=256
|
||
coeff_t zeta = dilithium_zetas[k++];
|
||
for (unsigned j = start; j < start + len; j++) {
|
||
#pragma HLS PIPELINE II=1
|
||
coeff_t &aj = a[j];
|
||
coeff_t &ajl = a[j + len];
|
||
butterfly(aj, ajl, zeta);
|
||
}
|
||
}
|
||
if (len == 1) break; // avoid unsigned wrap
|
||
}
|
||
}
|
||
|
||
// ---------------------------------------------------------------------
|
||
// Inverse NTT (Gentleman–Sande style) on a[0..255], followed by scaling
|
||
// by n^{-1}. Pattern mirrors standard Dilithium invNTT.
|
||
// ---------------------------------------------------------------------
|
||
void inv_ntt(coeff_t a[DILITHIUM_N])
|
||
{
|
||
#pragma HLS inline off
|
||
#pragma HLS ARRAY_PARTITION variable=a cyclic factor=4 dim=1
|
||
|
||
unsigned k = 0;
|
||
// len goes 1,2,4,...,128
|
||
for (unsigned len = 1; len < DILITHIUM_N; len <<= 1) {
|
||
#pragma HLS loop_tripcount min=1 max=8
|
||
for (unsigned start = 0; start < DILITHIUM_N; start += 2 * len) {
|
||
#pragma HLS loop_tripcount min=1 max=256
|
||
coeff_t zeta = dilithium_zetas_inv[k++];
|
||
for (unsigned j = start; j < start + len; j++) {
|
||
#pragma HLS PIPELINE II=1
|
||
coeff_t u = a[j];
|
||
coeff_t v = a[j + len];
|
||
|
||
coeff_t a_sum = mod_add(u, v);
|
||
coeff_t a_diff = mod_sub(u, v);
|
||
|
||
// Multiply the "difference" branch by zeta
|
||
wide_t prod = (wide_t)zeta * (wide_t)a_diff;
|
||
coeff_t t = mod_q(prod);
|
||
|
||
a[j] = a_sum;
|
||
a[j+len] = t;
|
||
}
|
||
}
|
||
}
|
||
|
||
// Final scaling by n^{-1} mod q
|
||
for (unsigned i = 0; i < DILITHIUM_N; i++) {
|
||
#pragma HLS PIPELINE II=1
|
||
wide_t prod = (wide_t)a[i] * (wide_t)DILITHIUM_INV_N;
|
||
a[i] = mod_q(prod);
|
||
}
|
||
}
|
||
|
||
// ---------------------------------------------------------------------
|
||
// Polynomial multiplication in R_q = Z_q[x]/(x^256 - 1):
|
||
// c = a * b (using NTT, pointwise multiply, INTT)
|
||
// ---------------------------------------------------------------------
|
||
static void poly_mult_internal(coeff_t a[DILITHIUM_N],
|
||
coeff_t b[DILITHIUM_N],
|
||
coeff_t c[DILITHIUM_N])
|
||
{
|
||
#pragma HLS inline off
|
||
#pragma HLS DATAFLOW
|
||
|
||
// In-place NTT
|
||
ntt(a);
|
||
ntt(b);
|
||
|
||
// Pointwise multiply
|
||
for (unsigned i = 0; i < DILITHIUM_N; i++) {
|
||
#pragma HLS PIPELINE II=1
|
||
wide_t prod = (wide_t)a[i] * (wide_t)b[i];
|
||
c[i] = mod_q(prod);
|
||
}
|
||
|
||
// In-place inverse NTT on c
|
||
inv_ntt(c);
|
||
}
|
||
|
||
// ---------------------------------------------------------------------
|
||
// AXI4-Stream <-> internal coefficient buffer
|
||
// Protocol:
|
||
//
|
||
// Input stream (axis_in):
|
||
// - First 256 beats : a[0..255], last = 0
|
||
// - Next 256 beats : b[0..255], last = 1 on the final beat
|
||
//
|
||
// Output stream (axis_out):
|
||
// - 256 beats with c[0..255], last = 1 on the final beat
|
||
// ---------------------------------------------------------------------
|
||
static void axis_read_polys(hls::stream<coeff_axis_t> &axis_in,
|
||
coeff_t a[DILITHIUM_N],
|
||
coeff_t b[DILITHIUM_N])
|
||
{
|
||
#pragma HLS inline off
|
||
coeff_axis_t w;
|
||
|
||
// Read a
|
||
for (unsigned i = 0; i < DILITHIUM_N; i++) {
|
||
#pragma HLS PIPELINE II=1
|
||
w = axis_in.read();
|
||
a[i] = (coeff_t)w.data;
|
||
}
|
||
|
||
// Read b
|
||
for (unsigned i = 0; i < DILITHIUM_N; i++) {
|
||
#pragma HLS PIPELINE II=1
|
||
w = axis_in.read();
|
||
b[i] = (coeff_t)w.data;
|
||
}
|
||
}
|
||
|
||
static void axis_write_poly(hls::stream<coeff_axis_t> &axis_out,
|
||
coeff_t c[DILITHIUM_N])
|
||
{
|
||
#pragma HLS inline off
|
||
|
||
coeff_axis_t w;
|
||
w.keep = -1;
|
||
w.strb = -1;
|
||
|
||
for (unsigned i = 0; i < DILITHIUM_N; i++) {
|
||
#pragma HLS PIPELINE II=1
|
||
w.data = (ap_int<32>)c[i];
|
||
w.last = (i == DILITHIUM_N - 1) ? 1 : 0;
|
||
axis_out.write(w);
|
||
}
|
||
}
|
||
|
||
// ---------------------------------------------------------------------
|
||
// Top-level HLS function: poly_mult_dilithium
|
||
//
|
||
// This is your second accelerator, to be instantiated alongside the
|
||
// Kyber NTT IP in the same overlay. It has its own AXI4-Stream ports,
|
||
// but you can reuse the DMA driver pattern from Kyber by adjusting
|
||
// the packing (no 2x16-bit packing here, just one 32-bit coeff per beat).
|
||
// ---------------------------------------------------------------------
|
||
int poly_mult(hls::stream<coeff_axis_t> &axis_in,
|
||
hls::stream<coeff_axis_t> &axis_out)
|
||
{
|
||
#pragma HLS INTERFACE axis register port=axis_in
|
||
#pragma HLS INTERFACE axis register port=axis_out
|
||
#pragma HLS INTERFACE s_axilite port=return bundle=CTRL_BUS
|
||
#pragma HLS DATAFLOW
|
||
|
||
coeff_t a[DILITHIUM_N];
|
||
coeff_t b[DILITHIUM_N];
|
||
coeff_t c[DILITHIUM_N];
|
||
|
||
#pragma HLS ARRAY_PARTITION variable=a cyclic factor=4 dim=1
|
||
#pragma HLS ARRAY_PARTITION variable=b cyclic factor=4 dim=1
|
||
#pragma HLS ARRAY_PARTITION variable=c cyclic factor=4 dim=1
|
||
|
||
axis_read_polys(axis_in, a, b);
|
||
poly_mult_internal(a, b, c);
|
||
axis_write_poly(axis_out, c);
|
||
|
||
return 0;
|
||
}
|