From 5ff88a326e64da0fcfb27d0dd6e90f9405f17b3e Mon Sep 17 00:00:00 2001 From: Alex Pilon Date: Tue, 13 Jun 2023 12:26:21 -0400 Subject: [PATCH 1/2] wip --- .gitmodules | 2 +- grammar_test.py | 31 +++++++++++++++++++++++++++++++ llama_cpp/llama.py | 33 +++++++++++++++++++++++++++++---- llama_cpp/llama_cpp.py | 42 ++++++++++++++++++++++++++++++++++++++++++ vendor/llama.cpp | 2 +- 5 files changed, 104 insertions(+), 6 deletions(-) create mode 100644 grammar_test.py diff --git a/.gitmodules b/.gitmodules index 7edf0975dc..4451d88384 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "vendor/llama.cpp"] path = vendor/llama.cpp - url = https://github.com/ggerganov/llama.cpp.git + url = https://github.com/xaptronic/llama.cpp.git diff --git a/grammar_test.py b/grammar_test.py new file mode 100644 index 0000000000..535c1736e6 --- /dev/null +++ b/grammar_test.py @@ -0,0 +1,31 @@ +from llama_cpp import Llama + +grammar = """ +root ::= nav eol (commands eol)* +commands ::= t | info +nav ::= "nav(\\"admin/" [a-z/]* "\\")" +info ::= "info(" setting ")" +t ::= "t(" setting ", " value ")" +value ::= color | string | number | boolean +color ::= "#" [0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f][0-9a-f] +setting ::= "\\"" [a-z ]+ "\\"" +string ::= "\\"" [ \\t!#-\\[\\]-~]* "\\"" +number ::= [0-9]+ +boolean ::= ("true" | "false") +eol ::= "\\n" +""" + +llm = Llama( + model_path="/Users/alex/llama-7b.ggmlv3.q8_0.bin", + lora_base="/Users/alex/llama-7b.ggml.f16.bin", + # python ~/llama.cpp/convert-lora-to-ggml.py . + lora_path="/Users/alex/src/github.com/Shopify/sidekick-data/src/webapp/models/ggml-adapter-model.bin", + # n_gpu_layers=1000, + n_ctx=2048, + grammar=grammar, +) + +# response = llm("make my theme orange") + +import code +code.interact(local=globals()) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 4b6ce8c37b..5846525f08 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -219,6 +219,7 @@ def __init__( last_n_tokens_size: int = 64, lora_base: Optional[str] = None, lora_path: Optional[str] = None, + grammar: Optional[str] = None, verbose: bool = True, ): """Load a llama.cpp model from `model_path`. @@ -272,6 +273,12 @@ def __init__( self.lora_base = lora_base self.lora_path = lora_path + self.grammar = grammar + + if grammar: + self.grammar = llama_cpp.llama_parse_grammar( + llama_cpp.c_char_p(self.grammar.encode("utf-8")) + ) ### DEPRECATED ### self.n_parts = n_parts @@ -496,8 +503,16 @@ def _sample( ) if not penalize_nl: candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit) + + if self.grammar: + llama_cpp.llama_sample_grammar( + self.ctx, + candidates=llama_cpp.ctypes.byref(candidates), + grammar=self.grammar, + ) # type: ignore + if temp.value == 0.0: - return llama_cpp.llama_sample_token_greedy( + id = llama_cpp.llama_sample_token_greedy( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore ) @@ -509,7 +524,7 @@ def _sample( candidates=llama_cpp.ctypes.byref(candidates), # type: ignore temp=temp, ) - return llama_cpp.llama_sample_token_mirostat( + id = llama_cpp.llama_sample_token_mirostat( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore tau=mirostat_tau, @@ -524,7 +539,7 @@ def _sample( candidates=llama_cpp.ctypes.pointer(candidates), temp=temp, ) - return llama_cpp.llama_sample_token_mirostat_v2( + id = llama_cpp.llama_sample_token_mirostat_v2( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore tau=mirostat_tau, @@ -561,11 +576,21 @@ def _sample( candidates=llama_cpp.ctypes.byref(candidates), # type: ignore temp=temp, ) - return llama_cpp.llama_sample_token( + id = llama_cpp.llama_sample_token( ctx=self.ctx, candidates=llama_cpp.ctypes.byref(candidates), # type: ignore ) + if self.grammar: + breakpoint() + id = llama_cpp.llama_grammar_accept_token( + self.ctx, + self.grammar, + id + ) + + return id + def sample( self, top_k: int = 40, diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 29136c7e93..3291a14462 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -58,6 +58,7 @@ def _load_shared_library(lib_base_name: str): cdll_args["winmode"] = 0 # Try to load the shared library, handling potential errors + print(_lib_paths) for _lib_path in _lib_paths: if _lib_path.exists(): try: @@ -113,6 +114,8 @@ def _load_shared_library(lib_base_name: str): llama_token = c_int llama_token_p = POINTER(llama_token) +llama_grammar_p = c_void_p + # typedef struct llama_token_data { # llama_token id; // token id @@ -793,6 +796,45 @@ def llama_sample_temperature( _lib.llama_sample_temperature.restype = None +def llama_parse_grammar(grammar: str): + return _lib.llama_parse_grammar(grammar) + +_lib.llama_parse_grammar.argtypes = [ + c_char_p, +] +_lib.llama_parse_grammar.restype = llama_grammar_p + + +def llama_sample_grammar( + ctx: llama_context_p, + candidates, # type: _Pointer[llama_token_data_array] + grammar: llama_grammar_p, +): + return _lib.llama_sample_grammar(ctx, candidates, grammar) + +_lib.llama_sample_grammar.argtypes = [ + llama_context_p, + llama_token_data_array_p, + llama_grammar_p, +] +_lib.llama_sample_grammar.restype = None + + +def llama_grammar_accept_token( + ctx: llama_context_p, + grammar: llama_grammar_p, + id: llama_token, +): + return _lib.llama_grammar_accept_token(ctx, grammar, id) + +_lib.llama_grammar_accept_token.argtypes = [ + llama_context_p, + llama_grammar_p, + llama_token +] +_lib.llama_grammar_accept_token.restype = llama_token + + # @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. # @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. # @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 4de0334f5c..3e78f0071a 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 4de0334f5cabf4696eced2e5d6e279fdfaa6c0f2 +Subproject commit 3e78f0071a76fac0a9807bd32de805d2ac67401a From c37a02982b319b5247c2e11191c1831980a60572 Mon Sep 17 00:00:00 2001 From: Alex Pilon Date: Wed, 14 Jun 2023 17:14:14 -0400 Subject: [PATCH 2/2] add bindings for llama_grammar_parse / llama_grammar_from_state --- grammar_test.py | 9 +++------ llama_cpp/llama.py | 16 ++++++++-------- llama_cpp/llama_cpp.py | 19 +++++++++++++++---- vendor/llama.cpp | 2 +- 4 files changed, 27 insertions(+), 19 deletions(-) diff --git a/grammar_test.py b/grammar_test.py index 535c1736e6..a2e436ff3f 100644 --- a/grammar_test.py +++ b/grammar_test.py @@ -1,7 +1,6 @@ from llama_cpp import Llama -grammar = """ -root ::= nav eol (commands eol)* +grammar = """root ::= nav eol (commands eol)* commands ::= t | info nav ::= "nav(\\"admin/" [a-z/]* "\\")" info ::= "info(" setting ")" @@ -17,15 +16,13 @@ llm = Llama( model_path="/Users/alex/llama-7b.ggmlv3.q8_0.bin", - lora_base="/Users/alex/llama-7b.ggml.f16.bin", + # lora_base="/Users/alex/llama-7b.ggml.f16.bin", # python ~/llama.cpp/convert-lora-to-ggml.py . - lora_path="/Users/alex/src/github.com/Shopify/sidekick-data/src/webapp/models/ggml-adapter-model.bin", + # lora_path="/Users/alex/src/github.com/Shopify/sidekick-data/src/webapp/models/ggml-adapter-model.bin", # n_gpu_layers=1000, n_ctx=2048, grammar=grammar, ) -# response = llm("make my theme orange") - import code code.interact(local=globals()) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 5846525f08..20e6eb0835 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -273,12 +273,6 @@ def __init__( self.lora_base = lora_base self.lora_path = lora_path - self.grammar = grammar - - if grammar: - self.grammar = llama_cpp.llama_parse_grammar( - llama_cpp.c_char_p(self.grammar.encode("utf-8")) - ) ### DEPRECATED ### self.n_parts = n_parts @@ -306,6 +300,12 @@ def __init__( f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}" ) + if grammar: + self.parse_state = llama_cpp.llama_grammar_parse( + llama_cpp.c_char_p(grammar.encode("utf-8")) + ) + self.grammar = llama_cpp.llama_grammar_from_state(self.parse_state) + if self.verbose: print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) @@ -582,7 +582,6 @@ def _sample( ) if self.grammar: - breakpoint() id = llama_cpp.llama_grammar_accept_token( self.ctx, self.grammar, @@ -890,7 +889,8 @@ def _create_completion( stopping_criteria=stopping_criteria, logits_processor=logits_processor, ): - if token == self._token_eos: + + if token == self._token_eos: #or token == self._token_nl: text = self.detokenize(completion_tokens) finish_reason = "stop" break diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 3291a14462..d5eec0959e 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -114,6 +114,8 @@ def _load_shared_library(lib_base_name: str): llama_token = c_int llama_token_p = POINTER(llama_token) +# struct llama_grammar +parse_state_p = c_void_p llama_grammar_p = c_void_p @@ -796,13 +798,22 @@ def llama_sample_temperature( _lib.llama_sample_temperature.restype = None -def llama_parse_grammar(grammar: str): - return _lib.llama_parse_grammar(grammar) +def llama_grammar_parse(grammar: str): + return _lib.llama_grammar_parse(grammar) -_lib.llama_parse_grammar.argtypes = [ +_lib.llama_grammar_parse.argtypes = [ c_char_p, ] -_lib.llama_parse_grammar.restype = llama_grammar_p +_lib.llama_grammar_parse.restype = parse_state_p + + +def llama_grammar_from_state(parse_state: parse_state_p): + return _lib.llama_grammar_from_state(parse_state) + +_lib.llama_grammar_from_state.argtypes = [ + parse_state_p +] +_lib.llama_grammar_from_state.restype = llama_grammar_p def llama_sample_grammar( diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 3e78f0071a..9d0fcb0c35 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 3e78f0071a76fac0a9807bd32de805d2ac67401a +Subproject commit 9d0fcb0c350305a91ce7460c57228f2d259a804f