Formatron v0.4.2
Formatron empowers everyone to control the output format of language models with minimal overhead.
Loading...
Searching...
No Matches
exllamav2.py
Go to the documentation of this file.
1"""
2This module integrates the ExLlamaV2 library by providing convenience utilities.
3"""
4import typing
5from copy import copy, deepcopy
6import kbnf
7import torch
8from exllamav2 import ExLlamaV2Tokenizer, ExLlamaV2
9from exllamav2.generator.base import ExLlamaV2Filter
10from formatron.config import EngineGenerationConfig
11from formatron.formatter import FormatterBase, FormatterBuilder
12from formatron.integrations._utils import get_original_characters
13from functools import lru_cache
14
15
16def create_engine_vocabulary(tokenizer: ExLlamaV2Tokenizer) -> kbnf.Vocabulary:
17 """
18 Create a vocabulary for the KBNF engine.
19 """
20 assert hasattr(tokenizer.tokenizer_model, "vocab"), (f"tokenizer({tokenizer})"
21 f" with tokenizer_model({tokenizer.tokenizer_model})"
22 f" does not have vocab attribute!")
23 vocab = {tokenizer.tokenizer_model.id_to_piece(
24 i): i for i in range(tokenizer.tokenizer_model.vocab_size())}
25 new_vocab = get_original_characters(vocab)
26 return kbnf.Vocabulary({v: kbnf.Token(k) for k, v in new_vocab.items()},
27 {k: v for k, v in enumerate(vocab)})
28
29
30def create_formatter_filter(model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer,
31 formatter_builder: FormatterBuilder,
32 engine_config: EngineGenerationConfig = None) -> ExLlamaV2Filter:
33 """
34 Create a formatter filter for the ExLlamaV2 engine.
35 """
36 vocab = create_engine_vocabulary(tokenizer)
37 f = formatter_builder.build(
38 vocab, lambda tokens: tokenizer.decode(torch.tensor(tokens)))
39 return FormatterFilter(model, tokenizer, f, engine_config)
40
41
42class FormatterFilter(ExLlamaV2Filter):
43 """
44 ExLlamaV2Filter that uses a formatter to mask logits.
45 """
47 def __init__(self, model, tokenizer, formatter: FormatterBase,
48 config: EngineGenerationConfig = None):
49 super().__init__(model, tokenizer)
50 self._formatter = formatter
51 if config is None:
52 config = EngineGenerationConfig()
53 self._config = config
54 self._pass_tokens = set()
56 def is_completed(self) -> bool:
57 """
58 Check if the formatter is completed.
59 """
61
62 def clone(self, c=None) -> "FormatterFilter":
63 if c is None:
64 c = FormatterFilter.__new__(FormatterFilter)
65 c.model = self.model
66 c.tokenizer = self.tokenizer
67 c.sequence_str = self.sequence_str
68 # formatter does not have mutable public state anyway
69 c._formatter = copy(self._formatter)
70 c._config = deepcopy(self._config)
71 c._pass_tokens = self._pass_tokens
72 return c
73
74 def begin(self, prefix_str: str) -> None:
75 if self._config.reset_at_beginning:
76 self._formatter.reset()
77 if self._config.read_prompt:
78 prompt = prefix_str.encode("utf-8")
79 self._formatter.accept_bytes(prompt)
80
81 def reset(self) -> None:
83
84 def feed(self, token: int):
86 return None
87 self._formatter.accept_token(token)
88
89 # adapted from https://github.com/Dan-wanna-M/formatron/issues/14
90 # Old version for compatibility
91 def next_set(self) -> typing.Tuple[typing.Set[int], typing.Set[int]]:
93 return {self.tokenizer.eos_token_id}, {self.tokenizer.eos_token_id}
94 self._formatter.compute_allowed_tokens()
95 self._pass_tokens.clear()
96 self._pass_tokens.update(self._formatter.get_allowed_tokens_since_last_computation())
97 return self._pass_tokens, set()
98
99 # adapted from https://github.com/Dan-wanna-M/formatron/issues/14
100 def next(self) -> typing.Tuple[typing.Sequence[int], typing.Sequence[int]]:
101 # Kludge to maintain compatibility with exllamav2 <= 0.2.0
102 if not hasattr(self, "allow_return_type_list"):
103 return self.next_set()
104 if self._formatter.is_completed():
105 return [self.tokenizer.eos_token_id], [self.tokenizer.eos_token_id]
106 self._formatter.compute_allowed_tokens()
107 return self._formatter.get_allowed_tokens_since_last_computation(), []
108
109 # adapted from https://github.com/Dan-wanna-M/formatron/issues/14
110 def use_background_worker(self) -> bool:
111 return True
112
113 @property
114 def formatter_captures(self) -> dict[str, typing.Any]:
115 return self._formatter.captures
Configuration for how an KBNF engine should be used in text generation.
Definition config.py:14
ExLlamaV2Filter that uses a formatter to mask logits.
Definition exllamav2.py:46
"FormatterFilter" clone(self, c=None)
Definition exllamav2.py:63
__init__(self, model, tokenizer, FormatterBase formatter, EngineGenerationConfig config=None)
Definition exllamav2.py:49
typing.Tuple[typing.Sequence[int], typing.Sequence[int]] next(self)
Definition exllamav2.py:101
dict[str, typing.Any] formatter_captures(self)
Definition exllamav2.py:122
typing.Tuple[typing.Set[int], typing.Set[int]] next_set(self)
Definition exllamav2.py:92
bool is_completed(self)
Check if the formatter is completed.
Definition exllamav2.py:60
Configuration classes for Formatron.
Definition config.py:1
This module contains the Formatter class and its related classes.
Definition formatter.py:1
ExLlamaV2Filter create_formatter_filter(ExLlamaV2 model, ExLlamaV2Tokenizer tokenizer, FormatterBuilder formatter_builder, EngineGenerationConfig engine_config=None)
Create a formatter filter for the ExLlamaV2 engine.
Definition exllamav2.py:36
kbnf.Vocabulary create_engine_vocabulary(ExLlamaV2Tokenizer tokenizer)
Create a vocabulary for the KBNF engine.
Definition exllamav2.py:20