Lokasi ngalangkungan proxy:   [ UP ]  
[Ngawartoskeun bug]   [Panyetelan cookie]                
Skip to content

Commit a7db23a

Browse files
committed
feat(chat-format): improve Jinja2ChatFormatter HF compatibility
Enhance Jinja2ChatFormatter to better support HuggingFace-style chat templates while keeping the formatter lightweight and aligned with llama-cpp-python's prompt-rendering needs. This change adds a custom Jinja extension for `{% generation %}` blocks. HuggingFace Transformers uses this tag to track assistant-token spans for assistant masks, but llama-cpp-python only needs the final rendered prompt. The new IgnoreGenerationTags extension therefore treats the tag as a transparent wrapper: it removes the generation/endgeneration tag pair while rendering the inner template body normally. This allows templates that contain `{% generation %}` blocks to render successfully without introducing span tracking overhead. The Jinja environment is also expanded to more closely match Transformers' chat-template runtime behavior. It now enables `jinja2.ext.loopcontrols` for templates that use `{% break %}` or `{% continue %}`, registers a plain JSON `tojson` filter that avoids Jinja's HTML escaping behavior, and exposes `raise_exception` and `strftime_now` as globals instead of passing them on every render call. The formatter now accepts an optional `special_tokens_map`, making additional tokenizer special tokens available to templates. This improves compatibility with templates that reference variables such as `pad_token`, `unk_token`, `sep_token`, or model-specific special tokens beyond `bos_token` and `eos_token`. This also adds optional `documents` support to `__call__`, allowing RAG-style or document-aware chat templates to receive a `documents` variable in the render context. Finally, static stop fields are precomputed during initialization. Text stop sequences and token-id stopping criteria are now built once instead of being recreated for every chat formatting call. The token-id stopping callback also guards against empty token arrays before reading the last token. Key changes: - Add IgnoreGenerationTags Jinja extension for HF `{% generation %}` blocks. - Enable Jinja loop controls for chat templates using break/continue. - Register Transformers-compatible `tojson` behavior. - Register `raise_exception` and `strftime_now` as Jinja globals. - Add `special_tokens_map` support for additional template variables. - Add optional `documents` argument for document-aware templates. - Precompute text stop sequences and token-id stopping criteria. - Improve type normalization for `stop_token_ids`. - Expand docstrings for formatter initialization and render-time variables. Signed-off-by: JamePeng <jame_peng@sina.com>
1 parent 9bb06da commit a7db23a

1 file changed

Lines changed: 232 additions & 32 deletions

File tree

llama_cpp/llama_chat_format.py

Lines changed: 232 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727

2828
import jinja2
29+
from jinja2.ext import Extension
2930
from jinja2.sandbox import ImmutableSandboxedEnvironment
3031

