mirror of
https://github.com/saymrwulf/NTT-learning.git
synced 2026-07-03 03:37:34 +00:00
Replace early notebook plots with teaching players
This commit is contained in:
parent
74700b8457
commit
f152567e06
9 changed files with 462 additions and 24 deletions
|
|
@ -48,7 +48,7 @@
|
|||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": "# MANDATORY | difficulty 3 | Prediction Check For n=4\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace\nfrom ntt_learning.visuals import interactive_trace\n\ntrace = fast_ntt_psi_ct_trace([1, 2, 3, 4], 7681, 1925)\ndisplay(interactive_trace(trace, title=\"Check your n=4 prediction\"))\n"
|
||||
"source": "# MANDATORY | difficulty 3 | Prediction Check For n=4\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace\nfrom ntt_learning.visuals import butterfly_story_player, interactive_trace\n\ntrace = fast_ntt_psi_ct_trace([1, 2, 3, 4], 7681, 1925)\ndisplay(butterfly_story_player(trace))\ndisplay(interactive_trace(trace, title=\"Check your n=4 prediction\"))\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@
|
|||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": "# MANDATORY | difficulty 3 | Trace The Exact n=4 Paper Example\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, forward_ntt_psi\nfrom ntt_learning.visuals import interactive_trace, plot_butterfly_network, plot_trace_overview\n\nsignal = [1, 2, 3, 4]\nmodulus = 7681\npsi = 1925\ntrace = fast_ntt_psi_ct_trace(signal, modulus, psi)\n\nprint(\"raw CT output (BO):\", trace.raw_output)\nprint(\"bit-reversed back to NO:\", trace.normal_order_output)\nprint(\"direct NTT_psi:\", forward_ntt_psi(signal, modulus, psi))\ndisplay(plot_trace_overview(trace, title=\"CT overview for [1,2,3,4]\"))\ndisplay(plot_butterfly_network(trace, title=\"Full CT network for [1,2,3,4]\"))\ndisplay(interactive_trace(trace, title=\"CT forward trace\"))\n"
|
||||
"source": "# MANDATORY | difficulty 3 | Trace The Exact n=4 Paper Example\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, forward_ntt_psi\nfrom ntt_learning.visuals import butterfly_story_player, interactive_trace, plot_butterfly_network, plot_trace_overview\n\nsignal = [1, 2, 3, 4]\nmodulus = 7681\npsi = 1925\ntrace = fast_ntt_psi_ct_trace(signal, modulus, psi)\n\nprint(\"raw CT output (BO):\", trace.raw_output)\nprint(\"bit-reversed back to NO:\", trace.normal_order_output)\nprint(\"direct NTT_psi:\", forward_ntt_psi(signal, modulus, psi))\ndisplay(butterfly_story_player(trace))\ndisplay(plot_trace_overview(trace, title=\"CT overview for [1,2,3,4]\"))\ndisplay(plot_butterfly_network(trace, title=\"Full CT network for [1,2,3,4]\"))\ndisplay(interactive_trace(trace, title=\"CT forward trace\"))\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
@ -102,7 +102,7 @@
|
|||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": "# MANDATORY | difficulty 3 | Go One Stage Deeper With n=8\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, find_psi\nfrom ntt_learning.visuals import interactive_trace, plot_butterfly_network, plot_trace_overview\n\nsignal = [0, 1, 2, 3, 4, 5, 6, 7]\nmodulus = 97\npsi = find_psi(8, modulus)\ntrace = fast_ntt_psi_ct_trace(signal, modulus, psi)\n\nprint(\"psi:\", psi)\nprint(\"BO output:\", trace.raw_output)\nprint(\"NO output:\", trace.normal_order_output)\ndisplay(plot_trace_overview(trace, title=\"Three CT stages for n=8\"))\ndisplay(plot_butterfly_network(trace, title=\"Full CT network for n=8\"))\ndisplay(interactive_trace(trace, title=\"n=8 CT stage explorer\"))\n"
|
||||
"source": "# MANDATORY | difficulty 3 | Go One Stage Deeper With n=8\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, find_psi\nfrom ntt_learning.visuals import butterfly_story_player, interactive_trace, plot_butterfly_network, plot_trace_overview\n\nsignal = [0, 1, 2, 3, 4, 5, 6, 7]\nmodulus = 97\npsi = find_psi(8, modulus)\ntrace = fast_ntt_psi_ct_trace(signal, modulus, psi)\n\nprint(\"psi:\", psi)\nprint(\"BO output:\", trace.raw_output)\nprint(\"NO output:\", trace.normal_order_output)\ndisplay(butterfly_story_player(trace))\ndisplay(plot_trace_overview(trace, title=\"Three CT stages for n=8\"))\ndisplay(plot_butterfly_network(trace, title=\"Full CT network for n=8\"))\ndisplay(interactive_trace(trace, title=\"n=8 CT stage explorer\"))\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@
|
|||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": "# MANDATORY | difficulty 2 | Run The Prediction Check\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution\nfrom ntt_learning.visuals import plot_convolution_grid, plot_wraparound\n\nleft = [3, 0, 2, 1]\nright = [1, 4, 0, 2]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nprint(\"negacyclic result:\", negacyclic_multiply(left, right, n=4))\ndisplay(plot_convolution_grid(left, right, title=\"Prediction check grid\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Prediction check fold\"))\n"
|
||||
"source": "# MANDATORY | difficulty 2 | Run The Prediction Check\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution\nfrom ntt_learning.visuals import plot_convolution_grid, plot_wraparound, schoolbook_diagonal_player, wraparound_comparison_player\n\nleft = [3, 0, 2, 1]\nright = [1, 4, 0, 2]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nprint(\"negacyclic result:\", negacyclic_multiply(left, right, n=4))\ndisplay(schoolbook_diagonal_player(left, right))\ndisplay(wraparound_comparison_player(raw, n=4))\ndisplay(plot_convolution_grid(left, right, title=\"Prediction check grid\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Prediction check fold\"))\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
@ -74,7 +74,7 @@
|
|||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": "# MANDATORY | difficulty 2 | A Second Visual Drill\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import schoolbook_convolution\nfrom ntt_learning.visuals import plot_wraparound\n\nraw = schoolbook_convolution([2, 5, 0, 1], [1, 0, 3, 2])\nprint(\"raw convolution:\", raw)\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Trace one tail coefficient by eye\"))\n"
|
||||
"source": "# MANDATORY | difficulty 2 | A Second Visual Drill\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import schoolbook_convolution\nfrom ntt_learning.visuals import plot_wraparound, wraparound_comparison_player\n\nraw = schoolbook_convolution([2, 5, 0, 1], [1, 0, 3, 2])\nprint(\"raw convolution:\", raw)\ndisplay(wraparound_comparison_player(raw, n=4))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Trace one tail coefficient by eye\"))\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
|
|||
|
|
@ -44,11 +44,11 @@
|
|||
"role": "mandatory",
|
||||
"difficulty": 2,
|
||||
"kind": "demo",
|
||||
"title": "See The Product Grid And The Diagonal Sums"
|
||||
"title": "Play The Diagonal Sweep"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": "# MANDATORY | difficulty 2 | See The Product Grid And The Diagonal Sums\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import convolution_contributions, schoolbook_convolution\nfrom ntt_learning.visuals import plot_convolution_grid\n\nleft = [1, 2, 3, 4]\nright = [5, 6, 7, 8]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nfor row in convolution_contributions(left, right):\n print(row)\n\nfig = plot_convolution_grid(left, right, title=\"Schoolbook products for [1,2,3,4] * [5,6,7,8]\")\ndisplay(fig)\n"
|
||||
"source": "# MANDATORY | difficulty 2 | Play The Diagonal Sweep\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import convolution_contributions, schoolbook_convolution\nfrom ntt_learning.visuals import plot_convolution_grid, schoolbook_diagonal_player\n\nleft = [1, 2, 3, 4]\nright = [5, 6, 7, 8]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nfor row in convolution_contributions(left, right):\n print(row)\n\ndisplay(schoolbook_diagonal_player(left, right))\nfig = plot_convolution_grid(left, right, title=\"Schoolbook products for [1,2,3,4] * [5,6,7,8]\")\ndisplay(fig)\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
@ -70,11 +70,11 @@
|
|||
"role": "mandatory",
|
||||
"difficulty": 2,
|
||||
"kind": "demo",
|
||||
"title": "See Cyclic And Negacyclic Folding Side By Side"
|
||||
"title": "Play The Wraparound Step By Step"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": "# MANDATORY | difficulty 2 | See Cyclic And Negacyclic Folding Side By Side\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution, wraparound_contributions\nfrom ntt_learning.visuals import plot_wraparound\n\nleft = [1, 2, 3, 4]\nright = [5, 6, 7, 8]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nprint(\"negacyclic in x^4 + 1:\", negacyclic_multiply(left, right, n=4))\nprint(\"cyclic folding rows:\")\nfor row in wraparound_contributions(raw, n=4, negacyclic=False):\n print(row)\nprint(\"negacyclic folding rows:\")\nfor row in wraparound_contributions(raw, n=4, negacyclic=True):\n print(row)\n\ndisplay(plot_wraparound(raw, n=4, negacyclic=False, title=\"Cyclic folding into x^4 - 1\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Negacyclic folding into x^4 + 1\"))\n"
|
||||
"source": "# MANDATORY | difficulty 2 | Play The Wraparound Step By Step\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution, wraparound_contributions\nfrom ntt_learning.visuals import plot_wraparound, wraparound_comparison_player\n\nleft = [1, 2, 3, 4]\nright = [5, 6, 7, 8]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nprint(\"negacyclic in x^4 + 1:\", negacyclic_multiply(left, right, n=4))\nprint(\"cyclic folding rows:\")\nfor row in wraparound_contributions(raw, n=4, negacyclic=False):\n print(row)\nprint(\"negacyclic folding rows:\")\nfor row in wraparound_contributions(raw, n=4, negacyclic=True):\n print(row)\n\ndisplay(wraparound_comparison_player(raw, n=4))\ndisplay(plot_wraparound(raw, n=4, negacyclic=False, title=\"Cyclic folding into x^4 - 1\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Negacyclic folding into x^4 + 1\"))\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@
|
|||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": "# MANDATORY | difficulty 3 | Compare The Two Fold Rules\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import schoolbook_convolution\nfrom ntt_learning.visuals import plot_wraparound\n\nraw = schoolbook_convolution([1, 2, 3, 4], [5, 6, 7, 8])\nprint(\"raw convolution:\", raw)\ndisplay(plot_wraparound(raw, n=4, negacyclic=False, title=\"Positive wrap into x^4 - 1\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Negative wrap into x^4 + 1\"))\n"
|
||||
"source": "# MANDATORY | difficulty 3 | Compare The Two Fold Rules\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import schoolbook_convolution\nfrom ntt_learning.visuals import plot_wraparound, wraparound_comparison_player\n\nraw = schoolbook_convolution([1, 2, 3, 4], [5, 6, 7, 8])\nprint(\"raw convolution:\", raw)\ndisplay(wraparound_comparison_player(raw, n=4))\ndisplay(plot_wraparound(raw, n=4, negacyclic=False, title=\"Positive wrap into x^4 - 1\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Negative wrap into x^4 + 1\"))\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@
|
|||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": "# MANDATORY | difficulty 3 | Interactive Signal Explorer\n\nimport ipywidgets as widgets\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi\n\nmodulus = 17\npsi = find_psi(4, modulus)\n\ndef preview(a0=1, a1=2, a2=3, a3=4):\n signal = [a0, a1, a2, a3]\n spectrum = forward_ntt_psi(signal, modulus, psi)\n print(\"signal:\", signal)\n print(\"spectrum:\", spectrum)\n print(\"inverse:\", inverse_ntt_psi(spectrum, modulus, psi))\n\ndisplay(\n widgets.interact(\n preview,\n a0=(0, 16),\n a1=(0, 16),\n a2=(0, 16),\n a3=(0, 16),\n )\n)\n"
|
||||
"source": "# MANDATORY | difficulty 3 | Interactive Signal Explorer\n\nimport ipywidgets as widgets\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi\nfrom ntt_learning.visuals import direct_ntt_player\n\nmodulus = 17\npsi = find_psi(4, modulus)\n\ndef preview(a0=1, a1=2, a2=3, a3=4):\n signal = [a0, a1, a2, a3]\n spectrum = forward_ntt_psi(signal, modulus, psi)\n print(\"signal:\", signal)\n print(\"spectrum:\", spectrum)\n print(\"inverse:\", inverse_ntt_psi(spectrum, modulus, psi))\n display(direct_ntt_player(signal, modulus, psi))\n\ndisplay(\n widgets.interact(\n preview,\n a0=(0, 16),\n a1=(0, 16),\n a2=(0, 16),\n a3=(0, 16),\n )\n)\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@
|
|||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": "# MANDATORY | difficulty 2 | Inspect \u03c9, \u03c8, And The Direct Transform Matrix\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_primitive_root, find_psi, ntt_psi_exponent_grid, ntt_psi_matrix\nfrom ntt_learning.visuals import plot_ntt_psi_exponent_heatmap, plot_ntt_psi_matrix_heatmap\n\nmodulus = 17\nn = 4\nomega = find_primitive_root(n, modulus)\npsi = find_psi(n, modulus)\n\nprint(\"omega:\", omega)\nprint(\"psi:\", psi)\nprint(\"exponent grid:\")\nfor row in ntt_psi_exponent_grid(n):\n print(row)\nprint(\"NTT_psi matrix:\")\nfor row in ntt_psi_matrix(n, modulus, psi):\n print(row)\n\ndisplay(plot_ntt_psi_exponent_heatmap(n, title=\"Exponents 2ij + i for n=4\"))\ndisplay(plot_ntt_psi_matrix_heatmap(n, modulus, psi, title=\"Concrete NTT_psi matrix in Z_17\"))\n"
|
||||
"source": "# MANDATORY | difficulty 2 | Inspect \u03c9, \u03c8, And The Direct Transform Matrix\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_primitive_root, find_psi, ntt_psi_exponent_grid, ntt_psi_matrix\nfrom ntt_learning.visuals import direct_ntt_player, plot_ntt_psi_exponent_heatmap, plot_ntt_psi_matrix_heatmap\n\nmodulus = 17\nn = 4\nomega = find_primitive_root(n, modulus)\npsi = find_psi(n, modulus)\n\nprint(\"omega:\", omega)\nprint(\"psi:\", psi)\nprint(\"exponent grid:\")\nfor row in ntt_psi_exponent_grid(n):\n print(row)\nprint(\"NTT_psi matrix:\")\nfor row in ntt_psi_matrix(n, modulus, psi):\n print(row)\n\ndisplay(direct_ntt_player([1, 2, 3, 4], modulus, psi))\ndisplay(plot_ntt_psi_exponent_heatmap(n, title=\"Exponents 2ij + i for n=4\"))\ndisplay(plot_ntt_psi_matrix_heatmap(n, modulus, psi, title=\"Concrete NTT_psi matrix in Z_17\"))\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
@ -74,7 +74,7 @@
|
|||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": "# MANDATORY | difficulty 2 | Run A Direct NTT\u03c8 / INTT\u03c8 Round Trip\n\nfrom ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi\n\nsignal = [1, 2, 3, 4]\nmodulus = 17\npsi = find_psi(len(signal), modulus)\nspectrum = forward_ntt_psi(signal, modulus, psi)\n\nprint(\"signal:\", signal)\nprint(\"spectrum:\", spectrum)\nprint(\"inverse recovery:\", inverse_ntt_psi(spectrum, modulus, psi))\n"
|
||||
"source": "# MANDATORY | difficulty 2 | Run A Direct NTT\u03c8 / INTT\u03c8 Round Trip\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi\nfrom ntt_learning.visuals import direct_ntt_player\n\nsignal = [1, 2, 3, 4]\nmodulus = 17\npsi = find_psi(len(signal), modulus)\nspectrum = forward_ntt_psi(signal, modulus, psi)\n\nprint(\"signal:\", signal)\nprint(\"spectrum:\", spectrum)\nprint(\"inverse recovery:\", inverse_ntt_psi(spectrum, modulus, psi))\ndisplay(direct_ntt_player(signal, modulus, psi))\n"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from html import escape
|
||||
from typing import Sequence
|
||||
|
||||
import ipywidgets as widgets
|
||||
|
|
@ -16,6 +17,7 @@ from .toy_ntt import (
|
|||
TransformTrace,
|
||||
base_multiply_pair,
|
||||
bit_reversed_order,
|
||||
convolution_contributions,
|
||||
forward_ntt_psi,
|
||||
inverse_ntt_psi,
|
||||
ntt_psi_exponent_grid,
|
||||
|
|
@ -26,6 +28,426 @@ from .toy_ntt import (
|
|||
wraparound_contributions,
|
||||
)
|
||||
|
||||
SVG_BG = "#f7f3ea"
|
||||
SVG_PANEL = "#fffdf8"
|
||||
SVG_INK = "#1f2933"
|
||||
SVG_SOFT = "#d8e2dc"
|
||||
SVG_ACCENT = "#ef476f"
|
||||
SVG_HILITE = "#ffd166"
|
||||
SVG_GOOD = "#06d6a0"
|
||||
SVG_WARN = "#f08a5d"
|
||||
SVG_BLUE = "#118ab2"
|
||||
|
||||
|
||||
def _widget_chrome(title: str, subtitle: str, view: widgets.Widget) -> widgets.Widget:
|
||||
header = widgets.HTML(
|
||||
f"""
|
||||
<div style="
|
||||
background: linear-gradient(135deg, #f7ede2 0%, #f5cac3 45%, #84a59d 100%);
|
||||
color: #102a43;
|
||||
padding: 14px 16px;
|
||||
border-radius: 14px 14px 0 0;
|
||||
font-family: 'Avenir Next', 'Trebuchet MS', sans-serif;
|
||||
box-shadow: inset 0 0 0 1px rgba(16,42,67,0.12);
|
||||
">
|
||||
<div style="font-size: 18px; font-weight: 800;">{escape(title)}</div>
|
||||
<div style="font-size: 12px; margin-top: 4px;">{escape(subtitle)}</div>
|
||||
</div>
|
||||
"""
|
||||
)
|
||||
box = widgets.VBox(
|
||||
[header, view],
|
||||
layout=widgets.Layout(border="1px solid #d9d9d9", overflow="hidden"),
|
||||
)
|
||||
return box
|
||||
|
||||
|
||||
def _player_widget(
|
||||
*,
|
||||
title: str,
|
||||
subtitle: str,
|
||||
frames: Sequence[str],
|
||||
captions: Sequence[str],
|
||||
width: str = "100%",
|
||||
) -> widgets.Widget:
|
||||
if not frames:
|
||||
raise ValueError("player widget requires at least one frame")
|
||||
|
||||
play = widgets.Play(value=0, min=0, max=len(frames) - 1, step=1, interval=1100, description="Play")
|
||||
slider = widgets.IntSlider(
|
||||
value=0,
|
||||
min=0,
|
||||
max=len(frames) - 1,
|
||||
step=1,
|
||||
description="Step",
|
||||
continuous_update=False,
|
||||
layout=widgets.Layout(width="420px"),
|
||||
)
|
||||
widgets.jslink((play, "value"), (slider, "value"))
|
||||
|
||||
frame_html = widgets.HTML(layout=widgets.Layout(width=width))
|
||||
caption_html = widgets.HTML(layout=widgets.Layout(width=width))
|
||||
|
||||
def render(index: int) -> None:
|
||||
frame_html.value = frames[index]
|
||||
caption_html.value = f"""
|
||||
<div style="
|
||||
background: #fffdf8;
|
||||
border-top: 1px solid #e8dcc9;
|
||||
padding: 12px 16px;
|
||||
font-family: 'Avenir Next', 'Trebuchet MS', sans-serif;
|
||||
color: #243b53;
|
||||
font-size: 13px;
|
||||
line-height: 1.45;
|
||||
">
|
||||
<strong>Frame {index + 1} of {len(frames)}</strong><br>
|
||||
{escape(captions[index])}
|
||||
</div>
|
||||
"""
|
||||
|
||||
slider.observe(lambda change: render(change["new"]), names="value")
|
||||
render(0)
|
||||
|
||||
controls = widgets.HBox(
|
||||
[play, slider],
|
||||
layout=widgets.Layout(padding="10px 14px", align_items="center"),
|
||||
)
|
||||
return _widget_chrome(title, subtitle, widgets.VBox([controls, frame_html, caption_html]))
|
||||
|
||||
|
||||
def _svg_box(x: float, y: float, w: float, h: float, label: str, value: str, *, fill: str, stroke: str, stroke_width: float = 2.0) -> str:
|
||||
return f"""
|
||||
<rect x="{x}" y="{y}" width="{w}" height="{h}" rx="10" fill="{fill}" stroke="{stroke}" stroke-width="{stroke_width}"></rect>
|
||||
<text x="{x + w / 2}" y="{y + 18}" text-anchor="middle" font-size="11" font-family="Menlo, monospace" fill="{SVG_INK}">{escape(label)}</text>
|
||||
<text x="{x + w / 2}" y="{y + h / 2 + 10}" text-anchor="middle" font-size="22" font-weight="700" font-family="Menlo, monospace" fill="{SVG_INK}">{escape(value)}</text>
|
||||
"""
|
||||
|
||||
|
||||
def _svg_text(x: float, y: float, text: str, *, size: int = 14, weight: str = "400", fill: str = SVG_INK, anchor: str = "start") -> str:
|
||||
return f'<text x="{x}" y="{y}" text-anchor="{anchor}" font-size="{size}" font-weight="{weight}" font-family="Avenir Next, Trebuchet MS, sans-serif" fill="{fill}">{escape(text)}</text>'
|
||||
|
||||
|
||||
def _convolution_frame_svg(
|
||||
left: Sequence[int],
|
||||
right: Sequence[int],
|
||||
diagonal_index: int,
|
||||
*,
|
||||
title: str,
|
||||
) -> str:
|
||||
rows = convolution_contributions(left, right)
|
||||
current = rows[diagonal_index]
|
||||
grid = pairwise_product_grid(left, right)
|
||||
cell = 68
|
||||
left_margin = 78
|
||||
top_margin = 78
|
||||
bottom_y = top_margin + len(left) * cell + 70
|
||||
width = left_margin + len(right) * cell + 120
|
||||
height = bottom_y + 110
|
||||
|
||||
parts = [
|
||||
f'<svg viewBox="0 0 {width} {height}" width="100%" style="background:{SVG_BG}; border-radius:0 0 14px 14px;">',
|
||||
f'<rect x="18" y="18" width="{width - 36}" height="{height - 36}" rx="18" fill="{SVG_PANEL}" stroke="#e8dcc9" stroke-width="2"></rect>',
|
||||
_svg_text(34, 48, title, size=22, weight="800"),
|
||||
_svg_text(34, 68, f"Diagonal {diagonal_index}: every highlighted cell lands in output coefficient y{diagonal_index}", size=13, fill="#486581"),
|
||||
]
|
||||
|
||||
for row in range(len(left)):
|
||||
parts.append(_svg_text(54, top_margin + row * cell + 42, f"a{row}", size=12, anchor="middle"))
|
||||
for col in range(len(right)):
|
||||
parts.append(_svg_text(left_margin + col * cell + 28, 64, f"b{col}", size=12, anchor="middle"))
|
||||
|
||||
for row, row_values in enumerate(grid):
|
||||
for col, value in enumerate(row_values):
|
||||
active = row + col == diagonal_index
|
||||
done = row + col < diagonal_index
|
||||
fill = SVG_HILITE if active else "#e6fffa" if done else "#f8f9fa"
|
||||
stroke = SVG_ACCENT if active else SVG_GOOD if done else "#cbd2d9"
|
||||
parts.append(
|
||||
_svg_box(
|
||||
left_margin + col * cell,
|
||||
top_margin + row * cell,
|
||||
56,
|
||||
56,
|
||||
f"a{row}·b{col}",
|
||||
str(value),
|
||||
fill=fill,
|
||||
stroke=stroke,
|
||||
stroke_width=3.2 if active else 1.8,
|
||||
)
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
output_index = int(row["output_index"])
|
||||
fill = SVG_HILITE if output_index == diagonal_index else "#d9f0ff" if output_index < diagonal_index else "#f1f5f9"
|
||||
stroke = SVG_ACCENT if output_index == diagonal_index else SVG_BLUE if output_index < diagonal_index else "#cbd2d9"
|
||||
parts.append(
|
||||
_svg_box(
|
||||
left_margin + output_index * cell * 0.92,
|
||||
bottom_y,
|
||||
60,
|
||||
56,
|
||||
f"y{output_index}",
|
||||
str(row["total"]),
|
||||
fill=fill,
|
||||
stroke=stroke,
|
||||
stroke_width=3.2 if output_index == diagonal_index else 1.8,
|
||||
)
|
||||
)
|
||||
|
||||
terms = [f"{term['left_value']}×{term['right_value']}={term['product']}" for term in current["terms"]]
|
||||
equation = " + ".join(terms) if terms else "0"
|
||||
parts.append(_svg_text(34, height - 34, f"Current diagonal sum: {equation} = {current['total']}", size=15, weight="700"))
|
||||
parts.append("</svg>")
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def schoolbook_diagonal_player(left: Sequence[int], right: Sequence[int]) -> widgets.Widget:
|
||||
"""Interactive diagonal-by-diagonal walkthrough of schoolbook multiplication."""
|
||||
rows = convolution_contributions(left, right)
|
||||
frames = [
|
||||
_convolution_frame_svg(left, right, diagonal_index=index, title="Schoolbook Multiplication As A Moving Diagonal")
|
||||
for index in range(len(rows))
|
||||
]
|
||||
captions = [
|
||||
f"Only the highlighted products contribute to y{row['output_index']}. Watch the diagonal sweep across the grid instead of imagining the sum in your head."
|
||||
for row in rows
|
||||
]
|
||||
return _player_widget(
|
||||
title="Convolution Diagonal Player",
|
||||
subtitle="Press play. The highlighted diagonal is the coefficient being formed right now.",
|
||||
frames=frames,
|
||||
captions=captions,
|
||||
)
|
||||
|
||||
|
||||
def _wrap_compare_frame_svg(coefficients: Sequence[int], n: int, step: int) -> str:
|
||||
cyclic_rows = wraparound_contributions(coefficients, n=n, negacyclic=False)
|
||||
negacyclic_rows = wraparound_contributions(coefficients, n=n, negacyclic=True)
|
||||
flat = [(index, coefficient) for index, coefficient in enumerate(coefficients)]
|
||||
current_index, current_value = flat[step]
|
||||
cell = 58
|
||||
width = max(980, len(coefficients) * cell + 200)
|
||||
height = 390
|
||||
top_y = 78
|
||||
cyclic_y = 210
|
||||
neg_y = 300
|
||||
|
||||
parts = [
|
||||
f'<svg viewBox="0 0 {width} {height}" width="100%" style="background:{SVG_BG}; border-radius:0 0 14px 14px;">',
|
||||
f'<rect x="18" y="18" width="{width - 36}" height="{height - 36}" rx="18" fill="{SVG_PANEL}" stroke="#e8dcc9" stroke-width="2"></rect>',
|
||||
_svg_text(34, 48, "Wraparound Comparison Player", size=22, weight="800"),
|
||||
_svg_text(34, 68, f"Current source term: x^{current_index} with coefficient {current_value}", size=13, fill="#486581"),
|
||||
_svg_text(34, top_y - 14, "Raw convolution tail", size=14, weight="700"),
|
||||
_svg_text(34, cyclic_y - 14, "Cyclic fold into x^n - 1", size=14, weight="700", fill=SVG_BLUE),
|
||||
_svg_text(34, neg_y - 14, "Negacyclic fold into x^n + 1", size=14, weight="700", fill=SVG_ACCENT),
|
||||
]
|
||||
|
||||
for index, coefficient in enumerate(coefficients):
|
||||
fill = SVG_HILITE if index == current_index else "#f1f5f9" if index > current_index else "#d8f3dc"
|
||||
stroke = SVG_ACCENT if index == current_index else "#cbd2d9" if index > current_index else SVG_GOOD
|
||||
parts.append(_svg_box(48 + index * cell, top_y, 48, 48, f"x^{index}", str(coefficient), fill=fill, stroke=stroke))
|
||||
|
||||
for slot, row in enumerate(cyclic_rows):
|
||||
parts.append(_svg_box(48 + slot * cell, cyclic_y, 48, 48, f"slot {slot}", str(row["total"]), fill="#e0fbfc", stroke=SVG_BLUE))
|
||||
for slot, row in enumerate(negacyclic_rows):
|
||||
parts.append(_svg_box(48 + slot * cell, neg_y, 48, 48, f"slot {slot}", str(row["total"]), fill="#ffe8d6", stroke=SVG_ACCENT))
|
||||
|
||||
wraps, slot = divmod(current_index, n)
|
||||
cyclic_label = f"+ wrap {wraps}"
|
||||
neg_label = ("-" if wraps % 2 else "+") + f" wrap {wraps}"
|
||||
source_x = 72 + current_index * cell
|
||||
target_x = 72 + slot * cell
|
||||
|
||||
parts.append(f'<line x1="{source_x}" y1="{top_y + 48}" x2="{target_x}" y2="{cyclic_y}" stroke="{SVG_BLUE}" stroke-width="4" marker-end="url(#arrow-blue)"></line>')
|
||||
parts.append(f'<line x1="{source_x}" y1="{top_y + 48}" x2="{target_x}" y2="{neg_y}" stroke="{SVG_ACCENT}" stroke-width="4" marker-end="url(#arrow-red)"></line>')
|
||||
parts.append(f"""
|
||||
<defs>
|
||||
<marker id="arrow-blue" markerWidth="10" markerHeight="10" refX="7" refY="3" orient="auto">
|
||||
<polygon points="0 0, 8 3, 0 6" fill="{SVG_BLUE}"></polygon>
|
||||
</marker>
|
||||
<marker id="arrow-red" markerWidth="10" markerHeight="10" refX="7" refY="3" orient="auto">
|
||||
<polygon points="0 0, 8 3, 0 6" fill="{SVG_ACCENT}"></polygon>
|
||||
</marker>
|
||||
</defs>
|
||||
""")
|
||||
parts.append(_svg_text((source_x + target_x) / 2, cyclic_y - 18, cyclic_label, size=12, fill=SVG_BLUE, anchor="middle"))
|
||||
parts.append(_svg_text((source_x + target_x) / 2, neg_y - 18, neg_label, size=12, fill=SVG_ACCENT, anchor="middle"))
|
||||
parts.append("</svg>")
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def wraparound_comparison_player(coefficients: Sequence[int], n: int) -> widgets.Widget:
|
||||
"""Interactive comparison of cyclic and negacyclic folding, one source term at a time."""
|
||||
frames = [_wrap_compare_frame_svg(coefficients, n, step) for step in range(len(coefficients))]
|
||||
captions = []
|
||||
for index, coefficient in enumerate(coefficients):
|
||||
wraps, slot = divmod(index, n)
|
||||
neg_sign = "-" if wraps % 2 else "+"
|
||||
captions.append(
|
||||
f"x^{index} with coefficient {coefficient} lands in slot {slot}. Cyclic folding keeps a + sign; negacyclic folding uses {neg_sign} after {wraps} wrap(s)."
|
||||
)
|
||||
return _player_widget(
|
||||
title="Wraparound Step Player",
|
||||
subtitle="Play one raw coefficient at a time and compare x^n-1 with x^n+1 side by side.",
|
||||
frames=frames,
|
||||
captions=captions,
|
||||
)
|
||||
|
||||
|
||||
def _direct_ntt_frame_svg(values: Sequence[int], modulus: int, psi: int, output_index: int, input_index: int) -> str:
|
||||
exponents = ntt_psi_exponent_grid(len(values))
|
||||
matrix = ntt_psi_matrix(len(values), modulus, psi)
|
||||
cell = 62
|
||||
width = 1040
|
||||
height = 470
|
||||
grid_x = 260
|
||||
grid_y = 84
|
||||
boxes_y = 360
|
||||
contributions = []
|
||||
for row, value in enumerate(values):
|
||||
exponent = exponents[output_index][row]
|
||||
factor = matrix[output_index][row]
|
||||
product = (value * factor) % modulus
|
||||
contributions.append((exponent, factor, product))
|
||||
|
||||
partial = sum(product for _, _, product in contributions[: input_index + 1]) % modulus
|
||||
final = sum(product for _, _, product in contributions) % modulus
|
||||
|
||||
parts = [
|
||||
f'<svg viewBox="0 0 {width} {height}" width="100%" style="background:{SVG_BG}; border-radius:0 0 14px 14px;">',
|
||||
f'<rect x="18" y="18" width="{width - 36}" height="{height - 36}" rx="18" fill="{SVG_PANEL}" stroke="#e8dcc9" stroke-width="2"></rect>',
|
||||
_svg_text(34, 48, "Direct NTTψ Contribution Player", size=22, weight="800"),
|
||||
_svg_text(34, 68, f"Building output slot j={output_index}, currently consuming input i={input_index}", size=13, fill="#486581"),
|
||||
_svg_text(34, 108, "signal", size=14, weight="700"),
|
||||
_svg_text(grid_x, 68, "matrix cell = psi^(2ij + i)", size=14, weight="700"),
|
||||
]
|
||||
|
||||
for index, value in enumerate(values):
|
||||
fill = SVG_HILITE if index == input_index else "#d8f3dc" if index < input_index else "#f1f5f9"
|
||||
stroke = SVG_ACCENT if index == input_index else SVG_GOOD if index < input_index else "#cbd2d9"
|
||||
parts.append(_svg_box(36, 128 + index * 60, 92, 48, f"a{index}", str(value), fill=fill, stroke=stroke))
|
||||
|
||||
for row in range(len(values)):
|
||||
parts.append(_svg_text(grid_x - 18, grid_y + row * cell + 34, f"i={row}", size=12, anchor="end"))
|
||||
for col in range(len(values)):
|
||||
parts.append(_svg_text(grid_x + col * cell + 28, grid_y - 12, f"j={col}", size=12, anchor="middle"))
|
||||
|
||||
for col in range(len(values)):
|
||||
for row in range(len(values)):
|
||||
active = col == output_index and row == input_index
|
||||
done = col == output_index and row < input_index
|
||||
fill = SVG_HILITE if active else "#d8f3dc" if done else "#f1f5f9"
|
||||
stroke = SVG_ACCENT if active else SVG_GOOD if done else "#cbd2d9"
|
||||
parts.append(
|
||||
_svg_box(
|
||||
grid_x + col * cell,
|
||||
grid_y + row * cell,
|
||||
56,
|
||||
56,
|
||||
f"{exponents[col][row]}",
|
||||
str(matrix[col][row]),
|
||||
fill=fill,
|
||||
stroke=stroke,
|
||||
stroke_width=3.0 if active else 1.6,
|
||||
)
|
||||
)
|
||||
|
||||
for index, (_, factor, product) in enumerate(contributions):
|
||||
fill = SVG_HILITE if index == input_index else "#d8f3dc" if index < input_index else "#f1f5f9"
|
||||
stroke = SVG_ACCENT if index == input_index else SVG_GOOD if index < input_index else "#cbd2d9"
|
||||
parts.append(_svg_box(520 + index * 90, boxes_y, 74, 54, f"a{index}·w", str(product), fill=fill, stroke=stroke))
|
||||
parts.append(_svg_text(556 + index * 90, boxes_y - 10, f"factor {factor}", size=11, anchor="middle"))
|
||||
|
||||
parts.append(_svg_text(34, height - 60, f"Current partial sum for j={output_index}: {partial} mod {modulus}", size=15, weight="700"))
|
||||
parts.append(_svg_text(34, height - 34, f"Completed output slot y{output_index}: {final} mod {modulus}", size=15, weight="700", fill=SVG_ACCENT))
|
||||
parts.append("</svg>")
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def direct_ntt_player(values: Sequence[int], modulus: int, psi: int) -> widgets.Widget:
|
||||
"""Interactive walkthrough of the direct NTT_psi matrix multiplication."""
|
||||
frames = []
|
||||
captions = []
|
||||
for output_index in range(len(values)):
|
||||
for input_index in range(len(values)):
|
||||
exponent = 2 * input_index * output_index + input_index
|
||||
factor = pow(psi, exponent, modulus)
|
||||
product = (values[input_index] * factor) % modulus
|
||||
frames.append(_direct_ntt_frame_svg(values, modulus, psi, output_index, input_index))
|
||||
captions.append(
|
||||
f"Output slot j={output_index}: multiply a{input_index}={values[input_index]} by psi^{exponent}={factor} to contribute {product} mod {modulus}."
|
||||
)
|
||||
return _player_widget(
|
||||
title="Direct NTTψ Player",
|
||||
subtitle="Press play and watch the transform build one contribution at a time.",
|
||||
frames=frames,
|
||||
captions=captions,
|
||||
)
|
||||
|
||||
|
||||
def _butterfly_story_frame_svg(trace: TransformTrace, stage_index: int, pair_index: int) -> str:
|
||||
stage = trace.stages[stage_index]
|
||||
left, right = stage.pairings[pair_index]
|
||||
zeta = stage.zetas[pair_index]
|
||||
width = max(820, len(stage.input_values) * 120)
|
||||
height = 340
|
||||
input_y = 110
|
||||
output_y = 240
|
||||
spacing = 92
|
||||
start_x = 60
|
||||
parts = [
|
||||
f'<svg viewBox="0 0 {width} {height}" width="100%" style="background:{SVG_BG}; border-radius:0 0 14px 14px;">',
|
||||
f'<rect x="18" y="18" width="{width - 36}" height="{height - 36}" rx="18" fill="{SVG_PANEL}" stroke="#e8dcc9" stroke-width="2"></rect>',
|
||||
_svg_text(34, 48, f"{trace.algorithm.upper()} Butterfly Player", size=22, weight="800"),
|
||||
_svg_text(34, 68, f"Stage {stage.stage_index}, active pair ({left}, {right}), zeta={zeta}", size=13, fill="#486581"),
|
||||
]
|
||||
|
||||
for index, value in enumerate(stage.input_values):
|
||||
active = index in (left, right)
|
||||
fill = SVG_HILITE if active else "#f1f5f9"
|
||||
stroke = SVG_ACCENT if active else "#cbd2d9"
|
||||
parts.append(_svg_box(start_x + index * spacing, input_y, 64, 52, f"in {index}", str(value), fill=fill, stroke=stroke, stroke_width=3.0 if active else 1.8))
|
||||
|
||||
for index, value in enumerate(stage.output_values):
|
||||
active = index in (left, right)
|
||||
fill = "#d8f3dc" if active else "#f1f5f9"
|
||||
stroke = SVG_GOOD if active else "#cbd2d9"
|
||||
parts.append(_svg_box(start_x + index * spacing, output_y, 64, 52, f"out {index}", str(value), fill=fill, stroke=stroke, stroke_width=3.0 if active else 1.8))
|
||||
|
||||
for index in range(len(stage.input_values)):
|
||||
x = start_x + index * spacing + 32
|
||||
color = SVG_ACCENT if index in (left, right) else "#d9d9d9"
|
||||
width_line = 4 if index in (left, right) else 1.8
|
||||
parts.append(f'<line x1="{x}" y1="{input_y + 52}" x2="{x}" y2="{output_y}" stroke="{color}" stroke-width="{width_line}"></line>')
|
||||
|
||||
x_left = start_x + left * spacing + 32
|
||||
x_right = start_x + right * spacing + 32
|
||||
parts.append(f'<line x1="{x_left}" y1="{input_y + 62}" x2="{x_right}" y2="{input_y + 62}" stroke="{SVG_ACCENT}" stroke-width="4"></line>')
|
||||
parts.append(_svg_text((x_left + x_right) / 2, 196, f"pair ({left}, {right})", size=12, weight="700", fill=SVG_ACCENT, anchor="middle"))
|
||||
parts.append(_svg_text((x_left + x_right) / 2, 214, f"inputs -> outputs = ({stage.input_values[left]}, {stage.input_values[right]}) -> ({stage.output_values[left]}, {stage.output_values[right]})", size=11, anchor="middle"))
|
||||
parts.append(_svg_text((x_left + x_right) / 2, 232, stage.note, size=11, anchor="middle", fill="#486581"))
|
||||
parts.append("</svg>")
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def butterfly_story_player(trace: TransformTrace) -> widgets.Widget:
|
||||
"""Interactive pair-by-pair walkthrough of a butterfly trace."""
|
||||
frames = []
|
||||
captions = []
|
||||
for stage_index, stage in enumerate(trace.stages):
|
||||
for pair_index, pair in enumerate(stage.pairings):
|
||||
left, right = pair
|
||||
frames.append(_butterfly_story_frame_svg(trace, stage_index, pair_index))
|
||||
captions.append(
|
||||
f"Stage {stage.stage_index}: the active pair is {pair}. Watch only these two wires. Everything else is parked until its own butterfly fires."
|
||||
)
|
||||
return _player_widget(
|
||||
title="Butterfly Pair Player",
|
||||
subtitle="Play pair by pair. This is the local machine the learner needs to internalize.",
|
||||
frames=frames,
|
||||
captions=captions,
|
||||
)
|
||||
|
||||
|
||||
def _value_colors(values: Sequence[int]) -> list[str]:
|
||||
colors = []
|
||||
|
|
|
|||
|
|
@ -455,12 +455,12 @@ def build_bundle_01() -> None:
|
|||
"mandatory",
|
||||
2,
|
||||
"demo",
|
||||
"See The Product Grid And The Diagonal Sums",
|
||||
"Play The Diagonal Sweep",
|
||||
"""
|
||||
from IPython.display import display
|
||||
|
||||
from ntt_learning.toy_ntt import convolution_contributions, schoolbook_convolution
|
||||
from ntt_learning.visuals import plot_convolution_grid
|
||||
from ntt_learning.visuals import plot_convolution_grid, schoolbook_diagonal_player
|
||||
|
||||
left = [1, 2, 3, 4]
|
||||
right = [5, 6, 7, 8]
|
||||
|
|
@ -470,6 +470,7 @@ def build_bundle_01() -> None:
|
|||
for row in convolution_contributions(left, right):
|
||||
print(row)
|
||||
|
||||
display(schoolbook_diagonal_player(left, right))
|
||||
fig = plot_convolution_grid(left, right, title="Schoolbook products for [1,2,3,4] * [5,6,7,8]")
|
||||
display(fig)
|
||||
""",
|
||||
|
|
@ -492,12 +493,12 @@ def build_bundle_01() -> None:
|
|||
"mandatory",
|
||||
2,
|
||||
"demo",
|
||||
"See Cyclic And Negacyclic Folding Side By Side",
|
||||
"Play The Wraparound Step By Step",
|
||||
"""
|
||||
from IPython.display import display
|
||||
|
||||
from ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution, wraparound_contributions
|
||||
from ntt_learning.visuals import plot_wraparound
|
||||
from ntt_learning.visuals import plot_wraparound, wraparound_comparison_player
|
||||
|
||||
left = [1, 2, 3, 4]
|
||||
right = [5, 6, 7, 8]
|
||||
|
|
@ -512,6 +513,7 @@ def build_bundle_01() -> None:
|
|||
for row in wraparound_contributions(raw, n=4, negacyclic=True):
|
||||
print(row)
|
||||
|
||||
display(wraparound_comparison_player(raw, n=4))
|
||||
display(plot_wraparound(raw, n=4, negacyclic=False, title="Cyclic folding into x^4 - 1"))
|
||||
display(plot_wraparound(raw, n=4, negacyclic=True, title="Negacyclic folding into x^4 + 1"))
|
||||
""",
|
||||
|
|
@ -599,7 +601,7 @@ def build_bundle_01() -> None:
|
|||
from IPython.display import display
|
||||
|
||||
from ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution
|
||||
from ntt_learning.visuals import plot_convolution_grid, plot_wraparound
|
||||
from ntt_learning.visuals import plot_convolution_grid, plot_wraparound, schoolbook_diagonal_player, wraparound_comparison_player
|
||||
|
||||
left = [3, 0, 2, 1]
|
||||
right = [1, 4, 0, 2]
|
||||
|
|
@ -607,6 +609,8 @@ def build_bundle_01() -> None:
|
|||
|
||||
print("raw convolution:", raw)
|
||||
print("negacyclic result:", negacyclic_multiply(left, right, n=4))
|
||||
display(schoolbook_diagonal_player(left, right))
|
||||
display(wraparound_comparison_player(raw, n=4))
|
||||
display(plot_convolution_grid(left, right, title="Prediction check grid"))
|
||||
display(plot_wraparound(raw, n=4, negacyclic=True, title="Prediction check fold"))
|
||||
""",
|
||||
|
|
@ -636,10 +640,11 @@ def build_bundle_01() -> None:
|
|||
from IPython.display import display
|
||||
|
||||
from ntt_learning.toy_ntt import schoolbook_convolution
|
||||
from ntt_learning.visuals import plot_wraparound
|
||||
from ntt_learning.visuals import plot_wraparound, wraparound_comparison_player
|
||||
|
||||
raw = schoolbook_convolution([2, 5, 0, 1], [1, 0, 3, 2])
|
||||
print("raw convolution:", raw)
|
||||
display(wraparound_comparison_player(raw, n=4))
|
||||
display(plot_wraparound(raw, n=4, negacyclic=True, title="Trace one tail coefficient by eye"))
|
||||
""",
|
||||
),
|
||||
|
|
@ -815,10 +820,11 @@ def build_bundle_01() -> None:
|
|||
from IPython.display import display
|
||||
|
||||
from ntt_learning.toy_ntt import schoolbook_convolution
|
||||
from ntt_learning.visuals import plot_wraparound
|
||||
from ntt_learning.visuals import plot_wraparound, wraparound_comparison_player
|
||||
|
||||
raw = schoolbook_convolution([1, 2, 3, 4], [5, 6, 7, 8])
|
||||
print("raw convolution:", raw)
|
||||
display(wraparound_comparison_player(raw, n=4))
|
||||
display(plot_wraparound(raw, n=4, negacyclic=False, title="Positive wrap into x^4 - 1"))
|
||||
display(plot_wraparound(raw, n=4, negacyclic=True, title="Negative wrap into x^4 + 1"))
|
||||
""",
|
||||
|
|
@ -926,7 +932,7 @@ def build_bundle_02() -> None:
|
|||
from IPython.display import display
|
||||
|
||||
from ntt_learning.toy_ntt import find_primitive_root, find_psi, ntt_psi_exponent_grid, ntt_psi_matrix
|
||||
from ntt_learning.visuals import plot_ntt_psi_exponent_heatmap, plot_ntt_psi_matrix_heatmap
|
||||
from ntt_learning.visuals import direct_ntt_player, plot_ntt_psi_exponent_heatmap, plot_ntt_psi_matrix_heatmap
|
||||
|
||||
modulus = 17
|
||||
n = 4
|
||||
|
|
@ -942,6 +948,7 @@ def build_bundle_02() -> None:
|
|||
for row in ntt_psi_matrix(n, modulus, psi):
|
||||
print(row)
|
||||
|
||||
display(direct_ntt_player([1, 2, 3, 4], modulus, psi))
|
||||
display(plot_ntt_psi_exponent_heatmap(n, title="Exponents 2ij + i for n=4"))
|
||||
display(plot_ntt_psi_matrix_heatmap(n, modulus, psi, title="Concrete NTT_psi matrix in Z_17"))
|
||||
""",
|
||||
|
|
@ -962,7 +969,10 @@ def build_bundle_02() -> None:
|
|||
"demo",
|
||||
"Run A Direct NTTψ / INTTψ Round Trip",
|
||||
"""
|
||||
from IPython.display import display
|
||||
|
||||
from ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi
|
||||
from ntt_learning.visuals import direct_ntt_player
|
||||
|
||||
signal = [1, 2, 3, 4]
|
||||
modulus = 17
|
||||
|
|
@ -972,6 +982,7 @@ def build_bundle_02() -> None:
|
|||
print("signal:", signal)
|
||||
print("spectrum:", spectrum)
|
||||
print("inverse recovery:", inverse_ntt_psi(spectrum, modulus, psi))
|
||||
display(direct_ntt_player(signal, modulus, psi))
|
||||
""",
|
||||
),
|
||||
code(
|
||||
|
|
@ -1079,6 +1090,7 @@ def build_bundle_02() -> None:
|
|||
from IPython.display import display
|
||||
|
||||
from ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi
|
||||
from ntt_learning.visuals import direct_ntt_player
|
||||
|
||||
modulus = 17
|
||||
psi = find_psi(4, modulus)
|
||||
|
|
@ -1089,6 +1101,7 @@ def build_bundle_02() -> None:
|
|||
print("signal:", signal)
|
||||
print("spectrum:", spectrum)
|
||||
print("inverse:", inverse_ntt_psi(spectrum, modulus, psi))
|
||||
display(direct_ntt_player(signal, modulus, psi))
|
||||
|
||||
display(
|
||||
widgets.interact(
|
||||
|
|
@ -1449,7 +1462,7 @@ def build_bundle_03() -> None:
|
|||
from IPython.display import display
|
||||
|
||||
from ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, forward_ntt_psi
|
||||
from ntt_learning.visuals import interactive_trace, plot_butterfly_network, plot_trace_overview
|
||||
from ntt_learning.visuals import butterfly_story_player, interactive_trace, plot_butterfly_network, plot_trace_overview
|
||||
|
||||
signal = [1, 2, 3, 4]
|
||||
modulus = 7681
|
||||
|
|
@ -1459,6 +1472,7 @@ def build_bundle_03() -> None:
|
|||
print("raw CT output (BO):", trace.raw_output)
|
||||
print("bit-reversed back to NO:", trace.normal_order_output)
|
||||
print("direct NTT_psi:", forward_ntt_psi(signal, modulus, psi))
|
||||
display(butterfly_story_player(trace))
|
||||
display(plot_trace_overview(trace, title="CT overview for [1,2,3,4]"))
|
||||
display(plot_butterfly_network(trace, title="Full CT network for [1,2,3,4]"))
|
||||
display(interactive_trace(trace, title="CT forward trace"))
|
||||
|
|
@ -1505,7 +1519,7 @@ def build_bundle_03() -> None:
|
|||
from IPython.display import display
|
||||
|
||||
from ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, find_psi
|
||||
from ntt_learning.visuals import interactive_trace, plot_butterfly_network, plot_trace_overview
|
||||
from ntt_learning.visuals import butterfly_story_player, interactive_trace, plot_butterfly_network, plot_trace_overview
|
||||
|
||||
signal = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
modulus = 97
|
||||
|
|
@ -1515,6 +1529,7 @@ def build_bundle_03() -> None:
|
|||
print("psi:", psi)
|
||||
print("BO output:", trace.raw_output)
|
||||
print("NO output:", trace.normal_order_output)
|
||||
display(butterfly_story_player(trace))
|
||||
display(plot_trace_overview(trace, title="Three CT stages for n=8"))
|
||||
display(plot_butterfly_network(trace, title="Full CT network for n=8"))
|
||||
display(interactive_trace(trace, title="n=8 CT stage explorer"))
|
||||
|
|
@ -1604,9 +1619,10 @@ def build_bundle_03() -> None:
|
|||
from IPython.display import display
|
||||
|
||||
from ntt_learning.toy_ntt import fast_ntt_psi_ct_trace
|
||||
from ntt_learning.visuals import interactive_trace
|
||||
from ntt_learning.visuals import butterfly_story_player, interactive_trace
|
||||
|
||||
trace = fast_ntt_psi_ct_trace([1, 2, 3, 4], 7681, 1925)
|
||||
display(butterfly_story_player(trace))
|
||||
display(interactive_trace(trace, title="Check your n=4 prediction"))
|
||||
""",
|
||||
),
|
||||
|
|
|
|||
Loading…
Reference in a new issue