Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions complex_tokenization/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""

from collections.abc import Callable
from functools import reduce
from functools import lru_cache, reduce

from tokenizers.pre_tokenizers import PreTokenizer

Expand All @@ -37,6 +37,7 @@ def __init__(
merge_size: int = 2,
connected: bool = False,
pretokenizer: PreTokenizer = GPTPretokenizer,
cache_maxsize: int | None = None,
):
if isinstance(units, str):
if units not in UNIT_FUNCTIONS:
Expand All @@ -47,6 +48,7 @@ def __init__(
self.merge_size = merge_size
self.connected = connected
self.pretokenizer = pretokenizer
self.cache_maxsize = cache_maxsize
self.merges: list[tuple[str, ...]] = []

@staticmethod
Expand All @@ -57,8 +59,17 @@ def add_merges(self, merges: list[tuple[str, ...]]):
self.merges.extend(merges)

def _build_graphs(self, texts: list[str]) -> tuple[GraphVertex, ...]:
# Deduplicate identical word graphs within this build: repeated words
# share one immutable subgraph (and its get_merges memo) instead of N
# copies. The cache is local to the build, so it's freed before training
# (no pinning of pre-merge graphs) and can't leak a settings-dependent
# graph to a later run. cache_maxsize=None is unbounded; 0 disables.
if self.cache_maxsize == 0:
units = self.units
else:
units = lru_cache(maxsize=self.cache_maxsize)(self.units)
return tuple(
words(text, connected=self.connected, units=self.units,
words(text, connected=self.connected, units=units,
pretokenizer=self.pretokenizer)
for text in texts
)
Expand Down Expand Up @@ -98,29 +109,29 @@ def get_merges(self) -> list[tuple[str, ...]]:


class BPETokenizer(Tokenizer):
def __init__(self, units="utf8_clusters", pretokenizer=GPTPretokenizer):
super().__init__(units=units, merge_size=2, connected=False, pretokenizer=pretokenizer)
def __init__(self, **kwargs):
super().__init__(merge_size=2, connected=False, **kwargs)


class BNETokenizer(Tokenizer):
def __init__(self, n=4, units="utf8_clusters", pretokenizer=GPTPretokenizer):
super().__init__(units=units, merge_size=n, connected=False, pretokenizer=pretokenizer)
def __init__(self, n=4, **kwargs):
super().__init__(merge_size=n, connected=False, **kwargs)


class BoundlessBPETokenizer(Tokenizer):
def __init__(self, units="utf8_clusters", pretokenizer=GPTPretokenizer):
super().__init__(units=units, merge_size=2, connected=True, pretokenizer=pretokenizer)
def __init__(self, **kwargs):
super().__init__(merge_size=2, connected=True, **kwargs)


class SuperBPETokenizer(Tokenizer):
def __init__(self, units="utf8_clusters", disconnected_merges: int | None = None, pretokenizer=GPTPretokenizer):
super().__init__(units=units, merge_size=2, connected=False, pretokenizer=pretokenizer)
def __init__(self, disconnected_merges: int | None = None, **kwargs):
super().__init__(merge_size=2, connected=False, **kwargs)
self._disconnected_merges = disconnected_merges

def train(self, texts: list[str], num_merges: int = 100, progress: bool = False) -> list[tuple[str, ...]]:
disconnected_merges = self._disconnected_merges or num_merges // 2

phase1 = BPETokenizer(units=self.units, pretokenizer=self.pretokenizer)
phase1 = BPETokenizer(units=self.units, pretokenizer=self.pretokenizer, cache_maxsize=self.cache_maxsize)
phase1.train(texts, num_merges=disconnected_merges, progress=progress)

self.connected = True
Expand Down
8 changes: 8 additions & 0 deletions tests/test_tokenizer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def test_get_merges_before_train(self):
tok = Tokenizer()
assert tok.get_merges() == []

def test_cache_maxsize_does_not_change_merges(self):
# Word-graph dedup is a memory/speed optimization; the merges it produces
# must be identical with the cache off (0) or on (small).
texts = ["the the the cat the cat sat the"]
uncached = BPETokenizer(cache_maxsize=0).train(texts, num_merges=10)
cached = BPETokenizer(cache_maxsize=10).train(texts, num_merges=10)
assert uncached == cached

def test_super_bpe_phase1_matches_bpe(self):
texts = ["the teacher teaches the thick thing"] * 3
bpe = BPETokenizer()
Expand Down
Loading