@@ -37,77 +37,106 @@ def test_llama_cpp_tokenization():
3737 assert tokens [- 1 ] == llama .token_eos ()
3838 assert tokens == [1 , 15043 , 2787 , 2 ]
3939
40-
41- def test_llama_patch (monkeypatch ):
40+ text = b""
41+ tokens = llama .tokenize (text , add_bos = True , special = True )
42+ assert tokens [- 1 ] != llama .token_eos ()
43+ assert tokens == [llama .token_bos ()]
44+ assert text == llama .detokenize (tokens )
45+
46+
47+ @pytest .fixture
48+ def mock_llama (monkeypatch ):
49+ def setup_mock (llama : llama_cpp .Llama , output_text : str ):
50+ llama .reset ()
51+ n_vocab = llama .n_vocab ()
52+ output_tokens = llama .tokenize (
53+ output_text .encode ("utf-8" ), add_bos = True , special = True
54+ )
55+ n = 0
56+ last_n_tokens = 0
57+
58+ def mock_decode (ctx : llama_cpp .llama_context_p , batch : llama_cpp .llama_batch ):
59+ nonlocal n
60+ nonlocal last_n_tokens
61+ # Test some basic invariants of this mocking technique
62+ assert ctx == llama ._ctx .ctx
63+ assert llama .n_tokens == n
64+ assert batch .n_tokens > 0
65+ n += batch .n_tokens
66+ last_n_tokens = batch .n_tokens
67+ return 0
68+
69+ def mock_get_logits (* args , ** kwargs ):
70+ nonlocal last_n_tokens
71+ size = n_vocab * last_n_tokens
72+ return (llama_cpp .c_float * size )()
73+
74+ def mock_sample (* args , ** kwargs ):
75+ nonlocal n
76+ if n < len (output_tokens ):
77+ return output_tokens [n ]
78+ else :
79+ return llama .token_eos ()
80+
81+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_decode" , mock_decode )
82+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_get_logits" , mock_get_logits )
83+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_sample_token" , mock_sample )
84+
85+ return setup_mock
86+
87+
88+ def test_llama_patch (mock_llama ):
4289 n_ctx = 128
4390 llama = llama_cpp .Llama (model_path = MODEL , vocab_only = True , n_ctx = n_ctx )
4491 n_vocab = llama_cpp .llama_n_vocab (llama ._model .model )
4592 assert n_vocab == 32000
4693
47- ## Set up mock function
48- def mock_decode (* args , ** kwargs ):
49- return 0
50-
51- def mock_get_logits (* args , ** kwargs ):
52- size = n_vocab * n_ctx
53- return (llama_cpp .c_float * size )()
54-
55- monkeypatch .setattr ("llama_cpp.llama_cpp.llama_decode" , mock_decode )
56- monkeypatch .setattr ("llama_cpp.llama_cpp.llama_get_logits" , mock_get_logits )
57-
5894 text = "The quick brown fox"
59- text_tokens = llama .tokenize (text .encode ("utf-8" ), add_bos = True , special = True )
6095 output_text = " jumps over the lazy dog."
61- all_text_tokens = llama .tokenize ((text + output_text ).encode ("utf-8" ), add_bos = True , special = True )
62- output_tokens = all_text_tokens [len (text_tokens ):]
63- token_eos = llama .token_eos ()
64- n = 0
65-
66- def mock_sample (* args , ** kwargs ):
67- nonlocal n
68- if n < len (output_tokens ):
69- n += 1
70- return output_tokens [n - 1 ]
71- else :
72- return token_eos
73-
74- monkeypatch .setattr ("llama_cpp.llama_cpp.llama_sample_token" , mock_sample )
96+ all_text = text + output_text
7597
98+ ## Test basic completion from bos until eos
99+ mock_llama (llama , all_text )
100+ completion = llama .create_completion ("" , max_tokens = 36 )
101+ assert completion ["choices" ][0 ]["text" ] == all_text
102+ assert completion ["choices" ][0 ]["finish_reason" ] == "stop"
76103
77104 ## Test basic completion until eos
78- n = 0 # reset
105+ mock_llama ( llama , all_text )
79106 completion = llama .create_completion (text , max_tokens = 20 )
80107 assert completion ["choices" ][0 ]["text" ] == output_text
81108 assert completion ["choices" ][0 ]["finish_reason" ] == "stop"
82109
83110 ## Test streaming completion until eos
84- n = 0 # reset
111+ mock_llama ( llama , all_text )
85112 chunks = list (llama .create_completion (text , max_tokens = 20 , stream = True ))
86113 assert "" .join (chunk ["choices" ][0 ]["text" ] for chunk in chunks ) == output_text
87114 assert chunks [- 1 ]["choices" ][0 ]["finish_reason" ] == "stop"
88115
89116 ## Test basic completion until stop sequence
90- n = 0 # reset
117+ mock_llama ( llama , all_text )
91118 completion = llama .create_completion (text , max_tokens = 20 , stop = ["lazy" ])
92119 assert completion ["choices" ][0 ]["text" ] == " jumps over the "
93120 assert completion ["choices" ][0 ]["finish_reason" ] == "stop"
94121
95122 ## Test streaming completion until stop sequence
96- n = 0 # reset
97- chunks = list (llama .create_completion (text , max_tokens = 20 , stream = True , stop = ["lazy" ]))
123+ mock_llama (llama , all_text )
124+ chunks = list (
125+ llama .create_completion (text , max_tokens = 20 , stream = True , stop = ["lazy" ])
126+ )
98127 assert (
99128 "" .join (chunk ["choices" ][0 ]["text" ] for chunk in chunks ) == " jumps over the "
100129 )
101130 assert chunks [- 1 ]["choices" ][0 ]["finish_reason" ] == "stop"
102131
103132 ## Test basic completion until length
104- n = 0 # reset
133+ mock_llama ( llama , all_text )
105134 completion = llama .create_completion (text , max_tokens = 2 )
106135 assert completion ["choices" ][0 ]["text" ] == " jumps"
107136 assert completion ["choices" ][0 ]["finish_reason" ] == "length"
108137
109138 ## Test streaming completion until length
110- n = 0 # reset
139+ mock_llama ( llama , all_text )
111140 chunks = list (llama .create_completion (text , max_tokens = 2 , stream = True ))
112141 assert "" .join (chunk ["choices" ][0 ]["text" ] for chunk in chunks ) == " jumps"
113142 assert chunks [- 1 ]["choices" ][0 ]["finish_reason" ] == "length"
@@ -131,44 +160,55 @@ def test_llama_pickle():
131160 assert llama .detokenize (llama .tokenize (text )) == text
132161
133162
134- def test_utf8 (monkeypatch ):
135- n_ctx = 512
136- llama = llama_cpp . Llama ( model_path = MODEL , vocab_only = True , n_ctx = n_ctx , logits_all = True )
163+ def test_utf8 (mock_llama , monkeypatch ):
164+ llama = llama_cpp . Llama ( model_path = MODEL , vocab_only = True , logits_all = True )
165+ n_ctx = llama . n_ctx ( )
137166 n_vocab = llama .n_vocab ()
138167
168+ output_text = "😀"
169+ output_tokens = llama .tokenize (
170+ output_text .encode ("utf-8" ), add_bos = True , special = True
171+ )
172+ token_eos = llama .token_eos ()
173+ n = 0
174+
175+ def reset ():
176+ nonlocal n
177+ llama .reset ()
178+ n = 0
179+
139180 ## Set up mock function
140- def mock_decode (* args , ** kwargs ):
181+ def mock_decode (ctx : llama_cpp .llama_context_p , batch : llama_cpp .llama_batch ):
182+ nonlocal n
183+ assert batch .n_tokens > 0
184+ assert llama .n_tokens == n
185+ n += batch .n_tokens
141186 return 0
142187
143188 def mock_get_logits (* args , ** kwargs ):
144189 size = n_vocab * n_ctx
145190 return (llama_cpp .c_float * size )()
146191
147- monkeypatch .setattr ("llama_cpp.llama_cpp.llama_decode" , mock_decode )
148- monkeypatch .setattr ("llama_cpp.llama_cpp.llama_get_logits" , mock_get_logits )
149-
150- output_text = "😀"
151- output_tokens = llama .tokenize (output_text .encode ("utf-8" ))
152- token_eos = llama .token_eos ()
153- n = 0
154-
155192 def mock_sample (* args , ** kwargs ):
156193 nonlocal n
157- if n < len (output_tokens ):
158- n += 1
194+ if n <= len (output_tokens ):
159195 return output_tokens [n - 1 ]
160196 else :
161197 return token_eos
162198
199+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_decode" , mock_decode )
200+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_get_logits" , mock_get_logits )
163201 monkeypatch .setattr ("llama_cpp.llama_cpp.llama_sample_token" , mock_sample )
164202
165203 ## Test basic completion with utf8 multibyte
166- n = 0 # reset
204+ # mock_llama(llama, output_text)
205+ reset ()
167206 completion = llama .create_completion ("" , max_tokens = 4 )
168207 assert completion ["choices" ][0 ]["text" ] == output_text
169208
170209 ## Test basic completion with incomplete utf8 multibyte
171- n = 0 # reset
210+ # mock_llama(llama, output_text)
211+ reset ()
172212 completion = llama .create_completion ("" , max_tokens = 1 )
173213 assert completion ["choices" ][0 ]["text" ] == ""
174214
@@ -196,5 +236,6 @@ def test_llama_server():
196236 ],
197237 }
198238
239+
199240def test_llama_cpp_version ():
200241 assert llama_cpp .__version__
0 commit comments