Formatron v0.5.0
Formatron empowers everyone to control the output format of language models with minimal overhead.
Loading...
Searching...
No Matches
exllamav2.py
Go to the documentation of this file.
1"""
2This module integrates the ExLlamaV2 library by providing convenience utilities.
3"""
4import typing
5from copy import copy, deepcopy
6import kbnf
7import torch
8from exllamav2 import ExLlamaV2Tokenizer, ExLlamaV2
9from exllamav2.generator.base import ExLlamaV2Filter
10from formatron.config import EngineGenerationConfig
11from formatron.formatter import FormatterBase, FormatterBuilder
12from formatron.integrations.utils import get_original_characters, default_mask_logits_fn, get_bit_mask
13
14
15__all__ = ["create_engine_vocabulary", "create_formatter_filter", "FormatterFilter"]
16def create_engine_vocabulary(tokenizer: ExLlamaV2Tokenizer,
17 vocab_processors: typing.Optional[list[typing.Callable]] = None) -> kbnf.Vocabulary:
18 """
19 Create a vocabulary for the KBNF engine.
20 Args:
21 tokenizer: The tokenizer.
22 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
23 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
24 """
25 assert hasattr(tokenizer.tokenizer_model, "vocab"), (f"tokenizer({tokenizer})"
26 f" with tokenizer_model({tokenizer.tokenizer_model})"
27 f" does not have vocab attribute!")
28 vocab = {tokenizer.tokenizer_model.id_to_piece(
29 i): i for i in range(tokenizer.tokenizer_model.vocab_size())}
30 new_vocab = get_original_characters(vocab, vocab_processors)
31 return kbnf.Vocabulary({k: kbnf.Token(v) for k, v in new_vocab.items()},
32 {v: k for k, v in vocab.items()})
33
34
35def create_formatter_filter(model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer,
36 formatter_builder: FormatterBuilder,
37 engine_config: EngineGenerationConfig = None,
38 vocab_processors: typing.Optional[list[typing.Callable]] = None) -> ExLlamaV2Filter:
39 """
40 Create a formatter filter for the ExLlamaV2 engine.
41 Args:
42 model: The ExLlamaV2 model.
43 tokenizer: The ExLlamaV2 tokenizer.
44 formatter_builder: The formatter builder.
45 engine_config: The engine generation configuration.
46 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
47 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
48 """
49 vocab = create_engine_vocabulary(tokenizer, vocab_processors)
50 f = formatter_builder.build(
51 vocab, lambda tokens: tokenizer.decode(torch.tensor(tokens)))
52 return FormatterFilter(model, tokenizer, f, engine_config)
53
54
55class FormatterFilter(ExLlamaV2Filter):
56 """
57 ExLlamaV2Filter that uses a formatter to mask logits.
58 """
60 def __init__(self, model, tokenizer, formatter: FormatterBase,
61 config: EngineGenerationConfig|None = None):
62 super().__init__(model, tokenizer)
63 self._formatter = formatter
64 if config is None:
65 config = EngineGenerationConfig()
66 self._config = config
67 self._pass_tokens = set()
68 self.eos_logits = None
69 self._mask_logits_fn = default_mask_logits_fn
70 self._bit_mask = None
72 def is_completed(self) -> bool:
73 """
74 Check if the formatter is completed.
75 """
77
78 def clone(self, c=None) -> "FormatterFilter":
79 if c is None:
80 c = FormatterFilter.__new__(FormatterFilter)
81 c.model = self.model
82 c.tokenizer = self.tokenizer
83 c.sequence_str = self.sequence_str
84 # formatter does not have mutable public state anyway
85 c._formatter = copy(self._formatter)
86 c._config = deepcopy(self._config)
87 c._pass_tokens = self._pass_tokens
88 return c
89
90 def begin(self, prefix_str: str) -> None:
91 if self._config.reset_at_beginning:
92 self._formatter.reset()
93 if self._config.read_prompt:
94 prompt = prefix_str.encode("utf-8")
95 self._formatter.accept_bytes(prompt)
96
97 def reset(self) -> None:
99
100 def feed(self, token: int):
102 return None
103 self._formatter.accept_token(token)
104
105 # adapted from https://github.com/Dan-wanna-M/formatron/issues/14
106 # Old version for compatibility
107 def next_set(self) -> typing.Tuple[typing.Set[int], typing.Set[int]]:
109 return {self.tokenizer.eos_token_id}, {self.tokenizer.eos_token_id}
110 self._formatter.compute_allowed_tokens()
111 self._pass_tokens.clear()
112 self._pass_tokens.update(self._formatter.get_allowed_tokens_since_last_computation())
113 return self._pass_tokens, set()
114
115 # adapted from https://github.com/Dan-wanna-M/formatron/issues/14
116 def next(self) -> typing.Tuple[typing.Sequence[int], typing.Sequence[int]]:
117 # Kludge to maintain compatibility with exllamav2 <= 0.2.0
118 if not hasattr(self, "allow_return_type_list"):
119 return self.next_set()
120 if self._formatter.is_completed():
121 return [self.tokenizer.eos_token_id], [self.tokenizer.eos_token_id]
122 self._formatter.compute_allowed_tokens()
123 return self._formatter.get_allowed_tokens_since_last_computation(), []
124
125 # adapted from https://github.com/Dan-wanna-M/formatron/issues/14
126 def use_background_worker(self) -> bool:
127 return True
128
129 # Used by ExLlamaV2 > 0.2.3
130 def can_mask_logits(self) -> bool:
131 return True
132
133 def prepare_logit_mask(self):
134 self._formatter.compute_allowed_tokens()
135 return True
136
137 def mask_logits(self, logits: torch.Tensor) -> torch.Tensor:
138 if self._bit_mask is None:
139 self._bit_mask = get_bit_mask(logits)
140 if self._formatter.is_completed():
141 if self.eos_logits is None:
142 self.eos_logits = torch.full_like(logits, float("-inf"))
143 self.eos_logits[self.tokenizer.eos_token_id] = 0
144 return self.eos_logits
145 return self._mask_logits_fn(self._bit_mask, self._formatter, logits)
146
147 @property
148 def formatter_captures(self) -> dict[str, typing.Any]:
149 """
150 Get the captures of the formatter.
151 """
152 return self._formatter.captures
Configuration for how an KBNF engine should be used in text generation.
Definition config.py:14
ExLlamaV2Filter that uses a formatter to mask logits.
Definition exllamav2.py:59
"FormatterFilter" clone(self, c=None)
Definition exllamav2.py:79
torch.Tensor mask_logits(self, torch.Tensor logits)
Definition exllamav2.py:138
typing.Tuple[typing.Sequence[int], typing.Sequence[int]] next(self)
Definition exllamav2.py:117
dict[str, typing.Any] formatter_captures(self)
Get the captures of the formatter.
Definition exllamav2.py:159
typing.Tuple[typing.Set[int], typing.Set[int]] next_set(self)
Definition exllamav2.py:108
bool is_completed(self)
Check if the formatter is completed.
Definition exllamav2.py:76
__init__(self, model, tokenizer, FormatterBase formatter, EngineGenerationConfig|None config=None)
Definition exllamav2.py:62
Configuration classes for Formatron.
Definition config.py:1
This module contains the Formatter class and its related classes.
Definition formatter.py:1
ExLlamaV2Filter create_formatter_filter(ExLlamaV2 model, ExLlamaV2Tokenizer tokenizer, FormatterBuilder formatter_builder, EngineGenerationConfig engine_config=None, typing.Optional[list[typing.Callable]] vocab_processors=None)
Create a formatter filter for the ExLlamaV2 engine.
Definition exllamav2.py:49
kbnf.Vocabulary create_engine_vocabulary(ExLlamaV2Tokenizer tokenizer, typing.Optional[list[typing.Callable]] vocab_processors=None)
Create a vocabulary for the KBNF engine.
Definition exllamav2.py:25