diff --git a/ntt_learning/visuals.py b/ntt_learning/visuals.py
index e6212e4..cf8acab 100644
--- a/ntt_learning/visuals.py
+++ b/ntt_learning/visuals.py
@@ -114,6 +114,8 @@ def _player_widget(
background: #fffdf8;
border-top: 1px solid #e8dcc9;
padding: 12px 16px;
+ min-height: 62px;
+ box-sizing: border-box;
font-family: 'Avenir Next', 'Trebuchet MS', sans-serif;
color: #243b53;
font-size: 13px;
@@ -153,25 +155,6 @@ def _svg_text(x: float, y: float, text: str, *, size: int = 14, weight: str = "4
return f'{escape(text)}'
-def _svg_badge(
- x: float,
- y: float,
- text: str,
- *,
- fill: str = "#ffffff",
- stroke: str = "#cbd2d9",
- text_fill: str = SVG_INK,
-) -> str:
- width = max(66.0, 10.0 + len(text) * 7.2)
- height = 24.0
- left = x - width / 2
- top = y - height + 6
- return f"""
-
- {escape(text)}
- """
-
-
def _svg_multiline_text(
x: float,
y: float,
@@ -221,12 +204,6 @@ def _svg_canvas_open(width: int, height: int) -> str:
)
-def _wrap_label_geometry(source_x: float, target_x: float, band_top: float, band_bottom: float) -> tuple[float, float]:
- label_x = (source_x + target_x) / 2
- label_y = (band_top + band_bottom) / 2
- return label_x, label_y
-
-
def _html_token(label: str, value: str, *, fill: str, border: str, text: str = SVG_INK) -> str:
return f"""
{escape(label)}
{escape(value)}
@@ -473,68 +452,187 @@ def schoolbook_diagonal_player(left: Sequence[int], right: Sequence[int]) -> wid
)
-def _wrap_compare_frame_svg(coefficients: Sequence[int], n: int, step: int) -> str:
+def _wrap_row_html(
+ label: str,
+ cards: Sequence[str],
+ *,
+ background: str,
+ border: str,
+ subtitle: str,
+) -> str:
+ columns = "150px " + " ".join(["minmax(66px, 1fr)"] * len(cards))
+ return f"""
+
+
{escape(label)}
+
{escape(subtitle)}
+
+
+
+ {escape(label)}
+
+ {"".join(cards)}
+
+
+
+ """
+
+
+def _wrap_compare_frame_html(coefficients: Sequence[int], n: int, step: int) -> str:
cyclic_rows = wraparound_contributions(coefficients, n=n, negacyclic=False)
negacyclic_rows = wraparound_contributions(coefficients, n=n, negacyclic=True)
flat = [(index, coefficient) for index, coefficient in enumerate(coefficients)]
current_index, current_value = flat[step]
- cell = 58
- width = max(980, len(coefficients) * cell + 200)
- height = 450
- top_y = 118
- cyclic_y = 252
- neg_y = 348
-
- parts = [
- _svg_canvas_open(width, height),
- f'
',
- _svg_text(34, 48, "Wraparound Comparison Player", size=22, weight="800"),
- _svg_text(34, 74, f"Current source term: x^{current_index} with coefficient {current_value}", size=13, fill="#486581"),
- _svg_text(34, top_y - 18, "Raw convolution tail", size=14, weight="700"),
- _svg_text(34, cyclic_y - 18, "Cyclic fold into x^n - 1", size=14, weight="700", fill=SVG_BLUE),
- _svg_text(34, neg_y - 18, "Negacyclic fold into x^n + 1", size=14, weight="700", fill=SVG_ACCENT),
- ]
-
- for index, coefficient in enumerate(coefficients):
- fill = SVG_HILITE if index == current_index else "#f1f5f9" if index > current_index else "#d8f3dc"
- stroke = SVG_ACCENT if index == current_index else "#cbd2d9" if index > current_index else SVG_GOOD
- parts.append(_svg_box(48 + index * cell, top_y, 48, 48, f"x^{index}", str(coefficient), fill=fill, stroke=stroke))
-
- for slot, row in enumerate(cyclic_rows):
- parts.append(_svg_box(48 + slot * cell, cyclic_y, 48, 48, f"slot {slot}", str(row["total"]), fill="#e0fbfc", stroke=SVG_BLUE))
- for slot, row in enumerate(negacyclic_rows):
- parts.append(_svg_box(48 + slot * cell, neg_y, 48, 48, f"slot {slot}", str(row["total"]), fill="#ffe8d6", stroke=SVG_ACCENT))
-
wraps, slot = divmod(current_index, n)
cyclic_label = f"+ wrap {wraps}"
neg_label = ("-" if wraps % 2 else "+") + f" wrap {wraps}"
- source_x = 72 + current_index * cell
- target_x = 72 + slot * cell
- arrow_start_y = top_y + 48
- cyclic_label_x, cyclic_label_y = _wrap_label_geometry(source_x, target_x, arrow_start_y + 10, cyclic_y - 26)
- neg_label_x, neg_label_y = _wrap_label_geometry(source_x, target_x, cyclic_y + 56, neg_y - 26)
- parts.append(f'
')
- parts.append(f'
')
- parts.append(f"""
-
-
-
-
-
-
-
-
- """)
- parts.append(_svg_badge(cyclic_label_x, cyclic_label_y, cyclic_label, fill="#f4fbff", stroke=SVG_BLUE, text_fill=SVG_BLUE))
- parts.append(_svg_badge(neg_label_x, neg_label_y, neg_label, fill="#fff0ec", stroke=SVG_ACCENT, text_fill=SVG_ACCENT))
- parts.append("")
- return "".join(parts)
+ raw_cards = []
+ for index, coefficient in enumerate(coefficients):
+ active = index == current_index
+ done = index < current_index
+ fill = SVG_HILITE if active else "#d8f3dc" if done else "#f1f5f9"
+ border = SVG_ACCENT if active else SVG_GOOD if done else "#cbd2d9"
+ raw_cards.append(_html_token(f"x^{index}", str(coefficient), fill=fill, border=border))
+
+ cyclic_cards = []
+ for row in cyclic_rows:
+ active = row["slot"] == slot
+ fill = "#cfeffd" if active else "#eaf7fc"
+ border = SVG_BLUE if active else "#9dc7d8"
+ cyclic_cards.append(_html_token(f"slot {row['slot']}", str(row["total"]), fill=fill, border=border))
+
+ negacyclic_cards = []
+ for row in negacyclic_rows:
+ active = row["slot"] == slot
+ fill = "#ffe1d6" if active else "#fff0ea"
+ border = SVG_ACCENT if active else "#e6a98d"
+ negacyclic_cards.append(_html_token(f"slot {row['slot']}", str(row["total"]), fill=fill, border=border))
+
+ flow_cards = f"""
+
+
+
Cyclic move
+
+ x^{current_index} -> slot {slot}
+ coefficient {current_value}
+ sign {cyclic_label}
+
+
+
+
Negacyclic move
+
+ x^{current_index} -> slot {slot}
+ coefficient {current_value}
+ sign {neg_label}
+
+
+
+ """
+
+ return f"""
+
+
+
+ Active source term: x^{current_index} = {current_value}
+
+
+ Track where this one term lands in x^n - 1 and x^n + 1.
+
+
+
+ {_wrap_row_html(
+ "Raw convolution tail",
+ raw_cards,
+ background=SVG_PANEL,
+ border="#e8dcc9",
+ subtitle="Highlighted card = the exact term moving right now.",
+ )}
+
+ {flow_cards}
+
+ {_wrap_row_html(
+ "Cyclic fold into x^n - 1",
+ cyclic_cards,
+ background="#f5fbff",
+ border="#c7e0ec",
+ subtitle=f"x^{current_index} lands in slot {slot} with sign {cyclic_label}.",
+ )}
+
+ {_wrap_row_html(
+ "Negacyclic fold into x^n + 1",
+ negacyclic_cards,
+ background="#fff6f1",
+ border="#efc7b8",
+ subtitle=f"x^{current_index} lands in slot {slot} with sign {neg_label}.",
+ )}
+
+ """
def wraparound_comparison_player(coefficients: Sequence[int], n: int) -> widgets.Widget:
"""Interactive comparison of cyclic and negacyclic folding, one source term at a time."""
- frames = [_wrap_compare_frame_svg(coefficients, n, step) for step in range(len(coefficients))]
+ frames = [_wrap_compare_frame_html(coefficients, n, step) for step in range(len(coefficients))]
captions = []
for index, coefficient in enumerate(coefficients):
wraps, slot = divmod(index, n)
diff --git a/tests/test_visuals.py b/tests/test_visuals.py
index e693f35..fc67a2a 100644
--- a/tests/test_visuals.py
+++ b/tests/test_visuals.py
@@ -14,8 +14,7 @@ os.environ.setdefault("MPLCONFIGDIR", str(MPLCONFIGDIR))
from ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, find_psi
from ntt_learning.visuals import (
_convolution_frame_html,
- _wrap_compare_frame_svg,
- _wrap_label_geometry,
+ _wrap_compare_frame_html,
butterfly_story_player,
direct_ntt_player,
schoolbook_diagonal_player,
@@ -70,7 +69,6 @@ class VisualUxTests(unittest.TestCase):
psi = find_psi(order=4, modulus=17)
trace = fast_ntt_psi_ct_trace([1, 2, 3, 4], modulus=17, psi=psi)
players = [
- wraparound_comparison_player([3, 0, 2, 1, 5, 4, 6], n=4),
direct_ntt_player([1, 2, 3, 4], modulus=17, psi=psi),
butterfly_story_player(trace),
]
@@ -100,36 +98,24 @@ class VisualUxTests(unittest.TestCase):
self.assertNotEqual(before, frame_html.value)
self.assertIn("Frame 2 of", caption_html.value)
- def test_wraparound_callouts_do_not_share_title_band(self) -> None:
- cell = 58
- top_y = 118
- cyclic_y = 252
- neg_y = 348
- arrow_start_y = top_y + 48
- current_index = 6
- slot = current_index % 4
- source_x = 72 + current_index * cell
- target_x = 72 + slot * cell
+ def test_wraparound_frame_uses_stable_html_bands(self) -> None:
+ frame = _wrap_compare_frame_html([5, 16, 34, 60, 61, 52, 32], n=4, step=6)
- cyclic_label_x, cyclic_label_y = _wrap_label_geometry(source_x, target_x, arrow_start_y + 10, cyclic_y - 26)
- neg_label_x, neg_label_y = _wrap_label_geometry(source_x, target_x, cyclic_y + 56, neg_y - 26)
-
- self.assertLess(cyclic_label_y, cyclic_y - 28)
- self.assertLess(neg_label_y, neg_y - 28)
- self.assertGreater(cyclic_label_y, top_y + 24)
- self.assertGreater(neg_label_y, cyclic_y + 24)
- self.assertNotAlmostEqual(cyclic_label_y, cyclic_y - 18, delta=1)
- self.assertNotAlmostEqual(neg_label_y, neg_y - 18, delta=1)
- self.assertGreater(cyclic_label_x, 0)
- self.assertGreater(neg_label_x, 0)
-
- def test_wraparound_frame_uses_badge_callouts_for_wrap_labels(self) -> None:
- frame = _wrap_compare_frame_svg([5, 16, 34, 60, 61, 52, 32], n=4, step=6)
-
- self.assertIn('fill="#f4fbff"', frame)
- self.assertIn('fill="#fff0ec"', frame)
+ self.assertNotIn("