Formatron v0.4.9
Formatron empowers everyone to control the output format of language models with minimal overhead.
Loading...
Searching...
No Matches
RWKV.py
Go to the documentation of this file.
1"""
2This module integrates the RWKV library by providing convenience utilities.
3"""
4import kbnf
5import rwkv.utils
6from kbnf import Token
7
8from formatron.config import EngineGenerationConfig
9from formatron.formatter import FormatterBuilder
10
11__all__ = ["create_engine_vocabulary", "PIPELINE", "PIPELINE_ARGS"]
12class PIPELINE_ARGS(rwkv.utils.PIPELINE_ARGS):
13 """
14 A wrapper for the arguments of the pipeline of RWKV.
15 """
17 def __init__(self,
18 temperature=1.0,
19 top_p=0.2,
20 top_k=0,
21 alpha_frequency=0.2,
22 alpha_presence=0.2,
23 alpha_decay=0.996,
24 token_ban=[],
25 token_stop=[],
26 chunk_len=256,
27 engine_gen_config=EngineGenerationConfig()):
28 super().__init__(temperature, top_p, top_k, alpha_frequency, alpha_presence, alpha_decay, token_ban, token_stop,
29 chunk_len)
30 self.engine_gen_config = engine_gen_config
32
33def create_engine_vocabulary(WORD_NAME: str, tokenizer) -> kbnf.Vocabulary: # NOSONAR
34 """
35 Create a vocabulary for the KBNF engine.
36 """
37 assert WORD_NAME == 'rwkv_vocab_v20230424', "Only world vocabulary is supported!"
38 return kbnf.Vocabulary({k: Token(v) for k, v in tokenizer.idx2token.items()},
39 {k: v.decode("UTF-8", errors="replace") for k, v in
40 tokenizer.idx2token.items()})
41
42
43class PIPELINE(rwkv.utils.PIPELINE): # NOSONAR
44 """
45 A wrapper for the pipeline of RWKV.
46 """
48 def __init__(self, model, WORD_NAME, formatter_builder: FormatterBuilder = None): # NOSONAR
49 super().__init__(model, WORD_NAME)
50 vocabulary = create_engine_vocabulary(WORD_NAME, self.tokenizer)
51 formatter = formatter_builder.build(vocabulary, lambda tokens: self.tokenizer.decode(tokens))
52 if formatter is not None:
53 self.formatter = formatter
54 else:
55 self.formatter = None
56
57 def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None):
58 all_tokens = []
59 out_last = 0
60 out_str = ''
61 occurrence = {}
62 if args.engine_gen_config.reset_at_beginning and self.formatter and self.formatter.is_completed():
63 self.formatter.reset()
64 for i in range(token_count):
65 # forward & adjust prob.
66 tokens = self.encode(ctx) if i == 0 else [token]
67 if self.formatter is not None:
68 if i == 0 and args.engine_gen_config.read_prompt:
69 for token in tokens:
70 self.formatter.accept_token(token)
71 while len(tokens) > 0:
72 out, state = self.model.forward(tokens[:args.chunk_len], state)
73 tokens = tokens[args.chunk_len:]
74 if self.formatter and self.formatter.is_completed():
75 break
76 for n in args.token_ban:
77 out[n] = -float('inf')
78 for n in occurrence:
79 out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
80 if self.formatter is not None:
81 formatter = self.formatter
82 formatter.compute_allowed_tokens()
83 out = out[:len(self.tokenizer.idx2token) + 1] # account for the padding `0` token
84 out = formatter.mask_logits(out)
85 # sampler
86 token = self.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
87 if self.formatter:
88 self.formatter.accept_token(token)
89 if token in args.token_stop:
90 break
91 all_tokens += [token]
92 for xxx in occurrence:
93 occurrence[xxx] *= args.alpha_decay
94
95 ttt = self.decode([token])
96 www = 1
97 if ttt in ' \t0123456789':
98 www = 0
99 if token not in occurrence:
100 occurrence[token] = www
101 else:
102 occurrence[token] += www
103 # print(occurrence) # debug
104
105 # output
106 tmp = self.decode(all_tokens[out_last:])
107 if '\ufffd' not in tmp: # is valid utf-8 string?
108 if callback:
109 callback(tmp)
110 out_str += tmp
111 out_last = i + 1
112 if self.formatter and self.formatter.is_completed():
113 break
114 return out_str
A wrapper for the arguments of the pipeline of RWKV.
Definition RWKV.py:16
__init__(self, temperature=1.0, top_p=0.2, top_k=0, alpha_frequency=0.2, alpha_presence=0.2, alpha_decay=0.996, token_ban=[], token_stop=[], chunk_len=256, engine_gen_config=EngineGenerationConfig())
Definition RWKV.py:28
A wrapper for the pipeline of RWKV.
Definition RWKV.py:47
generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None)
Definition RWKV.py:58
__init__(self, model, WORD_NAME, FormatterBuilder formatter_builder=None)
Definition RWKV.py:49
Configuration classes for Formatron.
Definition config.py:1
This module contains the Formatter class and its related classes.
Definition formatter.py:1
kbnf.Vocabulary create_engine_vocabulary(str WORD_NAME, tokenizer)
Create a vocabulary for the KBNF engine.
Definition RWKV.py:37