From 25b3773726649dd2708f92328a6c0db0ade0547c Mon Sep 17 00:00:00 2001 From: AmitMY Date: Sat, 27 Jun 2026 10:28:49 +0200 Subject: [PATCH] perf: incremental candidate counting for disconnected training MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The training loop rebuilt Counter(graph.get_merges()) from scratch every step — recounting the whole forest, even the subgraphs the chosen merge never touched. For a disconnected forest (BPE/BNE, where each merge hits only the few words containing the pair) almost all of it is wasted. _train_incremental keeps each subgraph's candidate counts plus a running global total, and after a merge updates only the subgraphs that contained it (located via an index). total is summed in subgraph order, so picking the first max-score candidate reproduces max(Counter(...), key=score) exactly, including tie-breaks — output is byte-identical. Gated to the disconnected case (Tokenizer passes incremental=not connected); connected/draw/verbose keep the plain recount loop, so Boundless/SuperBPE are unchanged. The two loops share _steps() for the progress scaffolding; their cores differ (Counter-rebuild vs total+index update). Trade-off: the persistent count state raises peak memory. Stacked on the bytes-Node change: BPE ~6.6x, BNE ~7.4x over the original baseline. 138 tests pass; digests identical. Co-Authored-By: Claude Opus 4.8 (1M context) --- complex_tokenization/tokenizer.py | 2 +- complex_tokenization/trainer.py | 60 +++++++++++++++++++++++++++++-- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/complex_tokenization/tokenizer.py b/complex_tokenization/tokenizer.py index c6ba315..c75a4a8 100644 --- a/complex_tokenization/tokenizer.py +++ b/complex_tokenization/tokenizer.py @@ -100,7 +100,7 @@ def train_on_trainer( GraphSettings.ONLY_MINIMAL_MERGES = True GraphSettings.MAX_MERGE_SIZE = self.merge_size - trainer.train(num_merges=num_merges, progress=progress) + trainer.train(num_merges=num_merges, progress=progress, incremental=not self.connected) self.merges = trainer.get_merges() return trainer, self.merges diff --git a/complex_tokenization/trainer.py b/complex_tokenization/trainer.py index 1b19ca2..39cb544 100644 --- a/complex_tokenization/trainer.py +++ b/complex_tokenization/trainer.py @@ -1,4 +1,4 @@ -from collections import Counter +from collections import Counter, defaultdict from functools import reduce from complex_tokenization.draw import create_gif, draw_dot_content @@ -12,6 +12,23 @@ def _merge_score(item): return (len(nodes) - 1) * count +def _index_add(total, index, i, counts): + for merge, count in counts.items(): + total[merge] += count + index[merge].add(i) + + +def _index_remove(total, index, i, counts): + for merge, count in counts.items(): + total[merge] -= count + if total[merge] == 0: + del total[merge] + others = index[merge] + others.discard(i) + if not others: + del index[merge] + + class Trainer: def __init__(self, graph: GraphVertex | None = None, graphs: tuple[GraphVertex, ...] | None = None): if graphs is None and graph is None: @@ -25,14 +42,21 @@ def __init__(self, graph: GraphVertex | None = None, graphs: tuple[GraphVertex, self.graph = graph self.merges = [] - def train(self, num_merges: int = 100, draw=False, verbose=False, progress=False): + def _steps(self, num_merges: int, progress: bool): remaining = range(len(self.merges), num_merges) if progress: from tqdm import tqdm remaining = tqdm(remaining, desc="Training", initial=len(self.merges), total=num_merges) + return remaining + + def train(self, num_merges: int = 100, draw=False, verbose=False, progress=False, incremental=False): + # Incremental counting only helps a forest where each merge touches few + # subgraphs (disconnected, word-level). draw/verbose use the plain loop. + if incremental and not draw and not verbose and isinstance(self.graph, UnconnectedGraphs): + return self._train_incremental(num_merges, progress) frames = [] - for _ in remaining: + for _ in self._steps(num_merges, progress): if draw: frames.append(draw_dot_content("\n".join(self.graph.dot()))) @@ -51,6 +75,36 @@ def train(self, num_merges: int = 100, draw=False, verbose=False, progress=False if draw: create_gif(frames, save="example.gif").show() + def _train_incremental(self, num_merges: int, progress: bool = False): + # Rebuilding Counter(graph.get_merges()) every step recounts the whole + # forest. Instead, keep each subgraph's candidate counts plus a running + # global total, and after a merge update only the subgraphs that + # contained it (found via `index`). total is summed in subgraph order, so + # picking the first max-score candidate matches max(Counter(...)) exactly. + components = list(self.graph.subgraphs) + comp_counts = [Counter(c.get_merges()) for c in components] + total: dict[tuple, int] = defaultdict(int) + index: dict[tuple, set[int]] = defaultdict(set) + for i, counts in enumerate(comp_counts): + _index_add(total, index, i, counts) + + for _ in self._steps(num_merges, progress): + if not total: + break + best = max((len(m) - 1) * c for m, c in total.items()) + nodes = next(m for cc in comp_counts for m in cc if (len(m) - 1) * total[m] == best) + token = reduce(lambda x, y: x + y, nodes) + + for i in list(index[nodes]): + _index_remove(total, index, i, comp_counts[i]) + components[i] = components[i].merge(token, nodes) + comp_counts[i] = Counter(components[i].get_merges()) + _index_add(total, index, i, comp_counts[i]) + + self.merges.append((token, nodes)) + + self.graph = UnconnectedGraphs(tuple(components)) + def get_merges(self): return [tuple(str(node) for node in nodes) for _, nodes in self.merges]