pqc-accelerate/HLS_Codes_Dilithium/polymult.cpp

228 lines
7.1 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// polymult.cpp
#include "ntt.h"
// n^{-1} mod q, for n = 256, q = 8380417
// Computed so that (256 * DILITHIUM_INV_N) % DILITHIUM_Q == 1.
static const coeff_t DILITHIUM_INV_N = (coeff_t)8347681;
// ---------------------------------------------------------------------
// Modular arithmetic
// ---------------------------------------------------------------------
coeff_t mod_q(wide_t x)
{
#pragma HLS inline
// Reduce into [0, q-1] via 64-bit modulo.
// This is not the most hardware-optimal, but it is simple and correct.
wide_t r = x % (wide_t)DILITHIUM_Q;
if (r < 0)
r += DILITHIUM_Q;
return (coeff_t)r;
}
coeff_t mod_add(coeff_t a, coeff_t b)
{
#pragma HLS inline
wide_t s = (wide_t)a + (wide_t)b;
if (s >= DILITHIUM_Q)
s -= DILITHIUM_Q;
return (coeff_t)s;
}
coeff_t mod_sub(coeff_t a, coeff_t b)
{
#pragma HLS inline
wide_t d = (wide_t)a - (wide_t)b;
if (d < 0)
d += DILITHIUM_Q;
return (coeff_t)d;
}
// One step of the butterfly: t = zeta * b, then (a+t, a-t)
static void butterfly(coeff_t &a, coeff_t &b, coeff_t zeta)
{
#pragma HLS inline
wide_t prod = (wide_t)zeta * (wide_t)b;
coeff_t t = mod_q(prod);
coeff_t a0 = a;
a = mod_add(a0, t);
b = mod_sub(a0, t);
}
// ---------------------------------------------------------------------
// Forward NTT in-place on a[0..255] for R_q = Z_q[x]/(x^256 - 1)
// Standard CooleyTukey style iteration over layers.
// The zetas[] table and its traversal pattern must match the Dilithium spec.
// ---------------------------------------------------------------------
void ntt(coeff_t a[DILITHIUM_N])
{
#pragma HLS inline off
#pragma HLS ARRAY_PARTITION variable=a cyclic factor=4 dim=1
unsigned k = 0;
// len goes 128,64,32,...,1
for (unsigned len = DILITHIUM_N >> 1; len >= 1; len >>= 1) {
#pragma HLS loop_tripcount min=1 max=8
for (unsigned start = 0; start < DILITHIUM_N; start += 2 * len) {
#pragma HLS loop_tripcount min=1 max=256
coeff_t zeta = dilithium_zetas[k++];
for (unsigned j = start; j < start + len; j++) {
#pragma HLS PIPELINE II=1
coeff_t &aj = a[j];
coeff_t &ajl = a[j + len];
butterfly(aj, ajl, zeta);
}
}
if (len == 1) break; // avoid unsigned wrap
}
}
// ---------------------------------------------------------------------
// Inverse NTT (GentlemanSande style) on a[0..255], followed by scaling
// by n^{-1}. Pattern mirrors standard Dilithium invNTT.
// ---------------------------------------------------------------------
void inv_ntt(coeff_t a[DILITHIUM_N])
{
#pragma HLS inline off
#pragma HLS ARRAY_PARTITION variable=a cyclic factor=4 dim=1
unsigned k = 0;
// len goes 1,2,4,...,128
for (unsigned len = 1; len < DILITHIUM_N; len <<= 1) {
#pragma HLS loop_tripcount min=1 max=8
for (unsigned start = 0; start < DILITHIUM_N; start += 2 * len) {
#pragma HLS loop_tripcount min=1 max=256
coeff_t zeta = dilithium_zetas_inv[k++];
for (unsigned j = start; j < start + len; j++) {
#pragma HLS PIPELINE II=1
coeff_t u = a[j];
coeff_t v = a[j + len];
coeff_t a_sum = mod_add(u, v);
coeff_t a_diff = mod_sub(u, v);
// Multiply the "difference" branch by zeta
wide_t prod = (wide_t)zeta * (wide_t)a_diff;
coeff_t t = mod_q(prod);
a[j] = a_sum;
a[j+len] = t;
}
}
}
// Final scaling by n^{-1} mod q
for (unsigned i = 0; i < DILITHIUM_N; i++) {
#pragma HLS PIPELINE II=1
wide_t prod = (wide_t)a[i] * (wide_t)DILITHIUM_INV_N;
a[i] = mod_q(prod);
}
}
// ---------------------------------------------------------------------
// Polynomial multiplication in R_q = Z_q[x]/(x^256 - 1):
// c = a * b (using NTT, pointwise multiply, INTT)
// ---------------------------------------------------------------------
static void poly_mult_internal(coeff_t a[DILITHIUM_N],
coeff_t b[DILITHIUM_N],
coeff_t c[DILITHIUM_N])
{
#pragma HLS inline off
#pragma HLS DATAFLOW
// In-place NTT
ntt(a);
ntt(b);
// Pointwise multiply
for (unsigned i = 0; i < DILITHIUM_N; i++) {
#pragma HLS PIPELINE II=1
wide_t prod = (wide_t)a[i] * (wide_t)b[i];
c[i] = mod_q(prod);
}
// In-place inverse NTT on c
inv_ntt(c);
}
// ---------------------------------------------------------------------
// AXI4-Stream <-> internal coefficient buffer
// Protocol:
//
// Input stream (axis_in):
// - First 256 beats : a[0..255], last = 0
// - Next 256 beats : b[0..255], last = 1 on the final beat
//
// Output stream (axis_out):
// - 256 beats with c[0..255], last = 1 on the final beat
// ---------------------------------------------------------------------
static void axis_read_polys(hls::stream<coeff_axis_t> &axis_in,
coeff_t a[DILITHIUM_N],
coeff_t b[DILITHIUM_N])
{
#pragma HLS inline off
coeff_axis_t w;
// Read a
for (unsigned i = 0; i < DILITHIUM_N; i++) {
#pragma HLS PIPELINE II=1
w = axis_in.read();
a[i] = (coeff_t)w.data;
}
// Read b
for (unsigned i = 0; i < DILITHIUM_N; i++) {
#pragma HLS PIPELINE II=1
w = axis_in.read();
b[i] = (coeff_t)w.data;
}
}
static void axis_write_poly(hls::stream<coeff_axis_t> &axis_out,
coeff_t c[DILITHIUM_N])
{
#pragma HLS inline off
coeff_axis_t w;
w.keep = -1;
w.strb = -1;
for (unsigned i = 0; i < DILITHIUM_N; i++) {
#pragma HLS PIPELINE II=1
w.data = (ap_int<32>)c[i];
w.last = (i == DILITHIUM_N - 1) ? 1 : 0;
axis_out.write(w);
}
}
// ---------------------------------------------------------------------
// Top-level HLS function: poly_mult_dilithium
//
// This is your second accelerator, to be instantiated alongside the
// Kyber NTT IP in the same overlay. It has its own AXI4-Stream ports,
// but you can reuse the DMA driver pattern from Kyber by adjusting
// the packing (no 2x16-bit packing here, just one 32-bit coeff per beat).
// ---------------------------------------------------------------------
int poly_mult(hls::stream<coeff_axis_t> &axis_in,
hls::stream<coeff_axis_t> &axis_out)
{
#pragma HLS INTERFACE axis register port=axis_in
#pragma HLS INTERFACE axis register port=axis_out
#pragma HLS INTERFACE s_axilite port=return bundle=CTRL_BUS
#pragma HLS DATAFLOW
coeff_t a[DILITHIUM_N];
coeff_t b[DILITHIUM_N];
coeff_t c[DILITHIUM_N];
#pragma HLS ARRAY_PARTITION variable=a cyclic factor=4 dim=1
#pragma HLS ARRAY_PARTITION variable=b cyclic factor=4 dim=1
#pragma HLS ARRAY_PARTITION variable=c cyclic factor=4 dim=1
axis_read_polys(axis_in, a, b);
poly_mult_internal(a, b, c);
axis_write_poly(axis_out, c);
return 0;
}