Skip to content

Commit

Permalink
Merge branch 'huggingface:main' into gguf-dequant
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored Aug 24, 2024
2 parents b7eae81 + 0a7af19 commit 8d66904
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 128 deletions.
70 changes: 34 additions & 36 deletions docs/source/en/chat_templating.md
Original file line number Diff line number Diff line change
Expand Up @@ -785,14 +785,23 @@ it's time to put an end to them!

## Advanced: Template writing tips

If you're unfamiliar with Jinja, we generally find that the easiest way to write a chat template is to first
write a short Python script that formats messages the way you want, and then convert that script into a template.
<Tip>

The easiest way to get started with writing Jinja templates is to take a look at some existing ones. You can use
`print(tokenizer.chat_template)` for any chat model to see what template it's using. In general, models that support tool use have
much more complex templates than other models - so when you're just getting started, they're probably a bad example
to learn from! You can also take a look at the
[Jinja documentation](https://jinja.palletsprojects.com/en/3.1.x/templates/#synopsis) for details
of general Jinja formatting and syntax.

Remember that the template handler will receive the conversation history as a variable called `messages`.
</Tip>

Jinja templates in `transformers` are identical to Jinja templates elsewhere. The main thing to know is that
the conversation history will be accessible inside your template as a variable called `messages`.
You will be able to access `messages` in your template just like you can in Python, which means you can loop over
it with `{% for message in messages %}` or access individual messages with `{{ messages[0] }}`, for example.

You can also use the following tips to convert your code to Jinja:
You can also use the following tips to write clean, efficient Jinja templates:

### Trimming whitespace

Expand All @@ -817,46 +826,35 @@ rather than like this:
Adding `-` will strip any whitespace that comes before the block. The second example looks innocent, but the newline
and indentation may end up being included in the output, which is probably not what you want!

### For loops

For loops in Jinja look like this:

```
{%- for message in messages %}
{{- message['content'] }}
{%- endfor %}
```
### Special variables

Note that whatever's inside the {{ expression block }} will be printed to the output. You can use operators like
`+` to combine strings inside expression blocks.
Inside your template, you will have access several special variables. The most important of these is `messages`,
which contains the chat history as a list of message dicts. However, there are several others. Not every
variable will be used in every template. The most common other variables are:

### If statements
- `tools` contains a list of tools in JSON schema format. Will be `None` or undefined if no tools are passed.
- `documents` contains a list of documents in the format `{"title": "Title", "contents": "Contents"}`, used for retrieval-augmented generation. Will be `None` or undefined if no documents are passed.
- `add_generation_prompt` is a bool that is `True` if the user has requested a generation prompt, and `False` otherwise. If this is set, your template should add the header for an assistant message to the end of the conversation. If your model doesn't have a specific header for assistant messages, you can ignore this flag.
- **Special tokens** like `bos_token` and `eos_token`. These are extracted from `tokenizer.special_tokens_map`. The exact tokens available inside each template will differ depending on the parent tokenizer.

If statements in Jinja look like this:
<Tip>

```
{%- if message['role'] == 'user' %}
{{- message['content'] }}
{%- endif %}
```
You can actually pass any `kwarg` to `apply_chat_template`, and it will be accessible inside the template as a variable. In general,
we recommend trying to stick to the core variables above, as it will make your model harder to use if users have
to write custom code to pass model-specific `kwargs`. However, we're aware that this field moves quickly, so if you
have a new use-case that doesn't fit in the core API, feel free to use a new `kwarg` for it! If a new `kwarg`
becomes common we may promote it into the core API and create a standard, documented format for it.

Note how where Python uses whitespace to mark the beginnings and ends of `for` and `if` blocks, Jinja requires you
to explicitly end them with `{% endfor %}` and `{% endif %}`.
</Tip>

### Special variables
### Callable functions

Inside your template, you will have access to the list of `messages`, but you can also access several other special
variables. These include special tokens like `bos_token` and `eos_token`, as well as the `add_generation_prompt`
variable that we discussed above. You can also use the `loop` variable to access information about the current loop
iteration, for example using `{% if loop.last %}` to check if the current message is the last message in the
conversation. Here's an example that puts these ideas together to add a generation prompt at the end of the
conversation if add_generation_prompt is `True`:
There is also a short list of callable functions available to you inside your templates. These are:

```
{%- if loop.last and add_generation_prompt %}
{{- bos_token + 'Assistant:\n' }}
{%- endif %}
```
- `raise_exception(msg)`: Raises a `TemplateException`. This is useful for debugging, and for telling users when they're
doing something that your template doesn't support.
- `strftime_now(format_str)`: Equivalent to `datetime.now().strftime(format_str)` in Python. This is used for getting
the current date/time in a specific format, which is sometimes included in system messages.

### Compatibility with non-Python Jinja

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &attn_weight,
const int im2col_step)
{
at::DeviceGuard guard(value.device());

AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
Expand Down Expand Up @@ -92,6 +94,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &grad_output,
const int im2col_step)
{
at::DeviceGuard guard(value.device());

AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
Expand Down
94 changes: 3 additions & 91 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from collections.abc import Mapping, Sized
from contextlib import contextmanager
from dataclasses import dataclass
from functools import lru_cache
from inspect import isfunction
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -65,6 +64,7 @@
requires_backends,
to_py_obj,
)
from .utils.chat_template_utils import _compile_jinja_template, _render_with_assistant_indices


