Formatron v0.4.9
Formatron empowers everyone to control the output format of language models with minimal overhead.
Loading...
Searching...
No Matches
exllamav2.py
Go to the documentation of this file.
1"""
2This module integrates the ExLlamaV2 library by providing convenience utilities.
3"""
4import typing
5from copy import copy, deepcopy
6import kbnf
7import torch
8from exllamav2 import ExLlamaV2Tokenizer, ExLlamaV2
9from exllamav2.generator.base import ExLlamaV2Filter
10from formatron.config import EngineGenerationConfig
11from formatron.formatter import FormatterBase, FormatterBuilder
12from formatron.integrations.utils import get_original_characters
13
14
15__all__ = ["create_engine_vocabulary", "create_formatter_filter", "FormatterFilter"]
16def create_engine_vocabulary(tokenizer: ExLlamaV2Tokenizer,
17 vocab_processors: typing.Optional[list[typing.Callable]] = None) -> kbnf.Vocabulary:
18 """
19 Create a vocabulary for the KBNF engine.
20 Args:
21 tokenizer: The tokenizer.
22 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
23 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
24 """
25 assert hasattr(tokenizer.tokenizer_model, "vocab"), (f"tokenizer({tokenizer})"
26 f" with tokenizer_model({tokenizer.tokenizer_model})"
27 f" does not have vocab attribute!")
28 vocab = {tokenizer.tokenizer_model.id_to_piece(
29 i): i for i in range(tokenizer.tokenizer_model.vocab_size())}
30 new_vocab = get_original_characters(vocab, vocab_processors)
31 return kbnf.Vocabulary({k: kbnf.Token(v) for k, v in new_vocab.items()},
32 {v: k for k, v in vocab.items()})
33
34
35def create_formatter_filter(model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer,
36 formatter_builder: FormatterBuilder,
37 engine_config: EngineGenerationConfig = None,
38 vocab_processors: typing.Optional[list[typing.Callable]] = None) -> ExLlamaV2Filter:
39 """
40 Create a formatter filter for the ExLlamaV2 engine.
41 Args:
42 model: The ExLlamaV2 model.
43 tokenizer: The ExLlamaV2 tokenizer.
44 formatter_builder: The formatter builder.
45 engine_config: The engine generation configuration.
46 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
47 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
48 """
49 vocab = create_engine_vocabulary(tokenizer, vocab_processors)
50 f = formatter_builder.build(
51 vocab, lambda tokens: tokenizer.decode(torch.tensor(tokens)))
52 return FormatterFilter(model, tokenizer, f, engine_config)
53
54
55class FormatterFilter(ExLlamaV2Filter):
56 """
57 ExLlamaV2Filter that uses a formatter to mask logits.
58 """
60 def __init__(self, model, tokenizer, formatter: FormatterBase,
61 config: EngineGenerationConfig|None = None):
62 super().__init__(model, tokenizer)
63 self._formatter = formatter
64 if config is None:
65 config = EngineGenerationConfig()
66 self._config = config
67 self._pass_tokens = set()
68 self.eos_logits = None
70 def is_completed(self) -> bool:
71 """
72 Check if the formatter is completed.
73 """
75
76 def clone(self, c=None) -> "FormatterFilter":
77 if c is None:
78 c = FormatterFilter.__new__(FormatterFilter)
79 c.model = self.model
80 c.tokenizer = self.tokenizer
81 c.sequence_str = self.sequence_str
82 # formatter does not have mutable public state anyway
83 c._formatter = copy(self._formatter)
84 c._config = deepcopy(self._config)
85 c._pass_tokens = self._pass_tokens
86 return c
87
88 def begin(self, prefix_str: str) -> None:
89 if self._config.reset_at_beginning:
90 self._formatter.reset()
91 if self._config.read_prompt:
92 prompt = prefix_str.encode("utf-8")
93 self._formatter.accept_bytes(prompt)
94
95 def reset(self) -> None:
97
98 def feed(self, token: int):
100 return None
101 self._formatter.accept_token(token)
102
103 # adapted from https://github.com/Dan-wanna-M/formatron/issues/14
104 # Old version for compatibility
105 def next_set(self) -> typing.Tuple[typing.Set[int], typing.Set[int]]:
107 return {self.tokenizer.eos_token_id}, {self.tokenizer.eos_token_id}
108 self._formatter.compute_allowed_tokens()
109 self._pass_tokens.clear()
110 self._pass_tokens.update(self._formatter.get_allowed_tokens_since_last_computation())
111 return self._pass_tokens, set()
112
113 # adapted from https://github.com/Dan-wanna-M/formatron/issues/14
114 def next(self) -> typing.Tuple[typing.Sequence[int], typing.Sequence[int]]:
115 # Kludge to maintain compatibility with exllamav2 <= 0.2.0
116 if not hasattr(self, "allow_return_type_list"):
117 return self.next_set()
118 if self._formatter.is_completed():
119 return [self.tokenizer.eos_token_id], [self.tokenizer.eos_token_id]
120 self._formatter.compute_allowed_tokens()
121 return self._formatter.get_allowed_tokens_since_last_computation(), []
122
123 # adapted from https://github.com/Dan-wanna-M/formatron/issues/14
124 def use_background_worker(self) -> bool:
125 return True
126
127 # Used by ExLlamaV2 > 0.2.3
128 def can_mask_logits(self) -> bool:
129 return True
130
131 def prepare_logit_mask(self):
132 self._formatter.compute_allowed_tokens()
133 return True
134
135 def mask_logits(self, logits: torch.Tensor) -> torch.Tensor:
137 if self.eos_logits is None:
138 self.eos_logits = torch.full_like(logits, float("-inf"))
139 self.eos_logits[self.tokenizer.eos_token_id] = 0
140 return self.eos_logits
141 return self._formatter.mask_logits(logits)
142
143 @property
144 def formatter_captures(self) -> dict[str, typing.Any]:
145 """
146 Get the captures of the formatter.
147 """
148 return self._formatter.captures
Configuration for how an KBNF engine should be used in text generation.
Definition config.py:14
ExLlamaV2Filter that uses a formatter to mask logits.
Definition exllamav2.py:59
"FormatterFilter" clone(self, c=None)
Definition exllamav2.py:77
torch.Tensor mask_logits(self, torch.Tensor logits)
Definition exllamav2.py:136
typing.Tuple[typing.Sequence[int], typing.Sequence[int]] next(self)
Definition exllamav2.py:115
dict[str, typing.Any] formatter_captures(self)
Get the captures of the formatter.
Definition exllamav2.py:155
typing.Tuple[typing.Set[int], typing.Set[int]] next_set(self)
Definition exllamav2.py:106
bool is_completed(self)
Check if the formatter is completed.
Definition exllamav2.py:74
__init__(self, model, tokenizer, FormatterBase formatter, EngineGenerationConfig|None config=None)
Definition exllamav2.py:62
Configuration classes for Formatron.
Definition config.py:1
This module contains the Formatter class and its related classes.
Definition formatter.py:1
ExLlamaV2Filter create_formatter_filter(ExLlamaV2 model, ExLlamaV2Tokenizer tokenizer, FormatterBuilder formatter_builder, EngineGenerationConfig engine_config=None, typing.Optional[list[typing.Callable]] vocab_processors=None)
Create a formatter filter for the ExLlamaV2 engine.
Definition exllamav2.py:49
kbnf.Vocabulary create_engine_vocabulary(ExLlamaV2Tokenizer tokenizer, typing.Optional[list[typing.Callable]] vocab_processors=None)
Create a vocabulary for the KBNF engine.
Definition exllamav2.py:25