From 0e651f23097d8c213aecb8b416e445e5b00d32f1 Mon Sep 17 00:00:00 2001 From: AmitMY Date: Sat, 27 Jun 2026 18:07:05 +0200 Subject: [PATCH] perf: make Node subclass bytes for C-level hash and equality MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The training loop's cost is hashing and comparing candidate node-tuples: Counter(graph.get_merges()) hashes them every step, and merge scans compare them. With Node a frozen dataclass, __hash__/__eq__ were Python calls. Making Node a bytes subclass (a Node *is* its bytes) gives those operations C-level implementations — and tuples of bytes hash ~4x faster than tuples of dataclass instances. This speeds the recount path too, so unlike the disconnected-only incremental work it also helps Boundless/SuperBPE. It's net simpler: __eq__/__hash__/__len__ are inherited from bytes, and Node no longer needs a `value` field/property at all (a Node *is* its bytes) — Tree.merge just returns the token. bytes wins the MRO for the dunders we want; we keep GraphVertex.__str__ explicitly. Memory is unchanged (nodes are interned). ~1.7x BPE, ~2.3x BNE, ~1.8x Boundless; memory unchanged; output identical (digests unchanged); 138 tests pass. Note: drops the public Node.value attribute — a Node is now its own bytes, so use the node directly or bytes(node). Co-Authored-By: Claude Opus 4.8 (1M context) --- complex_tokenization/graph.py | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/complex_tokenization/graph.py b/complex_tokenization/graph.py index 27e631e..efd63d1 100644 --- a/complex_tokenization/graph.py +++ b/complex_tokenization/graph.py @@ -47,12 +47,17 @@ def node_count(self) -> int: raise NotImplementedError -@dataclass(frozen=True, slots=True) -class Node(GraphVertex): - value: bytes +class Node(bytes, GraphVertex): + # A Node *is* its bytes, so __hash__/__eq__/__len__ are bytes' C-level + # operations — which is what the trainer's Counter and merge scans hammer. + # bytes wins the MRO for those, but we still want GraphVertex's __str__. + __str__ = GraphVertex.__str__ + + def __new__(cls, value: bytes): + return super().__new__(cls, value) def __bytes__(self): - return self.value + return self[:] # a plain bytes copy (not the Node subclass) def dot(self, level=0) -> Iterable[str]: yield "\t" * level + f'{self.oid} [label="{dot_escape(str(self))}"];' @@ -63,23 +68,10 @@ def merge(self, token: "Node", merge: tuple): def node_count(self) -> int: return 1 - def __eq__(self, other): - if not isinstance(other, Node): - return False - return self.value == other.value - - def __hash__(self): - # hash(bytes) is cached by CPython; the dataclass default hash((value,)) - # rebuilds and rehashes a 1-tuple on every call. - return hash(self.value) - def __add__(self, other): if isinstance(other, NodesSequence): - return NodesSequence(tuple([self]) + other.nodes) - return Node(value=self.value + other.value) - - def __len__(self): - return len(self.value) + return NodesSequence((self,) + other.nodes) + return Node(b"".join((self, other))) # both are bytes; join avoids Node.__add__ recursion @dataclass(frozen=True, slots=True) @@ -233,7 +225,7 @@ def merge(self, token: Node, nodes: tuple): if nodes[0] == self.root: if len(nodes) == len(self.children) + 1: if all(nodes[i + 1] == child for i, child in enumerate(self.children)): - return Node(value=token.value) + return token root = self.root.merge(token, nodes) children = tuple(child.merge(token, nodes) for child in self.children)