Skip to content

Commit 704039b

Browse files
committed
feat(server): add model-load chat_template_kwargs
1 parent 7613aca commit 704039b

6 files changed

Lines changed: 147 additions & 4 deletions

File tree

llama_cpp/llama_chat_format.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def raise_exception(message: str):
243243
tools=tools,
244244
tool_choice=tool_choice,
245245
strftime_now=self.strftime_now,
246+
**kwargs,
246247
)
247248

248249
stopping_criteria = None
@@ -617,6 +618,7 @@ def chat_completion_handler(
617618
function_call=function_call,
618619
tools=tools,
619620
tool_choice=tool_choice,
621+
**kwargs,
620622
)
621623
prompt = llama.tokenize(
622624
result.prompt.encode("utf-8"),
@@ -734,7 +736,9 @@ def format_autotokenizer(
734736
**kwargs: Any,
735737
) -> ChatFormatterResponse:
736738
tokenizer.use_default_system_prompt = False # type: ignore
737-
prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore
739+
prompt: str = tokenizer.apply_chat_template( # type: ignore
740+
messages, tokenize=False, **kwargs
741+
)
738742
assert isinstance(prompt, str)
739743
# Return formatted prompt and eos token by default
740744
return ChatFormatterResponse(
@@ -791,6 +795,7 @@ def format_tokenizer_config(
791795
messages=messages,
792796
bos_token=bos_token,
793797
eos_token=eos_token,
798+
**kwargs,
794799
)
795800
return ChatFormatterResponse(
796801
prompt=prompt, stop=[eos_token, bos_token], added_special=True

llama_cpp/server/cli.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

33
import argparse
4+
import json
45

5-
from typing import List, Literal, Union, Any, Type, TypeVar
6+
from typing import List, Literal, Union, Any, Type, TypeVar, Dict
67

78
from pydantic import BaseModel
89

@@ -40,6 +41,17 @@ def _contains_list_type(annotation: Type[Any] | None) -> bool:
4041
return False
4142

4243

44+
def _contains_dict_type(annotation: Type[Any] | None) -> bool:
45+
origin = getattr(annotation, "__origin__", None)
46+
47+
if origin is dict or origin is Dict:
48+
return True
49+
elif origin in (Literal, Union):
50+
return any(_contains_dict_type(arg) for arg in annotation.__args__) # type: ignore
51+
else:
52+
return False
53+
54+
4355
def _parse_bool_arg(arg: str | bytes | bool) -> bool:
4456
if isinstance(arg, bytes):
4557
arg = arg.decode("utf-8")
@@ -57,6 +69,16 @@ def _parse_bool_arg(arg: str | bytes | bool) -> bool:
5769
raise ValueError(f"Invalid boolean argument: {arg}")
5870

5971

72+
def _parse_json_object_arg(arg: str | bytes) -> dict[str, Any]:
73+
if isinstance(arg, bytes):
74+
arg = arg.decode("utf-8")
75+
76+
value = json.loads(arg)
77+
if not isinstance(value, dict):
78+
raise ValueError(f"Invalid JSON object argument: {arg}")
79+
return value
80+
81+
6082
def add_args_from_model(parser: argparse.ArgumentParser, model: Type[BaseModel]):
6183
"""Add arguments from a pydantic model to an argparse parser."""
6284

@@ -68,7 +90,15 @@ def add_args_from_model(parser: argparse.ArgumentParser, model: Type[BaseModel])
6890
_get_base_type(field.annotation) if field.annotation is not None else str
6991
)
7092
list_type = _contains_list_type(field.annotation)
71-
if base_type is not bool:
93+
dict_type = _contains_dict_type(field.annotation)
94+
if dict_type:
95+
parser.add_argument(
96+
f"--{name}",
97+
dest=name,
98+
type=_parse_json_object_arg,
99+
help=description,
100+
)
101+
elif base_type is not bool:
72102
parser.add_argument(
73103
f"--{name}",
74104
dest=name,

llama_cpp/server/model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,21 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
299299
# Misc
300300
verbose=settings.verbose,
301301
)
302+
if settings.chat_template_kwargs:
303+
base_chat_handler = (
304+
_model.chat_handler
305+
or _model._chat_handlers.get(_model.chat_format)
306+
or llama_cpp.llama_chat_format.get_chat_completion_handler(
307+
_model.chat_format
308+
)
309+
)
310+
311+
def chat_handler_with_kwargs(*args, **kwargs):
312+
return base_chat_handler(
313+
*args, **{**settings.chat_template_kwargs, **kwargs}
314+
)
315+
316+
_model.chat_handler = chat_handler_with_kwargs
302317
if settings.cache:
303318
if settings.cache_type == "disk":
304319
if settings.verbose:

llama_cpp/server/settings.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import multiprocessing
44

5-
from typing import Optional, List, Literal, Union, Dict, cast
5+
from typing import Any, Optional, List, Literal, Union, Dict, cast
66
from typing_extensions import Self
77

88
from pydantic import Field, model_validator
@@ -131,6 +131,10 @@ class ModelSettings(BaseSettings):
131131
default=None,
132132
description="Chat format to use.",
133133
)
134+
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
135+
default=None,
136+
description="Extra keyword arguments forwarded to chat templates at model load time. Matches llama.cpp server `chat_template_kwargs`.",
137+
)
134138
clip_model_path: Optional[str] = Field(
135139
default=None,
136140
description="Path to a CLIP model to use for multi-modal chat completion.",

tests/test_llama_chat_format.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,22 @@ def test_hf_tokenizer_config_str_to_chat_formatter():
9292
)
9393

