2121import diskcache
2222import ctypes
2323
24- from . import llama_cpp
2524from .llama_types import *
2625from .llama_grammar import LlamaGrammar
26+ import llama_cpp .llama_cpp as llama_cpp
2727import llama_cpp .llama_chat_format as llama_chat_format
2828
2929import numpy as np
@@ -752,6 +752,7 @@ def __init__(
752752 numa : bool = False ,
753753 # Chat Format Params
754754 chat_format : str = "llama-2" ,
755+ chat_handler : Optional [llama_chat_format .LlamaChatCompletionHandler ] = None ,
755756 # Misc
756757 verbose : bool = True ,
757758 # Extra Params
@@ -784,6 +785,7 @@ def __init__(
784785 lora_path: Path to a LoRA file to apply to the model.
785786 numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
786787 chat_format: String specifying the chat format to use when calling create_chat_completion.
788+ chat_handler: Optional chat handler to use when calling create_chat_completion.
787789 verbose: Print verbose output to stderr.
788790
789791 Raises:
@@ -910,6 +912,7 @@ def __init__(
910912 print (llama_cpp .llama_print_system_info ().decode ("utf-8" ), file = sys .stderr )
911913
912914 self .chat_format = chat_format
915+ self .chat_handler = chat_handler
913916
914917 self ._n_vocab = self .n_vocab ()
915918 self ._n_ctx = self .n_ctx ()
@@ -1231,7 +1234,7 @@ def create_embedding(
12311234 else :
12321235 inputs = input
12331236
1234- data : List [EmbeddingData ] = []
1237+ data : List [Embedding ] = []
12351238 total_tokens = 0
12361239 for index , input in enumerate (inputs ):
12371240 tokens = self .tokenize (input .encode ("utf-8" ), special = True )
@@ -1276,7 +1279,7 @@ def embed(self, input: str) -> List[float]:
12761279
12771280 def _create_completion (
12781281 self ,
1279- prompt : str ,
1282+ prompt : Union [ str , List [ int ]] ,
12801283 suffix : Optional [str ] = None ,
12811284 max_tokens : int = 16 ,
12821285 temperature : float = 0.8 ,
@@ -1297,7 +1300,9 @@ def _create_completion(
12971300 stopping_criteria : Optional [StoppingCriteriaList ] = None ,
12981301 logits_processor : Optional [LogitsProcessorList ] = None ,
12991302 grammar : Optional [LlamaGrammar ] = None ,
1300- ) -> Union [Iterator [Completion ], Iterator [CompletionChunk ]]:
1303+ ) -> Union [
1304+ Iterator [CreateCompletionResponse ], Iterator [CreateCompletionStreamResponse ]
1305+ ]:
13011306 assert self ._ctx is not None
13021307 assert suffix is None or suffix .__class__ is str
13031308
@@ -1309,7 +1314,7 @@ def _create_completion(
13091314 self .tokenize (prompt .encode ("utf-8" ), special = True )
13101315 if prompt != ""
13111316 else [self .token_bos ()]
1312- )
1317+ ) if isinstance ( prompt , str ) else prompt
13131318 text : bytes = b""
13141319 returned_tokens : int = 0
13151320 stop = (
@@ -1322,7 +1327,7 @@ def _create_completion(
13221327
13231328 if len (prompt_tokens ) >= self ._n_ctx :
13241329 raise ValueError (
1325- f"Requested tokens ({ len (prompt_tokens )} ) exceed context window of { llama_cpp .llama_n_ctx (self ._ctx )} "
1330+ f"Requested tokens ({ len (prompt_tokens )} ) exceed context window of { llama_cpp .llama_n_ctx (self .ctx )} "
13261331 )
13271332
13281333 if max_tokens <= 0 :
@@ -1732,7 +1737,7 @@ def _create_completion(
17321737
17331738 def create_completion (
17341739 self ,
1735- prompt : str ,
1740+ prompt : Union [ str , List [ int ]] ,
17361741 suffix : Optional [str ] = None ,
17371742 max_tokens : int = 128 ,
17381743 temperature : float = 0.8 ,
@@ -1753,7 +1758,7 @@ def create_completion(
17531758 stopping_criteria : Optional [StoppingCriteriaList ] = None ,
17541759 logits_processor : Optional [LogitsProcessorList ] = None ,
17551760 grammar : Optional [LlamaGrammar ] = None ,
1756- ) -> Union [Completion , Iterator [CompletionChunk ]]:
1761+ ) -> Union [CreateCompletionResponse , Iterator [CreateCompletionStreamResponse ]]:
17571762 """Generate text from a prompt.
17581763
17591764 Args:
@@ -1800,7 +1805,7 @@ def create_completion(
18001805 grammar = grammar ,
18011806 )
18021807 if stream :
1803- chunks : Iterator [CompletionChunk ] = completion_or_chunks
1808+ chunks : Iterator [CreateCompletionStreamResponse ] = completion_or_chunks
18041809 return chunks
18051810 completion : Completion = next (completion_or_chunks ) # type: ignore
18061811 return completion
@@ -1828,7 +1833,7 @@ def __call__(
18281833 stopping_criteria : Optional [StoppingCriteriaList ] = None ,
18291834 logits_processor : Optional [LogitsProcessorList ] = None ,
18301835 grammar : Optional [LlamaGrammar ] = None ,
1831- ) -> Union [Completion , Iterator [CompletionChunk ]]:
1836+ ) -> Union [CreateCompletionResponse , Iterator [CreateCompletionStreamResponse ]]:
18321837 """Generate text from a prompt.
18331838
18341839 Args:
@@ -1879,7 +1884,9 @@ def create_chat_completion(
18791884 self ,
18801885 messages : List [ChatCompletionRequestMessage ],
18811886 functions : Optional [List [ChatCompletionFunction ]] = None ,
1882- function_call : Optional [Union [str , ChatCompletionFunctionCall ]] = None ,
1887+ function_call : Optional [ChatCompletionRequestFunctionCall ] = None ,
1888+ tools : Optional [List [ChatCompletionTool ]] = None ,
1889+ tool_choice : Optional [ChatCompletionToolChoiceOption ] = None ,
18831890 temperature : float = 0.2 ,
18841891 top_p : float = 0.95 ,
18851892 top_k : int = 40 ,
@@ -1896,7 +1903,9 @@ def create_chat_completion(
18961903 model : Optional [str ] = None ,
18971904 logits_processor : Optional [LogitsProcessorList ] = None ,
18981905 grammar : Optional [LlamaGrammar ] = None ,
1899- ) -> Union [ChatCompletion , Iterator [ChatCompletionChunk ]]:
1906+ ) -> Union [
1907+ CreateChatCompletionResponse , Iterator [CreateChatCompletionStreamResponse ]
1908+ ]:
19001909 """Generate a chat completion from a list of messages.
19011910
19021911 Args:
@@ -1912,12 +1921,16 @@ def create_chat_completion(
19121921 Returns:
19131922 Generated chat completion or a stream of chat completion chunks.
19141923 """
1915- handler = llama_chat_format .get_chat_completion_handler (self .chat_format )
1924+ handler = self .chat_handler or llama_chat_format .get_chat_completion_handler (
1925+ self .chat_format
1926+ )
19161927 return handler (
1917- self ,
1928+ llama = self ,
19181929 messages = messages ,
19191930 functions = functions ,
19201931 function_call = function_call ,
1932+ tools = tools ,
1933+ tool_choice = tool_choice ,
19211934 temperature = temperature ,
19221935 top_p = top_p ,
19231936 top_k = top_k ,
@@ -1974,6 +1987,7 @@ def __getstate__(self):
19741987 numa = self .numa ,
19751988 # Chat Format Params
19761989 chat_format = self .chat_format ,
1990+ chat_handler = self .chat_handler ,
19771991 # Misc
19781992 verbose = self .verbose ,
19791993 )
@@ -2015,6 +2029,7 @@ def __setstate__(self, state):
20152029 numa = state ["numa" ],
20162030 # Chat Format Params
20172031 chat_format = state ["chat_format" ],
2032+ chat_handler = state ["chat_handler" ],
20182033 # Misc
20192034 verbose = state ["verbose" ],
20202035 )
0 commit comments