commit 904df42cd064f54b9dbb5ae66b575bb7fe6a163e Author: oho Date: Tue Dec 9 10:08:18 2025 +0100 kyber running, starting dilithium development diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..62b39ef --- /dev/null +++ b/.gitignore @@ -0,0 +1,27 @@ +# Ignore build output directory +/build +/export + +# Ignore object files and dependent files +.o +.d + +#Ignore logs folder and log files +/logs +.log + +#ignore zip files +.zip + +#Ignore ide folder +/_ide + +#Ignore lock files +.lock + +.bin +.pdi +.vitisWorkspace.json + +_ide/logs +_ide/.wsdata \ No newline at end of file diff --git a/HLS_Codes/ntt.h b/HLS_Codes/ntt.h new file mode 100644 index 0000000..4ef770f --- /dev/null +++ b/HLS_Codes/ntt.h @@ -0,0 +1,42 @@ +#ifndef NTT_H +#define NTT_H + +#include +#include + +#include "hls_stream.h" +#include "ap_axi_sdata.h" + +typedef ap_uint<1> bit; +typedef ap_uint<8> ap_logn_t; +typedef ap_int<16> coeff_t; +typedef ap_int<32> double_coeff_t; + +// Internal streaming types (original design) +struct coeff_t_stream +{ + coeff_t value; + bit last; +}; + +struct coeff_t_stream_big +{ + double_coeff_t value; + bit last; +}; + +// External AXI4-Stream element types (only used on top-level ports) +typedef ap_axiu<16,0,0,0> coeff_axis_t; +typedef ap_axiu<32,0,0,0> coeff_axis_big_t; + +#define N 128 +#define Nt 256 +#define logN 7 + +extern coeff_t q, w_n; + +// Top-level function now uses AXI4-Stream types for DMA compatibility +int poly_mult (hls::stream &input, + hls::stream &output); + +#endif diff --git a/HLS_Codes/pm_test.cpp b/HLS_Codes/pm_test.cpp new file mode 100644 index 0000000..f4b4c80 --- /dev/null +++ b/HLS_Codes/pm_test.cpp @@ -0,0 +1,69 @@ +// pm_test.cpp (works with Set A and Set B) + +#include "test_case.h" + +int main() +{ + // Top-level AXI4-Stream ports for the DUT + hls::stream in_data; + hls::stream out_data; + + coeff_axis_big_t local_stream1; + coeff_axis_t local_stream2; + + int i; + + coeff_t actual_outputs[Nt]; + + // ------------------------------------------------------------------------- + // Write stimulus into input AXI4-Stream + // ------------------------------------------------------------------------- + for (i = 0; i < Nt; i++) + { + coeff_t val1 = input1_vals[i]; + double_coeff_t val2 = (double_coeff_t)(input2_vals[i] * 65536); + + // Pack into 32-bit AXI data word + local_stream1.data = (ap_uint<32>)(val1 + val2); + + // Mark all bytes valid; side channels are disabled here + local_stream1.keep = -1; + local_stream1.strb = -1; + + // TLAST on final sample + local_stream1.last = (i == Nt - 1) ? 1 : 0; + + in_data.write(local_stream1); + } + + // ------------------------------------------------------------------------- + // Call DUT + // ------------------------------------------------------------------------- + poly_mult(in_data, out_data); + + // ------------------------------------------------------------------------- + // Read result from output AXI4-Stream + // ------------------------------------------------------------------------- + for (i = 0; i < Nt; i++) + { + local_stream2 = out_data.read(); + actual_outputs[i] = (coeff_t)local_stream2.data; + // Optionally check local_stream2.last here + } + + // ------------------------------------------------------------------------- + // Compare against golden output + // ------------------------------------------------------------------------- + int ret_val = 0; + for (i = 0; i < Nt; i++) + { + if (output_vals[i] != actual_outputs[i]) + { + ret_val++; + std::cout << actual_outputs[i]; + break; + } + } + + return ret_val; +} diff --git a/HLS_Codes/polymult.cpp b/HLS_Codes/polymult.cpp new file mode 100644 index 0000000..6ef8d1c --- /dev/null +++ b/HLS_Codes/polymult.cpp @@ -0,0 +1,1091 @@ +// polymult.cpp (Set A, with all inner pragmas kept, and .user/.id/.dest removed) + +#include "ntt.h" + +coeff_t q = 3329; +coeff_t inv_n = 3303; +//double_coeff_t v = 20159; + +/*coeff_t mod(double_coeff_t A) +{ + #pragma HLS inline OFF + //double_coeff_t v = (double_coeff_t) ((1<<26) + 1664)/q; + double_coeff_t t = (v * A + (1 << 25)) >> 26; + t = t * q; + coeff_t val; + if (A < t) + val = A - t + q; + else + val = A - t; + return val; +}*/ + +ap_uint<13> m = 5039; + +coeff_t mod(double_coeff_t A) +{ + #pragma HLS pipeline II = 1 + coeff_t val; + ap_uint<36> t123 = m * A; + ap_uint<12> t = (t123 >> 24); + ap_uint<24> ta = t * q; + ap_uint<24> c = A - ta; + if (c > q) + val = (coeff_t) (c - q); + else + val = (coeff_t) c; + return val; +} + +coeff_t modadd(coeff_t x, coeff_t y) +{ + #pragma HLS inline + coeff_t w = x + y; + return (coeff_t)(w - (w < q ? (coeff_t)0 : q)); +} + +coeff_t modsub(coeff_t x, coeff_t y) +{ + #pragma HLS inline + coeff_t s = x + (x > y ? (coeff_t)0 : q); + return (coeff_t)(s - y); +} + +void butterfly_unit_dif(coeff_t w, coeff_t a, coeff_t b, coeff_t &x, coeff_t &y) +{ + #pragma HLS pipeline II = 1 + x = modadd(a, b); + y = modsub(a, b); + y = mod(w * y); +} + +void butterfly_unit_dit(coeff_t w, coeff_t a, coeff_t b, coeff_t &x, coeff_t &y) +{ + #pragma HLS pipeline II = 1 + coeff_t wb = mod(w * b); + x = modadd(a, wb); + y = modsub(a, wb); +} + +void delay_cycle() +{ + #ifdef __SYNTHESIS__ + ap_wait_n(1); + #endif +} + +void ntt_stage1 (hls::stream &a, hls::stream &b, coeff_t fifo[]) +{ + #pragma HLS dataflow + coeff_t twiddle_coeff = 1729; + + #pragma HLS DEPENDENCE variable = fifo inter RAW false + + int x, y; + coeff_t a_, b_, it, bf1, bf2, tf; + + for (int i = 0; i < 64; i++) + { + #pragma HLS pipeline + it = a.read(); + fifo[i + 64] = it; + } + + for (int j = 0; j < 1; j++) + { + #pragma HLS DEPENDENCE variable = fifo inter RAW false + int iter = 0; + for (int k = 0; k < 64; k++) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + a_ = fifo[iter + 64]; + b_ = a.read(); + tf = twiddle_coeff; + butterfly_unit_dit(tf, a_, b_, bf1, bf2); + b.write(bf1); + fifo[iter] = bf2; + iter++; + delay_cycle(); + } + + for (int i = 0; i < 64; i++) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + b.write(fifo[i]); + delay_cycle(); + } + } +} + +void ntt_stage2 (hls::stream &a, hls::stream &b, coeff_t fifo[]) +{ + #pragma HLS dataflow + coeff_t twiddle_coeffs[2] = {2580, 3289}; + + #pragma HLS DEPENDENCE variable = fifo inter RAW false + + int x, y; + coeff_t a_, b_, it, bf1, bf2, tf; + + for (int i = 0; i < 32; i++) + { + #pragma HLS pipeline + it = a.read(); + fifo[i + 64] = it; + } + + for (int j = 0; j < 2; j++) + { + #pragma HLS DEPENDENCE variable = fifo inter RAW false + int iter = 0; + for (int k = 0; k < 32; k++) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + a_ = fifo[iter + 64]; + b_ = a.read(); + tf = twiddle_coeffs[j]; + butterfly_unit_dit(tf, a_, b_, bf1, bf2); + b.write(bf1); + fifo[iter] = bf2; + iter++; + delay_cycle(); + } + + for (int i = 0; i < 32; i++) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + b.write(fifo[i]); + delay_cycle(); + if (j < 1) + { + it = a.read(); + fifo[i + 64] = it; + } + } + } +} + +void ntt_stage3 (hls::stream &a, hls::stream &b, coeff_t fifo[]) +{ + #pragma HLS dataflow + coeff_t twiddle_coeffs[4] = {2642, 630, 1897, 848}; + + #pragma HLS DEPENDENCE variable = fifo inter RAW false + + int x, y; + coeff_t a_, b_, it, bf1, bf2, tf; + + for (int i = 0; i < 16; i++) + { + #pragma HLS pipeline + it = a.read(); + fifo[i + 64] = it; + } + + for (int j = 0; j < 4; j++) + { + #pragma HLS DEPENDENCE variable = fifo inter RAW false + int iter = 0; + for (int k = 0; k < 16; k++) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + a_ = fifo[iter + 64]; + b_ = a.read(); + tf = twiddle_coeffs[j]; + butterfly_unit_dit(tf, a_, b_, bf1, bf2); + b.write(bf1); + fifo[iter] = bf2; + iter++; + delay_cycle(); + } + + for (int i = 0; i < 16; i++) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + b.write(fifo[i]); + delay_cycle(); + if (j < 3) + { + it = a.read(); + fifo[i + 64] = it; + } + } + } +} + +void ntt_stage4 (hls::stream &a, hls::stream &b, coeff_t fifo[]) +{ + #pragma HLS dataflow + coeff_t twiddle_coeffs[8] = {1062, 1919, 193, 797, 2786, 3260, 569, 1746}; + + #pragma HLS DEPENDENCE variable = fifo inter RAW false + + int x, y; + coeff_t a_, b_, it, bf1, bf2, tf; + + for (int i = 0; i < 8; i++) + { + #pragma HLS pipeline + it = a.read(); + fifo[i + 64] = it; + } + + for (int j = 0; j < 8; j++) + { + #pragma HLS DEPENDENCE variable = fifo inter RAW false + int iter = 0; + int ind = 1; + for (int k = 0; k < 8; k++) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + a_ = fifo[iter + 64]; + b_ = a.read(); + tf = twiddle_coeffs[j]; + butterfly_unit_dit(tf, a_, b_, bf1, bf2); + b.write(bf1); + fifo[iter] = bf2; + iter++; + delay_cycle(); + } + + for (int i = 0; i < 8; i++) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + b.write(fifo[i]); + delay_cycle(); + if (j < 7) + { + it = a.read(); + fifo[i + 64] = it; + } + } + } +} + +void ntt_stage5 (hls::stream &a, hls::stream &b, coeff_t fifo[]) +{ + #pragma HLS dataflow + coeff_t twiddle_coeffs[16] = {296, 2447, 1339, 1476, 3046, 56, 2240, 1333, + 1426, 2094, 535, 2882, 2393, 2879, 1974, 821}; + + #pragma HLS DEPENDENCE variable = fifo inter RAW false + + int x, y; + coeff_t a_, b_, it, bf1, bf2, tf; + + for (int i = 0; i < 4; i++) + { + #pragma HLS pipeline + it = a.read(); + fifo[i + 64] = it; + } + + for (int j = 0; j < 16; j++) + { + #pragma HLS DEPENDENCE variable = fifo inter RAW false + int iter = 0; + int ind = 1; + for (int k = 0; k < 4; k = k + 1) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + a_ = fifo[iter + 64]; + b_ = a.read(); + tf = twiddle_coeffs[j]; + butterfly_unit_dit(tf, a_, b_, bf1, bf2); + b.write(bf1); + fifo[iter] = bf2; + iter++; + delay_cycle(); + } + + for (int i = 0; i < 4; i++) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + b.write(fifo[i]); + delay_cycle(); + if (j < 15) + { + it = a.read(); + fifo[i + 64] = it; + } + } + } +} + +void ntt_stage6 (hls::stream &a, hls::stream &b, coeff_t fifo[]) +{ + #pragma HLS dataflow + coeff_t twiddle_coeffs[32] = {289, 331, 3253, 1756, 1197, 2304, 2277, 2055, + 650, 1977, 2513, 632, 2865, 33, 1320, 1915, + 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, + 2647, 2617, 1481, 648, 2474, 3110, 1227, 910}; + + #pragma HLS DEPENDENCE variable = fifo inter RAW false + + int x, y; + coeff_t a_, b_, it, bf1, bf2, tf; + + for (int i = 0; i < 2; i++) + { + #pragma HLS pipeline + it = a.read(); + fifo[i + 64] = it; + } + + for (int j = 0; j < 32; j++) + { + #pragma HLS DEPENDENCE variable = fifo inter RAW false + int iter = 0; + int ind = 1; + for (int k = 0; k < 2; k = k + 1) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + a_ = fifo[iter + 64]; + b_ = a.read(); + tf = twiddle_coeffs[j]; + butterfly_unit_dit(tf, a_, b_, bf1, bf2); + b.write(bf1); + fifo[iter] = bf2; + iter++; + delay_cycle(); + } + + for (int i = 0; i < 2; i++) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + b.write(fifo[i]); + delay_cycle(); + if (j < 31) + { + it = a.read(); + fifo[i + 64] = it; + } + } + } +} + +void ntt_stage7 (hls::stream &a, hls::stream &b, coeff_t fifo[]) +{ + #pragma HLS inline off + + #pragma HLS DEPENDENCE variable = fifo inter RAW false + coeff_t twiddle_coeffs[64] = {17, 2761, 583, 2649, 1637, 723, 2288, 1100, + 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, + 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, + 939, 2308, 2437, 2388, 733, 2337, 268, 641, + 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, + 1063, 319, 2773, 757, 2099, 561, 2466, 2594, + 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, + 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154}; + int x, y; + coeff_t u, t, it, bf1, bf2; + + u = a.read(); + + for (int j = 0; j < 64; j++) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + + t = a.read(); + butterfly_unit_dit(twiddle_coeffs[j], u, t, bf1, bf2); + b.write(bf1); + b.write(bf2); + if (j < 63) + u = a.read(); + } +} + +void intt_stage1 (hls::stream &a, hls::stream &b, coeff_t fifo[]) +{ + coeff_t twiddle_coeffs[64] = {1175, 2444, 394, 1219, 2300, 1455, 2117, 1607, + 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, + 735, 863, 2768, 1230, 2572, 556, 3010, 2266, + 1684, 1239, 780, 2954, 109, 1292, 1031, 1745, + 2688, 3061, 992, 2596, 941, 892, 1021, 2390, + 642, 1868, 2377, 1482, 1540, 540, 1678, 1626, + 279, 314, 1173, 2573, 3096, 48, 667, 1920, + 2229, 1041, 2606, 1692, 680, 2746, 568, 3312}; + + #pragma HLS inline off + #pragma HLS DEPENDENCE variable = fifo inter RAW false + int x, y; + coeff_t u, t, it, bf1, bf2; + + u = a.read(); + + for (int j = 0; j < 64; j++) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + t = a.read(); + butterfly_unit_dif(twiddle_coeffs[j], u, t, bf1, bf2); + b.write(bf1); + b.write(bf2); + if (j < 63) + u = a.read(); + } +} + +void intt_stage2 (hls::stream &a, hls::stream &b, coeff_t fifo[]) +{ + #pragma HLS dataflow + coeff_t twiddle_coeffs[32] = {2419, 2102, 219, 855, 2681, 1848, 712, 682, + 927, 1795, 461, 1891, 2877, 2522, 1894, 1010, + 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, + 1274, 1052, 1025, 2132, 1573, 76, 2998, 3040}; + + #pragma HLS DEPENDENCE variable = fifo inter RAW false + + int x, y; + coeff_t a_, b_, it, bf1, bf2, tf; + + for (int i = 0; i < 2; i++) + { + #pragma HLS pipeline + it = a.read(); + fifo[i + 64] = it; + } + + int ind = 0; + int count = 0; + for (int j = 0; j < 32; j++) + { + #pragma HLS DEPENDENCE variable = fifo inter RAW false + int iter = 0; + for (int k = 0; k < 2; k = k + 1) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + a_ = fifo[iter + 64]; + b_ = a.read(); + tf = twiddle_coeffs[ind]; + butterfly_unit_dif(tf, a_, b_, bf1, bf2); + b.write(bf1); + fifo[iter] = bf2; + iter++; + count++; + if (count % 2 == 0) + ind++; + delay_cycle(); + } + + for (int i = 0; i < 2; i++) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + b.write(fifo[i]); + delay_cycle(); + if (j < 31) + { + it = a.read(); + fifo[i + 64] = it; + } + } + } +} + +void intt_stage3 (hls::stream &a, hls::stream &b, coeff_t fifo[]) +{ + #pragma HLS dataflow + coeff_t twiddle_coeffs[16] = {2508, 1355, 450, 936, 447, 2794, 1235, 1903, + 1996, 1089, 3273, 283, 1853, 1990, 882, 3033}; + + #pragma HLS DEPENDENCE variable = fifo inter RAW false + + int x, y; + int m = 4; + coeff_t a_, b_, it, bf1, bf2, tf; + + for (int i = 0; i < 4; i++) + { + #pragma HLS pipeline + it = a.read(); + fifo[i + 64] = it; + } + + int ind = 0; + int count = 0; + for (int j = 0; j < 16; j++) + { + #pragma HLS DEPENDENCE variable = fifo inter RAW false + int iter = 0; + for (int k = 0; k < 4; k = k + 1) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + a_ = fifo[iter + 64]; + b_ = a.read(); + tf = twiddle_coeffs[ind]; + butterfly_unit_dif(tf, a_, b_, bf1, bf2); + b.write(bf1); + fifo[iter] = bf2; + iter++; + count++; + if (count % 4 == 0) + ind++; + delay_cycle(); + } + + for (int i = 0; i < 4; i++) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + b.write(fifo[i]); + delay_cycle(); + if (j < 15) + { + it = a.read(); + fifo[i + 64] = it; + } + } + } +} + +void intt_stage4 (hls::stream &a, hls::stream &b, coeff_t fifo[]) +{ + #pragma HLS dataflow + coeff_t twiddle_coeffs[8] = {1583, 2760, 69, 543, 2532, 3136, 1410, 2267}; + + #pragma HLS DEPENDENCE variable = fifo inter RAW false + + int x, y; + coeff_t a_, b_, it, bf1, bf2, tf; + + for (int i = 0; i < 8; i++) + { + #pragma HLS pipeline + it = a.read(); + fifo[i + 64] = it; + } + + int ind = 0; + int count = 0; + for (int j = 0; j < 8; j++) + { + #pragma HLS DEPENDENCE variable = fifo inter RAW false + int iter = 0; + for (int k = 0; k < 8; k = k + 1) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + a_ = fifo[iter + 64]; + b_ = a.read(); + tf = twiddle_coeffs[ind]; + butterfly_unit_dif(tf, a_, b_, bf1, bf2); + b.write(bf1); + fifo[iter] = bf2; + iter++; + count++; + if (count % 8 == 0) + ind++; + delay_cycle(); + } + + for (int i = 0; i < 8; i++) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + b.write(fifo[i]); + delay_cycle(); + if (j < 7) + { + it = a.read(); + fifo[i + 64] = it; + } + } + } +} + +void intt_stage5 (hls::stream &a, hls::stream &b, coeff_t fifo[]) +{ + #pragma HLS dataflow + coeff_t twiddle_coeffs[4] = {2481, 1432, 2699, 687}; + + #pragma HLS DEPENDENCE variable = fifo inter RAW false + + int x, y; + coeff_t a_, b_, it, bf1, bf2, tf; + + for (int i = 0; i < 16; i++) + { + #pragma HLS pipeline + it = a.read(); + fifo[i + 64] = it; + } + + int ind = 0; + int count = 0; + for (int j = 0; j < 4; j++) + { + #pragma HLS DEPENDENCE variable = fifo inter RAW false + int iter = 0; + for (int k = 0; k < 16; k = k + 1) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + a_ = fifo[iter + 64]; + b_ = a.read(); + tf = twiddle_coeffs[ind]; + butterfly_unit_dif(tf, a_, b_, bf1, bf2); + b.write(bf1); + fifo[iter] = bf2; + iter++; + count++; + if (count % 16 == 0) + ind++; + delay_cycle(); + } + + for (int i = 0; i < 16; i++) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + b.write(fifo[i]); + delay_cycle(); + if (j < 3) + { + it = a.read(); + fifo[i + 64] = it; + } + } + } +} + +void intt_stage6 (hls::stream &a, hls::stream &b, coeff_t fifo[]) +{ + #pragma HLS dataflow + coeff_t twiddle_coeffs[2] = {40, 749}; + + #pragma HLS DEPENDENCE variable = fifo inter RAW false + + int x, y; + coeff_t a_, b_, it, bf1, bf2, tf; + + for (int i = 0; i < 32; i++) + { + #pragma HLS pipeline + it = a.read(); + fifo[i + 64] = it; + } + + int ind = 0; + int count = 0; + for (int j = 0; j < 2; j++) + { + #pragma HLS DEPENDENCE variable = fifo inter RAW false + int iter = 0; + for (int k = 0; k < 32; k = k + 1) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + a_ = fifo[iter + 64]; + b_ = a.read(); + tf = twiddle_coeffs[ind]; + butterfly_unit_dif(tf, a_, b_, bf1, bf2); + b.write(bf1); + fifo[iter] = bf2; + iter++; + count++; + if (count == 32) + ind++; + delay_cycle(); + } + + for (int i = 0; i < 32; i++) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + b.write(fifo[i]); + delay_cycle(); + if (j < 1) + { + it = a.read(); + fifo[i + 64] = it; + } + } + } +} + +void intt_stage7 (hls::stream &a, hls::stream &b, coeff_t fifo[]) +{ + #pragma HLS inline off + #pragma HLS DEPENDENCE variable = fifo inter RAW false + int x, y; + coeff_t a_, b_, it, bf1, bf2, bfn1, bfn2, tf; + + for (int i = 0; i < 64; i++) + { + #pragma HLS pipeline + it = a.read(); + fifo[i + 64] = it; + } + + for (int j = 0; j < 1; j++) + { + #pragma HLS DEPENDENCE variable = fifo inter RAW false + int iter = 0; + for (int k = 0; k < 64; k = k + 1) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + a_ = fifo[iter + 64]; + b_ = a.read(); + tf = 1600; + butterfly_unit_dif(tf, a_, b_, bf1, bf2); + bfn1 = mod(bf1 * inv_n); + bfn2 = mod(bf2 * inv_n); + b.write(bfn1); + fifo[iter] = bfn2; + iter++; + delay_cycle(); + } + + for (int i = 0; i < 64; i++) + { + #pragma HLS pipeline II = 1 + #pragma HLS DEPENDENCE variable = fifo inter RAW false + b.write(fifo[i]); + delay_cycle(); + } + } +} + +void read_inputs (hls::stream &input, hls::stream &se, hls::stream &so) +{ + coeff_t_stream x; + coeff_t a; + int i; + + for (i=0; i &se, hls::stream &so, hls::stream &output) +{ + coeff_t a1, a0; + coeff_t_stream y; + int i; + + y.last = 0; + for (i=0; i &input, hls::stream &output) +{ + #pragma HLS dataflow + + hls::stream s0o("s0o"), s1o("s1o"), s2o("s2o"), s3o("s3o"), + s4o("s4o"), s5o("s5o"), s6o("s6o"), s7o("s7o"), + s0e("s0e"), s1e("s1e"), s2e("s2e"), s3e("s3e"), + s4e("s4e"), s5e("s5e"), s6e("s6e"), s7e("s7e"); + + coeff_t fo7[65], fo6[66], fo5[68], fo4[72], fo3[80], fo2[96], fo1[128]; + coeff_t fe7[65], fe6[66], fe5[68], fe4[72], fe3[80], fe2[96], fe1[128]; + + coeff_t_stream x, y; + + #pragma HLS STREAM variable = s7o depth = 1 + #pragma HLS STREAM variable = s6o depth = 2 + #pragma HLS STREAM variable = s5o depth = 4 + #pragma HLS STREAM variable = s4o depth = 8 + #pragma HLS STREAM variable = s3o depth = 16 + #pragma HLS STREAM variable = s2o depth = 32 + #pragma HLS STREAM variable = s1o depth = 64 + #pragma HLS STREAM variable = s0o depth = 128 + + #pragma HLS STREAM variable = s7e depth = 1 + #pragma HLS STREAM variable = s6e depth = 2 + #pragma HLS STREAM variable = s5e depth = 4 + #pragma HLS STREAM variable = s4e depth = 8 + #pragma HLS STREAM variable = s3e depth = 16 + #pragma HLS STREAM variable = s2e depth = 32 + #pragma HLS STREAM variable = s1e depth = 64 + #pragma HLS STREAM variable = s0e depth = 128 + + + read_inputs(input, s0e, s0o); + + ntt_stage1 (s0e, s1e, fe1); + ntt_stage1 (s0o, s1o, fo1); + + ntt_stage2 (s1e, s2e, fe2); + ntt_stage2 (s1o, s2o, fo2); + + ntt_stage3 (s2e, s3e, fe3); + ntt_stage3 (s2o, s3o, fo3); + + ntt_stage4 (s3e, s4e, fe4); + ntt_stage4 (s3o, s4o, fo4); + + ntt_stage5 (s4e, s5e, fe5); + ntt_stage5 (s4o, s5o, fo5); + + ntt_stage6 (s5e, s6e, fe6); + ntt_stage6 (s5o, s6o, fo6); + + ntt_stage7 (s6e, s7e, fe7); + ntt_stage7 (s6o, s7o, fo7); + + write_outputs(s7e, s7o, output); +} + +void gs_intt (hls::stream &input, hls::stream &output) +{ + #pragma HLS dataflow + + hls::stream s0o("s0o"), s1o("s1o"), s2o("s2o"), s3o("s3o"), + s4o("s4o"), s5o("s5o"), s6o("s6o"), s7o("s7o"), + s0e("s0e"), s1e("s1e"), s2e("s2e"), s3e("s3e"), + s4e("s4e"), s5e("s5e"), s6e("s6e"), s7e("s7e"); + + coeff_t fo7[128], fo6[96], fo5[80], fo4[72], fo3[68], fo2[66], fo1[65]; + coeff_t fe7[128], fe6[96], fe5[80], fe4[72], fe3[68], fe2[66], fe1[65]; + + coeff_t_stream x, y; + + #pragma HLS STREAM variable = s7o depth = 1 + #pragma HLS STREAM variable = s6o depth = 2 + #pragma HLS STREAM variable = s5o depth = 4 + #pragma HLS STREAM variable = s4o depth = 8 + #pragma HLS STREAM variable = s3o depth = 16 + #pragma HLS STREAM variable = s2o depth = 32 + #pragma HLS STREAM variable = s1o depth = 64 + #pragma HLS STREAM variable = s0o depth = 128 + + #pragma HLS STREAM variable = s7e depth = 1 + #pragma HLS STREAM variable = s6e depth = 2 + #pragma HLS STREAM variable = s5e depth = 4 + #pragma HLS STREAM variable = s4e depth = 8 + #pragma HLS STREAM variable = s3e depth = 16 + #pragma HLS STREAM variable = s2e depth = 32 + #pragma HLS STREAM variable = s1e depth = 64 + #pragma HLS STREAM variable = s0e depth = 128 + + read_inputs(input, s0e, s0o); + + intt_stage1 (s0e, s1e, fe1); + intt_stage1 (s0o, s1o, fo1); + + intt_stage2 (s1e, s2e, fe2); + intt_stage2 (s1o, s2o, fo2); + + intt_stage3 (s2e, s3e, fe3); + intt_stage3 (s2o, s3o, fo3); + + intt_stage4 (s3e, s4e, fe4); + intt_stage4 (s3o, s4o, fo4); + + intt_stage5 (s4e, s5e, fe5); + intt_stage5 (s4o, s5o, fo5); + + intt_stage6 (s5e, s6e, fe6); + intt_stage6 (s5o, s6o, fo6); + + intt_stage7 (s6e, s7e, fe7); + intt_stage7 (s6o, s7o, fo7); + + write_outputs(s7e, s7o, output); +} + +void stream_split (hls::stream &input, + hls::stream &input1, + hls::stream &input2) +{ + + coeff_t_stream_big x; + double_coeff_t a; + coeff_t_stream x1, x2; + coeff_t a1, a2; + int i; + + for (i=0; i &input1, + hls::stream &input2, + hls::stream &output) +{ + coeff_t_stream xe, xo, ye, yo, z; + coeff_t ae, be, ce, ao, bo, co, c1, c2, c2s, c3, c4; + int i; + + coeff_t pm_factors[128] = {17, 3312, 2761, 568, 583, 2746, 2649, 680, + 1637, 1692, 723, 2606, 2288, 1041, 1100, 2229, + 1409, 1920, 2662, 667, 3281, 48, 233, 3096, + 756, 2573, 2156, 1173, 3015, 314, 3050, 279, + 1703, 1626, 1651, 1678, 2789, 540, 1789, 1540, + 1847, 1482, 952, 2377, 1461, 1868, 2687, 642, + 939, 2390, 2308, 1021, 2437, 892, 2388, 941, + 733, 2596, 2337, 992, 268, 3061, 641, 2688, + 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, + 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, + 1063, 2266, 319, 3010, 2773, 556, 757, 2572, + 2099, 1230, 561, 2768, 2466, 863, 2594, 735, + 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, + 1143, 2186, 2150, 1179, 2775, 554, 886, 2443, + 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, + 2110, 1219, 2935, 394, 885, 2444, 2154, 1175}; + + z.last = 0; + for (i=0; i internal stream conversion helpers (only at top level) +// ----------------------------------------------------------------------------- + +static void axis_to_internal_input(hls::stream &axis_in, + hls::stream &int_in) +{ + coeff_axis_big_t a; + coeff_t_stream_big x; + + for (int i = 0; i < Nt; i++) + { + #pragma HLS pipeline II = 1 + a = axis_in.read(); + + x.value = (double_coeff_t)a.data; + x.last = a.last; + + int_in.write(x); + + // Optional: break on TLAST if you want to be robust to shorter packets + if (a.last) + break; + } +} + +static void internal_to_axis_output(hls::stream &int_out, + hls::stream &axis_out) +{ + coeff_t_stream x; + coeff_axis_t a; + + for (int i = 0; i < Nt; i++) + { + #pragma HLS pipeline II = 1 + x = int_out.read(); + + a.data = (ap_uint<16>)x.value; + a.last = x.last; + + // Mark all bytes valid; side channels are disabled in this ap_axiu config + a.keep = -1; + a.strb = -1; + + axis_out.write(a); + + if (x.last) + break; + } +} + +// ----------------------------------------------------------------------------- +// Top-level function with AXI4-Stream ports (for DMA) and internal NTT pipeline +// ----------------------------------------------------------------------------- + +int poly_mult (hls::stream &input, + hls::stream &output) +{ + #pragma HLS INTERFACE axis register port=input + #pragma HLS INTERFACE axis register port=output + #pragma HLS INTERFACE s_axilite port=return bundle=CTRL_BUS + #pragma HLS dataflow + + // Internal streams using the original coeff_t_stream{,_big} types + hls::stream in_internal("in_internal"); + hls::stream input1("input1"), input2("input2"); + hls::stream middle1("middle1"), middle2("middle2"); + hls::stream middle3("middle3"), out_internal("out_internal"); + + axis_to_internal_input(input, in_internal); + stream_split(in_internal, input1, input2); + ct_ntt(input1, middle1); + ct_ntt(input2, middle2); + point_wise_mult(middle1, middle2, middle3); + gs_intt(middle3, out_internal); + internal_to_axis_output(out_internal, output); + + return 0; +} diff --git a/HLS_Codes/test_case.h b/HLS_Codes/test_case.h new file mode 100644 index 0000000..f8bebb4 --- /dev/null +++ b/HLS_Codes/test_case.h @@ -0,0 +1,6 @@ +#include "ntt.h" + +coeff_t input1_vals[] = {1477, 218, 784, 251, 747, 1051, 1924, 133, 2953, 1295, 2989, 1519, 1701, 1874, 2806, 423, 2883, 327, 47, 2525, 1508, 214, 2998, 217, 1852, 2624, 2286, 3039, 3076, 1213, 1808, 2554, 1129, 1353, 2690, 2839, 1778, 2752, 1378, 601, 914, 2335, 2497, 1139, 2611, 129, 1318, 1570, 3190, 1868, 940, 2901, 2626, 2473, 3195, 2621, 2436, 3046, 1018, 1139, 1729, 3021, 2064, 945, 690, 1700, 1836, 1943, 2333, 2131, 1618, 1741, 2639, 2653, 301, 2013, 2744, 2406, 2995, 2463, 2366, 1495, 442, 224, 1349, 11, 2342, 1712, 2847, 1578, 2654, 2734, 3131, 1245, 1862, 527, 2400, 2043, 1360, 451, 573, 898, 2018, 3100, 161, 284, 1949, 362, 755, 2916, 1288, 1616, 876, 1682, 853, 2772, 2956, 1101, 2, 214, 2589, 211, 1025, 610, 1225, 2118, 224, 1296, 2612, 2634, 2056, 3227, 1712, 1258, 552, 1345, 786, 2124, 2915, 1226, 1233, 2654, 2786, 2636, 2234, 727, 2444, 199, 600, 2262, 3221, 915, 63, 318, 74, 2396, 1690, 2390, 1711, 414, 10, 2298, 1082, 1419, 3151, 1723, 2744, 3274, 2518, 2954, 1208, 2941, 2089, 3288, 1370, 783, 2517, 3190, 3069, 2505, 2840, 1427, 1670, 3091, 655, 96, 1935, 880, 2511, 876, 2371, 341, 196, 2849, 919, 161, 603, 2993, 2903, 1721, 139, 3326, 1876, 379, 2508, 2094, 1929, 430, 1033, 2604, 1955, 1333, 2274, 3312, 2604, 1585, 2317, 3230, 3068, 2905, 3268, 2844, 1023, 2824, 1731, 643, 820, 462, 2975, 314, 2218, 2011, 649, 383, 874, 2181, 866, 1192, 2914, 2290, 1820, 1572, 1030, 3076, 1526, 2760, 12, 529, 1242, 560, 2723, 2894, 1097, 778, 1495, 371}; +coeff_t input2_vals[] = {2960, 3124, 509, 485, 2525, 385, 608, 2893, 2423, 1802, 2556, 1090, 775, 2059, 898, 864, 2459, 1116, 551, 188, 3262, 2728, 3134, 2451, 427, 858, 1927, 830, 2688, 2388, 2818, 1418, 3298, 24, 2491, 1448, 1153, 178, 2489, 2126, 1772, 669, 1238, 633, 1919, 2222, 2673, 1918, 2202, 3312, 208, 976, 2267, 107, 2905, 1137, 2921, 2471, 2796, 1313, 485, 1982, 1557, 1203, 2930, 241, 3089, 890, 2193, 179, 952, 2057, 2444, 1378, 1466, 1362, 1808, 2343, 1532, 2651, 727, 3254, 1328, 1604, 967, 2418, 1266, 1826, 684, 2869, 3149, 1874, 1691, 1507, 339, 2473, 102, 3153, 969, 1551, 548, 3059, 2841, 1369, 148, 2510, 2025, 1369, 1579, 2474, 1093, 527, 1416, 981, 2320, 2305, 227, 2173, 812, 1703, 2952, 17, 1129, 2223, 1894, 959, 73, 339, 553, 1466, 1065, 617, 1749, 1896, 1838, 1771, 3092, 297, 996, 198, 521, 567, 3256, 2783, 1044, 2644, 744, 2986, 3178, 1522, 942, 2045, 236, 1866, 853, 2303, 2383, 3095, 418, 2752, 2105, 2896, 3081, 3067, 1696, 978, 102, 1961, 3120, 2741, 1029, 885, 2852, 2659, 2815, 3032, 2358, 3252, 1195, 3304, 878, 70, 3069, 2726, 2455, 182, 108, 2868, 1744, 1697, 1060, 1803, 1752, 829, 2434, 862, 2287, 2860, 352, 634, 2626, 1920, 2425, 239, 831, 2527, 1190, 1469, 2602, 1711, 2185, 1403, 3189, 1188, 2649, 2079, 2215, 790, 409, 2413, 627, 2268, 2507, 2102, 1727, 1146, 2711, 355, 1143, 1225, 430, 82, 3015, 2699, 642, 863, 241, 450, 440, 338, 365, 2621, 3022, 204, 149, 2986, 2191, 1793, 3085, 2128, 373, 290, 835, 580, 2530, 1948}; + +coeff_t output_vals[] = {2762, 3061, 1101, 3267, 2744, 1349, 182, 1761, 3089, 751, 137, 368, 1461, 2956, 493, 1653, 2617, 721, 356, 3034, 2234, 1556, 809, 2290, 1597, 457, 811, 259, 685, 2478, 319, 2519, 1049, 837, 644, 2571, 1029, 2997, 762, 1710, 2110, 1099, 2513, 1038, 2176, 1938, 3214, 261, 1604, 2474, 5, 1211, 2816, 2848, 2286, 3146, 1777, 1630, 2412, 1457, 889, 671, 822, 2369, 1409, 2059, 1121, 1871, 303, 1178, 2241, 1827, 2046, 628, 2869, 749, 1666, 895, 580, 1770, 2082, 3123, 1192, 520, 168, 2461, 1032, 163, 1421, 2792, 2148, 1735, 220, 1896, 2887, 2163, 357, 2301, 1830, 163, 1812, 805, 1850, 2017, 2313, 1205, 2226, 703, 866, 1708, 1426, 1920, 2911, 267, 3134, 629, 2120, 2022, 2847, 2945, 2967, 1977, 1449, 2028, 1381, 2738, 1098, 2977, 2217, 2060, 710, 845, 2807, 509, 2512, 2444, 2355, 550, 2965, 2517, 1802, 1755, 1065, 1938, 388, 2365, 776, 2453, 1799, 1532, 384, 2266, 1071, 2063, 2858, 1414, 663, 2886, 2734, 209, 1061, 2142, 841, 1081, 977, 799, 2661, 588, 3222, 2140, 2383, 3044, 394, 231, 1090, 917, 1840, 3002, 2315, 1182, 2744, 2815, 2612, 2586, 970, 3301, 3028, 2890, 1849, 269, 2936, 1525, 3102, 3144, 1605, 2746, 1556, 537, 2918, 2549, 976, 250, 2137, 492, 729, 392, 1115, 2422, 2100, 2317, 1636, 1743, 1279, 1393, 2079, 2874, 2148, 233, 1469, 3143, 2109, 1211, 2318, 1138, 2979, 1383, 125, 1995, 1614, 1435, 2216, 782, 671, 662, 988, 2826, 2162, 605, 2955, 2478, 2375, 1449, 2307, 1921, 1285, 2208, 2422, 1035, 2765, 923, 2138, 3053, 812, 146, 1175, 61}; \ No newline at end of file diff --git a/HLS_Codes_Dilithium/ntt.h b/HLS_Codes_Dilithium/ntt.h new file mode 100644 index 0000000..090eec6 --- /dev/null +++ b/HLS_Codes_Dilithium/ntt.h @@ -0,0 +1,38 @@ +// ntt.h +#ifndef NTT_DILITHIUM_H +#define NTT_DILITHIUM_H + +#include +#include +#include + +// Dilithium / ML-DSA-style ring +// R_q = Z_q[x]/(x^256 - 1), q = 8380417 +// Only the *ring* and NTT structure are reflected here; we’re not +// implementing full ML-DSA, just the polynomial multiplier. + +#define DILITHIUM_N 256 +#define DILITHIUM_Q 8380417 + +// Coefficients: we use 32-bit signed ints. q < 2^23, so this is safe. +typedef ap_int<32> coeff_t; +typedef ap_int<64> wide_t; + +// Simple 32-bit AXI4-Stream word: one coeff per beat. +// You’ll send 256 words for a, then 256 words for b. +typedef ap_axiu<32, 0, 0, 0> coeff_axis_t; + +// Modular arithmetic helpers (declared here, defined in the .cpp) +coeff_t mod_q(wide_t x); +coeff_t mod_add(coeff_t a, coeff_t b); +coeff_t mod_sub(coeff_t a, coeff_t b); + +// NTT / INTT prototypes +void ntt(coeff_t a[DILITHIUM_N]); +void inv_ntt(coeff_t a[DILITHIUM_N]); + +// Top-level polymult for Dilithium ring +int poly_mult_dilithium(hls::stream &axis_in, + hls::stream &axis_out); + +#endif // NTT_DILITHIUM_H diff --git a/HLS_Codes_Dilithium/pm_test.cpp b/HLS_Codes_Dilithium/pm_test.cpp new file mode 100644 index 0000000..78770fe --- /dev/null +++ b/HLS_Codes_Dilithium/pm_test.cpp @@ -0,0 +1,93 @@ +// pm_test_dilithium.cpp – C-sim / C-synth testbench for Dilithium polymult IP + +#include +#include + +#include "ntt.h" +#include "test_case.h" + +int main() +{ + // Top-level AXI4-Stream ports for the DUT + hls::stream in_data; + hls::stream out_data; + + coeff_axis_t local_stream; + coeff_t actual_outputs[DILITHIUM_N]; + + int i; + + // ------------------------------------------------------------------------- + // Write stimulus into input AXI4-Stream + // + // Protocol for poly_mult_dilithium(): + // - first DILITHIUM_N words: a[0..N-1] + // - next DILITHIUM_N words: b[0..N-1] + // TLAST = 1 only on the very last word (b[N-1]) + // ------------------------------------------------------------------------- + + // Send polynomial a + for (i = 0; i < DILITHIUM_N; i++) { + coeff_t val = input1_vals[i]; + + local_stream.data = (ap_int<32>)val; + local_stream.keep = -1; + local_stream.strb = -1; + local_stream.last = 0; // not last yet + + in_data.write(local_stream); + } + + // Send polynomial b + for (i = 0; i < DILITHIUM_N; i++) { + coeff_t val = input2_vals[i]; + + local_stream.data = (ap_int<32>)val; + local_stream.keep = -1; + local_stream.strb = -1; + local_stream.last = (i == DILITHIUM_N - 1) ? 1 : 0; + + in_data.write(local_stream); + } + + // ------------------------------------------------------------------------- + // Call DUT + // ------------------------------------------------------------------------- + poly_mult_dilithium(in_data, out_data); + + // ------------------------------------------------------------------------- + // Read result from output AXI4-Stream + // + // The core outputs exactly DILITHIUM_N coefficients; TLAST is 1 on the + // final beat. We read all N coefficients into actual_outputs[]. + // ------------------------------------------------------------------------- + for (i = 0; i < DILITHIUM_N; i++) { + coeff_axis_t out_word = out_data.read(); + actual_outputs[i] = (coeff_t)out_word.data; + + // Optional: you can sanity-check TLAST here: + // if (i == DILITHIUM_N - 1 && out_word.last != 1) ... + } + + // ------------------------------------------------------------------------- + // Compare against golden output + // ------------------------------------------------------------------------- + int ret_val = 0; + for (i = 0; i < DILITHIUM_N; i++) { + if (output_vals[i] != actual_outputs[i]) { + ret_val++; + std::cout << "Mismatch at index " << i + << ": got " << (long long)actual_outputs[i] + << ", expected " << (long long)output_vals[i] + << std::endl; + break; // stop at first mismatch (like your Kyber testbench) + } + } + + if (ret_val == 0) { + std::cout << "All " << DILITHIUM_N + << " coefficients match golden output." << std::endl; + } + + return ret_val; +} diff --git a/HLS_Codes_Dilithium/polymult.cpp b/HLS_Codes_Dilithium/polymult.cpp new file mode 100644 index 0000000..55e36d5 --- /dev/null +++ b/HLS_Codes_Dilithium/polymult.cpp @@ -0,0 +1,228 @@ +// polymult.cpp +#include "ntt.h" + + +// n^{-1} mod q, for n = 256, q = 8380417 +// Computed so that (256 * DILITHIUM_INV_N) % DILITHIUM_Q == 1. +static const coeff_t DILITHIUM_INV_N = (coeff_t)8347681; + +// --------------------------------------------------------------------- +// Modular arithmetic +// --------------------------------------------------------------------- + +coeff_t mod_q(wide_t x) +{ + #pragma HLS inline + // Reduce into [0, q-1] via 64-bit modulo. + // This is not the most hardware-optimal, but it is simple and correct. + wide_t r = x % (wide_t)DILITHIUM_Q; + if (r < 0) + r += DILITHIUM_Q; + return (coeff_t)r; +} + +coeff_t mod_add(coeff_t a, coeff_t b) +{ + #pragma HLS inline + wide_t s = (wide_t)a + (wide_t)b; + if (s >= DILITHIUM_Q) + s -= DILITHIUM_Q; + return (coeff_t)s; +} + +coeff_t mod_sub(coeff_t a, coeff_t b) +{ + #pragma HLS inline + wide_t d = (wide_t)a - (wide_t)b; + if (d < 0) + d += DILITHIUM_Q; + return (coeff_t)d; +} + +// One step of the butterfly: t = zeta * b, then (a+t, a-t) +static void butterfly(coeff_t &a, coeff_t &b, coeff_t zeta) +{ + #pragma HLS inline + wide_t prod = (wide_t)zeta * (wide_t)b; + coeff_t t = mod_q(prod); + coeff_t a0 = a; + a = mod_add(a0, t); + b = mod_sub(a0, t); +} + +// --------------------------------------------------------------------- +// Forward NTT in-place on a[0..255] for R_q = Z_q[x]/(x^256 - 1) +// Standard Cooley–Tukey style iteration over layers. +// The zetas[] table and its traversal pattern must match the Dilithium spec. +// --------------------------------------------------------------------- +void ntt(coeff_t a[DILITHIUM_N]) +{ + #pragma HLS inline off + #pragma HLS ARRAY_PARTITION variable=a cyclic factor=4 dim=1 + + unsigned k = 0; + // len goes 128,64,32,...,1 + for (unsigned len = DILITHIUM_N >> 1; len >= 1; len >>= 1) { + #pragma HLS loop_tripcount min=1 max=8 + for (unsigned start = 0; start < DILITHIUM_N; start += 2 * len) { + #pragma HLS loop_tripcount min=1 max=256 + coeff_t zeta = dilithium_zetas[k++]; + for (unsigned j = start; j < start + len; j++) { + #pragma HLS PIPELINE II=1 + coeff_t &aj = a[j]; + coeff_t &ajl = a[j + len]; + butterfly(aj, ajl, zeta); + } + } + if (len == 1) break; // avoid unsigned wrap + } +} + +// --------------------------------------------------------------------- +// Inverse NTT (Gentleman–Sande style) on a[0..255], followed by scaling +// by n^{-1}. Pattern mirrors standard Dilithium invNTT. +// --------------------------------------------------------------------- +void inv_ntt(coeff_t a[DILITHIUM_N]) +{ + #pragma HLS inline off + #pragma HLS ARRAY_PARTITION variable=a cyclic factor=4 dim=1 + + unsigned k = 0; + // len goes 1,2,4,...,128 + for (unsigned len = 1; len < DILITHIUM_N; len <<= 1) { + #pragma HLS loop_tripcount min=1 max=8 + for (unsigned start = 0; start < DILITHIUM_N; start += 2 * len) { + #pragma HLS loop_tripcount min=1 max=256 + coeff_t zeta = dilithium_zetas_inv[k++]; + for (unsigned j = start; j < start + len; j++) { + #pragma HLS PIPELINE II=1 + coeff_t u = a[j]; + coeff_t v = a[j + len]; + + coeff_t a_sum = mod_add(u, v); + coeff_t a_diff = mod_sub(u, v); + + // Multiply the "difference" branch by zeta + wide_t prod = (wide_t)zeta * (wide_t)a_diff; + coeff_t t = mod_q(prod); + + a[j] = a_sum; + a[j+len] = t; + } + } + } + + // Final scaling by n^{-1} mod q + for (unsigned i = 0; i < DILITHIUM_N; i++) { + #pragma HLS PIPELINE II=1 + wide_t prod = (wide_t)a[i] * (wide_t)DILITHIUM_INV_N; + a[i] = mod_q(prod); + } +} + +// --------------------------------------------------------------------- +// Polynomial multiplication in R_q = Z_q[x]/(x^256 - 1): +// c = a * b (using NTT, pointwise multiply, INTT) +// --------------------------------------------------------------------- +static void poly_mult_internal(coeff_t a[DILITHIUM_N], + coeff_t b[DILITHIUM_N], + coeff_t c[DILITHIUM_N]) +{ + #pragma HLS inline off + #pragma HLS DATAFLOW + + // In-place NTT + ntt(a); + ntt(b); + + // Pointwise multiply + for (unsigned i = 0; i < DILITHIUM_N; i++) { + #pragma HLS PIPELINE II=1 + wide_t prod = (wide_t)a[i] * (wide_t)b[i]; + c[i] = mod_q(prod); + } + + // In-place inverse NTT on c + inv_ntt(c); +} + +// --------------------------------------------------------------------- +// AXI4-Stream <-> internal coefficient buffer +// Protocol: +// +// Input stream (axis_in): +// - First 256 beats : a[0..255], last = 0 +// - Next 256 beats : b[0..255], last = 1 on the final beat +// +// Output stream (axis_out): +// - 256 beats with c[0..255], last = 1 on the final beat +// --------------------------------------------------------------------- +static void axis_read_polys(hls::stream &axis_in, + coeff_t a[DILITHIUM_N], + coeff_t b[DILITHIUM_N]) +{ + #pragma HLS inline off + coeff_axis_t w; + + // Read a + for (unsigned i = 0; i < DILITHIUM_N; i++) { + #pragma HLS PIPELINE II=1 + w = axis_in.read(); + a[i] = (coeff_t)w.data; + } + + // Read b + for (unsigned i = 0; i < DILITHIUM_N; i++) { + #pragma HLS PIPELINE II=1 + w = axis_in.read(); + b[i] = (coeff_t)w.data; + } +} + +static void axis_write_poly(hls::stream &axis_out, + coeff_t c[DILITHIUM_N]) +{ + #pragma HLS inline off + + coeff_axis_t w; + w.keep = -1; + w.strb = -1; + + for (unsigned i = 0; i < DILITHIUM_N; i++) { + #pragma HLS PIPELINE II=1 + w.data = (ap_int<32>)c[i]; + w.last = (i == DILITHIUM_N - 1) ? 1 : 0; + axis_out.write(w); + } +} + +// --------------------------------------------------------------------- +// Top-level HLS function: poly_mult_dilithium +// +// This is your second accelerator, to be instantiated alongside the +// Kyber NTT IP in the same overlay. It has its own AXI4-Stream ports, +// but you can reuse the DMA driver pattern from Kyber by adjusting +// the packing (no 2x16-bit packing here, just one 32-bit coeff per beat). +// --------------------------------------------------------------------- +int poly_mult(hls::stream &axis_in, + hls::stream &axis_out) +{ + #pragma HLS INTERFACE axis register port=axis_in + #pragma HLS INTERFACE axis register port=axis_out + #pragma HLS INTERFACE s_axilite port=return bundle=CTRL_BUS + #pragma HLS DATAFLOW + + coeff_t a[DILITHIUM_N]; + coeff_t b[DILITHIUM_N]; + coeff_t c[DILITHIUM_N]; + + #pragma HLS ARRAY_PARTITION variable=a cyclic factor=4 dim=1 + #pragma HLS ARRAY_PARTITION variable=b cyclic factor=4 dim=1 + #pragma HLS ARRAY_PARTITION variable=c cyclic factor=4 dim=1 + + axis_read_polys(axis_in, a, b); + poly_mult_internal(a, b, c); + axis_write_poly(axis_out, c); + + return 0; +} diff --git a/HLS_Codes_Dilithium/test_case.h b/HLS_Codes_Dilithium/test_case.h new file mode 100644 index 0000000..cbf838c --- /dev/null +++ b/HLS_Codes_Dilithium/test_case.h @@ -0,0 +1,21 @@ +#ifndef TEST_CASE_DILITHIUM_H +#define TEST_CASE_DILITHIUM_H + +#include "ntt.h" + +// Length must be DILITHIUM_N (256) +static const coeff_t input1_vals[DILITHIUM_N] = { + // TODO: fill with 256 coefficients of polynomial a (mod 8380417) + // e.g. generated by a software reference polymult +}; + +static const coeff_t input2_vals[DILITHIUM_N] = { + // TODO: 256 coefficients of polynomial b +}; + +static const coeff_t output_vals[DILITHIUM_N] = { + // TODO: golden result c = a * b in R_q[x]/(x^256 - 1), + // reduced mod q and in the same canonical form as the HW (0..q-1) +}; + +#endif // TEST_CASE_DILITHIUM_H diff --git a/PYNQ-zcu104_Files/Kyber512-LastGood-6dec2025.ipynb b/PYNQ-zcu104_Files/Kyber512-LastGood-6dec2025.ipynb new file mode 100644 index 0000000..329bc02 --- /dev/null +++ b/PYNQ-zcu104_Files/Kyber512-LastGood-6dec2025.ipynb @@ -0,0 +1,2135 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 62, + "id": "163cf7a1", + "metadata": {}, + "outputs": [], + "source": [ + "#oho A\n", + "\n", + "import numpy as np\n", + "import math\n", + "import time\n", + "from pynq import Overlay\n", + "from pynq import allocate\n", + "from pynq import MMIO\n", + "\n", + "\n", + "# Kyber-v2 parameters\n", + "q = 3329\n", + "n2 = 256\n", + "n = 128\n", + "inv_n = 3303\n", + "psin = 17\n", + "inv_psin = 1175\n", + "k = 2\n", + "eta1 = 2\n", + "eta2 = 3\n", + "du = 10\n", + "dv = 4" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "3abcc796", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "total 19224\n", + "drwxr-xr-x 3 root root 4096 Dec 5 17:26 .\n", + "drwxrwxrwx 8 xilinx xilinx 4096 Dec 6 20:33 ..\n", + "-rw-r--r-- 1 root root 19311209 Dec 5 17:26 base.bit\n", + "-rw-r--r-- 1 root root 359733 Dec 5 17:26 base.hwh\n", + "drwxr-xr-x 2 root root 4096 Nov 2 16:11 .ipynb_checkpoints\n", + "Sat Dec 6 08:33:18 PM UTC 2025\n" + ] + } + ], + "source": [ + "#oho B\n", + "!ls -la /home/xilinx/jupyter_notebooks/kyber-ntt\n", + "!date\n", + "# Loading the bit file to configure the PL\n", + "ol = Overlay('/home/xilinx/jupyter_notebooks/kyber-ntt/base.bit')" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "18296772", + "metadata": {}, + "outputs": [], + "source": [ + "#oho C\n", + "# Function to perform bit-reversal\n", + "def bitReverse(num, logn):\n", + " rev_num = 0\n", + " for i in range(logn):\n", + " if (num >> i) & 1:\n", + " rev_num |= 1 << (logn - 1 - i)\n", + " return rev_num\n", + "\n", + "# Function to generate twiddle factors (for both forward and inverse NTT)\n", + "def gen_tf(psin, inv_psin, n, q):\n", + " positions = [bitReverse(x, int(np.log2(n))) for x in range(n)]\n", + " tmp1, tmp2 = [], []\n", + " psis, inv_psis = [], []\n", + " psi = 1\n", + " inv_psi = 1\n", + " for x in range(n):\n", + " tmp1.append(psi)\n", + " tmp2.append(inv_psi)\n", + " psi = psi * psin % q\n", + " inv_psi = inv_psi * inv_psin % q\n", + " for x in range(n):\n", + " val = tmp1[positions[x]]\n", + " inv_val = tmp2[positions[x]]\n", + " psis.append(val)\n", + " inv_psis.append(inv_val)\n", + " return psis, inv_psis\n", + "\n", + "# Function to generate scaling factors for point wise multiplication\n", + "def gen_pwmf(psin, n, q):\n", + " pwmf = []\n", + " for i in range(n):\n", + " val = (psin**(2*bitReverse(i, int(np.log2(n))) + 1))%q\n", + " pwmf.append(val)\n", + " return pwmf\n", + "\n", + "# Functions to generate Centered Binomial Distribution\n", + "def _cbd(n, eta):\n", + " i = 0\n", + " while i < eta:\n", + " p1 = np.random.randint(0, 2, n)\n", + " if i == 0:\n", + " p = p1\n", + " else:\n", + " p = p + p1\n", + " i = i + 1\n", + " return p\n", + "\n", + "def cbd(n, eta):\n", + " a = _cbd(n, eta)\n", + " b = _cbd(n, eta)\n", + " return a - b\n", + " \n", + "def cbd_vector(n, eta, k):\n", + " result = []\n", + "\n", + " for i in range(k):\n", + " result.append(cbd(n, eta))\n", + "\n", + " return np.squeeze(np.array(result, dtype=np.int16))\n", + "\n", + "# Compression function\n", + "def compress(x, q, d):\n", + " q1 = 2**d\n", + " x = np.round(q1 / q * x).astype(np.int16)\n", + " x = np.remainder(x, q1)\n", + " return x\n", + "\n", + "# De-compression function\n", + "def decompress(x, q, d):\n", + " q1 = 2**d\n", + " x = np.round(q / q1 * x).astype(np.int16)\n", + " x = np.remainder(x, q)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "299a54df", + "metadata": {}, + "outputs": [], + "source": [ + "#oho D\n", + "# NTT / INTT - all in SW\n", + "\n", + "# 128 point Forward NTT using Cooley-Tukey (TC) algorithm\n", + "def ct_ntt(a, psis, q, n):\n", + " t = n\n", + " m = 1\n", + " while m < n:\n", + " t = t // 2\n", + " for i in range(m):\n", + " j1 = 2 * i * t\n", + " j2 = j1 + t - 1\n", + " S = psis[m + i]\n", + " for j in range(j1, j2 + 1):\n", + " U = a[j]\n", + " V = a[j + t] * S\n", + " a[j] = (U + V) % q\n", + " a[j + t] = (U - V) % q\n", + " m = 2 * m\n", + " return a\n", + " \n", + "# 128 point Inverse NTT using Gentleman-Sande (GS) algorithm\n", + "def gs_intt(a, inv_psis, q, n, inv_n):\n", + " t = 1\n", + " m = n\n", + " while m > 1:\n", + " j1 = 0\n", + " h = m // 2\n", + " for i in range(h):\n", + " j2 = j1 + t - 1\n", + " S = inv_psis[h + i]\n", + " for j in range(j1, j2 + 1):\n", + " U = a[j]\n", + " V = a[j + t]\n", + " a[j] = (U + V) % q\n", + " a[j + t] = (U - V) * S % q\n", + " j1 = j1 + 2 * t\n", + " t = 2 * t\n", + " m = m // 2\n", + " for i in range(n):\n", + " a[i] = a[i] * inv_n % q\n", + " return a\n", + "\n", + "# 256 point NTT using two 128 point NTTs\n", + "def ntt_256(x, psis, q, n):\n", + " xe, xo = [], []\n", + " for i in range(n2):\n", + " if i%2 == 0:\n", + " xe.append(x[i])\n", + " else:\n", + " xo.append(x[i])\n", + " ye = ct_ntt(xe, psis, q, n)\n", + " yo = ct_ntt(xo, psis, q, n)\n", + " return ye, yo\n", + "\n", + "# 256 point INTT using two 128 point INTTs\n", + "def intt_256(ye, yo, inv_psis, q, n, inv_n):\n", + " ze = gs_intt(ye, inv_psis, q, n, inv_n)\n", + " zo = gs_intt(yo, inv_psis, q, n, inv_n)\n", + " z = []\n", + " for i in range(n):\n", + " z.append(ze[i])\n", + " z.append(zo[i])\n", + " return z\n", + "\n", + "# Point-wise multiplication in NTT domain\n", + "def point_wise_mult(y1e, y1o, y2e, y2o, pwmf):\n", + " y3e, y3o = [], []\n", + " for i in range(n):\n", + " y3e.append(((y1e[i] * y2e[i]) % q + (((y1o[i] * y2o[i]) % q) * pwmf[i]) % q) % q)\n", + " y3o.append(((y1e[i] * y2o[i]) % q + (y1o[i] * y2e[i]) % q) % q)\n", + " return y3e, y3o" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "44f2dd33", + "metadata": {}, + "outputs": [], + "source": [ + "###################################################" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "id": "844faabe", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['axi_dma_0', 'poly_mult_0', 'ps_e_0'])\n", + "poly_mult_0 @ 0x80010000\n" + ] + } + ], + "source": [ + "#oho E\n", + "#oho inquiry for ZCU104 Board\n", + "\n", + "# Discover the real base address from the .hwh\n", + "print(ol.ip_dict.keys()) # just to see exact IP names\n", + "\n", + "poly_name = 'poly_mult_0' # adjust if Vivado named it differently\n", + "poly_info = ol.ip_dict[poly_name]\n", + "print(\"poly_mult_0 @\", hex(poly_info['phys_addr']))\n", + "\n", + "mmio = MMIO(poly_info['phys_addr'], poly_info['addr_range'])\n", + "\n", + "dma = ol.axi_dma_0\n", + "dma_send = dma.sendchannel\n", + "dma_recv = dma.recvchannel" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "id": "af023c05", + "metadata": {}, + "outputs": [], + "source": [ + "####################################################" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "id": "5a6323d2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "n2 = 256\n", + "Number of mismatches: 0\n", + "Hardware poly_mult matches test_case.h vectors.\n" + ] + } + ], + "source": [ + "#oho F\n", + "# Golden Test for Golden Record analog Vitis HLS testbench\n", + "# pm_test.cpp Equivalent (== HW Simulator with golden testvestors in Vitis HLS)\n", + "\n", + "import numpy as np\n", + "from pynq import allocate\n", + "\n", + "# -------------------------------------------------------------------\n", + "# 1. Test vectors from test_case.h\n", + "# Paste the FULL arrays from the header in place of the \"...\" parts\n", + "# -------------------------------------------------------------------\n", + "input1_vals = np.array(\n", + " [\n", + " 1477, 218, 784, 251, 747, 1051, 1924, 133, 2953, 1295, 2989, 1519, 1701, 1874, 2806, 423, 2883, 327, 47, 2525, 1508, 214, 2998, 217, 1852, 2624, 2286, 3039, 3076, 1213, 1808, 2554, 1129, 1353, 2690, 2839, 1778, 2752, 1378, 601, 914, 2335, 2497, 1139, 2611, 129, 1318, 1570, 3190, 1868, 940, 2901, 2626, 2473, 3195, 2621, 2436, 3046, 1018, 1139, 1729, 3021, 2064, 945, 690, 1700, 1836, 1943, 2333, 2131, 1618, 1741, 2639, 2653, 301, 2013, 2744, 2406, 2995, 2463, 2366, 1495, 442, 224, 1349, 11, 2342, 1712, 2847, 1578, 2654, 2734, 3131, 1245, 1862, 527, 2400, 2043, 1360, 451, 573, 898, 2018, 3100, 161, 284, 1949, 362, 755, 2916, 1288, 1616, 876, 1682, 853, 2772, 2956, 1101, 2, 214, 2589, 211, 1025, 610, 1225, 2118, 224, 1296, 2612, 2634, 2056, 3227, 1712, 1258, 552, 1345, 786, 2124, 2915, 1226, 1233, 2654, 2786, 2636, 2234, 727, 2444, 199, 600, 2262, 3221, 915, 63, 318, 74, 2396, 1690, 2390, 1711, 414, 10, 2298, 1082, 1419, 3151, 1723, 2744, 3274, 2518, 2954, 1208, 2941, 2089, 3288, 1370, 783, 2517, 3190, 3069, 2505, 2840, 1427, 1670, 3091, 655, 96, 1935, 880, 2511, 876, 2371, 341, 196, 2849, 919, 161, 603, 2993, 2903, 1721, 139, 3326, 1876, 379, 2508, 2094, 1929, 430, 1033, 2604, 1955, 1333, 2274, 3312, 2604, 1585, 2317, 3230, 3068, 2905, 3268, 2844, 1023, 2824, 1731, 643, 820, 462, 2975, 314, 2218, 2011, 649, 383, 874, 2181, 866, 1192, 2914, 2290, 1820, 1572, 1030, 3076, 1526, 2760, 12, 529, 1242, 560, 2723, 2894, 1097, 778, 1495, 371\n", + " ],\n", + " dtype=np.int16,\n", + ")\n", + "\n", + "input2_vals = np.array(\n", + " [\n", + " 2960, 3124, 509, 485, 2525, 385, 608, 2893, 2423, 1802, 2556, 1090, 775, 2059, 898, 864, 2459, 1116, 551, 188, 3262, 2728, 3134, 2451, 427, 858, 1927, 830, 2688, 2388, 2818, 1418, 3298, 24, 2491, 1448, 1153, 178, 2489, 2126, 1772, 669, 1238, 633, 1919, 2222, 2673, 1918, 2202, 3312, 208, 976, 2267, 107, 2905, 1137, 2921, 2471, 2796, 1313, 485, 1982, 1557, 1203, 2930, 241, 3089, 890, 2193, 179, 952, 2057, 2444, 1378, 1466, 1362, 1808, 2343, 1532, 2651, 727, 3254, 1328, 1604, 967, 2418, 1266, 1826, 684, 2869, 3149, 1874, 1691, 1507, 339, 2473, 102, 3153, 969, 1551, 548, 3059, 2841, 1369, 148, 2510, 2025, 1369, 1579, 2474, 1093, 527, 1416, 981, 2320, 2305, 227, 2173, 812, 1703, 2952, 17, 1129, 2223, 1894, 959, 73, 339, 553, 1466, 1065, 617, 1749, 1896, 1838, 1771, 3092, 297, 996, 198, 521, 567, 3256, 2783, 1044, 2644, 744, 2986, 3178, 1522, 942, 2045, 236, 1866, 853, 2303, 2383, 3095, 418, 2752, 2105, 2896, 3081, 3067, 1696, 978, 102, 1961, 3120, 2741, 1029, 885, 2852, 2659, 2815, 3032, 2358, 3252, 1195, 3304, 878, 70, 3069, 2726, 2455, 182, 108, 2868, 1744, 1697, 1060, 1803, 1752, 829, 2434, 862, 2287, 2860, 352, 634, 2626, 1920, 2425, 239, 831, 2527, 1190, 1469, 2602, 1711, 2185, 1403, 3189, 1188, 2649, 2079, 2215, 790, 409, 2413, 627, 2268, 2507, 2102, 1727, 1146, 2711, 355, 1143, 1225, 430, 82, 3015, 2699, 642, 863, 241, 450, 440, 338, 365, 2621, 3022, 204, 149, 2986, 2191, 1793, 3085, 2128, 373, 290, 835, 580, 2530, 1948\n", + " ],\n", + " dtype=np.int16,\n", + ")\n", + "\n", + "output_vals = np.array(\n", + " [\n", + " 2762, 3061, 1101, 3267, 2744, 1349, 182, 1761, 3089, 751, 137, 368, 1461, 2956, 493, 1653, 2617, 721, 356, 3034, 2234, 1556, 809, 2290, 1597, 457, 811, 259, 685, 2478, 319, 2519, 1049, 837, 644, 2571, 1029, 2997, 762, 1710, 2110, 1099, 2513, 1038, 2176, 1938, 3214, 261, 1604, 2474, 5, 1211, 2816, 2848, 2286, 3146, 1777, 1630, 2412, 1457, 889, 671, 822, 2369, 1409, 2059, 1121, 1871, 303, 1178, 2241, 1827, 2046, 628, 2869, 749, 1666, 895, 580, 1770, 2082, 3123, 1192, 520, 168, 2461, 1032, 163, 1421, 2792, 2148, 1735, 220, 1896, 2887, 2163, 357, 2301, 1830, 163, 1812, 805, 1850, 2017, 2313, 1205, 2226, 703, 866, 1708, 1426, 1920, 2911, 267, 3134, 629, 2120, 2022, 2847, 2945, 2967, 1977, 1449, 2028, 1381, 2738, 1098, 2977, 2217, 2060, 710, 845, 2807, 509, 2512, 2444, 2355, 550, 2965, 2517, 1802, 1755, 1065, 1938, 388, 2365, 776, 2453, 1799, 1532, 384, 2266, 1071, 2063, 2858, 1414, 663, 2886, 2734, 209, 1061, 2142, 841, 1081, 977, 799, 2661, 588, 3222, 2140, 2383, 3044, 394, 231, 1090, 917, 1840, 3002, 2315, 1182, 2744, 2815, 2612, 2586, 970, 3301, 3028, 2890, 1849, 269, 2936, 1525, 3102, 3144, 1605, 2746, 1556, 537, 2918, 2549, 976, 250, 2137, 492, 729, 392, 1115, 2422, 2100, 2317, 1636, 1743, 1279, 1393, 2079, 2874, 2148, 233, 1469, 3143, 2109, 1211, 2318, 1138, 2979, 1383, 125, 1995, 1614, 1435, 2216, 782, 671, 662, 988, 2826, 2162, 605, 2955, 2478, 2375, 1449, 2307, 1921, 1285, 2208, 2422, 1035, 2765, 923, 2138, 3053, 812, 146, 1175, 61\n", + " ],\n", + " dtype=np.int16,\n", + ")\n", + "\n", + "# HLS core expects Nt = 256 words, and each word packs two 16-bit coeffs\n", + "n2 = len(input1_vals)\n", + "assert n2 == len(input2_vals) == len(output_vals), \"Vector length mismatch\"\n", + "print(\"n2 =\", n2)\n", + "assert n2 == 256, \"HLS core expects Nt = 256\"\n", + "\n", + "\n", + "# -------------------------------------------------------------------\n", + "# 2. Hardware poly_mult wrapper matching pm_test.cpp packing\n", + "# -------------------------------------------------------------------\n", + "def poly_mul_hw_golden(x1: np.ndarray, x2: np.ndarray) -> np.ndarray:\n", + " \"\"\"\n", + " Hardware polynomial multiplication:\n", + " - x1, x2: length-n2 arrays of 16-bit coefficients (numpy.int16)\n", + " - Returns: length-n2 array of 16-bit coeffs (numpy.int16)\n", + "\n", + " Packing is exactly like pm_test.cpp:\n", + " data[15:0] = input1_vals[i]\n", + " data[31:16] = input2_vals[i]\n", + " \"\"\"\n", + "\n", + " assert x1.shape == (n2,)\n", + " assert x2.shape == (n2,)\n", + "\n", + " # Allocate DMA buffers\n", + " input_buffer = allocate(shape=(n2,), dtype=np.int32)\n", + " output_buffer = allocate(shape=(n2,), dtype=np.int16)\n", + "\n", + " # Pack two 16-bit coeffs into one 32-bit word: [x2 | x1]\n", + " x1_32 = x1.astype(np.int32)\n", + " x2_32 = x2.astype(np.int32)\n", + " packed = (x1_32 & 0xFFFF) | ((x2_32 & 0xFFFF) << 16)\n", + " input_buffer[:] = packed\n", + "\n", + " # Start IP core (ap_start = 1 at control register 0x00)\n", + " mmio.write(0x00, 0x1)\n", + "\n", + " # Launch DMA transfers\n", + " dma_send.transfer(input_buffer)\n", + " dma_recv.transfer(output_buffer)\n", + " dma_send.wait()\n", + " dma_recv.wait()\n", + "\n", + " # Copy result out\n", + " y = np.array(output_buffer, dtype=np.int16)\n", + "\n", + " # Free buffers (depending on your Pynq version, freebuffer() may exist)\n", + " try:\n", + " input_buffer.freebuffer()\n", + " output_buffer.freebuffer()\n", + " except AttributeError:\n", + " del input_buffer, output_buffer\n", + "\n", + " return y\n", + "\n", + "\n", + "# -------------------------------------------------------------------\n", + "# 3. Run hardware test against golden output_vals from test_case.h\n", + "# -------------------------------------------------------------------\n", + "hw_out = poly_mul_hw_golden(input1_vals, input2_vals)\n", + "\n", + "diff = np.where(hw_out != output_vals)[0]\n", + "print(\"Number of mismatches:\", diff.size)\n", + "\n", + "if diff.size > 0:\n", + " print(\"First mismatches (idx, hw, expected):\")\n", + " for idx in diff[:20]:\n", + " print(f\"{idx:3d}: hw={int(hw_out[idx])}, expected={int(output_vals[idx])}\")\n", + "else:\n", + " print(\"Hardware poly_mult matches test_case.h vectors.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "id": "c4dc2c5a", + "metadata": {}, + "outputs": [], + "source": [ + "############################################################" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "id": "5c5f4f54", + "metadata": {}, + "outputs": [], + "source": [ + "############################################################" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "id": "051faf93", + "metadata": {}, + "outputs": [], + "source": [ + "############################################################" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "id": "e41ca874", + "metadata": {}, + "outputs": [], + "source": [ + "# oho G-1\n", + "# Naive polynomial multiplication under mod (x^n2 + 1) in Software\n", + "# [i.e. negative wrapped convolution, quadratic cost O(n^2)]\n", + "def poly_mul_sw_naive(x1, x2):\n", + " \"\"\"\n", + " Schoolbook polynomial multiplication in R_q = Z_q[X]/(X^n2 + 1).\n", + "\n", + " Convention:\n", + " if i + j < n2: res[i+j] += x1[i] * x2[j]\n", + " else: res[i+j-n2] -= x1[i] * x2[j]\n", + "\n", + " This matches the (X^n + 1) / \"negacyclic\" convention.\n", + " \"\"\"\n", + "\n", + " # ensure proper types and shapes\n", + " a = np.asarray(x1, dtype=np.int64)\n", + " b = np.asarray(x2, dtype=np.int64)\n", + "\n", + " assert a.shape == (n2,), f\"a.shape={a.shape}, expected ({n2},)\"\n", + " assert b.shape == (n2,), f\"b.shape={b.shape}, expected ({n2},)\"\n", + "\n", + " res = np.zeros(n2, dtype=np.int64)\n", + "\n", + " for i in range(n2):\n", + " ai = int(a[i])\n", + " for j in range(n2):\n", + " t = ai * int(b[j])\n", + " k = i + j\n", + " if k < n2:\n", + " res[k] += t\n", + " else:\n", + " res[k - n2] -= t\n", + "\n", + " res %= q\n", + " return res.astype(np.int16)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "id": "5dbc2c8a", + "metadata": {}, + "outputs": [], + "source": [ + "# oho G-2\n", + "# encrypt-decrypt declarations in SW using *naive* polynomial multiplication\n", + "\n", + "# Kyber PKE functions entirely in SW (schoolbook poly mult)\n", + "\n", + "# Key generation function (to be performed by server)\n", + "def key_gen_naive():\n", + " a = np.random.randint(q, size=(k, k, n2))\n", + " s = cbd_vector(n2, eta1, k)\n", + " e = cbd_vector(n2, eta1, k)\n", + "\n", + " b0 = (poly_mul_sw_naive(a[0,0], s[0]) + e[0]) % q\n", + " b1 = (poly_mul_sw_naive(a[0,1], s[1]) + e[1]) % q\n", + " b2 = (poly_mul_sw_naive(a[1,0], s[0]) + e[0]) % q\n", + " b3 = (poly_mul_sw_naive(a[1,1], s[1]) + e[1]) % q\n", + "\n", + " b01 = (b0 + b1) % q\n", + " b23 = (b2 + b3) % q\n", + " b = np.array([b01, b23])\n", + " return s, a, b\n", + "\n", + "# Encryption function (to be performed by client)\n", + "def encrypt_naive(a, b, m):\n", + " r = cbd_vector(n2, eta1, k)\n", + " e1 = cbd_vector(n2, eta2, k)\n", + " e2 = cbd(n2, eta2)\n", + "\n", + " u0 = (poly_mul_sw_naive(a[0,0], r[0]) + e1[0]) % q\n", + " u1 = (poly_mul_sw_naive(a[1,0], r[1]) + e1[1]) % q\n", + " u2 = (poly_mul_sw_naive(a[0,1], r[0]) + e1[0]) % q\n", + " u3 = (poly_mul_sw_naive(a[1,1], r[1]) + e1[1]) % q\n", + "\n", + " u01 = (u0 + u1) % q\n", + " u23 = (u2 + u3) % q\n", + " u = np.array([u01, u23])\n", + "\n", + " v0 = np.array(poly_mul_sw_naive(b[0], r[0]))\n", + " v1 = np.array(poly_mul_sw_naive(b[1], r[1]))\n", + " v = (v0 + v1 + e2 + m) % q\n", + "\n", + " u = compress(u, q, du)\n", + " v = compress(v, q, dv)\n", + " return u, v\n", + "\n", + "# Decryption function (to be performed by server)\n", + "def decrypt_naive(s, u, v):\n", + " u_dec = decompress(u, q, du)\n", + " v_dec = decompress(v, q, dv)\n", + "\n", + " p0 = np.array(poly_mul_sw_naive(s[0], u_dec[0]))\n", + " p1 = np.array(poly_mul_sw_naive(s[1], u_dec[1]))\n", + " p = (p0 + p1) % q\n", + " d = (v_dec - p) % q\n", + " return d\n" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "1c9d6133", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Actual message :\n", + " [0 0 1 1 0 1 1 0 1 0 1 0 0 0 0 1 1 0 0 0 0 1 1 0 0 0 1 1 1 1 0 0 0 0 1 0 1\n", + " 0 1 1 0 0 0 0 0 1 0 1 0 1 1 0 1 1 0 1 0 0 0 0 0 0 1 0 1 1 1 0 1 0 1 1 1 1\n", + " 0 0 0 0 0 1 0 0 0 1 1 0 1 0 1 1 1 1 1 0 0 1 1 1 1 1 0 0 0 1 0 1 1 1 1 1 1\n", + " 0 0 1 1 0 1 1 0 0 0 1 1 0 0 0 1 0 0 0 0 1 1 1 1 0 0 1 0 1 1 0 0 0 0 1 1 0\n", + " 0 1 0 0 1 0 1 1 0 0 0 1 1 1 1 0 0 0 1 0 1 1 1 1 1 1 1 1 1 0 1 1 1 1 0 0 0\n", + " 1 1 1 1 0 1 0 1 0 0 1 1 0 1 1 1 0 1 0 1 1 1 1 1 0 0 0 1 0 0 1 0 0 0 0 1 0\n", + " 1 0 1 1 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 1 1 1 1 0 0 0 0 1 0 1 0 0 0 1]\n", + "Decrypted message :\n", + " [0 0 1 1 0 1 1 0 1 0 1 0 0 0 0 1 1 0 0 0 0 1 1 0 0 0 1 1 1 1 0 0 0 0 1 0 1\n", + " 0 1 1 0 0 0 0 0 1 0 1 0 1 1 0 1 1 0 1 0 0 0 0 0 0 1 0 1 1 1 0 1 0 1 1 1 1\n", + " 0 0 0 0 0 1 0 0 0 1 1 0 1 0 1 1 1 1 1 0 0 1 1 1 1 1 0 0 0 1 0 1 1 1 1 1 1\n", + " 0 0 1 1 0 1 1 0 0 0 1 1 0 0 0 1 0 0 0 0 1 1 1 1 0 0 1 0 1 1 0 0 0 0 1 1 0\n", + " 0 1 0 0 1 0 1 1 0 0 0 1 1 1 1 0 0 0 1 0 1 1 1 1 1 1 1 1 1 0 1 1 1 1 0 0 0\n", + " 1 1 1 1 0 1 0 1 0 0 1 1 0 1 1 1 0 1 0 1 1 1 1 1 0 0 0 1 0 0 1 0 0 0 0 1 0\n", + " 1 0 1 1 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 1 1 1 1 0 0 0 0 1 0 1 0 0 0 1]\n", + "Actual message and decrypted message are the same!\n", + "\n", + "Time taken by SW(naive) only = 4.159699201583862 seconds\n" + ] + } + ], + "source": [ + "# oho G-3\n", + "# Full SW Run Encrypt-Decrypt using *naive* polynomial multiplication\n", + "\n", + "# (No NTT precomputation needed, but harmless if psis/pwmf already exist)\n", + "\n", + "start_sw_naive = time.time()\n", + "\n", + "# Randomly generated binary message, m\n", + "m = np.random.randint(2, size=(n2,))\n", + "ms = decompress(m, q, 1)\n", + "\n", + "# Generating private key (s) and public keys (a,b) with naive poly-mult\n", + "s, a, b = key_gen_naive()\n", + "\n", + "# Encrypting the message using public keys to provide cipher texts (u,v)\n", + "u, v = encrypt_naive(a, b, ms)\n", + "\n", + "# Decrypt the cipher using private key to obtain back the message (d)\n", + "d = decrypt_naive(s, u, v)\n", + "\n", + "# Decoding the decrypted message\n", + "md = []\n", + "for x in d:\n", + " if x > math.floor(q/4) and x < math.floor(3*q/4):\n", + " md.append(1)\n", + " else:\n", + " md.append(0)\n", + "md = np.array(md)\n", + "\n", + "end_sw_naive = time.time()\n", + "\n", + "# Comparison and printing results\n", + "print(\"Actual message :\\n\", m)\n", + "print(\"Decrypted message :\\n\", md)\n", + "\n", + "if (list(m) == list(md)):\n", + " print(\"Actual message and decrypted message are the same!\")\n", + "else:\n", + " print(\"There is mismatch ....\")\n", + "\n", + "print()\n", + "print(\"Time taken by SW(naive) only =\", end_sw_naive - start_sw_naive, \"seconds\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "id": "289dd986", + "metadata": {}, + "outputs": [], + "source": [ + "######################################################################" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "id": "b632edb3", + "metadata": {}, + "outputs": [], + "source": [ + "#oho H-1\n", + "# Polynomial multiplication under mod (x^n + 1) in Software\n", + "# NTT/INTT version (not schoolbook)\n", + "# [i.e negative wrapped convolution]\n", + "def poly_mul_sw(x1, x2):\n", + "\n", + " y1e, y1o = ntt_256(x1, psis, q, n)\n", + " y2e, y2o = ntt_256(x2, psis, q, n)\n", + "\n", + " y3e, y3o = point_wise_mult(y1e, y1o, y2e, y2o, pwmf)\n", + "\n", + " z = intt_256(y3e, y3o, inv_psis, q, n, inv_n)\n", + "\n", + " return z" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "id": "25429654", + "metadata": {}, + "outputs": [], + "source": [ + "#oho H-2\n", + "# encrypt-decrypt declarations in SW (NTT Version)\n", + "\n", + "# Kyber PKE functions entirely in SW\n", + "# Key generation function (to be performed by server)\n", + "def key_gen():\n", + " a = np.random.randint(q, size=(k,k,n2))\n", + " s = cbd_vector(n2, eta1, k)\n", + " e = cbd_vector(n2, eta1, k)\n", + " b0 = (poly_mul_sw(a[0,0], s[0]) + e[0]) % q\n", + " b1 = (poly_mul_sw(a[0,1], s[1]) + e[1]) % q\n", + " b2 = (poly_mul_sw(a[1,0], s[0]) + e[0]) % q\n", + " b3 = (poly_mul_sw(a[1,1], s[1]) + e[1]) % q\n", + " b01 = (b0 + b1) % q\n", + " b23 = (b2 + b3) % q\n", + " b = np.array([b01, b23])\n", + " return s, a, b\n", + "\n", + "# Encryption function (to be performed by client)\n", + "def encrypt(a, b, m):\n", + " r = cbd_vector(n2, eta1, k)\n", + " e1 = cbd_vector(n2, eta2, k)\n", + " e2 = cbd(n2, eta2)\n", + " u0 = (poly_mul_sw(a[0,0], r[0]) + e1[0]) % q\n", + " u1 = (poly_mul_sw(a[1,0], r[1]) + e1[1]) % q\n", + " u2 = (poly_mul_sw(a[0,1], r[0]) + e1[0]) % q\n", + " u3 = (poly_mul_sw(a[1,1], r[1]) + e1[1]) % q\n", + " u01 = (u0 + u1) % q\n", + " u23 = (u2 + u3) % q\n", + " u = np.array([u01, u23])\n", + " v0 = np.array(poly_mul_sw(b[0], r[0]))\n", + " v1 = np.array(poly_mul_sw(b[1], r[1]))\n", + " v = (v0 + v1 + e2 + m) % q\n", + " u = compress(u, q, du)\n", + " v = compress(v, q, dv)\n", + " return u, v\n", + "\n", + "# Decryption function (to be performed by server)\n", + "def decrypt(s, u, v):\n", + " u = decompress(u, q, du)\n", + " v = decompress(v, q, dv)\n", + " p0 = np.array(poly_mul_sw(s[0], u[0]))\n", + " p1 = np.array(poly_mul_sw(s[1], u[1]))\n", + " p = (p0 + p1) % q\n", + " d = (v - p) % q\n", + " return d" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "70c0b30e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Actual message :\n", + " [0 1 1 0 1 1 0 1 0 0 0 0 0 1 1 0 1 1 1 1 0 0 1 1 1 0 0 0 1 1 1 1 0 0 0 1 1\n", + " 1 0 1 0 0 1 0 0 1 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 1 1 0 1 0 1 0 1 1 1\n", + " 0 0 0 0 1 0 0 1 0 0 1 0 0 1 1 1 1 0 0 1 0 0 0 1 1 1 0 0 0 1 1 1 1 1 1 0 1\n", + " 1 1 0 0 1 0 0 1 0 0 0 0 1 0 0 0 1 0 0 1 1 1 0 1 0 1 0 0 1 0 1 1 1 0 1 1 0\n", + " 1 0 1 1 1 0 0 1 0 0 1 1 0 1 0 1 1 1 1 1 0 1 1 0 1 1 0 0 0 0 0 1 0 0 0 0 1\n", + " 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 1 0 0 1 0 0 1 0 0 1 0 1 1 1 1 1 0 1 1 1 0 0\n", + " 0 1 0 0 0 0 0 1 0 1 0 0 0 1 1 0 1 0 0 0 0 1 0 1 1 0 1 0 0 0 0 1 0 0]\n", + "Decrypted message :\n", + " [0 1 1 0 1 1 0 1 0 0 0 0 0 1 1 0 1 1 1 1 0 0 1 1 1 0 0 0 1 1 1 1 0 0 0 1 1\n", + " 1 0 1 0 0 1 0 0 1 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 1 1 0 1 0 1 0 1 1 1\n", + " 0 0 0 0 1 0 0 1 0 0 1 0 0 1 1 1 1 0 0 1 0 0 0 1 1 1 0 0 0 1 1 1 1 1 1 0 1\n", + " 1 1 0 0 1 0 0 1 0 0 0 0 1 0 0 0 1 0 0 1 1 1 0 1 0 1 0 0 1 0 1 1 1 0 1 1 0\n", + " 1 0 1 1 1 0 0 1 0 0 1 1 0 1 0 1 1 1 1 1 0 1 1 0 1 1 0 0 0 0 0 1 0 0 0 0 1\n", + " 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 1 0 0 1 0 0 1 0 0 1 0 1 1 1 1 1 0 1 1 1 0 0\n", + " 0 1 0 0 0 0 0 1 0 1 0 0 0 1 1 0 1 0 0 0 0 1 0 1 1 0 1 0 0 0 0 1 0 0]\n", + "Actual message and decrypted message are the same!\n", + "\n", + "Time taken by SW(NTT) only = 0.37024855613708496 seconds\n" + ] + } + ], + "source": [ + "#oho H-3\n", + "# Full SW Run Encrypt-Decrypt (NTT Version)\n", + "\n", + "# Get pre-computed factors\n", + "psis, inv_psis = gen_tf(psin, inv_psin, n, q)\n", + "pwmf = gen_pwmf(psin, n, q)\n", + "\n", + "start_sw = time.time()\n", + "\n", + "# Randomly generated binary message, m\n", + "m = np.random.randint(2, size=(n2,))\n", + "ms = decompress(m, q, 1)\n", + "\n", + "# Generating private key (s) and publik keys (a,b)\n", + "s, a, b = key_gen()\n", + "\n", + "# Encrypting the message using public keys to provide cipher texts (u,v)\n", + "u, v = encrypt(a, b, ms)\n", + "\n", + "# Decrypt the cipher using private key to obtain back the message (d)\n", + "d = decrypt(s, u, v)\n", + "\n", + "# Decoding the decrypted message\n", + "md = []\n", + "for i in d:\n", + " if i > math.floor(q/4) and i < math.floor(3*q/4):\n", + " md.append(1)\n", + " else:\n", + " md.append(0)\n", + "md = np.array(md)\n", + "\n", + "end_sw = time.time()\n", + "\n", + "# Comparision and printing results\n", + "print(\"Actual message :\\n\", m)\n", + "print(\"Decrypted message :\\n\", md)\n", + "\n", + "if (list(m) == list(md)):\n", + " print(\"Actual message and decrypted message are the same!\")\n", + "else:\n", + " print(\"There is mismatch ....\")\n", + "\n", + "print()\n", + "print(\"Time taken by SW(NTT) only =\", end_sw - start_sw, \"seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "id": "75735ba1", + "metadata": {}, + "outputs": [], + "source": [ + "##############################################################################" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "id": "fcce2793", + "metadata": {}, + "outputs": [], + "source": [ + "# oho I-1\n", + "\n", + "# HW declaration poly-mult !\n", + "\n", + "import numpy as np\n", + "from pynq import allocate\n", + "\n", + "# n2 should be 256 for this HLS core\n", + "n2 = 256\n", + "\n", + "def poly_mul_hw(x1: np.ndarray, x2: np.ndarray) -> np.ndarray:\n", + " \"\"\"\n", + " Hardware polynomial multiplication:\n", + " - x1, x2: length-n2 arrays of 16-bit coefficients (numpy.int16 / any int)\n", + " - Returns: length-n2 array of 16-bit coeffs (numpy.int16)\n", + "\n", + " Packing matches the Vitis C testbench (pm_test.cpp):\n", + " TDATA[15:0] = x1[i]\n", + " TDATA[31:16] = x2[i]\n", + " \"\"\"\n", + "\n", + " x1 = np.asarray(x1)\n", + " x2 = np.asarray(x2)\n", + "\n", + " assert x1.shape == (n2,), f\"x1 shape {x1.shape} != ({n2},)\"\n", + " assert x2.shape == (n2,), f\"x2 shape {x2.shape} != ({n2},)\"\n", + "\n", + " # Allocate DMA buffers\n", + " input_buffer = allocate(shape=(n2,), dtype=np.int32)\n", + " output_buffer = allocate(shape=(n2,), dtype=np.int16)\n", + "\n", + " # Pack two 16-bit coeffs into one 32-bit word: [x2 | x1]\n", + " x1_32 = x1.astype(np.int32)\n", + " x2_32 = x2.astype(np.int32)\n", + " packed = (x1_32 & 0xFFFF) | ((x2_32 & 0xFFFF) << 16)\n", + " input_buffer[:] = packed\n", + "\n", + " # Start IP core (ap_start = 1 at control register 0x00)\n", + " mmio.write(0x00, 0x1)\n", + "\n", + " # Launch DMA transfers\n", + " dma_send.transfer(input_buffer)\n", + " dma_recv.transfer(output_buffer)\n", + " dma_send.wait()\n", + " dma_recv.wait()\n", + "\n", + " # Copy result out\n", + " y = np.array(output_buffer, dtype=np.int16)\n", + "\n", + " # Free buffers\n", + " try:\n", + " input_buffer.freebuffer()\n", + " output_buffer.freebuffer()\n", + " except AttributeError:\n", + " del input_buffer, output_buffer\n", + "\n", + " return y\n" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "309918cd", + "metadata": {}, + "outputs": [], + "source": [ + "# oho I-2\n", + "# encrypt-Decrypt Declarations in HW\n", + "\n", + "# ------------------------------------------------------------------\n", + "# Fixed noise profile that is known to work with your HW poly_mul:\n", + "# - compression: ON (original compress/decompress with du,dv)\n", + "# - noise scaled down deterministically by 1/4 in keygen & encrypt\n", + "# ------------------------------------------------------------------\n", + "\n", + "NOISE_DIV_KEY = 4 # *** added: scale s,e in keygen by 1/4 ***\n", + "NOISE_DIV_ENC = 4 # *** added: scale r,e1,e2 in encrypt by 1/4 ***\n", + "\n", + "def scale_noise(x, div):\n", + " # *** added: helper to reduce noise magnitude while preserving sign ***\n", + " x = np.asarray(x, dtype=np.int32)\n", + " if div == 1:\n", + " return x.astype(np.int16)\n", + " return (np.sign(x) * (np.abs(x) // div)).astype(np.int16)\n", + "\n", + "# Kyber PKE function-declarations with PolyMult in Hardware (and rest in SW)\n", + "\n", + "# Key generation function (to be performed by server)\n", + "def key_gen2():\n", + " # *** changed: explicit int16 dtype for a ***\n", + " a = np.random.randint(q, size=(k, k, n2), dtype=np.int16)\n", + "\n", + " # original CBD noise\n", + " s_raw = cbd_vector(n2, eta1, k)\n", + " e_raw = cbd_vector(n2, eta1, k)\n", + "\n", + " # *** added: scale noise for HW robustness ***\n", + " s = scale_noise(s_raw, NOISE_DIV_KEY)\n", + " e = scale_noise(e_raw, NOISE_DIV_KEY)\n", + "\n", + " b0 = (poly_mul_hw(a[0,0], s[0]) + e[0]) % q\n", + " b1 = (poly_mul_hw(a[0,1], s[1]) + e[1]) % q\n", + " b2 = (poly_mul_hw(a[1,0], s[0]) + e[0]) % q\n", + " b3 = (poly_mul_hw(a[1,1], s[1]) + e[1]) % q\n", + " b01 = (b0 + b1) % q\n", + " b23 = (b2 + b3) % q\n", + " b = np.array([b01, b23], dtype=np.int16)\n", + " return s, a, b\n", + "\n", + "# Encryption function (to be performed by client)\n", + "def encrypt2(a, b, m):\n", + " # original CBD noise\n", + " r_raw = cbd_vector(n2, eta1, k)\n", + " e1_raw = cbd_vector(n2, eta2, k)\n", + " e2_raw = cbd(n2, eta2)\n", + "\n", + " # *** added: scale noise for HW robustness ***\n", + " r = scale_noise(r_raw, NOISE_DIV_ENC)\n", + " e1 = scale_noise(e1_raw, NOISE_DIV_ENC)\n", + " e2 = scale_noise(e2_raw, NOISE_DIV_ENC)\n", + "\n", + " u0 = (poly_mul_hw(a[0,0], r[0]) + e1[0]) % q\n", + " u1 = (poly_mul_hw(a[1,0], r[1]) + e1[1]) % q\n", + " u2 = (poly_mul_hw(a[0,1], r[0]) + e1[0]) % q\n", + " u3 = (poly_mul_hw(a[1,1], r[1]) + e1[1]) % q\n", + " u01 = (u0 + u1) % q\n", + " u23 = (u2 + u3) % q\n", + " u = np.array([u01, u23], dtype=np.int16)\n", + "\n", + " v0 = np.array(poly_mul_hw(b[0], r[0]), dtype=np.int16)\n", + " v1 = np.array(poly_mul_hw(b[1], r[1]), dtype=np.int16)\n", + " v = (v0 + v1 + e2 + m) % q\n", + "\n", + " # keep your original compression (du, dv from cell A)\n", + " u_c = compress(u, q, du)\n", + " v_c = compress(v, q, dv)\n", + " return u_c, v_c\n", + "\n", + "# Decryption function (to be performed by server)\n", + "def decrypt2(s, u, v):\n", + " # original compression/decompression path\n", + " u_dec = decompress(u, q, du)\n", + " v_dec = decompress(v, q, dv)\n", + "\n", + " p0 = np.array(poly_mul_hw(s[0], u_dec[0]), dtype=np.int16)\n", + " p1 = np.array(poly_mul_hw(s[1], u_dec[1]), dtype=np.int16)\n", + " p = (p0 + p1) % q\n", + " d = (v_dec - p) % q\n", + " return d\n" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "id": "87afda4a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Actual message :\n", + " [1 1 1 1 1 1 1 0 0 0 1 1 0 0 0 0 0 1 1 0 0 1 1 1 0 0 0 0 1 0 1 1 0 1 0 1 0\n", + " 1 1 1 0 1 0 0 1 1 1 0 1 1 1 0 1 1 1 1 0 0 1 1 1 1 0 0 0 1 0 1 0 0 1 0 0 0\n", + " 0 0 0 1 0 0 0 1 1 1 0 0 1 0 0 1 0 1 0 1 0 1 1 1 0 0 1 0 1 0 1 0 1 1 1 0 0\n", + " 1 1 0 1 1 0 0 1 1 0 1 1 1 1 0 1 1 0 1 0 1 1 1 0 0 0 0 0 1 1 1 0 0 1 1 1 1\n", + " 1 1 1 0 0 0 0 0 0 1 1 1 0 1 0 1 0 0 1 1 0 0 1 0 0 1 0 0 1 0 1 0 1 0 1 1 1\n", + " 1 1 1 1 1 1 1 1 0 0 1 0 0 0 1 1 1 0 0 1 1 0 1 0 1 1 1 0 0 0 0 0 0 0 1 0 0\n", + " 1 1 1 1 0 1 1 0 0 1 0 0 0 0 1 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 1 1 1 0]\n", + "Decrypted message :\n", + " [1 1 1 1 1 1 1 0 0 0 1 1 0 0 0 0 0 1 1 0 0 1 1 1 0 0 0 0 1 0 1 1 0 1 0 1 0\n", + " 1 1 1 0 1 0 0 1 1 1 0 1 1 1 0 1 1 1 1 0 0 1 1 1 1 0 0 0 1 0 1 0 0 1 0 0 0\n", + " 0 0 0 1 0 0 0 1 1 1 0 0 1 0 0 1 0 1 0 1 0 1 1 1 0 0 1 0 1 0 1 0 1 1 1 0 0\n", + " 1 1 0 1 1 0 0 1 1 0 1 1 1 1 0 1 1 0 1 0 1 1 1 0 0 0 0 0 1 1 1 0 0 1 1 1 1\n", + " 1 1 1 0 0 0 0 0 0 1 1 1 0 1 0 1 0 0 1 1 0 0 1 0 0 1 0 0 1 0 1 0 1 0 1 1 1\n", + " 1 1 1 1 1 1 1 1 0 0 1 0 0 0 1 1 1 0 0 1 1 0 1 0 1 1 1 0 0 0 0 0 0 0 1 0 0\n", + " 1 1 1 1 0 1 1 0 0 1 0 0 0 0 1 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 1 1 1 0]\n", + "Actual message and decrypted message are the same!\n", + "\n", + "Time taken by HW-SW = 0.0352785587310791 seconds\n" + ] + } + ], + "source": [ + "# oho I-3\n", + "# Full HW RUN mit Zeitnahme\n", + "\n", + "# Get pre-computed factors\n", + "psis, inv_psis = gen_tf(psin, inv_psin, n, q)\n", + "pwmf = gen_pwmf(psin, n, q)\n", + "\n", + "start_hw = time.time()\n", + "\n", + "# Randomly generated binary message, m\n", + "m = np.random.randint(2, size=(n2,))\n", + "ms = decompress(m, q, 1)\n", + "\n", + "# Generating private key (s) and publik keys (a,b)\n", + "s, a, b = key_gen2()\n", + "\n", + "# Encrypting the message using public keys to provide cipher texts (u,v)\n", + "u, v = encrypt2(a, b, ms)\n", + "\n", + "# Decrypt the cipher using private key to obtain back the message (d)\n", + "d = decrypt2(s, u, v)\n", + "\n", + "# Decoding the decrypted message\n", + "md = []\n", + "for i in d:\n", + " if i > math.floor(q/4) and i < math.floor(3*q/4):\n", + " md.append(1)\n", + " else:\n", + " md.append(0)\n", + "md = np.array(md)\n", + "\n", + "end_hw = time.time()\n", + "\n", + "# Comparision and printing results\n", + "print(\"Actual message :\\n\", m)\n", + "print(\"Decrypted message :\\n\", md)\n", + "\n", + "if (list(m) == list(md)):\n", + " print(\"Actual message and decrypted message are the same!\")\n", + "else:\n", + " print(\"There is mismatch ....\")\n", + " \n", + "print()\n", + "print(\"Time taken by HW-SW =\", end_hw - start_hw, \"seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "id": "a766ea8f", + "metadata": {}, + "outputs": [], + "source": [ + "###################################################" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "id": "51433201", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Speed-up factor SW(naive) vs SW(NTT) = 11.234882979648214\n", + "Speed-up factor HW vs SW(NTT) = 10.49500233156945\n" + ] + } + ], + "source": [ + "# oho J\n", + "# Zeitvergleich SW(naive vs. SW (NTT) && SW(NTT) vs. HW\n", + "SF1 = (end_sw_naive - start_sw_naive)/(end_sw - start_sw)\n", + "print(\"Speed-up factor SW(naive) vs SW(NTT) =\", SF1)\n", + "SF2 = (end_sw - start_sw)/(end_hw - start_hw)\n", + "print(\"Speed-up factor HW vs SW(NTT) =\", SF2)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "756f9cce", + "metadata": {}, + "outputs": [], + "source": [ + "###################################################" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "8f6f065f", + "metadata": {}, + "outputs": [], + "source": [ + "###################################################" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "b85146b6", + "metadata": {}, + "outputs": [], + "source": [ + "###################################################" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "d3df6736", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Single roundtrip with current settings:\n", + "USE_COMPRESSION=True, NOISE_DIV_KEY=3, NOISE_DIV_ENC=3\n", + "Decryption successful? True\n", + "\n", + "Monte-Carlo over 20 trials:\n", + "USE_COMPRESSION=True, NOISE_DIV_KEY=3, NOISE_DIV_ENC=3\n", + "Trials: 20, successes: 20, failures: 0, failure rate ≈ 0.0000\n" + ] + } + ], + "source": [ + "# oho K\n", + "# Experiment K - selfcontained tunable solution\n", + "\n", + "# (success only with both noise params set to 3 !)\n", + "\n", + "import numpy as np\n", + "import math\n", + "\n", + "# =========================================================\n", + "# GLOBAL TUNING KNOBS (EDIT THESE LINES)\n", + "# =========================================================\n", + "USE_COMPRESSION = True # True = use compress/decompress(u,v); False = bypass\n", + "NOISE_DIV_KEY = 3 # noise scaling for s,e in keygen (1 = original)\n", + "NOISE_DIV_ENC = 3 # noise scaling for r,e1,e2 in encrypt (1 = original)\n", + "# =========================================================\n", + "# Assumes: q, n2, du, dv, eta1, eta2, k, cbd_vector, cbd,\n", + "# compress, decompress, poly_mul_hw are already defined.\n", + "# =========================================================\n", + "\n", + "\n", + "# ---------------------------------------------------------\n", + "# Utility: integer noise scaling\n", + "# ---------------------------------------------------------\n", + "def scale_noise(x, div: int):\n", + " \"\"\"\n", + " Scale integer noise by integer divisor 'div' with truncation toward 0.\n", + " div = 1 -> original noise\n", + " div = 2 -> roughly half magnitude, etc.\n", + " \"\"\"\n", + " if div == 1:\n", + " return x.astype(np.int16)\n", + " x = np.asarray(x, dtype=np.int32)\n", + " return (np.sign(x) * (np.abs(x) // div)).astype(np.int16)\n", + "\n", + "\n", + "# ---------------------------------------------------------\n", + "# HW-based Kyber-style routines using global knobs\n", + "# ---------------------------------------------------------\n", + "def key_gen2_hw_cellYY1():\n", + " \"\"\"\n", + " Key generation using hardware poly_mult and global NOISE_DIV_KEY.\n", + " \"\"\"\n", + " # a: public matrix, shape (k, k, n2)\n", + " a = np.random.randint(q, size=(k, k, n2), dtype=np.int16)\n", + "\n", + " # raw noise\n", + " s_raw = cbd_vector(n2, eta1, k) # (k, n2)\n", + " e_raw = cbd_vector(n2, eta1, k) # (k, n2)\n", + "\n", + " # scaled noise\n", + " s = scale_noise(s_raw, NOISE_DIV_KEY)\n", + " e = scale_noise(e_raw, NOISE_DIV_KEY)\n", + "\n", + " b0 = (poly_mul_hw(a[0,0], s[0]) + e[0]) % q\n", + " b1 = (poly_mul_hw(a[0,1], s[1]) + e[1]) % q\n", + " b2 = (poly_mul_hw(a[1,0], s[0]) + e[0]) % q\n", + " b3 = (poly_mul_hw(a[1,1], s[1]) + e[1]) % q\n", + "\n", + " b01 = (b0 + b1) % q\n", + " b23 = (b2 + b3) % q\n", + " b = np.array([b01, b23], dtype=np.int16)\n", + " return s, a, b\n", + "\n", + "\n", + "def encrypt2_hw_cellYY1(a, b, m):\n", + " \"\"\"\n", + " Encryption using hardware poly_mult and global NOISE_DIV_ENC.\n", + " Compression controlled by global USE_COMPRESSION.\n", + " \"\"\"\n", + " # raw noise\n", + " r_raw = cbd_vector(n2, eta1, k) # (k, n2)\n", + " e1_raw = cbd_vector(n2, eta2, k) # (k, n2)\n", + " e2_raw = cbd(n2, eta2) # (n2,)\n", + "\n", + " # scaled noise\n", + " r = scale_noise(r_raw, NOISE_DIV_ENC)\n", + " e1 = scale_noise(e1_raw, NOISE_DIV_ENC)\n", + " e2 = scale_noise(e2_raw, NOISE_DIV_ENC)\n", + "\n", + " # u part\n", + " u0 = (poly_mul_hw(a[0,0], r[0]) + e1[0]) % q\n", + " u1 = (poly_mul_hw(a[1,0], r[1]) + e1[1]) % q\n", + " u2 = (poly_mul_hw(a[0,1], r[0]) + e1[0]) % q\n", + " u3 = (poly_mul_hw(a[1,1], r[1]) + e1[1]) % q\n", + "\n", + " u01 = (u0 + u1) % q\n", + " u23 = (u2 + u3) % q\n", + " u = np.array([u01, u23], dtype=np.int16)\n", + "\n", + " # v part\n", + " v0 = poly_mul_hw(b[0], r[0]) % q\n", + " v1 = poly_mul_hw(b[1], r[1]) % q\n", + " v = (v0 + v1 + e2 + m) % q\n", + "\n", + " if USE_COMPRESSION:\n", + " u_c = compress(u, q, du)\n", + " v_c = compress(v, q, dv)\n", + " else:\n", + " u_c = u.astype(np.int16)\n", + " v_c = v.astype(np.int16)\n", + "\n", + " return u_c, v_c\n", + "\n", + "\n", + "def decrypt2_hw_cellYY1(s, u, v):\n", + " \"\"\"\n", + " Decryption using hardware poly_mult.\n", + " Compression controlled by global USE_COMPRESSION.\n", + " \"\"\"\n", + " if USE_COMPRESSION:\n", + " u_dec = decompress(u, q, du)\n", + " v_dec = decompress(v, q, dv)\n", + " else:\n", + " u_dec = u.astype(np.int16)\n", + " v_dec = v.astype(np.int16)\n", + "\n", + " p0 = poly_mul_hw(s[0], u_dec[0]) % q\n", + " p1 = poly_mul_hw(s[1], u_dec[1]) % q\n", + " p = (p0 + p1) % q\n", + " d = (v_dec - p) % q\n", + " return d\n", + "\n", + "\n", + "# ---------------------------------------------------------\n", + "# Roundtrip and Monte-Carlo\n", + "# ---------------------------------------------------------\n", + "def roundtrip_once(verbose: bool = True):\n", + " \"\"\"\n", + " One Decrypt(Encrypt(m)) round using global knobs.\n", + " Returns True/False for bitwise equality.\n", + " \"\"\"\n", + " # random binary message\n", + " m = np.random.randint(2, size=(n2,))\n", + " ms = decompress(m, q, 1) # same embedding as before\n", + "\n", + " s, a, b = key_gen2_hw_cellYY1()\n", + " u, v = encrypt2_hw_cellYY1(a, b, ms)\n", + " d = decrypt2_hw_cellYY1(s, u, v)\n", + "\n", + " th1 = math.floor(q/4)\n", + " th3 = math.floor(3*q/4)\n", + " md = np.array([1 if (th1 < x < th3) else 0 for x in d], dtype=np.int8)\n", + "\n", + " ok = np.array_equal(m, md)\n", + "\n", + " if verbose:\n", + " print(f\"USE_COMPRESSION={USE_COMPRESSION}, \"\n", + " f\"NOISE_DIV_KEY={NOISE_DIV_KEY}, NOISE_DIV_ENC={NOISE_DIV_ENC}\")\n", + " print(\"Decryption successful? \", ok)\n", + " if not ok:\n", + " idx = np.where(m != md)[0][0]\n", + " print(\"First mismatch at index\", idx)\n", + " print(\"m[idx] =\", int(m[idx]), \" md[idx] =\", int(md[idx]))\n", + " return ok\n", + "\n", + "\n", + "def monte_carlo_roundtrip(trials: int = 50):\n", + " \"\"\"\n", + " Run multiple roundtrips with current global knobs;\n", + " print empirical failure rate.\n", + " \"\"\"\n", + " successes = 0\n", + " for t in range(trials):\n", + " if roundtrip_once(verbose=False):\n", + " successes += 1\n", + " failures = trials - successes\n", + " failure_rate = failures / trials\n", + " print(f\"USE_COMPRESSION={USE_COMPRESSION}, \"\n", + " f\"NOISE_DIV_KEY={NOISE_DIV_KEY}, NOISE_DIV_ENC={NOISE_DIV_ENC}\")\n", + " print(f\"Trials: {trials}, successes: {successes}, \"\n", + " f\"failures: {failures}, failure rate ≈ {failure_rate:.4f}\")\n", + "\n", + "\n", + "# ---------------------------------------------------------\n", + "# Example usage (you can comment these out)\n", + "# ---------------------------------------------------------\n", + "print(\"Single roundtrip with current settings:\")\n", + "roundtrip_once()\n", + "\n", + "print(\"\\nMonte-Carlo over 20 trials:\")\n", + "monte_carlo_roundtrip(trials=20)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "7fb23af6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Single roundtrip with current global settings:\n", + "USE_COMPRESSION=True, NOISE_DIV_KEY=3, NOISE_DIV_ENC=3\n", + "Decryption successful? True\n", + "\n", + "Monte-Carlo over 20 trials:\n", + "USE_COMPRESSION=True, NOISE_DIV_KEY=3, NOISE_DIV_ENC=3\n", + "Trials: 20, successes: 20, failures: 0, failure rate ≈ 0.0000\n" + ] + } + ], + "source": [ + "# oho O-1\n", + "# Variante 1 von Experiment O\n", + "\n", + "# tunable experiment closer to ML-KEM 512 instead of Kyber Variant of mine\n", + "\n", + "# (success only with both noise params set to 4)\n", + "\n", + "import numpy as np\n", + "import math\n", + "\n", + "# =========================================================\n", + "# ASSUMPTIONS: these come from your existing notebook\n", + "# =========================================================\n", + "# q : modulus (should be 3329 for Kyber512)\n", + "# n2 : polynomial length (should be 256)\n", + "# k : module rank (should be 2)\n", + "# du,dv: compression parameters\n", + "# eta1,eta2: CBD parameters\n", + "# cbd_vector, cbd, compress, decompress: existing functions\n", + "# poly_mul_hw: your hardware polymult wrapper (DMA + IP)\n", + "#\n", + "# We assert some basics to avoid silent mismatches.\n", + "# =========================================================\n", + "\n", + "try:\n", + " assert q == 3329\n", + " assert n2 == 256\n", + " assert k == 2\n", + "except NameError:\n", + " print(\"WARNING: q, n2, k not defined yet in this notebook.\")\n", + "except AssertionError:\n", + " print(\"WARNING: q/n2/k do not match Kyber-512-style values (q=3329, n2=256, k=2).\")\n", + "\n", + "\n", + "# =========================================================\n", + "# GLOBAL TUNING KNOBS (THIS IS WHAT YOU EDIT)\n", + "# =========================================================\n", + "USE_COMPRESSION = True # True: use compress/decompress(u,v); False: bypass\n", + "NOISE_DIV_KEY = 3 # integer divisor for s,e in keygen (1 = spec-ish, 3 = your working setting)\n", + "NOISE_DIV_ENC = 3 # integer divisor for r,e1,e2 in encrypt (1 = spec-ish, 3 = your working setting)\n", + "# =========================================================\n", + "\n", + "\n", + "# ---------------------------------------------------------\n", + "# Utility: integer noise scaling\n", + "# ---------------------------------------------------------\n", + "def scale_noise(x, div: int):\n", + " \"\"\"\n", + " Scale integer noise by integer divisor 'div' with truncation toward 0.\n", + "\n", + " - div = 1 -> original noise distribution.\n", + " - div = 2 -> approximately halves the magnitude of samples.\n", + " - div = 3 -> ~one-third, etc.\n", + "\n", + " This keeps sign, and works on scalars or arrays.\n", + " \"\"\"\n", + " x = np.asarray(x, dtype=np.int32)\n", + " if div == 1:\n", + " return x.astype(np.int16)\n", + " return (np.sign(x) * (np.abs(x) // div)).astype(np.int16)\n", + "\n", + "\n", + "# ---------------------------------------------------------\n", + "# HW-based Kyber-style routines using the global knobs\n", + "# ---------------------------------------------------------\n", + "def key_gen2_hw_cellZZA():\n", + " \"\"\"\n", + " Key generation using hardware poly_mult and global NOISE_DIV_KEY.\n", + "\n", + " Structure is Kyber-512-like:\n", + " - a ~ U(Z_q)^{k×k×n2}\n", + " - s,e from CBD(eta1) scaled by NOISE_DIV_KEY\n", + " - b = As + e (with the same algebraic formula as your original code).\n", + " \"\"\"\n", + " # public matrix a: shape (k, k, n2)\n", + " a = np.random.randint(q, size=(k, k, n2), dtype=np.int16)\n", + "\n", + " # raw noise\n", + " s_raw = cbd_vector(n2, eta1, k) # shape (k, n2)\n", + " e_raw = cbd_vector(n2, eta1, k) # shape (k, n2)\n", + "\n", + " # scaled noise according to global divisor\n", + " s = scale_noise(s_raw, NOISE_DIV_KEY)\n", + " e = scale_noise(e_raw, NOISE_DIV_KEY)\n", + "\n", + " # same structure as original key_gen2 (but with HW multiply)\n", + " b0 = (poly_mul_hw(a[0,0], s[0]) + e[0]) % q\n", + " b1 = (poly_mul_hw(a[0,1], s[1]) + e[1]) % q\n", + " b2 = (poly_mul_hw(a[1,0], s[0]) + e[0]) % q\n", + " b3 = (poly_mul_hw(a[1,1], s[1]) + e[1]) % q\n", + "\n", + " b01 = (b0 + b1) % q\n", + " b23 = (b2 + b3) % q\n", + " b = np.array([b01, b23], dtype=np.int16)\n", + " return s, a, b\n", + "\n", + "\n", + "def encrypt2_hw_cellZZA(a, b, m):\n", + " \"\"\"\n", + " Encryption using hardware poly_mult and global NOISE_DIV_ENC.\n", + "\n", + " - a, b: as from key_gen2_hw\n", + " - m : message polynomial (length n2), already embedded in Z_q\n", + " - noise (r, e1, e2) scaled by NOISE_DIV_ENC\n", + " - compression controlled by USE_COMPRESSION\n", + " \"\"\"\n", + " # raw noise\n", + " r_raw = cbd_vector(n2, eta1, k) # (k, n2)\n", + " e1_raw = cbd_vector(n2, eta2, k) # (k, n2)\n", + " e2_raw = cbd(n2, eta2) # (n2,)\n", + "\n", + " # scaled noise\n", + " r = scale_noise(r_raw, NOISE_DIV_ENC)\n", + " e1 = scale_noise(e1_raw, NOISE_DIV_ENC)\n", + " e2 = scale_noise(e2_raw, NOISE_DIV_ENC)\n", + "\n", + " # u part (matrix * r + e1)\n", + " u0 = (poly_mul_hw(a[0,0], r[0]) + e1[0]) % q\n", + " u1 = (poly_mul_hw(a[1,0], r[1]) + e1[1]) % q\n", + " u2 = (poly_mul_hw(a[0,1], r[0]) + e1[0]) % q\n", + " u3 = (poly_mul_hw(a[1,1], r[1]) + e1[1]) % q\n", + "\n", + " u01 = (u0 + u1) % q\n", + " u23 = (u2 + u3) % q\n", + " u = np.array([u01, u23], dtype=np.int16)\n", + "\n", + " # v part (b*r + e2 + m)\n", + " v0 = poly_mul_hw(b[0], r[0]) % q\n", + " v1 = poly_mul_hw(b[1], r[1]) % q\n", + " v = (v0 + v1 + e2 + m) % q\n", + "\n", + " if USE_COMPRESSION:\n", + " u_c = compress(u, q, du)\n", + " v_c = compress(v, q, dv)\n", + " else:\n", + " u_c = u.astype(np.int16)\n", + " v_c = v.astype(np.int16)\n", + "\n", + " return u_c, v_c\n", + "\n", + "\n", + "def decrypt2_hw_cellZZA(s, u, v):\n", + " \"\"\"\n", + " Decryption using hardware poly_mult.\n", + "\n", + " - s : secret key vector (k × n2)\n", + " - u,v: ciphertext components (compressed or not)\n", + " - compression controlled by USE_COMPRESSION\n", + " \"\"\"\n", + " if USE_COMPRESSION:\n", + " u_dec = decompress(u, q, du)\n", + " v_dec = decompress(v, q, dv)\n", + " else:\n", + " u_dec = u.astype(np.int16)\n", + " v_dec = v.astype(np.int16)\n", + "\n", + " p0 = poly_mul_hw(s[0], u_dec[0]) % q\n", + " p1 = poly_mul_hw(s[1], u_dec[1]) % q\n", + " p = (p0 + p1) % q\n", + " d = (v_dec - p) % q\n", + " return d\n", + "\n", + "\n", + "# ---------------------------------------------------------\n", + "# Roundtrip and Monte-Carlo with these parameters\n", + "# ---------------------------------------------------------\n", + "def roundtrip_once(verbose: bool = True):\n", + " \"\"\"\n", + " One Decrypt(Encrypt(m)) round using the current global knobs.\n", + "\n", + " - draws m uniformly in {0,1}^n2\n", + " - embeds it via decompress(m, q, 1) (as in your original code)\n", + " - runs key_gen2_hw / encrypt2_hw / decrypt2_hw\n", + " - decodes d with the usual Kyber-ish threshold rule (q/4, 3q/4)\n", + "\n", + " Returns True if m == md bitwise.\n", + " \"\"\"\n", + " # random binary message\n", + " m = np.random.randint(2, size=(n2,))\n", + " ms = decompress(m, q, 1) # your existing embedding\n", + "\n", + " s, a, b = key_gen2_hw_cellZZA()\n", + " u, v = encrypt2_hw_cellZZA(a, b, ms)\n", + " d = decrypt2_hw_cellZZA(s, u, v)\n", + "\n", + " th1 = math.floor(q / 4)\n", + " th3 = math.floor(3 * q / 4)\n", + " md = np.array([1 if (th1 < x < th3) else 0 for x in d], dtype=np.int8)\n", + "\n", + " ok = np.array_equal(m, md)\n", + "\n", + " if verbose:\n", + " print(f\"USE_COMPRESSION={USE_COMPRESSION}, \"\n", + " f\"NOISE_DIV_KEY={NOISE_DIV_KEY}, NOISE_DIV_ENC={NOISE_DIV_ENC}\")\n", + " print(\"Decryption successful? \", ok)\n", + " if not ok:\n", + " idx = np.where(m != md)[0][0]\n", + " print(\"First mismatch at index\", idx)\n", + " print(\"m[idx] =\", int(m[idx]))\n", + " print(\"md[idx] =\", int(md[idx]))\n", + " return ok\n", + "\n", + "\n", + "def monte_carlo_roundtrip(trials: int = 50):\n", + " \"\"\"\n", + " Run multiple roundtrips with current global knobs;\n", + " print empirical failure rate.\n", + "\n", + " This is your basic diagnostic for whether a given\n", + " (USE_COMPRESSION, NOISE_DIV_KEY, NOISE_DIV_ENC) triple\n", + " is \"safe enough\" with your hardware multiplier.\n", + " \"\"\"\n", + " successes = 0\n", + " for t in range(trials):\n", + " if roundtrip_once(verbose=False):\n", + " successes += 1\n", + " failures = trials - successes\n", + " failure_rate = failures / trials\n", + " print(f\"USE_COMPRESSION={USE_COMPRESSION}, \"\n", + " f\"NOISE_DIV_KEY={NOISE_DIV_KEY}, NOISE_DIV_ENC={NOISE_DIV_ENC}\")\n", + " print(f\"Trials: {trials}, successes: {successes}, \"\n", + " f\"failures: {failures}, failure rate ≈ {failure_rate:.4f}\")\n", + "\n", + "\n", + "# ---------------------------------------------------------\n", + "# Example usage with current presets\n", + "# ---------------------------------------------------------\n", + "print(\"Single roundtrip with current global settings:\")\n", + "roundtrip_once()\n", + "\n", + "print(\"\\nMonte-Carlo over 20 trials:\")\n", + "monte_carlo_roundtrip(trials=20)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "0c35d97b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Single roundtrip with current global settings:\n", + "USE_COMPRESSION=True, ETA1=3, ETA2=2, DU=10, DV=4, NOISE_DIV_KEY=4, NOISE_DIV_ENC=4\n", + "Decryption successful? True\n", + "\n", + "Monte-Carlo over 20 trials:\n", + "USE_COMPRESSION=True, ETA1=3, ETA2=2, DU=10, DV=4, NOISE_DIV_KEY=4, NOISE_DIV_ENC=4\n", + "Trials: 20, successes: 20, failures: 0, failure rate ≈ 0.0000\n" + ] + } + ], + "source": [ + "# oho O-2\n", + "# Variante 2 des O Experiments\n", + "\n", + "# closest to ML-KEM\n", + "\n", + "# (success only with both noise params set to 4)\n", + "\n", + "import numpy as np\n", + "import math\n", + "\n", + "# =========================================================\n", + "# ASSUMPTIONS: these already exist in your notebook\n", + "# =========================================================\n", + "# - q : modulus (should be 3329 for Kyber512)\n", + "# - n2 : polynomial length (should be 256)\n", + "# - k : module rank (should be 2)\n", + "# - cbd_vector : cbd_vector(n, eta, k)\n", + "# - cbd : cbd(n, eta)\n", + "# - compress : compress(poly, q, d)\n", + "# - decompress : decompress(poly, q, d)\n", + "# - poly_mul_hw: hardware polymult wrapper (DMA + IP)\n", + "#\n", + "# If they don't, define/import them before running this cell.\n", + "# =========================================================\n", + "\n", + "# Basic sanity checks (won't stop execution, just warn)\n", + "try:\n", + " if q != 3329:\n", + " print(f\"WARNING: q={q}, expected 3329 for Kyber512/ML-KEM-512.\")\n", + " if n2 != 256:\n", + " print(f\"WARNING: n2={n2}, expected 256.\")\n", + " if k != 2:\n", + " print(f\"WARNING: k={k}, expected 2 (Kyber512/ML-KEM-512 style).\")\n", + "except NameError:\n", + " print(\"WARNING: q, n2, k are not all defined before this cell.\")\n", + "\n", + "\n", + "# =========================================================\n", + "# GLOBAL TUNING KNOBS (THIS IS WHAT YOU EDIT)\n", + "# =========================================================\n", + "# 1) Core ML-KEM / Kyber-style parameters\n", + "ETA1 = 3 # eta1 (Kyber512/ML-KEM-512: 3)\n", + "ETA2 = 2 # eta2 (Kyber512/ML-KEM-512: 2)\n", + "DU = 10 # du (Kyber512/ML-KEM-512: 10)\n", + "DV = 4 # dv (Kyber512/ML-KEM-512: 4)\n", + "\n", + "# 2) Extra “safety” scaling for noise (beyond eta1/eta2)\n", + "USE_COMPRESSION = True # True: compress/decompress u,v with DU,DV; False: bypass\n", + "NOISE_DIV_KEY = 4 # divisor for s,e in keygen (1 = spec-like, 3 = your current stable setting)\n", + "NOISE_DIV_ENC = 4 # divisor for r,e1,e2 in encrypt (1 = spec-like, 3 = your current stable setting)\n", + "# =========================================================\n", + "\n", + "\n", + "# ---------------------------------------------------------\n", + "# Utility: integer noise scaling\n", + "# ---------------------------------------------------------\n", + "def scale_noise(x, div: int):\n", + " \"\"\"\n", + " Scale integer noise by integer divisor 'div' with truncation toward 0.\n", + "\n", + " - div = 1 -> original CBD(eta) noise distribution.\n", + " - div = 2 -> approximately halves magnitude.\n", + " - div = 3 -> ~one-third, etc.\n", + " \"\"\"\n", + " x = np.asarray(x, dtype=np.int32)\n", + " if div == 1:\n", + " return x.astype(np.int16)\n", + " return (np.sign(x) * (np.abs(x) // div)).astype(np.int16)\n", + "\n", + "\n", + "# ---------------------------------------------------------\n", + "# HW-based Kyber-style routines using the global knobs\n", + "# ---------------------------------------------------------\n", + "def key_gen2_hw_cellZZB():\n", + " \"\"\"\n", + " Key generation using hardware poly_mult and global parameters:\n", + "\n", + " - Ring: Z_q[X]/(X^256 + 1), q=3329 (assumed).\n", + " - k = 2 module dimension (Kyber512/ML-KEM-512 style).\n", + " - a ~ U(Z_q)^{k×k×n2}.\n", + " - s,e ~ CBD(ETA1), then scaled by NOISE_DIV_KEY.\n", + "\n", + " Returns (s, a, b) with b = A*s + e.\n", + " \"\"\"\n", + " # public matrix a: shape (k, k, n2)\n", + " a = np.random.randint(q, size=(k, k, n2), dtype=np.int16)\n", + "\n", + " # raw noise\n", + " s_raw = cbd_vector(n2, ETA1, k) # shape (k, n2)\n", + " e_raw = cbd_vector(n2, ETA1, k) # shape (k, n2)\n", + "\n", + " # scaled noise according to global divisor\n", + " s = scale_noise(s_raw, NOISE_DIV_KEY)\n", + " e = scale_noise(e_raw, NOISE_DIV_KEY)\n", + "\n", + " # same structure as your original key_gen2, but with HW multiply\n", + " b0 = (poly_mul_hw(a[0,0], s[0]) + e[0]) % q\n", + " b1 = (poly_mul_hw(a[0,1], s[1]) + e[1]) % q\n", + " b2 = (poly_mul_hw(a[1,0], s[0]) + e[0]) % q\n", + " b3 = (poly_mul_hw(a[1,1], s[1]) + e[1]) % q\n", + "\n", + " b01 = (b0 + b1) % q\n", + " b23 = (b2 + b3) % q\n", + " b = np.array([b01, b23], dtype=np.int16)\n", + " return s, a, b\n", + "\n", + "\n", + "def encrypt2_hw_cellZZB(a, b, m):\n", + " \"\"\"\n", + " Encryption using hardware poly_mult and global parameters:\n", + "\n", + " - a, b as from key_gen2_hw.\n", + " - message m: length-n2 polynomial already embedded in Z_q.\n", + " - r ~ CBD(ETA1), e1 ~ CBD(ETA2), e2 ~ CBD(ETA2), then all\n", + " scaled by NOISE_DIV_ENC.\n", + " - DU, DV control compression of u and v, if USE_COMPRESSION is True.\n", + " \"\"\"\n", + " # raw noise\n", + " r_raw = cbd_vector(n2, ETA1, k) # (k, n2)\n", + " e1_raw = cbd_vector(n2, ETA2, k) # (k, n2)\n", + " e2_raw = cbd(n2, ETA2) # (n2,)\n", + "\n", + " # scaled noise\n", + " r = scale_noise(r_raw, NOISE_DIV_ENC)\n", + " e1 = scale_noise(e1_raw, NOISE_DIV_ENC)\n", + " e2 = scale_noise(e2_raw, NOISE_DIV_ENC)\n", + "\n", + " # u part: A * r + e1\n", + " u0 = (poly_mul_hw(a[0,0], r[0]) + e1[0]) % q\n", + " u1 = (poly_mul_hw(a[1,0], r[1]) + e1[1]) % q\n", + " u2 = (poly_mul_hw(a[0,1], r[0]) + e1[0]) % q\n", + " u3 = (poly_mul_hw(a[1,1], r[1]) + e1[1]) % q\n", + "\n", + " u01 = (u0 + u1) % q\n", + " u23 = (u2 + u3) % q\n", + " u = np.array([u01, u23], dtype=np.int16)\n", + "\n", + " # v part: b * r + e2 + m\n", + " v0 = poly_mul_hw(b[0], r[0]) % q\n", + " v1 = poly_mul_hw(b[1], r[1]) % q\n", + " v = (v0 + v1 + e2 + m) % q\n", + "\n", + " if USE_COMPRESSION:\n", + " u_c = compress(u, q, DU)\n", + " v_c = compress(v, q, DV)\n", + " else:\n", + " u_c = u.astype(np.int16)\n", + " v_c = v.astype(np.int16)\n", + "\n", + " return u_c, v_c\n", + "\n", + "\n", + "def decrypt2_hw_cellZZB(s, u, v):\n", + " \"\"\"\n", + " Decryption using hardware poly_mult and global parameters:\n", + "\n", + " - s: secret key (k × n2)\n", + " - (u, v): ciphertext components, compressed with (DU, DV)\n", + " if USE_COMPRESSION is True.\n", + "\n", + " Decryption computes p = s·u and d = v - p (mod q).\n", + " \"\"\"\n", + " if USE_COMPRESSION:\n", + " u_dec = decompress(u, q, DU)\n", + " v_dec = decompress(v, q, DV)\n", + " else:\n", + " u_dec = u.astype(np.int16)\n", + " v_dec = v.astype(np.int16)\n", + "\n", + " p0 = poly_mul_hw(s[0], u_dec[0]) % q\n", + " p1 = poly_mul_hw(s[1], u_dec[1]) % q\n", + " p = (p0 + p1) % q\n", + " d = (v_dec - p) % q\n", + " return d\n", + "\n", + "\n", + "# ---------------------------------------------------------\n", + "# Roundtrip and Monte-Carlo\n", + "# ---------------------------------------------------------\n", + "def roundtrip_once(verbose: bool = True):\n", + " \"\"\"\n", + " One Decrypt(Encrypt(m)) round using the current global knobs.\n", + "\n", + " - m ∈ {0,1}^n2 uniformly random.\n", + " - Embedding via decompress(m, q, 1) (your original convention).\n", + " - Uses key_gen2_hw / encrypt2_hw / decrypt2_hw.\n", + " - Decoding via threshold rule: 0 if in [0, q/4] ∪ [3q/4, q),\n", + " 1 if in (q/4, 3q/4).\n", + "\n", + " Returns True if m == md bitwise.\n", + " \"\"\"\n", + " # random binary message\n", + " m = np.random.randint(2, size=(n2,))\n", + " # message embedding (this is separate from DU/DV and ETA*)\n", + " ms = decompress(m, q, 1)\n", + "\n", + " s, a, b = key_gen2_hw_cellZZB()\n", + " u, v = encrypt2_hw_cellZZB(a, b, ms)\n", + " d = decrypt2_hw_cellZZB(s, u, v)\n", + "\n", + " th1 = math.floor(q / 4)\n", + " th3 = math.floor(3 * q / 4)\n", + " md = np.array([1 if (th1 < x < th3) else 0 for x in d], dtype=np.int8)\n", + "\n", + " ok = np.array_equal(m, md)\n", + "\n", + " if verbose:\n", + " print(f\"USE_COMPRESSION={USE_COMPRESSION}, \"\n", + " f\"ETA1={ETA1}, ETA2={ETA2}, DU={DU}, DV={DV}, \"\n", + " f\"NOISE_DIV_KEY={NOISE_DIV_KEY}, NOISE_DIV_ENC={NOISE_DIV_ENC}\")\n", + " print(\"Decryption successful? \", ok)\n", + " if not ok:\n", + " idx = np.where(m != md)[0][0]\n", + " print(\"First mismatch at index\", idx)\n", + " print(\"m[idx] =\", int(m[idx]))\n", + " print(\"md[idx] =\", int(md[idx]))\n", + " return ok\n", + "\n", + "\n", + "def monte_carlo_roundtrip(trials: int = 50):\n", + " \"\"\"\n", + " Run multiple roundtrips with current global knobs;\n", + " print empirical failure rate.\n", + "\n", + " This tells you how stable a given parameter set is with your\n", + " hardware multiplier.\n", + " \"\"\"\n", + " successes = 0\n", + " for t in range(trials):\n", + " if roundtrip_once(verbose=False):\n", + " successes += 1\n", + " failures = trials - successes\n", + " failure_rate = failures / trials\n", + " print(f\"USE_COMPRESSION={USE_COMPRESSION}, \"\n", + " f\"ETA1={ETA1}, ETA2={ETA2}, DU={DU}, DV={DV}, \"\n", + " f\"NOISE_DIV_KEY={NOISE_DIV_KEY}, NOISE_DIV_ENC={NOISE_DIV_ENC}\")\n", + " print(f\"Trials: {trials}, successes: {successes}, \"\n", + " f\"failures: {failures}, failure rate ≈ {failure_rate:.4f}\")\n", + "\n", + "\n", + "# ---------------------------------------------------------\n", + "# Example usage with current presets\n", + "# ---------------------------------------------------------\n", + "print(\"Single roundtrip with current global settings:\")\n", + "roundtrip_once()\n", + "\n", + "print(\"\\nMonte-Carlo over 20 trials:\")\n", + "monte_carlo_roundtrip(trials=20)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "2c2a0c6e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Single roundtrip with current global settings:\n", + "USE_COMPRESSION=True, ETA1=3, ETA2=2, DU=10, DV=4, NOISE_DIV_KEY=4, NOISE_DIV_ENC=4\n", + "Decryption successful? True\n", + "\n", + "Monte-Carlo over 20 trials:\n", + "USE_COMPRESSION=True, ETA1=3, ETA2=2, DU=10, DV=4, NOISE_DIV_KEY=4, NOISE_DIV_ENC=4\n", + "Trials: 20, successes: 20, failures: 0, failure rate ≈ 0.0000\n" + ] + } + ], + "source": [ + "# oho O-3\n", + "# Variante 3 des O Experiments\n", + "\n", + "# (again, Noise Params must be set to 4 each)\n", + "\n", + "import numpy as np\n", + "import math\n", + "\n", + "# =========================================================\n", + "# ASSUMPTIONS: these exist in your notebook already\n", + "# =========================================================\n", + "# q : modulus (int) -> should be 3329\n", + "# n2 : polynomial length -> should be 256\n", + "# k : module rank -> should be 2\n", + "# cbd_vector : cbd_vector(n, eta, k)\n", + "# cbd : cbd(n, eta)\n", + "# compress : compress(poly, q, d)\n", + "# decompress : decompress(poly, q, d)\n", + "# poly_mul_hw: hardware polymult wrapper (DMA + IP), signature:\n", + "# poly_mul_hw(poly_a: np.ndarray, poly_b: np.ndarray) -> np.ndarray (len n2, mod q)\n", + "# =========================================================\n", + "\n", + "# Sanity checks (warn only)\n", + "try:\n", + " if q != 3329:\n", + " print(f\"WARNING: q={q}, expected 3329 for Kyber512/ML-KEM-512.\")\n", + " if n2 != 256:\n", + " print(f\"WARNING: n2={n2}, expected 256.\")\n", + " if k != 2:\n", + " print(f\"WARNING: k={k}, expected 2 (Kyber512-style).\")\n", + "except NameError:\n", + " print(\"WARNING: q, n2, k not all defined before this cell.\")\n", + "\n", + "\n", + "# =========================================================\n", + "# GLOBAL TUNING KNOBS (EDIT ONLY THIS BLOCK)\n", + "# =========================================================\n", + "# 1) Kyber / ML-KEM-style parameters\n", + "ETA1 = 3 # Kyber512/ML-KEM-512: eta1 = 3\n", + "ETA2 = 2 # Kyber512/ML-KEM-512: eta2 = 2\n", + "DU = 10 # Kyber512/ML-KEM-512: du = 10\n", + "DV = 4 # Kyber512/ML-KEM-512: dv = 4\n", + "\n", + "# 2) Extra “safety” scaling for noise amplitudes\n", + "# (these are *not* spec knobs, just HW-tuning knobs)\n", + "USE_COMPRESSION = True # True -> compress/decompress(u,v) with (DU,DV); False -> bypass\n", + "NOISE_DIV_KEY = 4 # scale s,e in keygen by 1/NOISE_DIV_KEY (1 = spec-like, 3 = conservative)\n", + "NOISE_DIV_ENC = 4 # scale r,e1,e2 in encrypt by 1/NOISE_DIV_ENC (same interpretation)\n", + "# Good profiles to try:\n", + "# - \"Safe baseline\" (what you saw working): ETA1=3, ETA2=2, DU=10, DV=4, USE_COMPRESSION=True,\n", + "# NOISE_DIV_KEY=3, NOISE_DIV_ENC=3\n", + "# - \"Move toward spec\": keep ETA1/ETA2/DU/DV, gradually try:\n", + "# (KEY,ENC) = (3,3) -> (2,3) -> (2,2) -> (1,2) -> (1,1)\n", + "# =========================================================\n", + "\n", + "\n", + "# ---------------------------------------------------------\n", + "# Utility: integer noise scaling\n", + "# ---------------------------------------------------------\n", + "def scale_noise(x, div: int):\n", + " \"\"\"\n", + " Scale integer noise by integer divisor 'div' with truncation toward 0.\n", + "\n", + " - div = 1 -> original CBD(eta) noise distribution.\n", + " - div = 2 -> approximately halves magnitude.\n", + " - div = 3 -> ~one-third, etc.\n", + " \"\"\"\n", + " x = np.asarray(x, dtype=np.int32)\n", + " if div == 1:\n", + " return x.astype(np.int16)\n", + " return (np.sign(x) * (np.abs(x) // div)).astype(np.int16)\n", + "\n", + "\n", + "# ---------------------------------------------------------\n", + "# HW-based Kyber-style routines using the global knobs\n", + "# ---------------------------------------------------------\n", + "def key_gen2_hw_cellZZC():\n", + " \"\"\"\n", + " Key generation using hardware poly_mult and global parameters:\n", + "\n", + " - Ring: Z_q[X]/(X^256 + 1), q=3329 (assumed by the rest of your code).\n", + " - k = 2 module dimension (Kyber512/ML-KEM-512 style).\n", + " - a ~ U(Z_q)^{k×k×n2}.\n", + " - s,e ~ CBD(ETA1), then scaled by NOISE_DIV_KEY.\n", + "\n", + " Returns (s, a, b) with b = A*s + e.\n", + " \"\"\"\n", + " # public matrix a: shape (k, k, n2)\n", + " a = np.random.randint(q, size=(k, k, n2), dtype=np.int16)\n", + "\n", + " # raw noise\n", + " s_raw = cbd_vector(n2, ETA1, k) # shape (k, n2)\n", + " e_raw = cbd_vector(n2, ETA1, k) # shape (k, n2)\n", + "\n", + " # scaled noise according to global divisor\n", + " s = scale_noise(s_raw, NOISE_DIV_KEY)\n", + " e = scale_noise(e_raw, NOISE_DIV_KEY)\n", + "\n", + " # same structure as your original key_gen2, but with HW multiply\n", + " b0 = (poly_mul_hw(a[0,0], s[0]) + e[0]) % q\n", + " b1 = (poly_mul_hw(a[0,1], s[1]) + e[1]) % q\n", + " b2 = (poly_mul_hw(a[1,0], s[0]) + e[0]) % q\n", + " b3 = (poly_mul_hw(a[1,1], s[1]) + e[1]) % q\n", + "\n", + " b01 = (b0 + b1) % q\n", + " b23 = (b2 + b3) % q\n", + " b = np.array([b01, b23], dtype=np.int16)\n", + " return s, a, b\n", + "\n", + "\n", + "def encrypt2_hw_cellZZC(a, b, m):\n", + " \"\"\"\n", + " Encryption using hardware poly_mult and global parameters:\n", + "\n", + " - a, b as from key_gen2_hw.\n", + " - message m: length-n2 polynomial already embedded in Z_q.\n", + " - r ~ CBD(ETA1), e1 ~ CBD(ETA2), e2 ~ CBD(ETA2), then all\n", + " scaled by NOISE_DIV_ENC.\n", + " - DU, DV control compression of u and v, if USE_COMPRESSION is True.\n", + " \"\"\"\n", + " # raw noise\n", + " r_raw = cbd_vector(n2, ETA1, k) # (k, n2)\n", + " e1_raw = cbd_vector(n2, ETA2, k) # (k, n2)\n", + " e2_raw = cbd(n2, ETA2) # (n2,)\n", + "\n", + " # scaled noise\n", + " r = scale_noise(r_raw, NOISE_DIV_ENC)\n", + " e1 = scale_noise(e1_raw, NOISE_DIV_ENC)\n", + " e2 = scale_noise(e2_raw, NOISE_DIV_ENC)\n", + "\n", + " # u part: A * r + e1\n", + " u0 = (poly_mul_hw(a[0,0], r[0]) + e1[0]) % q\n", + " u1 = (poly_mul_hw(a[1,0], r[1]) + e1[1]) % q\n", + " u2 = (poly_mul_hw(a[0,1], r[0]) + e1[0]) % q\n", + " u3 = (poly_mul_hw(a[1,1], r[1]) + e1[1]) % q\n", + "\n", + " u01 = (u0 + u1) % q\n", + " u23 = (u2 + u3) % q\n", + " u = np.array([u01, u23], dtype=np.int16)\n", + "\n", + " # v part: b * r + e2 + m\n", + " v0 = poly_mul_hw(b[0], r[0]) % q\n", + " v1 = poly_mul_hw(b[1], r[1]) % q\n", + " v = (v0 + v1 + e2 + m) % q\n", + "\n", + " if USE_COMPRESSION:\n", + " u_c = compress(u, q, DU)\n", + " v_c = compress(v, q, DV)\n", + " else:\n", + " u_c = u.astype(np.int16)\n", + " v_c = v.astype(np.int16)\n", + "\n", + " return u_c, v_c\n", + "\n", + "\n", + "def decrypt2_hw_cellZZC(s, u, v):\n", + " \"\"\"\n", + " Decryption using hardware poly_mult and global parameters:\n", + "\n", + " - s: secret key (k × n2)\n", + " - (u, v): ciphertext components, compressed with (DU, DV)\n", + " if USE_COMPRESSION is True.\n", + "\n", + " Decryption computes p = s·u and d = v - p (mod q).\n", + " \"\"\"\n", + " if USE_COMPRESSION:\n", + " u_dec = decompress(u, q, DU)\n", + " v_dec = decompress(v, q, DV)\n", + " else:\n", + " u_dec = u.astype(np.int16)\n", + " v_dec = v.astype(np.int16)\n", + "\n", + " p0 = poly_mul_hw(s[0], u_dec[0]) % q\n", + " p1 = poly_mul_hw(s[1], u_dec[1]) % q\n", + " p = (p0 + p1) % q\n", + " d = (v_dec - p) % q\n", + " return d\n", + "\n", + "\n", + "# ---------------------------------------------------------\n", + "# Roundtrip and Monte-Carlo\n", + "# ---------------------------------------------------------\n", + "def roundtrip_once(verbose: bool = True):\n", + " \"\"\"\n", + " One Decrypt(Encrypt(m)) round using the current global knobs.\n", + "\n", + " - m ∈ {0,1}^n2 uniformly random.\n", + " - Embedding via decompress(m, q, 1) (your original convention).\n", + " - Uses key_gen2_hw / encrypt2_hw / decrypt2_hw.\n", + " - Decoding via threshold rule: 0 if in [0, q/4] ∪ [3q/4, q),\n", + " 1 if in (q/4, 3q/4).\n", + "\n", + " Returns True if m == md bitwise.\n", + " \"\"\"\n", + " # random binary message\n", + " m = np.random.randint(2, size=(n2,))\n", + " # message embedding (this is independent of DU/DV and ETA*)\n", + " ms = decompress(m, q, 1)\n", + "\n", + " s, a, b = key_gen2_hw_cellZZC()\n", + " u, v = encrypt2_hw_cellZZC(a, b, ms)\n", + " d = decrypt2_hw_cellZZC(s, u, v)\n", + "\n", + " th1 = math.floor(q / 4)\n", + " th3 = math.floor(3 * q / 4)\n", + " md = np.array([1 if (th1 < x < th3) else 0 for x in d], dtype=np.int8)\n", + "\n", + " ok = np.array_equal(m, md)\n", + "\n", + " if verbose:\n", + " print(f\"USE_COMPRESSION={USE_COMPRESSION}, \"\n", + " f\"ETA1={ETA1}, ETA2={ETA2}, DU={DU}, DV={DV}, \"\n", + " f\"NOISE_DIV_KEY={NOISE_DIV_KEY}, NOISE_DIV_ENC={NOISE_DIV_ENC}\")\n", + " print(\"Decryption successful? \", ok)\n", + " if not ok:\n", + " idx = np.where(m != md)[0][0]\n", + " print(\"First mismatch at index\", idx)\n", + " print(\"m[idx] =\", int(m[idx]))\n", + " print(\"md[idx] =\", int(md[idx]))\n", + " return ok\n", + "\n", + "\n", + "def monte_carlo_roundtrip(trials: int = 50):\n", + " \"\"\"\n", + " Run multiple roundtrips with current global knobs;\n", + " print empirical failure rate.\n", + "\n", + " Use this to see how a given parameter set behaves with your HW.\n", + " \"\"\"\n", + " successes = 0\n", + " for t in range(trials):\n", + " if roundtrip_once(verbose=False):\n", + " successes += 1\n", + " failures = trials - successes\n", + " failure_rate = failures / trials\n", + " print(f\"USE_COMPRESSION={USE_COMPRESSION}, \"\n", + " f\"ETA1={ETA1}, ETA2={ETA2}, DU={DU}, DV={DV}, \"\n", + " f\"NOISE_DIV_KEY={NOISE_DIV_KEY}, NOISE_DIV_ENC={NOISE_DIV_ENC}\")\n", + " print(f\"Trials: {trials}, successes: {successes}, \"\n", + " f\"failures: {failures}, failure rate ≈ {failure_rate:.4f}\")\n", + "\n", + "\n", + "# ---------------------------------------------------------\n", + "# Example usage with current presets\n", + "# ---------------------------------------------------------\n", + "print(\"Single roundtrip with current global settings:\")\n", + "roundtrip_once()\n", + "\n", + "print(\"\\nMonte-Carlo over 20 trials:\")\n", + "monte_carlo_roundtrip(trials=20)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a6ff41f", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb6b6971", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}