9494
assert chat_formatter_respoonse.prompt == ("<s>[INST] Hello, world! [/INST]</s>")
95+
96+
97+
def test_hf_tokenizer_config_chat_formatter_passes_template_kwargs():
98+
tokenizer_config = {
99+
"chat_template": "{{ bos_token }}{{ enable_thinking | default(false) }} {{ messages[0]['content'] }}",
100+
"bos_token": "<s>",
101+
"eos_token": "</s>",
102+
}
103+
chat_formatter = hf_tokenizer_config_to_chat_formatter(
104+
tokenizer_config, add_generation_prompt=False
105+
)
106+
response = chat_formatter(
107+
messages=[
108+
ChatCompletionRequestUserMessage(role="user", content="Hello, world!"),
109+
],
110+
enable_thinking=True,
111+
)
112+
113+
assert response.prompt == "<s>True Hello, world!"

tests/test_server_model.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import argparse
2+
3+
from llama_cpp.server.cli import add_args_from_model, parse_model_from_args
4+
import llama_cpp.server.model as server_model
5+
from llama_cpp.server.settings import ModelSettings
6+
7+
8+
def test_cli_parses_chat_template_kwargs_json():
9+
parser = argparse.ArgumentParser()
10+
add_args_from_model(parser, ModelSettings)
11+
12+
args = parser.parse_args(
13+
[
14+
"--model",
15+
"test.gguf",
16+
"--chat_template_kwargs",
17+
'{"enable_thinking": true, "template_mode": "extended"}',
18+
]
19+
)
20+
settings = parse_model_from_args(ModelSettings, args)
21+
22+
assert settings.chat_template_kwargs == {
23+
"enable_thinking": True,
24+
"template_mode": "extended",
25+
}
26+
27+
28+
def test_load_llama_from_model_settings_merges_chat_template_kwargs(monkeypatch):
29+
captured = {}
30+
31+
def base_handler(*args, **kwargs):
32+
captured["args"] = args
33+
captured["kwargs"] = kwargs
34+
return "ok"
35+
36+
class FakeLlama:
37+
def __init__(self, **kwargs):
38+
self.chat_handler = kwargs["chat_handler"]
39+
self.chat_format = kwargs["chat_format"]
40+
self._chat_handlers = {}
41+
42+
def set_cache(self, cache):
43+
raise AssertionError("cache should not be set in this test")
44+
45+
monkeypatch.setattr(server_model.llama_cpp, "Llama", FakeLlama)
46+
monkeypatch.setattr(
47+
server_model.llama_cpp.llama_chat_format,
48+
"get_chat_completion_handler",
49+
lambda chat_format: base_handler,
50+
)
51+
52+
model = server_model.LlamaProxy.load_llama_from_model_settings(
53+
ModelSettings(
54+
model="test.gguf",
55+
chat_format="chatml",
56+
chat_template_kwargs={
57+
"enable_thinking": True,
58+
"template_mode": "default",
59+
},
60+
)
61+
)
62+
63+
result = model.chat_handler(template_mode="override", extra_flag="x")
64+
65+
assert result == "ok"
66+
assert captured["kwargs"] == {
67+
"enable_thinking": True,
68+
"template_mode": "override",
69+
"extra_flag": "x",
70+
}

0 commit comments

Comments
 (0)