diff --git a/complex_tokenization/trainer.py b/complex_tokenization/trainer.py index db7a667..1b19ca2 100644 --- a/complex_tokenization/trainer.py +++ b/complex_tokenization/trainer.py @@ -26,27 +26,21 @@ def __init__(self, graph: GraphVertex | None = None, graphs: tuple[GraphVertex, self.merges = [] def train(self, num_merges: int = 100, draw=False, verbose=False, progress=False): - frames = [] - remaining = range(len(self.merges), num_merges) if progress: from tqdm import tqdm remaining = tqdm(remaining, desc="Training", initial=len(self.merges), total=num_merges) + frames = [] for _ in remaining: - if draw: - dot_content = "\n".join(self.graph.dot()) - image = draw_dot_content(dot_content) - frames.append(image) + frames.append(draw_dot_content("\n".join(self.graph.dot()))) counts = Counter(self.graph.get_merges()) - if not counts: break nodes = max(counts.items(), key=_merge_score)[0] - if verbose: print("Merging", nodes, "count=", counts[nodes]) token = reduce(lambda x, y: x + y, nodes) @@ -55,8 +49,7 @@ def train(self, num_merges: int = 100, draw=False, verbose=False, progress=False self.merges.append((token, nodes)) if draw: - gif = create_gif(frames, save="example.gif") - gif.show() + create_gif(frames, save="example.gif").show() def get_merges(self): return [tuple(str(node) for node in nodes) for _, nodes in self.merges] @@ -70,12 +63,5 @@ def get_merges(self): utf8("头"), )), )) - # example_sentence = "the teacher teaches the thick." - example_sentence = "test test" - # example_graph = sentence_to_graph(example_sentence) - - # other_graph = words(example_sentence) - # example_graph = NodesSequence((example_graph, utf8(" "), other_graph)) - trainer = Trainer(graph=example_graph) trainer.train(num_merges=10, draw=True)