diff --git a/complex_tokenization/graph.py b/complex_tokenization/graph.py index 27e631e..efd63d1 100644 --- a/complex_tokenization/graph.py +++ b/complex_tokenization/graph.py @@ -47,12 +47,17 @@ def node_count(self) -> int: raise NotImplementedError -@dataclass(frozen=True, slots=True) -class Node(GraphVertex): - value: bytes +class Node(bytes, GraphVertex): + # A Node *is* its bytes, so __hash__/__eq__/__len__ are bytes' C-level + # operations — which is what the trainer's Counter and merge scans hammer. + # bytes wins the MRO for those, but we still want GraphVertex's __str__. + __str__ = GraphVertex.__str__ + + def __new__(cls, value: bytes): + return super().__new__(cls, value) def __bytes__(self): - return self.value + return self[:] # a plain bytes copy (not the Node subclass) def dot(self, level=0) -> Iterable[str]: yield "\t" * level + f'{self.oid} [label="{dot_escape(str(self))}"];' @@ -63,23 +68,10 @@ def merge(self, token: "Node", merge: tuple): def node_count(self) -> int: return 1 - def __eq__(self, other): - if not isinstance(other, Node): - return False - return self.value == other.value - - def __hash__(self): - # hash(bytes) is cached by CPython; the dataclass default hash((value,)) - # rebuilds and rehashes a 1-tuple on every call. - return hash(self.value) - def __add__(self, other): if isinstance(other, NodesSequence): - return NodesSequence(tuple([self]) + other.nodes) - return Node(value=self.value + other.value) - - def __len__(self): - return len(self.value) + return NodesSequence((self,) + other.nodes) + return Node(b"".join((self, other))) # both are bytes; join avoids Node.__add__ recursion @dataclass(frozen=True, slots=True) @@ -233,7 +225,7 @@ def merge(self, token: Node, nodes: tuple): if nodes[0] == self.root: if len(nodes) == len(self.children) + 1: if all(nodes[i + 1] == child for i, child in enumerate(self.children)): - return Node(value=token.value) + return token root = self.root.merge(token, nodes) children = tuple(child.merge(token, nodes) for child in self.children)