mirror of
https://github.com/saymrwulf/pqc-accelerate.git
synced 2026-05-14 20:48:07 +00:00
1566 lines
54 KiB
Text
1566 lines
54 KiB
Text
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "33632a5e",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# oho A\n",
|
||
"# Install and import liboqs-python (software Dilithium)\n",
|
||
"# If you don't have liboqs / liboqs-python yet, uncomment the next lines.\n",
|
||
"# On a ZCU104 this will compile on ARM; it may take a while.\n",
|
||
"#\n",
|
||
"#!git clone --depth=1 https://github.com/open-quantum-safe/liboqs-python.git\n",
|
||
"#%cd liboqs-python\n",
|
||
"#!pip install .\n",
|
||
"#%cd -"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"id": "6093010b",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Enabled signature mechanisms:\n",
|
||
"('Dilithium2', 'Dilithium3', 'Dilithium5', 'ML-DSA-44', 'ML-DSA-65', 'ML-DSA-87', 'Falcon-512', 'Falcon-1024', 'Falcon-padded-512', 'Falcon-padded-1024', 'SPHINCS+-SHA2-128f-simple', 'SPHINCS+-SHA2-128s-simple', 'SPHINCS+-SHA2-192f-simple', 'SPHINCS+-SHA2-192s-simple', 'SPHINCS+-SHA2-256f-simple', 'SPHINCS+-SHA2-256s-simple', 'SPHINCS+-SHAKE-128f-simple', 'SPHINCS+-SHAKE-128s-simple', 'SPHINCS+-SHAKE-192f-simple', 'SPHINCS+-SHAKE-192s-simple', 'SPHINCS+-SHAKE-256f-simple', 'SPHINCS+-SHAKE-256s-simple', 'MAYO-1', 'MAYO-2', 'MAYO-3', 'MAYO-5', 'cross-rsdp-128-balanced', 'cross-rsdp-128-fast', 'cross-rsdp-128-small', 'cross-rsdp-192-balanced', 'cross-rsdp-192-fast', 'cross-rsdp-192-small', 'cross-rsdp-256-balanced', 'cross-rsdp-256-fast', 'cross-rsdp-256-small', 'cross-rsdpg-128-balanced', 'cross-rsdpg-128-fast', 'cross-rsdpg-128-small', 'cross-rsdpg-192-balanced', 'cross-rsdpg-192-fast', 'cross-rsdpg-192-small', 'cross-rsdpg-256-balanced', 'cross-rsdpg-256-fast', 'cross-rsdpg-256-small', 'OV-Is', 'OV-Ip', 'OV-III', 'OV-V', 'OV-Is-pkc', 'OV-Ip-pkc', 'OV-III-pkc', 'OV-V-pkc', 'OV-Is-pkc-skc', 'OV-Ip-pkc-skc', 'OV-III-pkc-skc', 'OV-V-pkc-skc', 'SNOVA_24_5_4', 'SNOVA_24_5_4_SHAKE', 'SNOVA_24_5_4_esk', 'SNOVA_24_5_4_SHAKE_esk', 'SNOVA_37_17_2', 'SNOVA_25_8_3', 'SNOVA_56_25_2', 'SNOVA_49_11_3', 'SNOVA_37_8_4', 'SNOVA_24_5_5', 'SNOVA_60_10_4', 'SNOVA_29_6_5')\n",
|
||
"Using algorithm: ML-DSA-44\n",
|
||
"Software-only Dilithium verify OK? True\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Cell 1: \n",
|
||
"\n",
|
||
"\n",
|
||
"import numpy as np\n",
|
||
"import math\n",
|
||
"import time\n",
|
||
"from pynq import Overlay\n",
|
||
"from pynq import allocate\n",
|
||
"from pynq import MMIO\n",
|
||
"\n",
|
||
"import oqs\n",
|
||
"\n",
|
||
"# Check which signature mechanisms are available\n",
|
||
"print(\"Enabled signature mechanisms:\")\n",
|
||
"print(oqs.get_enabled_sig_mechanisms())\n",
|
||
"\n",
|
||
"# Prefer the ML-DSA name; fall back to legacy Dilithium2 if needed\n",
|
||
"if \"ML-DSA-44\" in oqs.get_enabled_sig_mechanisms():\n",
|
||
" DILITHIUM_ALG = \"ML-DSA-44\"\n",
|
||
"elif \"Dilithium2\" in oqs.get_enabled_sig_mechanisms():\n",
|
||
" DILITHIUM_ALG = \"Dilithium2\"\n",
|
||
"else:\n",
|
||
" raise RuntimeError(\"No Dilithium / ML-DSA-44 implementation enabled in liboqs.\")\n",
|
||
"\n",
|
||
"print(\"Using algorithm:\", DILITHIUM_ALG)\n",
|
||
"\n",
|
||
"# Simple software-only keygen / sign / verify\n",
|
||
"message = b\"Hello from ZCU104 + Dilithium\"\n",
|
||
"\n",
|
||
"with oqs.Signature(DILITHIUM_ALG) as signer:\n",
|
||
" public_key = signer.generate_keypair()\n",
|
||
" secret_key = signer.export_secret_key()\n",
|
||
"\n",
|
||
" signature = signer.sign(message)\n",
|
||
" ok = signer.verify(message, signature, public_key)\n",
|
||
"\n",
|
||
"print(\"Software-only Dilithium verify OK?\", ok)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "d32bcf5f",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "d8913f4a",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"id": "088eb969",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"total 19388\n",
|
||
"drwxr-xr-x 3 root root 4096 Dec 13 15:53 .\n",
|
||
"drwxrwxrwx 9 xilinx xilinx 4096 Dec 13 15:57 ..\n",
|
||
"-rw-r--r-- 1 root root 19311209 Dec 13 15:53 base.bit\n",
|
||
"-rw-r--r-- 1 root root 527952 Dec 13 15:52 base.hwh\n",
|
||
"drwxr-xr-x 2 root root 4096 Dec 13 15:23 .ipynb_checkpoints\n",
|
||
"Sat Dec 13 03:57:25 PM UTC 2025\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"#oho B\n",
|
||
"!ls -la /home/xilinx/jupyter_notebooks/kyber_ntt_and_dilithium_ntt\n",
|
||
"!date\n",
|
||
"# Loading the kyber-ntt_and_dilithium_ntt bit file to configure the PL\n",
|
||
"ol = Overlay('/home/xilinx/jupyter_notebooks/kyber_ntt_and_dilithium_ntt/base.bit')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"id": "fc922ce1",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"dict_keys(['axi_dma_0', 'poly_mult_0', 'poly_mult_dil_0', 'axi_dma_1', 'ps_e_0'])\n",
|
||
"poly_mult_dil_0 @ 0x80020000\n",
|
||
"send _max_size: 16383\n",
|
||
"recv _max_size: 16383\n",
|
||
"{'C_DLYTMR_RESOLUTION': '125', 'C_ENABLE_MULTI_CHANNEL': '0', 'C_FAMILY': 'zynquplus', 'C_INCLUDE_MM2S': '1', 'C_INCLUDE_MM2S_DRE': '0', 'C_INCLUDE_MM2S_SF': '1', 'C_INCLUDE_S2MM': '1', 'C_INCLUDE_S2MM_DRE': '0', 'C_INCLUDE_S2MM_SF': '1', 'C_INCLUDE_SG': '0', 'C_INCREASE_THROUGHPUT': '0', 'C_MICRO_DMA': '0', 'C_MM2S_BURST_SIZE': '16', 'C_M_AXIS_MM2S_CNTRL_TDATA_WIDTH': '32', 'C_M_AXIS_MM2S_TDATA_WIDTH': '64', 'C_M_AXI_MM2S_ADDR_WIDTH': '32', 'C_M_AXI_MM2S_DATA_WIDTH': '64', 'C_M_AXI_S2MM_ADDR_WIDTH': '32', 'C_M_AXI_S2MM_DATA_WIDTH': '32', 'C_M_AXI_SG_ADDR_WIDTH': '32', 'C_M_AXI_SG_DATA_WIDTH': '32', 'C_NUM_MM2S_CHANNELS': '1', 'C_NUM_S2MM_CHANNELS': '1', 'C_PRMRY_IS_ACLK_ASYNC': '0', 'C_S2MM_BURST_SIZE': '16', 'C_SG_INCLUDE_STSCNTRL_STRM': '0', 'C_SG_LENGTH_WIDTH': '14', 'C_SG_USE_STSAPP_LENGTH': '0', 'C_S_AXIS_S2MM_STS_TDATA_WIDTH': '32', 'C_S_AXIS_S2MM_TDATA_WIDTH': '32', 'C_S_AXI_LITE_ADDR_WIDTH': '10', 'C_S_AXI_LITE_DATA_WIDTH': '32', 'Component_Name': 'base_axi_dma_1_0', 'c_addr_width': '32', 'c_dlytmr_resolution': '125', 'c_enable_multi_channel': '0', 'c_include_mm2s': '1', 'c_include_mm2s_dre': '0', 'c_include_mm2s_sf': '1', 'c_include_s2mm': '1', 'c_include_s2mm_dre': '0', 'c_include_s2mm_sf': '1', 'c_include_sg': '0', 'c_increase_throughput': '0', 'c_m_axi_mm2s_data_width': '64', 'c_m_axi_s2mm_data_width': '32', 'c_m_axis_mm2s_tdata_width': '64', 'c_micro_dma': '0', 'c_mm2s_burst_size': '16', 'c_num_mm2s_channels': '1', 'c_num_s2mm_channels': '1', 'c_prmry_is_aclk_async': '0', 'c_s2mm_burst_size': '16', 'c_s_axis_s2mm_tdata_width': '32', 'c_sg_include_stscntrl_strm': '0', 'c_sg_length_width': '14', 'c_sg_use_stsapp_length': '0', 'c_single_interface': '0', 'EDK_IPTYPE': 'PERIPHERAL', 'C_BASEADDR': '0x80030000', 'C_HIGHADDR': '0x8003FFFF', 'ADDR_WIDTH': '10', 'ARUSER_WIDTH': '0', 'AWUSER_WIDTH': '0', 'BUSER_WIDTH': '0', 'CLK_DOMAIN': 'base_ps_e_0_0_pl_clk0', 'DATA_WIDTH': '32', 'FREQ_HZ': '99999001', 'HAS_BRESP': '1', 'HAS_BURST': '0', 'HAS_CACHE': '0', 'HAS_LOCK': '0', 'HAS_PROT': '0', 'HAS_QOS': '0', 'HAS_REGION': '0', 'HAS_RRESP': '1', 'HAS_WSTRB': '0', 'ID_WIDTH': '0', 'INSERT_VIP': '0', 'MAX_BURST_LENGTH': '1', 'NUM_READ_OUTSTANDING': '1', 'NUM_READ_THREADS': '1', 'NUM_WRITE_OUTSTANDING': '1', 'NUM_WRITE_THREADS': '1', 'PHASE': '0.0', 'PROTOCOL': 'AXI4LITE', 'READ_WRITE_MODE': 'READ_WRITE', 'RUSER_BITS_PER_BYTE': '0', 'RUSER_WIDTH': '0', 'SUPPORTS_NARROW_BURST': '0', 'WUSER_BITS_PER_BYTE': '0', 'WUSER_WIDTH': '0', 'HAS_TKEEP': '1', 'HAS_TLAST': '1', 'HAS_TREADY': '1', 'HAS_TSTRB': '0', 'LAYERED_METADATA': 'undef', 'TDATA_NUM_BYTES': '8', 'TDEST_WIDTH': '0', 'TID_WIDTH': '0', 'TUSER_WIDTH': '0'}\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"#oho E\n",
|
||
"#oho inquiry for ZCU104 Board\n",
|
||
"\n",
|
||
"# Discover the real base address from the .hwh\n",
|
||
"print(ol.ip_dict.keys()) # just to see exact IP names\n",
|
||
"\n",
|
||
"poly_name = 'poly_mult_dil_0' # adjust if Vivado named it differently\n",
|
||
"poly_info = ol.ip_dict[poly_name]\n",
|
||
"print(\"poly_mult_dil_0 @\", hex(poly_info['phys_addr']))\n",
|
||
"mmio = MMIO(poly_info['phys_addr'], poly_info['addr_range'])\n",
|
||
"\n",
|
||
"# oho !!! dma_1 gehört zu dilithium accelerator, not dma_0 !\n",
|
||
"dma = ol.axi_dma_1\n",
|
||
"dma_send = dma.sendchannel\n",
|
||
"dma_recv = dma.recvchannel\n",
|
||
"\n",
|
||
"\n",
|
||
"# 1) What PYNQ thinks the max dma channel size is:\n",
|
||
"print(\"send _max_size:\", dma.sendchannel._max_size)\n",
|
||
"print(\"recv _max_size:\", dma.recvchannel._max_size)\n",
|
||
"\n",
|
||
"# 2) versus: Dump the raw IP parameters from the .hwh:\n",
|
||
"print(ol.ip_dict['axi_dma_1']['parameters'])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "dafae649",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "3572bb8e",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"id": "1db1b137",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\n",
|
||
"# 4. Dilithium ring parameters\n",
|
||
"Q = 8380417\n",
|
||
"N = 256\n",
|
||
"\n",
|
||
"def center_mod_q(x):\n",
|
||
" \"\"\"Map integer x to centered representative in (-Q/2, Q/2].\"\"\"\n",
|
||
" x = int(x) % Q\n",
|
||
" if x > Q // 2:\n",
|
||
" x -= Q\n",
|
||
" return x\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"id": "af55b05f",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def hw_poly_mult(a, b):\n",
|
||
" \"\"\"\n",
|
||
" Hardware polynomial multiplication in R_q[X]/(X^N + 1).\n",
|
||
"\n",
|
||
" a, b: 1D numpy arrays of length N with integer entries (Python ints or np.int32/int64).\n",
|
||
" Returns: numpy array length N, centered mod Q.\n",
|
||
" \"\"\"\n",
|
||
" a = np.asarray(a, dtype=np.int64)\n",
|
||
" b = np.asarray(b, dtype=np.int64)\n",
|
||
" assert a.shape == (N,)\n",
|
||
" assert b.shape == (N,)\n",
|
||
"\n",
|
||
" # Allocate DMA buffers\n",
|
||
" # Input: N 64-bit words, each packing two 32-bit coefficients.\n",
|
||
" in_buf = allocate(shape=(N,), dtype=np.uint64)\n",
|
||
" # Output: N 32-bit coefficients\n",
|
||
" out_buf = allocate(shape=(N,), dtype=np.uint32)\n",
|
||
"\n",
|
||
" # Pack two 32-bit coeffs into one 64-bit word\n",
|
||
" # Convention: low 32 bits = a[i], high 32 bits = b[i]\n",
|
||
" for i in range(N):\n",
|
||
" a_i = int(a[i] % Q) & 0xFFFFFFFF\n",
|
||
" b_i = int(b[i] % Q) & 0xFFFFFFFF\n",
|
||
" word = ((b_i << 32) | a_i) # pure Python ints, no NumPy ufuncs\n",
|
||
" in_buf[i] = np.uint64(word)\n",
|
||
"\n",
|
||
"\n",
|
||
" \n",
|
||
" # oho vorbild kyber \n",
|
||
" # Start IP core (ap_start = 1 at control register 0x00)\n",
|
||
" mmio.write(0x00, 0x1)\n",
|
||
" \n",
|
||
" \n",
|
||
" \n",
|
||
" # Start DMA transfers\n",
|
||
" dma.sendchannel.transfer(in_buf)\n",
|
||
" dma.recvchannel.transfer(out_buf)\n",
|
||
" dma.sendchannel.wait()\n",
|
||
" dma.recvchannel.wait()\n",
|
||
"\n",
|
||
" # Convert back to centered coefficients\n",
|
||
" c = np.empty(N, dtype=np.int64)\n",
|
||
" for i in range(N):\n",
|
||
" c[i] = center_mod_q(out_buf[i])\n",
|
||
"\n",
|
||
" # Optional: free buffers (PYNQ also cleans them up via GC, but explicit is nice on small boards)\n",
|
||
" in_buf.freebuffer()\n",
|
||
" out_buf.freebuffer()\n",
|
||
"\n",
|
||
" return c\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"id": "cb8a287b",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Cell 3: Software negacyclic polynomial multiplication in R_q[X]/(X^N + 1)\n",
|
||
"\n",
|
||
"def sw_poly_mult(a, b):\n",
|
||
" \"\"\"\n",
|
||
" O(N^2) negacyclic polynomial multiplication:\n",
|
||
" c(X) = a(X) * b(X) mod (X^N + 1, Q)\n",
|
||
" \"\"\"\n",
|
||
" a = np.asarray(a, dtype=np.int64)\n",
|
||
" b = np.asarray(b, dtype=np.int64)\n",
|
||
" assert a.shape == (N,)\n",
|
||
" assert b.shape == (N,)\n",
|
||
"\n",
|
||
" c = np.zeros(N, dtype=np.int64)\n",
|
||
"\n",
|
||
" for i in range(N):\n",
|
||
" acc = 0\n",
|
||
" for j in range(N):\n",
|
||
" k = i - j\n",
|
||
" if k >= 0:\n",
|
||
" acc += int(a[j]) * int(b[k])\n",
|
||
" else:\n",
|
||
" # Wrap with a minus sign: X^N = -1\n",
|
||
" k += N\n",
|
||
" acc -= int(a[j]) * int(b[k])\n",
|
||
" c[i] = center_mod_q(acc)\n",
|
||
"\n",
|
||
" return c\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "a1e0e420",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"id": "7dfe972f",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Max |c_sw - c_hw| mod Q = 0\n",
|
||
"Number of mismatching coefficients: 0\n",
|
||
"HW matches SW golden for this test vector.\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Cell 4: Compare HW vs SW polynomial multiplication\n",
|
||
"\n",
|
||
"from numpy.random import default_rng\n",
|
||
"\n",
|
||
"rng = default_rng(12345)\n",
|
||
"\n",
|
||
"# Generate random centered coefficients\n",
|
||
"a = rng.integers(low=-Q//2, high=Q//2, size=N, dtype=np.int64)\n",
|
||
"b = rng.integers(low=-Q//2, high=Q//2, size=N, dtype=np.int64)\n",
|
||
"\n",
|
||
"c_sw = sw_poly_mult(a, b)\n",
|
||
"\n",
|
||
"c_hw = hw_poly_mult(a, b)\n",
|
||
"\n",
|
||
"# Compare modulo Q (in case of small representation differences)\n",
|
||
"diff = (c_sw - c_hw) % Q\n",
|
||
"max_diff = int(np.max(np.abs(diff)))\n",
|
||
"num_mism = int(np.count_nonzero(diff))\n",
|
||
"\n",
|
||
"print(\"Max |c_sw - c_hw| mod Q =\", max_diff)\n",
|
||
"print(\"Number of mismatching coefficients:\", num_mism)\n",
|
||
"\n",
|
||
"if num_mism == 0:\n",
|
||
" print(\"HW matches SW golden for this test vector.\")\n",
|
||
"else:\n",
|
||
" print(\"Mismatch: investigate HLS core / packing / reduction.\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "fa590582",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "9a7991e6",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"id": "99023a64",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n",
|
||
"=== Software Dilithium using liboqs ===\n",
|
||
"Keygen time (SW): 0.001 s\n",
|
||
"Sign time (SW): 0.005 s\n",
|
||
"Verify time (SW): 0.001 s\n",
|
||
"Verification OK? True\n",
|
||
"\n",
|
||
"=== Hardware-accelerated poly_mult_dil ===\n",
|
||
"Software poly_mult time: 0.282400 s\n",
|
||
"Hardware poly_mult time (incl. DMA): 0.009862 s\n",
|
||
"Mismatch count: 0\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Cell 5: Tie it together\n",
|
||
"\n",
|
||
"import time\n",
|
||
"\n",
|
||
"# 1. Software Dilithium keygen/sign/verify (liboqs)\n",
|
||
"message = b\"Hardware-software co-design demo on ZCU104\"\n",
|
||
"\n",
|
||
"with oqs.Signature(DILITHIUM_ALG) as signer:\n",
|
||
" print(\"\\n=== Software Dilithium using liboqs ===\")\n",
|
||
" t0 = time.time()\n",
|
||
" public_key = signer.generate_keypair()\n",
|
||
" secret_key = signer.export_secret_key()\n",
|
||
" t1 = time.time()\n",
|
||
" print(\"Keygen time (SW): %.3f s\" % (t1 - t0))\n",
|
||
"\n",
|
||
" t0 = time.time()\n",
|
||
" signature = signer.sign(message)\n",
|
||
" t1 = time.time()\n",
|
||
" print(\"Sign time (SW): %.3f s\" % (t1 - t0))\n",
|
||
"\n",
|
||
" t0 = time.time()\n",
|
||
" ok = signer.verify(message, signature, public_key)\n",
|
||
" t1 = time.time()\n",
|
||
" print(\"Verify time (SW): %.3f s\" % (t1 - t0))\n",
|
||
" print(\"Verification OK?\", ok)\n",
|
||
"\n",
|
||
"# 2. Hardware-accelerated polynomial multiplication demo\n",
|
||
"print(\"\\n=== Hardware-accelerated poly_mult_dil ===\")\n",
|
||
"\n",
|
||
"a = rng.integers(low=-Q//2, high=Q//2, size=N, dtype=np.int64)\n",
|
||
"b = rng.integers(low=-Q//2, high=Q//2, size=N, dtype=np.int64)\n",
|
||
"\n",
|
||
"t0 = time.time()\n",
|
||
"c_sw = sw_poly_mult(a, b)\n",
|
||
"t1 = time.time()\n",
|
||
"print(\"Software poly_mult time: %.6f s\" % (t1 - t0))\n",
|
||
"\n",
|
||
"t0 = time.time()\n",
|
||
"c_hw = hw_poly_mult(a, b)\n",
|
||
"t1 = time.time()\n",
|
||
"print(\"Hardware poly_mult time (incl. DMA): %.6f s\" % (t1 - t0))\n",
|
||
"\n",
|
||
"diff = (c_sw - c_hw) % Q\n",
|
||
"num_mism = int(np.count_nonzero(diff))\n",
|
||
"print(\"Mismatch count:\", num_mism)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "33dcd211",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "36fb8ddb",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"###########################################################################################"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "ac4e49fd",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"## Kyber-Notebook analoge Test Verfahren"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "a9ec9ac1",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"id": "818b97e0",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# oho D_dil\n",
|
||
"# 512-point NTT / INTT in pure Python, matching the HLS core\n",
|
||
"# and a SW NTT-based polynomial multiplication in R_q[X]/(X^256 + 1).\n",
|
||
"\n",
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"# Ring parameters (keep them in sync with your hardware)\n",
|
||
"Q = 8380417\n",
|
||
"N = 256\n",
|
||
"N2 = 2 * N\n",
|
||
"\n",
|
||
"# These are exactly the twiddle constants we used in the HLS core\n",
|
||
"NTT_WLEN = [\n",
|
||
" 8380416, 4808194, 4614810, 2883726,\n",
|
||
" 6250525, 7044481, 3241972, 6644104,\n",
|
||
" 1921994\n",
|
||
"]\n",
|
||
"\n",
|
||
"NTT_WLEN_INV = [\n",
|
||
" 8380416, 3572223, 3761513, 5234739,\n",
|
||
" 3764867, 3227876, 6621070, 6125690,\n",
|
||
" 527981\n",
|
||
"]\n",
|
||
"\n",
|
||
"INV_NTT512 = 8364049 # 512^{-1} mod Q\n",
|
||
"\n",
|
||
"def mod_q(x: int) -> int:\n",
|
||
" \"\"\"Reduce integer x modulo Q into [0, Q).\"\"\"\n",
|
||
" r = x % Q\n",
|
||
" if r < 0:\n",
|
||
" r += Q\n",
|
||
" return r\n",
|
||
"\n",
|
||
"def modadd(a: int, b: int) -> int:\n",
|
||
" s = a + b\n",
|
||
" if s >= Q:\n",
|
||
" s -= Q\n",
|
||
" return s\n",
|
||
"\n",
|
||
"def modsub(a: int, b: int) -> int:\n",
|
||
" d = a - b\n",
|
||
" if d < 0:\n",
|
||
" d += Q\n",
|
||
" return d\n",
|
||
"\n",
|
||
"def mul_mod(a: int, b: int) -> int:\n",
|
||
" return mod_q(a * b)\n",
|
||
"\n",
|
||
"def ntt_512_py(a, invert: bool = False):\n",
|
||
" \"\"\"\n",
|
||
" In-place iterative radix-2 NTT of size 512, pure Python.\n",
|
||
"\n",
|
||
" a: iterable of length 512, entries in Z_Q.\n",
|
||
" invert = False: forward NTT\n",
|
||
" invert = True: inverse NTT (includes multiply by 512^{-1} mod Q)\n",
|
||
" \"\"\"\n",
|
||
" # Copy into a plain Python list of ints\n",
|
||
" a = [mod_q(int(x)) for x in a]\n",
|
||
" n2 = N2\n",
|
||
"\n",
|
||
" # Bit-reversal permutation\n",
|
||
" j = 0\n",
|
||
" for i in range(1, n2):\n",
|
||
" bit = n2 >> 1\n",
|
||
" while j & bit:\n",
|
||
" j ^= bit\n",
|
||
" bit >>= 1\n",
|
||
" j ^= bit\n",
|
||
" if i < j:\n",
|
||
" a[i], a[j] = a[j], a[i]\n",
|
||
"\n",
|
||
" length = 2\n",
|
||
" stage = 0\n",
|
||
" while length <= n2:\n",
|
||
" wlen = NTT_WLEN_INV[stage] if invert else NTT_WLEN[stage]\n",
|
||
" half = length >> 1\n",
|
||
"\n",
|
||
" for i in range(0, n2, length):\n",
|
||
" w = 1\n",
|
||
" for k in range(half):\n",
|
||
" u = a[i + k]\n",
|
||
" v = mul_mod(a[i + k + half], w)\n",
|
||
" a[i + k] = modadd(u, v)\n",
|
||
" a[i + k + half] = modsub(u, v)\n",
|
||
" w = mul_mod(w, wlen)\n",
|
||
"\n",
|
||
" length <<= 1\n",
|
||
" stage += 1\n",
|
||
"\n",
|
||
" if invert:\n",
|
||
" # Multiply by 512^{-1} mod Q\n",
|
||
" for i in range(n2):\n",
|
||
" a[i] = mul_mod(a[i], INV_NTT512)\n",
|
||
"\n",
|
||
" return a\n",
|
||
"\n",
|
||
"def poly_mul_sw_ntt(a, b):\n",
|
||
" \"\"\"\n",
|
||
" SW NTT-based polynomial multiplication:\n",
|
||
" c(X) = a(X) * b(X) mod (X^256 + 1, Q)\n",
|
||
" using the same 512-pt NTT+fold as the HLS core.\n",
|
||
"\n",
|
||
" a, b: 1D arrays of length N (can be centered ints).\n",
|
||
" Returns: length-N numpy array, centered mod Q.\n",
|
||
" \"\"\"\n",
|
||
" a = np.asarray(a, dtype=np.int64)\n",
|
||
" b = np.asarray(b, dtype=np.int64)\n",
|
||
" assert a.shape == (N,) and b.shape == (N,)\n",
|
||
"\n",
|
||
" # Map inputs into [0, Q)\n",
|
||
" A = [mod_q(int(x)) for x in a] + [0] * (N2 - N)\n",
|
||
" B = [mod_q(int(x)) for x in b] + [0] * (N2 - N)\n",
|
||
"\n",
|
||
" # Forward NTT\n",
|
||
" A = ntt_512_py(A, invert=False)\n",
|
||
" B = ntt_512_py(B, invert=False)\n",
|
||
"\n",
|
||
" # Pointwise multiply in NTT domain\n",
|
||
" C = [mul_mod(A[i], B[i]) for i in range(N2)]\n",
|
||
"\n",
|
||
" # Inverse NTT\n",
|
||
" C = ntt_512_py(C, invert=True)\n",
|
||
"\n",
|
||
" # Negacyclic fold: c[k] = C[k] - C[k+N] mod Q\n",
|
||
" c = np.empty(N, dtype=np.int64)\n",
|
||
" for k in range(N):\n",
|
||
" val = modsub(C[k], C[k + N])\n",
|
||
" # Use your existing centered representation helper if you have it\n",
|
||
" # Otherwise: center in (-Q/2, Q/2]\n",
|
||
" if 'center_mod_q' in globals():\n",
|
||
" c[k] = center_mod_q(val)\n",
|
||
" else:\n",
|
||
" # fallback centered mapping\n",
|
||
" r = val % Q\n",
|
||
" if r > Q // 2:\n",
|
||
" r -= Q\n",
|
||
" c[k] = r\n",
|
||
"\n",
|
||
" return c\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "12e59835",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"id": "f676296e",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"liboqs available, using algorithm: ML-DSA-44\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# oho F_dil – liboqs setup for Dilithium\n",
|
||
"\n",
|
||
"HAS_OQS = False\n",
|
||
"DILITHIUM_ALG = None\n",
|
||
"oqs_signer = None\n",
|
||
"oqs_pk = None\n",
|
||
"\n",
|
||
"try:\n",
|
||
" import oqs\n",
|
||
"\n",
|
||
" enabled = oqs.get_enabled_sig_mechanisms()\n",
|
||
" # Prefer ML-DSA-44 if present, else fall back to Dilithium2\n",
|
||
" if \"ML-DSA-44\" in enabled:\n",
|
||
" DILITHIUM_ALG = \"ML-DSA-44\"\n",
|
||
" elif \"Dilithium2\" in enabled:\n",
|
||
" DILITHIUM_ALG = \"Dilithium2\"\n",
|
||
" else:\n",
|
||
" print(\"No Dilithium / ML-DSA scheme enabled in liboqs; skipping liboqs benchmarks.\")\n",
|
||
" if DILITHIUM_ALG is not None:\n",
|
||
" oqs_signer = oqs.Signature(DILITHIUM_ALG)\n",
|
||
" oqs_pk = oqs_signer.generate_keypair() # one keypair reused for all timings\n",
|
||
" # sk = oqs_signer.export_secret_key() # not needed explicitly here\n",
|
||
" HAS_OQS = True\n",
|
||
" print(\"liboqs available, using algorithm:\", DILITHIUM_ALG)\n",
|
||
"except ImportError:\n",
|
||
" print(\"liboqs-python (oqs) not installed; skipping liboqs benchmarks.\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "91689702",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"id": "4aaa17d9",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Naive vs SW NTT mismatches: 0\n",
|
||
"Naive vs HW mismatches: 0\n",
|
||
"t_naive = 0.286735 s\n",
|
||
"t_sw_ntt = 0.081213 s\n",
|
||
"t_hw_ntt = 0.008756 s\n",
|
||
"Speedup naive / SW NTT: 3.53 x\n",
|
||
"Speedup SW NTT / HW NTT: 9.28 x\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# oho E_dil\n",
|
||
"# Compare SW naive vs SW NTT vs HW NTT for Dilithium ring multiplication.\n",
|
||
"\n",
|
||
"import time\n",
|
||
"from numpy.random import default_rng\n",
|
||
"\n",
|
||
"rng = default_rng(12345)\n",
|
||
"\n",
|
||
"# Alias your existing naive golden for clarity\n",
|
||
"def poly_mul_sw_naive(a, b):\n",
|
||
" # This assumes you already have sw_poly_mult(a,b) as the O(N^2) golden.\n",
|
||
" return sw_poly_mult(a, b)\n",
|
||
"\n",
|
||
"def benchmark_once():\n",
|
||
" # Random centered coefficients in (-Q/2, Q/2]\n",
|
||
" a = rng.integers(low=-Q//2, high=Q//2, size=N, dtype=np.int64)\n",
|
||
" b = rng.integers(low=-Q//2, high=Q//2, size=N, dtype=np.int64)\n",
|
||
"\n",
|
||
" # 1) Naive SW (O(N^2))\n",
|
||
" t0 = time.perf_counter()\n",
|
||
" c_naive = poly_mul_sw_naive(a, b)\n",
|
||
" t1 = time.perf_counter()\n",
|
||
" t_naive = t1 - t0\n",
|
||
"\n",
|
||
" # 2) SW NTT (same algorithm as HW, but in Python)\n",
|
||
" t0 = time.perf_counter()\n",
|
||
" c_sw_ntt = poly_mul_sw_ntt(a, b)\n",
|
||
" t1 = time.perf_counter()\n",
|
||
" t_sw_ntt = t1 - t0\n",
|
||
"\n",
|
||
" # 3) HW NTT (HLS core via DMA)\n",
|
||
" t0 = time.perf_counter()\n",
|
||
" c_hw = hw_poly_mult(a, b)\n",
|
||
" t1 = time.perf_counter()\n",
|
||
" t_hw = t1 - t0\n",
|
||
"\n",
|
||
" # Correctness checks vs naive, modulo Q\n",
|
||
" diff_sw_ntt = (c_naive - c_sw_ntt) % Q\n",
|
||
" diff_hw = (c_naive - c_hw) % Q\n",
|
||
"\n",
|
||
" print(\"Naive vs SW NTT mismatches:\",\n",
|
||
" int(np.count_nonzero(diff_sw_ntt)))\n",
|
||
" print(\"Naive vs HW mismatches:\",\n",
|
||
" int(np.count_nonzero(diff_hw)))\n",
|
||
"\n",
|
||
" print(f\"t_naive = {t_naive:.6f} s\")\n",
|
||
" print(f\"t_sw_ntt = {t_sw_ntt:.6f} s\")\n",
|
||
" print(f\"t_hw_ntt = {t_hw:.6f} s\")\n",
|
||
"\n",
|
||
" if t_sw_ntt > 0 and t_hw > 0:\n",
|
||
" print(\"Speedup naive / SW NTT: %.2f x\" % (t_naive / t_sw_ntt))\n",
|
||
" print(\"Speedup SW NTT / HW NTT: %.2f x\" % (t_sw_ntt / t_hw))\n",
|
||
"\n",
|
||
"benchmark_once()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "d500e63b",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 23,
|
||
"id": "bee9cd02",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Naive vs SW NTT mismatches: 0\n",
|
||
"Naive vs HW mismatches: 0\n",
|
||
"t_naive = 0.293554 s\n",
|
||
"t_sw_ntt = 0.081380 s\n",
|
||
"t_hw_ntt = 0.008733 s\n",
|
||
"Speedup naive / SW NTT: 3.61 x\n",
|
||
"Speedup SW NTT / HW NTT: 9.32 x\n",
|
||
"liboqs Dilithium sign+verify OK? True, time: 0.002986 s\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# oho G_dil – benchmark: SW naive vs SW NTT vs HW NTT vs liboqs Dilithium\n",
|
||
"\n",
|
||
"import time\n",
|
||
"from numpy.random import default_rng\n",
|
||
"\n",
|
||
"rng = default_rng(12345)\n",
|
||
"\n",
|
||
"# Alias: your existing naive golden\n",
|
||
"def poly_mul_sw_naive(a, b):\n",
|
||
" # This assumes sw_poly_mult(a,b) is the O(N^2) negacyclic multiply you already defined.\n",
|
||
" return sw_poly_mult(a, b)\n",
|
||
"\n",
|
||
"def benchmark_once():\n",
|
||
" # Random centered coefficients in (-Q/2, Q/2]\n",
|
||
" a = rng.integers(low=-Q//2, high=Q//2, size=N, dtype=np.int64)\n",
|
||
" b = rng.integers(low=-Q//2, high=Q//2, size=N, dtype=np.int64)\n",
|
||
"\n",
|
||
" # 1) Naive SW (O(N^2))\n",
|
||
" t0 = time.perf_counter()\n",
|
||
" c_naive = poly_mul_sw_naive(a, b)\n",
|
||
" t1 = time.perf_counter()\n",
|
||
" t_naive = t1 - t0\n",
|
||
"\n",
|
||
" # 2) SW NTT (same algorithm as HW, but in Python)\n",
|
||
" t0 = time.perf_counter()\n",
|
||
" c_sw_ntt = poly_mul_sw_ntt(a, b)\n",
|
||
" t1 = time.perf_counter()\n",
|
||
" t_sw_ntt = t1 - t0\n",
|
||
"\n",
|
||
" # 3) HW NTT (HLS core via DMA)\n",
|
||
" t0 = time.perf_counter()\n",
|
||
" c_hw = hw_poly_mult(a, b)\n",
|
||
" t1 = time.perf_counter()\n",
|
||
" t_hw = t1 - t0\n",
|
||
"\n",
|
||
" # Correctness checks vs naive, modulo Q\n",
|
||
" diff_sw_ntt = (c_naive - c_sw_ntt) % Q\n",
|
||
" diff_hw = (c_naive - c_hw) % Q\n",
|
||
"\n",
|
||
" print(\"Naive vs SW NTT mismatches:\",\n",
|
||
" int(np.count_nonzero(diff_sw_ntt)))\n",
|
||
" print(\"Naive vs HW mismatches:\",\n",
|
||
" int(np.count_nonzero(diff_hw)))\n",
|
||
"\n",
|
||
" print(f\"t_naive = {t_naive:.6f} s\")\n",
|
||
" print(f\"t_sw_ntt = {t_sw_ntt:.6f} s\")\n",
|
||
" print(f\"t_hw_ntt = {t_hw:.6f} s\")\n",
|
||
"\n",
|
||
" if t_sw_ntt > 0 and t_hw > 0:\n",
|
||
" print(\"Speedup naive / SW NTT: %.2f x\" % (t_naive / t_sw_ntt))\n",
|
||
" print(\"Speedup SW NTT / HW NTT: %.2f x\" % (t_sw_ntt / t_hw))\n",
|
||
"\n",
|
||
" # 4) liboqs Dilithium sign+verify (full scheme, not just one poly mult)\n",
|
||
" if HAS_OQS and oqs_signer is not None:\n",
|
||
" # Use some message derived from the polynomials just to keep it tied together\n",
|
||
" msg = a.tobytes()[:32] # first 32 bytes as a \"message\"\n",
|
||
" t0 = time.perf_counter()\n",
|
||
" sig = oqs_signer.sign(msg)\n",
|
||
" ok = oqs_signer.verify(msg, sig, oqs_pk)\n",
|
||
" t1 = time.perf_counter()\n",
|
||
" t_oqs = t1 - t0\n",
|
||
" print(f\"liboqs Dilithium sign+verify OK? {ok}, time: {t_oqs:.6f} s\")\n",
|
||
" else:\n",
|
||
" print(\"liboqs Dilithium not available in this environment.\")\n",
|
||
"\n",
|
||
"benchmark_once()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "0ffbb055",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "1f592b8e",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "d26c9518",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"#########################################################\n",
|
||
"# Fair comparison\n",
|
||
"#########################################################"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"id": "963d4b71",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# oho H_dil – simple R_q vector/matrix utilities and pluggable poly_mul backend\n",
|
||
"\n",
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"Q = 8380417\n",
|
||
"N = 256\n",
|
||
"\n",
|
||
"def poly_mul_naive(a, b):\n",
|
||
" # Your existing O(N^2) golden\n",
|
||
" return sw_poly_mult(a, b)\n",
|
||
"\n",
|
||
"def poly_mul_sw_ntt_backend(a, b):\n",
|
||
" # Your Python NTT-based multiply from the previous cells\n",
|
||
" return poly_mul_sw_ntt(a, b)\n",
|
||
"\n",
|
||
"def poly_mul_hw_backend(a, b):\n",
|
||
" # Your FPGA-based multiply via DMA\n",
|
||
" return hw_poly_mult(a, b)\n",
|
||
"\n",
|
||
"def poly_add(a, b):\n",
|
||
" a = np.asarray(a, dtype=np.int64)\n",
|
||
" b = np.asarray(b, dtype=np.int64)\n",
|
||
" c = (a + b) % Q\n",
|
||
" # center if you prefer centered reps; else leave mod Q\n",
|
||
" return np.array([center_mod_q(int(x)) for x in c], dtype=np.int64)\n",
|
||
"\n",
|
||
"def poly_sub(a, b):\n",
|
||
" a = np.asarray(a, dtype=np.int64)\n",
|
||
" b = np.asarray(b, dtype=np.int64)\n",
|
||
" c = (a - b) % Q\n",
|
||
" return np.array([center_mod_q(int(x)) for x in c], dtype=np.int64)\n",
|
||
"\n",
|
||
"def sample_poly_small(eta=2):\n",
|
||
" # Very crude: coefficients uniform in [-eta, eta]\n",
|
||
" return np.random.randint(-eta, eta+1, size=N, dtype=np.int64)\n",
|
||
"\n",
|
||
"def sample_poly_uniform():\n",
|
||
" # Uniform in [0, Q)\n",
|
||
" return np.random.randint(0, Q, size=N, dtype=np.int64)\n",
|
||
"\n",
|
||
"def vec_poly_add(v1, v2):\n",
|
||
" assert len(v1) == len(v2)\n",
|
||
" return [poly_add(v1[i], v2[i]) for i in range(len(v1))]\n",
|
||
"\n",
|
||
"def vec_poly_sub(v1, v2):\n",
|
||
" assert len(v1) == len(v2)\n",
|
||
" return [poly_sub(v1[i], v2[i]) for i in range(len(v1))]\n",
|
||
"\n",
|
||
"def mat_poly_mul_vec(A, x, poly_mul_func):\n",
|
||
" \"\"\"\n",
|
||
" A: list of k rows, each row is list of l polynomials (shape k x l).\n",
|
||
" x: list of l polynomials (length l).\n",
|
||
" poly_mul_func: function(a,b) -> polynomial of length N.\n",
|
||
" Returns: list of k polynomials (A * x).\n",
|
||
" \"\"\"\n",
|
||
" k = len(A)\n",
|
||
" l = len(A[0])\n",
|
||
" assert len(x) == l\n",
|
||
" result = []\n",
|
||
" for i in range(k):\n",
|
||
" acc = np.zeros(N, dtype=np.int64)\n",
|
||
" for j in range(l):\n",
|
||
" acc = poly_add(acc, poly_mul_func(A[i][j], x[j]))\n",
|
||
" result.append(acc)\n",
|
||
" return result\n",
|
||
"\n",
|
||
"def vec_poly_mul_scalar(v, c_poly, poly_mul_func):\n",
|
||
" \"\"\"\n",
|
||
" v: list of polynomials\n",
|
||
" c_poly: polynomial\n",
|
||
" returns: [v[i] * c_poly] via poly_mul_func\n",
|
||
" \"\"\"\n",
|
||
" return [poly_mul_func(v[i], c_poly) for i in range(len(v))]\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"id": "425937d6",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# oho I_dil – toy Dilithium-style keygen / sign / verify with pluggable poly_mul\n",
|
||
"\n",
|
||
"# Parameters mimicking Dilithium-II\n",
|
||
"K = 4 # rows of A\n",
|
||
"L = 4 # cols of A\n",
|
||
"\n",
|
||
"def toy_dil_keygen(poly_mul_func):\n",
|
||
" \"\"\"\n",
|
||
" Keygen-like flow:\n",
|
||
" - Sample A (KxL matrix of polynomials)\n",
|
||
" - Sample s1 (L small polys), s2 (K small polys)\n",
|
||
" - t = A*s1 + s2\n",
|
||
" Returns (pk, sk) where:\n",
|
||
" pk = (A, t)\n",
|
||
" sk = (s1, s2)\n",
|
||
" \"\"\"\n",
|
||
" # Sample A\n",
|
||
" A = [[sample_poly_uniform() for _ in range(L)] for _ in range(K)]\n",
|
||
" # Secret vectors\n",
|
||
" s1 = [sample_poly_small() for _ in range(L)]\n",
|
||
" s2 = [sample_poly_small() for _ in range(K)]\n",
|
||
" # t = A*s1 + s2\n",
|
||
" As1 = mat_poly_mul_vec(A, s1, poly_mul_func)\n",
|
||
" t = vec_poly_add(As1, s2)\n",
|
||
" pk = (A, t)\n",
|
||
" sk = (s1, s2)\n",
|
||
" return pk, sk\n",
|
||
"\n",
|
||
"def toy_dil_sign(poly_mul_func, sk, pk, msg_bytes):\n",
|
||
" \"\"\"\n",
|
||
" Sign-like flow:\n",
|
||
" - Sample y (L small polys)\n",
|
||
" - w = A*y\n",
|
||
" - c = hash-like dummy poly derived from msg (here: simple deterministic map)\n",
|
||
" - z = y + c*s1\n",
|
||
" - r = w - c*s2\n",
|
||
" Returns a toy signature (z, c, r).\n",
|
||
" \"\"\"\n",
|
||
" A, t = pk\n",
|
||
" s1, s2 = sk\n",
|
||
"\n",
|
||
" # Sample y\n",
|
||
" y = [sample_poly_small() for _ in range(L)]\n",
|
||
" # w = A*y\n",
|
||
" w = mat_poly_mul_vec(A, y, poly_mul_func)\n",
|
||
"\n",
|
||
" # Dummy \"hash\" c: a polynomial derived deterministically from msg_bytes\n",
|
||
" # This is NOT secure; it's just to inject msg into the flow.\n",
|
||
" np.random.seed(int.from_bytes(msg_bytes[:4], \"little\", signed=False))\n",
|
||
" c_poly = sample_poly_small(eta=1)\n",
|
||
"\n",
|
||
" # z = y + c*s1\n",
|
||
" cs1 = vec_poly_mul_scalar(s1, c_poly, poly_mul_func)\n",
|
||
" z = vec_poly_add(y, cs1)\n",
|
||
"\n",
|
||
" # r = w - c*s2\n",
|
||
" cs2 = vec_poly_mul_scalar(s2, c_poly, poly_mul_func)\n",
|
||
" r = vec_poly_sub(w, cs2)\n",
|
||
"\n",
|
||
" sig = (z, c_poly, r)\n",
|
||
" return sig\n",
|
||
"\n",
|
||
"def toy_dil_verify(poly_mul_func, pk, msg_bytes, sig):\n",
|
||
" \"\"\"\n",
|
||
" Verify-like flow:\n",
|
||
" - Reconstruct w' = A*z - c*t\n",
|
||
" - Check that w' is \"small-ish\" in some toy sense\n",
|
||
" Returns True/False.\n",
|
||
" \"\"\"\n",
|
||
" A, t = pk\n",
|
||
" z, c_poly, r = sig\n",
|
||
"\n",
|
||
" # Recompute w' = A*z - c*t\n",
|
||
" Az = mat_poly_mul_vec(A, z, poly_mul_func)\n",
|
||
" ct = vec_poly_mul_scalar(t, c_poly, poly_mul_func)\n",
|
||
" w_prime = vec_poly_sub(Az, ct)\n",
|
||
"\n",
|
||
" # Compare w' and r (toy condition): they should be \"close\"\n",
|
||
" # Here just check exact equality; in real Dilithium you'd have norm bounds.\n",
|
||
" for wp, rv in zip(w_prime, r):\n",
|
||
" if not np.array_equal((wp % Q), (rv % Q)):\n",
|
||
" return False\n",
|
||
" return True\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"id": "8ea91187",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"=== Toy Dilithium-style flows (same harness, different poly_mul) ===\n",
|
||
"\n",
|
||
"[toy-naive] message: b'Test message for toy Dilithium'\n",
|
||
"[toy-naive] signature preview:\n",
|
||
" z[0]: [6, 12, 7, 17, ...]\n",
|
||
" c : [0, -1, 1, -1, ...]\n",
|
||
" r[0]: [-657352, 4065526, -3395326, 1460624, ...]\n",
|
||
"[toy-naive] keygen: 4.116052 s\n",
|
||
"[toy-naive] sign: 6.019188 s\n",
|
||
"[toy-naive] verify: 5.229953 s\n",
|
||
"[toy-naive] verify OK? True\n",
|
||
"\n",
|
||
"[toy-sw-ntt] message: b'Test message for toy Dilithium'\n",
|
||
"[toy-sw-ntt] signature preview:\n",
|
||
" z[0]: [6, 12, 7, 17, ...]\n",
|
||
" c : [0, -1, 1, -1, ...]\n",
|
||
" r[0]: [-657352, 4065526, -3395326, 1460624, ...]\n",
|
||
"[toy-sw-ntt] keygen: 1.292702 s\n",
|
||
"[toy-sw-ntt] sign: 1.921250 s\n",
|
||
"[toy-sw-ntt] verify: 1.606101 s\n",
|
||
"[toy-sw-ntt] verify OK? True\n",
|
||
"\n",
|
||
"[toy-hw-ntt] message: b'Test message for toy Dilithium'\n",
|
||
"[toy-hw-ntt] signature preview:\n",
|
||
" z[0]: [6, 12, 7, 17, ...]\n",
|
||
" c : [0, -1, 1, -1, ...]\n",
|
||
" r[0]: [-657352, 4065526, -3395326, 1460624, ...]\n",
|
||
"[toy-hw-ntt] keygen: 0.154008 s\n",
|
||
"[toy-hw-ntt] sign: 0.222962 s\n",
|
||
"[toy-hw-ntt] verify: 0.184250 s\n",
|
||
"[toy-hw-ntt] verify OK? True\n",
|
||
"\n",
|
||
"Speedups (toy flows, sign only):\n",
|
||
" naive / sw-ntt: 3.13 x\n",
|
||
" sw-ntt / hw-ntt: 8.62 x\n",
|
||
"\n",
|
||
"=== liboqs real Dilithium (C implementation) ===\n",
|
||
"\n",
|
||
"[liboqs] message: b'Test message for liboqs Dilithium'\n",
|
||
"[liboqs] signature preview: len=2420, first 16 bytes: 12954eb436346ddf827f047c121627f6\n",
|
||
"[liboqs] keygen: 0.001172 s\n",
|
||
"[liboqs] sign: 0.001966 s\n",
|
||
"[liboqs] verify: 0.000734 s\n",
|
||
"[liboqs] verify OK? True\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(0.0011719209996954305, 0.0019655719997899723, 0.0007336500002566027)"
|
||
]
|
||
},
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# oho J_dil – end-to-end benchmark: toy flows vs liboqs, with message & signature previews\n",
|
||
"\n",
|
||
"import time\n",
|
||
"import binascii\n",
|
||
"\n",
|
||
"def preview_poly(poly, n=4):\n",
|
||
" \"\"\"Return a short string preview of the first n coefficients of a polynomial.\"\"\"\n",
|
||
" poly = np.asarray(poly, dtype=np.int64)\n",
|
||
" coeffs = \", \".join(str(int(x)) for x in poly[:n])\n",
|
||
" return f\"[{coeffs}{', ...' if poly.size > n else ''}]\"\n",
|
||
"\n",
|
||
"def preview_sig_toy(sig, n=4):\n",
|
||
" \"\"\"Preview toy signature (z, c_poly, r) in a compact way.\"\"\"\n",
|
||
" z, c_poly, r = sig\n",
|
||
" # z: list of L polys, r: list of K polys\n",
|
||
" z0 = z[0] if len(z) > 0 else np.zeros(N, dtype=np.int64)\n",
|
||
" r0 = r[0] if len(r) > 0 else np.zeros(N, dtype=np.int64)\n",
|
||
" return {\n",
|
||
" \"z[0]\": preview_poly(z0, n),\n",
|
||
" \"c\": preview_poly(c_poly, n),\n",
|
||
" \"r[0]\": preview_poly(r0, n),\n",
|
||
" }\n",
|
||
"\n",
|
||
"def preview_sig_liboqs(sig_bytes, n=16):\n",
|
||
" \"\"\"Preview liboqs signature as first n bytes in hex.\"\"\"\n",
|
||
" sig = bytes(sig_bytes)\n",
|
||
" prefix = sig[:n]\n",
|
||
" return f\"len={len(sig)}, first {n} bytes: {binascii.hexlify(prefix).decode()}\"\n",
|
||
"\n",
|
||
"def run_toy_flow(poly_mul_func, label):\n",
|
||
" msg = b\"Test message for toy Dilithium\"\n",
|
||
" print(f\"\\n[{label}] message:\", msg)\n",
|
||
"\n",
|
||
" t0 = time.perf_counter()\n",
|
||
" pk, sk = toy_dil_keygen(poly_mul_func)\n",
|
||
" t1 = time.perf_counter()\n",
|
||
"\n",
|
||
" sig = toy_dil_sign(poly_mul_func, sk, pk, msg)\n",
|
||
" t2 = time.perf_counter()\n",
|
||
"\n",
|
||
" ok = toy_dil_verify(poly_mul_func, pk, msg, sig)\n",
|
||
" t3 = time.perf_counter()\n",
|
||
"\n",
|
||
" # Signature preview\n",
|
||
" sig_preview = preview_sig_toy(sig, n=4)\n",
|
||
"\n",
|
||
" print(f\"[{label}] signature preview:\")\n",
|
||
" print(f\" z[0]: {sig_preview['z[0]']}\")\n",
|
||
" print(f\" c : {sig_preview['c']}\")\n",
|
||
" print(f\" r[0]: {sig_preview['r[0]']}\")\n",
|
||
"\n",
|
||
" print(f\"[{label}] keygen: {t1 - t0:.6f} s\")\n",
|
||
" print(f\"[{label}] sign: {t2 - t1:.6f} s\")\n",
|
||
" print(f\"[{label}] verify: {t3 - t2:.6f} s\")\n",
|
||
" print(f\"[{label}] verify OK? {ok}\")\n",
|
||
" return (t1 - t0), (t2 - t1), (t3 - t2)\n",
|
||
"\n",
|
||
"def run_liboqs_flow():\n",
|
||
" if not HAS_OQS or oqs_signer is None:\n",
|
||
" print(\"[liboqs] not available.\")\n",
|
||
" return None\n",
|
||
"\n",
|
||
" msg = b\"Test message for liboqs Dilithium\"\n",
|
||
" print(\"\\n[liboqs] message:\", msg)\n",
|
||
"\n",
|
||
" t0 = time.perf_counter()\n",
|
||
" with oqs.Signature(DILITHIUM_ALG) as signer:\n",
|
||
" pk = signer.generate_keypair()\n",
|
||
" sk = signer.export_secret_key()\n",
|
||
" t1 = time.perf_counter()\n",
|
||
"\n",
|
||
" sig = signer.sign(msg)\n",
|
||
" t2 = time.perf_counter()\n",
|
||
"\n",
|
||
" ok = signer.verify(msg, sig, pk)\n",
|
||
" t3 = time.perf_counter()\n",
|
||
"\n",
|
||
" sig_preview = preview_sig_liboqs(sig, n=16)\n",
|
||
" print(f\"[liboqs] signature preview: {sig_preview}\")\n",
|
||
" print(f\"[liboqs] keygen: {t1 - t0:.6f} s\")\n",
|
||
" print(f\"[liboqs] sign: {t2 - t1:.6f} s\")\n",
|
||
" print(f\"[liboqs] verify: {t3 - t2:.6f} s\")\n",
|
||
" print(f\"[liboqs] verify OK? {ok}\")\n",
|
||
" return (t1 - t0), (t2 - t1), (t3 - t2)\n",
|
||
"\n",
|
||
"print(\"=== Toy Dilithium-style flows (same harness, different poly_mul) ===\")\n",
|
||
"\n",
|
||
"tkg_n, ts_n, tv_n = run_toy_flow(poly_mul_naive, \"toy-naive\")\n",
|
||
"tkg_s, ts_s, tv_s = run_toy_flow(poly_mul_sw_ntt_backend, \"toy-sw-ntt\")\n",
|
||
"tkg_h, ts_h, tv_h = run_toy_flow(poly_mul_hw_backend, \"toy-hw-ntt\")\n",
|
||
"\n",
|
||
"print(\"\\nSpeedups (toy flows, sign only):\")\n",
|
||
"print(\" naive / sw-ntt: %.2f x\" % (ts_n / ts_s))\n",
|
||
"print(\" sw-ntt / hw-ntt: %.2f x\" % (ts_s / ts_h))\n",
|
||
"\n",
|
||
"print(\"\\n=== liboqs real Dilithium (C implementation) ===\")\n",
|
||
"run_liboqs_flow()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "3071d384",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"#######################\n",
|
||
"# Interoperability Test\n",
|
||
"#######################"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"id": "bacd163d",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[interop] generating pk, sk with naive backend\n",
|
||
"\n",
|
||
"[interop] sign=naive, verify=sw-ntt\n",
|
||
" sign backend: poly_mul_naive\n",
|
||
" verify backend: poly_mul_sw_ntt_backend\n",
|
||
" sign time: 6.004031 s\n",
|
||
" verify time: 1.602216 s\n",
|
||
" verify OK? True\n",
|
||
"\n",
|
||
"[interop] sign=sw-ntt, verify=naive\n",
|
||
" sign backend: poly_mul_sw_ntt_backend\n",
|
||
" verify backend: poly_mul_naive\n",
|
||
" sign time: 1.920178 s\n",
|
||
" verify time: 5.216772 s\n",
|
||
" verify OK? True\n",
|
||
"\n",
|
||
"[interop] sign=hw-ntt, verify=sw-ntt\n",
|
||
" sign backend: poly_mul_hw_backend\n",
|
||
" verify backend: poly_mul_sw_ntt_backend\n",
|
||
" sign time: 0.222985 s\n",
|
||
" verify time: 1.599958 s\n",
|
||
" verify OK? True\n",
|
||
"\n",
|
||
"[interop] sign=sw-ntt, verify=hw-ntt\n",
|
||
" sign backend: poly_mul_sw_ntt_backend\n",
|
||
" verify backend: poly_mul_hw_backend\n",
|
||
" sign time: 1.918987 s\n",
|
||
" verify time: 0.183846 s\n",
|
||
" verify OK? True\n",
|
||
"\n",
|
||
"[interop] sign=naive, verify=hw-ntt\n",
|
||
" sign backend: poly_mul_naive\n",
|
||
" verify backend: poly_mul_hw_backend\n",
|
||
" sign time: 5.994883 s\n",
|
||
" verify time: 0.183779 s\n",
|
||
" verify OK? True\n",
|
||
"\n",
|
||
"[interop] sign=hw-ntt, verify=naive\n",
|
||
" sign backend: poly_mul_hw_backend\n",
|
||
" verify backend: poly_mul_naive\n",
|
||
" sign time: 0.225005 s\n",
|
||
" verify time: 5.256011 s\n",
|
||
" verify OK? True\n",
|
||
"\n",
|
||
"[interop] summary:\n",
|
||
" naive ↔ sw-ntt: True\n",
|
||
" hw-ntt ↔ sw-ntt: True\n",
|
||
" hw-ntt ↔ naive: True\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# oho K_dil – interoperability tests between poly_mul backends\n",
|
||
"\n",
|
||
"import time\n",
|
||
"\n",
|
||
"def run_interop_case(label, poly_mul_sign, poly_mul_verify, pk, sk, msg):\n",
|
||
" \"\"\"Sign with one backend, verify with another.\"\"\"\n",
|
||
" print(f\"\\n[interop] {label}\")\n",
|
||
" t0 = time.perf_counter()\n",
|
||
" sig = toy_dil_sign(poly_mul_sign, sk, pk, msg)\n",
|
||
" t1 = time.perf_counter()\n",
|
||
" ok = toy_dil_verify(poly_mul_verify, pk, msg, sig)\n",
|
||
" t2 = time.perf_counter()\n",
|
||
" print(f\" sign backend: {poly_mul_sign.__name__}\")\n",
|
||
" print(f\" verify backend: {poly_mul_verify.__name__}\")\n",
|
||
" print(f\" sign time: {t1 - t0:.6f} s\")\n",
|
||
" print(f\" verify time: {t2 - t1:.6f} s\")\n",
|
||
" print(f\" verify OK? {ok}\")\n",
|
||
" return ok\n",
|
||
"\n",
|
||
"# Fix a message and a single keypair (so only backends differ)\n",
|
||
"msg = b\"Interop test message for toy Dilithium\"\n",
|
||
"print(\"[interop] generating pk, sk with naive backend\")\n",
|
||
"pk, sk = toy_dil_keygen(poly_mul_naive)\n",
|
||
"\n",
|
||
"# 1) Sign by SW naive, verify by SW NTT and vice versa\n",
|
||
"ok_1a = run_interop_case(\n",
|
||
" \"sign=naive, verify=sw-ntt\",\n",
|
||
" poly_mul_naive,\n",
|
||
" poly_mul_sw_ntt_backend,\n",
|
||
" pk, sk, msg,\n",
|
||
")\n",
|
||
"ok_1b = run_interop_case(\n",
|
||
" \"sign=sw-ntt, verify=naive\",\n",
|
||
" poly_mul_sw_ntt_backend,\n",
|
||
" poly_mul_naive,\n",
|
||
" pk, sk, msg,\n",
|
||
")\n",
|
||
"\n",
|
||
"# 2) Sign by HW NTT, verify by SW NTT and vice versa\n",
|
||
"ok_2a = run_interop_case(\n",
|
||
" \"sign=hw-ntt, verify=sw-ntt\",\n",
|
||
" poly_mul_hw_backend,\n",
|
||
" poly_mul_sw_ntt_backend,\n",
|
||
" pk, sk, msg,\n",
|
||
")\n",
|
||
"ok_2b = run_interop_case(\n",
|
||
" \"sign=sw-ntt, verify=hw-ntt\",\n",
|
||
" poly_mul_sw_ntt_backend,\n",
|
||
" poly_mul_hw_backend,\n",
|
||
" pk, sk, msg,\n",
|
||
")\n",
|
||
"\n",
|
||
"# Also useful: cross all three backends for sanity\n",
|
||
"ok_3a = run_interop_case(\n",
|
||
" \"sign=naive, verify=hw-ntt\",\n",
|
||
" poly_mul_naive,\n",
|
||
" poly_mul_hw_backend,\n",
|
||
" pk, sk, msg,\n",
|
||
")\n",
|
||
"ok_3b = run_interop_case(\n",
|
||
" \"sign=hw-ntt, verify=naive\",\n",
|
||
" poly_mul_hw_backend,\n",
|
||
" poly_mul_naive,\n",
|
||
" pk, sk, msg,\n",
|
||
")\n",
|
||
"\n",
|
||
"print(\"\\n[interop] summary:\")\n",
|
||
"print(\" naive ↔ sw-ntt:\", ok_1a and ok_1b)\n",
|
||
"print(\" hw-ntt ↔ sw-ntt:\", ok_2a and ok_2b)\n",
|
||
"print(\" hw-ntt ↔ naive:\", ok_3a and ok_3b)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "98b90350",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"######################################\n",
|
||
"# Throughput Testing\n",
|
||
"######################################"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 31,
|
||
"id": "509500db",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# oho L_dil – reusable-buffer variant of the HW poly multiply for throughput\n",
|
||
"\n",
|
||
"from pynq import allocate\n",
|
||
"\n",
|
||
"def make_hw_poly_mult_reusable():\n",
|
||
" \"\"\"\n",
|
||
" Returns a callable hw_poly_mult_reuse(a, b) that uses preallocated\n",
|
||
" DMA buffers under the hood, for high-throughput benchmarking.\n",
|
||
" \"\"\"\n",
|
||
" in_buf = allocate(shape=(N,), dtype=np.uint64)\n",
|
||
" out_buf = allocate(shape=(N,), dtype=np.uint32)\n",
|
||
"\n",
|
||
" def hw_poly_mult_reuse(a, b):\n",
|
||
" a = np.asarray(a, dtype=np.int64)\n",
|
||
" b = np.asarray(b, dtype=np.int64)\n",
|
||
" assert a.shape == (N,) and b.shape == (N,)\n",
|
||
"\n",
|
||
" # Pack two 32-bit coeffs into one 64-bit word\n",
|
||
" # Convention: low 32 bits = a[i], high 32 bits = b[i]\n",
|
||
" for i in range(N):\n",
|
||
" a_i = int(a[i] % Q) & 0xFFFFFFFF\n",
|
||
" b_i = int(b[i] % Q) & 0xFFFFFFFF\n",
|
||
" word = ((b_i << 32) | a_i)\n",
|
||
" in_buf[i] = np.uint64(word)\n",
|
||
"\n",
|
||
" # Start hardware core (ap_start = 1)\n",
|
||
" mmio.write(0x00, 0x01)\n",
|
||
"\n",
|
||
" # Kick DMA transfers\n",
|
||
" dma.sendchannel.transfer(in_buf)\n",
|
||
" dma.recvchannel.transfer(out_buf)\n",
|
||
" dma.sendchannel.wait()\n",
|
||
" dma.recvchannel.wait()\n",
|
||
"\n",
|
||
" # We don't need to return the result for pure throughput measurement,\n",
|
||
" # but it's useful to keep it functional.\n",
|
||
" c = np.empty(N, dtype=np.int64)\n",
|
||
" for i in range(N):\n",
|
||
" c[i] = center_mod_q(int(out_buf[i]))\n",
|
||
" return c\n",
|
||
"\n",
|
||
" return hw_poly_mult_reuse, in_buf, out_buf\n",
|
||
"\n",
|
||
"# Create a reusable HW function once\n",
|
||
"hw_poly_mult_reuse, hw_in_buf, hw_out_buf = make_hw_poly_mult_reusable()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 32,
|
||
"id": "f9b43699",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"=== Throughput benchmark with num_ops = 10 ===\n",
|
||
"[naive] total time: 2.871808 s, ops/s: 3.48\n",
|
||
"[sw-ntt] total time: 0.811113 s, ops/s: 12.33\n",
|
||
"[hw-ntt] total time: 0.076152 s, ops/s: 131.32\n",
|
||
"\n",
|
||
"Throughput ratios (ops/s):\n",
|
||
" sw-ntt / naive: 3.54 x\n",
|
||
" hw-ntt / sw-ntt: 10.65 x\n",
|
||
" hw-ntt / naive: 37.71 x\n",
|
||
"=== Throughput benchmark with num_ops = 100 ===\n",
|
||
"[naive] total time: 28.092749 s, ops/s: 3.56\n",
|
||
"[sw-ntt] total time: 8.087545 s, ops/s: 12.36\n",
|
||
"[hw-ntt] total time: 0.763630 s, ops/s: 130.95\n",
|
||
"\n",
|
||
"Throughput ratios (ops/s):\n",
|
||
" sw-ntt / naive: 3.47 x\n",
|
||
" hw-ntt / sw-ntt: 10.59 x\n",
|
||
" hw-ntt / naive: 36.79 x\n",
|
||
"=== Throughput benchmark with num_ops = 1000 ===\n",
|
||
"[naive] total time: 275.062409 s, ops/s: 3.64\n",
|
||
"[sw-ntt] total time: 77.406778 s, ops/s: 12.92\n",
|
||
"[hw-ntt] total time: 7.654999 s, ops/s: 130.63\n",
|
||
"\n",
|
||
"Throughput ratios (ops/s):\n",
|
||
" sw-ntt / naive: 3.55 x\n",
|
||
" hw-ntt / sw-ntt: 10.11 x\n",
|
||
" hw-ntt / naive: 35.93 x\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# oho M_dil – throughput benchmark: many independent multiplies\n",
|
||
"\n",
|
||
"import time\n",
|
||
"from numpy.random import default_rng\n",
|
||
"\n",
|
||
"rng = default_rng(123456)\n",
|
||
"\n",
|
||
"def throughput_benchmark(num_ops=100):\n",
|
||
" print(f\"=== Throughput benchmark with num_ops = {num_ops} ===\")\n",
|
||
"\n",
|
||
" # Pre-generate inputs so all backends see the same workload\n",
|
||
" A_batch = rng.integers(low=-Q//2, high=Q//2, size=(num_ops, N), dtype=np.int64)\n",
|
||
" B_batch = rng.integers(low=-Q//2, high=Q//2, size=(num_ops, N), dtype=np.int64)\n",
|
||
"\n",
|
||
" # 1) Naive SW (O(N^2))\n",
|
||
" t0 = time.perf_counter()\n",
|
||
" for k in range(num_ops):\n",
|
||
" _ = poly_mul_sw_naive(A_batch[k], B_batch[k])\n",
|
||
" t1 = time.perf_counter()\n",
|
||
" t_naive_total = t1 - t0\n",
|
||
" naive_ops_per_s = num_ops / t_naive_total\n",
|
||
"\n",
|
||
" print(f\"[naive] total time: {t_naive_total:.6f} s, ops/s: {naive_ops_per_s:.2f}\")\n",
|
||
"\n",
|
||
" # 2) SW NTT (Python 512-NTT)\n",
|
||
" t0 = time.perf_counter()\n",
|
||
" for k in range(num_ops):\n",
|
||
" _ = poly_mul_sw_ntt(A_batch[k], B_batch[k])\n",
|
||
" t1 = time.perf_counter()\n",
|
||
" t_sw_ntt_total = t1 - t0\n",
|
||
" sw_ntt_ops_per_s = num_ops / t_sw_ntt_total\n",
|
||
"\n",
|
||
" print(f\"[sw-ntt] total time: {t_sw_ntt_total:.6f} s, ops/s: {sw_ntt_ops_per_s:.2f}\")\n",
|
||
"\n",
|
||
" # 3) HW NTT (FPGA via reusable DMA buffers)\n",
|
||
" t0 = time.perf_counter()\n",
|
||
" for k in range(num_ops):\n",
|
||
" _ = hw_poly_mult_reuse(A_batch[k], B_batch[k])\n",
|
||
" t1 = time.perf_counter()\n",
|
||
" t_hw_total = t1 - t0\n",
|
||
" hw_ops_per_s = num_ops / t_hw_total\n",
|
||
"\n",
|
||
" print(f\"[hw-ntt] total time: {t_hw_total:.6f} s, ops/s: {hw_ops_per_s:.2f}\")\n",
|
||
"\n",
|
||
" # Aggregate ratios\n",
|
||
" print(\"\\nThroughput ratios (ops/s):\")\n",
|
||
" print(\" sw-ntt / naive: %.2f x\" % (sw_ntt_ops_per_s / naive_ops_per_s))\n",
|
||
" print(\" hw-ntt / sw-ntt: %.2f x\" % (hw_ops_per_s / sw_ntt_ops_per_s))\n",
|
||
" print(\" hw-ntt / naive: %.2f x\" % (hw_ops_per_s / naive_ops_per_s))\n",
|
||
"\n",
|
||
"# Example: varying load\n",
|
||
"throughput_benchmark(num_ops=10)\n",
|
||
"throughput_benchmark(num_ops=100)\n",
|
||
"# next line: Do you want to wait so long here ?\n",
|
||
"#throughput_benchmark(num_ops=1000)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "61736ef2",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.4"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|