2This module integrates the ExLlamaV2 library by providing convenience utilities.
5from copy
import copy, deepcopy
8from exllamav2
import ExLlamaV2Tokenizer, ExLlamaV2
9from exllamav2.generator.base
import ExLlamaV2Filter
15__all__ = [
"create_engine_vocabulary",
"create_formatter_filter",
"FormatterFilter"]
17 vocab_processors: typing.Optional[list[typing.Callable]] =
None) -> kbnf.Vocabulary:
19 Create a vocabulary for the KBNF engine.
21 tokenizer: The tokenizer.
22 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
23 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
25 assert hasattr(tokenizer.tokenizer_model,
"vocab"), (f
"tokenizer({tokenizer})"
26 f
" with tokenizer_model({tokenizer.tokenizer_model})"
27 f
" does not have vocab attribute!")
28 vocab = {tokenizer.tokenizer_model.id_to_piece(
29 i): i
for i
in range(tokenizer.tokenizer_model.vocab_size())}
30 new_vocab = get_original_characters(vocab, vocab_processors)
31 return kbnf.Vocabulary({k: kbnf.Token(v)
for k, v
in new_vocab.items()},
32 {v: k
for k, v
in vocab.items()})
36 formatter_builder: FormatterBuilder,
37 engine_config: EngineGenerationConfig =
None,
38 vocab_processors: typing.Optional[list[typing.Callable]] =
None) -> ExLlamaV2Filter:
40 Create a formatter filter for the ExLlamaV2 engine.
42 model: The ExLlamaV2 model.
43 tokenizer: The ExLlamaV2 tokenizer.
44 formatter_builder: The formatter builder.
45 engine_config: The engine generation configuration.
46 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
47 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
50 f = formatter_builder.build(
51 vocab,
lambda tokens: tokenizer.decode(torch.tensor(tokens)))
57 ExLlamaV2Filter that uses a formatter to mask logits.
60 def __init__(self, model, tokenizer, formatter: FormatterBase,
61 config: EngineGenerationConfig|
None =
None):
72 Check if the formatter is completed.
76 def clone(self, c=None) -> "FormatterFilter":
78 c = FormatterFilter.__new__(FormatterFilter)
80 c.tokenizer = self.tokenizer
81 c.sequence_str = self.sequence_str
84 c._config = deepcopy(self.
_config)
88 def begin(self, prefix_str: str) ->
None:
92 prompt = prefix_str.encode(
"utf-8")
95 def reset(self) -> None:
98 def feed(self, token: int):
105 def next_set(self) -> typing.Tuple[typing.Set[int], typing.Set[int]]:
107 return {self.tokenizer.eos_token_id}, {self.tokenizer.eos_token_id}
114 def next(self) -> typing.Tuple[typing.Sequence[int], typing.Sequence[int]]:
116 if not hasattr(self,
"allow_return_type_list"):
119 return [self.tokenizer.eos_token_id], [self.tokenizer.eos_token_id]
121 return self.
_formatter.get_allowed_tokens_since_last_computation(), []
135 def mask_logits(self, logits: torch.Tensor) -> torch.Tensor:
138 self.
eos_logits = torch.full_like(logits, float(
"-inf"))
139 self.
eos_logits[self.tokenizer.eos_token_id] = 0
146 Get the captures of the formatter.