if TYPE_CHECKING:
Expand Down Expand Up @@ -1791,7 +1791,7 @@ def apply_chat_template(
)

# Compilation function uses a cache to avoid recompiling the same template
compiled_template = self._compile_jinja_template(chat_template)
compiled_template = _compile_jinja_template(chat_template)

if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
Expand Down Expand Up @@ -1831,7 +1831,7 @@ def apply_chat_template(
# Indicates it's a Conversation object
chat = chat.messages
if return_assistant_tokens_mask:
rendered_chat, generation_indices = self._render_with_assistant_indices(
rendered_chat, generation_indices = _render_with_assistant_indices(
compiled_template=compiled_template,
messages=chat,
tools=tool_schemas,
Expand Down Expand Up @@ -1888,94 +1888,6 @@ def apply_chat_template(
else:
return rendered

def _render_with_assistant_indices(
self, compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs
):
rendered_blocks = []
generation_indices = []
with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices):
for block in compiled_template.generate(
messages=messages,
tools=tools,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
):
rendered_blocks.append(block)
rendered_chat = "".join(rendered_blocks)
return rendered_chat, generation_indices

@lru_cache
def _compile_jinja_template(self, chat_template):
try:
import jinja2
from jinja2 import nodes
from jinja2.exceptions import TemplateError
from jinja2.ext import Extension
from jinja2.sandbox import ImmutableSandboxedEnvironment
except ImportError:
raise ImportError("apply_chat_template requires jinja2 to be installed.")

if version.parse(jinja2.__version__) < version.parse("3.1.0"):
raise ImportError(
"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}."
)

def raise_exception(message):
raise TemplateError(message)

def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
# We override the built-in tojson filter because Jinja's default filter escapes HTML characters
# We also expose some options like custom indents and separators
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)

class AssistantTracker(Extension):
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
tags = {"generation"}

def __init__(self, environment: ImmutableSandboxedEnvironment):
# The class is only initiated by jinja.
super().__init__(environment)
environment.extend(activate_tracker=self.activate_tracker)
self._rendered_blocks = None
self._generation_indices = None

def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)

@jinja2.pass_eval_context
def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
rv = caller()
if self.is_active():
# Only track generation indices if the tracker is active
start_index = len("".join(self._rendered_blocks))
end_index = start_index + len(rv)
self._generation_indices.append((start_index, end_index))
return rv

def is_active(self) -> bool:
return self._rendered_blocks or self._generation_indices

@contextmanager
def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]):
try:
if self.is_active():
raise ValueError("AssistantTracker should not be reused before closed")
self._rendered_blocks = rendered_blocks
self._generation_indices = generation_indices

yield
finally:
self._rendered_blocks = None
self._generation_indices = None

jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker])
jinja_env.filters["tojson"] = tojson
jinja_env.globals["raise_exception"] = raise_exception
return jinja_env.from_string(chat_template)

