@@ -47,16 +47,22 @@ def test_llama_cpp_tokenization():
4747@pytest .fixture
4848def mock_llama (monkeypatch ):
4949 def setup_mock (llama : llama_cpp .Llama , output_text : str ):
50+ n_ctx = llama .n_ctx ()
5051 n_vocab = llama .n_vocab ()
5152 output_tokens = llama .tokenize (
5253 output_text .encode ("utf-8" ), add_bos = True , special = True
5354 )
55+ logits = (llama_cpp .c_float * (n_vocab * n_ctx ))(- 100.0 )
56+ for i in range (n_ctx ):
57+ output_idx = i + 1 # logits for first tokens predict second token
58+ if output_idx < len (output_tokens ):
59+ logits [i * n_vocab + output_tokens [output_idx ]] = 100.0
60+ else :
61+ logits [i * n_vocab + llama .token_eos ()] = 100.0
5462 n = 0
5563 last_n_tokens = 0
5664
5765 def mock_decode (ctx : llama_cpp .llama_context_p , batch : llama_cpp .llama_batch ):
58- nonlocal n
59- nonlocal last_n_tokens
6066 # Test some basic invariants of this mocking technique
6167 assert ctx == llama ._ctx .ctx , "context does not match mock_llama"
6268 assert batch .n_tokens > 0 , "no tokens in batch"
@@ -70,26 +76,22 @@ def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
7076 batch .n_tokens - 1
7177 ], "logits not allocated for last token"
7278 # Update the mock context state
79+ nonlocal n
80+ nonlocal last_n_tokens
7381 n = max (batch .pos [i ] for i in range (batch .n_tokens )) + 1
7482 last_n_tokens = batch .n_tokens
7583 return 0
7684
77- def mock_get_logits (* args , ** kwargs ):
78- nonlocal n
79- nonlocal last_n_tokens
85+ def mock_get_logits (ctx : llama_cpp . llama_context_p ):
86+ # Test some basic invariants of this mocking technique
87+ assert ctx == llama . _ctx . ctx , "context does not match mock_llama"
8088 assert n > 0 , "mock_llama_decode not called"
8189 assert last_n_tokens > 0 , "mock_llama_decode not called"
82- logits = (llama_cpp .c_float * (last_n_tokens * n_vocab ))(- 100.0 )
83- for logits_idx , output_idx in enumerate (
84- range (n - last_n_tokens + 1 , n + 1 )
85- ):
86- if output_idx < len (output_tokens ):
87- logits [
88- logits_idx * last_n_tokens + output_tokens [output_idx ]
89- ] = 100.0
90- else :
91- logits [logits_idx * last_n_tokens + llama .token_eos ()] = 100.0
92- return logits
90+ # Return view of logits for last_n_tokens
91+ return (llama_cpp .c_float * (last_n_tokens * n_vocab )).from_address (
92+ ctypes .addressof (logits )
93+ + (n - last_n_tokens ) * n_vocab * ctypes .sizeof (llama_cpp .c_float )
94+ )
9395
9496 monkeypatch .setattr ("llama_cpp.llama_cpp.llama_decode" , mock_decode )
9597 monkeypatch .setattr ("llama_cpp.llama_cpp.llama_get_logits" , mock_get_logits )
0 commit comments