diff --git a/ntt_learning/visuals.py b/ntt_learning/visuals.py
index 76f3edb..e6212e4 100644
--- a/ntt_learning/visuals.py
+++ b/ntt_learning/visuals.py
@@ -153,6 +153,25 @@ def _svg_text(x: float, y: float, text: str, *, size: int = 14, weight: str = "4
return f'{escape(text)}'
+def _svg_badge(
+ x: float,
+ y: float,
+ text: str,
+ *,
+ fill: str = "#ffffff",
+ stroke: str = "#cbd2d9",
+ text_fill: str = SVG_INK,
+) -> str:
+ width = max(66.0, 10.0 + len(text) * 7.2)
+ height = 24.0
+ left = x - width / 2
+ top = y - height + 6
+ return f"""
+
+ {escape(text)}
+ """
+
+
def _svg_multiline_text(
x: float,
y: float,
@@ -202,6 +221,12 @@ def _svg_canvas_open(width: int, height: int) -> str:
)
+def _wrap_label_geometry(source_x: float, target_x: float, band_top: float, band_bottom: float) -> tuple[float, float]:
+ label_x = (source_x + target_x) / 2
+ label_y = (band_top + band_bottom) / 2
+ return label_x, label_y
+
+
def _html_token(label: str, value: str, *, fill: str, border: str, text: str = SVG_INK) -> str:
return f"""
')
- parts.append(f'')
+ parts.append(f'')
+ parts.append(f'')
parts.append(f"""
@@ -498,8 +526,8 @@ def _wrap_compare_frame_svg(coefficients: Sequence[int], n: int, step: int) -> s
""")
- parts.append(_svg_text((source_x + target_x) / 2, cyclic_y - 18, cyclic_label, size=12, fill=SVG_BLUE, anchor="middle"))
- parts.append(_svg_text((source_x + target_x) / 2, neg_y - 18, neg_label, size=12, fill=SVG_ACCENT, anchor="middle"))
+ parts.append(_svg_badge(cyclic_label_x, cyclic_label_y, cyclic_label, fill="#f4fbff", stroke=SVG_BLUE, text_fill=SVG_BLUE))
+ parts.append(_svg_badge(neg_label_x, neg_label_y, neg_label, fill="#fff0ec", stroke=SVG_ACCENT, text_fill=SVG_ACCENT))
parts.append("")
return "".join(parts)
diff --git a/tests/test_visuals.py b/tests/test_visuals.py
index 6a7299f..e693f35 100644
--- a/tests/test_visuals.py
+++ b/tests/test_visuals.py
@@ -14,6 +14,8 @@ os.environ.setdefault("MPLCONFIGDIR", str(MPLCONFIGDIR))
from ntt_learning.toy_ntt import fast_ntt_psi_ct_trace, find_psi
from ntt_learning.visuals import (
_convolution_frame_html,
+ _wrap_compare_frame_svg,
+ _wrap_label_geometry,
butterfly_story_player,
direct_ntt_player,
schoolbook_diagonal_player,
@@ -98,6 +100,37 @@ class VisualUxTests(unittest.TestCase):
self.assertNotEqual(before, frame_html.value)
self.assertIn("Frame 2 of", caption_html.value)
+ def test_wraparound_callouts_do_not_share_title_band(self) -> None:
+ cell = 58
+ top_y = 118
+ cyclic_y = 252
+ neg_y = 348
+ arrow_start_y = top_y + 48
+ current_index = 6
+ slot = current_index % 4
+ source_x = 72 + current_index * cell
+ target_x = 72 + slot * cell
+
+ cyclic_label_x, cyclic_label_y = _wrap_label_geometry(source_x, target_x, arrow_start_y + 10, cyclic_y - 26)
+ neg_label_x, neg_label_y = _wrap_label_geometry(source_x, target_x, cyclic_y + 56, neg_y - 26)
+
+ self.assertLess(cyclic_label_y, cyclic_y - 28)
+ self.assertLess(neg_label_y, neg_y - 28)
+ self.assertGreater(cyclic_label_y, top_y + 24)
+ self.assertGreater(neg_label_y, cyclic_y + 24)
+ self.assertNotAlmostEqual(cyclic_label_y, cyclic_y - 18, delta=1)
+ self.assertNotAlmostEqual(neg_label_y, neg_y - 18, delta=1)
+ self.assertGreater(cyclic_label_x, 0)
+ self.assertGreater(neg_label_x, 0)
+
+ def test_wraparound_frame_uses_badge_callouts_for_wrap_labels(self) -> None:
+ frame = _wrap_compare_frame_svg([5, 16, 34, 60, 61, 52, 32], n=4, step=6)
+
+ self.assertIn('fill="#f4fbff"', frame)
+ self.assertIn('fill="#fff0ec"', frame)
+ self.assertIn("+ wrap 1", frame)
+ self.assertIn("- wrap 1", frame)
+
if __name__ == "__main__":
unittest.main()