def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[List[Dict]] = None) -> str:
"""
Retrieve the chat template string used for tokenizing chat messages. This template is used
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,22 @@ def speed_metrics(split, start_time, num_samples=None, num_steps=None, num_token


class SchedulerType(ExplicitEnum):
"""
Scheduler names for the parameter `lr_scheduler_type` in [`TrainingArguments`].
By default, it uses "linear". Internally, this retrieves `get_linear_schedule_with_warmup` scheduler from [`Trainer`].
Scheduler types:
- "linear" = get_linear_schedule_with_warmup
- "cosine" = get_cosine_schedule_with_warmup
- "cosine_with_restarts" = get_cosine_with_hard_restarts_schedule_with_warmup
- "polynomial" = get_polynomial_decay_schedule_with_warmup
- "constant" = get_constant_schedule
- "constant_with_warmup" = get_constant_schedule_with_warmup
- "inverse_sqrt" = get_inverse_sqrt_schedule
- "reduce_lr_on_plateau" = get_reduce_on_plateau_schedule
- "cosine_with_min_lr" = get_cosine_with_min_lr_schedule_with_warmup
- "warmup_stable_decay" = get_wsd_schedule
"""

LINEAR = "linear"
COSINE = "cosine"
COSINE_WITH_RESTARTS = "cosine_with_restarts"
Expand Down
104 changes: 103 additions & 1 deletion src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,22 @@
import inspect
import json
import re
from typing import Any, Callable, Dict, Optional, Tuple, Union, get_args, get_origin, get_type_hints
from contextlib import contextmanager
from datetime import datetime
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin, get_type_hints

from packaging import version

from .import_utils import is_jinja_available


if is_jinja_available():
import jinja2
from jinja2.ext import Extension
from jinja2.sandbox import ImmutableSandboxedEnvironment
else:
jinja2 = None


BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
Expand Down Expand Up @@ -314,3 +329,90 @@ def get_json_schema(func: Callable) -> Dict:
if return_dict is not None:
output["return"] = return_dict
return {"type": "function", "function": output}


def _render_with_assistant_indices(
compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs
):
rendered_blocks = []
generation_indices = []
with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices):
for block in compiled_template.generate(
messages=messages,
tools=tools,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
):
rendered_blocks.append(block)
rendered_chat = "".join(rendered_blocks)
return rendered_chat, generation_indices


@lru_cache
def _compile_jinja_template(chat_template):
class AssistantTracker(Extension):
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
tags = {"generation"}

def __init__(self, environment: ImmutableSandboxedEnvironment):
# The class is only initiated by jinja.
super().__init__(environment)
environment.extend(activate_tracker=self.activate_tracker)
self._rendered_blocks = None
self._generation_indices = None

def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)

@jinja2.pass_eval_context
def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
rv = caller()
if self.is_active():
# Only track generation indices if the tracker is active
start_index = len("".join(self._rendered_blocks))
end_index = start_index + len(rv)
self._generation_indices.append((start_index, end_index))
return rv

def is_active(self) -> bool:
return self._rendered_blocks or self._generation_indices

@contextmanager
def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]):
try:
if self.is_active():
raise ValueError("AssistantTracker should not be reused before closed")
self._rendered_blocks = rendered_blocks
self._generation_indices = generation_indices

yield
finally:
self._rendered_blocks = None
self._generation_indices = None

if version.parse(jinja2.__version__) < version.parse("3.1.0"):
raise ImportError(
"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}."
)

def raise_exception(message):
raise jinja2.exceptions.TemplateError(message)

def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
# We override the built-in tojson filter because Jinja's default filter escapes HTML characters
# We also expose some options like custom indents and separators
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)

def strftime_now(format):
return datetime.now().strftime(format)

jinja_env = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols]
)
jinja_env.filters["tojson"] = tojson
jinja_env.globals["raise_exception"] = raise_exception
jinja_env.globals["strftime_now"] = strftime_now
return jinja_env.from_string(chat_template)
Loading

0 comments on commit 8d66904

Please sign in to comment.