from ..graphs import Graph
import rustworkx as rx
from functools import cached_property, lru_cache
from .group import Group
import copy
[docs]
class Tree(Graph):
"""A directed acyclic graph tree structure.
Represents a rooted tree where each node has at most one parent.
Built from nested tuple structures and backed by a RustworkX directed graph
inherited from Graph.
Subclasses can override ``_node_value_attr`` to use a different attribute
name for the node value (e.g. RhythmTree uses ``'proportion'`` instead of
``'label'``).
Parameters
----------
root : hashable
The value for the root node of the tree.
children : tuple
Nested tuple structure defining the tree's children.
Each element is either a leaf value or a ``(value, children)`` pair.
"""
_node_value_attr = 'label'
[docs]
def __init__(self, root, children:tuple):
super().__init__(directed=True)
self._building_tree = True
self._root = self._build_tree(root, children)
self._building_tree = False
self._list = Group((root, children))
@property
def root(self):
return self._root
@property
def group(self):
return self._list
[docs]
@classmethod
def from_tree_structure(cls, source_tree):
"""Create a new instance with the same topology as source_tree but no node data."""
inst = cls.__new__(cls)
inst._graph = source_tree._graph.copy()
for idx in inst._graph.node_indices():
inst._graph[idx] = {}
inst._root = source_tree._root
inst._list = copy.deepcopy(source_tree._list)
inst._meta = {}
inst._structure_version = 0
inst._building_tree = False
inst._post_structure_clone()
return inst
def _post_structure_clone(self):
"""Subclass hook: initialize domain-specific state after topology-only cloning."""
pass
def _normalize_mutation_scope(self, scope_node=None, op=None):
if scope_node is None:
return None
return scope_node if scope_node in self else None
def _resolve_data_update_scope(self, node, changed_keys, op=None):
return self._normalize_mutation_scope(scope_node=node, op=op)
def _before_post_mutation(self, scope_node=None, op=None):
pass
def _after_post_mutation(self, scope_node=None, op=None):
pass
def _after_node_data_mutation(self, node, attrs, scope_node=None, op=None):
normalized_scope = self._normalize_mutation_scope(scope_node=scope_node, op=op)
self._post_mutation(scope_node=normalized_scope, op=op or 'set_node_data')
def _post_mutation(self, scope_node=None, op=None):
scope_node = self._normalize_mutation_scope(scope_node=scope_node, op=op)
self._before_post_mutation(scope_node=scope_node, op=op)
self._invalidate_caches()
self._update_group_structure()
self._after_post_mutation(scope_node=scope_node, op=op)
def _invalidate_caches(self):
"""Invalidate all tree caches"""
super()._invalidate_caches()
for attr in ['depth', 'k', 'leaf_nodes']:
if attr in self.__dict__:
delattr(self, attr)
if hasattr(self, 'parent'):
self.parent.cache_clear()
@cached_property
def depth(self):
"""Maximum depth of the tree.
Returns
-------
int
The maximum depth of any node in the tree
"""
if not hasattr(self, '_root') or self._root is None:
return 0
root_idx = self._get_node_index(self._root)
if root_idx is None:
return 0
def edge_cost_fn(edge_data):
return 1.0
distances = rx.digraph_dijkstra_shortest_path_lengths(
self._graph, root_idx, edge_cost_fn
)
return int(max(distances.values())) if distances else 0
@cached_property
def k(self):
"""Maximum branching factor of the tree"""
return max((self.out_degree(n) for n in self.nodes), default=0)
@cached_property
def leaf_nodes(self):
"""Return leaf nodes (nodes with no successors) in tree traversal order.
Returns
-------
tuple
All leaf nodes in the tree, in left-right traversal order
"""
leaf_nodes_list = []
def collect_leaves(node):
if self.out_degree(node) == 0:
leaf_nodes_list.append(node)
else:
for child in self.successors(node):
collect_leaves(child)
collect_leaves(self.root)
return tuple(leaf_nodes_list)
[docs]
def subtree_leaves(self, node):
"""Return leaf nodes of the subtree rooted at the given node, in left-right order.
Parameters
----------
node : int
The root of the subtree whose leaves to return
Returns
-------
tuple
Leaf nodes of the subtree in left-right traversal order
Raises
------
ValueError
If the node is not found in the tree
"""
if node not in self:
raise ValueError(f"Node {node} not found in tree")
leaves = []
def collect_leaves(n):
if self.out_degree(n) == 0:
leaves.append(n)
else:
for child in self.successors(n):
collect_leaves(child)
collect_leaves(node)
return tuple(leaves)
[docs]
def depth_of(self, node):
"""Return the depth of a node in the tree.
The depth is the length of the path from the root to the node.
Parameters
----------
node : int
The node to get the depth of
Returns
-------
int
The depth of the node (0 for root)
Raises
------
ValueError
If the node is not found in the tree
"""
if node not in self:
raise ValueError(f"Node {node} not found in tree")
root_idx = self._get_node_index(self._root)
node_idx = self._get_node_index(node)
if root_idx == node_idx:
return 0
depth = 0
current = node_idx
while current != root_idx:
parents = list(self._graph.predecessor_indices(current))
if not parents:
raise ValueError(f"Node {node} is not reachable from root")
current = parents[0]
depth += 1
return depth
[docs]
@lru_cache(maxsize=None)
def parent(self, node):
"""Returns the parent of a node.
Parameters
----------
node : hashable
The node to get the parent of.
Returns
-------
int or None
The parent node, or None if the node is the root.
"""
parents = list(self.predecessors(node))
return parents[0] if parents else None
[docs]
@lru_cache(maxsize=None)
def ancestors(self, node):
"""Return all ancestors of a node in the tree.
Parameters
----------
node : int
The node whose ancestors to return
Returns
-------
tuple
All ancestors from root to parent (excluding the node itself)
Raises
------
ValueError
If the node is not found in the tree
"""
if node not in self:
raise ValueError(f"Node {node} not found in tree")
if node == self._root:
return tuple()
root_idx = self._get_node_index(self._root)
node_idx = self._get_node_index(node)
ancestor_indices = []
current = node_idx
while current != root_idx:
parents = list(self._graph.predecessor_indices(current))
if not parents:
raise ValueError(f"Node {node} is not reachable from root")
current = parents[0]
ancestor_indices.append(current)
ancestor_indices.reverse()
return tuple(self._get_node_object(ai) for ai in ancestor_indices)
[docs]
@lru_cache(maxsize=None)
def descendants(self, node):
"""Return all descendants of a node in depth-first order.
Parameters
----------
node : int
The node whose descendants to return
Returns
-------
tuple
All descendants of the node in depth-first order
Raises
------
ValueError
If the node is not found in the tree
"""
if node not in self:
raise ValueError(f"Node {node} not found in tree")
node_idx = self._get_node_index(node)
dfs_edges = rx.dfs_edges(self._graph, node_idx)
descendant_indices = []
visited = {node_idx}
for src, tgt in dfs_edges:
if src == node_idx or src in visited:
if tgt not in visited:
descendant_indices.append(tgt)
visited.add(tgt)
return tuple(self._get_node_object(di) for di in descendant_indices)
[docs]
@lru_cache(maxsize=None)
def branch(self, node):
"""Return all nodes on the branch from the root to the given node.
Parameters
----------
node : int
The target node
Returns
-------
tuple
All nodes from root to the given node (inclusive)
Raises
------
ValueError
If the node is not found in the tree
"""
if node not in self:
raise ValueError(f"Node {node} not found in tree")
if node == self._root:
return (self._root,)
root_idx = self._get_node_index(self._root)
node_idx = self._get_node_index(node)
branch_indices = []
current = node_idx
while current != root_idx:
branch_indices.append(current)
parents = list(self._graph.predecessor_indices(current))
if not parents:
return tuple()
current = parents[0]
branch_indices.append(root_idx)
branch_indices.reverse()
return tuple(self._get_node_object(idx) for idx in branch_indices)
[docs]
def path_signature(self, root_node, target_node):
"""Return child-index path from ``root_node`` to ``target_node``.
The signature is a tuple of child indices describing how to navigate
from ``root_node`` to ``target_node`` by repeated ``successors`` lookups.
"""
if root_node not in self:
raise ValueError(f"Root node {root_node} not found in tree")
if target_node not in self:
raise ValueError(f"Target node {target_node} not found in tree")
branch = list(self.branch(target_node))
if root_node not in branch:
raise ValueError(
f"Node {target_node} is not in subtree rooted at {root_node}"
)
root_idx = branch.index(root_node)
signature = []
for i in range(root_idx + 1, len(branch)):
parent = branch[i - 1]
current = branch[i]
children = list(self.successors(parent))
signature.append(children.index(current))
return tuple(signature)
[docs]
def node_from_signature(self, root_node, signature):
"""Resolve a node by child-index signature from ``root_node``."""
if root_node not in self:
raise ValueError(f"Root node {root_node} not found in tree")
current = root_node
for idx in signature:
children = list(self.successors(current))
if idx < 0 or idx >= len(children):
raise ValueError(
f"Invalid child index {idx} for node {current} with {len(children)} children"
)
current = children[idx]
return current
[docs]
def map_parallel_nodes(self, other_tree, self_root=None, other_root=None):
"""Map nodes between topologically parallel subtrees.
Returns a dict mapping nodes in ``self`` to corresponding nodes in
``other_tree`` by child-order traversal.
"""
if not isinstance(other_tree, Tree):
raise TypeError("other_tree must be a Tree instance")
self_root = self.root if self_root is None else self_root
other_root = other_tree.root if other_root is None else other_root
if self_root not in self:
raise ValueError(f"Node {self_root} not found in source tree")
if other_root not in other_tree:
raise ValueError(f"Node {other_root} not found in target tree")
mapping = {}
stack = [(self_root, other_root)]
while stack:
src_node, dst_node = stack.pop()
mapping[src_node] = dst_node
src_children = list(self.successors(src_node))
dst_children = list(other_tree.successors(dst_node))
if len(src_children) != len(dst_children):
raise ValueError(
"Topology mismatch while mapping parallel subtrees"
)
for src_child, dst_child in zip(reversed(src_children), reversed(dst_children)):
stack.append((src_child, dst_child))
return mapping
[docs]
def siblings(self, node):
"""Returns the siblings of a node (nodes with the same parent)."""
parent = self.parent(node)
return tuple(n for n in self.successors(parent) if n != node) if parent else tuple()
[docs]
def lowest_common_ancestor(self, node_a, node_b):
if node_a not in self or node_b not in self:
raise ValueError("Both nodes must exist in the tree")
branch_a = self.branch(node_a)
branch_b = self.branch(node_b)
lca = self.root
for a, b in zip(branch_a, branch_b):
if a != b:
break
lca = a
return lca
[docs]
def subtree(self, node, renumber=True):
"""Extract a tree subtree rooted at the given node.
Parameters
----------
node : int
The node to use as the root of the subtree
renumber : bool, optional
Whether to renumber the nodes in the new tree (default: True)
Returns
-------
Tree
A new Tree object representing the subtree containing the node
and all its descendants
Raises
------
ValueError
If the node is not found in the tree
"""
if node not in self:
raise ValueError(f"Node {node} not found in tree")
descendants = [node] + list(self.descendants(node))
new_tree = self.__class__.__new__(self.__class__)
new_tree._graph = rx.PyDiGraph()
new_tree._meta = self._meta.copy()
new_tree._structure_version = 0
node_mapping = {}
for old_node in descendants:
new_node_id = new_tree._graph.add_node(self[old_node].copy())
node_mapping[old_node] = new_node_id
for old_node in descendants:
for successor in self.successors(old_node):
if successor in descendants:
new_tree._graph.add_edge(
node_mapping[old_node],
node_mapping[successor],
None
)
new_tree._root = node_mapping[node]
if hasattr(self, 'group'):
attr = getattr(self, '_node_value_attr', 'label')
def build_group_structure(root_node):
children = [child for child in descendants if self.parent(child) == root_node]
if not children:
return self[root_node].get(attr, root_node)
child_structures = []
for child in sorted(children):
child_structure = build_group_structure(child)
child_structures.append(child_structure)
root_val = self[root_node].get(attr, root_node)
return (root_val, tuple(child_structures))
structure = build_group_structure(node)
if isinstance(structure, tuple) and len(structure) > 1:
new_tree._list = Group(structure)
else:
new_tree._list = Group((structure, tuple()))
if hasattr(self, '_after_subtree_built'):
self._after_subtree_built(new_tree, node_mapping, renumber)
if renumber:
new_tree.renumber_nodes()
return new_tree
[docs]
def at_depth(self, n, operator='=='):
"""Return nodes at a specific depth.
Parameters
----------
n : int
The depth level to query
operator : str, optional
Comparison operator ('==', '>=', '<=', '<', '>'), default is '=='
Returns
-------
list
Nodes satisfying the depth condition in breadth-first order
Raises
------
ValueError
If operator is not supported
"""
if operator not in ['==', '>=', '<=', '<', '>']:
raise ValueError(f"Unsupported operator: {operator}")
all_levels = []
current_level = [self.root]
current_depth = 0
while current_level and current_depth <= self.depth:
all_levels.append(current_level[:])
if current_depth >= self.depth:
break
next_level = []
for node in current_level:
for child in self.successors(node):
next_level.append(child)
current_level = next_level
current_depth += 1
matching_nodes = []
if operator == '==':
if n < len(all_levels):
matching_nodes = all_levels[n]
elif operator == '>=':
for depth, level in enumerate(all_levels):
if depth >= n:
matching_nodes.extend(level)
elif operator == '<=':
for depth, level in enumerate(all_levels):
if depth <= n:
matching_nodes.extend(level)
elif operator == '<':
for depth, level in enumerate(all_levels):
if depth < n:
matching_nodes.extend(level)
elif operator == '>':
for depth, level in enumerate(all_levels):
if depth > n:
matching_nodes.extend(level)
return matching_nodes
[docs]
def add_node(self, **attr):
"""Add a node to the tree"""
if getattr(self, '_building_tree', False):
return Graph.add_node(self, **attr)
raise NotImplementedError("Use add_child() to add nodes to a tree")
[docs]
def add_edge(self, u, v, **attr):
"""Add an edge to the tree"""
if getattr(self, '_building_tree', False):
return Graph.add_edge(self, u, v, **attr)
raise NotImplementedError("Use add_child() to add edges to a tree")
[docs]
def remove_node(self, node):
"""Remove a node and its subtree"""
raise NotImplementedError("Use prune() or remove_subtree() to remove nodes from a tree")
[docs]
def remove_edge(self, u, v):
"""Remove an edge from the tree"""
raise NotImplementedError("Use prune() or remove_subtree() to remove edges from a tree")
[docs]
def add_child(self, parent, index=None, **attr):
"""Add a child node to a parent.
Parameters
----------
parent : int
The parent node ID.
index : int or None, optional
Position to insert child (default is None, which appends).
**attr : dict
Node attributes.
Returns
-------
int
The new child node ID.
"""
self._building_tree = True
try:
normalized = self._normalize_node_attrs(node=parent, attrs=attr, op='add_child')
self._validate_node_attrs(node=parent, attrs=normalized, op='add_child')
child_id = super().add_node(**normalized)
super().add_edge(parent, child_id)
self._post_mutation(scope_node=parent, op='add_child')
return child_id
finally:
self._building_tree = False
[docs]
def add_subtree(self, parent, subtree, index=None):
"""Add a subtree as a child of a parent node.
Parameters
----------
parent : int
The parent node to attach to.
subtree : Tree
Tree instance to attach.
index : int or None, optional
Position to insert subtree (default is None, which appends).
Returns
-------
int
The root ID of the attached subtree.
"""
if not isinstance(subtree, Tree):
raise TypeError("subtree must be a Tree instance")
node_mapping = {}
for node in subtree.nodes:
new_id = Graph.add_node(self, **dict(subtree.nodes[node]))
node_mapping[node] = new_id
for u, v in subtree.edges:
Graph.add_edge(self, node_mapping[u], node_mapping[v])
subtree_root = node_mapping[subtree.root]
Graph.add_edge(self, parent, subtree_root)
self._post_mutation(scope_node=parent, op='add_subtree')
return subtree_root
[docs]
def prune(self, node):
"""Remove a node and promote its children to its parent.
Parameters
----------
node : int
The node to remove.
"""
if node == self.root:
raise ValueError("Cannot prune the root node")
parent = self.parent(node)
children = list(self.successors(node))
for child in children:
Graph.add_edge(self, parent, child)
Graph.remove_node(self, node)
self._post_mutation(scope_node=parent, op='prune')
[docs]
def remove_subtree(self, node):
"""Remove a node and its entire subtree.
Parameters
----------
node : int
The root of the subtree to remove.
"""
if node == self.root:
raise ValueError("Cannot remove the root node")
subtree_nodes = [node] + list(self.descendants(node))
parent = self.parent(node)
for n in subtree_nodes:
Graph.remove_node(self, n)
self._post_mutation(scope_node=parent, op='remove_subtree')
[docs]
def replace_node(self, old_node, **attr):
"""Replace a node with new attributes while preserving structure.
Parameters
----------
old_node : int
The node to replace.
**attr : dict
New attributes for the node.
Returns
-------
int
The replaced node ID.
"""
parent = self.parent(old_node)
normalized = self._normalize_node_attrs(node=old_node, attrs=attr, op='replace_node')
self._validate_node_attrs(node=old_node, attrs=normalized, op='replace_node')
self._graph[old_node] = copy.deepcopy(normalized if normalized else {})
scope_node = parent if parent is not None else old_node
self._post_mutation(scope_node=scope_node, op='replace_node')
return old_node
def _update_group_structure(self):
"""Update the Group structure based on current graph state.
This method rebuilds the _list Group from the current tree structure.
Subclasses can override this to preserve specific parts of the Group.
"""
if hasattr(self, '_list'):
attr = getattr(self, '_node_value_attr', 'label')
def get_node_value(node):
return self[node].get(attr, node)
def get_children(node):
return list(self.successors(node))
structure = self._build_nested_structure(self.root, get_node_value, get_children)
if isinstance(structure, tuple) and len(structure) > 1:
self._list = Group(structure)
else:
self._list = Group((structure, tuple()))
[docs]
def graft_subtree(self, target_node, subtree, mode='replace'):
"""Graft a subtree onto the tree at the specified leaf node.
Parameters
----------
target_node : int
The leaf node where the subtree will be grafted
subtree : Tree
The Tree instance to graft onto this tree
mode : str, optional
Grafting mode - either 'replace' or 'adopt' (default: 'replace')
- 'replace': Replace the leaf node with subtree root
- 'adopt': Keep the leaf node and give it the children from subtree root
Returns
-------
int
The root node ID of the grafted subtree (for 'replace' mode) or
the target_node ID (for 'adopt' mode)
Raises
------
TypeError
If subtree is not a Tree instance
ValueError
If target_node is not found in the tree, is not a leaf node, or mode is invalid
"""
if not isinstance(subtree, Tree):
raise TypeError("subtree must be a Tree instance")
if target_node not in self:
raise ValueError(f"Target node {target_node} not found in tree")
if self.out_degree(target_node) > 0:
raise ValueError(f"Target node {target_node} is not a leaf node. Can only graft to leaf nodes.")
if mode not in ['replace', 'adopt']:
raise ValueError(f"Invalid mode '{mode}'. Use 'replace' or 'adopt'")
if mode == 'replace':
return self._graft_replace_leaf(target_node, subtree)
else: # adopt
return self._graft_adopt_leaf(target_node, subtree)
def _graft_replace_leaf(self, target_node, subtree):
"""Replace the leaf node with the subtree root."""
parent = self.parent(target_node)
node_mapping = {subtree.root: target_node}
for node in subtree.nodes:
if node == subtree.root:
continue
new_id = Graph.add_node(self, **dict(subtree.nodes[node]))
node_mapping[node] = new_id
self._graph[target_node] = copy.deepcopy(dict(subtree.nodes[subtree.root]))
for u, v in subtree.edges:
if u == subtree.root:
Graph.add_edge(self, target_node, node_mapping[v])
else:
Graph.add_edge(self, node_mapping[u], node_mapping[v])
self._post_mutation(scope_node=parent, op='graft_subtree')
return target_node
def _graft_adopt_leaf(self, target_node, subtree):
"""Keep the leaf node and give it the children from subtree root."""
# Add all subtree nodes except the root
subtree_nodes_except_root = [node for node in subtree.nodes if node != subtree.root]
node_mapping = {}
for node in subtree_nodes_except_root:
new_id = Graph.add_node(self, **dict(subtree.nodes[node]))
node_mapping[node] = new_id
# Add edges between the mapped nodes (excluding edges from subtree root)
for u, v in subtree.edges:
if u != subtree.root and v != subtree.root:
Graph.add_edge(self, node_mapping[u], node_mapping[v])
# Connect target node to the children of subtree root
subtree_root_children = list(subtree.successors(subtree.root))
for child in subtree_root_children:
Graph.add_edge(self, target_node, node_mapping[child])
self._post_mutation(scope_node=target_node, op='graft_subtree')
return target_node
[docs]
def move_subtree(self, node, new_parent, index=None):
"""Move a subtree to a new parent.
Parameters
----------
node : int
Root of subtree to move.
new_parent : int
New parent node.
index : int or None, optional
Position under new parent (default is None, which appends).
"""
if node == self.root:
raise ValueError("Cannot move the root node")
old_parent = self.parent(node)
scope_node = new_parent
if old_parent is not None and new_parent is not None:
scope_node = self.lowest_common_ancestor(old_parent, new_parent)
Graph.remove_edge(self, old_parent, node)
Graph.add_edge(self, new_parent, node)
self._post_mutation(scope_node=scope_node, op='move_subtree')
[docs]
def prune_to_depth(self, max_depth):
"""Prune the tree to a maximum depth, removing all nodes beyond that depth."""
if max_depth < 0:
raise ValueError("max_depth must be non-negative")
root_idx = self._get_node_index(self._root)
depths = {}
visited = set()
queue = [(root_idx, 0)]
while queue:
node_idx, depth = queue.pop(0)
if node_idx in visited:
continue
visited.add(node_idx)
depths[node_idx] = depth
for successor_idx in self._graph.successor_indices(node_idx):
if successor_idx not in visited:
queue.append((successor_idx, depth + 1))
indices_to_remove = [idx for idx, depth in depths.items() if depth > max_depth]
for idx in indices_to_remove:
node_obj = self._get_node_object(idx)
Graph.remove_node(self, node_obj)
self._post_mutation(scope_node=None, op='prune_to_depth')
[docs]
def prune_leaves(self, n):
"""Prune n levels from each branch, starting from the leaves."""
if n < 0:
raise ValueError("n must be non-negative")
if n == 0:
return
for _ in range(n):
leaf_indices = [idx for idx in self._graph.node_indices() if self._graph.out_degree(idx) == 0]
for idx in leaf_indices:
node_obj = self._get_node_object(idx)
Graph.remove_node(self, node_obj)
if self._graph.num_nodes() == 1:
break
self._post_mutation(scope_node=None, op='prune_leaves')
[docs]
def __deepcopy__(self, memo):
"""Create a deep copy of the tree including Tree-specific attributes."""
new_tree = self.__class__.__new__(self.__class__)
new_tree._graph = self._graph.copy()
new_tree._meta = copy.deepcopy(self._meta, memo)
new_tree._structure_version = 0
new_tree._root = self._root
new_tree._list = copy.deepcopy(self._list, memo)
if hasattr(self, '_building_tree'):
new_tree._building_tree = self._building_tree
return new_tree
def _build_tree(self, root, children):
"""Build the tree structure from nested tuples."""
attr = getattr(self, '_node_value_attr', 'label')
root_id = super().add_node(**{attr: root})
self._add_children(root_id, children)
return root_id
def _add_children(self, parent_id, children_list):
attr = getattr(self, '_node_value_attr', 'label')
for child in children_list:
match child:
case tuple((D, S)):
duration_id = super().add_node(**{attr: D})
super().add_edge(parent_id, duration_id)
self._add_children(duration_id, S)
case Tree():
child_attr = getattr(child, '_node_value_attr', 'label')
val = child._graph.nodes[child.root].get(child_attr, child.root)
meta_val = child._meta if isinstance(child._meta, dict) else child._meta.to_dict('records')[0]
duration_id = super().add_node(**{attr: val}, meta=meta_val)
super().add_edge(parent_id, duration_id)
self._add_children(duration_id, child.group.S)
case _:
child_id = super().add_node(**{attr: child})
super().add_edge(parent_id, child_id)
@classmethod
def _from_graph(cls, G, clear_attributes=False, renumber=True, node_attr='label'):
"""Create a Tree from a RustworkX graph.
Parameters
----------
G : rx.PyDiGraph or Graph
RustworkX PyDiGraph or Graph instance.
clear_attributes : bool, optional
Whether to clear node attributes (default is False).
renumber : bool, optional
Whether to renumber nodes (default is True).
node_attr : str, optional
Attribute name to use for node labels (default is 'label').
Returns
-------
Tree
New Tree instance.
"""
if isinstance(G, Graph):
graph = G
else:
graph = Graph.from_rustworkx(G)
if not hasattr(graph._graph, 'in_degree'):
raise TypeError("Tree graphs must be directed")
def get_node_value(node_obj):
node_data = graph[node_obj]
if clear_attributes:
return None
value = node_data.get(node_attr)
return value if value is not None else node_obj
def get_children(node_obj):
return list(graph.successors(node_obj))
def _build_children_list(node_obj):
return cls._build_nested_structure(node_obj, get_node_value, get_children)
root_objects = [node for node in graph if graph.in_degree(node) == 0]
if len(root_objects) != 1:
raise ValueError(f"Graph must have exactly one root node, found {len(root_objects)}")
root = root_objects[0]
children_structure = _build_children_list(root)
if cls is Tree:
if isinstance(children_structure, tuple) and len(children_structure) > 1:
tree = cls(children_structure[0], children_structure[1])
else:
tree = cls(children_structure, tuple())
else:
base_tree = Tree._from_graph(G, clear_attributes, renumber=False, node_attr=node_attr)
tree = cls._from_base_tree(base_tree)
if renumber:
tree.renumber_nodes()
return tree
@classmethod
def _build_nested_structure(cls, root_node, get_node_value, get_children):
"""Build nested tuple structure from a tree starting at root_node.
Parameters
----------
root_node : hashable
The node to start building from.
get_node_value : callable
Function that takes a node and returns its value.
get_children : callable
Function that takes a node and returns its children.
Returns
-------
tuple or hashable
Nested tuple structure representing the tree.
"""
children = get_children(root_node)
if not children:
return get_node_value(root_node)
child_structures = []
for child in sorted(children):
child_structure = cls._build_nested_structure(child, get_node_value, get_children)
child_structures.append(child_structure)
root_value = get_node_value(root_node)
return (root_value, tuple(child_structures))
@classmethod
def _from_base_tree(cls, base_tree):
"""Create a Tree subclass instance from a base Tree.
Subclasses should override this method to handle their specific
construction.
Parameters
----------
base_tree : Tree
Base Tree instance.
Returns
-------
Tree
New Tree subclass instance.
"""
return base_tree