Source code for klotho.topos.collections._pattern

from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum, auto
from itertools import cycle as itertools_cycle
from math import lcm
from typing import Any, List, Optional, TypeVar, Union

import numpy as np

T = TypeVar('T')


def _is_pattern(obj) -> bool:
    from klotho.topos.collections.sequences import Pattern

    return isinstance(obj, Pattern)


class SlotKind(Enum):
    LEAF = auto()
    STRUCTURAL = auto()
    DELEGATE_PATTERN = auto()
    DELEGATE_CYCLIC = auto()


@dataclass(frozen=True)
class NodeSpec:
    kind: SlotKind
    value: Any = None
    children: tuple[NodeSpec, ...] = ()
    sequence: tuple[Any, ...] = ()
    period: int = 1


[docs] class Cyclic: """Finite repeating sequence for use as a Pattern delegate.""" __slots__ = ('_sequence',)
[docs] def __init__(self, sequence): self._sequence = tuple(sequence) if not self._sequence: raise ValueError("Cyclic sequence cannot be empty")
@property def sequence(self): return self._sequence def __len__(self): return len(self._sequence) def __repr__(self): return f"Cyclic({list(self._sequence)!r})"
class _CyclicRuntime: __slots__ = ('_items', '_index') def __init__(self, items: tuple[Any, ...]): self._items = items self._index = 0 def __next__(self): value = self._items[self._index] self._index = (self._index + 1) % len(self._items) return value def reset(self): self._index = 0 def snapshot(self) -> int: return self._index def restore(self, index: int): self._index = index class _LeafRuntime: __slots__ = ('value',) def __init__(self, value): self.value = value class _StructuralRuntime: __slots__ = ('children', 'index') def __init__(self, children): self.children = children self.index = 0 def snapshot(self): return (self.index, [child_snapshot(c) for c in self.children]) def restore(self, state): self.index, child_states = state for child, child_state in zip(self.children, child_states): child_restore(child, child_state) class _DelegateRuntime: __slots__ = ('target',) def __init__(self, target): self.target = target def snapshot(self): target = self.target if _is_pattern(target): return ('pattern', target._snapshot()) return ('cyclic', target.snapshot()) def restore(self, state): kind, payload = state if kind == 'pattern': self.target._restore(payload) else: self.target.restore(payload) def child_snapshot(node): if isinstance(node, _LeafRuntime): return ('leaf',) if isinstance(node, _DelegateRuntime): return ('delegate', node.snapshot()) return ('struct', node.snapshot()) def child_restore(node, state): kind = state[0] if kind == 'leaf': return if kind == 'delegate': node.restore(state[1]) else: node.restore(state[1]) def _leaf_label(value: Any) -> str: text = repr(value) return text if len(text) <= 28 else text[:25] + '...' def _classify(item) -> tuple[SlotKind, Any]: if _is_pattern(item): return SlotKind.DELEGATE_PATTERN, item if isinstance(item, (Cyclic, itertools_cycle)): return SlotKind.DELEGATE_CYCLIC, item if isinstance(item, list): return SlotKind.STRUCTURAL, item if isinstance(item, np.ndarray): if item.ndim == 0: return SlotKind.LEAF, item.item() return SlotKind.STRUCTURAL, item return SlotKind.LEAF, item def _normalize_cyclic(item) -> tuple[Any, ...]: if isinstance(item, Cyclic): return item.sequence if isinstance(item, itertools_cycle): raise TypeError( "itertools.cycle is not supported in Pattern; use Cyclic([...]) " "or pass a list directly." ) raise TypeError(f"Expected Cyclic or itertools.cycle, got {type(item)!r}") def _compile_item(item) -> NodeSpec: kind, payload = _classify(item) if kind is SlotKind.LEAF: return NodeSpec(SlotKind.LEAF, value=payload, period=1) if kind is SlotKind.DELEGATE_PATTERN: delegate = payload return NodeSpec( SlotKind.DELEGATE_PATTERN, value=id(delegate), period=delegate.length, ) if kind is SlotKind.DELEGATE_CYCLIC: sequence = _normalize_cyclic(payload) return NodeSpec( SlotKind.DELEGATE_CYCLIC, sequence=sequence, period=len(sequence), ) children_payload = list(payload) if not children_payload: raise ValueError("Pattern cannot contain empty structural container") children = tuple(_compile_item(child) for child in children_payload) n = len(children) period = lcm(*[n * child.period for child in children]) return NodeSpec(SlotKind.STRUCTURAL, children=children, period=period) def _build_runtime(spec: NodeSpec, source_item): kind, payload = _classify(source_item) if kind is SlotKind.LEAF: return _LeafRuntime(spec.value) if kind is SlotKind.DELEGATE_PATTERN: return _DelegateRuntime(payload) if kind is SlotKind.DELEGATE_CYCLIC: return _DelegateRuntime(_CyclicRuntime(_normalize_cyclic(payload))) return _StructuralRuntime( [_build_runtime(child_spec, child_source) for child_spec, child_source in zip(spec.children, payload)] ) def _reset_runtime(node): if isinstance(node, _StructuralRuntime): node.index = 0 for child in node.children: _reset_runtime(child) elif isinstance(node, _DelegateRuntime): target = node.target if _is_pattern(target): target.reset() else: target.reset() def _advance_runtime(node): if isinstance(node, _LeafRuntime): return node.value if isinstance(node, _DelegateRuntime): return next(node.target) child = node.children[node.index] value = _advance_runtime(child) node.index = (node.index + 1) % len(node.children) return value def pattern_to_graph( pattern, *, expand_delegates: bool = True, ): """Build a NetworkX DiGraph describing *pattern* structure for plotting.""" import networkx as nx graph = nx.DiGraph() counter = 0 delegate_targets: dict[int, Any] = {} def _next_id() -> int: nonlocal counter node_id = counter counter += 1 return node_id def _structural_children(source): if isinstance(source, np.ndarray): return list(source) return list(source) def _walk(source, spec: NodeSpec, parent: Optional[int] = None, slot: Optional[int] = None) -> int: node_id = _next_id() if spec.kind is SlotKind.LEAF: graph.add_node(node_id, kind='leaf', label=_leaf_label(spec.value), period=1) elif spec.kind is SlotKind.STRUCTURAL: child_sources = _structural_children(source) graph.add_node( node_id, kind='structural', label=f'cycle x{len(spec.children)}', period=spec.period, n_slots=len(spec.children), child_ids=[], ) for idx, (child_spec, child_source) in enumerate(zip(spec.children, child_sources)): child_id = _walk(child_source, child_spec) graph.nodes[node_id]['child_ids'].append(child_id) graph.add_edge(node_id, child_id, slot=idx, edge_kind='slot') elif spec.kind is SlotKind.DELEGATE_PATTERN: delegate = source delegate_targets[node_id] = delegate graph.add_node( node_id, kind='delegate_pattern', label=f'Pattern (len={spec.period})', period=spec.period, delegate_id=id(delegate), ) if expand_delegates: sub_id = _walk(delegate.pattern, delegate.spec) graph.add_edge(node_id, sub_id, slot=0, edge_kind='delegate') elif spec.kind is SlotKind.DELEGATE_CYCLIC: graph.add_node( node_id, kind='delegate_cyclic', label=f'Cyclic {list(spec.sequence)!r}', period=spec.period, sequence=spec.sequence, ) if parent is not None: graph.add_edge(parent, node_id, slot=slot if slot is not None else 0, edge_kind='slot') return node_id root = _walk(pattern.pattern, pattern.spec) graph.graph['root'] = root graph.graph['pattern_length'] = pattern.length graph.graph['period_sequence'] = list(pattern.materialize_period()) graph.graph['delegate_targets'] = delegate_targets return graph