Source code for klotho.topos.graphs.graphs

import rustworkx as rx
import copy
from functools import lru_cache
from typing import List, TypeVar, Optional, Any, Union, Dict, Iterator, Tuple
from types import MappingProxyType
import numpy as np

T = TypeVar('T')


[docs] class Graph:
[docs] def __init__(self, directed: bool = False): """Initialize an empty Klotho graph.""" self._meta = {} self._structure_version = 0 self._graph = rx.PyDiGraph() if directed else rx.PyGraph() self._topology_mutable = True self._node_attr_mutable = True
@property def nodes(self): """Return a view of the nodes that can be subscripted.""" return GraphNodeView(self) @property def edges(self): """Return a view of the edges.""" return GraphEdgeView(self)
[docs] def __getitem__(self, node): """Get node data for a given node.""" if not self._graph.has_node(node): raise KeyError(f"Node {node} not found in graph") node_data = self._graph.get_node_data(node) if not isinstance(node_data, dict): return MappingProxyType({}) return MappingProxyType(node_data)
def _is_topology_mutable(self): return getattr(self, '_topology_mutable', True) def _is_node_attr_mutable(self): return getattr(self, '_node_attr_mutable', True) def _set_mutability_policy(self, topology_mutable=True, node_attr_mutable=True): self._topology_mutable = bool(topology_mutable) self._node_attr_mutable = bool(node_attr_mutable) def _ensure_topology_mutable(self, op=None): if not self._is_topology_mutable(): op_name = op or "topology mutation" raise PermissionError(f"{self.__class__.__name__} does not allow {op_name}") def _ensure_node_attr_mutable(self, op=None): if not self._is_node_attr_mutable(): op_name = op or "node-attribute mutation" raise PermissionError(f"{self.__class__.__name__} does not allow {op_name}") def _normalize_node_attrs(self, node, attrs, op=None): return dict(attrs) if isinstance(attrs, dict) else {} def _validate_node_attrs(self, node, attrs, op=None): pass def _resolve_data_update_scope(self, node, changed_keys, op=None): return node def _before_node_data_mutation(self, node, attrs, scope_node=None, op=None): pass def _after_node_data_mutation(self, node, attrs, scope_node=None, op=None): pass def _apply_node_data_mutation(self, node, attrs, op=None, replace=False): if not self._graph.has_node(node): raise KeyError(f"Node {node} not found in graph") self._ensure_node_attr_mutable(op=op) normalized = self._normalize_node_attrs(node=node, attrs=attrs, op=op) self._validate_node_attrs(node=node, attrs=normalized, op=op) changed_keys = tuple(sorted(normalized.keys())) scope_node = self._resolve_data_update_scope(node=node, changed_keys=changed_keys, op=op) self._before_node_data_mutation(node=node, attrs=normalized, scope_node=scope_node, op=op) existing_data = self._graph.get_node_data(node) existing_data = existing_data if isinstance(existing_data, dict) else {} if replace: new_data = dict(normalized) else: new_data = dict(existing_data) new_data.update(normalized) self._graph[node] = new_data self._invalidate_caches() self._after_node_data_mutation(node=node, attrs=normalized, scope_node=scope_node, op=op)
[docs] def __len__(self): """Return the number of nodes.""" return self._graph.num_nodes()
[docs] def __str__(self): """String representation of the graph.""" return f"Graph with {self.number_of_nodes()} nodes and {self.number_of_edges()} edges"
[docs] def __repr__(self): """String representation of the graph.""" return f"Graph({self.number_of_nodes()}, {self.number_of_edges()})"
[docs] def __iter__(self): """Iterate over node objects.""" return iter(self._graph.node_indices())
[docs] def __contains__(self, node): """Check if a node is in the graph.""" return self._graph.has_node(node)
def _invalidate_caches(self): """Invalidate all caches when structure changes""" self._structure_version += 1 if hasattr(self, 'descendants'): self.descendants.cache_clear() if hasattr(self, 'ancestors'): self.ancestors.cache_clear() if hasattr(self, 'successors'): self.successors.cache_clear() if hasattr(self, 'predecessors'): self.predecessors.cache_clear()
[docs] def out_degree(self, node): """Get the out-degree of a node""" if hasattr(self._graph, 'out_degree'): return self._graph.out_degree(node) else: return self._graph.degree(node)
[docs] def in_degree(self, node): """Get the in-degree of a node""" if hasattr(self._graph, 'in_degree'): return self._graph.in_degree(node) else: return self._graph.degree(node)
def _get_node_object(self, index): """Convert RustworkX node index to node object. For the base Graph class, nodes are just their indices. Subclasses can override this for different node representations. """ return index def _get_node_index(self, node): """Convert node object to RustworkX index. For the base Graph class, nodes are just their indices. Subclasses can override this for different node representations. """ return node
[docs] @classmethod def directed(cls): """Create an empty directed graph.""" return cls(directed=True)
[docs] @classmethod def digraph(cls): """Alias for directed graph creation.""" return cls.directed()
[docs] @classmethod def from_rustworkx(cls, graph: Union[rx.PyGraph, rx.PyDiGraph], copy_graph: bool = True): """Create a Graph from an existing RustworkX graph.""" if not isinstance(graph, (rx.PyGraph, rx.PyDiGraph)): raise TypeError(f"Expected rustworkx graph, got: {type(graph)}") new_graph = cls.__new__(cls) new_graph._graph = graph.copy() if copy_graph else graph new_graph._meta = {} new_graph._structure_version = 0 new_graph._topology_mutable = True new_graph._node_attr_mutable = True return new_graph
[docs] @classmethod def from_networkx(cls, graph, copy_graph: bool = True): """Create a Graph from a NetworkX graph.""" import networkx as nx if not isinstance(graph, (nx.Graph, nx.DiGraph)): raise TypeError(f"Expected networkx.Graph or networkx.DiGraph, got: {type(graph)}") nx_graph = graph.copy() if copy_graph else graph directed = nx_graph.is_directed() new_graph = cls(directed=directed) node_index_map = {} for node, attrs in nx_graph.nodes(data=True): node_attrs = dict(attrs) if isinstance(attrs, dict) else {} node_index_map[node] = new_graph._graph.add_node(node_attrs) for u, v, attrs in nx_graph.edges(data=True): edge_attrs = dict(attrs) if isinstance(attrs, dict) else {} new_graph._graph.add_edge(node_index_map[u], node_index_map[v], edge_attrs) return new_graph
[docs] @classmethod def from_nodes_edges( cls, nodes: Optional[List[Any]] = None, edges: Optional[List[Tuple[Any, ...]]] = None, directed: bool = False, node_mode: str = 'label', node_key: str = 'label', ): """Create a graph from node and edge iterables.""" graph = cls(directed=directed) node_lookup: Dict[Any, int] = {} if node_mode not in {'label', 'id'}: raise ValueError("node_mode must be 'label' or 'id'") def ensure_node(node_value, attrs=None): attrs = dict(attrs) if isinstance(attrs, dict) else {} if node_mode == 'id': if not isinstance(node_value, int): raise TypeError("node_mode='id' requires integer node values") while graph.number_of_nodes() <= node_value: graph.add_node() graph.set_node_data(node_value, **attrs) return node_value if node_value in node_lookup: node_id = node_lookup[node_value] if attrs: graph.set_node_data(node_id, **attrs) return node_id node_attrs = {node_key: node_value} node_attrs.update(attrs) node_id = graph.add_node(**node_attrs) node_lookup[node_value] = node_id return node_id if nodes: for node_entry in nodes: if isinstance(node_entry, tuple) and len(node_entry) == 2 and isinstance(node_entry[1], dict): ensure_node(node_entry[0], node_entry[1]) else: ensure_node(node_entry) if edges: for edge in edges: if len(edge) == 2: u_label, v_label = edge edge_attrs = {} elif len(edge) == 3: u_label, v_label, edge_attrs = edge edge_attrs = dict(edge_attrs) if isinstance(edge_attrs, dict) else {} else: raise ValueError("Edges must be (u, v) or (u, v, attrs)") u = ensure_node(u_label) v = ensure_node(v_label) graph.add_edge(u, v, **edge_attrs) return graph
[docs] @classmethod def from_edges( cls, edges: List[Tuple[Any, ...]], directed: bool = False, node_mode: str = 'label', node_key: str = 'label', ): """Create a graph from an edge list.""" return cls.from_nodes_edges( nodes=None, edges=edges, directed=directed, node_mode=node_mode, node_key=node_key, )
[docs] @classmethod def empty_graph( cls, n_nodes: int = 0, labels: Optional[List[Any]] = None, directed: bool = False, node_key: str = 'label', ): """Create an empty graph with optional labeled nodes.""" if n_nodes < 0: raise ValueError("n_nodes must be >= 0") resolved_labels = cls._resolve_labels(n_nodes, labels) return cls.from_nodes_edges( nodes=[(label, {node_key: label}) for label in resolved_labels], edges=[], directed=directed, node_mode='label', node_key=node_key, )
[docs] @classmethod def path_graph( cls, n_nodes: int, labels: Optional[List[Any]] = None, directed: bool = False, node_key: str = 'label', ): """Create a path graph with optional labels.""" if n_nodes < 0: raise ValueError("n_nodes must be >= 0") resolved_labels = cls._resolve_labels(n_nodes, labels) edges = [(resolved_labels[i], resolved_labels[i + 1]) for i in range(max(0, n_nodes - 1))] return cls.from_nodes_edges( nodes=[(label, {node_key: label}) for label in resolved_labels], edges=edges, directed=directed, node_mode='label', node_key=node_key, )
[docs] @classmethod def cycle_graph( cls, n_nodes: int, labels: Optional[List[Any]] = None, directed: bool = False, node_key: str = 'label', ): """Create a cycle graph with optional labels.""" if n_nodes < 0: raise ValueError("n_nodes must be >= 0") resolved_labels = cls._resolve_labels(n_nodes, labels) edges = [(resolved_labels[i], resolved_labels[i + 1]) for i in range(max(0, n_nodes - 1))] if n_nodes > 2: edges.append((resolved_labels[-1], resolved_labels[0])) return cls.from_nodes_edges( nodes=[(label, {node_key: label}) for label in resolved_labels], edges=edges, directed=directed, node_mode='label', node_key=node_key, )
[docs] @classmethod def star_graph( cls, n_nodes: int, center: int = 0, labels: Optional[List[Any]] = None, directed: bool = False, node_key: str = 'label', ): """Create a star graph with optional labels.""" if n_nodes < 0: raise ValueError("n_nodes must be >= 0") if n_nodes == 0: return cls.empty_graph(0, directed=directed, node_key=node_key) if center < 0 or center >= n_nodes: raise ValueError("center must be a valid node index") resolved_labels = cls._resolve_labels(n_nodes, labels) center_label = resolved_labels[center] edges = [(center_label, resolved_labels[i]) for i in range(n_nodes) if i != center] return cls.from_nodes_edges( nodes=[(label, {node_key: label}) for label in resolved_labels], edges=edges, directed=directed, node_mode='label', node_key=node_key, )
[docs] @classmethod def random_graph( cls, n_nodes: int, p: float = 0.3, labels: Optional[List[Any]] = None, directed: bool = False, seed: Optional[int] = None, node_key: str = 'label', ): """Create a random Erdos-Renyi style graph with optional labels.""" if n_nodes < 0: raise ValueError("n_nodes must be >= 0") if p < 0 or p > 1: raise ValueError("p must be between 0 and 1") resolved_labels = cls._resolve_labels(n_nodes, labels) rng = np.random.default_rng(seed) edges: List[Tuple[Any, Any]] = [] if directed: for i in range(n_nodes): for j in range(n_nodes): if i == j: continue if rng.random() < p: edges.append((resolved_labels[i], resolved_labels[j])) else: for i in range(n_nodes): for j in range(i + 1, n_nodes): if rng.random() < p: edges.append((resolved_labels[i], resolved_labels[j])) return cls.from_nodes_edges( nodes=[(label, {node_key: label}) for label in resolved_labels], edges=edges, directed=directed, node_mode='label', node_key=node_key, )
@staticmethod def _resolve_labels(n_nodes: int, labels: Optional[List[Any]] = None) -> List[Any]: if labels is None: return list(range(n_nodes)) if len(labels) != n_nodes: raise ValueError(f"Expected {n_nodes} labels, got {len(labels)}") return list(labels)
[docs] def neighbors(self, node): """Get neighbors of a node""" return list(self._graph.neighbors(node))
[docs] @lru_cache(maxsize=None) def predecessors(self, node): """Returns all predecessors of a node. Parameters ---------- node : int The node whose predecessors to return. Returns ------- tuple All predecessors of the node. """ _ = self._structure_version if hasattr(self._graph, 'predecessor_indices'): return tuple(self._graph.predecessor_indices(node)) else: return tuple(self.neighbors(node))
[docs] @lru_cache(maxsize=None) def successors(self, node): """Returns all successors of a node. Parameters ---------- node : int The node whose successors to return. Returns ------- tuple All successors of the node in sorted order (left-to-right). """ _ = self._structure_version if hasattr(self._graph, 'successor_indices'): succ_indices = self._graph.successor_indices(node) return tuple(sorted(succ_indices)) else: return tuple(sorted(self.neighbors(node)))
[docs] @lru_cache(maxsize=None) def descendants(self, node): """Returns all descendants of a node using native RustworkX algorithm. Parameters ---------- node : int The node whose descendants to return. Returns ------- tuple All descendants of the node. """ _ = self._structure_version try: return tuple(rx.descendants(self._graph, node)) except Exception: return tuple()
[docs] @lru_cache(maxsize=None) def ancestors(self, node): """Returns all ancestors of a node using native RustworkX algorithm. Parameters ---------- node : int The node whose ancestors to return. Returns ------- tuple All ancestors of the node. """ _ = self._structure_version try: return tuple(rx.ancestors(self._graph, node)) except Exception: return tuple()
[docs] def topological_sort(self): """Returns nodes in topological order. Returns ------- generator Nodes in topological order. """ if hasattr(self._graph, 'out_degree'): indices = rx.topological_sort(self._graph) else: indices = self._graph.node_indices() return (idx for idx in indices)
[docs] def to_directed(self): """Return a directed version of this graph. Returns ------- Graph A new Graph instance with directed edges. """ directed_rx = rx.PyDiGraph() for idx in self._graph.node_indices(): node_data = self._graph.get_node_data(idx) directed_rx.add_node(node_data) for src, tgt, edge_data in self.edges(data=True): directed_rx.add_edge(src, tgt, edge_data) new_graph = Graph.__new__(Graph) new_graph._graph = directed_rx new_graph._meta = copy.deepcopy(self._meta) new_graph._structure_version = 0 return new_graph
[docs] def number_of_nodes(self): """Return the number of nodes in the graph. Returns ------- int Number of nodes. """ return self._graph.num_nodes()
[docs] def number_of_edges(self): """Return the number of edges in the graph. Returns ------- int Number of edges. """ return self._graph.num_edges()
[docs] def nodes_with_data(self, data=True): """Return nodes with their data. Parameters ---------- data : bool If True, return node data as well. Returns ------- Iterator Iterator of (node, data) pairs if data=True, else just nodes. """ if data: for idx in self._graph.node_indices(): node_data = self._graph.get_node_data(idx) yield (idx, node_data if isinstance(node_data, dict) else {}) else: for idx in self._graph.node_indices(): yield idx
[docs] def subgraph(self, node, renumber=True): """Extract a subgraph starting from a given node. Parameters ---------- node : int The node to use as the starting point of the subgraph. renumber : bool Whether to renumber the nodes in the new graph. Returns ------- Graph A new Graph object representing the subgraph. """ if node not in self: raise ValueError(f"Node {node} not found in graph") descendants = [node] + list(self.descendants(node)) subgraph_rx = self._graph.subgraph(descendants) return self._from_graph(subgraph_rx, renumber=renumber)
@property def root_nodes(self): """Returns root nodes (nodes with no predecessors)""" root_indices = [] if hasattr(self._graph, 'in_degree'): for idx in self._graph.node_indices(): if self._graph.in_degree(idx) == 0: root_indices.append(idx) else: if self._graph.num_nodes() == 0: return tuple() degrees = [(idx, self._graph.degree(idx)) for idx in self._graph.node_indices()] if degrees: min_deg_nodes = [idx for idx, deg in degrees if deg > 0] if min_deg_nodes: root_indices = [min(min_deg_nodes)] return tuple(root_indices)
[docs] def add_node(self, **attr): """Add a node to the graph. Parameters ---------- **attr : dict Node attributes as keyword arguments. Returns ------- int The node ID that was added. """ self._ensure_topology_mutable(op='add_node') node_id = self._graph.add_node(attr if attr else {}) self._invalidate_caches() return node_id
[docs] def set_node_data(self, node, **attr): """Update data for an existing node. Parameters ---------- node : int The node to update. **attr : dict Node attributes to set. """ self._apply_node_data_mutation(node=node, attrs=attr, op='set_node_data', replace=False)
[docs] def update_node_data(self, node, attrs: Dict[str, Any]): """Update data for an existing node from a dictionary.""" self._apply_node_data_mutation(node=node, attrs=attrs, op='update_node_data', replace=False)
[docs] def replace_node_data(self, node, attrs: Dict[str, Any]): """Replace all data for an existing node.""" self._apply_node_data_mutation(node=node, attrs=attrs, op='replace_node_data', replace=True)
[docs] def remove_node(self, node): """Remove a node from the graph.""" self._ensure_topology_mutable(op='remove_node') self._graph.remove_node(node) self._invalidate_caches()
[docs] def add_edge(self, u, v, **attr): """Add an edge to the graph with optional attributes.""" self._ensure_topology_mutable(op='add_edge') if not self._graph.has_node(u): raise KeyError(f"Node {u} not found in graph") if not self._graph.has_node(v): raise KeyError(f"Node {v} not found in graph") self._graph.add_edge(u, v, attr if attr else {}) self._invalidate_caches()
[docs] def has_edge(self, u, v): """Check if an edge exists between two nodes.""" return self._graph.has_edge(u, v)
[docs] def remove_edge(self, u, v): """Remove an edge from the graph.""" self._ensure_topology_mutable(op='remove_edge') self._graph.remove_edge(u, v) self._invalidate_caches()
[docs] def update(self, edges=None, nodes=None): """Update the graph with nodes and edges.""" self._ensure_topology_mutable(op='update') if nodes: for node_data in nodes: if isinstance(node_data, tuple) and len(node_data) == 2: node, attrs = node_data if not self._graph.has_node(node): raise KeyError(f"Node {node} not found in graph") self.set_node_data(node, **attrs) else: self.add_node(label=node_data) if edges: for edge_data in edges: if len(edge_data) == 2: u, v = edge_data self.add_edge(u, v) elif len(edge_data) == 3: u, v, attrs = edge_data self.add_edge(u, v, **attrs) self._invalidate_caches()
[docs] def clear(self): """Remove all nodes and edges from the graph.""" self._ensure_topology_mutable(op='clear') self._graph.clear() self._invalidate_caches()
[docs] def set_node_attributes(self, node, attributes): """Set attributes for a node.""" self._apply_node_data_mutation(node=node, attrs=attributes, op='set_node_attributes', replace=False)
[docs] def clear_node_attributes(self, nodes=None): """Clear attributes of specified nodes or all nodes. Parameters ---------- nodes : list, optional Specific nodes to clear attributes for, or None for all nodes. """ self._ensure_node_attr_mutable(op='clear_node_attributes') nodes_to_clear = nodes if nodes is not None else self._graph.node_indices() for node in nodes_to_clear: if node in self: self._apply_node_data_mutation(node=node, attrs={}, op='clear_node_attributes', replace=True)
[docs] def renumber_nodes(self, method='default'): """Renumber the nodes in the graph to consecutive integers. Parameters ---------- method : str The method to use for renumbering: - 'default': Use sequential numbering - 'dfs': Use depth-first search preorder - 'bfs': Use breadth-first search Returns ------- Graph Self with renumbered nodes. """ if method == 'default': pass elif method in ['dfs', 'bfs']: pass else: raise ValueError(f"Unknown renumbering method: {method}") return self
[docs] def copy(self): """Create a deep copy of this graph.""" return copy.deepcopy(self)
[docs] def is_directed(self): """Return True if graph is directed, False otherwise.""" return isinstance(self._graph, rx.PyDiGraph)
[docs] def is_multigraph(self): """Return True if graph is a multigraph, False otherwise.""" return False
[docs] def to_networkx(self): """Convert this Graph to a NetworkX graph for compatibility with NetworkX functions. Returns ------- networkx.Graph or networkx.DiGraph NetworkX equivalent of this graph. """ import networkx as nx if self.is_directed(): nx_graph = nx.DiGraph() else: nx_graph = nx.Graph() for node, attrs in self.nodes(data=True): nx_graph.add_node(node, **attrs) for u, v, attrs in self.edges(data=True): nx_graph.add_edge(u, v, **attrs) return nx_graph
@classmethod def _from_graph(cls, G, **kwargs): """Create a new instance from an existing graph. Parameters ---------- G : Graph or rustworkx graph The graph to create a new instance from. **kwargs : dict Additional arguments. Returns ------- Graph A new Graph instance. """ if isinstance(G, Graph): new_graph = cls.from_rustworkx(G._graph) new_graph._meta = copy.deepcopy(G._meta) return new_graph if isinstance(G, (rx.PyGraph, rx.PyDiGraph)): return cls.from_rustworkx(G) raise TypeError(f"Unsupported graph type: {type(G)}")
[docs] @classmethod def grid_graph(cls, dims, periodic=False): """Create a grid graph with dimensions specified in dims. Creates a grid graph where nodes are numbered sequentially 0, 1, 2, ... and node data contains the coordinate as a tuple. Parameters ---------- dims : list or tuple List of ranges or iterables defining the coordinate space for each dimension. periodic : bool, optional Whether to create periodic boundary conditions (default is False) Returns ------- Graph A new Graph instance with grid structure and coordinate data in nodes """ if not dims or len(dims) == 0: return cls() dims = [list(d) for d in dims] if len(dims) == 1: if periodic: rx_graph = rx.generators.cycle_graph(len(dims[0])) else: rx_graph = rx.generators.path_graph(len(dims[0])) for i, coord_val in enumerate(dims[0]): rx_graph[i] = {'coord': (coord_val,)} elif len(dims) == 2 and not periodic: rows, cols = len(dims[0]), len(dims[1]) rx_graph = rx.generators.grid_graph(rows, cols) import itertools for i, coord in enumerate(itertools.product(dims[0], dims[1])): rx_graph[i] = {'coord': coord} else: factor_graphs = [] for dim_values in dims: if periodic: factor_graph = rx.generators.cycle_graph(len(dim_values)) else: factor_graph = rx.generators.path_graph(len(dim_values)) factor_graphs.append(factor_graph) product_graph = factor_graphs[0] id_to_coord_index = {node_id: (node_id,) for node_id in product_graph.node_indices()} for next_factor in factor_graphs[1:]: product_graph, node_map = rx.graph_cartesian_product(product_graph, next_factor) next_id_to_coord_index = {} for (left_node, right_node), product_node in node_map.items(): next_id_to_coord_index[product_node] = id_to_coord_index[left_node] + (right_node,) id_to_coord_index = next_id_to_coord_index for node_id, coord_idx_tuple in id_to_coord_index.items(): coord = tuple(dims[i][coord_idx_tuple[i]] for i in range(len(dims))) product_graph[node_id] = {'coord': coord} rx_graph = product_graph return cls.from_rustworkx(rx_graph)
[docs] @classmethod def complete_graph(cls, n_nodes, labels: Optional[List[Any]] = None, directed: bool = False, node_key: str = 'label'): """Create a complete graph. Parameters ---------- n_nodes : int Number of nodes in the complete graph. Returns ------- Graph A new Graph instance with complete structure. """ if n_nodes < 0: raise ValueError("n_nodes must be >= 0") resolved_labels = cls._resolve_labels(n_nodes, labels) edges: List[Tuple[Any, Any]] = [] if directed: for i in range(n_nodes): for j in range(n_nodes): if i != j: edges.append((resolved_labels[i], resolved_labels[j])) else: for i in range(n_nodes): for j in range(i + 1, n_nodes): edges.append((resolved_labels[i], resolved_labels[j])) return cls.from_nodes_edges( nodes=[(label, {node_key: label}) for label in resolved_labels], edges=edges, directed=directed, node_mode='label', node_key=node_key, )
[docs] @classmethod def from_cost_matrix(cls, cost_matrix: np.ndarray, items: List[T]): """Create a Graph from a cost matrix. Transform a symmetric cost matrix into an undirected graph where edge weights represent the costs between nodes. Self-loops are excluded from the resulting graph. Parameters ---------- cost_matrix : numpy.ndarray Symmetric cost matrix with numeric values. Should be square with dimensions matching the length of items. items : List[T] List of items corresponding to matrix indices. Used as node values in the resulting graph. Returns ------- Graph Undirected graph with nodes corresponding to matrix indices and edge weights equal to the cost matrix values. Only edges with positive costs are included. Node attributes 'value' contain the original items. Examples -------- Create a graph from a simple cost matrix: >>> import numpy as np >>> matrix = np.array([[0, 1, 2], [1, 0, 3], [2, 3, 0]]) >>> items = ['A', 'B', 'C'] >>> graph = Graph.from_cost_matrix(matrix, items) >>> list(graph.edges(data=True)) [('A', 'B', {'weight': 1}), ('A', 'C', {'weight': 2}), ('B', 'C', {'weight': 3})] """ graph = cls() node_list = [] for item in items: node_id = graph.add_node(value=item) node_list.append((node_id, {'value': item})) edge_list = [] for i in range(len(items)): for j in range(i + 1, len(items)): cost = cost_matrix[i, j] if cost > 0: edge_list.append((node_list[i][0], node_list[j][0], {'weight': cost})) if edge_list: graph._graph.add_edges_from(edge_list) return graph
def __deepcopy__(self, memo): new_graph = self.__class__.__new__(self.__class__) new_graph._graph = self._graph.copy() new_graph._meta = copy.deepcopy(self._meta, memo) new_graph._structure_version = 0 new_graph._topology_mutable = self._is_topology_mutable() new_graph._node_attr_mutable = self._is_node_attr_mutable() return new_graph
[docs] class GraphNodeView: """View of graph nodes that mimics NetworkX NodeView behavior."""
[docs] def __init__(self, graph: Graph): self._graph = graph
def __iter__(self): return iter(self._graph._graph.node_indices()) def __len__(self): return self._graph.number_of_nodes() def __contains__(self, node): return self._graph._graph.has_node(node) def __getitem__(self, node): node_data = self._graph._graph.get_node_data(node) if not isinstance(node_data, dict): return MappingProxyType({}) return MappingProxyType(node_data)
[docs] def __call__(self, data=False): """Return nodes with optional data.""" if data: for idx in self._graph._graph.node_indices(): node_data = self._graph._graph.get_node_data(idx) if isinstance(node_data, dict): yield (idx, MappingProxyType(node_data)) else: yield (idx, MappingProxyType({})) else: for idx in self._graph._graph.node_indices(): yield idx
[docs] class GraphEdgeView: """View of graph edges that mimics NetworkX EdgeView behavior."""
[docs] def __init__(self, graph: Graph): self._graph = graph
def __iter__(self): for edge_data in self._graph._graph.edge_list(): src_idx, tgt_idx = edge_data yield (src_idx, tgt_idx) def __len__(self): return self._graph.number_of_edges()
[docs] def __call__(self, data=False): """Return edges with optional data.""" if data: for src_idx, tgt_idx in self._graph._graph.edge_list(): edge_data = self._graph._graph.get_edge_data(src_idx, tgt_idx) yield (src_idx, tgt_idx, edge_data if isinstance(edge_data, dict) else {}) else: for src_idx, tgt_idx in self._graph._graph.edge_list(): yield (src_idx, tgt_idx)
[docs] def __getitem__(self, edge): """Get edge data for a given edge (u, v).""" u, v = edge edge_data = self._graph._graph.get_edge_data(u, v) return edge_data if isinstance(edge_data, dict) else {}