NTT-learning/ntt_learning/visuals.py

804 lines
27 KiB
Python

"""Blunt visual helpers for the NTT notebooks."""
from __future__ import annotations
from typing import Sequence
import ipywidgets as widgets
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from IPython.display import clear_output, display
from .toy_ntt import (
TransformStage,
TransformTrace,
base_multiply_pair,
bit_reversed_order,
forward_ntt_psi,
inverse_ntt_psi,
ntt_psi_exponent_grid,
ntt_psi_matrix,
pairwise_product_grid,
pointwise_multiply,
stage_pairings,
wraparound_contributions,
)
def _value_colors(values: Sequence[int]) -> list[str]:
colors = []
for value in values:
if value < 0:
colors.append("#f08a5d")
elif value == 0:
colors.append("#d9d9d9")
else:
colors.append("#7ad3a8")
return colors
def _draw_value_row(ax, values: Sequence[int], y: float, prefix: str) -> None:
colors = _value_colors(values)
for index, (value, color) in enumerate(zip(values, colors)):
ax.text(
index,
y,
f"{prefix}{index}\n{value}",
ha="center",
va="center",
fontsize=10,
family="monospace",
bbox={
"boxstyle": "round,pad=0.35",
"facecolor": color,
"edgecolor": "#222222",
"linewidth": 1.2,
},
)
def _annotate_grid(ax, grid: Sequence[Sequence[int]]) -> None:
for row, row_values in enumerate(grid):
for column, value in enumerate(row_values):
ax.text(
column,
row,
str(value),
ha="center",
va="center",
color="#101010",
fontsize=10,
family="monospace",
)
def plot_integer_grid(
grid: Sequence[Sequence[int]],
*,
title: str,
x_label: str,
y_label: str,
cmap: str = "YlGnBu",
):
"""Plot a heatmap with the exact integer values written in every cell."""
if not grid or not grid[0]:
raise ValueError("plot_integer_grid requires a non-empty rectangular grid")
fig, ax = plt.subplots(figsize=(max(6, len(grid[0]) * 1.2), max(4, len(grid) * 0.85)))
ax.imshow(grid, cmap=cmap, aspect="auto")
ax.set_title(title, fontsize=14, fontweight="bold")
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_xticks(range(len(grid[0])))
ax.set_yticks(range(len(grid)))
_annotate_grid(ax, grid)
fig.tight_layout()
return fig
def plot_convolution_grid(
left: Sequence[int], right: Sequence[int], title: str = "Schoolbook Product Grid"
):
"""Plot the full schoolbook multiplication table and the diagonal sums."""
grid = pairwise_product_grid(left, right)
diagonal_sums = []
for diagonal in range(len(left) + len(right) - 1):
total = 0
for row in range(len(left)):
column = diagonal - row
if 0 <= column < len(right):
total += grid[row][column]
diagonal_sums.append(total)
fig, axes = plt.subplots(2, 1, figsize=(max(7, len(right) * 1.2), 6), height_ratios=[3, 1])
heatmap_ax, sum_ax = axes
heatmap_ax.imshow(grid, cmap="YlGnBu", aspect="auto")
heatmap_ax.set_title(title, fontsize=14, fontweight="bold")
heatmap_ax.set_xlabel("right coefficient index")
heatmap_ax.set_ylabel("left coefficient index")
heatmap_ax.set_xticks(range(len(right)))
heatmap_ax.set_yticks(range(len(left)))
_annotate_grid(heatmap_ax, grid)
sum_ax.axis("off")
sum_ax.set_title("Diagonal Sums = Convolution Coefficients", fontsize=12, fontweight="bold", pad=8)
for index, value in enumerate(diagonal_sums):
sum_ax.text(
index,
0,
f"y{index}\n{value}",
ha="center",
va="center",
fontsize=10,
family="monospace",
bbox={
"boxstyle": "round,pad=0.35",
"facecolor": "#f4f1de",
"edgecolor": "#222222",
"linewidth": 1.0,
},
)
sum_ax.set_xlim(-0.5, len(diagonal_sums) - 0.5)
sum_ax.set_ylim(-1, 1)
fig.tight_layout()
return fig
def plot_ntt_psi_exponent_heatmap(length: int, title: str = "NTT_psi Exponent Grid"):
"""Plot the exponent pattern 2ij + i used by the direct negative-wrapped NTT."""
return plot_integer_grid(
ntt_psi_exponent_grid(length),
title=title,
x_label="output index j",
y_label="input index i",
cmap="YlOrRd",
)
def plot_ntt_psi_matrix_heatmap(length: int, modulus: int, psi: int, title: str = "NTT_psi Matrix Values"):
"""Plot the concrete direct transform matrix over Z_q."""
return plot_integer_grid(
ntt_psi_matrix(length, modulus, psi),
title=title,
x_label="output index j",
y_label="input index i",
cmap="PuBuGn",
)
def plot_wraparound(
coefficients: Sequence[int],
n: int,
*,
negacyclic: bool = True,
title: str | None = None,
):
"""Plot how the tail wraps back into degree < n."""
rows = wraparound_contributions(coefficients, n=n, negacyclic=negacyclic)
if title is None:
title = "Negacyclic Folding" if negacyclic else "Cyclic Folding"
fig, ax = plt.subplots(figsize=(max(8, len(coefficients) * 1.1), 5.5))
ax.set_title(title, fontsize=14, fontweight="bold")
ax.axis("off")
top_y = 2.4
bottom_y = 0.4
_draw_value_row(ax, coefficients, top_y, "x^")
reduced_values = [row["total"] for row in rows]
_draw_value_row(ax, reduced_values, bottom_y, "slot ")
for slot, row in enumerate(rows):
for contribution in row["contributions"]:
source_index = contribution["source_index"]
color = "#d1495b" if contribution["sign"] < 0 else "#2a9d8f"
label = "-" if contribution["sign"] < 0 else "+"
ax.annotate(
"",
xy=(slot, bottom_y + 0.3),
xytext=(source_index, top_y - 0.25),
arrowprops={"arrowstyle": "->", "color": color, "linewidth": 2.0},
)
mid_x = (slot + source_index) / 2
mid_y = (top_y + bottom_y) / 2 + 0.25
ax.text(
mid_x,
mid_y,
f"{label} wrap {contribution['wraps']}",
ha="center",
va="center",
fontsize=9,
color=color,
family="monospace",
)
ax.set_xlim(-0.8, max(len(coefficients), n) - 0.2)
ax.set_ylim(-0.4, 3.2)
fig.tight_layout()
return fig
def plot_vector_comparison(
left: Sequence[int],
right: Sequence[int],
*,
left_label: str = "left",
right_label: str = "right",
title: str = "Vector Comparison",
):
"""Plot two vectors slot-by-slot with explicit differences."""
if len(left) != len(right):
raise ValueError("plot_vector_comparison requires equal-length vectors")
differences = [int(right_value - left_value) for left_value, right_value in zip(left, right)]
fig, axes = plt.subplots(3, 1, figsize=(max(8, len(left) * 1.3), 7), height_ratios=[1, 1, 1])
labels = [left_label, right_label, "delta"]
rows = [left, right, differences]
row_colors = ["#edf6f9", "#fff3b0", "#f5cac3"]
for ax, label, values, row_color in zip(axes, labels, rows, row_colors):
ax.axis("off")
ax.set_title(label, fontsize=12, fontweight="bold", pad=6)
for index, value in enumerate(values):
ax.text(
index,
0,
f"{index}\n{value}",
ha="center",
va="center",
fontsize=10,
family="monospace",
bbox={
"boxstyle": "round,pad=0.32",
"facecolor": row_color if label != "delta" else _value_colors([value])[0],
"edgecolor": "#222222",
"linewidth": 1.1,
},
)
ax.set_xlim(-0.8, len(values) - 0.2)
ax.set_ylim(-0.8, 0.8)
fig.suptitle(title, fontsize=14, fontweight="bold")
fig.tight_layout()
return fig
def plot_bit_reversal_mapping(length: int, title: str = "Normal Order To Bit-Reversed Order"):
"""Plot the bit-reversal permutation as explicit wires."""
if length <= 0 or length & (length - 1):
raise ValueError("plot_bit_reversal_mapping requires a power-of-two length")
from .toy_ntt import bit_reversed_indices
permutation = bit_reversed_indices(length)
width = length.bit_length() - 1
fig, ax = plt.subplots(figsize=(8, max(4, length * 0.65)))
ax.set_title(title, fontsize=14, fontweight="bold")
ax.axis("off")
for index, target in enumerate(permutation):
ax.text(
0,
-index,
f"{index:>2} | {index:0{width}b}",
ha="center",
va="center",
family="monospace",
bbox={"boxstyle": "round,pad=0.25", "facecolor": "#edf6f9", "edgecolor": "#264653"},
)
ax.text(
4,
-target,
f"{target:>2} | {target:0{width}b}",
ha="center",
va="center",
family="monospace",
bbox={"boxstyle": "round,pad=0.25", "facecolor": "#fff3b0", "edgecolor": "#9c6644"},
)
ax.plot([0.6, 3.4], [-index, -target], color="#7f5539", linewidth=2.2, alpha=0.9)
ax.text(0, 1, "NO", ha="center", va="center", fontsize=12, fontweight="bold")
ax.text(4, 1, "BO", ha="center", va="center", fontsize=12, fontweight="bold")
ax.set_xlim(-1.2, 5.2)
ax.set_ylim(-length + 0.2, 1.8)
fig.tight_layout()
return fig
def plot_butterfly_network(trace: TransformTrace, title: str | None = None):
"""Plot the whole staged network with pair links visible at each stage."""
if title is None:
title = f"{trace.algorithm.upper()} Butterfly Network"
columns = [trace.input_values] + [stage.output_values for stage in trace.stages]
fig, ax = plt.subplots(figsize=(max(10, len(columns) * 2.4), max(5, len(trace.input_values) * 0.8)))
ax.set_title(title, fontsize=14, fontweight="bold")
ax.axis("off")
x_positions = [index * 2.4 for index in range(len(columns))]
for column_index, (x, values) in enumerate(zip(x_positions, columns)):
for row_index, value in enumerate(values):
ax.text(
x,
-row_index,
str(value),
ha="center",
va="center",
fontsize=10,
family="monospace",
bbox={
"boxstyle": "round,pad=0.26",
"facecolor": _value_colors([value])[0],
"edgecolor": "#222222",
"linewidth": 1.0,
},
)
label = "input" if column_index == 0 else f"s{column_index}"
ax.text(x, 1.1, label, ha="center", va="center", fontsize=11, fontweight="bold")
if column_index == 0:
continue
previous_x = x_positions[column_index - 1]
stage = trace.stages[column_index - 1]
colors = ["#264653", "#2a9d8f", "#e76f51", "#8d99ae", "#c1121f", "#3a86ff"]
for index in range(len(values)):
ax.plot([previous_x + 0.35, x - 0.35], [-index, -index], color="#b0b0b0", linewidth=0.9, alpha=0.65)
for pair_index, ((left, right), zeta) in enumerate(zip(stage.pairings, stage.zetas)):
color = colors[pair_index % len(colors)]
x_mid = (previous_x + x) / 2
ax.plot([previous_x + 0.35, x - 0.35], [-left, -left], color=color, linewidth=2.2)
ax.plot([previous_x + 0.35, x - 0.35], [-right, -right], color=color, linewidth=2.2)
ax.plot([x_mid, x_mid], [-left, -right], color=color, linewidth=2.6, alpha=0.95)
ax.text(
x_mid,
-((left + right) / 2),
f"zeta={zeta}",
ha="center",
va="center",
fontsize=8,
family="monospace",
bbox={
"boxstyle": "round,pad=0.18",
"facecolor": "#ffffff",
"edgecolor": color,
"linewidth": 1.0,
},
)
ax.set_xlim(-1.1, x_positions[-1] + 1.1)
ax.set_ylim(-len(trace.input_values) + 0.2, 1.8)
fig.tight_layout()
return fig
def plot_stage_pairing_map(
length: int,
block_size: int,
*,
title: str | None = None,
):
"""Plot which indices talk to each other in one butterfly stage."""
if title is None:
title = f"Stage Pairing Map (n={length}, block={block_size})"
pairs = stage_pairings(length, block_size)
fig, ax = plt.subplots(figsize=(10, max(4.2, length * 0.55)))
ax.set_title(title, fontsize=14, fontweight="bold")
ax.axis("off")
for index in range(length):
ax.text(
0,
-index,
f"{index}",
ha="center",
va="center",
fontsize=10,
family="monospace",
bbox={"boxstyle": "round,pad=0.25", "facecolor": "#edf6f9", "edgecolor": "#264653"},
)
colors = ["#264653", "#2a9d8f", "#e76f51", "#8d99ae", "#c1121f", "#3a86ff"]
for pair_index, (left, right) in enumerate(pairs):
color = colors[pair_index % len(colors)]
ax.plot([0.6, 4.2], [-left, -right], color=color, linewidth=2.4)
ax.text(
5.4,
-((left + right) / 2),
f"{left} <-> {right}",
ha="center",
va="center",
fontsize=9,
family="monospace",
bbox={"boxstyle": "round,pad=0.22", "facecolor": "#ffffff", "edgecolor": color},
)
ax.text(0, 1.0, "indices", ha="center", va="center", fontsize=11, fontweight="bold")
ax.text(5.4, 1.0, "pairs", ha="center", va="center", fontsize=11, fontweight="bold")
ax.set_xlim(-1.0, 6.8)
ax.set_ylim(-length + 0.2, 1.8)
fig.tight_layout()
return fig
def plot_stage_schedule(length: int, title: str | None = None):
"""Plot the full stage schedule for a power-of-two transform length."""
if length <= 0 or length & (length - 1):
raise ValueError("plot_stage_schedule requires a power-of-two length")
if title is None:
title = f"Butterfly Stage Schedule For n={length}"
stages = []
block_size = 2
stage_index = 1
while block_size <= length:
stages.append(
{
"stage": stage_index,
"block_size": block_size,
"pair_distance": block_size // 2,
"pair_count": len(stage_pairings(length, block_size)),
}
)
block_size *= 2
stage_index += 1
fig, ax = plt.subplots(figsize=(11, max(4.5, len(stages) * 0.95)))
ax.set_title(title, fontsize=14, fontweight="bold")
ax.axis("off")
headers = ["stage", "block", "distance", "pairs", "what happens"]
x_positions = [0, 2.0, 4.0, 6.0, 9.2]
for x, header in zip(x_positions, headers):
ax.text(x, 1.0, header, ha="center", va="center", fontsize=11, fontweight="bold")
for row_index, row in enumerate(stages):
y = -row_index
explanation = f"indices {row['pair_distance']} apart talk inside blocks of {row['block_size']}"
values = [
str(row["stage"]),
str(row["block_size"]),
str(row["pair_distance"]),
str(row["pair_count"]),
explanation,
]
for x, value in zip(x_positions, values):
ax.text(
x,
y,
value,
ha="center",
va="center",
fontsize=10,
family="monospace",
bbox={
"boxstyle": "round,pad=0.24",
"facecolor": "#edf6f9" if x < 8 else "#ffffff",
"edgecolor": "#264653",
"linewidth": 1.0,
},
)
ax.set_xlim(-1.0, 12.3)
ax.set_ylim(-len(stages) + 0.2, 1.7)
fig.tight_layout()
return fig
def plot_stage(stage: TransformStage, title: str | None = None):
"""Plot one explicit butterfly stage with input and output rows."""
if title is None:
title = f"{stage.algorithm.upper()} Stage {stage.stage_index}"
fig, ax = plt.subplots(figsize=(max(8, len(stage.input_values) * 1.35), 5.8))
ax.set_title(title, fontsize=14, fontweight="bold")
ax.axis("off")
input_y = 2.6
output_y = 0.5
_draw_value_row(ax, stage.input_values, input_y, "i")
_draw_value_row(ax, stage.output_values, output_y, "o")
colors = ["#264653", "#2a9d8f", "#e76f51", "#8d99ae", "#c1121f", "#3a86ff"]
for pair_index, ((left, right), zeta) in enumerate(zip(stage.pairings, stage.zetas)):
color = colors[pair_index % len(colors)]
center_x = (left + right) / 2
ax.plot([left, right], [input_y - 0.45, input_y - 0.45], color=color, linewidth=2.5)
ax.plot([left, left], [input_y - 0.45, output_y + 0.55], color=color, linewidth=1.5, alpha=0.85)
ax.plot([right, right], [input_y - 0.45, output_y + 0.55], color=color, linewidth=1.5, alpha=0.85)
ax.text(
center_x,
1.55,
f"pair {left}-{right}\nzeta={zeta}",
ha="center",
va="center",
fontsize=10,
family="monospace",
bbox={
"boxstyle": "round,pad=0.35",
"facecolor": "#ffffff",
"edgecolor": color,
"linewidth": 1.4,
},
)
ax.text(
len(stage.input_values) / 2 - 0.5,
-0.05,
stage.note,
ha="center",
va="center",
fontsize=10,
color="#333333",
)
ax.set_xlim(-0.8, len(stage.input_values) - 0.2)
ax.set_ylim(-0.5, 3.3)
fig.tight_layout()
return fig
def plot_transform_pipeline(
left: Sequence[int],
right: Sequence[int],
*,
modulus: int,
psi: int,
title: str = "Transform-Domain Multiply Pipeline",
):
"""Plot the end-to-end direct NTT_psi multiply pipeline."""
left_hat = forward_ntt_psi(left, modulus, psi)
right_hat = forward_ntt_psi(right, modulus, psi)
product_hat = pointwise_multiply(left_hat, right_hat, modulus)
recovered = inverse_ntt_psi(product_hat, modulus, psi)
lanes = [
("left", list(left)),
("right", list(right)),
("left_hat", left_hat),
("right_hat", right_hat),
("pointwise", product_hat),
("inverse", recovered),
]
fig, ax = plt.subplots(figsize=(max(10, len(lanes) * 2.15), max(4.6, len(left) * 0.85)))
ax.set_title(title, fontsize=14, fontweight="bold")
ax.axis("off")
for lane_index, (label, values) in enumerate(lanes):
x = lane_index * 2.2
for row_index, value in enumerate(values):
ax.text(
x,
-row_index,
str(value),
ha="center",
va="center",
fontsize=10,
family="monospace",
bbox={
"boxstyle": "round,pad=0.24",
"facecolor": _value_colors([value])[0],
"edgecolor": "#222222",
"linewidth": 1.0,
},
)
ax.text(x, 1.1, label, ha="center", va="center", fontsize=10, fontweight="bold")
if lane_index < len(lanes) - 1:
ax.annotate(
"",
xy=(x + 1.5, -len(values) / 2 + 0.4),
xytext=(x + 0.6, -len(values) / 2 + 0.4),
arrowprops={"arrowstyle": "->", "color": "#6c757d", "linewidth": 1.8},
)
ax.text(4.4, 1.6, "NTT_psi", ha="center", va="center", fontsize=10, family="monospace")
ax.text(8.8, 1.6, "slotwise *", ha="center", va="center", fontsize=10, family="monospace")
ax.text(11.0, 1.6, "INTT_psi", ha="center", va="center", fontsize=10, family="monospace")
ax.set_xlim(-1.0, (len(lanes) - 1) * 2.2 + 1.1)
ax.set_ylim(-len(left) + 0.2, 2.0)
fig.tight_layout()
return fig
def plot_base_multiply_pair_diagram(
left: Sequence[int],
right: Sequence[int],
*,
zeta: int,
modulus: int,
title: str = "Base Multiplication On A Degree-1 Pair",
):
"""Plot the two-term base multiplication block used in Kyber-style explanations."""
if len(left) != 2 or len(right) != 2:
raise ValueError("plot_base_multiply_pair_diagram expects two 2-entry vectors")
result = base_multiply_pair(left, right, zeta, modulus)
fig, ax = plt.subplots(figsize=(9, 4.6))
ax.set_title(title, fontsize=14, fontweight="bold")
ax.axis("off")
left_x = 0
right_x = 2.8
out_x = 6.8
ys = [0.9, -0.7]
for x, label, values, facecolor in [
(left_x, "left", left, "#edf6f9"),
(right_x, "right", right, "#fff3b0"),
(out_x, "out", result, "#d8f3dc"),
]:
ax.text(x, 1.8, label, ha="center", va="center", fontsize=12, fontweight="bold")
for index, (y, value) in enumerate(zip(ys, values)):
ax.text(
x,
y,
f"{label}[{index}] = {value}",
ha="center",
va="center",
fontsize=10,
family="monospace",
bbox={
"boxstyle": "round,pad=0.3",
"facecolor": facecolor,
"edgecolor": "#222222",
"linewidth": 1.0,
},
)
ax.annotate("", xy=(left_x + 0.9, 0.1), xytext=(out_x - 1.0, 0.9), arrowprops={"arrowstyle": "->", "linewidth": 2.0, "color": "#355070"})
ax.annotate("", xy=(left_x + 0.9, -0.7), xytext=(out_x - 1.0, -0.7), arrowprops={"arrowstyle": "->", "linewidth": 2.0, "color": "#355070"})
ax.annotate("", xy=(right_x + 0.9, 0.1), xytext=(out_x - 1.0, 0.9), arrowprops={"arrowstyle": "->", "linewidth": 2.0, "color": "#6d597a"})
ax.annotate("", xy=(right_x + 0.9, -0.7), xytext=(out_x - 1.0, -0.7), arrowprops={"arrowstyle": "->", "linewidth": 2.0, "color": "#6d597a"})
ax.text(
4.8,
1.05,
f"c0 = a0*b0 + zeta*a1*b1 mod {modulus}\n= {result[0]}",
ha="center",
va="center",
fontsize=10,
family="monospace",
bbox={"boxstyle": "round,pad=0.32", "facecolor": "#ffffff", "edgecolor": "#355070"},
)
ax.text(
4.8,
-1.0,
f"c1 = a0*b1 + a1*b0 mod {modulus}\n= {result[1]}",
ha="center",
va="center",
fontsize=10,
family="monospace",
bbox={"boxstyle": "round,pad=0.32", "facecolor": "#ffffff", "edgecolor": "#6d597a"},
)
ax.text(4.8, 0.0, f"zeta = {zeta}", ha="center", va="center", fontsize=10, family="monospace")
ax.set_xlim(-1.0, 8.2)
ax.set_ylim(-2.0, 2.2)
fig.tight_layout()
return fig
def plot_root_order_comparison(samples: Sequence[tuple[int, int]], title: str = "Root Existence Check"):
"""Plot which moduli allow n-th and 2n-th root stories."""
fig, ax = plt.subplots(figsize=(10, max(4.5, len(samples) * 0.75)))
ax.set_title(title, fontsize=14, fontweight="bold")
ax.axis("off")
headers = ["n", "q", "n | q-1", "2n | q-1"]
x_positions = [0, 2, 4.2, 6.8]
for x, header in zip(x_positions, headers):
ax.text(x, 1.2, header, ha="center", va="center", fontsize=11, fontweight="bold")
for row_index, (n, q) in enumerate(samples):
y = -row_index
statuses = [str(n), str(q), "yes" if (q - 1) % n == 0 else "no", "yes" if (q - 1) % (2 * n) == 0 else "no"]
for x, value in zip(x_positions, statuses):
facecolor = "#d8f3dc" if value == "yes" else "#f5cac3" if value == "no" else "#edf6f9"
ax.text(
x,
y,
value,
ha="center",
va="center",
fontsize=10,
family="monospace",
bbox={"boxstyle": "round,pad=0.25", "facecolor": facecolor, "edgecolor": "#222222"},
)
ax.set_xlim(-1.0, 8.0)
ax.set_ylim(-len(samples) + 0.2, 1.8)
fig.tight_layout()
return fig
def plot_trace_overview(trace: TransformTrace, title: str | None = None):
"""Plot every stage output as a column of values."""
if title is None:
title = f"{trace.algorithm.upper()} Trace Overview"
columns = [trace.input_values] + [stage.output_values for stage in trace.stages]
fig, ax = plt.subplots(figsize=(max(9, len(columns) * 2.0), max(4.5, len(trace.input_values) * 0.7)))
ax.set_title(title, fontsize=14, fontweight="bold")
ax.axis("off")
for column_index, values in enumerate(columns):
x = column_index * 2.0
for row_index, value in enumerate(values):
ax.text(
x,
-row_index,
str(value),
ha="center",
va="center",
fontsize=10,
family="monospace",
bbox={
"boxstyle": "round,pad=0.25",
"facecolor": _value_colors([value])[0],
"edgecolor": "#222222",
"linewidth": 1.0,
},
)
if column_index == 0:
label = "input"
else:
label = f"stage {column_index}"
ax.text(x, 1, label, ha="center", va="center", fontsize=11, fontweight="bold")
ax.set_xlim(-1.0, (len(columns) - 1) * 2.0 + 1.0)
ax.set_ylim(-len(trace.input_values) + 0.2, 1.8)
fig.tight_layout()
return fig
def interactive_trace(trace: TransformTrace, title: str | None = None):
"""Return a slider-based stage explorer for a transform trace."""
if title is None:
title = f"{trace.algorithm.upper()} Stage Explorer"
slider = widgets.IntSlider(
value=1,
min=1,
max=max(1, len(trace.stages)),
step=1,
description="Stage",
continuous_update=False,
)
output = widgets.Output()
def render(stage_index: int) -> None:
with output:
clear_output(wait=True)
stage = trace.stages[stage_index - 1]
fig = plot_stage(stage, title=f"{title} | Stage {stage_index}")
display(fig)
plt.close(fig)
rows = []
for pair, zeta in zip(stage.pairings, stage.zetas):
left, right = pair
rows.append(
f"pair {pair}: inputs=({stage.input_values[left]}, {stage.input_values[right]}) "
f"-> outputs=({stage.output_values[left]}, {stage.output_values[right]}) | zeta={zeta}"
)
print("\n".join(rows))
slider.observe(lambda change: render(change["new"]), names="value")
render(slider.value)
widget = widgets.VBox([widgets.HTML(f"<h4>{title}</h4>"), slider, output])
return widget
def show_trace(trace: TransformTrace, title: str | None = None):
"""Display the interactive trace widget immediately."""
widget = interactive_trace(trace, title=title)
display(widget)
return widget