Formatron v0.5.0
Formatron empowers everyone to control the output format of language models with minimal overhead.
Loading...
Searching...
No Matches
utils.py
Go to the documentation of this file.
1import re
2import typing
3from functools import lru_cache
4
5__all__ = ["get_original_characters", "update_vocab_0xHH", "update_vocab_sentencepiece", "update_vocab_dot_G"]
6
7def _multiple_replace(replacements: typing.Dict[bytes, bytes], regex: re.Pattern[bytes], text: bytes) -> bytes:
8 # For each match, look-up corresponding value in dictionary
9 return regex.sub(lambda mo: replacements[mo.group()], text)
10def default_mask_logits_fn(bit_mask, formatter, logits):
11 return formatter.mask_logits(logits)
13 try:
14 from kbnf.triton_logits_mask import mask_logits_inplace
15 def fast_mask_logits_fn(bit_mask, formatter, logits):
16 mask_logits_inplace(logits, bit_mask, [formatter._engine])
17 return logits
18 return fast_mask_logits_fn
19 except ImportError:
20 return default_mask_logits_fn
21
22def get_bit_mask(logits):
23 import torch
24 return torch.empty(((logits.shape[-1]+31)//32), dtype=torch.int32, device='cpu', pin_memory=True)
25
26def get_original_characters(vocab: typing.Dict[str, int],
27 processors: typing.Optional[list[typing.Callable]] = None) -> typing.Dict[int, bytes]:
28 """
29 Get a vocabulary of original characters unmangled to raw UTF-8 bytes by the provided processors.
30
31 Args:
32 vocab: The mangled vocabulary.
33 processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
34 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
35 """
36 old_char_to_new_char = {}
37 assert len(set(vocab.values())) == len(vocab), "Vocabulary contains duplicate token IDs!"
38 if processors is None:
39 processors = autodetect_processors(vocab)
40 for update_vocab in processors:
41 update_vocab(old_char_to_new_char)
42 # Create a regular expression from the dictionary keys with longest keys first to avoid conflicts
43 regex = re.compile(b"(%s)" % b"|".join(sorted(list(map(re.escape, old_char_to_new_char.keys())), key=lambda x: len(x), reverse=True)))
44 new_vocab = {}
45 for k in vocab:
46 token_id = vocab[k]
47 new_k = _multiple_replace(old_char_to_new_char, regex, k.encode("UTF-8"))
48 new_vocab[token_id] = new_k
49 return new_vocab
50
51
52def autodetect_processors(vocab: typing.Dict[str, int]) -> typing.List[typing.Callable]:
53 """
54 Autodetect vocabulary processors.
55 """
56 result = []
57 llama_present = any(i.find('<0xF0>') != -1 for i in vocab.keys())
58 underscore_present = (len([1 for i in vocab.keys() if i.find('\u2581') != -1]) / len(vocab)) > 0.2
59 g_present = (len([1 for i in vocab.keys() if i.find('\u0120') != -1]) / len(vocab)) > 0.2
60 if llama_present:
61 result.append(update_vocab_0xHH)
62 if underscore_present:
63 result.append(update_vocab_sentencepiece)
64 elif g_present:
65 result.append(update_vocab_dot_G)
66 return result
67
68
69def update_vocab_0xHH(token_to_char: typing.Dict[bytes, bytes]):
70 """
71 Vocabulary processor for <0xHH> tokens (used in llama tokenizers)
72 """
73 for j in range(256):
74 token_to_char[("<0x" + f"{j:02x}".upper() + ">").encode("UTF-8")] = bytes([j])
75
76
77def update_vocab_sentencepiece(token_to_char: typing.Dict[bytes, bytes]):
78 """
79 Vocabulary processor for ▁ token (used in sentencepiece tokenizers)
80 """
81 token_to_char["\u2581".encode("UTF-8")] = b" "
82
83
84def update_vocab_dot_G(token_to_char: typing.Dict[bytes, bytes]):
85 """
86 Vocabulary processor for GPT2 style token mangling, like from \\n to Ġ(used in huggingface bytelevel preprocessors)
87 """
88 token_to_char.update(_huggingface_bytelevel_decoder())
89
90
91@lru_cache()
93 """
94 I hate legacy code.
95 """
96 bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
97 cs = bs[:]
98 n = 0
99 for b in range(2**8):
100 if b not in bs:
101 bs.append(b)
102 cs.append(2**8+n)
103 n += 1
104 cs = [chr(n).encode("UTF-8") for n in cs]
105 for i in range(len(bs)):
106 bs[i] = bytes([bs[i]])
107 return dict(zip(cs, bs))
update_vocab_0xHH(typing.Dict[bytes, bytes] token_to_char)
Vocabulary processor for <0xHH> tokens (used in llama tokenizers)
Definition utils.py:72
bytes _multiple_replace(typing.Dict[bytes, bytes] replacements, re.Pattern[bytes] regex, bytes text)
Definition utils.py:7
_huggingface_bytelevel_decoder()
I hate legacy code.
Definition utils.py:97
update_vocab_dot_G(typing.Dict[bytes, bytes] token_to_char)
Vocabulary processor for GPT2 style token mangling, like from \n to Ġ(used in huggingface bytelevel p...
Definition utils.py:87
default_mask_logits_fn(bit_mask, formatter, logits)
Definition utils.py:10
typing.Dict[int, bytes] get_original_characters(typing.Dict[str, int] vocab, typing.Optional[list[typing.Callable]] processors=None)
Get a vocabulary of original characters unmangled to raw UTF-8 bytes by the provided processors.
Definition utils.py:35
update_vocab_sentencepiece(typing.Dict[bytes, bytes] token_to_char)
Vocabulary processor for ▁ token (used in sentencepiece tokenizers)
Definition utils.py:80
typing.List[typing.Callable] autodetect_processors(typing.Dict[str, int] vocab)
Autodetect vocabulary processors.
Definition utils.py:55