from klotho.topos.graphs import Graph, Tree
from klotho.topos.collections.sets import CombinationSet, PartitionSet
from klotho.topos.graphs.lattices import Lattice
from klotho.thetos.parameters.parameter_fields import ParameterField
from klotho.chronos.rhythm_trees import RhythmTree
from klotho.chronos.temporal_units import TemporalUnit, TemporalUnitSequence, TemporalBlock
from klotho.tonos.systems.combination_product_sets import CombinationProductSet, MasterSet
from klotho.tonos.scales import Scale
from klotho.tonos.chords import Chord, Voicing
from klotho.dynatos.dynamics import DynamicRange
from klotho.dynatos.envelopes import Envelope
from klotho.thetos.composition.compositional import CompositionalUnit
from klotho.thetos.parameters.parameter_tree import ParameterTree
from klotho.topos.collections.sequences import Pattern
try:
import networkx as nx
except ImportError:
class _NetworkXCompat:
def __getattr__(self, name):
raise ImportError(f"NetworkX function '{name}' not available. Visualization may be limited.")
nx = _NetworkXCompat()
import matplotlib.pyplot as plt
from fractions import Fraction
import numpy as np
import plotly.graph_objects as go
import math
from sklearn.manifold import MDS, SpectralEmbedding
from ._dispatch import _plot_rt, _plot_master_set, _plot_cps, _reduce_positions, _cps_node_positions, _plot_lattice
from ._dispatch import KlothoPlot
from ._plot_pattern import plot_pattern
__all__ = ['plot']
[docs]
def plot(obj, **kwargs):
"""
Universal plot dispatcher for Klotho objects.
Returns a :class:`KlothoPlot` for types that support animation
(Lattice, CPS, RhythmTree, TemporalUnit, CompositionalUnit).
Jupyter renders the static figure automatically via
``_repr_html_``. Call ``.play()`` on the result to trigger
animated playback with audio.
For types without animation support, the figure is displayed
immediately via ``IPython.display.display`` and ``None`` is
returned.
Parameters
----------
obj : object
Object to plot. Supported types include ``Tree``,
``RhythmTree``, ``CombinationSet``, ``CombinationProductSet``,
``DynamicRange``, ``Envelope``, ``Lattice``, ``Scale``,
``Chord``, ``Voicing``, and graph-like objects.
**kwargs
Keyword arguments forwarded to the type-specific plotting
function.
Returns
-------
KlothoPlot or None
A ``KlothoPlot`` wrapper for animatable types, otherwise
``None``.
Raises
------
TypeError
If the object type is not supported for plotting.
"""
def _show(fig):
if fig is not None:
try:
from IPython.display import display as ipy_display, HTML
import matplotlib.figure
if hasattr(fig, 'to_html'):
html_str = fig.to_html(full_html=False, include_plotlyjs=True)
ipy_display(HTML(html_str))
else:
ipy_display(fig)
if isinstance(fig, matplotlib.figure.Figure):
plt.close(fig)
except ImportError:
pass
def _wrap(plot_fn, target, kw):
animate = kw.pop('animate', False)
if animate:
return _show(plot_fn(target, animate=True, **kw))
return KlothoPlot(plot_fn, target, kw)
match obj:
case Graph():
match obj:
case Tree():
match obj:
case RhythmTree():
return _wrap(_plot_rt, obj, dict(kwargs))
case ParameterTree():
return _show(_plot_parameter_tree(obj, **kwargs))
case _:
return _show(_plot_tree(obj, **kwargs))
case ParameterField():
return _show(_plot_field(obj, **kwargs))
case Lattice():
return _wrap(_plot_lattice, obj, dict(kwargs))
case _:
return _show(_plot_graph(obj._graph, **kwargs))
case MasterSet():
return _wrap(_plot_master_set, obj, dict(kwargs))
case CombinationSet():
match obj:
case CombinationProductSet():
return _wrap(_plot_cps, obj, dict(kwargs))
case _:
return _show(_plot_cs(obj, **kwargs))
case Scale() | Chord() | Voicing():
return _show(_plot_scale_chord(obj, **kwargs))
case PartitionSet():
return _show(_plot_graph(obj.graph._graph, **kwargs))
case DynamicRange():
return _show(_plot_dynamic_range(obj, **kwargs))
case Envelope():
return _show(_plot_envelope(obj, **kwargs))
case Pattern():
return _show(plot_pattern(obj, **kwargs))
case CompositionalUnit():
return _wrap(lambda o, **kw: _plot_rt(o._rt, audio_source=o, **kw), obj, dict(kwargs))
case TemporalUnit():
return _wrap(lambda o, **kw: _plot_rt(o._rt, audio_source=o, **kw), obj, dict(kwargs))
case TemporalUnitSequence():
raise NotImplementedError("Plotting for temporal unit sequences not yet implemented")
case TemporalBlock():
raise NotImplementedError("Plotting for temporal blocks not yet implemented")
case _ if hasattr(obj, 'nodes') and hasattr(obj, 'edges'):
return _show(_plot_graph(obj, **kwargs))
case _:
raise TypeError(f"Unsupported object type for plotting: {type(obj)}")
def _plot_parameter_tree(tree: ParameterTree, attributes: list[str] | None = None, figsize: tuple[float, float] = (20, 5),
invert: bool = True, output_file: str | None = None) -> go.Figure:
"""
Visualize a ParameterTree structure with muting logic applied.
Similar to ``_plot_tree`` but respects the ParameterTree's muting
mechanism, only displaying active (non-muted) attributes for each
node.
Parameters
----------
tree : ParameterTree
ParameterTree instance to visualize.
attributes : list of str or None, optional
Node attributes to display instead of labels. Special values
``"node_id"``, ``"node"``, or ``"id"`` display the node
identifier.
figsize : tuple of float, optional
Width and height of the output figure in inches.
invert : bool, optional
When ``True``, places the root at the top of the diagram.
output_file : str or None, optional
Path to save the visualization. Displays the plot when
``None``.
"""
def _hierarchy_pos(G, root, width=1.5, height=1.0, xcenter=0.5, pos=None, parent=None, depth=0, inverted=True, vert_gap=None):
if pos is None:
max_depth = _get_max_depth(G, root)
vert_gap = height / max(max_depth, 1) if max_depth > 0 else height
max_breadth = _get_max_breadth(G, root)
width = max(2.5, 1.5 * max_breadth)
pos = {root: (xcenter, height if inverted else 0)}
else:
y = (height - (depth * vert_gap)) if inverted else (depth * vert_gap)
pos[root] = (xcenter, y)
children = _get_children(G, root, parent)
if children:
chain_depths = {child: _get_max_depth(G, child, parent=root) for child in children}
total_depth = sum(chain_depths.values())
if len(children) == 1:
dx = width * 0.8
else:
dx = width / len(children)
nextx = xcenter - width/2 + dx/2
for child in children:
depth_factor = 1.0
if total_depth > 0 and len(children) > 1:
depth_factor = 0.5 + (0.5 * chain_depths[child] / total_depth)
child_width = dx * depth_factor * 1.5
_hierarchy_pos(G, child,
width=child_width,
height=height,
xcenter=nextx,
pos=pos,
parent=root,
depth=depth+1,
inverted=inverted,
vert_gap=vert_gap)
nextx += dx
return pos
def _count_leaves(G, node, parent=None):
children = _get_children(G, node, parent)
if not children:
return 1
return sum(_count_leaves(G, child, node) for child in children)
def _get_max_depth(G, node, parent=None, current_depth=0):
children = _get_children(G, node, parent)
if not children:
return current_depth
return max(_get_max_depth(G, child, node, current_depth + 1) for child in children)
def _get_max_breadth(G, root, parent=None):
nodes_by_level = {}
def _count_by_level(node, level=0, parent=None):
if level not in nodes_by_level:
nodes_by_level[level] = 0
nodes_by_level[level] += 1
children = _get_children(G, node, parent)
for child in children:
_count_by_level(child, level+1, node)
_count_by_level(root, parent=parent)
return max(nodes_by_level.values()) if nodes_by_level else 1
G = tree._graph
root = tree.root
height_scale = figsize[1] / 1.5
pos = _hierarchy_pos(G, root, height=height_scale, inverted=invert)
fig = go.Figure()
# Handle edge iteration for both wrapped and raw RustworkX graphs
if hasattr(G, 'edge_list') and str(type(G)).find('rustworkx') != -1:
# Raw RustworkX graph - use edge_list() for (u, v) pairs
edges = G.edge_list()
else:
# Wrapped graph or NetworkX - use edges() method
edges = G.edges()
for u, v in edges:
if u in pos and v in pos:
x1, y1 = pos[u]
x2, y2 = pos[v]
fig.add_trace(
go.Scatter(
x=[x1, x2], y=[y1, y2],
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
line=dict(color='#808080', width=2),
showlegend=False,
hoverinfo='none'
)
)
node_x, node_y = [], []
hover_data = []
node_symbols = []
node_text = []
# Handle node iteration for both wrapped and raw RustworkX graphs
if hasattr(G, 'node_indices') and str(type(G)).find('rustworkx') != -1:
# Raw RustworkX graph - use node_indices() for node IDs
nodes = G.node_indices()
else:
# Wrapped graph or NetworkX - use nodes() method
nodes = G.nodes()
for node in nodes:
if node in pos:
x, y = pos[node]
node_x.append(x)
node_y.append(y)
active_items = tree[node].active_items()
display_text = ""
if "defName" in active_items:
display_text = str(active_items["defName"])
node_text.append(display_text)
if attributes is None:
label_text = str(G[node].get('label', node)) if G[node].get('label') is not None else str(node)
else:
label_parts = []
for attr in attributes:
if attr in {"node_id", "node", "id"}:
label_parts.append(str(node))
elif attr in active_items:
value = active_items[attr]
label_parts.append(f"{attr}: {value}" if value is not None else f"{attr}: None")
label_text = "<br>".join(label_parts)
hover_data.append(label_text)
is_leaf = len(list(G.neighbors(node))) == 0
node_symbols.append('circle' if is_leaf else 'square')
fig.add_trace(
go.Scatter(
x=node_x, y=node_y,
mode='markers+text',
marker=dict(
size=30,
color='white',
line=dict(color='white', width=2),
symbol=node_symbols
),
text=node_text,
textposition='middle center',
textfont=dict(color='#404040', size=10, family='Arial', weight='bold'),
hovertemplate='%{customdata}<extra></extra>',
customdata=hover_data,
showlegend=False
)
)
width_px, height_px = int(figsize[0] * 100), int(figsize[1] * 100)
x_padding = (max(node_x) - min(node_x)) * 0.02 if node_x else 0.05
y_padding = (max(node_y) - min(node_y)) * 0.1 if node_y else 0.2
fig.update_layout(
width=width_px,
height=height_px,
paper_bgcolor='black',
plot_bgcolor='black',
xaxis=dict(
showgrid=False, zeroline=False, showticklabels=False,
range=[min(node_x)-x_padding, max(node_x)+x_padding] if node_x else [-1, 1]
),
yaxis=dict(
showgrid=False, zeroline=False, showticklabels=False,
scaleanchor="x", scaleratio=1,
range=[min(node_y)-y_padding, max(node_y)+y_padding] if node_y else [-1, 1]
),
hovermode='closest',
margin=dict(l=0, r=0, t=0, b=0),
)
if output_file:
if output_file.endswith('.html'):
fig.write_html(output_file)
else:
fig.write_image(output_file)
return fig
def _plot_tree(tree: Tree, attributes: list[str] | None = None, figsize: tuple[float, float] = (20, 5),
invert: bool = True, output_file: str | None = None) -> None:
"""
Visualize a tree structure with customizable node appearance and layout.
Renders a tree graph with nodes positioned hierarchically. Internal
nodes are drawn as squares and leaf nodes as circles, with white
borders on a black background.
Parameters
----------
tree : Tree
Tree instance to visualize.
attributes : list of str or None, optional
Node attributes to display instead of labels. Special values
``"node_id"``, ``"node"``, or ``"id"`` display the node
identifier.
figsize : tuple of float, optional
Width and height of the output figure in inches.
invert : bool, optional
When ``True``, places the root at the top of the diagram.
output_file : str or None, optional
Path to save the visualization. Displays the plot when
``None``.
"""
def _hierarchy_pos(G, root, width=1.5, vert_gap=0.2, xcenter=0.5, pos=None, parent=None, depth=0, inverted=True):
"""
Position nodes in a hierarchical layout for wide and deep trees.
Allocates horizontal space based on the structure of the tree,
giving more room to branches with deeper chains and ensuring
proper vertical spacing.
Returns
-------
dict
Mapping of each node to its ``(x, y)`` position.
"""
if pos is None:
max_depth = _get_max_depth(G, root)
vert_gap = min(0.2, 0.8 / max(max_depth, 1))
max_breadth = _get_max_breadth(G, root)
width = max(1.5, 0.8 * max_breadth)
pos = {root: (xcenter, 1 if inverted else 0)}
else:
y = (1 - (depth * vert_gap)) if inverted else (depth * vert_gap)
pos[root] = (xcenter, y)
children = _get_children(G, root, parent)
if children:
if len(children) == 1:
child_width = width * 0.8
child_x = xcenter
_hierarchy_pos(G, children[0],
width=child_width,
vert_gap=vert_gap,
xcenter=child_x,
pos=pos,
parent=root,
depth=depth+1,
inverted=inverted)
else:
dx = width / len(children)
start_x = xcenter - width/2 + dx/2
for i, child in enumerate(children):
child_x = start_x + i * dx
child_width = dx * 0.9
_hierarchy_pos(G, child,
width=child_width,
vert_gap=vert_gap,
xcenter=child_x,
pos=pos,
parent=root,
depth=depth+1,
inverted=inverted)
return pos
def _count_leaves(G, node, parent=None):
children = _get_children(G, node, parent)
if not children:
return 1
return sum(_count_leaves(G, child, node) for child in children)
def _get_max_depth(G, node, parent=None, current_depth=0):
children = _get_children(G, node, parent)
if not children:
return current_depth
return max(_get_max_depth(G, child, node, current_depth + 1) for child in children)
def _get_max_breadth(G, root, parent=None):
"""
Calculate the maximum breadth of the tree.
Returns the maximum number of nodes at any single level of the tree.
"""
nodes_by_level = {}
def _count_by_level(node, level=0, parent=None):
if level not in nodes_by_level:
nodes_by_level[level] = 0
nodes_by_level[level] += 1
children = _get_children(G, node, parent)
for child in children:
_count_by_level(child, level+1, node)
_count_by_level(root, parent=parent)
return max(nodes_by_level.values()) if nodes_by_level else 1
original_G = tree._graph
root = tree.root
# Use original_G for our custom tree operations
pos = _hierarchy_pos(tree, root, inverted=invert)
# Convert to NetworkX for matplotlib plotting
is_rustworkx = hasattr(original_G, 'node_indices') or str(type(original_G)).find('rustworkx') != -1
if is_rustworkx:
import networkx as nx
if hasattr(original_G, 'in_degree'):
G = nx.DiGraph()
else:
G = nx.Graph()
# Add nodes with their data
for node_idx in original_G.node_indices():
try:
node_data = tree[node_idx] if hasattr(tree, '__getitem__') else {}
G.add_node(node_idx, **node_data)
except:
G.add_node(node_idx)
# Add edges
try:
for edge in original_G.edge_list():
src, dst = edge
G.add_edge(src, dst)
except:
pass
else:
G = original_G
plt.figure(figsize=figsize)
ax = plt.gca()
ax.set_facecolor('black')
plt.gcf().set_facecolor('black')
for node, (x, y) in pos.items():
if attributes is None:
value_attr = getattr(tree, '_node_value_attr', 'label')
value = tree[node].get(value_attr)
label_text = str(value) if value is not None else str(node)
else:
label_parts = []
for attr in attributes:
if attr in {"node_id", "node", "id"}:
label_parts.append(str(node))
elif attr in tree[node]:
value = tree[node][attr]
label_parts.append(str(value) if value is not None else '')
label_text = "\n".join(label_parts)
is_leaf = len(list(G.neighbors(node))) == 0
box_style = "circle,pad=0.3" if is_leaf else "square,pad=0.3"
ax.text(x, y, label_text, ha='center', va='center', zorder=5, fontsize=16,
bbox=dict(boxstyle=box_style, fc="black", ec="white", linewidth=2),
color='white')
nx.draw_networkx_edges(G, pos, arrows=False, width=2.0, edge_color='white')
plt.axis('off')
plt.margins(x=0)
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
if output_file:
plt.savefig(output_file, bbox_inches='tight', pad_inches=0)
plt.close()
else:
plt.show()
def _get_children(G, node, parent=None):
"""Return the child nodes of *node* in a tree-like graph.
Parameters
----------
G : graph
A RustworkX, NetworkX, or Klotho graph object.
node : hashable
The parent node whose children are requested.
parent : hashable or None, optional
In undirected graphs, *parent* is excluded from the neighbour
list to avoid back-traversal.
Returns
-------
list
Child node identifiers.
"""
if hasattr(G, 'successors') and not str(type(G)).find('rustworkx') != -1:
# For our wrapped Graph/Tree classes
return list(G.successors(node))
elif hasattr(G, 'successor_indices'):
# For raw RustworkX graphs - use successor_indices to get node indices, not data
return list(G.successor_indices(node))
elif hasattr(G, 'neighbors'):
# For NetworkX graphs (both directed and undirected)
children = list(G.neighbors(node))
if parent is not None and parent in children:
children.remove(parent)
return children
else:
return []
def _is_leaf(G, node):
"""Return ``True`` if *node* has no children in *G*."""
return len(_get_children(G, node)) == 0
def _get_graph_layout(G, layout='spring', k=1, dim=2):
"""
Compute node positions using RustworkX or NetworkX layout algorithms.
Parameters
----------
G : graph
A RustworkX or NetworkX graph.
layout : str, optional
Layout algorithm name (e.g. ``'spring'``, ``'circular'``).
k : float, optional
Optimal distance between nodes for force-directed layouts.
dim : int, optional
Number of spatial dimensions (2 or 3).
Returns
-------
dict
Mapping of node identifiers to position tuples.
"""
import rustworkx as rx
# Check if this is a RustworkX graph
is_rustworkx = hasattr(G, 'node_indices') or str(type(G)).find('rustworkx') != -1
try:
# Try RustworkX layouts first for better performance
if hasattr(G, '_graph') and hasattr(G._graph, 'node_indices'):
# This is our Graph class wrapping RustworkX
rx_graph = G._graph
if layout == 'spring':
return rx.spring_layout(rx_graph, k=k, dim=dim)
elif layout == 'circular':
pos_2d = rx.circular_layout(rx_graph)
if dim == 3:
return {node: (*coords, 0) for node, coords in pos_2d.items()}
return pos_2d
elif layout == 'random':
return rx.random_layout(rx_graph, dim=dim)
elif is_rustworkx:
# This is a raw RustworkX graph - use RustworkX layouts directly
rx_graph = G
if layout == 'spring':
return rx.spring_layout(rx_graph, k=k, dim=dim)
elif layout == 'circular':
pos_2d = rx.circular_layout(rx_graph)
if dim == 3:
return {node: (*coords, 0) for node, coords in pos_2d.items()}
return pos_2d
elif layout == 'random':
return rx.random_layout(rx_graph, dim=dim)
except Exception:
pass
# Fallback to NetworkX layouts - convert RustworkX to NetworkX if needed
try:
nx_graph = G
# Convert RustworkX to NetworkX if needed
if is_rustworkx:
import networkx as nx
nx_graph = nx.DiGraph() if hasattr(G, 'in_degree') else nx.Graph()
# Preserve original node indices and data for semantic meaning
for node_idx in G.node_indices():
try:
node_data = G[node_idx] if hasattr(G, '__getitem__') else {}
nx_graph.add_node(node_idx, **node_data)
except:
nx_graph.add_node(node_idx)
# Add edges with their data
try:
for edge in G.edge_list():
src, dst = edge
try:
edge_data = G.get_edge_data(src, dst) if hasattr(G, 'get_edge_data') else {}
if edge_data is None:
edge_data = {}
nx_graph.add_edge(src, dst, **edge_data)
except:
nx_graph.add_edge(src, dst)
except:
# If edge_list fails, try to add edges without data
try:
edges = [(i, j) for i in G.node_indices() for j in G.node_indices()
if G.has_edge(i, j)]
nx_graph.add_edges_from(edges)
except:
pass
if layout == 'spring':
pos = nx.spring_layout(nx_graph, k=k, dim=dim)
elif layout == 'random':
pos = nx.random_layout(nx_graph, dim=dim)
elif layout == 'kamada_kawai':
pos = nx.kamada_kawai_layout(nx_graph, dim=dim)
elif layout == 'spectral':
if dim == 3:
try:
pos = nx.spectral_layout(nx_graph, dim=3)
except (ValueError, Exception):
pos_2d = nx.spectral_layout(nx_graph, dim=2)
pos = {node: (*coords, 0) for node, coords in pos_2d.items()}
else:
pos = nx.spectral_layout(nx_graph, dim=dim)
elif layout == 'circular':
pos_2d = nx.circular_layout(nx_graph)
if dim == 3:
pos = {node: (*coords, 0) for node, coords in pos_2d.items()}
else:
pos = pos_2d
else:
layout_func = getattr(nx, f'{layout}_layout')
pos_2d = layout_func(nx_graph)
if dim == 3:
pos = {node: (*coords, 0) for node, coords in pos_2d.items()}
else:
pos = pos_2d
# No need to map back since we preserved original node indices
return pos
except Exception:
# Final fallback to spring layout
try:
if is_rustworkx:
# For RustworkX graphs, use a simple layout
node_indices = list(G.node_indices())
return {node: (i, 0, 0) if dim == 3 else (i, 0) for i, node in enumerate(node_indices)}
else:
return nx.spring_layout(G, k=k, dim=dim)
except Exception:
# Return minimal layout if all else fails
try:
if is_rustworkx:
node_indices = list(G.node_indices())
nodes = node_indices
else:
nodes = list(G.nodes()) if hasattr(G, 'nodes') else list(range(len(G)))
return {node: (i, 0, 0) if dim == 3 else (i, 0) for i, node in enumerate(nodes)}
except Exception:
# Ultimate fallback - create simple positions for any graph
if hasattr(G, '__len__'):
return {i: (i, 0, 0) if dim == 3 else (i, 0) for i in range(len(G))}
else:
return {0: (0, 0, 0) if dim == 3 else (0, 0)}
def _plot_graph(G, figsize: tuple[float, float] = (10, 10),
node_size: float = 1000, font_size: float = 12,
layout: str = 'spring', k: float = 1,
show_edge_labels: bool = True,
edge_width: bool = False,
edge_color: bool = False,
width_range: tuple[float, float] = (0.75, 3),
cmap: str = 'viridis',
invert_weights: bool = False,
path: list | None = None,
attributes: list[str] | None = None,
dim: int = 2,
output_file: str | None = None):
"""
Render a general graph using matplotlib with customizable layout and styling.
Parameters
----------
G : graph
RustworkX or NetworkX graph to visualize.
figsize : tuple of float, optional
Width and height of the figure in inches.
node_size : float, optional
Size of the drawn nodes.
font_size : float, optional
Font size for node labels.
layout : str, optional
Layout algorithm (``'spring'``, ``'circular'``, etc.).
k : float, optional
Optimal node distance for spring layout.
show_edge_labels : bool, optional
Whether to display edge weight labels.
edge_width : bool, optional
Scale edge widths by weight when ``True``.
edge_color : bool, optional
Color edges by weight when ``True``.
width_range : tuple of float, optional
Min and max edge widths when *edge_width* is enabled.
cmap : str, optional
Matplotlib colormap for edge coloring.
invert_weights : bool, optional
Invert the weight scale for width/color mapping.
path : list or None, optional
Sequence of node identifiers to highlight as a path.
attributes : list of str or None, optional
Node attributes to display as labels.
dim : int, optional
Dimensionality of the layout (2 or 3).
output_file : str or None, optional
Path to save the figure. Displays the plot when ``None``.
"""
# Convert RustworkX graphs to NetworkX for plotting compatibility
original_G = G
is_rustworkx = hasattr(G, 'node_indices') or str(type(G)).find('rustworkx') != -1
if is_rustworkx:
import networkx as nx
# Convert RustworkX to NetworkX for plotting functions
if hasattr(G, 'in_degree'):
G = nx.DiGraph()
else:
G = nx.Graph()
# Add nodes with their data
for node_idx in original_G.node_indices():
try:
node_data = original_G[node_idx] if hasattr(original_G, '__getitem__') else {}
G.add_node(node_idx, **node_data)
except:
G.add_node(node_idx)
# Add edges with their data
try:
for edge in original_G.edge_list():
src, dst = edge
try:
edge_data = original_G.get_edge_data(src, dst) if hasattr(original_G, 'get_edge_data') else {}
if edge_data is None:
edge_data = {}
G.add_edge(src, dst, **edge_data)
except:
G.add_edge(src, dst)
except:
# If edge_list fails, try to add edges without data
try:
edges = [(i, j) for i in original_G.node_indices() for j in original_G.node_indices()
if original_G.has_edge(i, j)]
G.add_edges_from(edges)
except:
pass
if dim not in [2, 3]:
raise ValueError(f"dim must be 2 or 3, got {dim}")
# Use original RustworkX graph for layout if possible, otherwise use converted graph
layout_graph = original_G if is_rustworkx else G
pos = _get_graph_layout(layout_graph, layout=layout, k=k, dim=dim)
is_directed = isinstance(G, nx.DiGraph)
weights = []
min_weight = max_weight = None
if edge_width or edge_color:
weight_dict = nx.get_edge_attributes(G, 'weight')
if weight_dict:
weights = list(weight_dict.values())
min_weight, max_weight = min(weights), max(weights)
def get_edge_props(edge_list, for_plotly=False):
widths = []
colors = []
for u, v in edge_list:
if edge_width and weights:
w = G[u][v].get('weight', min_weight)
if max_weight > min_weight:
norm_w = (w - min_weight) / (max_weight - min_weight)
if invert_weights:
norm_w = 1 - norm_w
else:
norm_w = 0
width = width_range[0] + norm_w * (width_range[1] - width_range[0])
widths.append(width)
else:
widths.append(2)
if edge_color and weights:
w = G[u][v].get('weight', min_weight)
if max_weight > min_weight:
norm_w = (w - min_weight) / (max_weight - min_weight)
if invert_weights:
norm_w = 1 - norm_w
else:
norm_w = 0
color = plt.cm.get_cmap(cmap)(norm_w)
if for_plotly:
color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))
colors.append(color_hex)
else:
colors.append(color)
else:
colors.append('#808080')
return widths, colors
def get_edge_rad(u, v):
if not is_directed:
return 0.0
if u == v:
return 0.24
if G.has_edge(v, u):
pair_key = tuple(sorted((repr(u), repr(v))))
pair_sign = 1 if (sum(ord(ch) for ch in ''.join(pair_key)) % 2 == 0) else -1
return 0.28 * pair_sign
return 0.06
def get_node_labels():
labels = {}
for node in G.nodes():
node_attrs = G.nodes[node] if hasattr(G, 'nodes') else {}
if attributes is None:
label_text = str(node)
else:
label_parts = []
for attr in attributes:
if attr in {"node_id", "node", "id"}:
label_parts.append(str(node))
elif attr in node_attrs:
value = node_attrs[attr]
label_parts.append(str(value) if value is not None else '')
if not label_parts or all(part == '' for part in label_parts):
label_parts = [str(node)]
if dim == 3:
label_text = "<br>".join(label_parts)
else:
label_text = "\n".join(label_parts)
labels[node] = label_text
return labels
if dim == 3:
fig = go.Figure()
if path:
path_edges = list(zip(path[:-1], path[1:]))
non_path_edges = [(u, v) for u, v in G.edges() if (u, v) not in path_edges and (v, u) not in path_edges]
if non_path_edges:
widths, colors = get_edge_props(non_path_edges, for_plotly=True)
for i, (u, v) in enumerate(non_path_edges):
x1, y1, z1 = pos[u]
x2, y2, z2 = pos[v]
fig.add_trace(
go.Scatter3d(
x=[x1, x2], y=[y1, y2], z=[z1, z2],
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
line=dict(color=colors[i], width=widths[i]),
opacity=0.5,
showlegend=False,
hoverinfo='none'
)
)
if path_edges:
path_colors = plt.cm.viridis(np.linspace(0, 1, len(path_edges)))
for i, (u, v) in enumerate(path_edges):
x1, y1, z1 = pos[u]
x2, y2, z2 = pos[v]
color = path_colors[i]
color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))
fig.add_trace(
go.Scatter3d(
x=[x1, x2], y=[y1, y2], z=[z1, z2],
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
line=dict(color=color_hex, width=6),
showlegend=False,
hoverinfo='none'
)
)
else:
edge_list = list(G.edges())
widths, colors = get_edge_props(edge_list, for_plotly=True)
for i, (u, v) in enumerate(edge_list):
x1, y1, z1 = pos[u]
x2, y2, z2 = pos[v]
fig.add_trace(
go.Scatter3d(
x=[x1, x2], y=[y1, y2], z=[z1, z2],
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
line=dict(color=colors[i], width=widths[i]),
showlegend=False,
hoverinfo='none'
)
)
node_x, node_y, node_z = [], [], []
node_text, hover_data = [], []
node_colors = []
labels = get_node_labels()
for node in G.nodes():
if node in pos:
x, y, z = pos[node]
node_x.append(x)
node_y.append(y)
node_z.append(z)
node_text.append(labels[node])
hover_data.append(labels[node])
if path and node in path:
path_index = path.index(node)
color = plt.cm.viridis(path_index / len(path))
color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))
node_colors.append(color_hex)
else:
node_colors.append('white')
fig.add_trace(
go.Scatter3d(
x=node_x, y=node_y, z=node_z,
mode='markers+text',
marker=dict(
size=node_size/50,
color=node_colors,
line=dict(color='white', width=2)
),
text=node_text,
textposition='middle center',
textfont=dict(color='black', size=font_size, family='Arial', weight='bold'),
hovertemplate='%{customdata}<extra></extra>',
customdata=hover_data,
showlegend=False
)
)
if show_edge_labels:
edge_weights = nx.get_edge_attributes(G, 'weight')
for (u, v), weight in edge_weights.items():
if u in pos and v in pos:
x1, y1, z1 = pos[u]
x2, y2, z2 = pos[v]
mid_x, mid_y, mid_z = (x1 + x2) / 2, (y1 + y2) / 2, (z1 + z2) / 2
fig.add_trace(
go.Scatter3d(
x=[mid_x], y=[mid_y], z=[mid_z],
mode='text',
text=[f'{weight:.2f}'],
textfont=dict(color='white', size=font_size-2),
showlegend=False,
hoverinfo='none'
)
)
width_px, height_px = int(figsize[0] * 100), int(figsize[1] * 100)
fig.update_layout(
width=width_px,
height=height_px,
paper_bgcolor='black',
plot_bgcolor='black',
scene=dict(
camera=dict(
eye=dict(x=1.5, y=1.5, z=1.5),
center=dict(x=0, y=0, z=0)
),
xaxis=dict(
showgrid=False, zeroline=False, showticklabels=False,
showline=False, showbackground=False,
title=dict(text='', font=dict(color='white'))
),
yaxis=dict(
showgrid=False, zeroline=False, showticklabels=False,
showline=False, showbackground=False,
title=dict(text='', font=dict(color='white'))
),
zaxis=dict(
showgrid=False, zeroline=False, showticklabels=False,
showline=False, showbackground=False,
title=dict(text='', font=dict(color='white'))
),
bgcolor='black'
),
hovermode='closest',
margin=dict(l=0, r=0, t=0, b=0)
)
if output_file:
if output_file.endswith('.html'):
fig.write_html(output_file)
else:
fig.write_image(output_file)
return fig
else:
plt.figure(figsize=figsize)
ax = plt.gca()
ax.set_facecolor('black')
plt.gcf().set_facecolor('black')
if path:
path_edges = list(zip(path[:-1], path[1:]))
non_path_edges = [(u, v) for u, v in G.edges() if (u, v) not in path_edges and (v, u) not in path_edges]
if non_path_edges:
widths, colors = get_edge_props(non_path_edges)
if is_directed:
for i, (u, v) in enumerate(non_path_edges):
rad = get_edge_rad(u, v)
nx.draw_networkx_edges(
G,
pos,
edgelist=[(u, v)],
edge_color=[colors[i]],
width=widths[i],
alpha=0.5,
arrows=True,
arrowstyle='-|>',
arrowsize=24,
connectionstyle=f"arc3,rad={rad}",
)
else:
nx.draw_networkx_edges(G, pos, edgelist=non_path_edges, edge_color=colors,
width=widths, alpha=0.5)
if path_edges:
path_colors = plt.cm.viridis(np.linspace(0, 1, len(path_edges)))
for (u, v), color in zip(path_edges, path_colors):
if is_directed:
rad = get_edge_rad(u, v)
nx.draw_networkx_edges(G, pos, edgelist=[(u, v)], edge_color=[color],
width=3, arrows=True,
arrowstyle='-|>', arrowsize=24,
connectionstyle=f"arc3,rad={rad}")
else:
nx.draw_networkx_edges(G, pos, edgelist=[(u, v)], edge_color=[color],
width=3)
else:
edge_list = list(G.edges())
widths, colors = get_edge_props(edge_list)
if is_directed:
for i, (u, v) in enumerate(edge_list):
rad = get_edge_rad(u, v)
nx.draw_networkx_edges(
G,
pos,
edgelist=[(u, v)],
edge_color=[colors[i]],
width=widths[i],
arrows=True,
arrowstyle='-|>',
arrowsize=24,
connectionstyle=f"arc3,rad={rad}",
)
else:
nx.draw_networkx_edges(G, pos, edge_color=colors, width=widths)
if path:
non_path_nodes = [node for node in G.nodes() if node not in path]
nx.draw_networkx_nodes(G, pos, nodelist=non_path_nodes, node_color='black',
node_size=node_size, edgecolors='white', linewidths=2)
colors = plt.cm.viridis(np.linspace(0, 1, len(path)))
nx.draw_networkx_nodes(G, pos, nodelist=path, node_color=colors,
node_size=node_size, edgecolors='white', linewidths=2)
else:
nx.draw_networkx_nodes(G, pos, node_color='black', node_size=node_size,
edgecolors='white', linewidths=2)
labels = get_node_labels()
nx.draw_networkx_labels(
G,
pos,
labels=labels,
font_color='white',
font_size=font_size,
bbox=dict(facecolor='black', edgecolor='none', alpha=0.4),
)
if show_edge_labels:
edge_weights = {(u,v): f'{w:.2f}' for (u,v), w in nx.get_edge_attributes(G, 'weight').items()}
if is_directed:
for (u, v), weight in edge_weights.items():
x1, y1 = pos[u]
x2, y2 = pos[v]
dx = x2 - x1
dy = y2 - y1
length = (dx**2 + dy**2)**0.5
mid_x = (x1 + x2) / 2
mid_y = (y1 + y2) / 2
if length > 0:
rad = get_edge_rad(u, v)
perp_x = -dy / length
perp_y = dx / length
offset = rad * length * 0.6
curve_mid_x = mid_x + perp_x * offset
curve_mid_y = mid_y + perp_y * offset
else:
curve_mid_x, curve_mid_y = mid_x, mid_y
ax.text(curve_mid_x, curve_mid_y, weight, ha='center', va='center',
color='white', fontsize=font_size,
bbox=dict(facecolor='black', edgecolor='none', alpha=0.6))
else:
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_weights,
font_color='white', font_size=font_size,
bbox=dict(facecolor='black', edgecolor='none', alpha=0.6),
label_pos=0.5, rotate=False)
plt.axis('off')
plt.margins(x=0.1, y=0.1)
if output_file:
plt.savefig(output_file, bbox_inches='tight', pad_inches=0,
facecolor='black', edgecolor='none')
plt.close()
else:
plt.show()
def _plot_curve(*args, figsize=(16, 8), x_range=(0, 1), colors=None, labels=None,
title=None, grid=True, legend=True, output_file=None):
"""
Plot one or more curves with a dark-background style.
Parameters
----------
*args : array-like
One or more sequences of y-values to plot.
figsize : tuple of float, optional
Width and height of the figure in inches.
x_range : tuple of float, optional
``(min, max)`` range for the x-axis.
colors : list or None, optional
Colors for each curve. Defaults to the viridis colormap.
labels : list of str or None, optional
Labels for the legend entries.
title : str or None, optional
Title displayed above the plot.
grid : bool, optional
Whether to render grid lines.
legend : bool, optional
Whether to display the legend.
output_file : str or None, optional
Path to save the figure. Displays the plot when ``None``.
"""
plt.figure(figsize=figsize)
ax = plt.gca()
ax.set_facecolor('black')
plt.gcf().set_facecolor('black')
curves = args
if not curves:
raise ValueError("At least one curve must be provided")
if colors is None and len(curves) > 1:
colors = plt.cm.viridis(np.linspace(0, 0.8, len(curves)))
elif colors is None:
colors = ['#e6e6e6'] # Default white
if labels is None:
labels = [f"Curve {i+1}" for i in range(len(curves))]
for i, curve in enumerate(curves):
if i < len(colors):
color = colors[i]
else:
color = plt.cm.viridis(i / len(curves))
label = labels[i] if i < len(labels) else f"Curve {i+1}"
x = np.linspace(x_range[0], x_range[1], len(curve))
ax.plot(x, curve, color=color, linewidth=2.5, label=label)
if title:
ax.set_title(title, color='white', fontsize=14)
ax.spines['bottom'].set_color('white')
ax.spines['top'].set_color('white')
ax.spines['right'].set_color('white')
ax.spines['left'].set_color('white')
ax.tick_params(axis='x', colors='white')
ax.tick_params(axis='y', colors='white')
if grid:
ax.grid(color='#555555', linestyle='-', linewidth=0.5, alpha=0.5)
if legend and len(curves) > 1:
ax.legend(frameon=True, facecolor='black', edgecolor='#555555', labelcolor='white')
plt.tight_layout()
if output_file:
plt.savefig(output_file, bbox_inches='tight', facecolor='black')
plt.close()
else:
plt.show()
def _plot_cs(cs: CombinationSet, figsize: tuple[float, float] = (12, 12),
node_size: float = 1000, font_size: float = 12,
show_edge_labels: bool = False, edge_alpha: float = 0.3,
title: str = None, output_file: str = None) -> None:
"""
Render a CombinationSet as a circular graph with connected combinations.
Parameters
----------
cs : CombinationSet
CombinationSet instance to visualize.
figsize : tuple of float, optional
Width and height of the figure in inches.
node_size : float, optional
Size of the drawn nodes.
font_size : float, optional
Font size for node labels.
show_edge_labels : bool, optional
Whether to display labels on edges.
edge_alpha : float, optional
Edge transparency in the range ``[0, 1]``.
title : str or None, optional
Plot title. Auto-generated when ``None``.
output_file : str or None, optional
Path to save the figure. Displays the plot when ``None``.
"""
plt.figure(figsize=figsize)
ax = plt.gca()
ax.set_facecolor('black')
plt.gcf().set_facecolor('black')
# Convert RustworkX graph to NetworkX for visualization
rx_graph = cs.graph._graph
G = nx.Graph()
# Add nodes with data
for node_idx in rx_graph.node_indices():
node_data = rx_graph.get_node_data(node_idx)
if isinstance(node_data, dict):
G.add_node(node_idx, **node_data)
else:
G.add_node(node_idx)
# Add edges with data
for src, tgt in rx_graph.edge_list():
edge_data = rx_graph.get_edge_data(src, tgt)
if edge_data is not None and isinstance(edge_data, dict):
G.add_edge(src, tgt, **edge_data)
else:
G.add_edge(src, tgt)
pos = nx.circular_layout(G)
# Draw edges with low alpha since it's a complete graph
nx.draw_networkx_edges(G, pos, edge_color='#808080', width=1, alpha=edge_alpha)
# Draw nodes
nx.draw_networkx_nodes(G, pos, node_color='black', node_size=node_size,
edgecolors='white', linewidths=2)
# Create labels from combos
labels = {}
for node, attrs in G.nodes(data=True):
if 'combo' in attrs:
combo = attrs['combo']
label = ''.join(str(cs.factor_to_alias[f]).strip('()') for f in combo)
labels[node] = label
nx.draw_networkx_labels(G, pos, labels=labels, font_color='white', font_size=font_size)
if show_edge_labels and G.number_of_edges() < 50: # Only show edge labels for smaller graphs
edge_labels = {(u, v): f'{u}-{v}' for u, v in G.edges()}
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels,
font_color='white', font_size=font_size-2,
bbox=dict(facecolor='black', edgecolor='none', alpha=0.6))
if title is None:
factor_string = ' '.join(str(cs.factor_to_alias[f]) for f in cs.factors)
title = f"CombinationSet r={cs.rank} [{factor_string}]"
ax.set_title(title, color='white', fontsize=14)
plt.axis('off')
plt.margins(x=0.1, y=0.1)
if output_file:
plt.savefig(output_file, bbox_inches='tight', pad_inches=0,
facecolor='black', edgecolor='none')
plt.close()
else:
plt.show()
def _plot_scale_chord(obj, figsize: tuple = (12, 12),
node_size: int = 30, text_size: int = 12, show_labels: bool = True,
title: str = None, output_file: str = None, nodes: list = None,
layout: str = 'circle') -> go.Figure:
"""
Render a Scale or Chord using interval-type-appropriate visualization.
Cents-based objects are drawn as circular clock-like node diagrams;
ratio-based objects use proportional segments showing interval sizes.
Parameters
----------
obj : Scale, Chord, or Voicing
The pitch structure to visualize.
figsize : tuple of float, optional
Width and height of the figure in inches.
node_size : int, optional
Node size (cents mode only).
text_size : int, optional
Font size for text labels.
show_labels : bool, optional
Whether to display labels on segments / nodes.
title : str or None, optional
Plot title. Derived from the object when ``None``.
output_file : str or None, optional
Path to save the figure. Displays the plot when ``None``.
nodes : list or None, optional
Node IDs to highlight (cents mode only).
layout : str, optional
``'circle'`` (default) or ``'line'`` for ratio-based objects.
Returns
-------
plotly.graph_objects.Figure
Interactive Plotly figure.
"""
if obj.is_relative:
calc_degrees = list(obj._degrees)
else:
calc_degrees = list(obj._pitches)
if not calc_degrees:
raise ValueError(f"{type(obj).__name__} has no degrees to plot")
n_degrees = len(calc_degrees)
degrees = obj.degrees
calc_obj = obj
fig = go.Figure()
# Branch based on interval type
if calc_obj._interval_type_mode == "cents":
return _plot_cents_scale_chord(obj, calc_obj, degrees, calc_degrees, fig, figsize,
node_size, text_size, show_labels, title, output_file, nodes)
else:
return _plot_ratio_scale_chord_clean(obj, calc_obj, degrees, calc_degrees, fig, figsize,
text_size, show_labels, title, output_file, layout)
def _plot_cents_scale_chord(obj, calc_obj, degrees, calc_degrees, fig, figsize,
node_size, text_size, show_labels, title, output_file, nodes):
"""Render cents-based scales/chords as circular node diagrams."""
n_degrees = len(degrees)
node_x, node_y = [], []
node_text, hover_data, node_colors = [], [], []
for i, degree in enumerate(degrees):
# Use the underlying collection's degrees for angle calculation
calc_degree = calc_degrees[i]
if calc_obj._interval_type_mode == "cents":
equave_value = calc_obj._equave if isinstance(calc_obj._equave, float) else 1200.0
proportion = calc_degree / equave_value
else:
equave_value = float(calc_obj._equave)
proportion = math.log(float(calc_degree)) / math.log(equave_value)
angle = -2 * math.pi * proportion + math.pi / 2
x = math.cos(angle)
y = math.sin(angle)
node_x.append(x)
node_y.append(y)
if obj.is_instanced:
calc_degree = calc_degrees[i]
if obj._interval_type_mode == "cents":
display_text = f"{calc_degree:.1f}¢"
base_hover = f"{calc_degree:.1f} cents"
else:
display_text = f"{calc_degree}"
base_hover = f"{calc_degree}"
note_name = degree.pitchclass
cents_offset = degree.cents_offset
cent_info = ""
if abs(cents_offset) > 0.01:
cent_info = f" ({cents_offset:+.2f}¢)"
hover_info = f"Node {i}<br>{base_hover}<br>{note_name}{cent_info}"
elif obj._interval_type_mode == "cents":
display_text = f"{degree:.1f}¢"
hover_info = f"Node {i}<br>{degree:.1f} cents"
else:
display_text = f"{degree}"
hover_info = f"Node {i}<br>{degree}"
node_text.append(display_text if show_labels else "")
hover_data.append(hover_info)
rainbow_color = plt.cm.hsv(i / n_degrees)
color_hex = '#%02x%02x%02x' % (int(rainbow_color[0]*255), int(rainbow_color[1]*255), int(rainbow_color[2]*255))
if nodes is not None:
# Handle both single list and list of lists, plus numpy arrays and Motive objects
all_highlighted_nodes = set()
# Convert nodes to a list if it's a single numpy array or Motive
if hasattr(nodes, 'to_numpy'): # Motive object
nodes = nodes.to_numpy().tolist()
elif hasattr(nodes, 'tolist') and hasattr(nodes, 'shape'): # numpy array
nodes = nodes.tolist()
# Check if it's a list of lists/arrays/motives
if isinstance(nodes, list) and len(nodes) > 0:
first_item = nodes[0]
# Check if first item is a container (list, array, motive) rather than a scalar
is_container_list = False
try:
is_container_list = (
hasattr(first_item, 'to_numpy') or # Motive
(hasattr(first_item, 'tolist') and hasattr(first_item, 'shape')) or # numpy array
(isinstance(first_item, list)) # list
)
except:
is_container_list = False
if is_container_list:
# Handle list of lists/arrays/motives
for node_list in nodes:
if hasattr(node_list, 'to_numpy'): # Motive object
all_highlighted_nodes.update(node_list.to_numpy().tolist())
elif hasattr(node_list, 'tolist') and hasattr(node_list, 'shape'): # numpy array
all_highlighted_nodes.update(node_list.tolist())
elif hasattr(node_list, '__iter__') and not isinstance(node_list, str): # regular list
all_highlighted_nodes.update(node_list)
else: # single value
all_highlighted_nodes.add(node_list)
else:
# Single list of scalar nodes
all_highlighted_nodes = set(nodes)
else:
# Single list of nodes (already converted above if needed)
all_highlighted_nodes = set(nodes)
if i in all_highlighted_nodes:
node_colors.append(color_hex)
else:
dimmed_color_hex = '#%02x%02x%02x' % (int(rainbow_color[0]*128), int(rainbow_color[1]*128), int(rainbow_color[2]*128))
node_colors.append(dimmed_color_hex)
else:
node_colors.append(color_hex)
for i in range(n_degrees):
for j in range(i + 1, n_degrees):
x1, y1 = node_x[i], node_y[i]
x2, y2 = node_x[j], node_y[j]
fig.add_trace(
go.Scatter(
x=[x1, x2], y=[y1, y2],
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
line=dict(color='#444444', width=1),
showlegend=False,
hoverinfo='none'
)
)
if nodes is not None:
# Convert nodes to consistent format for edge drawing
def convert_to_list(item):
"""Convert any node container to a plain list."""
if hasattr(item, 'to_numpy'): # Motive object
return item.to_numpy().tolist()
elif hasattr(item, 'tolist') and hasattr(item, 'shape'): # numpy array
return item.tolist()
elif isinstance(item, list):
return item
else:
return [item]
# Convert main nodes structure
if hasattr(nodes, 'to_numpy') or (hasattr(nodes, 'tolist') and hasattr(nodes, 'shape')):
# Single Motive or numpy array
processed_nodes = convert_to_list(nodes)
is_container_list = False
elif isinstance(nodes, list) and len(nodes) > 0:
first_item = nodes[0]
# Check if it's a list of containers
try:
is_container_list = (
hasattr(first_item, 'to_numpy') or # Motive
(hasattr(first_item, 'tolist') and hasattr(first_item, 'shape')) or # numpy array
isinstance(first_item, list) # list
)
except:
is_container_list = False
if is_container_list:
processed_nodes = [convert_to_list(item) for item in nodes]
else:
processed_nodes = nodes
is_container_list = False
else:
processed_nodes = nodes
is_container_list = False
if is_container_list:
# Multiple shapes - use viridis color scheme
viridis_colors = plt.cm.viridis(np.linspace(0, 1, len(processed_nodes)))
for shape_idx, node_list in enumerate(processed_nodes):
sorted_nodes = sorted(node_list)
viridis_color = viridis_colors[shape_idx]
color_hex = '#%02x%02x%02x' % (int(viridis_color[0]*255), int(viridis_color[1]*255), int(viridis_color[2]*255))
for i in range(len(sorted_nodes)):
current_idx = sorted_nodes[i]
next_idx = sorted_nodes[(i + 1) % len(sorted_nodes)]
if current_idx < len(node_x) and next_idx < len(node_x):
x1, y1 = node_x[current_idx], node_y[current_idx]
x2, y2 = node_x[next_idx], node_y[next_idx]
fig.add_trace(
go.Scatter(
x=[x1, x2], y=[y1, y2],
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
line=dict(color=color_hex, width=3),
showlegend=False,
hoverinfo='none'
)
)
else:
# Single shape - use white
sorted_nodes = sorted(processed_nodes)
for i in range(len(sorted_nodes)):
current_idx = sorted_nodes[i]
next_idx = sorted_nodes[(i + 1) % len(sorted_nodes)]
if current_idx < len(node_x) and next_idx < len(node_x):
x1, y1 = node_x[current_idx], node_y[current_idx]
x2, y2 = node_x[next_idx], node_y[next_idx]
fig.add_trace(
go.Scatter(
x=[x1, x2], y=[y1, y2],
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
line=dict(color='white', width=3),
showlegend=False,
hoverinfo='none'
)
)
fig.add_trace(
go.Scatter(
x=node_x, y=node_y,
mode='markers+text' if show_labels else 'markers',
marker=dict(
size=node_size,
color=node_colors,
line=dict(color='white', width=2)
),
text=node_text,
textposition='middle center',
textfont=dict(color='white', size=text_size, family='Arial', weight='bold'),
hovertemplate='%{customdata}<extra></extra>',
hoverlabel=dict(bgcolor='lightgrey', font_color='black'),
customdata=hover_data,
showlegend=False
)
)
if calc_obj._interval_type_mode == "cents":
equave_value = calc_obj._equave if isinstance(calc_obj._equave, float) else 1200.0
equave_text = f"Equave: {equave_value:.1f}¢"
else:
equave_text = f"Equave: {calc_obj._equave}"
fig.add_annotation(
x=0, y=1.3,
text=equave_text,
showarrow=False,
font=dict(color='white', size=text_size),
align='center'
)
if title is None:
if obj.is_instanced:
obj_type = type(obj).__name__
interval_type = "cents" if obj._interval_type_mode == "cents" else "ratios"
root_pitch = obj.reference_pitch
root_note = root_pitch.pitchclass
if abs(root_pitch.cents_offset) > 0.01:
root_note += f" ({root_pitch.cents_offset:+.2f}¢)"
title = f"{obj_type} ({interval_type}) - Root: {root_note}"
else:
obj_type = type(obj).__name__
if calc_obj._interval_type_mode == "cents":
title = f"{obj_type} (cents)"
else:
title = f"{obj_type} (ratios)"
width_px, height_px = int(figsize[0] * 100), int(figsize[1] * 100)
fig.update_layout(
title=dict(text=title, font=dict(color='white')),
width=width_px,
height=height_px,
paper_bgcolor='black',
plot_bgcolor='black',
xaxis=dict(
showgrid=False, zeroline=False, showticklabels=False,
range=[-1.5, 1.5]
),
yaxis=dict(
showgrid=False, zeroline=False, showticklabels=False,
scaleanchor="x", scaleratio=1,
range=[-1.5, 1.5]
),
hovermode='closest',
margin=dict(l=0, r=0, t=50, b=0),
)
if output_file:
if output_file.endswith('.html'):
fig.write_html(output_file)
else:
fig.write_image(output_file)
return fig
def _plot_dynamic_range(dynamic_range: DynamicRange, mode: str = 'db', figsize=(20, 5),
resolution: int = 1000, show_labels: bool = True,
show_grid: bool = True, title: str = None, output_file: str = None):
"""
Render a DynamicRange as a colored curve with dynamic-marking labels.
Parameters
----------
dynamic_range : DynamicRange
DynamicRange instance to visualize.
mode : str, optional
``'db'`` for decibels or ``'amp'`` for linear amplitude.
figsize : tuple of float, optional
Width and height of the figure in inches.
resolution : int, optional
Number of sample points used for smooth curve rendering.
show_labels : bool, optional
Whether to display dynamic-marking labels along the curve.
show_grid : bool, optional
Whether to draw grid lines.
title : str or None, optional
Plot title. Auto-generated when ``None``.
output_file : str or None, optional
Path to save the figure. Displays the plot when ``None``.
"""
plt.figure(figsize=figsize)
ax = plt.gca()
ax.set_facecolor('black')
plt.gcf().set_facecolor('black')
dynamics = dynamic_range._dynamics
num_dynamics = len(dynamics)
match mode.lower():
case 'db':
min_val = dynamic_range.min_dynamic.db
max_val = dynamic_range.max_dynamic.db
ylabel = 'Decibels (dB)'
get_value = lambda d: d.db
mode_display = 'dB'
case 'amp':
min_val = dynamic_range.min_dynamic.amp
max_val = dynamic_range.max_dynamic.amp
ylabel = 'Amplitude'
get_value = lambda d: d.amp
mode_display = 'amp'
case _:
raise ValueError(f"Invalid mode '{mode}'. Must be 'db' or 'amp'.")
x = np.linspace(0, 1, resolution)
y = np.zeros(resolution)
for i, xi in enumerate(x):
norm_pos = xi
if dynamic_range.curve == 0:
curved_pos = norm_pos
else:
curved_pos = (np.exp(dynamic_range.curve * norm_pos) - 1) / (np.exp(dynamic_range.curve) - 1)
value = min_val + curved_pos * (max_val - min_val)
y[i] = value
colors = plt.cm.plasma(np.linspace(0, 1, resolution))
for i in range(resolution - 1):
ax.plot([x[i], x[i+1]], [y[i], y[i+1]],
color=colors[i], linewidth=3, alpha=0.8)
if show_labels:
dynamic_positions = np.linspace(0, 1, num_dynamics)
for i, (pos, dyn) in enumerate(zip(dynamic_positions, dynamics)):
dynamic_obj = dynamic_range[dyn]
value = get_value(dynamic_obj)
ax.axvline(x=pos, color='white', linestyle='--', alpha=0.6, linewidth=1)
ax.text(pos, max_val + (max_val - min_val) * 0.02, dyn,
ha='center', va='bottom', color='white', fontsize=12, fontweight='bold')
ax.scatter([pos], [value], color='white', s=50, zorder=5, edgecolor='black', linewidth=1)
ax.set_xlim(-0.01, 1.01)
ax.set_ylim(min_val - (max_val - min_val) * 0.05, max_val + (max_val - min_val) * 0.1)
# ax.set_xlabel('Dynamic Range Position', color='white', fontsize=12)
ax.set_ylabel(ylabel, color='white', fontsize=12)
if title is None:
curve_desc = f"curve={dynamic_range.curve}" if dynamic_range.curve != 0 else "linear"
title = f"Dynamic Range ({mode_display}) - {curve_desc}"
ax.set_title(title, color='white', fontsize=14)
ax.spines['bottom'].set_color('white')
ax.spines['top'].set_color('white')
ax.spines['right'].set_color('white')
ax.spines['left'].set_color('white')
ax.tick_params(axis='x', colors='white')
ax.tick_params(axis='y', colors='white')
if show_grid:
ax.grid(color='#555555', linestyle='-', linewidth=0.5, alpha=0.5)
plt.tight_layout()
if output_file:
plt.savefig(output_file, bbox_inches='tight', facecolor='black')
plt.close()
else:
plt.show()
def _plot_envelope(envelope: Envelope, figsize=(20, 5), show_points: bool = True,
show_grid: bool = True, title: str = None, output_file: str = None,
resolution: int = 1000):
"""
Render an Envelope as a time-vs-value curve with breakpoint markers.
Parameters
----------
envelope : Envelope
Envelope instance to visualize.
figsize : tuple of float, optional
Width and height of the figure in inches.
show_points : bool, optional
Whether to draw markers at each breakpoint.
show_grid : bool, optional
Whether to display grid lines.
title : str or None, optional
Plot title. Defaults to ``"Envelope"``.
output_file : str or None, optional
Path to save the figure. Displays the plot when ``None``.
resolution : int, optional
Number of sample points for smooth curve rendering.
"""
plt.figure(figsize=figsize)
ax = plt.gca()
ax.set_facecolor('black')
plt.gcf().set_facecolor('black')
x = np.linspace(0, envelope.total_time, resolution)
y = np.array([envelope.at_time(t) for t in x])
ax.plot(x, y, color='#e6e6e6', linewidth=2.5)
if show_points:
point_times = envelope.breakpoint_times
point_values = envelope.values
ax.scatter(point_times, point_values, color='white', s=80,
zorder=5, edgecolor='black', linewidth=2)
for i, (t, v) in enumerate(zip(point_times, point_values)):
ax.text(t, v + (max(y) - min(y)) * 0.05, f'{v:.2f}',
ha='center', va='bottom', color='white', fontsize=10, fontweight='bold')
if title is None:
title = f"Envelope"
ax.set_title(title, color='white', fontsize=14)
ax.set_xlabel('Time', color='white', fontsize=12)
ax.set_ylabel('Value', color='white', fontsize=12)
ax.spines['bottom'].set_color('white')
ax.spines['top'].set_color('white')
ax.spines['right'].set_color('white')
ax.spines['left'].set_color('white')
ax.tick_params(axis='x', colors='white')
ax.tick_params(axis='y', colors='white')
if show_grid:
ax.grid(color='#555555', linestyle='-', linewidth=0.5, alpha=0.5)
plt.tight_layout()
if output_file:
plt.savefig(output_file, bbox_inches='tight', facecolor='black')
plt.close()
else:
plt.show()
def _plot_field(field: ParameterField, figsize: tuple[float, float] = (12, 12),
node_size: float = 8, title: str = None,
output_file: str = None,
dim_reduction: str = None, target_dims: int = 3,
mds_metric: bool = True, mds_max_iter: int = 300,
spectral_affinity: str = 'rbf', spectral_gamma: float = None,
nodes: list = None, path: list = None, path_mode: str = 'adjacent',
mute_background: bool = False, colormap: str = 'coolwarm', show_colorbar: bool = False) -> go.Figure:
"""
Render a ParameterField as a 2D or 3D grid with value-mapped node colors.
Parameters
----------
field : ParameterField
ParameterField instance to visualize.
figsize : tuple of float, optional
Width and height of the figure in inches.
node_size : float, optional
Size of the drawn nodes.
title : str or None, optional
Plot title. Auto-generated when ``None``.
output_file : str or None, optional
Path to save the figure. Displays the plot when ``None``.
dim_reduction : str or None, optional
Dimensionality reduction method for fields with more than 3
dimensions. ``'mds'``, ``'spectral'``, or ``None`` (raises
for dim > 3).
target_dims : int, optional
Target dimensions after reduction (2 or 3).
mds_metric : bool, optional
Use metric MDS when ``True``, non-metric when ``False``.
mds_max_iter : int, optional
Maximum iterations for the MDS algorithm.
spectral_affinity : str, optional
Kernel for spectral embedding.
spectral_gamma : float or None, optional
Kernel coefficient for rbf. Auto-determined when ``None``.
nodes : list of tuple or None, optional
Coordinate tuples to highlight.
path : list of tuple or None, optional
Coordinate tuples defining a traversal path.
path_mode : str, optional
``'adjacent'`` to show edges between neighbouring selected
nodes (default).
mute_background : bool, optional
Only show highlighted nodes / path coordinates when ``True``.
colormap : str, optional
Matplotlib colormap name for field value colouring.
show_colorbar : bool, optional
Whether to display a colourbar for field values.
Returns
-------
plotly.graph_objects.Figure
Interactive Plotly figure.
Raises
------
ValueError
If the field dimensionality exceeds 3 and *dim_reduction* is
``None``.
"""
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.manifold import MDS, SpectralEmbedding
# Convert nodes parameter to tuples if needed (safety mechanism for all lattice types)
if nodes is not None:
converted_nodes = []
for node in nodes:
if isinstance(node, (list, tuple)):
converted_nodes.append(tuple(node))
elif hasattr(node, 'tolist'): # numpy array
converted_nodes.append(tuple(node.tolist()))
elif hasattr(node, '__iter__') and not isinstance(node, str):
converted_nodes.append(tuple(node))
else:
converted_nodes.append(node)
nodes = converted_nodes
# Convert path parameter to tuples if needed (safety mechanism for all lattice types)
if path is not None:
converted_path = []
for coord in path:
if isinstance(coord, (list, tuple)):
converted_path.append(tuple(coord))
elif hasattr(coord, 'tolist'): # numpy array
converted_path.append(tuple(coord.tolist()))
elif hasattr(coord, '__iter__') and not isinstance(coord, str):
converted_path.append(tuple(coord))
else:
converted_path.append(coord)
path = converted_path
if field.dimensionality > 3 and dim_reduction is None:
raise ValueError(f"Plotting dimensionality > 3 requires dim_reduction. Got dimensionality={field.dimensionality}. "
f"Use dim_reduction='mds' or 'spectral'")
if target_dims not in [2, 3]:
raise ValueError(f"target_dims must be 2 or 3, got {target_dims}")
if field.dimensionality <= 2:
max_resolution = 5
elif field.dimensionality == 3:
max_resolution = 2
else:
if target_dims == 3:
max_resolution = 1
else:
max_resolution = 3
expected_total = 1
for dim in field._dims:
expected_total *= len(dim)
if expected_total > 10000:
expected_total = float('inf')
break
if nodes or path:
coord_ranges = []
all_coords_to_fit = []
if nodes:
all_coords_to_fit.extend(coord for coord in nodes if coord in field)
if path:
all_coords_to_fit.extend(coord for coord in path if coord in field)
if all_coords_to_fit:
for dim in range(field.dimensionality):
dim_vals = [coord[dim] for coord in all_coords_to_fit]
if dim_vals:
min_val, max_val = min(dim_vals), max(dim_vals)
coord_ranges.append((min_val - 1, max_val + 1))
else:
coord_ranges.append((-1, 1))
coords = []
import itertools
ranges = [range(start, end + 1) for start, end in coord_ranges]
coords = list(itertools.product(*ranges))
coords = [coord for coord in coords if coord in field]
else:
coords = field.coords
elif field.dimensionality > 3 or field._is_lazy or expected_total > 1000:
# For large fields, determine plotting area based on path extent if provided
if path:
# Calculate the range needed to encompass the entire path
coord_ranges = []
for dim in range(field.dimensionality):
dim_vals = [coord[dim] for coord in path if coord in field]
if dim_vals:
min_val, max_val = min(dim_vals), max(dim_vals)
# Add buffer around path
coord_ranges.append((min_val - 2, max_val + 2))
else:
coord_ranges.append((-max_resolution, max_resolution))
# Generate coordinates for the path-encompassing area
coords = []
import itertools
ranges = [range(start, end + 1) for start, end in coord_ranges]
coords = list(itertools.product(*ranges))
coords = [coord for coord in coords if coord in field]
else:
# No path provided, use default reduced coordinates
coords = field._get_plot_coords(max_resolution)
else:
coords = field.coords
if nodes or field.dimensionality > 3 or field._is_lazy or expected_total > 1000:
G_reduced = nx.Graph()
G_reduced.add_nodes_from(coords)
for i, coord1 in enumerate(coords):
for j, coord2 in enumerate(coords):
if i < j:
diff_count = sum(1 for a, b in zip(coord1, coord2) if abs(a - b) == 1)
same_count = sum(1 for a, b in zip(coord1, coord2) if a == b)
if diff_count == 1 and same_count == len(coord1) - 1:
G_reduced.add_edge(coord1, coord2)
G = G_reduced
else:
G = field
if nodes and path_mode == 'origin':
origin = tuple(0 for _ in range(field.dimensionality))
if origin not in coords:
coords.append(origin)
if hasattr(G, 'add_node'):
G.add_node(origin)
for coord in coords:
if coord != origin:
diff_count = sum(1 for a, b in zip(origin, coord) if abs(a - b) == 1)
same_count = sum(1 for a, b in zip(origin, coord) if a == b)
if diff_count == 1 and same_count == len(coord) - 1:
G.add_edge(origin, coord)
original_coords = coords
if field.dimensionality > 3:
coord_matrix = np.array([list(coord) for coord in coords])
if dim_reduction == 'mds':
reducer = MDS(n_components=target_dims, metric_mds=mds_metric, max_iter=mds_max_iter, init='random', n_init=4, random_state=42)
reduced_coords = reducer.fit_transform(coord_matrix)
elif dim_reduction == 'spectral':
if spectral_affinity == 'precomputed':
coord_to_idx = {coord: i for i, coord in enumerate(coords)}
n = len(coords)
adjacency_matrix = np.zeros((n, n))
for i, coord1 in enumerate(coords):
for j, coord2 in enumerate(coords):
if i != j:
diff_count = sum(1 for a, b in zip(coord1, coord2) if abs(a - b) == 1)
same_count = sum(1 for a, b in zip(coord1, coord2) if a == b)
if diff_count == 1 and same_count == len(coord1) - 1:
adjacency_matrix[i, j] = 1
reducer = SpectralEmbedding(n_components=target_dims, affinity='precomputed', random_state=42)
reduced_coords = reducer.fit_transform(adjacency_matrix)
else:
reducer = SpectralEmbedding(n_components=target_dims, affinity=spectral_affinity,
gamma=spectral_gamma, random_state=42)
reduced_coords = reducer.fit_transform(coord_matrix)
else:
raise ValueError(f"Unknown dim_reduction method: {dim_reduction}. Use 'mds' or 'spectral'")
coords = [tuple(reduced_coords[i]) for i in range(len(coords))]
effective_dimensionality = target_dims
coord_mapping = {original_coords[i]: coords[i] for i in range(len(coords))}
G_reduced = nx.Graph()
G_reduced.add_nodes_from(coords)
for i, coord1 in enumerate(original_coords):
for j, coord2 in enumerate(original_coords):
if i < j:
diff_count = sum(1 for a, b in zip(coord1, coord2) if abs(a - b) == 1)
same_count = sum(1 for a, b in zip(coord1, coord2) if a == b)
if diff_count == 1 and same_count == len(coord1) - 1:
u_reduced = coord_mapping[coord1]
v_reduced = coord_mapping[coord2]
G_reduced.add_edge(u_reduced, v_reduced)
G = G_reduced
else:
effective_dimensionality = field.dimensionality
field_values = []
for coord in original_coords:
try:
field_values.append(field.get_field_value(coord))
except KeyError:
field_values.append(0.0)
field_values = np.array(field_values)
if len(field_values) > 0:
vmin, vmax = field_values.min(), field_values.max()
if vmax == vmin:
vmax = vmin + 1e-10
else:
vmin, vmax = 0, 1
cmap = plt.get_cmap(colormap)
normalized_values = (field_values - vmin) / (vmax - vmin)
colors = [cmap(val) for val in normalized_values]
color_hex = ['#%02x%02x%02x' % (int(c[0]*255), int(c[1]*255), int(c[2]*255)) for c in colors]
# Map matplotlib colormap names to plotly colormap names
plotly_colormap_mapping = {
'viridis': 'viridis',
'plasma': 'plasma',
'inferno': 'inferno',
'magma': 'magma',
'coolwarm': 'rdbu',
'hot': 'hot',
'cool': 'blues',
'spring': 'greens',
'summer': 'ylgnbu',
'autumn': 'orrd',
'winter': 'blues',
'copper': 'burg',
'gray': 'greys',
'grey': 'greys',
'jet': 'jet',
'hsv': 'hsv',
'rainbow': 'rainbow',
'seismic': 'rdbu',
'terrain': 'earth',
'spectral': 'spectral',
'RdYlBu': 'rdylbu',
'RdBu': 'rdbu',
'PiYG': 'piyg',
'PRGn': 'prgn',
'BrBG': 'brbg',
'RdGy': 'rdgy',
'PuOr': 'puor'
}
plotly_colormap = plotly_colormap_mapping.get(colormap, 'viridis')
if title is None:
resolution_str = 'x'.join(str(r) for r in field.resolution)
bipolar_str = "bipolar" if field.bipolar else "unipolar"
if field.dimensionality > 3:
title = f"{field.dimensionality}D→{target_dims}D Field ({resolution_str}, {bipolar_str}, {dim_reduction})"
else:
title = f"{field.dimensionality}D Field ({resolution_str}, {bipolar_str})"
valid_coords = set(coords) if field.dimensionality <= 3 else set(original_coords)
highlighted_coords = set()
if nodes:
highlighted_coords.update(coord for coord in nodes if coord in valid_coords)
if path:
highlighted_coords.update(coord for coord in path if coord in valid_coords)
use_dimmed = ((nodes is not None and len(nodes) > 0) or (path is not None and len(path) > 0)) and not mute_background
fig = go.Figure()
if effective_dimensionality == 1:
# Draw edges with gradient colors based on field values
for u, v in G.edges():
x1, y1 = u[0], 0
x2, y2 = v[0], 0
# Get field values at both nodes
u_idx = coords.index(u) if u in coords else -1
v_idx = coords.index(v) if v in coords else -1
if u_idx >= 0 and v_idx >= 0 and u_idx < len(field_values) and v_idx < len(field_values):
u_val = field_values[u_idx]
v_val = field_values[v_idx]
# Create gradient color between the two values
avg_val = (u_val + v_val) / 2
normalized_avg = (avg_val - vmin) / (vmax - vmin) if vmax != vmin else 0.5
avg_color = cmap(normalized_avg)
edge_color_hex = '#%02x%02x%02x' % (int(avg_color[0]*255), int(avg_color[1]*255), int(avg_color[2]*255))
# Dim the edge color slightly
edge_color_dimmed = '#%02x%02x%02x' % (
int(int(edge_color_hex[1:3], 16) * 0.7),
int(int(edge_color_hex[3:5], 16) * 0.7),
int(int(edge_color_hex[5:7], 16) * 0.7)
)
else:
edge_color_dimmed = '#555555'
edge_width = 1 if use_dimmed else 2
fig.add_trace(
go.Scatter(
x=[x1, x2], y=[y1, y2],
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
line=dict(color=edge_color_dimmed, width=edge_width),
showlegend=False,
hoverinfo='none'
)
)
if nodes and len(highlighted_coords) >= 1:
highlighted_list = list(highlighted_coords)
if path_mode == 'adjacent' and len(highlighted_coords) > 1:
for i in range(len(highlighted_list)):
for j in range(i + 1, len(highlighted_list)):
coord1, coord2 = highlighted_list[i], highlighted_list[j]
diff_count = sum(1 for a, b in zip(coord1, coord2) if abs(a - b) == 1)
same_count = sum(1 for a, b in zip(coord1, coord2) if a == b)
is_lattice_adjacent = diff_count == 1 and same_count == len(coord1) - 1
if is_lattice_adjacent:
x1, x2 = coord1[0], coord2[0]
y1, y2 = 0, 0
fig.add_trace(
go.Scatter(
x=[x1, x2], y=[y1, y2],
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
line=dict(color='white', width=4),
showlegend=False,
hoverinfo='none'
)
)
elif path_mode == 'origin':
origin = tuple(0 for _ in range(field.dimensionality))
for target_coord in highlighted_list:
if target_coord != origin:
try:
if hasattr(G, 'has_node') and G.has_node(origin) and G.has_node(target_coord):
path_coords = nx.shortest_path(G, origin, target_coord)
for k in range(len(path_coords) - 1):
pc1, pc2 = path_coords[k], path_coords[k + 1]
x1, x2 = pc1[0], pc2[0]
y1, y2 = 0, 0
fig.add_trace(
go.Scatter(
x=[x1, x2], y=[y1, y2],
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
line=dict(color='white', width=4),
showlegend=False,
hoverinfo='none'
)
)
except (KeyError, nx.NetworkXNoPath):
continue
# Draw path edges with viridis coloring for time progression
if path and len(path) > 1:
viridis_colors = plt.cm.viridis(np.linspace(0.15, 1, len(path) - 1))
for i in range(len(path) - 1):
coord1, coord2 = path[i], path[i + 1]
if coord1 in coords and coord2 in coords:
x1, y1 = coord1[0], coord1[1]
x2, y2 = coord2[0], coord2[1]
color = viridis_colors[i]
path_color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))
# Add path edge with enhanced visibility
fig.add_trace(
go.Scatter(
x=[x1, x2], y=[y1, y2],
mode='lines',
line=dict(color=path_color_hex, width=8),
opacity=0.9,
showlegend=False,
hoverinfo='none'
)
)
# Add subtle white outline for better contrast
fig.add_trace(
go.Scatter(
x=[x1, x2], y=[y1, y2],
mode='lines',
line=dict(color='white', width=10),
opacity=0.3,
showlegend=False,
hoverinfo='none'
)
)
node_x, node_y = [], []
hover_data = []
node_colors = []
for i, coord in enumerate(coords):
x = coord[0]
node_x.append(x)
node_y.append(0)
orig_coord = coord if field.dimensionality <= 3 else original_coords[i]
field_val = field_values[i] if i < len(field_values) else 0.0
if field.dimensionality > 3:
orig_coord_str = str(original_coords[i]).replace(',)', ')')
reduced_coord_str = f"({x:.2f})"
hover_data.append(f"Original: {orig_coord_str}<br>Reduced: {reduced_coord_str}<br>Value: {field_val:.4f}")
else:
hover_data.append(f"Coordinate: ({x})<br>Value: {field_val:.4f}")
if nodes and orig_coord in highlighted_coords:
node_colors.append('white')
elif use_dimmed:
node_colors.append('#111111')
else:
try:
coord_idx = original_coords.index(orig_coord) if orig_coord in original_coords else i
color = color_hex[coord_idx] if coord_idx < len(color_hex) else '#FFFFFF'
node_colors.append(color)
except (ValueError, IndexError):
node_colors.append('#FFFFFF')
fig.add_trace(
go.Scatter(
x=node_x, y=node_y,
mode='markers',
marker=dict(
size=node_size * 2,
color=node_colors,
line=dict(color='white', width=2)
),
hovertemplate='%{text}<extra></extra>',
text=hover_data,
showlegend=False
)
)
fig.update_layout(
yaxis=dict(
showgrid=False, zeroline=False, showticklabels=False,
range=[-0.5, 0.5]
)
)
elif effective_dimensionality == 2:
# Draw edges with gradient colors based on field values
for u, v in G.edges():
x1, y1 = u
x2, y2 = v
# Get field values at both nodes
u_idx = coords.index(u) if u in coords else -1
v_idx = coords.index(v) if v in coords else -1
if u_idx >= 0 and v_idx >= 0 and u_idx < len(field_values) and v_idx < len(field_values):
u_val = field_values[u_idx]
v_val = field_values[v_idx]
# Create gradient color between the two values
avg_val = (u_val + v_val) / 2
normalized_avg = (avg_val - vmin) / (vmax - vmin) if vmax != vmin else 0.5
avg_color = cmap(normalized_avg)
edge_color_hex = '#%02x%02x%02x' % (int(avg_color[0]*255), int(avg_color[1]*255), int(avg_color[2]*255))
# Dim the edge color slightly
edge_color_dimmed = '#%02x%02x%02x' % (
int(int(edge_color_hex[1:3], 16) * 0.7),
int(int(edge_color_hex[3:5], 16) * 0.7),
int(int(edge_color_hex[5:7], 16) * 0.7)
)
else:
edge_color_dimmed = '#555555'
edge_width = 1 if use_dimmed else 2
fig.add_trace(
go.Scatter(
x=[x1, x2], y=[y1, y2],
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
line=dict(color=edge_color_dimmed, width=edge_width),
showlegend=False,
hoverinfo='none'
)
)
if nodes and len(highlighted_coords) >= 1:
highlighted_list = list(highlighted_coords)
if path_mode == 'adjacent' and len(highlighted_coords) > 1:
for i in range(len(highlighted_list)):
for j in range(i + 1, len(highlighted_list)):
coord1, coord2 = highlighted_list[i], highlighted_list[j]
diff_count = sum(1 for a, b in zip(coord1, coord2) if abs(a - b) == 1)
same_count = sum(1 for a, b in zip(coord1, coord2) if a == b)
is_lattice_adjacent = diff_count == 1 and same_count == len(coord1) - 1
if is_lattice_adjacent:
x1, y1 = coord1
x2, y2 = coord2
fig.add_trace(
go.Scatter(
x=[x1, x2], y=[y1, y2],
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
line=dict(color='white', width=4),
showlegend=False,
hoverinfo='none'
)
)
elif path_mode == 'origin':
origin = tuple(0 for _ in range(field.dimensionality))
for target_coord in highlighted_list:
if target_coord != origin:
try:
if hasattr(G, 'has_node') and G.has_node(origin) and G.has_node(target_coord):
path_coords = nx.shortest_path(G, origin, target_coord)
for k in range(len(path_coords) - 1):
pc1, pc2 = path_coords[k], path_coords[k + 1]
x1, y1 = pc1
x2, y2 = pc2
fig.add_trace(
go.Scatter(
x=[x1, x2], y=[y1, y2],
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
line=dict(color='white', width=4),
showlegend=False,
hoverinfo='none'
)
)
except (KeyError, nx.NetworkXNoPath):
continue
# Draw path edges with viridis coloring for time progression
if path and len(path) > 1:
viridis_colors = plt.cm.viridis(np.linspace(0.15, 1, len(path) - 1))
for i in range(len(path) - 1):
coord1, coord2 = path[i], path[i + 1]
if coord1 in coords and coord2 in coords:
x1, y1 = coord1
x2, y2 = coord2
color = viridis_colors[i]
path_color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))
# Add path edge with enhanced visibility
fig.add_trace(
go.Scatter(
x=[x1, x2], y=[y1, y2],
mode='lines',
line=dict(color=path_color_hex, width=8),
opacity=0.9,
showlegend=False,
hoverinfo='none'
)
)
# Add subtle white outline for better contrast
fig.add_trace(
go.Scatter(
x=[x1, x2], y=[y1, y2],
mode='lines',
line=dict(color='white', width=10),
opacity=0.3,
showlegend=False,
hoverinfo='none'
)
)
node_x, node_y = [], []
hover_data = []
node_colors = []
for i, coord in enumerate(coords):
x, y = coord
node_x.append(x)
node_y.append(y)
orig_coord = coord if field.dimensionality <= 3 else original_coords[i]
field_val = field_values[i] if i < len(field_values) else 0.0
if field.dimensionality > 3:
orig_coord_str = str(original_coords[i]).replace(',)', ')')
reduced_coord_str = f"({x:.2f}, {y:.2f})"
hover_data.append(f"Original: {orig_coord_str}<br>Reduced: {reduced_coord_str}<br>Value: {field_val:.4f}")
else:
hover_data.append(f"Coordinate: ({x}, {y})<br>Value: {field_val:.4f}")
if nodes and orig_coord in highlighted_coords:
node_colors.append('white')
elif use_dimmed:
node_colors.append('#111111')
else:
try:
coord_idx = original_coords.index(orig_coord) if orig_coord in original_coords else i
color = color_hex[coord_idx] if coord_idx < len(color_hex) else '#FFFFFF'
node_colors.append(color)
except (ValueError, IndexError):
node_colors.append('#FFFFFF')
fig.add_trace(
go.Scatter(
x=node_x, y=node_y,
mode='markers',
marker=dict(
size=node_size * 2,
color=node_colors,
line=dict(color='white', width=2)
),
hovertemplate='%{text}<extra></extra>',
text=hover_data,
showlegend=False
)
)
fig.update_layout(
yaxis=dict(
scaleanchor="x", scaleratio=1
)
)
elif effective_dimensionality == 3:
# Draw edges with gradient colors based on field values
for u, v in G.edges():
x1, y1, z1 = (u[0], u[1], u[2]) if len(u) >= 3 else (u[0], u[1] if len(u) >= 2 else 0, 0)
x2, y2, z2 = (v[0], v[1], v[2]) if len(v) >= 3 else (v[0], v[1] if len(v) >= 2 else 0, 0)
# Get field values at both nodes
u_idx = coords.index(u) if u in coords else -1
v_idx = coords.index(v) if v in coords else -1
if u_idx >= 0 and v_idx >= 0 and u_idx < len(field_values) and v_idx < len(field_values):
u_val = field_values[u_idx]
v_val = field_values[v_idx]
# Create gradient color between the two values
avg_val = (u_val + v_val) / 2
normalized_avg = (avg_val - vmin) / (vmax - vmin) if vmax != vmin else 0.5
avg_color = cmap(normalized_avg)
edge_color_hex = '#%02x%02x%02x' % (int(avg_color[0]*255), int(avg_color[1]*255), int(avg_color[2]*255))
# Dim the edge color slightly
edge_color_dimmed = '#%02x%02x%02x' % (
int(int(edge_color_hex[1:3], 16) * 0.7),
int(int(edge_color_hex[3:5], 16) * 0.7),
int(int(edge_color_hex[5:7], 16) * 0.7)
)
else:
edge_color_dimmed = '#555555'
edge_width = 1 if use_dimmed else 2
fig.add_trace(
go.Scatter3d(
x=[x1, x2], y=[y1, y2], z=[z1, z2],
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
line=dict(color=edge_color_dimmed, width=edge_width),
showlegend=False,
hoverinfo='none'
)
)
if nodes and len(highlighted_coords) >= 1:
highlighted_list = list(highlighted_coords)
if path_mode == 'adjacent' and len(highlighted_coords) > 1:
for i in range(len(highlighted_list)):
for j in range(i + 1, len(highlighted_list)):
coord1, coord2 = highlighted_list[i], highlighted_list[j]
diff_count = sum(1 for a, b in zip(coord1, coord2) if abs(a - b) == 1)
same_count = sum(1 for a, b in zip(coord1, coord2) if a == b)
is_lattice_adjacent = diff_count == 1 and same_count == len(coord1) - 1
if is_lattice_adjacent:
x1, y1, z1 = (coord1[0], coord1[1], coord1[2]) if len(coord1) >= 3 else (coord1[0], coord1[1] if len(coord1) >= 2 else 0, 0)
x2, y2, z2 = (coord2[0], coord2[1], coord2[2]) if len(coord2) >= 3 else (coord2[0], coord2[1] if len(coord2) >= 2 else 0, 0)
fig.add_trace(
go.Scatter3d(
x=[x1, x2], y=[y1, y2], z=[z1, z2],
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
line=dict(color='white', width=4),
showlegend=False,
hoverinfo='none'
)
)
elif path_mode == 'origin':
origin = tuple(0 for _ in range(field.dimensionality))
for target_coord in highlighted_list:
if target_coord != origin:
try:
if hasattr(G, 'has_node') and G.has_node(origin) and G.has_node(target_coord):
path_coords = nx.shortest_path(G, origin, target_coord)
for k in range(len(path_coords) - 1):
pc1, pc2 = path_coords[k], path_coords[k + 1]
x1, y1, z1 = (pc1[0], pc1[1], pc1[2]) if len(pc1) >= 3 else (pc1[0], pc1[1] if len(pc1) >= 2 else 0, 0)
x2, y2, z2 = (pc2[0], pc2[1], pc2[2]) if len(pc2) >= 3 else (pc2[0], pc2[1] if len(pc2) >= 2 else 0, 0)
fig.add_trace(
go.Scatter3d(
x=[x1, x2], y=[y1, y2], z=[z1, z2],
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
line=dict(color='white', width=4),
showlegend=False,
hoverinfo='none'
)
)
except (KeyError, nx.NetworkXNoPath):
continue
# Draw path edges with viridis coloring for time progression
if path and len(path) > 1:
viridis_colors = plt.cm.viridis(np.linspace(0.15, 1, len(path) - 1))
for i in range(len(path) - 1):
coord1, coord2 = path[i], path[i + 1]
if coord1 in coords and coord2 in coords:
x1, y1, z1 = (coord1[0], coord1[1], coord1[2]) if len(coord1) >= 3 else (coord1[0], coord1[1] if len(coord1) >= 2 else 0, 0)
x2, y2, z2 = (coord2[0], coord2[1], coord2[2]) if len(coord2) >= 3 else (coord2[0], coord2[1] if len(coord2) >= 2 else 0, 0)
color = viridis_colors[i]
path_color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))
# Add path edge with enhanced visibility
fig.add_trace(
go.Scatter3d(
x=[x1, x2], y=[y1, y2], z=[z1, z2],
mode='lines',
line=dict(color=path_color_hex, width=8),
opacity=0.9,
showlegend=False,
hoverinfo='none'
)
)
# Add subtle white outline for better contrast
fig.add_trace(
go.Scatter3d(
x=[x1, x2], y=[y1, y2], z=[z1, z2],
mode='lines',
line=dict(color='white', width=10),
opacity=0.3,
showlegend=False,
hoverinfo='none'
)
)
node_x, node_y, node_z = [], [], []
hover_data = []
node_colors = []
for i, coord in enumerate(coords):
if len(coord) >= 3:
x, y, z = coord[0], coord[1], coord[2]
elif len(coord) == 2:
x, y, z = coord[0], coord[1], 0
else:
x, y, z = coord[0], 0, 0
node_x.append(x)
node_y.append(y)
node_z.append(z)
orig_coord = coord if field.dimensionality <= 3 else original_coords[i]
field_val = field_values[i] if i < len(field_values) else 0.0
if field.dimensionality > 3:
orig_coord_str = str(original_coords[i]).replace(',)', ')')
reduced_coord_str = f"({x:.2f}, {y:.2f}, {z:.2f})"
hover_data.append(f"Original: {orig_coord_str}<br>Reduced: {reduced_coord_str}<br>Value: {field_val:.4f}")
else:
hover_data.append(f"Coordinate: ({x}, {y}, {z})<br>Value: {field_val:.4f}")
if nodes and orig_coord in highlighted_coords:
node_colors.append('white')
elif use_dimmed:
node_colors.append('#111111')
else:
try:
coord_idx = original_coords.index(orig_coord) if orig_coord in original_coords else i
color = color_hex[coord_idx] if coord_idx < len(color_hex) else '#FFFFFF'
node_colors.append(color)
except (ValueError, IndexError):
node_colors.append('#FFFFFF')
fig.add_trace(
go.Scatter3d(
x=node_x, y=node_y, z=node_z,
mode='markers',
marker=dict(
size=node_size,
color=node_colors,
line=dict(color='white', width=2)
),
hovertemplate='%{text}<extra></extra>',
text=hover_data,
showlegend=False
)
)
width_px, height_px = int(figsize[0] * 100), int(figsize[1] * 100)
x_coords = [coord[0] for coord in coords]
x_min, x_max = min(x_coords), max(x_coords)
if effective_dimensionality >= 2:
y_coords = [coord[1] for coord in coords]
y_min, y_max = min(y_coords), max(y_coords)
if effective_dimensionality == 3:
z_coords = [coord[2] for coord in coords]
z_min, z_max = min(z_coords), max(z_coords)
if field.dimensionality > 3:
x_ticks = np.linspace(x_min, x_max, min(10, int(x_max - x_min) + 1))
x_ticks = [round(t, 1) for t in x_ticks]
if effective_dimensionality >= 2:
y_ticks = np.linspace(y_min, y_max, min(10, int(y_max - y_min) + 1))
y_ticks = [round(t, 1) for t in y_ticks]
if effective_dimensionality == 3:
z_ticks = np.linspace(z_min, z_max, min(10, int(z_max - z_min) + 1))
z_ticks = [round(t, 1) for t in z_ticks]
else:
x_ticks = list(range(int(x_min), int(x_max) + 1))
if effective_dimensionality >= 2:
y_ticks = list(range(int(y_min), int(y_max) + 1))
if effective_dimensionality == 3:
z_ticks = list(range(int(z_min), int(z_max) + 1))
layout_dict = dict(
title=dict(text=title, font=dict(color='white')),
width=width_px,
height=height_px,
paper_bgcolor='black',
plot_bgcolor='black',
hovermode='closest',
margin=dict(l=0, r=0, t=50, b=0)
)
if effective_dimensionality <= 2:
if field.dimensionality > 3:
layout_dict.update(dict(
xaxis=dict(
showgrid=False, zeroline=False, showticklabels=False,
showline=False, title=dict(text='', font=dict(color='white'))
),
yaxis=dict(
showgrid=False, zeroline=False, showticklabels=False,
showline=False, title=dict(text='', font=dict(color='white'))
)
))
else:
x_title = 'X'
y_title = 'Y' if effective_dimensionality == 2 else ''
layout_dict.update(dict(
xaxis=dict(
title=dict(text=x_title, font=dict(color='white')),
tickfont=dict(color='white'),
gridcolor='#555555',
zerolinecolor='#555555',
tickmode='array',
tickvals=x_ticks,
ticktext=[str(t) for t in x_ticks]
),
yaxis=dict(
title=dict(text=y_title, font=dict(color='white')),
tickfont=dict(color='white'),
gridcolor='#555555',
zerolinecolor='#555555',
tickmode='array',
tickvals=y_ticks if effective_dimensionality == 2 else [0],
ticktext=[str(t) for t in y_ticks] if effective_dimensionality == 2 else ['']
)
))
else:
if field.dimensionality > 3:
layout_dict.update(dict(
scene=dict(
camera=dict(
eye=dict(x=1.5, y=1.5, z=1.5),
center=dict(x=0, y=0, z=0)
),
xaxis=dict(
showgrid=False, zeroline=False, showticklabels=False,
showline=False, showbackground=False,
title=dict(text='', font=dict(color='white'))
),
yaxis=dict(
showgrid=False, zeroline=False, showticklabels=False,
showline=False, showbackground=False,
title=dict(text='', font=dict(color='white'))
),
zaxis=dict(
showgrid=False, zeroline=False, showticklabels=False,
showline=False, showbackground=False,
title=dict(text='', font=dict(color='white'))
),
bgcolor='black'
)
))
else:
layout_dict.update(dict(
scene=dict(
camera=dict(
eye=dict(x=1.5, y=1.5, z=1.5),
center=dict(x=0, y=0, z=0)
),
xaxis=dict(
title=dict(text='X', font=dict(color='white')),
tickfont=dict(color='white'),
gridcolor='#555555',
zerolinecolor='#555555',
backgroundcolor='black',
tickmode='array',
tickvals=x_ticks,
ticktext=[str(t) for t in x_ticks]
),
yaxis=dict(
title=dict(text='Y', font=dict(color='white')),
tickfont=dict(color='white'),
gridcolor='#555555',
zerolinecolor='#555555',
backgroundcolor='black',
tickmode='array',
tickvals=y_ticks,
ticktext=[str(t) for t in y_ticks]
),
zaxis=dict(
title=dict(text='Z', font=dict(color='white')),
tickfont=dict(color='white'),
gridcolor='#555555',
zerolinecolor='#555555',
backgroundcolor='black',
tickmode='array',
tickvals=z_ticks,
ticktext=[str(t) for t in z_ticks]
),
bgcolor='black'
)
))
if show_colorbar and len(field_values) > 0 and vmax != vmin:
fig.add_trace(
go.Scatter(
x=[None], y=[None],
mode='markers',
marker=dict(
colorscale=plotly_colormap,
cmin=vmin,
cmax=vmax,
colorbar=dict(
title=dict(text="Field Value", font=dict(color='white')),
tickfont=dict(color='white'),
x=1.02
),
showscale=True
),
showlegend=False,
hoverinfo='skip'
)
)
fig.update_layout(**layout_dict)
if output_file:
if output_file.endswith('.html'):
fig.write_html(output_file)
else:
fig.write_image(output_file)
return fig
def _plot_ratio_scale_chord_new(obj, calc_obj, degrees, calc_degrees, fig, figsize,
text_size, show_labels, title, output_file, layout):
"""Render ratio-based scales/chords as proportional segments."""
n_degrees = len(degrees)
if n_degrees < 2:
raise ValueError("Need at least 2 degrees to plot intervals")
# Get the complete intervals (including final interval to equave)
if hasattr(calc_obj, 'complete_intervals'):
interval_ratios = calc_obj.complete_intervals
else:
# Fallback for other types
interval_ratios = list(calc_obj.intervals)
if calc_obj._degrees:
final_interval = calc_obj._equave / calc_obj._degrees[-1]
interval_ratios.append(final_interval)
# Convert interval ratios to log sizes for proportional display
intervals = [math.log(float(ratio)) for ratio in interval_ratios]
n_segments = len(interval_ratios)
total_log_size = sum(intervals)
if layout == 'circle':
current_angle = math.pi / 2
# Color based on distance from unison (1/1)
colors = []
for i in range(n_segments):
if i < len(calc_degrees):
# Distance from unison in log space
distance = abs(math.log(float(calc_degrees[i])))
hue = min(distance / 2.0, 1.0) # Normalize and cap at 1.0
else:
# Final interval - use distance based on equave
distance = abs(math.log(float(calc_obj._equave)))
hue = min(distance / 2.0, 1.0)
colors.append(plt.cm.hsv(hue))
for i, (interval_size, interval_ratio) in enumerate(zip(intervals, interval_ratios)):
if interval_size <= 0:
continue
proportion = interval_size / total_log_size
angle_span = 2 * math.pi * proportion
num_points = max(50, int(angle_span * 50))
angles = np.linspace(current_angle, current_angle - angle_span, num_points)
inner_radius = 0.85
outer_radius = 1.0
x_outer = outer_radius * np.cos(angles)
y_outer = outer_radius * np.sin(angles)
x_inner = inner_radius * np.cos(angles)
y_inner = inner_radius * np.sin(angles)
x_coords = np.concatenate([x_outer, x_inner[::-1], [x_outer[0]]])
y_coords = np.concatenate([y_outer, y_inner[::-1], [y_outer[0]]])
color = colors[i]
color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))
if i < len(degrees):
degree = degrees[i]
calc_degree = calc_degrees[i]
if obj.is_instanced:
note_name = degree.pitchclass
cents_offset = degree.cents_offset
cent_info = f" ({cents_offset:+.2f}¢)" if abs(cents_offset) > 0.01 else ""
hover_text = f"Degree {i}<br>{calc_degree}<br>{note_name}{cent_info}"
label_text = f"{calc_degree}" if show_labels else ""
else:
hover_text = f"Degree {i}<br>{calc_degree}"
label_text = f"{calc_degree}" if show_labels else ""
else:
hover_text = f"To Equave<br>{calc_obj._equave}"
label_text = f"{calc_obj._equave}" if show_labels else ""
fig.add_trace(go.Scatter(
x=x_coords, y=y_coords,
fill='toself',
fillcolor=color_hex,
line=dict(color='white', width=1),
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
showlegend=False,
hovertemplate=f'{hover_text}<extra></extra>',
hoverlabel=dict(bgcolor='lightgrey', font_color='black')
))
if show_labels and label_text:
mid_angle = current_angle - angle_span / 2
label_radius = 0.85
label_x = label_radius * math.cos(mid_angle)
label_y = label_radius * math.sin(mid_angle)
fig.add_trace(go.Scatter(
x=[label_x], y=[label_y],
mode='text',
text=[label_text],
textfont=dict(color='white', size=text_size, family='Arial', weight='bold'),
showlegend=False,
hoverinfo='skip'
))
current_angle -= angle_span
fig.update_layout(
xaxis=dict(
range=[-1.2, 1.2],
showgrid=False,
zeroline=False,
showticklabels=False
),
yaxis=dict(
range=[-1.2, 1.2],
scaleanchor="x",
scaleratio=1,
showgrid=False,
zeroline=False,
showticklabels=False
)
)
else:
current_pos = 0
# Color based on distance from unison (1/1)
colors = []
for i in range(n_segments):
if i < len(calc_degrees):
# Distance from unison in log space
distance = abs(math.log(float(calc_degrees[i])))
hue = min(distance / 2.0, 1.0) # Normalize and cap at 1.0
else:
# Final interval - use distance based on equave
distance = abs(math.log(float(calc_obj._equave)))
hue = min(distance / 2.0, 1.0)
colors.append(plt.cm.hsv(hue))
y_center = 0
bar_height = 0.3
for i, interval_size in enumerate(intervals):
if interval_size <= 0:
continue
proportion = interval_size / total_log_size
segment_width = proportion * 2.0
x_coords = [current_pos, current_pos + segment_width, current_pos + segment_width, current_pos, current_pos]
y_coords = [y_center - bar_height/2, y_center - bar_height/2, y_center + bar_height/2, y_center + bar_height/2, y_center - bar_height/2]
color = colors[i]
color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))
if i < len(degrees):
degree = degrees[i]
calc_degree = calc_degrees[i]
if obj.is_instanced:
note_name = degree.pitchclass
cents_offset = degree.cents_offset
cent_info = f" ({cents_offset:+.2f}¢)" if abs(cents_offset) > 0.01 else ""
hover_text = f"Degree {i}<br>{calc_degree}<br>{note_name}{cent_info}"
label_text = f"{calc_degree}" if show_labels else ""
else:
hover_text = f"Degree {i}<br>{calc_degree}"
label_text = f"{calc_degree}" if show_labels else ""
else:
hover_text = f"To Equave<br>{calc_obj._equave}"
label_text = f"{calc_obj._equave}" if show_labels else ""
fig.add_trace(go.Scatter(
x=x_coords, y=y_coords,
fill='toself',
fillcolor=color_hex,
line=dict(color='white', width=1),
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
showlegend=False,
hovertemplate=f'{hover_text}<extra></extra>',
hoverlabel=dict(bgcolor='lightgrey', font_color='black')
))
if show_labels and label_text:
label_x = current_pos + segment_width / 2
label_y = y_center
fig.add_trace(go.Scatter(
x=[label_x], y=[label_y],
mode='text',
text=[label_text],
textfont=dict(color='white', size=text_size, family='Arial', weight='bold'),
showlegend=False,
hoverinfo='skip'
))
current_pos += segment_width
fig.update_layout(
xaxis=dict(
range=[-0.1, 2.1],
showgrid=False,
zeroline=False,
showticklabels=False
),
yaxis=dict(
range=[-0.5, 0.5],
showgrid=False,
zeroline=False,
showticklabels=False
)
)
if title is None:
title = repr(obj)
width_px, height_px = int(figsize[0] * 100), int(figsize[1] * 100)
fig.update_layout(
title=dict(text=title, font=dict(color='white')),
width=width_px,
height=height_px,
paper_bgcolor='black',
plot_bgcolor='black',
hovermode='closest',
margin=dict(l=0, r=0, t=50, b=0),
)
if output_file:
if output_file.endswith('.html'):
fig.write_html(output_file)
else:
fig.write_image(output_file)
return fig
def _plot_ratio_scale_chord_fixed(obj, calc_obj, degrees, calc_degrees, fig, figsize,
text_size, show_labels, title, output_file, layout):
"""Render ratio-based scales/chords as proportional segments (fixed intervals)."""
n_degrees = len(degrees)
if n_degrees < 2:
raise ValueError("Need at least 2 degrees to plot intervals")
# Use the intervals property (Scale now includes final interval)
interval_ratios = calc_obj.intervals
# Convert interval ratios to log sizes for proportional display
intervals = [math.log(float(ratio)) for ratio in interval_ratios]
n_segments = len(interval_ratios)
total_log_size = sum(intervals)
if layout == 'circle':
current_angle = math.pi / 2
# Color based on distance from unison (1/1)
colors = []
for i in range(n_segments):
if i < len(calc_degrees):
# Distance from unison in log space
distance = abs(math.log(float(calc_degrees[i])))
hue = min(distance / 2.0, 1.0) # Normalize and cap at 1.0
else:
# Final interval - use distance based on equave
distance = abs(math.log(float(calc_obj._equave)))
hue = min(distance / 2.0, 1.0)
colors.append(plt.cm.hsv(hue))
# Add degree labels at borders first
if show_labels:
for i, calc_degree in enumerate(calc_degrees):
degree_angle = current_angle
for j in range(i):
degree_angle -= 2 * math.pi * (intervals[j] / total_log_size)
degree_radius = 1.1
degree_x = degree_radius * math.cos(degree_angle)
degree_y = degree_radius * math.sin(degree_angle)
fig.add_trace(go.Scatter(
x=[degree_x], y=[degree_y],
mode='text',
text=[f"{calc_degree}"],
textfont=dict(color='white', size=text_size+2, family='Arial'),
showlegend=False,
hovertemplate=f'Node {i}<br>Degree: {calc_degree}<extra></extra>',
hoverlabel=dict(bgcolor='lightgrey', font_color='black')
))
# Add equave label at the end
equave_x = 1.1 * math.cos(math.pi / 2)
equave_y = 1.1 * math.sin(math.pi / 2)
fig.add_trace(go.Scatter(
x=[equave_x], y=[equave_y],
mode='text',
text=["1/1"],
textfont=dict(color='white', size=text_size+2, family='Arial'),
showlegend=False,
hovertemplate=f'Node 0<br>Degree: 1/1<extra></extra>',
hoverlabel=dict(bgcolor='lightgrey', font_color='black')
))
# Draw interval segments
for i, (interval_size, interval_ratio) in enumerate(zip(intervals, interval_ratios)):
if interval_size <= 0:
continue
proportion = interval_size / total_log_size
angle_span = 2 * math.pi * proportion
num_points = max(50, int(angle_span * 50))
angles = np.linspace(current_angle, current_angle - angle_span, num_points)
inner_radius = 0.85
outer_radius = 1.0
x_outer = outer_radius * np.cos(angles)
y_outer = outer_radius * np.sin(angles)
x_inner = inner_radius * np.cos(angles)
y_inner = inner_radius * np.sin(angles)
x_coords = np.concatenate([x_outer, x_inner[::-1], [x_outer[0]]])
y_coords = np.concatenate([y_outer, y_inner[::-1], [y_outer[0]]])
color = colors[i]
color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))
# Hover text for the interval segment
hover_text = f"Interval: {interval_ratio}"
fig.add_trace(go.Scatter(
x=x_coords, y=y_coords,
fill='toself',
fillcolor=color_hex,
line=dict(color='white', width=1),
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
showlegend=False,
hovertemplate=f'{hover_text}<extra></extra>',
hoverlabel=dict(bgcolor='lightgrey', font_color='black')
))
# Add interval label inside the segment
if show_labels:
mid_angle = current_angle - angle_span / 2
label_radius = 0.925
label_x = label_radius * math.cos(mid_angle)
label_y = label_radius * math.sin(mid_angle)
fig.add_trace(go.Scatter(
x=[label_x], y=[label_y],
mode='text',
text=[f"{interval_ratio}"],
textfont=dict(color='black', size=text_size+2, family='Arial', weight='bold'),
showlegend=False,
hoverinfo='skip'
))
current_angle -= angle_span
fig.update_layout(
xaxis=dict(
range=[-1.2, 1.2],
showgrid=False,
zeroline=False,
showticklabels=False
),
yaxis=dict(
range=[-1.2, 1.2],
scaleanchor="x",
scaleratio=1,
showgrid=False,
zeroline=False,
showticklabels=False
)
)
else: # line layout
current_pos = 0
# Color based on distance from unison (1/1)
colors = []
for i in range(n_segments):
if i < len(calc_degrees):
# Distance from unison in log space
distance = abs(math.log(float(calc_degrees[i])))
hue = min(distance / 2.0, 1.0) # Normalize and cap at 1.0
else:
# Final interval - use distance based on equave
distance = abs(math.log(float(calc_obj._equave)))
hue = min(distance / 2.0, 1.0)
colors.append(plt.cm.hsv(hue))
y_center = 0
bar_height = 0.3
# Add degree labels at borders
if show_labels:
for i, calc_degree in enumerate(calc_degrees):
degree_pos = current_pos
for j in range(i):
degree_pos += (intervals[j] / total_log_size) * 2.0
fig.add_trace(go.Scatter(
x=[degree_pos], y=[y_center + bar_height/2 + 0.1],
mode='text',
text=[f"{calc_degree}"],
textfont=dict(color='white', size=text_size+2, family='Arial'),
showlegend=False,
hovertemplate=f'Node {i}<br>Degree: {calc_degree}<extra></extra>',
hoverlabel=dict(bgcolor='lightgrey', font_color='black')
))
# Add equave label at the end
fig.add_trace(go.Scatter(
x=[2.0], y=[y_center + bar_height/2 + 0.1],
mode='text',
text=["1/1"],
textfont=dict(color='white', size=text_size+2, family='Arial'),
showlegend=False,
hovertemplate=f'Node 0<br>Degree: 1/1<extra></extra>',
hoverlabel=dict(bgcolor='lightgrey', font_color='black')
))
# Draw interval segments
for i, (interval_size, interval_ratio) in enumerate(zip(intervals, interval_ratios)):
if interval_size <= 0:
continue
proportion = interval_size / total_log_size
segment_width = proportion * 2.0
x_coords = [current_pos, current_pos + segment_width, current_pos + segment_width, current_pos, current_pos]
y_coords = [y_center - bar_height/2, y_center - bar_height/2, y_center + bar_height/2, y_center + bar_height/2, y_center - bar_height/2]
color = colors[i]
color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))
# Hover text for the interval segment
hover_text = f"Interval: {interval_ratio}"
fig.add_trace(go.Scatter(
x=x_coords, y=y_coords,
fill='toself',
fillcolor=color_hex,
line=dict(color='white', width=1),
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
showlegend=False,
hovertemplate=f'{hover_text}<extra></extra>',
hoverlabel=dict(bgcolor='lightgrey', font_color='black')
))
# Add interval label inside the segment
if show_labels:
label_x = current_pos + segment_width / 2
label_y = y_center
fig.add_trace(go.Scatter(
x=[label_x], y=[label_y],
mode='text',
text=[f"{interval_ratio}"],
textfont=dict(color='black', size=text_size+2, family='Arial', weight='bold'),
showlegend=False,
hoverinfo='skip'
))
current_pos += segment_width
fig.update_layout(
xaxis=dict(
range=[-0.1, 2.1],
showgrid=False,
zeroline=False,
showticklabels=False
),
yaxis=dict(
range=[-0.5, 0.5],
showgrid=False,
zeroline=False,
showticklabels=False
)
)
if title is None:
title = repr(obj)
width_px, height_px = int(figsize[0] * 100), int(figsize[1] * 100)
fig.update_layout(
title=dict(text=title, font=dict(color='white')),
width=width_px,
height=height_px,
paper_bgcolor='black',
plot_bgcolor='black',
hovermode='closest',
margin=dict(l=0, r=0, t=50, b=0),
)
if output_file:
if output_file.endswith('.html'):
fig.write_html(output_file)
else:
fig.write_image(output_file)
return fig
def _plot_ratio_scale_chord_clean(obj, calc_obj, degrees, calc_degrees, fig, figsize,
text_size, show_labels, title, output_file, layout):
"""Render ratio-based scales/chords as proportional segments (clean style)."""
n_degrees = len(degrees)
if n_degrees < 2:
raise ValueError("Need at least 2 degrees to plot intervals")
# Use the intervals property (Scale now includes final interval)
interval_ratios = calc_obj.intervals
# Convert interval ratios to log sizes for proportional display
intervals = [math.log(float(ratio)) for ratio in interval_ratios]
n_segments = len(interval_ratios)
total_log_size = sum(intervals)
# Generate distinct colors for each segment
colors = plt.cm.Set1(np.linspace(0, 1, n_segments))
if layout == 'circle':
current_angle = math.pi / 2
# Add degree labels at borders first
if show_labels:
for i, calc_degree in enumerate(calc_degrees):
degree_angle = current_angle
for j in range(i):
degree_angle -= 2 * math.pi * (intervals[j] / total_log_size)
degree_radius = 1.1
degree_x = degree_radius * math.cos(degree_angle)
degree_y = degree_radius * math.sin(degree_angle)
fig.add_trace(go.Scatter(
x=[degree_x], y=[degree_y],
mode='text',
text=[f"{calc_degree}"],
textfont=dict(color='white', size=text_size+2, family='Arial'),
showlegend=False,
hovertemplate=f'Node {i}<br>Degree: {calc_degree}<extra></extra>',
hoverlabel=dict(bgcolor='lightgrey', font_color='black')
))
# Add equave label at the end
equave_x = 1.1 * math.cos(math.pi / 2)
equave_y = 1.1 * math.sin(math.pi / 2)
fig.add_trace(go.Scatter(
x=[equave_x], y=[equave_y],
mode='text',
text=["1/1"],
textfont=dict(color='white', size=text_size+2, family='Arial'),
showlegend=False,
hovertemplate=f'Node 0<br>Degree: 1/1<extra></extra>',
hoverlabel=dict(bgcolor='lightgrey', font_color='black')
))
# Draw interval segments
for i, (interval_size, interval_ratio) in enumerate(zip(intervals, interval_ratios)):
if interval_size <= 0:
continue
proportion = interval_size / total_log_size
angle_span = 2 * math.pi * proportion
num_points = max(50, int(angle_span * 50))
angles = np.linspace(current_angle, current_angle - angle_span, num_points)
inner_radius = 0.85
outer_radius = 1.0
x_outer = outer_radius * np.cos(angles)
y_outer = outer_radius * np.sin(angles)
x_inner = inner_radius * np.cos(angles)
y_inner = inner_radius * np.sin(angles)
x_coords = np.concatenate([x_outer, x_inner[::-1], [x_outer[0]]])
y_coords = np.concatenate([y_outer, y_inner[::-1], [y_outer[0]]])
color = colors[i]
color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))
# Hover text for the interval segment
hover_text = f"Interval: {interval_ratio}"
fig.add_trace(go.Scatter(
x=x_coords, y=y_coords,
fill='toself',
fillcolor=color_hex,
line=dict(color='white', width=1),
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
showlegend=False,
hovertemplate=f'{hover_text}<extra></extra>',
hoverlabel=dict(bgcolor='lightgrey', font_color='black')
))
# Add interval label inside the segment
if show_labels:
mid_angle = current_angle - angle_span / 2
label_radius = 0.925
label_x = label_radius * math.cos(mid_angle)
label_y = label_radius * math.sin(mid_angle)
fig.add_trace(go.Scatter(
x=[label_x], y=[label_y],
mode='text',
text=[f"{interval_ratio}"],
textfont=dict(color='black', size=text_size+2, family='Arial', weight='bold'),
showlegend=False,
hoverinfo='skip'
))
current_angle -= angle_span
fig.update_layout(
xaxis=dict(
range=[-1.2, 1.2],
showgrid=False,
zeroline=False,
showticklabels=False
),
yaxis=dict(
range=[-1.2, 1.2],
scaleanchor="x",
scaleratio=1,
showgrid=False,
zeroline=False,
showticklabels=False
)
)
else: # line layout
# Use better dimensions for line layout
figsize = (20, 1.5)
current_pos = 0
y_center = 0
bar_height = 0.2
# Add degree labels at borders
if show_labels:
# First degree (1/1)
fig.add_trace(go.Scatter(
x=[0], y=[y_center + bar_height/2 + 0.05],
mode='text',
text=["1/1"],
textfont=dict(color='white', size=text_size+2, family='Arial'),
showlegend=False,
hovertemplate=f'Node 0<br>Degree: 1/1<extra></extra>',
hoverlabel=dict(bgcolor='lightgrey', font_color='black')
))
# Other degrees
for i, calc_degree in enumerate(calc_degrees[1:], 1):
degree_pos = 0
for j in range(i):
degree_pos += (intervals[j] / total_log_size) * 2.0
fig.add_trace(go.Scatter(
x=[degree_pos], y=[y_center + bar_height/2 + 0.05],
mode='text',
text=[f"{calc_degree}"],
textfont=dict(color='white', size=text_size+2, family='Arial'),
showlegend=False,
hovertemplate=f'Node {i}<br>Degree: {calc_degree}<extra></extra>',
hoverlabel=dict(bgcolor='lightgrey', font_color='black')
))
# Add equave label at the end
fig.add_trace(go.Scatter(
x=[2.0], y=[y_center + bar_height/2 + 0.05],
mode='text',
text=[f"{calc_obj._equave}"],
textfont=dict(color='white', size=text_size+2, family='Arial'),
showlegend=False,
hovertemplate=f'Node {len(calc_degrees)}<br>Degree: {calc_obj._equave}<extra></extra>',
hoverlabel=dict(bgcolor='lightgrey', font_color='black')
))
# Draw interval segments
for i, (interval_size, interval_ratio) in enumerate(zip(intervals, interval_ratios)):
if interval_size <= 0:
continue
proportion = interval_size / total_log_size
segment_width = proportion * 2.0
x_coords = [current_pos, current_pos + segment_width, current_pos + segment_width, current_pos, current_pos]
y_coords = [y_center - bar_height/2, y_center - bar_height/2, y_center + bar_height/2, y_center + bar_height/2, y_center - bar_height/2]
color = colors[i]
color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))
# Hover text for the interval segment
hover_text = f"Interval: {interval_ratio}"
fig.add_trace(go.Scatter(
x=x_coords, y=y_coords,
fill='toself',
fillcolor=color_hex,
line=dict(color='white', width=1),
mode='lines+markers',
marker=dict(size=0.1, opacity=0),
showlegend=False,
hovertemplate=f'{hover_text}<extra></extra>',
hoverlabel=dict(bgcolor='lightgrey', font_color='black')
))
# Add interval label inside the segment
if show_labels:
label_x = current_pos + segment_width / 2
label_y = y_center
fig.add_trace(go.Scatter(
x=[label_x], y=[label_y],
mode='text',
text=[f"{interval_ratio}"],
textfont=dict(color='black', size=text_size+2, family='Arial', weight='bold'),
showlegend=False,
hoverinfo='skip'
))
current_pos += segment_width
fig.update_layout(
xaxis=dict(
range=[-0.05, 2.05],
showgrid=False,
zeroline=False,
showticklabels=False
),
yaxis=dict(
range=[-0.2, 0.2],
showgrid=False,
zeroline=False,
showticklabels=False
)
)
if title is None:
title = repr(obj)
width_px, height_px = int(figsize[0] * 100), int(figsize[1] * 100)
fig.update_layout(
title=dict(text=title, font=dict(color='white')),
width=width_px,
height=height_px,
paper_bgcolor='black',
plot_bgcolor='black',
hovermode='closest',
margin=dict(l=0, r=0, t=50, b=0),
)
if output_file:
if output_file.endswith('.html'):
fig.write_html(output_file)
else:
fig.write_image(output_file)
return fig