3132
import numpy as np
@@ -220,28 +221,165 @@ def __call__(
220221

221222

222223
class Jinja2ChatFormatter(ChatFormatter):
224+
class IgnoreGenerationTags(Extension):
225+
"""Render HuggingFace `{% generation %}` blocks without tracking.
226+
227+
HuggingFace chat templates may wrap assistant text with:
228+
229+
{% generation %}
230+
...
231+
{% endgeneration %}
232+
233+
Transformers uses this tag to compute assistant-token masks. In
234+
llama-cpp-python chat formatting we only need the final rendered prompt,
235+
so this extension simply removes the tag pair and renders the inner
236+
content as normal Jinja template content.
237+
238+
This keeps compatibility with HF templates while avoiding the overhead
239+
of span tracking.
240+
241+
More information see:
242+
https://github.com/huggingface/transformers/blob/39603d0e5cdb6f00e8d473d7fcbb01032d709181/src/transformers/utils/chat_template_utils.py#L425
243+
"""
244+
245+
tags = {"generation"}
246+
247+
def parse(self, parser: jinja2.parser.Parser):
248+
# Consume the opening `{% generation %}` token.
249+
lineno = next(parser.stream).lineno
250+
251+
# Parse and return the block body until `{% endgeneration %}`.
252+
# Returning the body directly makes the tag a transparent wrapper.
253+
body = parser.parse_statements(
254+
("name:endgeneration",),
255+
drop_needle=True,
256+
)
257+
258+
# Preserve line numbers for better template error messages.
259+
for node in body:
260+
node.set_lineno(lineno)
261+
262+
return body
263+
223264
def __init__(
224265
self,
225266
template: str,
226267
eos_token: str,
227268
bos_token: str,
228269
add_generation_prompt: bool = True,
229270
stop_token_ids: Optional[List[int]] = None,
271+
special_tokens_map: Optional[Dict[str, str]] = None,
230272
):
231-
"""A chat formatter that uses jinja2 templates to format the prompt."""
273+
"""Format chat messages with a HuggingFace-style Jinja2 chat template.
274+
275+
Args:
276+
template:
277+
Raw HuggingFace chat template string.
278+
eos_token:
279+
Text form of the model EOS token.
280+
bos_token:
281+
Text form of the model BOS token.
282+
add_generation_prompt:
283+
Whether to ask the template to append the assistant generation
284+
prefix. This mirrors Transformers' `add_generation_prompt`.
285+
stop_token_ids:
286+
Optional token ids that should stop generation when they appear
287+
as the last generated token. This is llama-cpp-python specific.
288+
special_tokens_map:
289+
Optional tokenizer special-token map. Some HF templates may
290+
reference extra variables such as `pad_token`, `unk_token`,
291+
`sep_token`, or model-specific special tokens.
292+
"""
232293
self.template = template
233294
self.eos_token = eos_token
234295
self.bos_token = bos_token
235296
self.add_generation_prompt = add_generation_prompt
297+
self.special_tokens_map = special_tokens_map or {}
298+
236299
self.stop_token_ids = (
237-
set(stop_token_ids) if stop_token_ids is not None else None
300+
{int(token_id) for token_id in stop_token_ids}
301+
if stop_token_ids is not None
302+
else None
238303
)
239304

240-
self._environment = ImmutableSandboxedEnvironment(
305+
environment = ImmutableSandboxedEnvironment(
241306
loader=jinja2.BaseLoader(),
242307
trim_blocks=True,
243308
lstrip_blocks=True,
244-
).from_string(self.template)
309+
# Keep this aligned with Transformers' chat-template Jinja setup:
310+
# - IgnoreGenerationTags supports `{% generation %}` blocks.
311+
# - loopcontrols supports `{% break %}` and `{% continue %}`.
312+
extensions=[
313+
Jinja2ChatFormatter.IgnoreGenerationTags,
314+
jinja2.ext.loopcontrols,
315+
],
316+
)
317+
318+
# Match Transformers' chat-template JSON behavior.
319+
# Jinja's default `tojson` escapes HTML characters, which is not what
320+
# plain-text chat templates usually expect.
321+
environment.filters["tojson"] = self.tojson
322+
323+
# Register these as globals once instead of passing them on every render.
324+
environment.globals["raise_exception"] = self.raise_exception
325+
environment.globals["strftime_now"] = self.strftime_now
326+
327+
self._environment = environment
328+
self._template = environment.from_string(self.template)
329+
330+
# Precompute static stop fields once. This avoids rebuilding closures and
331+
# StoppingCriteriaList objects for every chat completion request.
332+
self._stop = [self.eos_token] if self.eos_token else []
333+
self._stopping_criteria = self._build_stopping_criteria()
334+
335+
@staticmethod
336+
def raise_exception(message: str):
337+
"""Raise a Jinja template error from inside a chat template."""
338+
raise jinja2.exceptions.TemplateError(message)
339+
340+
@staticmethod
341+
def strftime_now(format_string: str = "%Y-%m-%d %H:%M:%S") -> str:
342+
"""Return the current local time formatted with `datetime.strftime`."""
343+
return datetime.datetime.now().strftime(format_string)
344+
345+
@staticmethod
346+
def tojson(
347+
x: Any,
348+
ensure_ascii: bool = False,
349+
indent: Optional[int] = None,
350+
separators: Optional[Tuple[str, str]] = None,
351+
sort_keys: bool = False,
352+
) -> str:
353+
"""Serialize an object to JSON for chat-template rendering.
354+
355+
This intentionally bypasses Jinja's built-in `tojson` filter because
356+
the built-in filter escapes HTML-sensitive characters. HuggingFace chat
357+
templates expect plain JSON text instead.
358+
"""
359+
return json.dumps(
360+
x,
361+
ensure_ascii=ensure_ascii,
362+
indent=indent,
363+
separators=separators,
364+
sort_keys=sort_keys,
365+
)
366+
367+
def _build_stopping_criteria(self):
368+
"""Create stopping criteria once during initialization."""
369+
if self.stop_token_ids is None:
370+
return None
371+
372+
stop_token_ids = self.stop_token_ids
373+
374+
def stop_on_last_token(
375+
tokens: npt.NDArray[np.intc],
376+
logits: npt.NDArray[np.single],
377+
) -> bool:
378+
# Defensive guard: generation normally calls this with at least one
379+
# token, but the callback should never crash on empty input.
380+
return len(tokens) > 0 and int(tokens[-1]) in stop_token_ids
381+
382+
return llama_core.StoppingCriteriaList([stop_on_last_token])
245383

246384
def __call__(
247385
self,
@@ -251,44 +389,106 @@ def __call__(
251389
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
252390
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
253391
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
392+
documents: Optional[List[Dict[str, Any]]] = None,
254393
**kwargs: Any,
255394
) -> ChatFormatterResponse:
256-
def raise_exception(message: str):
257-
raise ValueError(message)
395+
"""Render OpenAI-style chat messages into a model prompt.
258396
259-
def strftime_now(format_string="%Y-%m-%d %H:%M:%S") -> str:
260-
"""
261-
Returns the current time formatted as a string.
262-
"""
263-
return datetime.datetime.now().strftime(format_string)
397+
The method builds the variable context expected by HuggingFace-style
398+
Jinja chat templates and renders the final prompt string used by
399+
llama-cpp-python.
264400
265-
prompt = self._environment.render(
266-
messages=messages,
267-
eos_token=self.eos_token,
268-
bos_token=self.bos_token,
269-
raise_exception=raise_exception,
270-
strftime_now=strftime_now,
271-
add_generation_prompt=self.add_generation_prompt,
272-
functions=functions,
273-
function_call=function_call,
274-
tools=tools,
275-
tool_choice=tool_choice,
276-
)
401+
Template variables provided by default:
402+
messages:
403+
The chat history to render. Each item is expected to be an
404+
OpenAI-style message dictionary, usually containing at least
405+
`role` and `content`.
277406
278-
stopping_criteria = None
279-
if self.stop_token_ids is not None:
407+
eos_token:
408+
The model's end-of-sequence token string.
409+
410+
bos_token:
411+
The model's beginning-of-sequence token string.
412+
413+
add_generation_prompt:
414+
Whether the template should append the assistant generation
415+
prefix. This mirrors Transformers' `add_generation_prompt`.
416+
417+
functions:
418+
Legacy OpenAI-compatible function definitions, if provided.
280419
281-
def stop_on_last_token(
282-
tokens: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
283-
) -> bool:
284-
return tokens[-1] in self.stop_token_ids
420+
function_call:
421+
Legacy OpenAI-compatible function-call selection, if provided.
285422
286-
stopping_criteria = llama_core.StoppingCriteriaList([stop_on_last_token])
423+
tools:
424+
OpenAI/HuggingFace-compatible tool definitions, if provided.
425+
This formatter expects tools to already be normalized into
426+
JSON-schema-like dictionaries. It does not auto-convert Python
427+
callables into JSON schemas like Transformers can.
428+
429+
tool_choice:
430+
Optional tool-choice instruction, such as `"auto"`, `"none"`,
431+
or a specific tool/function selection object.
432+
433+
documents:
434+
Optional RAG/document context. Some HF chat templates reference
435+
this variable when rendering retrieval-augmented prompts.
436+
437+
**kwargs:
438+
Extra model-specific or template-specific variables. These are
439+
merged into the template context last, so they can intentionally
440+
override the defaults above when needed.
441+
442+
Additional variables:
443+
Values from `special_tokens_map` are also exposed to the template,
444+
such as `pad_token`, `unk_token`, `sep_token`, or custom
445+
model-specific special tokens. Core variables like `messages`,
446+
`eos_token`, and `bos_token` override `special_tokens_map` entries
447+
by default.
448+
449+
Returns:
450+
ChatFormatterResponse:
451+
Contains the rendered prompt, text stop sequences, optional
452+
token-id stopping criteria, and `added_special=True` because the
453+
chat template is responsible for adding model special tokens.
454+
455+
Raises:
456+
jinja2.exceptions.TemplateError:
457+
If the template calls `raise_exception(...)` or Jinja rendering
458+
fails.
459+
"""
460+
template_kwargs: Dict[str, Any] = {}
461+
462+
# Make extra tokenizer special tokens available to templates, e.g.
463+
# `pad_token`, `unk_token`, `sep_token`, or model-specific tokens.
464+
template_kwargs.update(self.special_tokens_map)
465+
466+
# Explicit core variables should override values from special_tokens_map.
467+
template_kwargs.update(
468+
{
469+
"messages": messages,
470+
"eos_token": self.eos_token,
471+
"bos_token": self.bos_token,
472+
"add_generation_prompt": self.add_generation_prompt,
473+
"functions": functions,
474+
"function_call": function_call,
475+
"tools": tools,
476+
"tool_choice": tool_choice,
477+
"documents": documents,
478+
}
479+
)
480+
481+
# Let caller-provided kwargs extend the template context.
482+
# If a caller intentionally passes a same-name key, it will override the
483+
# defaults above. This is useful for model-specific template variables.
484+
template_kwargs.update(kwargs)
485+
486+
prompt = self._template.render(**template_kwargs)
287487

288488
return ChatFormatterResponse(
289489
prompt=prompt,
290-
stop=[self.eos_token],
291-
stopping_criteria=stopping_criteria,
490+
stop=self._stop,
491+
stopping_criteria=self._stopping_criteria,
292492
added_special=True,
293493
)
294494

0 commit comments

Comments
 (0)