diff --git a/notebooks/butterfly_mechanics/03_fast_forward_ct/lab.ipynb b/notebooks/butterfly_mechanics/03_fast_forward_ct/lab.ipynb
index 8004e6c..6917614 100644
--- a/notebooks/butterfly_mechanics/03_fast_forward_ct/lab.ipynb
+++ b/notebooks/butterfly_mechanics/03_fast_forward_ct/lab.ipynb
@@ -48,7 +48,7 @@
}
},
"outputs": [],
- "source": "# MANDATORY | difficulty 3 | Prediction Check For n=4\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace\nfrom ntt_learning.visuals import interactive_trace\n\ntrace = fast_ntt_psi_ct_trace([1, 2, 3, 4], 7681, 1925)\ndisplay(interactive_trace(trace, title=\"Check your n=4 prediction\"))\n"
+ "source": "# MANDATORY | difficulty 3 | Prediction Check For n=4\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace\nfrom ntt_learning.visuals import butterfly_story_player, interactive_trace\n\ntrace = fast_ntt_psi_ct_trace([1, 2, 3, 4], 7681, 1925)\ndisplay(butterfly_story_player(trace))\ndisplay(interactive_trace(trace, title=\"Check your n=4 prediction\"))\n"
},
{
"cell_type": "markdown",
diff --git a/notebooks/butterfly_mechanics/03_fast_forward_ct/lecture.ipynb b/notebooks/butterfly_mechanics/03_fast_forward_ct/lecture.ipynb
index fab970f..50a6656 100644
--- a/notebooks/butterfly_mechanics/03_fast_forward_ct/lecture.ipynb
+++ b/notebooks/butterfly_mechanics/03_fast_forward_ct/lecture.ipynb
@@ -62,7 +62,7 @@
}
},
"outputs": [],
- "source": "# MANDATORY | difficulty 3 | Trace The Exact n=4 Paper Example\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, forward_ntt_psi\nfrom ntt_learning.visuals import interactive_trace, plot_butterfly_network, plot_trace_overview\n\nsignal = [1, 2, 3, 4]\nmodulus = 7681\npsi = 1925\ntrace = fast_ntt_psi_ct_trace(signal, modulus, psi)\n\nprint(\"raw CT output (BO):\", trace.raw_output)\nprint(\"bit-reversed back to NO:\", trace.normal_order_output)\nprint(\"direct NTT_psi:\", forward_ntt_psi(signal, modulus, psi))\ndisplay(plot_trace_overview(trace, title=\"CT overview for [1,2,3,4]\"))\ndisplay(plot_butterfly_network(trace, title=\"Full CT network for [1,2,3,4]\"))\ndisplay(interactive_trace(trace, title=\"CT forward trace\"))\n"
+ "source": "# MANDATORY | difficulty 3 | Trace The Exact n=4 Paper Example\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, forward_ntt_psi\nfrom ntt_learning.visuals import butterfly_story_player, interactive_trace, plot_butterfly_network, plot_trace_overview\n\nsignal = [1, 2, 3, 4]\nmodulus = 7681\npsi = 1925\ntrace = fast_ntt_psi_ct_trace(signal, modulus, psi)\n\nprint(\"raw CT output (BO):\", trace.raw_output)\nprint(\"bit-reversed back to NO:\", trace.normal_order_output)\nprint(\"direct NTT_psi:\", forward_ntt_psi(signal, modulus, psi))\ndisplay(butterfly_story_player(trace))\ndisplay(plot_trace_overview(trace, title=\"CT overview for [1,2,3,4]\"))\ndisplay(plot_butterfly_network(trace, title=\"Full CT network for [1,2,3,4]\"))\ndisplay(interactive_trace(trace, title=\"CT forward trace\"))\n"
},
{
"cell_type": "markdown",
@@ -102,7 +102,7 @@
}
},
"outputs": [],
- "source": "# MANDATORY | difficulty 3 | Go One Stage Deeper With n=8\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, find_psi\nfrom ntt_learning.visuals import interactive_trace, plot_butterfly_network, plot_trace_overview\n\nsignal = [0, 1, 2, 3, 4, 5, 6, 7]\nmodulus = 97\npsi = find_psi(8, modulus)\ntrace = fast_ntt_psi_ct_trace(signal, modulus, psi)\n\nprint(\"psi:\", psi)\nprint(\"BO output:\", trace.raw_output)\nprint(\"NO output:\", trace.normal_order_output)\ndisplay(plot_trace_overview(trace, title=\"Three CT stages for n=8\"))\ndisplay(plot_butterfly_network(trace, title=\"Full CT network for n=8\"))\ndisplay(interactive_trace(trace, title=\"n=8 CT stage explorer\"))\n"
+ "source": "# MANDATORY | difficulty 3 | Go One Stage Deeper With n=8\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, find_psi\nfrom ntt_learning.visuals import butterfly_story_player, interactive_trace, plot_butterfly_network, plot_trace_overview\n\nsignal = [0, 1, 2, 3, 4, 5, 6, 7]\nmodulus = 97\npsi = find_psi(8, modulus)\ntrace = fast_ntt_psi_ct_trace(signal, modulus, psi)\n\nprint(\"psi:\", psi)\nprint(\"BO output:\", trace.raw_output)\nprint(\"NO output:\", trace.normal_order_output)\ndisplay(butterfly_story_player(trace))\ndisplay(plot_trace_overview(trace, title=\"Three CT stages for n=8\"))\ndisplay(plot_butterfly_network(trace, title=\"Full CT network for n=8\"))\ndisplay(interactive_trace(trace, title=\"n=8 CT stage explorer\"))\n"
},
{
"cell_type": "code",
diff --git a/notebooks/foundations/01_convolution_to_toy_ntt/lab.ipynb b/notebooks/foundations/01_convolution_to_toy_ntt/lab.ipynb
index 326b210..2bc5b86 100644
--- a/notebooks/foundations/01_convolution_to_toy_ntt/lab.ipynb
+++ b/notebooks/foundations/01_convolution_to_toy_ntt/lab.ipynb
@@ -48,7 +48,7 @@
}
},
"outputs": [],
- "source": "# MANDATORY | difficulty 2 | Run The Prediction Check\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution\nfrom ntt_learning.visuals import plot_convolution_grid, plot_wraparound\n\nleft = [3, 0, 2, 1]\nright = [1, 4, 0, 2]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nprint(\"negacyclic result:\", negacyclic_multiply(left, right, n=4))\ndisplay(plot_convolution_grid(left, right, title=\"Prediction check grid\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Prediction check fold\"))\n"
+ "source": "# MANDATORY | difficulty 2 | Run The Prediction Check\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution\nfrom ntt_learning.visuals import plot_convolution_grid, plot_wraparound, schoolbook_diagonal_player, wraparound_comparison_player\n\nleft = [3, 0, 2, 1]\nright = [1, 4, 0, 2]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nprint(\"negacyclic result:\", negacyclic_multiply(left, right, n=4))\ndisplay(schoolbook_diagonal_player(left, right))\ndisplay(wraparound_comparison_player(raw, n=4))\ndisplay(plot_convolution_grid(left, right, title=\"Prediction check grid\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Prediction check fold\"))\n"
},
{
"cell_type": "markdown",
@@ -74,7 +74,7 @@
}
},
"outputs": [],
- "source": "# MANDATORY | difficulty 2 | A Second Visual Drill\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import schoolbook_convolution\nfrom ntt_learning.visuals import plot_wraparound\n\nraw = schoolbook_convolution([2, 5, 0, 1], [1, 0, 3, 2])\nprint(\"raw convolution:\", raw)\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Trace one tail coefficient by eye\"))\n"
+ "source": "# MANDATORY | difficulty 2 | A Second Visual Drill\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import schoolbook_convolution\nfrom ntt_learning.visuals import plot_wraparound, wraparound_comparison_player\n\nraw = schoolbook_convolution([2, 5, 0, 1], [1, 0, 3, 2])\nprint(\"raw convolution:\", raw)\ndisplay(wraparound_comparison_player(raw, n=4))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Trace one tail coefficient by eye\"))\n"
},
{
"cell_type": "markdown",
diff --git a/notebooks/foundations/01_convolution_to_toy_ntt/lecture.ipynb b/notebooks/foundations/01_convolution_to_toy_ntt/lecture.ipynb
index d31bda9..ee04fa2 100644
--- a/notebooks/foundations/01_convolution_to_toy_ntt/lecture.ipynb
+++ b/notebooks/foundations/01_convolution_to_toy_ntt/lecture.ipynb
@@ -44,11 +44,11 @@
"role": "mandatory",
"difficulty": 2,
"kind": "demo",
- "title": "See The Product Grid And The Diagonal Sums"
+ "title": "Play The Diagonal Sweep"
}
},
"outputs": [],
- "source": "# MANDATORY | difficulty 2 | See The Product Grid And The Diagonal Sums\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import convolution_contributions, schoolbook_convolution\nfrom ntt_learning.visuals import plot_convolution_grid\n\nleft = [1, 2, 3, 4]\nright = [5, 6, 7, 8]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nfor row in convolution_contributions(left, right):\n print(row)\n\nfig = plot_convolution_grid(left, right, title=\"Schoolbook products for [1,2,3,4] * [5,6,7,8]\")\ndisplay(fig)\n"
+ "source": "# MANDATORY | difficulty 2 | Play The Diagonal Sweep\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import convolution_contributions, schoolbook_convolution\nfrom ntt_learning.visuals import plot_convolution_grid, schoolbook_diagonal_player\n\nleft = [1, 2, 3, 4]\nright = [5, 6, 7, 8]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nfor row in convolution_contributions(left, right):\n print(row)\n\ndisplay(schoolbook_diagonal_player(left, right))\nfig = plot_convolution_grid(left, right, title=\"Schoolbook products for [1,2,3,4] * [5,6,7,8]\")\ndisplay(fig)\n"
},
{
"cell_type": "markdown",
@@ -70,11 +70,11 @@
"role": "mandatory",
"difficulty": 2,
"kind": "demo",
- "title": "See Cyclic And Negacyclic Folding Side By Side"
+ "title": "Play The Wraparound Step By Step"
}
},
"outputs": [],
- "source": "# MANDATORY | difficulty 2 | See Cyclic And Negacyclic Folding Side By Side\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution, wraparound_contributions\nfrom ntt_learning.visuals import plot_wraparound\n\nleft = [1, 2, 3, 4]\nright = [5, 6, 7, 8]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nprint(\"negacyclic in x^4 + 1:\", negacyclic_multiply(left, right, n=4))\nprint(\"cyclic folding rows:\")\nfor row in wraparound_contributions(raw, n=4, negacyclic=False):\n print(row)\nprint(\"negacyclic folding rows:\")\nfor row in wraparound_contributions(raw, n=4, negacyclic=True):\n print(row)\n\ndisplay(plot_wraparound(raw, n=4, negacyclic=False, title=\"Cyclic folding into x^4 - 1\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Negacyclic folding into x^4 + 1\"))\n"
+ "source": "# MANDATORY | difficulty 2 | Play The Wraparound Step By Step\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution, wraparound_contributions\nfrom ntt_learning.visuals import plot_wraparound, wraparound_comparison_player\n\nleft = [1, 2, 3, 4]\nright = [5, 6, 7, 8]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nprint(\"negacyclic in x^4 + 1:\", negacyclic_multiply(left, right, n=4))\nprint(\"cyclic folding rows:\")\nfor row in wraparound_contributions(raw, n=4, negacyclic=False):\n print(row)\nprint(\"negacyclic folding rows:\")\nfor row in wraparound_contributions(raw, n=4, negacyclic=True):\n print(row)\n\ndisplay(wraparound_comparison_player(raw, n=4))\ndisplay(plot_wraparound(raw, n=4, negacyclic=False, title=\"Cyclic folding into x^4 - 1\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Negacyclic folding into x^4 + 1\"))\n"
},
{
"cell_type": "markdown",
diff --git a/notebooks/foundations/01_convolution_to_toy_ntt/studio.ipynb b/notebooks/foundations/01_convolution_to_toy_ntt/studio.ipynb
index 7fd3770..2a869c6 100644
--- a/notebooks/foundations/01_convolution_to_toy_ntt/studio.ipynb
+++ b/notebooks/foundations/01_convolution_to_toy_ntt/studio.ipynb
@@ -48,7 +48,7 @@
}
},
"outputs": [],
- "source": "# MANDATORY | difficulty 3 | Compare The Two Fold Rules\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import schoolbook_convolution\nfrom ntt_learning.visuals import plot_wraparound\n\nraw = schoolbook_convolution([1, 2, 3, 4], [5, 6, 7, 8])\nprint(\"raw convolution:\", raw)\ndisplay(plot_wraparound(raw, n=4, negacyclic=False, title=\"Positive wrap into x^4 - 1\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Negative wrap into x^4 + 1\"))\n"
+ "source": "# MANDATORY | difficulty 3 | Compare The Two Fold Rules\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import schoolbook_convolution\nfrom ntt_learning.visuals import plot_wraparound, wraparound_comparison_player\n\nraw = schoolbook_convolution([1, 2, 3, 4], [5, 6, 7, 8])\nprint(\"raw convolution:\", raw)\ndisplay(wraparound_comparison_player(raw, n=4))\ndisplay(plot_wraparound(raw, n=4, negacyclic=False, title=\"Positive wrap into x^4 - 1\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Negative wrap into x^4 + 1\"))\n"
},
{
"cell_type": "markdown",
diff --git a/notebooks/foundations/02_negative_wrapped_ntt/lab.ipynb b/notebooks/foundations/02_negative_wrapped_ntt/lab.ipynb
index 405cd48..6a1def9 100644
--- a/notebooks/foundations/02_negative_wrapped_ntt/lab.ipynb
+++ b/notebooks/foundations/02_negative_wrapped_ntt/lab.ipynb
@@ -62,7 +62,7 @@
}
},
"outputs": [],
- "source": "# MANDATORY | difficulty 3 | Interactive Signal Explorer\n\nimport ipywidgets as widgets\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi\n\nmodulus = 17\npsi = find_psi(4, modulus)\n\ndef preview(a0=1, a1=2, a2=3, a3=4):\n signal = [a0, a1, a2, a3]\n spectrum = forward_ntt_psi(signal, modulus, psi)\n print(\"signal:\", signal)\n print(\"spectrum:\", spectrum)\n print(\"inverse:\", inverse_ntt_psi(spectrum, modulus, psi))\n\ndisplay(\n widgets.interact(\n preview,\n a0=(0, 16),\n a1=(0, 16),\n a2=(0, 16),\n a3=(0, 16),\n )\n)\n"
+ "source": "# MANDATORY | difficulty 3 | Interactive Signal Explorer\n\nimport ipywidgets as widgets\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi\nfrom ntt_learning.visuals import direct_ntt_player\n\nmodulus = 17\npsi = find_psi(4, modulus)\n\ndef preview(a0=1, a1=2, a2=3, a3=4):\n signal = [a0, a1, a2, a3]\n spectrum = forward_ntt_psi(signal, modulus, psi)\n print(\"signal:\", signal)\n print(\"spectrum:\", spectrum)\n print(\"inverse:\", inverse_ntt_psi(spectrum, modulus, psi))\n display(direct_ntt_player(signal, modulus, psi))\n\ndisplay(\n widgets.interact(\n preview,\n a0=(0, 16),\n a1=(0, 16),\n a2=(0, 16),\n a3=(0, 16),\n )\n)\n"
},
{
"cell_type": "markdown",
diff --git a/notebooks/foundations/02_negative_wrapped_ntt/lecture.ipynb b/notebooks/foundations/02_negative_wrapped_ntt/lecture.ipynb
index fd54389..b9c23eb 100644
--- a/notebooks/foundations/02_negative_wrapped_ntt/lecture.ipynb
+++ b/notebooks/foundations/02_negative_wrapped_ntt/lecture.ipynb
@@ -48,7 +48,7 @@
}
},
"outputs": [],
- "source": "# MANDATORY | difficulty 2 | Inspect \u03c9, \u03c8, And The Direct Transform Matrix\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_primitive_root, find_psi, ntt_psi_exponent_grid, ntt_psi_matrix\nfrom ntt_learning.visuals import plot_ntt_psi_exponent_heatmap, plot_ntt_psi_matrix_heatmap\n\nmodulus = 17\nn = 4\nomega = find_primitive_root(n, modulus)\npsi = find_psi(n, modulus)\n\nprint(\"omega:\", omega)\nprint(\"psi:\", psi)\nprint(\"exponent grid:\")\nfor row in ntt_psi_exponent_grid(n):\n print(row)\nprint(\"NTT_psi matrix:\")\nfor row in ntt_psi_matrix(n, modulus, psi):\n print(row)\n\ndisplay(plot_ntt_psi_exponent_heatmap(n, title=\"Exponents 2ij + i for n=4\"))\ndisplay(plot_ntt_psi_matrix_heatmap(n, modulus, psi, title=\"Concrete NTT_psi matrix in Z_17\"))\n"
+ "source": "# MANDATORY | difficulty 2 | Inspect \u03c9, \u03c8, And The Direct Transform Matrix\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_primitive_root, find_psi, ntt_psi_exponent_grid, ntt_psi_matrix\nfrom ntt_learning.visuals import direct_ntt_player, plot_ntt_psi_exponent_heatmap, plot_ntt_psi_matrix_heatmap\n\nmodulus = 17\nn = 4\nomega = find_primitive_root(n, modulus)\npsi = find_psi(n, modulus)\n\nprint(\"omega:\", omega)\nprint(\"psi:\", psi)\nprint(\"exponent grid:\")\nfor row in ntt_psi_exponent_grid(n):\n print(row)\nprint(\"NTT_psi matrix:\")\nfor row in ntt_psi_matrix(n, modulus, psi):\n print(row)\n\ndisplay(direct_ntt_player([1, 2, 3, 4], modulus, psi))\ndisplay(plot_ntt_psi_exponent_heatmap(n, title=\"Exponents 2ij + i for n=4\"))\ndisplay(plot_ntt_psi_matrix_heatmap(n, modulus, psi, title=\"Concrete NTT_psi matrix in Z_17\"))\n"
},
{
"cell_type": "markdown",
@@ -74,7 +74,7 @@
}
},
"outputs": [],
- "source": "# MANDATORY | difficulty 2 | Run A Direct NTT\u03c8 / INTT\u03c8 Round Trip\n\nfrom ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi\n\nsignal = [1, 2, 3, 4]\nmodulus = 17\npsi = find_psi(len(signal), modulus)\nspectrum = forward_ntt_psi(signal, modulus, psi)\n\nprint(\"signal:\", signal)\nprint(\"spectrum:\", spectrum)\nprint(\"inverse recovery:\", inverse_ntt_psi(spectrum, modulus, psi))\n"
+ "source": "# MANDATORY | difficulty 2 | Run A Direct NTT\u03c8 / INTT\u03c8 Round Trip\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi\nfrom ntt_learning.visuals import direct_ntt_player\n\nsignal = [1, 2, 3, 4]\nmodulus = 17\npsi = find_psi(len(signal), modulus)\nspectrum = forward_ntt_psi(signal, modulus, psi)\n\nprint(\"signal:\", signal)\nprint(\"spectrum:\", spectrum)\nprint(\"inverse recovery:\", inverse_ntt_psi(spectrum, modulus, psi))\ndisplay(direct_ntt_player(signal, modulus, psi))\n"
},
{
"cell_type": "code",
diff --git a/ntt_learning/visuals.py b/ntt_learning/visuals.py
index 06da6f0..02e3b40 100644
--- a/ntt_learning/visuals.py
+++ b/ntt_learning/visuals.py
@@ -2,6 +2,7 @@
from __future__ import annotations
+from html import escape
from typing import Sequence
import ipywidgets as widgets
@@ -16,6 +17,7 @@ from .toy_ntt import (
TransformTrace,
base_multiply_pair,
bit_reversed_order,
+ convolution_contributions,
forward_ntt_psi,
inverse_ntt_psi,
ntt_psi_exponent_grid,
@@ -26,6 +28,426 @@ from .toy_ntt import (
wraparound_contributions,
)
+SVG_BG = "#f7f3ea"
+SVG_PANEL = "#fffdf8"
+SVG_INK = "#1f2933"
+SVG_SOFT = "#d8e2dc"
+SVG_ACCENT = "#ef476f"
+SVG_HILITE = "#ffd166"
+SVG_GOOD = "#06d6a0"
+SVG_WARN = "#f08a5d"
+SVG_BLUE = "#118ab2"
+
+
+def _widget_chrome(title: str, subtitle: str, view: widgets.Widget) -> widgets.Widget:
+ header = widgets.HTML(
+ f"""
+
+
{escape(title)}
+
{escape(subtitle)}
+
+ """
+ )
+ box = widgets.VBox(
+ [header, view],
+ layout=widgets.Layout(border="1px solid #d9d9d9", overflow="hidden"),
+ )
+ return box
+
+
+def _player_widget(
+ *,
+ title: str,
+ subtitle: str,
+ frames: Sequence[str],
+ captions: Sequence[str],
+ width: str = "100%",
+) -> widgets.Widget:
+ if not frames:
+ raise ValueError("player widget requires at least one frame")
+
+ play = widgets.Play(value=0, min=0, max=len(frames) - 1, step=1, interval=1100, description="Play")
+ slider = widgets.IntSlider(
+ value=0,
+ min=0,
+ max=len(frames) - 1,
+ step=1,
+ description="Step",
+ continuous_update=False,
+ layout=widgets.Layout(width="420px"),
+ )
+ widgets.jslink((play, "value"), (slider, "value"))
+
+ frame_html = widgets.HTML(layout=widgets.Layout(width=width))
+ caption_html = widgets.HTML(layout=widgets.Layout(width=width))
+
+ def render(index: int) -> None:
+ frame_html.value = frames[index]
+ caption_html.value = f"""
+
+ Frame {index + 1} of {len(frames)}
+ {escape(captions[index])}
+
+ """
+
+ slider.observe(lambda change: render(change["new"]), names="value")
+ render(0)
+
+ controls = widgets.HBox(
+ [play, slider],
+ layout=widgets.Layout(padding="10px 14px", align_items="center"),
+ )
+ return _widget_chrome(title, subtitle, widgets.VBox([controls, frame_html, caption_html]))
+
+
+def _svg_box(x: float, y: float, w: float, h: float, label: str, value: str, *, fill: str, stroke: str, stroke_width: float = 2.0) -> str:
+ return f"""
+
+ {escape(label)}
+ {escape(value)}
+ """
+
+
+def _svg_text(x: float, y: float, text: str, *, size: int = 14, weight: str = "400", fill: str = SVG_INK, anchor: str = "start") -> str:
+ return f'{escape(text)}'
+
+
+def _convolution_frame_svg(
+ left: Sequence[int],
+ right: Sequence[int],
+ diagonal_index: int,
+ *,
+ title: str,
+) -> str:
+ rows = convolution_contributions(left, right)
+ current = rows[diagonal_index]
+ grid = pairwise_product_grid(left, right)
+ cell = 68
+ left_margin = 78
+ top_margin = 78
+ bottom_y = top_margin + len(left) * cell + 70
+ width = left_margin + len(right) * cell + 120
+ height = bottom_y + 110
+
+ parts = [
+ f'")
+ return "".join(parts)
+
+
+def schoolbook_diagonal_player(left: Sequence[int], right: Sequence[int]) -> widgets.Widget:
+ """Interactive diagonal-by-diagonal walkthrough of schoolbook multiplication."""
+ rows = convolution_contributions(left, right)
+ frames = [
+ _convolution_frame_svg(left, right, diagonal_index=index, title="Schoolbook Multiplication As A Moving Diagonal")
+ for index in range(len(rows))
+ ]
+ captions = [
+ f"Only the highlighted products contribute to y{row['output_index']}. Watch the diagonal sweep across the grid instead of imagining the sum in your head."
+ for row in rows
+ ]
+ return _player_widget(
+ title="Convolution Diagonal Player",
+ subtitle="Press play. The highlighted diagonal is the coefficient being formed right now.",
+ frames=frames,
+ captions=captions,
+ )
+
+
+def _wrap_compare_frame_svg(coefficients: Sequence[int], n: int, step: int) -> str:
+ cyclic_rows = wraparound_contributions(coefficients, n=n, negacyclic=False)
+ negacyclic_rows = wraparound_contributions(coefficients, n=n, negacyclic=True)
+ flat = [(index, coefficient) for index, coefficient in enumerate(coefficients)]
+ current_index, current_value = flat[step]
+ cell = 58
+ width = max(980, len(coefficients) * cell + 200)
+ height = 390
+ top_y = 78
+ cyclic_y = 210
+ neg_y = 300
+
+ parts = [
+ f'")
+ return "".join(parts)
+
+
+def wraparound_comparison_player(coefficients: Sequence[int], n: int) -> widgets.Widget:
+ """Interactive comparison of cyclic and negacyclic folding, one source term at a time."""
+ frames = [_wrap_compare_frame_svg(coefficients, n, step) for step in range(len(coefficients))]
+ captions = []
+ for index, coefficient in enumerate(coefficients):
+ wraps, slot = divmod(index, n)
+ neg_sign = "-" if wraps % 2 else "+"
+ captions.append(
+ f"x^{index} with coefficient {coefficient} lands in slot {slot}. Cyclic folding keeps a + sign; negacyclic folding uses {neg_sign} after {wraps} wrap(s)."
+ )
+ return _player_widget(
+ title="Wraparound Step Player",
+ subtitle="Play one raw coefficient at a time and compare x^n-1 with x^n+1 side by side.",
+ frames=frames,
+ captions=captions,
+ )
+
+
+def _direct_ntt_frame_svg(values: Sequence[int], modulus: int, psi: int, output_index: int, input_index: int) -> str:
+ exponents = ntt_psi_exponent_grid(len(values))
+ matrix = ntt_psi_matrix(len(values), modulus, psi)
+ cell = 62
+ width = 1040
+ height = 470
+ grid_x = 260
+ grid_y = 84
+ boxes_y = 360
+ contributions = []
+ for row, value in enumerate(values):
+ exponent = exponents[output_index][row]
+ factor = matrix[output_index][row]
+ product = (value * factor) % modulus
+ contributions.append((exponent, factor, product))
+
+ partial = sum(product for _, _, product in contributions[: input_index + 1]) % modulus
+ final = sum(product for _, _, product in contributions) % modulus
+
+ parts = [
+ f'")
+ return "".join(parts)
+
+
+def direct_ntt_player(values: Sequence[int], modulus: int, psi: int) -> widgets.Widget:
+ """Interactive walkthrough of the direct NTT_psi matrix multiplication."""
+ frames = []
+ captions = []
+ for output_index in range(len(values)):
+ for input_index in range(len(values)):
+ exponent = 2 * input_index * output_index + input_index
+ factor = pow(psi, exponent, modulus)
+ product = (values[input_index] * factor) % modulus
+ frames.append(_direct_ntt_frame_svg(values, modulus, psi, output_index, input_index))
+ captions.append(
+ f"Output slot j={output_index}: multiply a{input_index}={values[input_index]} by psi^{exponent}={factor} to contribute {product} mod {modulus}."
+ )
+ return _player_widget(
+ title="Direct NTTψ Player",
+ subtitle="Press play and watch the transform build one contribution at a time.",
+ frames=frames,
+ captions=captions,
+ )
+
+
+def _butterfly_story_frame_svg(trace: TransformTrace, stage_index: int, pair_index: int) -> str:
+ stage = trace.stages[stage_index]
+ left, right = stage.pairings[pair_index]
+ zeta = stage.zetas[pair_index]
+ width = max(820, len(stage.input_values) * 120)
+ height = 340
+ input_y = 110
+ output_y = 240
+ spacing = 92
+ start_x = 60
+ parts = [
+ f'")
+ return "".join(parts)
+
+
+def butterfly_story_player(trace: TransformTrace) -> widgets.Widget:
+ """Interactive pair-by-pair walkthrough of a butterfly trace."""
+ frames = []
+ captions = []
+ for stage_index, stage in enumerate(trace.stages):
+ for pair_index, pair in enumerate(stage.pairings):
+ left, right = pair
+ frames.append(_butterfly_story_frame_svg(trace, stage_index, pair_index))
+ captions.append(
+ f"Stage {stage.stage_index}: the active pair is {pair}. Watch only these two wires. Everything else is parked until its own butterfly fires."
+ )
+ return _player_widget(
+ title="Butterfly Pair Player",
+ subtitle="Play pair by pair. This is the local machine the learner needs to internalize.",
+ frames=frames,
+ captions=captions,
+ )
+
def _value_colors(values: Sequence[int]) -> list[str]:
colors = []
diff --git a/tools/render_notebooks.py b/tools/render_notebooks.py
index 45b90b7..a3dee6c 100644
--- a/tools/render_notebooks.py
+++ b/tools/render_notebooks.py
@@ -455,12 +455,12 @@ def build_bundle_01() -> None:
"mandatory",
2,
"demo",
- "See The Product Grid And The Diagonal Sums",
+ "Play The Diagonal Sweep",
"""
from IPython.display import display
from ntt_learning.toy_ntt import convolution_contributions, schoolbook_convolution
- from ntt_learning.visuals import plot_convolution_grid
+ from ntt_learning.visuals import plot_convolution_grid, schoolbook_diagonal_player
left = [1, 2, 3, 4]
right = [5, 6, 7, 8]
@@ -470,6 +470,7 @@ def build_bundle_01() -> None:
for row in convolution_contributions(left, right):
print(row)
+ display(schoolbook_diagonal_player(left, right))
fig = plot_convolution_grid(left, right, title="Schoolbook products for [1,2,3,4] * [5,6,7,8]")
display(fig)
""",
@@ -492,12 +493,12 @@ def build_bundle_01() -> None:
"mandatory",
2,
"demo",
- "See Cyclic And Negacyclic Folding Side By Side",
+ "Play The Wraparound Step By Step",
"""
from IPython.display import display
from ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution, wraparound_contributions
- from ntt_learning.visuals import plot_wraparound
+ from ntt_learning.visuals import plot_wraparound, wraparound_comparison_player
left = [1, 2, 3, 4]
right = [5, 6, 7, 8]
@@ -512,6 +513,7 @@ def build_bundle_01() -> None:
for row in wraparound_contributions(raw, n=4, negacyclic=True):
print(row)
+ display(wraparound_comparison_player(raw, n=4))
display(plot_wraparound(raw, n=4, negacyclic=False, title="Cyclic folding into x^4 - 1"))
display(plot_wraparound(raw, n=4, negacyclic=True, title="Negacyclic folding into x^4 + 1"))
""",
@@ -599,7 +601,7 @@ def build_bundle_01() -> None:
from IPython.display import display
from ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution
- from ntt_learning.visuals import plot_convolution_grid, plot_wraparound
+ from ntt_learning.visuals import plot_convolution_grid, plot_wraparound, schoolbook_diagonal_player, wraparound_comparison_player
left = [3, 0, 2, 1]
right = [1, 4, 0, 2]
@@ -607,6 +609,8 @@ def build_bundle_01() -> None:
print("raw convolution:", raw)
print("negacyclic result:", negacyclic_multiply(left, right, n=4))
+ display(schoolbook_diagonal_player(left, right))
+ display(wraparound_comparison_player(raw, n=4))
display(plot_convolution_grid(left, right, title="Prediction check grid"))
display(plot_wraparound(raw, n=4, negacyclic=True, title="Prediction check fold"))
""",
@@ -636,10 +640,11 @@ def build_bundle_01() -> None:
from IPython.display import display
from ntt_learning.toy_ntt import schoolbook_convolution
- from ntt_learning.visuals import plot_wraparound
+ from ntt_learning.visuals import plot_wraparound, wraparound_comparison_player
raw = schoolbook_convolution([2, 5, 0, 1], [1, 0, 3, 2])
print("raw convolution:", raw)
+ display(wraparound_comparison_player(raw, n=4))
display(plot_wraparound(raw, n=4, negacyclic=True, title="Trace one tail coefficient by eye"))
""",
),
@@ -815,10 +820,11 @@ def build_bundle_01() -> None:
from IPython.display import display
from ntt_learning.toy_ntt import schoolbook_convolution
- from ntt_learning.visuals import plot_wraparound
+ from ntt_learning.visuals import plot_wraparound, wraparound_comparison_player
raw = schoolbook_convolution([1, 2, 3, 4], [5, 6, 7, 8])
print("raw convolution:", raw)
+ display(wraparound_comparison_player(raw, n=4))
display(plot_wraparound(raw, n=4, negacyclic=False, title="Positive wrap into x^4 - 1"))
display(plot_wraparound(raw, n=4, negacyclic=True, title="Negative wrap into x^4 + 1"))
""",
@@ -926,7 +932,7 @@ def build_bundle_02() -> None:
from IPython.display import display
from ntt_learning.toy_ntt import find_primitive_root, find_psi, ntt_psi_exponent_grid, ntt_psi_matrix
- from ntt_learning.visuals import plot_ntt_psi_exponent_heatmap, plot_ntt_psi_matrix_heatmap
+ from ntt_learning.visuals import direct_ntt_player, plot_ntt_psi_exponent_heatmap, plot_ntt_psi_matrix_heatmap
modulus = 17
n = 4
@@ -942,6 +948,7 @@ def build_bundle_02() -> None:
for row in ntt_psi_matrix(n, modulus, psi):
print(row)
+ display(direct_ntt_player([1, 2, 3, 4], modulus, psi))
display(plot_ntt_psi_exponent_heatmap(n, title="Exponents 2ij + i for n=4"))
display(plot_ntt_psi_matrix_heatmap(n, modulus, psi, title="Concrete NTT_psi matrix in Z_17"))
""",
@@ -962,7 +969,10 @@ def build_bundle_02() -> None:
"demo",
"Run A Direct NTTψ / INTTψ Round Trip",
"""
+ from IPython.display import display
+
from ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi
+ from ntt_learning.visuals import direct_ntt_player
signal = [1, 2, 3, 4]
modulus = 17
@@ -972,6 +982,7 @@ def build_bundle_02() -> None:
print("signal:", signal)
print("spectrum:", spectrum)
print("inverse recovery:", inverse_ntt_psi(spectrum, modulus, psi))
+ display(direct_ntt_player(signal, modulus, psi))
""",
),
code(
@@ -1079,6 +1090,7 @@ def build_bundle_02() -> None:
from IPython.display import display
from ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi
+ from ntt_learning.visuals import direct_ntt_player
modulus = 17
psi = find_psi(4, modulus)
@@ -1089,6 +1101,7 @@ def build_bundle_02() -> None:
print("signal:", signal)
print("spectrum:", spectrum)
print("inverse:", inverse_ntt_psi(spectrum, modulus, psi))
+ display(direct_ntt_player(signal, modulus, psi))
display(
widgets.interact(
@@ -1449,7 +1462,7 @@ def build_bundle_03() -> None:
from IPython.display import display
from ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, forward_ntt_psi
- from ntt_learning.visuals import interactive_trace, plot_butterfly_network, plot_trace_overview
+ from ntt_learning.visuals import butterfly_story_player, interactive_trace, plot_butterfly_network, plot_trace_overview
signal = [1, 2, 3, 4]
modulus = 7681
@@ -1459,6 +1472,7 @@ def build_bundle_03() -> None:
print("raw CT output (BO):", trace.raw_output)
print("bit-reversed back to NO:", trace.normal_order_output)
print("direct NTT_psi:", forward_ntt_psi(signal, modulus, psi))
+ display(butterfly_story_player(trace))
display(plot_trace_overview(trace, title="CT overview for [1,2,3,4]"))
display(plot_butterfly_network(trace, title="Full CT network for [1,2,3,4]"))
display(interactive_trace(trace, title="CT forward trace"))
@@ -1505,7 +1519,7 @@ def build_bundle_03() -> None:
from IPython.display import display
from ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, find_psi
- from ntt_learning.visuals import interactive_trace, plot_butterfly_network, plot_trace_overview
+ from ntt_learning.visuals import butterfly_story_player, interactive_trace, plot_butterfly_network, plot_trace_overview
signal = [0, 1, 2, 3, 4, 5, 6, 7]
modulus = 97
@@ -1515,6 +1529,7 @@ def build_bundle_03() -> None:
print("psi:", psi)
print("BO output:", trace.raw_output)
print("NO output:", trace.normal_order_output)
+ display(butterfly_story_player(trace))
display(plot_trace_overview(trace, title="Three CT stages for n=8"))
display(plot_butterfly_network(trace, title="Full CT network for n=8"))
display(interactive_trace(trace, title="n=8 CT stage explorer"))
@@ -1604,9 +1619,10 @@ def build_bundle_03() -> None:
from IPython.display import display
from ntt_learning.toy_ntt import fast_ntt_psi_ct_trace
- from ntt_learning.visuals import interactive_trace
+ from ntt_learning.visuals import butterfly_story_player, interactive_trace
trace = fast_ntt_psi_ct_trace([1, 2, 3, 4], 7681, 1925)
+ display(butterfly_story_player(trace))
display(interactive_trace(trace, title="Check your n=4 prediction"))
""",
),