2This module integrates the vllm library by providing convenience utilities.
17 Logit processor that uses formatters to mask batch logits.
20 def __init__(self, formatters: typing.Sequence[FormatterBase |
None], eos_token_id: int,
21 configs: typing.Sequence[EngineGenerationConfig] |
None =
None):
27 assert len(configs) == len(formatters), \
28 f
"Number of formatters({len(formatters)}) must match number of configs({len(configs)})"
35 return [f.captures
if f
is not None else None for f
in self.
_formatters]
39 Check if the formatters are completed. Each boolean in the list corresponds to the
40 completion status of the formatter at the same index.
42 return [f.is_completed()
if f
is not None else None for f
in self.
_formatters]
44 def reset(self) -> None:
55 def __call__(self, prompt, generated_tokens, logits):
56 result = next(self.
_iter,
None)
59 raise ValueError(f
"Batch size {self._debug_counter} "
60 f
"is greater than number of formatters({len(self._formatters)})!")
61 if len(generated_tokens) == 0:
63 formatter, config = result
66 if config.reset_at_beginning
and formatter.is_completed():
68 if config.read_prompt:
70 formatter.accept_token(token)
72 assert result
is None, (f
"Batch size {self._debug_counter} "
73 f
"is less than number of formatters({len(self._formatters)})!")
75 result = next(self.
_iter)
80 while formatter.is_completed():
83 formatter, _ = next(self.
_iter)
86 if len(generated_tokens) != 0:
87 input_id = generated_tokens[-1]
89 formatter.accept_token(input_id)
91 if formatter.is_completed():
92 logits[:] = float(
"-inf")
95 formatter.compute_allowed_tokens()
96 logits = formatter.mask_logits(logits)
102 Create a vocabulary for the KBNF engine.
104 tokenizer = llm.get_tokenizer()
105 vocab = tokenizer.get_vocab()
106 new_vocab = get_original_characters(vocab)
107 return kbnf.Vocabulary({v: kbnf.Token(k)
for k, v
in new_vocab.items()}, {
108 v: k
for k, v
in vocab.items()})
112 formatter_builders: typing.Sequence[FormatterBuilder |
None] | FormatterBuilder,
113 configs: typing.Sequence[EngineGenerationConfig] =
None) \
114 -> FormattersLogitsProcessor:
116 Create a formatter logits processor.
118 tokenizer = llm.get_tokenizer()
120 if not isinstance(formatter_builders, collections.abc.Sequence):
121 formatter_builders = [formatter_builders]
122 formatters = [i.build(vocab,
lambda tokens: tokenizer.decode(tokens))
if i
is not None else None
123 for i
in formatter_builders]