diff --git a/notebooks/butterfly_mechanics/03_fast_forward_ct/lab.ipynb b/notebooks/butterfly_mechanics/03_fast_forward_ct/lab.ipynb index 8004e6c..6917614 100644 --- a/notebooks/butterfly_mechanics/03_fast_forward_ct/lab.ipynb +++ b/notebooks/butterfly_mechanics/03_fast_forward_ct/lab.ipynb @@ -48,7 +48,7 @@ } }, "outputs": [], - "source": "# MANDATORY | difficulty 3 | Prediction Check For n=4\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace\nfrom ntt_learning.visuals import interactive_trace\n\ntrace = fast_ntt_psi_ct_trace([1, 2, 3, 4], 7681, 1925)\ndisplay(interactive_trace(trace, title=\"Check your n=4 prediction\"))\n" + "source": "# MANDATORY | difficulty 3 | Prediction Check For n=4\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace\nfrom ntt_learning.visuals import butterfly_story_player, interactive_trace\n\ntrace = fast_ntt_psi_ct_trace([1, 2, 3, 4], 7681, 1925)\ndisplay(butterfly_story_player(trace))\ndisplay(interactive_trace(trace, title=\"Check your n=4 prediction\"))\n" }, { "cell_type": "markdown", diff --git a/notebooks/butterfly_mechanics/03_fast_forward_ct/lecture.ipynb b/notebooks/butterfly_mechanics/03_fast_forward_ct/lecture.ipynb index fab970f..50a6656 100644 --- a/notebooks/butterfly_mechanics/03_fast_forward_ct/lecture.ipynb +++ b/notebooks/butterfly_mechanics/03_fast_forward_ct/lecture.ipynb @@ -62,7 +62,7 @@ } }, "outputs": [], - "source": "# MANDATORY | difficulty 3 | Trace The Exact n=4 Paper Example\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, forward_ntt_psi\nfrom ntt_learning.visuals import interactive_trace, plot_butterfly_network, plot_trace_overview\n\nsignal = [1, 2, 3, 4]\nmodulus = 7681\npsi = 1925\ntrace = fast_ntt_psi_ct_trace(signal, modulus, psi)\n\nprint(\"raw CT output (BO):\", trace.raw_output)\nprint(\"bit-reversed back to NO:\", trace.normal_order_output)\nprint(\"direct NTT_psi:\", forward_ntt_psi(signal, modulus, psi))\ndisplay(plot_trace_overview(trace, title=\"CT overview for [1,2,3,4]\"))\ndisplay(plot_butterfly_network(trace, title=\"Full CT network for [1,2,3,4]\"))\ndisplay(interactive_trace(trace, title=\"CT forward trace\"))\n" + "source": "# MANDATORY | difficulty 3 | Trace The Exact n=4 Paper Example\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, forward_ntt_psi\nfrom ntt_learning.visuals import butterfly_story_player, interactive_trace, plot_butterfly_network, plot_trace_overview\n\nsignal = [1, 2, 3, 4]\nmodulus = 7681\npsi = 1925\ntrace = fast_ntt_psi_ct_trace(signal, modulus, psi)\n\nprint(\"raw CT output (BO):\", trace.raw_output)\nprint(\"bit-reversed back to NO:\", trace.normal_order_output)\nprint(\"direct NTT_psi:\", forward_ntt_psi(signal, modulus, psi))\ndisplay(butterfly_story_player(trace))\ndisplay(plot_trace_overview(trace, title=\"CT overview for [1,2,3,4]\"))\ndisplay(plot_butterfly_network(trace, title=\"Full CT network for [1,2,3,4]\"))\ndisplay(interactive_trace(trace, title=\"CT forward trace\"))\n" }, { "cell_type": "markdown", @@ -102,7 +102,7 @@ } }, "outputs": [], - "source": "# MANDATORY | difficulty 3 | Go One Stage Deeper With n=8\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, find_psi\nfrom ntt_learning.visuals import interactive_trace, plot_butterfly_network, plot_trace_overview\n\nsignal = [0, 1, 2, 3, 4, 5, 6, 7]\nmodulus = 97\npsi = find_psi(8, modulus)\ntrace = fast_ntt_psi_ct_trace(signal, modulus, psi)\n\nprint(\"psi:\", psi)\nprint(\"BO output:\", trace.raw_output)\nprint(\"NO output:\", trace.normal_order_output)\ndisplay(plot_trace_overview(trace, title=\"Three CT stages for n=8\"))\ndisplay(plot_butterfly_network(trace, title=\"Full CT network for n=8\"))\ndisplay(interactive_trace(trace, title=\"n=8 CT stage explorer\"))\n" + "source": "# MANDATORY | difficulty 3 | Go One Stage Deeper With n=8\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, find_psi\nfrom ntt_learning.visuals import butterfly_story_player, interactive_trace, plot_butterfly_network, plot_trace_overview\n\nsignal = [0, 1, 2, 3, 4, 5, 6, 7]\nmodulus = 97\npsi = find_psi(8, modulus)\ntrace = fast_ntt_psi_ct_trace(signal, modulus, psi)\n\nprint(\"psi:\", psi)\nprint(\"BO output:\", trace.raw_output)\nprint(\"NO output:\", trace.normal_order_output)\ndisplay(butterfly_story_player(trace))\ndisplay(plot_trace_overview(trace, title=\"Three CT stages for n=8\"))\ndisplay(plot_butterfly_network(trace, title=\"Full CT network for n=8\"))\ndisplay(interactive_trace(trace, title=\"n=8 CT stage explorer\"))\n" }, { "cell_type": "code", diff --git a/notebooks/foundations/01_convolution_to_toy_ntt/lab.ipynb b/notebooks/foundations/01_convolution_to_toy_ntt/lab.ipynb index 326b210..2bc5b86 100644 --- a/notebooks/foundations/01_convolution_to_toy_ntt/lab.ipynb +++ b/notebooks/foundations/01_convolution_to_toy_ntt/lab.ipynb @@ -48,7 +48,7 @@ } }, "outputs": [], - "source": "# MANDATORY | difficulty 2 | Run The Prediction Check\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution\nfrom ntt_learning.visuals import plot_convolution_grid, plot_wraparound\n\nleft = [3, 0, 2, 1]\nright = [1, 4, 0, 2]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nprint(\"negacyclic result:\", negacyclic_multiply(left, right, n=4))\ndisplay(plot_convolution_grid(left, right, title=\"Prediction check grid\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Prediction check fold\"))\n" + "source": "# MANDATORY | difficulty 2 | Run The Prediction Check\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution\nfrom ntt_learning.visuals import plot_convolution_grid, plot_wraparound, schoolbook_diagonal_player, wraparound_comparison_player\n\nleft = [3, 0, 2, 1]\nright = [1, 4, 0, 2]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nprint(\"negacyclic result:\", negacyclic_multiply(left, right, n=4))\ndisplay(schoolbook_diagonal_player(left, right))\ndisplay(wraparound_comparison_player(raw, n=4))\ndisplay(plot_convolution_grid(left, right, title=\"Prediction check grid\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Prediction check fold\"))\n" }, { "cell_type": "markdown", @@ -74,7 +74,7 @@ } }, "outputs": [], - "source": "# MANDATORY | difficulty 2 | A Second Visual Drill\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import schoolbook_convolution\nfrom ntt_learning.visuals import plot_wraparound\n\nraw = schoolbook_convolution([2, 5, 0, 1], [1, 0, 3, 2])\nprint(\"raw convolution:\", raw)\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Trace one tail coefficient by eye\"))\n" + "source": "# MANDATORY | difficulty 2 | A Second Visual Drill\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import schoolbook_convolution\nfrom ntt_learning.visuals import plot_wraparound, wraparound_comparison_player\n\nraw = schoolbook_convolution([2, 5, 0, 1], [1, 0, 3, 2])\nprint(\"raw convolution:\", raw)\ndisplay(wraparound_comparison_player(raw, n=4))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Trace one tail coefficient by eye\"))\n" }, { "cell_type": "markdown", diff --git a/notebooks/foundations/01_convolution_to_toy_ntt/lecture.ipynb b/notebooks/foundations/01_convolution_to_toy_ntt/lecture.ipynb index d31bda9..ee04fa2 100644 --- a/notebooks/foundations/01_convolution_to_toy_ntt/lecture.ipynb +++ b/notebooks/foundations/01_convolution_to_toy_ntt/lecture.ipynb @@ -44,11 +44,11 @@ "role": "mandatory", "difficulty": 2, "kind": "demo", - "title": "See The Product Grid And The Diagonal Sums" + "title": "Play The Diagonal Sweep" } }, "outputs": [], - "source": "# MANDATORY | difficulty 2 | See The Product Grid And The Diagonal Sums\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import convolution_contributions, schoolbook_convolution\nfrom ntt_learning.visuals import plot_convolution_grid\n\nleft = [1, 2, 3, 4]\nright = [5, 6, 7, 8]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nfor row in convolution_contributions(left, right):\n print(row)\n\nfig = plot_convolution_grid(left, right, title=\"Schoolbook products for [1,2,3,4] * [5,6,7,8]\")\ndisplay(fig)\n" + "source": "# MANDATORY | difficulty 2 | Play The Diagonal Sweep\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import convolution_contributions, schoolbook_convolution\nfrom ntt_learning.visuals import plot_convolution_grid, schoolbook_diagonal_player\n\nleft = [1, 2, 3, 4]\nright = [5, 6, 7, 8]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nfor row in convolution_contributions(left, right):\n print(row)\n\ndisplay(schoolbook_diagonal_player(left, right))\nfig = plot_convolution_grid(left, right, title=\"Schoolbook products for [1,2,3,4] * [5,6,7,8]\")\ndisplay(fig)\n" }, { "cell_type": "markdown", @@ -70,11 +70,11 @@ "role": "mandatory", "difficulty": 2, "kind": "demo", - "title": "See Cyclic And Negacyclic Folding Side By Side" + "title": "Play The Wraparound Step By Step" } }, "outputs": [], - "source": "# MANDATORY | difficulty 2 | See Cyclic And Negacyclic Folding Side By Side\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution, wraparound_contributions\nfrom ntt_learning.visuals import plot_wraparound\n\nleft = [1, 2, 3, 4]\nright = [5, 6, 7, 8]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nprint(\"negacyclic in x^4 + 1:\", negacyclic_multiply(left, right, n=4))\nprint(\"cyclic folding rows:\")\nfor row in wraparound_contributions(raw, n=4, negacyclic=False):\n print(row)\nprint(\"negacyclic folding rows:\")\nfor row in wraparound_contributions(raw, n=4, negacyclic=True):\n print(row)\n\ndisplay(plot_wraparound(raw, n=4, negacyclic=False, title=\"Cyclic folding into x^4 - 1\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Negacyclic folding into x^4 + 1\"))\n" + "source": "# MANDATORY | difficulty 2 | Play The Wraparound Step By Step\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution, wraparound_contributions\nfrom ntt_learning.visuals import plot_wraparound, wraparound_comparison_player\n\nleft = [1, 2, 3, 4]\nright = [5, 6, 7, 8]\nraw = schoolbook_convolution(left, right)\n\nprint(\"raw convolution:\", raw)\nprint(\"negacyclic in x^4 + 1:\", negacyclic_multiply(left, right, n=4))\nprint(\"cyclic folding rows:\")\nfor row in wraparound_contributions(raw, n=4, negacyclic=False):\n print(row)\nprint(\"negacyclic folding rows:\")\nfor row in wraparound_contributions(raw, n=4, negacyclic=True):\n print(row)\n\ndisplay(wraparound_comparison_player(raw, n=4))\ndisplay(plot_wraparound(raw, n=4, negacyclic=False, title=\"Cyclic folding into x^4 - 1\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Negacyclic folding into x^4 + 1\"))\n" }, { "cell_type": "markdown", diff --git a/notebooks/foundations/01_convolution_to_toy_ntt/studio.ipynb b/notebooks/foundations/01_convolution_to_toy_ntt/studio.ipynb index 7fd3770..2a869c6 100644 --- a/notebooks/foundations/01_convolution_to_toy_ntt/studio.ipynb +++ b/notebooks/foundations/01_convolution_to_toy_ntt/studio.ipynb @@ -48,7 +48,7 @@ } }, "outputs": [], - "source": "# MANDATORY | difficulty 3 | Compare The Two Fold Rules\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import schoolbook_convolution\nfrom ntt_learning.visuals import plot_wraparound\n\nraw = schoolbook_convolution([1, 2, 3, 4], [5, 6, 7, 8])\nprint(\"raw convolution:\", raw)\ndisplay(plot_wraparound(raw, n=4, negacyclic=False, title=\"Positive wrap into x^4 - 1\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Negative wrap into x^4 + 1\"))\n" + "source": "# MANDATORY | difficulty 3 | Compare The Two Fold Rules\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import schoolbook_convolution\nfrom ntt_learning.visuals import plot_wraparound, wraparound_comparison_player\n\nraw = schoolbook_convolution([1, 2, 3, 4], [5, 6, 7, 8])\nprint(\"raw convolution:\", raw)\ndisplay(wraparound_comparison_player(raw, n=4))\ndisplay(plot_wraparound(raw, n=4, negacyclic=False, title=\"Positive wrap into x^4 - 1\"))\ndisplay(plot_wraparound(raw, n=4, negacyclic=True, title=\"Negative wrap into x^4 + 1\"))\n" }, { "cell_type": "markdown", diff --git a/notebooks/foundations/02_negative_wrapped_ntt/lab.ipynb b/notebooks/foundations/02_negative_wrapped_ntt/lab.ipynb index 405cd48..6a1def9 100644 --- a/notebooks/foundations/02_negative_wrapped_ntt/lab.ipynb +++ b/notebooks/foundations/02_negative_wrapped_ntt/lab.ipynb @@ -62,7 +62,7 @@ } }, "outputs": [], - "source": "# MANDATORY | difficulty 3 | Interactive Signal Explorer\n\nimport ipywidgets as widgets\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi\n\nmodulus = 17\npsi = find_psi(4, modulus)\n\ndef preview(a0=1, a1=2, a2=3, a3=4):\n signal = [a0, a1, a2, a3]\n spectrum = forward_ntt_psi(signal, modulus, psi)\n print(\"signal:\", signal)\n print(\"spectrum:\", spectrum)\n print(\"inverse:\", inverse_ntt_psi(spectrum, modulus, psi))\n\ndisplay(\n widgets.interact(\n preview,\n a0=(0, 16),\n a1=(0, 16),\n a2=(0, 16),\n a3=(0, 16),\n )\n)\n" + "source": "# MANDATORY | difficulty 3 | Interactive Signal Explorer\n\nimport ipywidgets as widgets\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi\nfrom ntt_learning.visuals import direct_ntt_player\n\nmodulus = 17\npsi = find_psi(4, modulus)\n\ndef preview(a0=1, a1=2, a2=3, a3=4):\n signal = [a0, a1, a2, a3]\n spectrum = forward_ntt_psi(signal, modulus, psi)\n print(\"signal:\", signal)\n print(\"spectrum:\", spectrum)\n print(\"inverse:\", inverse_ntt_psi(spectrum, modulus, psi))\n display(direct_ntt_player(signal, modulus, psi))\n\ndisplay(\n widgets.interact(\n preview,\n a0=(0, 16),\n a1=(0, 16),\n a2=(0, 16),\n a3=(0, 16),\n )\n)\n" }, { "cell_type": "markdown", diff --git a/notebooks/foundations/02_negative_wrapped_ntt/lecture.ipynb b/notebooks/foundations/02_negative_wrapped_ntt/lecture.ipynb index fd54389..b9c23eb 100644 --- a/notebooks/foundations/02_negative_wrapped_ntt/lecture.ipynb +++ b/notebooks/foundations/02_negative_wrapped_ntt/lecture.ipynb @@ -48,7 +48,7 @@ } }, "outputs": [], - "source": "# MANDATORY | difficulty 2 | Inspect \u03c9, \u03c8, And The Direct Transform Matrix\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_primitive_root, find_psi, ntt_psi_exponent_grid, ntt_psi_matrix\nfrom ntt_learning.visuals import plot_ntt_psi_exponent_heatmap, plot_ntt_psi_matrix_heatmap\n\nmodulus = 17\nn = 4\nomega = find_primitive_root(n, modulus)\npsi = find_psi(n, modulus)\n\nprint(\"omega:\", omega)\nprint(\"psi:\", psi)\nprint(\"exponent grid:\")\nfor row in ntt_psi_exponent_grid(n):\n print(row)\nprint(\"NTT_psi matrix:\")\nfor row in ntt_psi_matrix(n, modulus, psi):\n print(row)\n\ndisplay(plot_ntt_psi_exponent_heatmap(n, title=\"Exponents 2ij + i for n=4\"))\ndisplay(plot_ntt_psi_matrix_heatmap(n, modulus, psi, title=\"Concrete NTT_psi matrix in Z_17\"))\n" + "source": "# MANDATORY | difficulty 2 | Inspect \u03c9, \u03c8, And The Direct Transform Matrix\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_primitive_root, find_psi, ntt_psi_exponent_grid, ntt_psi_matrix\nfrom ntt_learning.visuals import direct_ntt_player, plot_ntt_psi_exponent_heatmap, plot_ntt_psi_matrix_heatmap\n\nmodulus = 17\nn = 4\nomega = find_primitive_root(n, modulus)\npsi = find_psi(n, modulus)\n\nprint(\"omega:\", omega)\nprint(\"psi:\", psi)\nprint(\"exponent grid:\")\nfor row in ntt_psi_exponent_grid(n):\n print(row)\nprint(\"NTT_psi matrix:\")\nfor row in ntt_psi_matrix(n, modulus, psi):\n print(row)\n\ndisplay(direct_ntt_player([1, 2, 3, 4], modulus, psi))\ndisplay(plot_ntt_psi_exponent_heatmap(n, title=\"Exponents 2ij + i for n=4\"))\ndisplay(plot_ntt_psi_matrix_heatmap(n, modulus, psi, title=\"Concrete NTT_psi matrix in Z_17\"))\n" }, { "cell_type": "markdown", @@ -74,7 +74,7 @@ } }, "outputs": [], - "source": "# MANDATORY | difficulty 2 | Run A Direct NTT\u03c8 / INTT\u03c8 Round Trip\n\nfrom ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi\n\nsignal = [1, 2, 3, 4]\nmodulus = 17\npsi = find_psi(len(signal), modulus)\nspectrum = forward_ntt_psi(signal, modulus, psi)\n\nprint(\"signal:\", signal)\nprint(\"spectrum:\", spectrum)\nprint(\"inverse recovery:\", inverse_ntt_psi(spectrum, modulus, psi))\n" + "source": "# MANDATORY | difficulty 2 | Run A Direct NTT\u03c8 / INTT\u03c8 Round Trip\n\nfrom IPython.display import display\n\nfrom ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi\nfrom ntt_learning.visuals import direct_ntt_player\n\nsignal = [1, 2, 3, 4]\nmodulus = 17\npsi = find_psi(len(signal), modulus)\nspectrum = forward_ntt_psi(signal, modulus, psi)\n\nprint(\"signal:\", signal)\nprint(\"spectrum:\", spectrum)\nprint(\"inverse recovery:\", inverse_ntt_psi(spectrum, modulus, psi))\ndisplay(direct_ntt_player(signal, modulus, psi))\n" }, { "cell_type": "code", diff --git a/ntt_learning/visuals.py b/ntt_learning/visuals.py index 06da6f0..02e3b40 100644 --- a/ntt_learning/visuals.py +++ b/ntt_learning/visuals.py @@ -2,6 +2,7 @@ from __future__ import annotations +from html import escape from typing import Sequence import ipywidgets as widgets @@ -16,6 +17,7 @@ from .toy_ntt import ( TransformTrace, base_multiply_pair, bit_reversed_order, + convolution_contributions, forward_ntt_psi, inverse_ntt_psi, ntt_psi_exponent_grid, @@ -26,6 +28,426 @@ from .toy_ntt import ( wraparound_contributions, ) +SVG_BG = "#f7f3ea" +SVG_PANEL = "#fffdf8" +SVG_INK = "#1f2933" +SVG_SOFT = "#d8e2dc" +SVG_ACCENT = "#ef476f" +SVG_HILITE = "#ffd166" +SVG_GOOD = "#06d6a0" +SVG_WARN = "#f08a5d" +SVG_BLUE = "#118ab2" + + +def _widget_chrome(title: str, subtitle: str, view: widgets.Widget) -> widgets.Widget: + header = widgets.HTML( + f""" +
+
{escape(title)}
+
{escape(subtitle)}
+
+ """ + ) + box = widgets.VBox( + [header, view], + layout=widgets.Layout(border="1px solid #d9d9d9", overflow="hidden"), + ) + return box + + +def _player_widget( + *, + title: str, + subtitle: str, + frames: Sequence[str], + captions: Sequence[str], + width: str = "100%", +) -> widgets.Widget: + if not frames: + raise ValueError("player widget requires at least one frame") + + play = widgets.Play(value=0, min=0, max=len(frames) - 1, step=1, interval=1100, description="Play") + slider = widgets.IntSlider( + value=0, + min=0, + max=len(frames) - 1, + step=1, + description="Step", + continuous_update=False, + layout=widgets.Layout(width="420px"), + ) + widgets.jslink((play, "value"), (slider, "value")) + + frame_html = widgets.HTML(layout=widgets.Layout(width=width)) + caption_html = widgets.HTML(layout=widgets.Layout(width=width)) + + def render(index: int) -> None: + frame_html.value = frames[index] + caption_html.value = f""" +
+ Frame {index + 1} of {len(frames)}
+ {escape(captions[index])} +
+ """ + + slider.observe(lambda change: render(change["new"]), names="value") + render(0) + + controls = widgets.HBox( + [play, slider], + layout=widgets.Layout(padding="10px 14px", align_items="center"), + ) + return _widget_chrome(title, subtitle, widgets.VBox([controls, frame_html, caption_html])) + + +def _svg_box(x: float, y: float, w: float, h: float, label: str, value: str, *, fill: str, stroke: str, stroke_width: float = 2.0) -> str: + return f""" + + {escape(label)} + {escape(value)} + """ + + +def _svg_text(x: float, y: float, text: str, *, size: int = 14, weight: str = "400", fill: str = SVG_INK, anchor: str = "start") -> str: + return f'{escape(text)}' + + +def _convolution_frame_svg( + left: Sequence[int], + right: Sequence[int], + diagonal_index: int, + *, + title: str, +) -> str: + rows = convolution_contributions(left, right) + current = rows[diagonal_index] + grid = pairwise_product_grid(left, right) + cell = 68 + left_margin = 78 + top_margin = 78 + bottom_y = top_margin + len(left) * cell + 70 + width = left_margin + len(right) * cell + 120 + height = bottom_y + 110 + + parts = [ + f'', + f'', + _svg_text(34, 48, title, size=22, weight="800"), + _svg_text(34, 68, f"Diagonal {diagonal_index}: every highlighted cell lands in output coefficient y{diagonal_index}", size=13, fill="#486581"), + ] + + for row in range(len(left)): + parts.append(_svg_text(54, top_margin + row * cell + 42, f"a{row}", size=12, anchor="middle")) + for col in range(len(right)): + parts.append(_svg_text(left_margin + col * cell + 28, 64, f"b{col}", size=12, anchor="middle")) + + for row, row_values in enumerate(grid): + for col, value in enumerate(row_values): + active = row + col == diagonal_index + done = row + col < diagonal_index + fill = SVG_HILITE if active else "#e6fffa" if done else "#f8f9fa" + stroke = SVG_ACCENT if active else SVG_GOOD if done else "#cbd2d9" + parts.append( + _svg_box( + left_margin + col * cell, + top_margin + row * cell, + 56, + 56, + f"a{row}·b{col}", + str(value), + fill=fill, + stroke=stroke, + stroke_width=3.2 if active else 1.8, + ) + ) + + for row in rows: + output_index = int(row["output_index"]) + fill = SVG_HILITE if output_index == diagonal_index else "#d9f0ff" if output_index < diagonal_index else "#f1f5f9" + stroke = SVG_ACCENT if output_index == diagonal_index else SVG_BLUE if output_index < diagonal_index else "#cbd2d9" + parts.append( + _svg_box( + left_margin + output_index * cell * 0.92, + bottom_y, + 60, + 56, + f"y{output_index}", + str(row["total"]), + fill=fill, + stroke=stroke, + stroke_width=3.2 if output_index == diagonal_index else 1.8, + ) + ) + + terms = [f"{term['left_value']}×{term['right_value']}={term['product']}" for term in current["terms"]] + equation = " + ".join(terms) if terms else "0" + parts.append(_svg_text(34, height - 34, f"Current diagonal sum: {equation} = {current['total']}", size=15, weight="700")) + parts.append("") + return "".join(parts) + + +def schoolbook_diagonal_player(left: Sequence[int], right: Sequence[int]) -> widgets.Widget: + """Interactive diagonal-by-diagonal walkthrough of schoolbook multiplication.""" + rows = convolution_contributions(left, right) + frames = [ + _convolution_frame_svg(left, right, diagonal_index=index, title="Schoolbook Multiplication As A Moving Diagonal") + for index in range(len(rows)) + ] + captions = [ + f"Only the highlighted products contribute to y{row['output_index']}. Watch the diagonal sweep across the grid instead of imagining the sum in your head." + for row in rows + ] + return _player_widget( + title="Convolution Diagonal Player", + subtitle="Press play. The highlighted diagonal is the coefficient being formed right now.", + frames=frames, + captions=captions, + ) + + +def _wrap_compare_frame_svg(coefficients: Sequence[int], n: int, step: int) -> str: + cyclic_rows = wraparound_contributions(coefficients, n=n, negacyclic=False) + negacyclic_rows = wraparound_contributions(coefficients, n=n, negacyclic=True) + flat = [(index, coefficient) for index, coefficient in enumerate(coefficients)] + current_index, current_value = flat[step] + cell = 58 + width = max(980, len(coefficients) * cell + 200) + height = 390 + top_y = 78 + cyclic_y = 210 + neg_y = 300 + + parts = [ + f'', + f'', + _svg_text(34, 48, "Wraparound Comparison Player", size=22, weight="800"), + _svg_text(34, 68, f"Current source term: x^{current_index} with coefficient {current_value}", size=13, fill="#486581"), + _svg_text(34, top_y - 14, "Raw convolution tail", size=14, weight="700"), + _svg_text(34, cyclic_y - 14, "Cyclic fold into x^n - 1", size=14, weight="700", fill=SVG_BLUE), + _svg_text(34, neg_y - 14, "Negacyclic fold into x^n + 1", size=14, weight="700", fill=SVG_ACCENT), + ] + + for index, coefficient in enumerate(coefficients): + fill = SVG_HILITE if index == current_index else "#f1f5f9" if index > current_index else "#d8f3dc" + stroke = SVG_ACCENT if index == current_index else "#cbd2d9" if index > current_index else SVG_GOOD + parts.append(_svg_box(48 + index * cell, top_y, 48, 48, f"x^{index}", str(coefficient), fill=fill, stroke=stroke)) + + for slot, row in enumerate(cyclic_rows): + parts.append(_svg_box(48 + slot * cell, cyclic_y, 48, 48, f"slot {slot}", str(row["total"]), fill="#e0fbfc", stroke=SVG_BLUE)) + for slot, row in enumerate(negacyclic_rows): + parts.append(_svg_box(48 + slot * cell, neg_y, 48, 48, f"slot {slot}", str(row["total"]), fill="#ffe8d6", stroke=SVG_ACCENT)) + + wraps, slot = divmod(current_index, n) + cyclic_label = f"+ wrap {wraps}" + neg_label = ("-" if wraps % 2 else "+") + f" wrap {wraps}" + source_x = 72 + current_index * cell + target_x = 72 + slot * cell + + parts.append(f'') + parts.append(f'') + parts.append(f""" + + + + + + + + + """) + parts.append(_svg_text((source_x + target_x) / 2, cyclic_y - 18, cyclic_label, size=12, fill=SVG_BLUE, anchor="middle")) + parts.append(_svg_text((source_x + target_x) / 2, neg_y - 18, neg_label, size=12, fill=SVG_ACCENT, anchor="middle")) + parts.append("") + return "".join(parts) + + +def wraparound_comparison_player(coefficients: Sequence[int], n: int) -> widgets.Widget: + """Interactive comparison of cyclic and negacyclic folding, one source term at a time.""" + frames = [_wrap_compare_frame_svg(coefficients, n, step) for step in range(len(coefficients))] + captions = [] + for index, coefficient in enumerate(coefficients): + wraps, slot = divmod(index, n) + neg_sign = "-" if wraps % 2 else "+" + captions.append( + f"x^{index} with coefficient {coefficient} lands in slot {slot}. Cyclic folding keeps a + sign; negacyclic folding uses {neg_sign} after {wraps} wrap(s)." + ) + return _player_widget( + title="Wraparound Step Player", + subtitle="Play one raw coefficient at a time and compare x^n-1 with x^n+1 side by side.", + frames=frames, + captions=captions, + ) + + +def _direct_ntt_frame_svg(values: Sequence[int], modulus: int, psi: int, output_index: int, input_index: int) -> str: + exponents = ntt_psi_exponent_grid(len(values)) + matrix = ntt_psi_matrix(len(values), modulus, psi) + cell = 62 + width = 1040 + height = 470 + grid_x = 260 + grid_y = 84 + boxes_y = 360 + contributions = [] + for row, value in enumerate(values): + exponent = exponents[output_index][row] + factor = matrix[output_index][row] + product = (value * factor) % modulus + contributions.append((exponent, factor, product)) + + partial = sum(product for _, _, product in contributions[: input_index + 1]) % modulus + final = sum(product for _, _, product in contributions) % modulus + + parts = [ + f'', + f'', + _svg_text(34, 48, "Direct NTTψ Contribution Player", size=22, weight="800"), + _svg_text(34, 68, f"Building output slot j={output_index}, currently consuming input i={input_index}", size=13, fill="#486581"), + _svg_text(34, 108, "signal", size=14, weight="700"), + _svg_text(grid_x, 68, "matrix cell = psi^(2ij + i)", size=14, weight="700"), + ] + + for index, value in enumerate(values): + fill = SVG_HILITE if index == input_index else "#d8f3dc" if index < input_index else "#f1f5f9" + stroke = SVG_ACCENT if index == input_index else SVG_GOOD if index < input_index else "#cbd2d9" + parts.append(_svg_box(36, 128 + index * 60, 92, 48, f"a{index}", str(value), fill=fill, stroke=stroke)) + + for row in range(len(values)): + parts.append(_svg_text(grid_x - 18, grid_y + row * cell + 34, f"i={row}", size=12, anchor="end")) + for col in range(len(values)): + parts.append(_svg_text(grid_x + col * cell + 28, grid_y - 12, f"j={col}", size=12, anchor="middle")) + + for col in range(len(values)): + for row in range(len(values)): + active = col == output_index and row == input_index + done = col == output_index and row < input_index + fill = SVG_HILITE if active else "#d8f3dc" if done else "#f1f5f9" + stroke = SVG_ACCENT if active else SVG_GOOD if done else "#cbd2d9" + parts.append( + _svg_box( + grid_x + col * cell, + grid_y + row * cell, + 56, + 56, + f"{exponents[col][row]}", + str(matrix[col][row]), + fill=fill, + stroke=stroke, + stroke_width=3.0 if active else 1.6, + ) + ) + + for index, (_, factor, product) in enumerate(contributions): + fill = SVG_HILITE if index == input_index else "#d8f3dc" if index < input_index else "#f1f5f9" + stroke = SVG_ACCENT if index == input_index else SVG_GOOD if index < input_index else "#cbd2d9" + parts.append(_svg_box(520 + index * 90, boxes_y, 74, 54, f"a{index}·w", str(product), fill=fill, stroke=stroke)) + parts.append(_svg_text(556 + index * 90, boxes_y - 10, f"factor {factor}", size=11, anchor="middle")) + + parts.append(_svg_text(34, height - 60, f"Current partial sum for j={output_index}: {partial} mod {modulus}", size=15, weight="700")) + parts.append(_svg_text(34, height - 34, f"Completed output slot y{output_index}: {final} mod {modulus}", size=15, weight="700", fill=SVG_ACCENT)) + parts.append("") + return "".join(parts) + + +def direct_ntt_player(values: Sequence[int], modulus: int, psi: int) -> widgets.Widget: + """Interactive walkthrough of the direct NTT_psi matrix multiplication.""" + frames = [] + captions = [] + for output_index in range(len(values)): + for input_index in range(len(values)): + exponent = 2 * input_index * output_index + input_index + factor = pow(psi, exponent, modulus) + product = (values[input_index] * factor) % modulus + frames.append(_direct_ntt_frame_svg(values, modulus, psi, output_index, input_index)) + captions.append( + f"Output slot j={output_index}: multiply a{input_index}={values[input_index]} by psi^{exponent}={factor} to contribute {product} mod {modulus}." + ) + return _player_widget( + title="Direct NTTψ Player", + subtitle="Press play and watch the transform build one contribution at a time.", + frames=frames, + captions=captions, + ) + + +def _butterfly_story_frame_svg(trace: TransformTrace, stage_index: int, pair_index: int) -> str: + stage = trace.stages[stage_index] + left, right = stage.pairings[pair_index] + zeta = stage.zetas[pair_index] + width = max(820, len(stage.input_values) * 120) + height = 340 + input_y = 110 + output_y = 240 + spacing = 92 + start_x = 60 + parts = [ + f'', + f'', + _svg_text(34, 48, f"{trace.algorithm.upper()} Butterfly Player", size=22, weight="800"), + _svg_text(34, 68, f"Stage {stage.stage_index}, active pair ({left}, {right}), zeta={zeta}", size=13, fill="#486581"), + ] + + for index, value in enumerate(stage.input_values): + active = index in (left, right) + fill = SVG_HILITE if active else "#f1f5f9" + stroke = SVG_ACCENT if active else "#cbd2d9" + parts.append(_svg_box(start_x + index * spacing, input_y, 64, 52, f"in {index}", str(value), fill=fill, stroke=stroke, stroke_width=3.0 if active else 1.8)) + + for index, value in enumerate(stage.output_values): + active = index in (left, right) + fill = "#d8f3dc" if active else "#f1f5f9" + stroke = SVG_GOOD if active else "#cbd2d9" + parts.append(_svg_box(start_x + index * spacing, output_y, 64, 52, f"out {index}", str(value), fill=fill, stroke=stroke, stroke_width=3.0 if active else 1.8)) + + for index in range(len(stage.input_values)): + x = start_x + index * spacing + 32 + color = SVG_ACCENT if index in (left, right) else "#d9d9d9" + width_line = 4 if index in (left, right) else 1.8 + parts.append(f'') + + x_left = start_x + left * spacing + 32 + x_right = start_x + right * spacing + 32 + parts.append(f'') + parts.append(_svg_text((x_left + x_right) / 2, 196, f"pair ({left}, {right})", size=12, weight="700", fill=SVG_ACCENT, anchor="middle")) + parts.append(_svg_text((x_left + x_right) / 2, 214, f"inputs -> outputs = ({stage.input_values[left]}, {stage.input_values[right]}) -> ({stage.output_values[left]}, {stage.output_values[right]})", size=11, anchor="middle")) + parts.append(_svg_text((x_left + x_right) / 2, 232, stage.note, size=11, anchor="middle", fill="#486581")) + parts.append("") + return "".join(parts) + + +def butterfly_story_player(trace: TransformTrace) -> widgets.Widget: + """Interactive pair-by-pair walkthrough of a butterfly trace.""" + frames = [] + captions = [] + for stage_index, stage in enumerate(trace.stages): + for pair_index, pair in enumerate(stage.pairings): + left, right = pair + frames.append(_butterfly_story_frame_svg(trace, stage_index, pair_index)) + captions.append( + f"Stage {stage.stage_index}: the active pair is {pair}. Watch only these two wires. Everything else is parked until its own butterfly fires." + ) + return _player_widget( + title="Butterfly Pair Player", + subtitle="Play pair by pair. This is the local machine the learner needs to internalize.", + frames=frames, + captions=captions, + ) + def _value_colors(values: Sequence[int]) -> list[str]: colors = [] diff --git a/tools/render_notebooks.py b/tools/render_notebooks.py index 45b90b7..a3dee6c 100644 --- a/tools/render_notebooks.py +++ b/tools/render_notebooks.py @@ -455,12 +455,12 @@ def build_bundle_01() -> None: "mandatory", 2, "demo", - "See The Product Grid And The Diagonal Sums", + "Play The Diagonal Sweep", """ from IPython.display import display from ntt_learning.toy_ntt import convolution_contributions, schoolbook_convolution - from ntt_learning.visuals import plot_convolution_grid + from ntt_learning.visuals import plot_convolution_grid, schoolbook_diagonal_player left = [1, 2, 3, 4] right = [5, 6, 7, 8] @@ -470,6 +470,7 @@ def build_bundle_01() -> None: for row in convolution_contributions(left, right): print(row) + display(schoolbook_diagonal_player(left, right)) fig = plot_convolution_grid(left, right, title="Schoolbook products for [1,2,3,4] * [5,6,7,8]") display(fig) """, @@ -492,12 +493,12 @@ def build_bundle_01() -> None: "mandatory", 2, "demo", - "See Cyclic And Negacyclic Folding Side By Side", + "Play The Wraparound Step By Step", """ from IPython.display import display from ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution, wraparound_contributions - from ntt_learning.visuals import plot_wraparound + from ntt_learning.visuals import plot_wraparound, wraparound_comparison_player left = [1, 2, 3, 4] right = [5, 6, 7, 8] @@ -512,6 +513,7 @@ def build_bundle_01() -> None: for row in wraparound_contributions(raw, n=4, negacyclic=True): print(row) + display(wraparound_comparison_player(raw, n=4)) display(plot_wraparound(raw, n=4, negacyclic=False, title="Cyclic folding into x^4 - 1")) display(plot_wraparound(raw, n=4, negacyclic=True, title="Negacyclic folding into x^4 + 1")) """, @@ -599,7 +601,7 @@ def build_bundle_01() -> None: from IPython.display import display from ntt_learning.toy_ntt import negacyclic_multiply, schoolbook_convolution - from ntt_learning.visuals import plot_convolution_grid, plot_wraparound + from ntt_learning.visuals import plot_convolution_grid, plot_wraparound, schoolbook_diagonal_player, wraparound_comparison_player left = [3, 0, 2, 1] right = [1, 4, 0, 2] @@ -607,6 +609,8 @@ def build_bundle_01() -> None: print("raw convolution:", raw) print("negacyclic result:", negacyclic_multiply(left, right, n=4)) + display(schoolbook_diagonal_player(left, right)) + display(wraparound_comparison_player(raw, n=4)) display(plot_convolution_grid(left, right, title="Prediction check grid")) display(plot_wraparound(raw, n=4, negacyclic=True, title="Prediction check fold")) """, @@ -636,10 +640,11 @@ def build_bundle_01() -> None: from IPython.display import display from ntt_learning.toy_ntt import schoolbook_convolution - from ntt_learning.visuals import plot_wraparound + from ntt_learning.visuals import plot_wraparound, wraparound_comparison_player raw = schoolbook_convolution([2, 5, 0, 1], [1, 0, 3, 2]) print("raw convolution:", raw) + display(wraparound_comparison_player(raw, n=4)) display(plot_wraparound(raw, n=4, negacyclic=True, title="Trace one tail coefficient by eye")) """, ), @@ -815,10 +820,11 @@ def build_bundle_01() -> None: from IPython.display import display from ntt_learning.toy_ntt import schoolbook_convolution - from ntt_learning.visuals import plot_wraparound + from ntt_learning.visuals import plot_wraparound, wraparound_comparison_player raw = schoolbook_convolution([1, 2, 3, 4], [5, 6, 7, 8]) print("raw convolution:", raw) + display(wraparound_comparison_player(raw, n=4)) display(plot_wraparound(raw, n=4, negacyclic=False, title="Positive wrap into x^4 - 1")) display(plot_wraparound(raw, n=4, negacyclic=True, title="Negative wrap into x^4 + 1")) """, @@ -926,7 +932,7 @@ def build_bundle_02() -> None: from IPython.display import display from ntt_learning.toy_ntt import find_primitive_root, find_psi, ntt_psi_exponent_grid, ntt_psi_matrix - from ntt_learning.visuals import plot_ntt_psi_exponent_heatmap, plot_ntt_psi_matrix_heatmap + from ntt_learning.visuals import direct_ntt_player, plot_ntt_psi_exponent_heatmap, plot_ntt_psi_matrix_heatmap modulus = 17 n = 4 @@ -942,6 +948,7 @@ def build_bundle_02() -> None: for row in ntt_psi_matrix(n, modulus, psi): print(row) + display(direct_ntt_player([1, 2, 3, 4], modulus, psi)) display(plot_ntt_psi_exponent_heatmap(n, title="Exponents 2ij + i for n=4")) display(plot_ntt_psi_matrix_heatmap(n, modulus, psi, title="Concrete NTT_psi matrix in Z_17")) """, @@ -962,7 +969,10 @@ def build_bundle_02() -> None: "demo", "Run A Direct NTTψ / INTTψ Round Trip", """ + from IPython.display import display + from ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi + from ntt_learning.visuals import direct_ntt_player signal = [1, 2, 3, 4] modulus = 17 @@ -972,6 +982,7 @@ def build_bundle_02() -> None: print("signal:", signal) print("spectrum:", spectrum) print("inverse recovery:", inverse_ntt_psi(spectrum, modulus, psi)) + display(direct_ntt_player(signal, modulus, psi)) """, ), code( @@ -1079,6 +1090,7 @@ def build_bundle_02() -> None: from IPython.display import display from ntt_learning.toy_ntt import find_psi, forward_ntt_psi, inverse_ntt_psi + from ntt_learning.visuals import direct_ntt_player modulus = 17 psi = find_psi(4, modulus) @@ -1089,6 +1101,7 @@ def build_bundle_02() -> None: print("signal:", signal) print("spectrum:", spectrum) print("inverse:", inverse_ntt_psi(spectrum, modulus, psi)) + display(direct_ntt_player(signal, modulus, psi)) display( widgets.interact( @@ -1449,7 +1462,7 @@ def build_bundle_03() -> None: from IPython.display import display from ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, forward_ntt_psi - from ntt_learning.visuals import interactive_trace, plot_butterfly_network, plot_trace_overview + from ntt_learning.visuals import butterfly_story_player, interactive_trace, plot_butterfly_network, plot_trace_overview signal = [1, 2, 3, 4] modulus = 7681 @@ -1459,6 +1472,7 @@ def build_bundle_03() -> None: print("raw CT output (BO):", trace.raw_output) print("bit-reversed back to NO:", trace.normal_order_output) print("direct NTT_psi:", forward_ntt_psi(signal, modulus, psi)) + display(butterfly_story_player(trace)) display(plot_trace_overview(trace, title="CT overview for [1,2,3,4]")) display(plot_butterfly_network(trace, title="Full CT network for [1,2,3,4]")) display(interactive_trace(trace, title="CT forward trace")) @@ -1505,7 +1519,7 @@ def build_bundle_03() -> None: from IPython.display import display from ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, find_psi - from ntt_learning.visuals import interactive_trace, plot_butterfly_network, plot_trace_overview + from ntt_learning.visuals import butterfly_story_player, interactive_trace, plot_butterfly_network, plot_trace_overview signal = [0, 1, 2, 3, 4, 5, 6, 7] modulus = 97 @@ -1515,6 +1529,7 @@ def build_bundle_03() -> None: print("psi:", psi) print("BO output:", trace.raw_output) print("NO output:", trace.normal_order_output) + display(butterfly_story_player(trace)) display(plot_trace_overview(trace, title="Three CT stages for n=8")) display(plot_butterfly_network(trace, title="Full CT network for n=8")) display(interactive_trace(trace, title="n=8 CT stage explorer")) @@ -1604,9 +1619,10 @@ def build_bundle_03() -> None: from IPython.display import display from ntt_learning.toy_ntt import fast_ntt_psi_ct_trace - from ntt_learning.visuals import interactive_trace + from ntt_learning.visuals import butterfly_story_player, interactive_trace trace = fast_ntt_psi_ct_trace([1, 2, 3, 4], 7681, 1925) + display(butterfly_story_player(trace)) display(interactive_trace(trace, title="Check your n=4 prediction")) """, ),