mirror of
https://github.com/saymrwulf/pqc-accelerate.git
synced 2026-05-14 20:48:07 +00:00
kyber running, starting dilithium development
This commit is contained in:
commit
904df42cd0
10 changed files with 3750 additions and 0 deletions
27
.gitignore
vendored
Normal file
27
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
# Ignore build output directory
|
||||
/build
|
||||
/export
|
||||
|
||||
# Ignore object files and dependent files
|
||||
.o
|
||||
.d
|
||||
|
||||
#Ignore logs folder and log files
|
||||
/logs
|
||||
.log
|
||||
|
||||
#ignore zip files
|
||||
.zip
|
||||
|
||||
#Ignore ide folder
|
||||
/_ide
|
||||
|
||||
#Ignore lock files
|
||||
.lock
|
||||
|
||||
.bin
|
||||
.pdi
|
||||
.vitisWorkspace.json
|
||||
|
||||
_ide/logs
|
||||
_ide/.wsdata
|
||||
42
HLS_Codes/ntt.h
Normal file
42
HLS_Codes/ntt.h
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
#ifndef NTT_H
|
||||
#define NTT_H
|
||||
|
||||
#include <ap_int.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "hls_stream.h"
|
||||
#include "ap_axi_sdata.h"
|
||||
|
||||
typedef ap_uint<1> bit;
|
||||
typedef ap_uint<8> ap_logn_t;
|
||||
typedef ap_int<16> coeff_t;
|
||||
typedef ap_int<32> double_coeff_t;
|
||||
|
||||
// Internal streaming types (original design)
|
||||
struct coeff_t_stream
|
||||
{
|
||||
coeff_t value;
|
||||
bit last;
|
||||
};
|
||||
|
||||
struct coeff_t_stream_big
|
||||
{
|
||||
double_coeff_t value;
|
||||
bit last;
|
||||
};
|
||||
|
||||
// External AXI4-Stream element types (only used on top-level ports)
|
||||
typedef ap_axiu<16,0,0,0> coeff_axis_t;
|
||||
typedef ap_axiu<32,0,0,0> coeff_axis_big_t;
|
||||
|
||||
#define N 128
|
||||
#define Nt 256
|
||||
#define logN 7
|
||||
|
||||
extern coeff_t q, w_n;
|
||||
|
||||
// Top-level function now uses AXI4-Stream types for DMA compatibility
|
||||
int poly_mult (hls::stream<coeff_axis_big_t> &input,
|
||||
hls::stream<coeff_axis_t> &output);
|
||||
|
||||
#endif
|
||||
69
HLS_Codes/pm_test.cpp
Normal file
69
HLS_Codes/pm_test.cpp
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
// pm_test.cpp (works with Set A and Set B)
|
||||
|
||||
#include "test_case.h"
|
||||
|
||||
int main()
|
||||
{
|
||||
// Top-level AXI4-Stream ports for the DUT
|
||||
hls::stream<coeff_axis_big_t> in_data;
|
||||
hls::stream<coeff_axis_t> out_data;
|
||||
|
||||
coeff_axis_big_t local_stream1;
|
||||
coeff_axis_t local_stream2;
|
||||
|
||||
int i;
|
||||
|
||||
coeff_t actual_outputs[Nt];
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Write stimulus into input AXI4-Stream
|
||||
// -------------------------------------------------------------------------
|
||||
for (i = 0; i < Nt; i++)
|
||||
{
|
||||
coeff_t val1 = input1_vals[i];
|
||||
double_coeff_t val2 = (double_coeff_t)(input2_vals[i] * 65536);
|
||||
|
||||
// Pack into 32-bit AXI data word
|
||||
local_stream1.data = (ap_uint<32>)(val1 + val2);
|
||||
|
||||
// Mark all bytes valid; side channels are disabled here
|
||||
local_stream1.keep = -1;
|
||||
local_stream1.strb = -1;
|
||||
|
||||
// TLAST on final sample
|
||||
local_stream1.last = (i == Nt - 1) ? 1 : 0;
|
||||
|
||||
in_data.write(local_stream1);
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Call DUT
|
||||
// -------------------------------------------------------------------------
|
||||
poly_mult(in_data, out_data);
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Read result from output AXI4-Stream
|
||||
// -------------------------------------------------------------------------
|
||||
for (i = 0; i < Nt; i++)
|
||||
{
|
||||
local_stream2 = out_data.read();
|
||||
actual_outputs[i] = (coeff_t)local_stream2.data;
|
||||
// Optionally check local_stream2.last here
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Compare against golden output
|
||||
// -------------------------------------------------------------------------
|
||||
int ret_val = 0;
|
||||
for (i = 0; i < Nt; i++)
|
||||
{
|
||||
if (output_vals[i] != actual_outputs[i])
|
||||
{
|
||||
ret_val++;
|
||||
std::cout << actual_outputs[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return ret_val;
|
||||
}
|
||||
1091
HLS_Codes/polymult.cpp
Normal file
1091
HLS_Codes/polymult.cpp
Normal file
File diff suppressed because it is too large
Load diff
6
HLS_Codes/test_case.h
Normal file
6
HLS_Codes/test_case.h
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
#include "ntt.h"
|
||||
|
||||
coeff_t input1_vals[] = {1477, 218, 784, 251, 747, 1051, 1924, 133, 2953, 1295, 2989, 1519, 1701, 1874, 2806, 423, 2883, 327, 47, 2525, 1508, 214, 2998, 217, 1852, 2624, 2286, 3039, 3076, 1213, 1808, 2554, 1129, 1353, 2690, 2839, 1778, 2752, 1378, 601, 914, 2335, 2497, 1139, 2611, 129, 1318, 1570, 3190, 1868, 940, 2901, 2626, 2473, 3195, 2621, 2436, 3046, 1018, 1139, 1729, 3021, 2064, 945, 690, 1700, 1836, 1943, 2333, 2131, 1618, 1741, 2639, 2653, 301, 2013, 2744, 2406, 2995, 2463, 2366, 1495, 442, 224, 1349, 11, 2342, 1712, 2847, 1578, 2654, 2734, 3131, 1245, 1862, 527, 2400, 2043, 1360, 451, 573, 898, 2018, 3100, 161, 284, 1949, 362, 755, 2916, 1288, 1616, 876, 1682, 853, 2772, 2956, 1101, 2, 214, 2589, 211, 1025, 610, 1225, 2118, 224, 1296, 2612, 2634, 2056, 3227, 1712, 1258, 552, 1345, 786, 2124, 2915, 1226, 1233, 2654, 2786, 2636, 2234, 727, 2444, 199, 600, 2262, 3221, 915, 63, 318, 74, 2396, 1690, 2390, 1711, 414, 10, 2298, 1082, 1419, 3151, 1723, 2744, 3274, 2518, 2954, 1208, 2941, 2089, 3288, 1370, 783, 2517, 3190, 3069, 2505, 2840, 1427, 1670, 3091, 655, 96, 1935, 880, 2511, 876, 2371, 341, 196, 2849, 919, 161, 603, 2993, 2903, 1721, 139, 3326, 1876, 379, 2508, 2094, 1929, 430, 1033, 2604, 1955, 1333, 2274, 3312, 2604, 1585, 2317, 3230, 3068, 2905, 3268, 2844, 1023, 2824, 1731, 643, 820, 462, 2975, 314, 2218, 2011, 649, 383, 874, 2181, 866, 1192, 2914, 2290, 1820, 1572, 1030, 3076, 1526, 2760, 12, 529, 1242, 560, 2723, 2894, 1097, 778, 1495, 371};
|
||||
coeff_t input2_vals[] = {2960, 3124, 509, 485, 2525, 385, 608, 2893, 2423, 1802, 2556, 1090, 775, 2059, 898, 864, 2459, 1116, 551, 188, 3262, 2728, 3134, 2451, 427, 858, 1927, 830, 2688, 2388, 2818, 1418, 3298, 24, 2491, 1448, 1153, 178, 2489, 2126, 1772, 669, 1238, 633, 1919, 2222, 2673, 1918, 2202, 3312, 208, 976, 2267, 107, 2905, 1137, 2921, 2471, 2796, 1313, 485, 1982, 1557, 1203, 2930, 241, 3089, 890, 2193, 179, 952, 2057, 2444, 1378, 1466, 1362, 1808, 2343, 1532, 2651, 727, 3254, 1328, 1604, 967, 2418, 1266, 1826, 684, 2869, 3149, 1874, 1691, 1507, 339, 2473, 102, 3153, 969, 1551, 548, 3059, 2841, 1369, 148, 2510, 2025, 1369, 1579, 2474, 1093, 527, 1416, 981, 2320, 2305, 227, 2173, 812, 1703, 2952, 17, 1129, 2223, 1894, 959, 73, 339, 553, 1466, 1065, 617, 1749, 1896, 1838, 1771, 3092, 297, 996, 198, 521, 567, 3256, 2783, 1044, 2644, 744, 2986, 3178, 1522, 942, 2045, 236, 1866, 853, 2303, 2383, 3095, 418, 2752, 2105, 2896, 3081, 3067, 1696, 978, 102, 1961, 3120, 2741, 1029, 885, 2852, 2659, 2815, 3032, 2358, 3252, 1195, 3304, 878, 70, 3069, 2726, 2455, 182, 108, 2868, 1744, 1697, 1060, 1803, 1752, 829, 2434, 862, 2287, 2860, 352, 634, 2626, 1920, 2425, 239, 831, 2527, 1190, 1469, 2602, 1711, 2185, 1403, 3189, 1188, 2649, 2079, 2215, 790, 409, 2413, 627, 2268, 2507, 2102, 1727, 1146, 2711, 355, 1143, 1225, 430, 82, 3015, 2699, 642, 863, 241, 450, 440, 338, 365, 2621, 3022, 204, 149, 2986, 2191, 1793, 3085, 2128, 373, 290, 835, 580, 2530, 1948};
|
||||
|
||||
coeff_t output_vals[] = {2762, 3061, 1101, 3267, 2744, 1349, 182, 1761, 3089, 751, 137, 368, 1461, 2956, 493, 1653, 2617, 721, 356, 3034, 2234, 1556, 809, 2290, 1597, 457, 811, 259, 685, 2478, 319, 2519, 1049, 837, 644, 2571, 1029, 2997, 762, 1710, 2110, 1099, 2513, 1038, 2176, 1938, 3214, 261, 1604, 2474, 5, 1211, 2816, 2848, 2286, 3146, 1777, 1630, 2412, 1457, 889, 671, 822, 2369, 1409, 2059, 1121, 1871, 303, 1178, 2241, 1827, 2046, 628, 2869, 749, 1666, 895, 580, 1770, 2082, 3123, 1192, 520, 168, 2461, 1032, 163, 1421, 2792, 2148, 1735, 220, 1896, 2887, 2163, 357, 2301, 1830, 163, 1812, 805, 1850, 2017, 2313, 1205, 2226, 703, 866, 1708, 1426, 1920, 2911, 267, 3134, 629, 2120, 2022, 2847, 2945, 2967, 1977, 1449, 2028, 1381, 2738, 1098, 2977, 2217, 2060, 710, 845, 2807, 509, 2512, 2444, 2355, 550, 2965, 2517, 1802, 1755, 1065, 1938, 388, 2365, 776, 2453, 1799, 1532, 384, 2266, 1071, 2063, 2858, 1414, 663, 2886, 2734, 209, 1061, 2142, 841, 1081, 977, 799, 2661, 588, 3222, 2140, 2383, 3044, 394, 231, 1090, 917, 1840, 3002, 2315, 1182, 2744, 2815, 2612, 2586, 970, 3301, 3028, 2890, 1849, 269, 2936, 1525, 3102, 3144, 1605, 2746, 1556, 537, 2918, 2549, 976, 250, 2137, 492, 729, 392, 1115, 2422, 2100, 2317, 1636, 1743, 1279, 1393, 2079, 2874, 2148, 233, 1469, 3143, 2109, 1211, 2318, 1138, 2979, 1383, 125, 1995, 1614, 1435, 2216, 782, 671, 662, 988, 2826, 2162, 605, 2955, 2478, 2375, 1449, 2307, 1921, 1285, 2208, 2422, 1035, 2765, 923, 2138, 3053, 812, 146, 1175, 61};
|
||||
38
HLS_Codes_Dilithium/ntt.h
Normal file
38
HLS_Codes_Dilithium/ntt.h
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
// ntt.h
|
||||
#ifndef NTT_DILITHIUM_H
|
||||
#define NTT_DILITHIUM_H
|
||||
|
||||
#include <ap_int.h>
|
||||
#include <hls_stream.h>
|
||||
#include <ap_axi_sdata.h>
|
||||
|
||||
// Dilithium / ML-DSA-style ring
|
||||
// R_q = Z_q[x]/(x^256 - 1), q = 8380417
|
||||
// Only the *ring* and NTT structure are reflected here; we’re not
|
||||
// implementing full ML-DSA, just the polynomial multiplier.
|
||||
|
||||
#define DILITHIUM_N 256
|
||||
#define DILITHIUM_Q 8380417
|
||||
|
||||
// Coefficients: we use 32-bit signed ints. q < 2^23, so this is safe.
|
||||
typedef ap_int<32> coeff_t;
|
||||
typedef ap_int<64> wide_t;
|
||||
|
||||
// Simple 32-bit AXI4-Stream word: one coeff per beat.
|
||||
// You’ll send 256 words for a, then 256 words for b.
|
||||
typedef ap_axiu<32, 0, 0, 0> coeff_axis_t;
|
||||
|
||||
// Modular arithmetic helpers (declared here, defined in the .cpp)
|
||||
coeff_t mod_q(wide_t x);
|
||||
coeff_t mod_add(coeff_t a, coeff_t b);
|
||||
coeff_t mod_sub(coeff_t a, coeff_t b);
|
||||
|
||||
// NTT / INTT prototypes
|
||||
void ntt(coeff_t a[DILITHIUM_N]);
|
||||
void inv_ntt(coeff_t a[DILITHIUM_N]);
|
||||
|
||||
// Top-level polymult for Dilithium ring
|
||||
int poly_mult_dilithium(hls::stream<coeff_axis_t> &axis_in,
|
||||
hls::stream<coeff_axis_t> &axis_out);
|
||||
|
||||
#endif // NTT_DILITHIUM_H
|
||||
93
HLS_Codes_Dilithium/pm_test.cpp
Normal file
93
HLS_Codes_Dilithium/pm_test.cpp
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
// pm_test_dilithium.cpp – C-sim / C-synth testbench for Dilithium polymult IP
|
||||
|
||||
#include <iostream>
|
||||
#include <hls_stream.h>
|
||||
|
||||
#include "ntt.h"
|
||||
#include "test_case.h"
|
||||
|
||||
int main()
|
||||
{
|
||||
// Top-level AXI4-Stream ports for the DUT
|
||||
hls::stream<coeff_axis_t> in_data;
|
||||
hls::stream<coeff_axis_t> out_data;
|
||||
|
||||
coeff_axis_t local_stream;
|
||||
coeff_t actual_outputs[DILITHIUM_N];
|
||||
|
||||
int i;
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Write stimulus into input AXI4-Stream
|
||||
//
|
||||
// Protocol for poly_mult_dilithium():
|
||||
// - first DILITHIUM_N words: a[0..N-1]
|
||||
// - next DILITHIUM_N words: b[0..N-1]
|
||||
// TLAST = 1 only on the very last word (b[N-1])
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// Send polynomial a
|
||||
for (i = 0; i < DILITHIUM_N; i++) {
|
||||
coeff_t val = input1_vals[i];
|
||||
|
||||
local_stream.data = (ap_int<32>)val;
|
||||
local_stream.keep = -1;
|
||||
local_stream.strb = -1;
|
||||
local_stream.last = 0; // not last yet
|
||||
|
||||
in_data.write(local_stream);
|
||||
}
|
||||
|
||||
// Send polynomial b
|
||||
for (i = 0; i < DILITHIUM_N; i++) {
|
||||
coeff_t val = input2_vals[i];
|
||||
|
||||
local_stream.data = (ap_int<32>)val;
|
||||
local_stream.keep = -1;
|
||||
local_stream.strb = -1;
|
||||
local_stream.last = (i == DILITHIUM_N - 1) ? 1 : 0;
|
||||
|
||||
in_data.write(local_stream);
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Call DUT
|
||||
// -------------------------------------------------------------------------
|
||||
poly_mult_dilithium(in_data, out_data);
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Read result from output AXI4-Stream
|
||||
//
|
||||
// The core outputs exactly DILITHIUM_N coefficients; TLAST is 1 on the
|
||||
// final beat. We read all N coefficients into actual_outputs[].
|
||||
// -------------------------------------------------------------------------
|
||||
for (i = 0; i < DILITHIUM_N; i++) {
|
||||
coeff_axis_t out_word = out_data.read();
|
||||
actual_outputs[i] = (coeff_t)out_word.data;
|
||||
|
||||
// Optional: you can sanity-check TLAST here:
|
||||
// if (i == DILITHIUM_N - 1 && out_word.last != 1) ...
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Compare against golden output
|
||||
// -------------------------------------------------------------------------
|
||||
int ret_val = 0;
|
||||
for (i = 0; i < DILITHIUM_N; i++) {
|
||||
if (output_vals[i] != actual_outputs[i]) {
|
||||
ret_val++;
|
||||
std::cout << "Mismatch at index " << i
|
||||
<< ": got " << (long long)actual_outputs[i]
|
||||
<< ", expected " << (long long)output_vals[i]
|
||||
<< std::endl;
|
||||
break; // stop at first mismatch (like your Kyber testbench)
|
||||
}
|
||||
}
|
||||
|
||||
if (ret_val == 0) {
|
||||
std::cout << "All " << DILITHIUM_N
|
||||
<< " coefficients match golden output." << std::endl;
|
||||
}
|
||||
|
||||
return ret_val;
|
||||
}
|
||||
228
HLS_Codes_Dilithium/polymult.cpp
Normal file
228
HLS_Codes_Dilithium/polymult.cpp
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
// polymult.cpp
|
||||
#include "ntt.h"
|
||||
|
||||
|
||||
// n^{-1} mod q, for n = 256, q = 8380417
|
||||
// Computed so that (256 * DILITHIUM_INV_N) % DILITHIUM_Q == 1.
|
||||
static const coeff_t DILITHIUM_INV_N = (coeff_t)8347681;
|
||||
|
||||
// ---------------------------------------------------------------------
|
||||
// Modular arithmetic
|
||||
// ---------------------------------------------------------------------
|
||||
|
||||
coeff_t mod_q(wide_t x)
|
||||
{
|
||||
#pragma HLS inline
|
||||
// Reduce into [0, q-1] via 64-bit modulo.
|
||||
// This is not the most hardware-optimal, but it is simple and correct.
|
||||
wide_t r = x % (wide_t)DILITHIUM_Q;
|
||||
if (r < 0)
|
||||
r += DILITHIUM_Q;
|
||||
return (coeff_t)r;
|
||||
}
|
||||
|
||||
coeff_t mod_add(coeff_t a, coeff_t b)
|
||||
{
|
||||
#pragma HLS inline
|
||||
wide_t s = (wide_t)a + (wide_t)b;
|
||||
if (s >= DILITHIUM_Q)
|
||||
s -= DILITHIUM_Q;
|
||||
return (coeff_t)s;
|
||||
}
|
||||
|
||||
coeff_t mod_sub(coeff_t a, coeff_t b)
|
||||
{
|
||||
#pragma HLS inline
|
||||
wide_t d = (wide_t)a - (wide_t)b;
|
||||
if (d < 0)
|
||||
d += DILITHIUM_Q;
|
||||
return (coeff_t)d;
|
||||
}
|
||||
|
||||
// One step of the butterfly: t = zeta * b, then (a+t, a-t)
|
||||
static void butterfly(coeff_t &a, coeff_t &b, coeff_t zeta)
|
||||
{
|
||||
#pragma HLS inline
|
||||
wide_t prod = (wide_t)zeta * (wide_t)b;
|
||||
coeff_t t = mod_q(prod);
|
||||
coeff_t a0 = a;
|
||||
a = mod_add(a0, t);
|
||||
b = mod_sub(a0, t);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------
|
||||
// Forward NTT in-place on a[0..255] for R_q = Z_q[x]/(x^256 - 1)
|
||||
// Standard Cooley–Tukey style iteration over layers.
|
||||
// The zetas[] table and its traversal pattern must match the Dilithium spec.
|
||||
// ---------------------------------------------------------------------
|
||||
void ntt(coeff_t a[DILITHIUM_N])
|
||||
{
|
||||
#pragma HLS inline off
|
||||
#pragma HLS ARRAY_PARTITION variable=a cyclic factor=4 dim=1
|
||||
|
||||
unsigned k = 0;
|
||||
// len goes 128,64,32,...,1
|
||||
for (unsigned len = DILITHIUM_N >> 1; len >= 1; len >>= 1) {
|
||||
#pragma HLS loop_tripcount min=1 max=8
|
||||
for (unsigned start = 0; start < DILITHIUM_N; start += 2 * len) {
|
||||
#pragma HLS loop_tripcount min=1 max=256
|
||||
coeff_t zeta = dilithium_zetas[k++];
|
||||
for (unsigned j = start; j < start + len; j++) {
|
||||
#pragma HLS PIPELINE II=1
|
||||
coeff_t &aj = a[j];
|
||||
coeff_t &ajl = a[j + len];
|
||||
butterfly(aj, ajl, zeta);
|
||||
}
|
||||
}
|
||||
if (len == 1) break; // avoid unsigned wrap
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------
|
||||
// Inverse NTT (Gentleman–Sande style) on a[0..255], followed by scaling
|
||||
// by n^{-1}. Pattern mirrors standard Dilithium invNTT.
|
||||
// ---------------------------------------------------------------------
|
||||
void inv_ntt(coeff_t a[DILITHIUM_N])
|
||||
{
|
||||
#pragma HLS inline off
|
||||
#pragma HLS ARRAY_PARTITION variable=a cyclic factor=4 dim=1
|
||||
|
||||
unsigned k = 0;
|
||||
// len goes 1,2,4,...,128
|
||||
for (unsigned len = 1; len < DILITHIUM_N; len <<= 1) {
|
||||
#pragma HLS loop_tripcount min=1 max=8
|
||||
for (unsigned start = 0; start < DILITHIUM_N; start += 2 * len) {
|
||||
#pragma HLS loop_tripcount min=1 max=256
|
||||
coeff_t zeta = dilithium_zetas_inv[k++];
|
||||
for (unsigned j = start; j < start + len; j++) {
|
||||
#pragma HLS PIPELINE II=1
|
||||
coeff_t u = a[j];
|
||||
coeff_t v = a[j + len];
|
||||
|
||||
coeff_t a_sum = mod_add(u, v);
|
||||
coeff_t a_diff = mod_sub(u, v);
|
||||
|
||||
// Multiply the "difference" branch by zeta
|
||||
wide_t prod = (wide_t)zeta * (wide_t)a_diff;
|
||||
coeff_t t = mod_q(prod);
|
||||
|
||||
a[j] = a_sum;
|
||||
a[j+len] = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final scaling by n^{-1} mod q
|
||||
for (unsigned i = 0; i < DILITHIUM_N; i++) {
|
||||
#pragma HLS PIPELINE II=1
|
||||
wide_t prod = (wide_t)a[i] * (wide_t)DILITHIUM_INV_N;
|
||||
a[i] = mod_q(prod);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------
|
||||
// Polynomial multiplication in R_q = Z_q[x]/(x^256 - 1):
|
||||
// c = a * b (using NTT, pointwise multiply, INTT)
|
||||
// ---------------------------------------------------------------------
|
||||
static void poly_mult_internal(coeff_t a[DILITHIUM_N],
|
||||
coeff_t b[DILITHIUM_N],
|
||||
coeff_t c[DILITHIUM_N])
|
||||
{
|
||||
#pragma HLS inline off
|
||||
#pragma HLS DATAFLOW
|
||||
|
||||
// In-place NTT
|
||||
ntt(a);
|
||||
ntt(b);
|
||||
|
||||
// Pointwise multiply
|
||||
for (unsigned i = 0; i < DILITHIUM_N; i++) {
|
||||
#pragma HLS PIPELINE II=1
|
||||
wide_t prod = (wide_t)a[i] * (wide_t)b[i];
|
||||
c[i] = mod_q(prod);
|
||||
}
|
||||
|
||||
// In-place inverse NTT on c
|
||||
inv_ntt(c);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------
|
||||
// AXI4-Stream <-> internal coefficient buffer
|
||||
// Protocol:
|
||||
//
|
||||
// Input stream (axis_in):
|
||||
// - First 256 beats : a[0..255], last = 0
|
||||
// - Next 256 beats : b[0..255], last = 1 on the final beat
|
||||
//
|
||||
// Output stream (axis_out):
|
||||
// - 256 beats with c[0..255], last = 1 on the final beat
|
||||
// ---------------------------------------------------------------------
|
||||
static void axis_read_polys(hls::stream<coeff_axis_t> &axis_in,
|
||||
coeff_t a[DILITHIUM_N],
|
||||
coeff_t b[DILITHIUM_N])
|
||||
{
|
||||
#pragma HLS inline off
|
||||
coeff_axis_t w;
|
||||
|
||||
// Read a
|
||||
for (unsigned i = 0; i < DILITHIUM_N; i++) {
|
||||
#pragma HLS PIPELINE II=1
|
||||
w = axis_in.read();
|
||||
a[i] = (coeff_t)w.data;
|
||||
}
|
||||
|
||||
// Read b
|
||||
for (unsigned i = 0; i < DILITHIUM_N; i++) {
|
||||
#pragma HLS PIPELINE II=1
|
||||
w = axis_in.read();
|
||||
b[i] = (coeff_t)w.data;
|
||||
}
|
||||
}
|
||||
|
||||
static void axis_write_poly(hls::stream<coeff_axis_t> &axis_out,
|
||||
coeff_t c[DILITHIUM_N])
|
||||
{
|
||||
#pragma HLS inline off
|
||||
|
||||
coeff_axis_t w;
|
||||
w.keep = -1;
|
||||
w.strb = -1;
|
||||
|
||||
for (unsigned i = 0; i < DILITHIUM_N; i++) {
|
||||
#pragma HLS PIPELINE II=1
|
||||
w.data = (ap_int<32>)c[i];
|
||||
w.last = (i == DILITHIUM_N - 1) ? 1 : 0;
|
||||
axis_out.write(w);
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------
|
||||
// Top-level HLS function: poly_mult_dilithium
|
||||
//
|
||||
// This is your second accelerator, to be instantiated alongside the
|
||||
// Kyber NTT IP in the same overlay. It has its own AXI4-Stream ports,
|
||||
// but you can reuse the DMA driver pattern from Kyber by adjusting
|
||||
// the packing (no 2x16-bit packing here, just one 32-bit coeff per beat).
|
||||
// ---------------------------------------------------------------------
|
||||
int poly_mult(hls::stream<coeff_axis_t> &axis_in,
|
||||
hls::stream<coeff_axis_t> &axis_out)
|
||||
{
|
||||
#pragma HLS INTERFACE axis register port=axis_in
|
||||
#pragma HLS INTERFACE axis register port=axis_out
|
||||
#pragma HLS INTERFACE s_axilite port=return bundle=CTRL_BUS
|
||||
#pragma HLS DATAFLOW
|
||||
|
||||
coeff_t a[DILITHIUM_N];
|
||||
coeff_t b[DILITHIUM_N];
|
||||
coeff_t c[DILITHIUM_N];
|
||||
|
||||
#pragma HLS ARRAY_PARTITION variable=a cyclic factor=4 dim=1
|
||||
#pragma HLS ARRAY_PARTITION variable=b cyclic factor=4 dim=1
|
||||
#pragma HLS ARRAY_PARTITION variable=c cyclic factor=4 dim=1
|
||||
|
||||
axis_read_polys(axis_in, a, b);
|
||||
poly_mult_internal(a, b, c);
|
||||
axis_write_poly(axis_out, c);
|
||||
|
||||
return 0;
|
||||
}
|
||||
21
HLS_Codes_Dilithium/test_case.h
Normal file
21
HLS_Codes_Dilithium/test_case.h
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
#ifndef TEST_CASE_DILITHIUM_H
|
||||
#define TEST_CASE_DILITHIUM_H
|
||||
|
||||
#include "ntt.h"
|
||||
|
||||
// Length must be DILITHIUM_N (256)
|
||||
static const coeff_t input1_vals[DILITHIUM_N] = {
|
||||
// TODO: fill with 256 coefficients of polynomial a (mod 8380417)
|
||||
// e.g. generated by a software reference polymult
|
||||
};
|
||||
|
||||
static const coeff_t input2_vals[DILITHIUM_N] = {
|
||||
// TODO: 256 coefficients of polynomial b
|
||||
};
|
||||
|
||||
static const coeff_t output_vals[DILITHIUM_N] = {
|
||||
// TODO: golden result c = a * b in R_q[x]/(x^256 - 1),
|
||||
// reduced mod q and in the same canonical form as the HW (0..q-1)
|
||||
};
|
||||
|
||||
#endif // TEST_CASE_DILITHIUM_H
|
||||
2135
PYNQ-zcu104_Files/Kyber512-LastGood-6dec2025.ipynb
Normal file
2135
PYNQ-zcu104_Files/Kyber512-LastGood-6dec2025.ipynb
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue