dilithium hls try w C-Sim, Co-Sim, Synthesis, Impl ok

This commit is contained in:
oho 2025-12-12 19:37:56 +01:00
parent ab38cd1ada
commit ccbe56a514
4 changed files with 1399 additions and 612 deletions

View file

@ -9,33 +9,34 @@
typedef ap_uint<1> bit; typedef ap_uint<1> bit;
typedef ap_uint<8> ap_logn_t; typedef ap_uint<8> ap_logn_t;
typedef ap_int<24> coeff_t; typedef ap_int<32> coeff_t;
typedef ap_int<48> double_coeff_t; typedef ap_int<64> double_coeff_t;
// Internal streaming types // Internal streaming types (original design)
struct coeff_t_stream { struct coeff_t_stream
{
coeff_t value; coeff_t value;
bit last; bit last;
}; };
struct coeff_t_stream_big {
struct coeff_t_stream_big
{
double_coeff_t value; double_coeff_t value;
bit last; bit last;
}; };
// External AXI4-Stream element types (top-level ports) // External AXI4-Stream element types (only used on top-level ports)
typedef ap_axiu<24,0,0,0> coeff_axis_t; typedef ap_axiu<32,0,0,0> coeff_axis_t;
typedef ap_axiu<48,0,0,0> coeff_axis_big_t; typedef ap_axiu<64,0,0,0> coeff_axis_big_t;
#define N 128 #define N 128
#define Nt 256 #define Nt 256
#define logN 7 #define logN 7
// Modulus and twiddle constants extern coeff_t q, w_n;
extern coeff_t q;
extern coeff_t inv_n;
// Top-level function prototype // Top-level function now uses AXI4-Stream types for DMA compatibility
int poly_mult(hls::stream<coeff_axis_big_t> &input, int poly_mult_dil (hls::stream<coeff_axis_big_t> &input,
hls::stream<coeff_axis_t> &output); hls::stream<coeff_axis_t> &output);
#endif #endif

View file

@ -1,46 +1,78 @@
// pm_test.cpp
#include "test_case.h" #include "test_case.h"
int main() { int main()
{
// Top-level AXI4-Stream ports for the DUT // Top-level AXI4-Stream ports for the DUT
hls::stream<coeff_axis_big_t> in_data; hls::stream<coeff_axis_big_t> in_data;
hls::stream<coeff_axis_t> out_data; hls::stream<coeff_axis_t> out_data;
coeff_axis_big_t local_stream1; coeff_axis_big_t local_stream1;
coeff_axis_t local_stream2; coeff_axis_t local_stream2;
coeff_t actual_outputs[Nt];
int i; int i;
// Write input stimuli into input AXI4-Stream coeff_t actual_outputs[Nt];
for (i = 0; i < Nt; i++) { coeff_t golden_outputs[Nt]; // NEW: buffer for golden result
coeff_t val1 = input1_vals[i];
double_coeff_t val2 = (double_coeff_t) input2_vals[i] << 24; // -------------------------------------------------------------------------
// Pack two 24-bit values into one 48-bit word (stored in ap_axiu<48>) // Write stimulus into input AXI4-Stream
local_stream1.data = (ap_uint<48>) ((ap_uint<48>) val1 | (ap_uint<48>) val2); // -------------------------------------------------------------------------
local_stream1.keep = -1; for (i = 0; i < Nt; i++)
{
coeff_t val1 = input1_vals[i];
coeff_t val2 = input2_vals[i];
// Packing: 2×32-bit coeffs into one 64-bit word
ap_uint<64> word = 0;
word |= (ap_uint<32>)val1; // low 32 bits
word |= (ap_uint<64>)(ap_uint<32>)val2 << 32; // high 32 bits
local_stream1.data = word;
local_stream1.keep = -1; // 0xFF for 64-bit TDATA
local_stream1.strb = -1; local_stream1.strb = -1;
local_stream1.last = (i == Nt - 1) ? 1 : 0; local_stream1.last = (i == Nt - 1) ? 1 : 0;
in_data.write(local_stream1); in_data.write(local_stream1);
} }
// -------------------------------------------------------------------------
// Call DUT // Call DUT
poly_mult(in_data, out_data); // -------------------------------------------------------------------------
poly_mult_dil(in_data, out_data);
// Read results from output AXI4-Stream // -------------------------------------------------------------------------
for (i = 0; i < Nt; i++) { // Read result from output AXI4-Stream
// -------------------------------------------------------------------------
for (i = 0; i < Nt; i++)
{
local_stream2 = out_data.read(); local_stream2 = out_data.read();
actual_outputs[i] = (coeff_t) local_stream2.data; actual_outputs[i] = (coeff_t)local_stream2.data;
// (Optionally check local_stream2.last here) // local_stream2.last could be checked here if you want
} }
// -------------------------------------------------------------------------
// Compute golden result (software negacyclic product)
// -------------------------------------------------------------------------
golden_poly_mult_dil(golden_outputs, input1_vals, input2_vals);
// -------------------------------------------------------------------------
// Compare against golden output // Compare against golden output
// -------------------------------------------------------------------------
int ret_val = 0; int ret_val = 0;
for (i = 0; i < Nt; i++) { for (i = 0; i < Nt; i++)
if (output_vals[i] != actual_outputs[i]) { {
if (golden_outputs[i] != actual_outputs[i])
{
ret_val++; ret_val++;
std::cout << "Mismatch at index " << i std::cout << "Mismatch at i = " << i
<< ": expected " << output_vals[i] << " golden = " << golden_outputs[i]
<< ", got " << actual_outputs[i] << std::endl; << " hw = " << actual_outputs[i]
<< std::endl;
break; break;
} }
} }
return ret_val; return ret_val;
} }

File diff suppressed because it is too large Load diff

View file

@ -1,64 +1,60 @@
#include "ntt.h" #include "ntt.h"
// Test input and expected output arrays for Dilithium polynomial multiplication coeff_t input1_vals[] = {1477, 218, 784, 251, 747, 1051, 1924, 133, 2953, 1295, 2989, 1519, 1701, 1874, 2806, 423, 2883, 327, 47, 2525, 1508, 214, 2998, 217, 1852, 2624, 2286, 3039, 3076, 1213, 1808, 2554, 1129, 1353, 2690, 2839, 1778, 2752, 1378, 601, 914, 2335, 2497, 1139, 2611, 129, 1318, 1570, 3190, 1868, 940, 2901, 2626, 2473, 3195, 2621, 2436, 3046, 1018, 1139, 1729, 3021, 2064, 945, 690, 1700, 1836, 1943, 2333, 2131, 1618, 1741, 2639, 2653, 301, 2013, 2744, 2406, 2995, 2463, 2366, 1495, 442, 224, 1349, 11, 2342, 1712, 2847, 1578, 2654, 2734, 3131, 1245, 1862, 527, 2400, 2043, 1360, 451, 573, 898, 2018, 3100, 161, 284, 1949, 362, 755, 2916, 1288, 1616, 876, 1682, 853, 2772, 2956, 1101, 2, 214, 2589, 211, 1025, 610, 1225, 2118, 224, 1296, 2612, 2634, 2056, 3227, 1712, 1258, 552, 1345, 786, 2124, 2915, 1226, 1233, 2654, 2786, 2636, 2234, 727, 2444, 199, 600, 2262, 3221, 915, 63, 318, 74, 2396, 1690, 2390, 1711, 414, 10, 2298, 1082, 1419, 3151, 1723, 2744, 3274, 2518, 2954, 1208, 2941, 2089, 3288, 1370, 783, 2517, 3190, 3069, 2505, 2840, 1427, 1670, 3091, 655, 96, 1935, 880, 2511, 876, 2371, 341, 196, 2849, 919, 161, 603, 2993, 2903, 1721, 139, 3326, 1876, 379, 2508, 2094, 1929, 430, 1033, 2604, 1955, 1333, 2274, 3312, 2604, 1585, 2317, 3230, 3068, 2905, 3268, 2844, 1023, 2824, 1731, 643, 820, 462, 2975, 314, 2218, 2011, 649, 383, 874, 2181, 866, 1192, 2914, 2290, 1820, 1572, 1030, 3076, 1526, 2760, 12, 529, 1242, 560, 2723, 2894, 1097, 778, 1495, 371};
coeff_t input1_vals[] = { coeff_t input2_vals[] = {2960, 3124, 509, 485, 2525, 385, 608, 2893, 2423, 1802, 2556, 1090, 775, 2059, 898, 864, 2459, 1116, 551, 188, 3262, 2728, 3134, 2451, 427, 858, 1927, 830, 2688, 2388, 2818, 1418, 3298, 24, 2491, 1448, 1153, 178, 2489, 2126, 1772, 669, 1238, 633, 1919, 2222, 2673, 1918, 2202, 3312, 208, 976, 2267, 107, 2905, 1137, 2921, 2471, 2796, 1313, 485, 1982, 1557, 1203, 2930, 241, 3089, 890, 2193, 179, 952, 2057, 2444, 1378, 1466, 1362, 1808, 2343, 1532, 2651, 727, 3254, 1328, 1604, 967, 2418, 1266, 1826, 684, 2869, 3149, 1874, 1691, 1507, 339, 2473, 102, 3153, 969, 1551, 548, 3059, 2841, 1369, 148, 2510, 2025, 1369, 1579, 2474, 1093, 527, 1416, 981, 2320, 2305, 227, 2173, 812, 1703, 2952, 17, 1129, 2223, 1894, 959, 73, 339, 553, 1466, 1065, 617, 1749, 1896, 1838, 1771, 3092, 297, 996, 198, 521, 567, 3256, 2783, 1044, 2644, 744, 2986, 3178, 1522, 942, 2045, 236, 1866, 853, 2303, 2383, 3095, 418, 2752, 2105, 2896, 3081, 3067, 1696, 978, 102, 1961, 3120, 2741, 1029, 885, 2852, 2659, 2815, 3032, 2358, 3252, 1195, 3304, 878, 70, 3069, 2726, 2455, 182, 108, 2868, 1744, 1697, 1060, 1803, 1752, 829, 2434, 862, 2287, 2860, 352, 634, 2626, 1920, 2425, 239, 831, 2527, 1190, 1469, 2602, 1711, 2185, 1403, 3189, 1188, 2649, 2079, 2215, 790, 409, 2413, 627, 2268, 2507, 2102, 1727, 1146, 2711, 355, 1143, 1225, 430, 82, 3015, 2699, 642, 863, 241, 450, 440, 338, 365, 2621, 3022, 204, 149, 2986, 2191, 1793, 3085, 2128, 373, 290, 835, 580, 2530, 1948};
1477, 218, 784, 251, 747, 1051, 1924, 133, 2953, 1295, 2989, 1519, 1701, 1874, 2806, 423,
2883, 327, 47, 2525, 1508, 214, 2998, 217, 1852, 2624, 2286, 3039, 3076, 1213, 1808, 2554, coeff_t output_vals[] = {2762, 3061, 1101, 3267, 2744, 1349, 182, 1761, 3089, 751, 137, 368, 1461, 2956, 493, 1653, 2617, 721, 356, 3034, 2234, 1556, 809, 2290, 1597, 457, 811, 259, 685, 2478, 319, 2519, 1049, 837, 644, 2571, 1029, 2997, 762, 1710, 2110, 1099, 2513, 1038, 2176, 1938, 3214, 261, 1604, 2474, 5, 1211, 2816, 2848, 2286, 3146, 1777, 1630, 2412, 1457, 889, 671, 822, 2369, 1409, 2059, 1121, 1871, 303, 1178, 2241, 1827, 2046, 628, 2869, 749, 1666, 895, 580, 1770, 2082, 3123, 1192, 520, 168, 2461, 1032, 163, 1421, 2792, 2148, 1735, 220, 1896, 2887, 2163, 357, 2301, 1830, 163, 1812, 805, 1850, 2017, 2313, 1205, 2226, 703, 866, 1708, 1426, 1920, 2911, 267, 3134, 629, 2120, 2022, 2847, 2945, 2967, 1977, 1449, 2028, 1381, 2738, 1098, 2977, 2217, 2060, 710, 845, 2807, 509, 2512, 2444, 2355, 550, 2965, 2517, 1802, 1755, 1065, 1938, 388, 2365, 776, 2453, 1799, 1532, 384, 2266, 1071, 2063, 2858, 1414, 663, 2886, 2734, 209, 1061, 2142, 841, 1081, 977, 799, 2661, 588, 3222, 2140, 2383, 3044, 394, 231, 1090, 917, 1840, 3002, 2315, 1182, 2744, 2815, 2612, 2586, 970, 3301, 3028, 2890, 1849, 269, 2936, 1525, 3102, 3144, 1605, 2746, 1556, 537, 2918, 2549, 976, 250, 2137, 492, 729, 392, 1115, 2422, 2100, 2317, 1636, 1743, 1279, 1393, 2079, 2874, 2148, 233, 1469, 3143, 2109, 1211, 2318, 1138, 2979, 1383, 125, 1995, 1614, 1435, 2216, 782, 671, 662, 988, 2826, 2162, 605, 2955, 2478, 2375, 1449, 2307, 1921, 1285, 2208, 2422, 1035, 2765, 923, 2138, 3053, 812, 146, 1175, 61};
1129, 1353, 2690, 2839, 1778, 2752, 1378, 601, 914, 2335, 2497, 1139, 2611, 129, 1318, 1570,
3190, 1868, 940, 2901, 2626, 2473, 3195, 2621, 2436, 3046, 1018, 1139, 1729, 3021, 2064, 945, // -------------------------------------------------------------------------
690, 1700, 1836, 1943, 2333, 2131, 1618, 1741, 2639, 2653, 301, 2013, 2744, 2406, 2995, 2463, // Golden model for Dilithium-style negacyclic polynomial multiplication
2366, 1495, 442, 224, 1349, 11, 2342, 1712, 2847, 1578, 2654, 2734, 3131, 1245, 1862, 527, // c(x) = a(x) * b(x) mod (x^Nt + 1, q)
2400, 2043, 1360, 451, 573, 898, 2018, 3100, 161, 284, 1949, 362, 755, 2916, 1288, 1616, 876, // -------------------------------------------------------------------------
1682, 853, 2772, 2956, 1101, 2, 214, 2589, 211, 1025, 610, 1225, 2118, 224, 1296, 2612, 2634,
2056, 3227, 1712, 1258, 552, 1345, 786, 2124, 2915, 1226, 1233, 2654, 2786, 2636, 2234, 727, static inline coeff_t golden_mod_q(long long x)
2444, 199, 600, 2262, 3221, 915, 63, 318, 74, 2396, 1690, 2390, 1711, 414, 10, 2298, 1082, {
1419, 3151, 1723, 2744, 3274, 2518, 2954, 1208, 2941, 2089, 3288, 1370, 783, 2517, 3190, 3069, // If you want to stay in sync with the core's modulus,
2505, 2840, 1427, 1670, 3091, 655, 96, 1935, 880, 2511, 876, 2371, 341, 196, 2849, 919, 161, // you can also replace the next line with: long long q_long = (long long)q;
603, 2993, 2903, 1721, 139, 3326, 1876, 379, 2508, 2094, 1929, 430, 1033, 2604, 1955, 1333, const long long q_long = 8380417LL; // Dilithium q; change if needed
2274, 3312, 2604, 1585, 2317, 3230, 3068, 2905, 3268, 2844, 1023, 2824, 1731, 643, 820, 462,
2975, 314, 2218, 2011, 649, 383, 874, 2181, 866, 1192, 2914, 2290, 1820, 1572, 1030, 3076, long long r = x % q_long;
1526, 2760, 12, 529, 1242, 560, 2723, 2894, 1097, 778, 1495, 371 if (r < 0)
}; r += q_long;
coeff_t input2_vals[] = {
2960, 3124, 509, 485, 2525, 385, 608, 2893, 2423, 1802, 2556, 1090, 775, 2059, 898, 864, return (coeff_t)r;
2459, 1116, 551, 188, 3262, 2728, 3134, 2451, 427, 858, 1927, 830, 2688, 2388, 2818, 1418, }
3298, 24, 2491, 1448, 1153, 178, 2489, 2126, 1772, 669, 1238, 633, 1919, 2222, 2673, 1918,
2202, 3312, 208, 976, 2267, 107, 2905, 1137, 2921, 2471, 2796, 1313, 485, 1982, 1557, 1203, static inline void golden_poly_mult_dil(coeff_t c[Nt],
2930, 241, 3089, 890, 2193, 179, 952, 2057, 2444, 1378, 1466, 1362, 1808, 2343, 1532, 2651, const coeff_t a[Nt],
727, 3254, 1328, 1604, 967, 2418, 1266, 1826, 684, 2869, 3149, 1874, 1691, 1507, 339, 2473, const coeff_t b[Nt])
102, 3153, 969, 1551, 548, 3059, 2841, 1369, 148, 2510, 2025, 1369, 1579, 2474, 1093, 527, {
1416, 981, 2320, 2305, 227, 2173, 812, 1703, 2952, 17, 1129, 2223, 1894, 959, 73, 339, 553, long long acc[Nt];
1466, 1065, 617, 1749, 1896, 1838, 1771, 3092, 297, 996, 198, 521, 567, 3256, 2783, 1044,
2644, 744, 2986, 3178, 1522, 942, 2045, 236, 1866, 853, 2303, 2383, 3095, 418, 2752, 2105, // Zero accumulator
2896, 3081, 3067, 1696, 978, 102, 1961, 3120, 2741, 1029, 885, 2852, 2659, 2815, 3032, 2358, for (int i = 0; i < Nt; i++)
3252, 1195, 3304, 878, 70, 3069, 2726, 2455, 182, 108, 2868, 1744, 1697, 1060, 1803, 1752, acc[i] = 0;
829, 2434, 862, 2287, 2860, 352, 634, 2626, 1920, 2425, 239, 831, 2527, 1190, 1469, 2602,
1711, 2185, 1403, 3189, 1188, 2649, 2079, 2215, 790, 409, 2413, 627, 2268, 2507, 2102, 1727, // Negacyclic convolution: mod (x^Nt + 1)
1146, 2711, 355, 1143, 1225, 430, 82, 3015, 2699, 642, 863, 241, 450, 440, 338, 365, 2621, for (int i = 0; i < Nt; i++)
3022, 204, 149, 2986, 2191, 1793, 3085, 2128, 373, 290, 835, 580, 2530, 1948 {
}; for (int j = 0; j < Nt; j++)
coeff_t output_vals[] = { {
4610776, 120935, 1254126, 1209043, 8250679, 5330432, 735926, 4979294, 3072462, 4438343, 5959108, 4150199, long long prod = (long long)a[i] * (long long)b[j];
4125374, 799268, 2926975, 3345416, 6514953, 832221, 7949483, 5277257, 2590090, 7395643, 8089082, 314198, int idx = i + j;
7811635, 2435420, 7246266, 2153173, 3788177, 4021035, 1833670, 2642681, 3518018, 5099879, 5041326, 2680133,
2294752, 5040790, 6356070, 6817707, 3358789, 4806383, 3327317, 312329, 2347630, 3825407, 1256042, 3557082, if (idx < Nt)
5430887, 645661, 4038266, 5101636, 694984, 5345508, 2538149, 7704469, 608436, 4903777, 1808767, 7001927, {
428238, 7935468, 5373889, 3436349, 4187378, 106705, 4142516, 3600459, 1797819, 5861129, 6751083, 5646281, acc[idx] += prod; // "low" part
6572021, 2630356, 335813, 6149778, 2975343, 768557, 4186842, 811761, 4856709, 2679131, 6203378, 3703724, }
5132796, 3547916, 7657381, 1684219, 7387744, 824164, 6723600, 6802520, 776578, 7448197, 2239319, 6534227, else
1227829, 5033869, 3064815, 286414, 1995671, 6388718, 2560006, 7552570, 6339575, 805335, 3113443, 2864772, {
800943, 4400049, 7451911, 7392281, 6887610, 3918170, 5386836, 5448798, 3923739, 5342517, 7219282, 4135030, acc[idx - Nt] -= prod; // folded back with a minus
4848075, 3867634, 3104028, 1564387, 4163402, 3540475, 7001216, 3682604, 3743336, 7384009, 4377282, 2190376, }
4942710, 7034094, 5629808, 5848759, 2332012, 4648421, 497040, 2470792, 6962362, 3819457, 4432599, 7653714, }
261471, 1350408, 6731269, 5239397, 1550210, 486852, 2250358, 4924861, 6976950, 300934, 806395, 3144929, }
3055825, 4790760, 5436599, 5357177, 6797049, 3711049, 1553304, 639926, 3251638, 2748721, 7813863, 1852468,
3250212, 2591962, 3399977, 3748588, 6728836, 6594725, 1978511, 2616210, 2563868, 1564352, 807960, 2796290, // Final reduction mod q
6274942, 7041857, 7435708, 3382319, 3146709, 5105958, 262779, 6159078, 6653706, 761471, 6987350, 356740, for (int i = 0; i < Nt; i++)
5740292, 4488534, 4422468, 1044504, 807782, 4184739, 661801, 2796031, 806827, 2038358, 1126151, 1109702, {
776038, 5143598, 2021761, 4167720, 2332673, 2619321, 697951, 2289483, 7123747, 1746806, 2804069, 6903563, c[i] = golden_mod_q(acc[i]);
636065, 1071798, 4014993, 7412447, 2790677, 5199202, 3494314, 2470575, 6040629, 6097253, 4627782, 4484757, }
1870653, 5869777, 814072, 4417289, 6604748, 5623468, 6238573, 1810479, 2943024, 2030271, 5815022, 3857202, }
1218369, 691903, 5686355, 419696, 1377088, 6367464, 3803857, 3828148, 2852124, 7792963, 7612924, 5532742,
5453559, 1973636, 5708166, 774499, 3566136, 4292240, 3878276, 6917142, 6084901, 4680272, 2564962, 5851662,
2089634, 5040595, 1341598, 5159646, 3461480, 733187, 797953, 4158070, 6311107, 6386220, 5275160, 4743402
};