Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 12 additions & 20 deletions complex_tokenization/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,17 @@ def node_count(self) -> int:
raise NotImplementedError


@dataclass(frozen=True, slots=True)
class Node(GraphVertex):
value: bytes
class Node(bytes, GraphVertex):
# A Node *is* its bytes, so __hash__/__eq__/__len__ are bytes' C-level
# operations — which is what the trainer's Counter and merge scans hammer.
# bytes wins the MRO for those, but we still want GraphVertex's __str__.
__str__ = GraphVertex.__str__

def __new__(cls, value: bytes):
return super().__new__(cls, value)

def __bytes__(self):
return self.value
return self[:] # a plain bytes copy (not the Node subclass)

def dot(self, level=0) -> Iterable[str]:
yield "\t" * level + f'{self.oid} [label="{dot_escape(str(self))}"];'
Expand All @@ -63,23 +68,10 @@ def merge(self, token: "Node", merge: tuple):
def node_count(self) -> int:
return 1

def __eq__(self, other):
if not isinstance(other, Node):
return False
return self.value == other.value

def __hash__(self):
# hash(bytes) is cached by CPython; the dataclass default hash((value,))
# rebuilds and rehashes a 1-tuple on every call.
return hash(self.value)

def __add__(self, other):
if isinstance(other, NodesSequence):
return NodesSequence(tuple([self]) + other.nodes)
return Node(value=self.value + other.value)

def __len__(self):
return len(self.value)
return NodesSequence((self,) + other.nodes)
return Node(b"".join((self, other))) # both are bytes; join avoids Node.__add__ recursion


@dataclass(frozen=True, slots=True)
Expand Down Expand Up @@ -233,7 +225,7 @@ def merge(self, token: Node, nodes: tuple):
if nodes[0] == self.root:
if len(nodes) == len(self.children) + 1:
if all(nodes[i + 1] == child for i, child in enumerate(self.children)):
return Node(value=token.value)
return token

root = self.root.merge(token, nodes)
children = tuple(child.merge(token, nodes) for child in self.children)
Expand Down
Loading