Lokasi ngalangkungan proxy:   [ UP ]  
[Ngawartoskeun bug]   [Panyetelan cookie]                
Skip to content

Commit 0ea2444

Browse files
committed
tests: avoid constantly reallocating logits
1 parent 0a7e05b commit 0ea2444

1 file changed

Lines changed: 18 additions & 16 deletions

File tree

tests/test_llama.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,22 @@ def test_llama_cpp_tokenization():
4747
@pytest.fixture
4848
def mock_llama(monkeypatch):
4949
def setup_mock(llama: llama_cpp.Llama, output_text: str):
50+
n_ctx = llama.n_ctx()
5051
n_vocab = llama.n_vocab()
5152
output_tokens = llama.tokenize(
5253
output_text.encode("utf-8"), add_bos=True, special=True
5354
)
55+
logits = (llama_cpp.c_float * (n_vocab * n_ctx))(-100.0)
56+
for i in range(n_ctx):
57+
output_idx = i + 1 # logits for first tokens predict second token
58+
if output_idx < len(output_tokens):
59+
logits[i * n_vocab + output_tokens[output_idx]] = 100.0
60+
else:
61+
logits[i * n_vocab + llama.token_eos()] = 100.0
5462
n = 0
5563
last_n_tokens = 0
5664

5765
def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
58-
nonlocal n
59-
nonlocal last_n_tokens
6066
# Test some basic invariants of this mocking technique
6167
assert ctx == llama._ctx.ctx, "context does not match mock_llama"
6268
assert batch.n_tokens > 0, "no tokens in batch"
@@ -70,26 +76,22 @@ def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
7076
batch.n_tokens - 1
7177
], "logits not allocated for last token"
7278
# Update the mock context state
79+
nonlocal n
80+
nonlocal last_n_tokens
7381
n = max(batch.pos[i] for i in range(batch.n_tokens)) + 1
7482
last_n_tokens = batch.n_tokens
7583
return 0
7684

77-
def mock_get_logits(*args, **kwargs):
78-
nonlocal n
79-
nonlocal last_n_tokens
85+
def mock_get_logits(ctx: llama_cpp.llama_context_p):
86+
# Test some basic invariants of this mocking technique
87+
assert ctx == llama._ctx.ctx, "context does not match mock_llama"
8088
assert n > 0, "mock_llama_decode not called"
8189
assert last_n_tokens > 0, "mock_llama_decode not called"
82-
logits = (llama_cpp.c_float * (last_n_tokens * n_vocab))(-100.0)
83-
for logits_idx, output_idx in enumerate(
84-
range(n - last_n_tokens + 1, n + 1)
85-
):
86-
if output_idx < len(output_tokens):
87-
logits[
88-
logits_idx * last_n_tokens + output_tokens[output_idx]
89-
] = 100.0
90-
else:
91-
logits[logits_idx * last_n_tokens + llama.token_eos()] = 100.0
92-
return logits
90+
# Return view of logits for last_n_tokens
91+
return (llama_cpp.c_float * (last_n_tokens * n_vocab)).from_address(
92+
ctypes.addressof(logits)
93+
+ (n - last_n_tokens) * n_vocab * ctypes.sizeof(llama_cpp.c_float)
94+
)
9395

9496
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
9597
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)

0 commit comments

Comments
 (0)