Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

max_tokens to max_new_tokens #149

Merged
merged 5 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/internlm2_agent_cli_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def main():
model = HFTransformer(
path=args.path,
meta_template=META,
max_new_tokens=1024,
top_p=0.8,
top_k=None,
temperature=0.1,
Expand Down Expand Up @@ -69,7 +70,7 @@ def input_prompt():
print('\nInternLm2:', end='')
current_length = 0
last_status = None
for agent_return in chatbot.stream_chat(history, max_new_tokens=512):
for agent_return in chatbot.stream_chat(history):
status = agent_return.state
if status not in [
AgentStatusCode.STREAM_ING, AgentStatusCode.CODING,
Expand Down
1 change: 1 addition & 0 deletions examples/internlm2_agent_web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def init_model(self, option, ip=None):
model_name='internlm2-chat-20b',
url=model_url,
meta_template=META,
max_new_tokens=1024,
top_p=0.8,
top_k=100,
temperature=0,
Expand Down
4 changes: 2 additions & 2 deletions examples/model_cli_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def main():
model = HFTransformer(
path=args.path,
meta_template=META,
max_new_tokens=1024,
top_p=0.8,
top_k=None,
temperature=0.1,
Expand All @@ -51,8 +52,7 @@ def input_prompt():
history = [dict(role='user', content=prompt)]
print('\nInternLm2:', end='')
current_length = 0
for status, response, _ in model.stream_chat(
history, max_new_tokens=512):
for status, response, _ in model.stream_chat(history):
print(response[current_length:], end='', flush=True)
current_length = len(response)
history.append(dict(role='assistant', content=response))
Expand Down
4 changes: 2 additions & 2 deletions lagent/llms/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(self,
template_parser: 'APITemplateParser' = APITemplateParser,
meta_template: Optional[Dict] = None,
*,
max_tokens: int = 512,
max_new_tokens: int = 512,
top_p: float = 0.8,
top_k: float = None,
temperature: float = 0.8,
Expand All @@ -169,7 +169,7 @@ def __init__(self,
if isinstance(stop_words, str):
stop_words = [stop_words]
self.gen_params = dict(
max_tokens=max_tokens,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
temperature=temperature,
Expand Down
8 changes: 4 additions & 4 deletions lagent/llms/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class BaseModel:

Args:
path (str): The path to the model.
max_seq_len (int): The maximum sequence length of the model. Defaults
to 2048.
max_new_tokens (int): Maximum length of output expected to be generated by the model. Defaults
to 512.
tokenizer_only (bool): If True, only the tokenizer will be initialized.
Defaults to False.
meta_template (list of dict, optional): The model's meta prompt
Expand All @@ -116,7 +116,7 @@ def __init__(self,
template_parser: 'LMTemplateParser' = LMTemplateParser,
meta_template: Optional[List[Dict]] = None,
*,
max_tokens: int = 512,
max_new_tokens: int = 512,
top_p: float = 0.8,
top_k: float = None,
temperature: float = 0.8,
Expand All @@ -133,7 +133,7 @@ def __init__(self,
if isinstance(stop_words, str):
stop_words = [stop_words]
self.gen_params = dict(
max_tokens=max_tokens,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
temperature=temperature,
Expand Down
49 changes: 3 additions & 46 deletions lagent/llms/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
import logging
import warnings
from typing import Dict, List, Optional, Union

from lagent.schema import ModelStatusCode
Expand All @@ -19,8 +18,6 @@ class HFTransformer(BaseModel):

Args:
path (str): The name or path to HuggingFace's model.
max_seq_len (int): The maximum length of the input sequence. Defaults
to 2048.
tokenizer_path (str): The path to the tokenizer. Defaults to None.
tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
Defaults to {}.
Expand Down Expand Up @@ -157,49 +154,9 @@ def stream_generate(
eos_token_id.extend(self.additional_eos_token_id)
eos_token_id_tensor = torch.tensor(eos_token_id).to(
input_ids.device) if eos_token_id is not None else None
has_default_max_length = (
kwargs.get('max_length') is None
and generation_config.max_length is not None)
if (has_default_max_length
and generation_config.max_new_tokens is None):
warnings.warn(
"Using `max_length`'s default"
f'({generation_config.max_length})'
'to control the generation length. '
'This behaviour is deprecated and will be removed'
' from the config in v5 of Transformers -- we'
' recommend using `max_new_tokens` to control the'
' maximum length of the generation.',
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length)
if not has_default_max_length:
logger.warn( # pylint: disable=W4902
'Both `max_new_tokens`'
f'(={generation_config.max_new_tokens})'
'and `max_length`'
f'(={generation_config.max_length})'
' seem to have been set.`max_new_tokens`'
' will take precedence. Please refer to'
' the documentation for more information. '
'(https://huggingface.co/docs/transformers/main/en'
'/main_classes/text_generation)',
UserWarning,
)

if input_ids_seq_length >= generation_config.max_length:
input_ids_string = 'input_ids'
logger.warning(
f'Input length of {input_ids_string}'
f' is {input_ids_seq_length},'
' but `max_length` is set to'
f' {generation_config.max_length}.'
'This can lead to unexpected behavior.'
' You should consider increasing `max_new_tokens`.')

# 2. Set generation parameters if not already defined
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length)
# Set generation parameters if not already defined
logits_processor = self.logits_processor
stopping_criteria = self.stopping_criteria

Expand Down
30 changes: 14 additions & 16 deletions lagent/llms/lmdepoly_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def generate(self,
inputs: Union[str, List[str]],
session_id: int = 2967,
request_id: str = '',
max_tokens: int = 512,
sequence_start: bool = True,
sequence_end: bool = True,
**kwargs):
Expand All @@ -57,7 +56,6 @@ def generate(self,
inputs (str, List[str]): user's prompt(s) in this round
session_id (int): the identical id of a session
request_id (str): the identical id of this round conversation
max_tokens (int): the expected generated token numbers
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session

Expand All @@ -72,9 +70,12 @@ def generate(self,
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'

self.chatbot.cfg = self._update_gen_params(**kwargs)
max_new_tokens = self.chatbot.cfg.max_new_tokens

logger = get_logger(log_level=self.chatbot.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'max_out_len {max_tokens}')
f'max_out_len {max_new_tokens}')

if self.chatbot._session is None:
sequence_start = True
Expand All @@ -88,12 +89,9 @@ def generate(self,
self.chatbot._session.request_id = request_id
self.chatbot._session.response = ''

self.chatbot.cfg = self._update_gen_params(
max_tokens=max_tokens, **kwargs)

status, res, _ = None, '', 0
for status, res, _ in self.chatbot._stream_infer(
self.chatbot._session, prompt, max_tokens, sequence_start,
self.chatbot._session, prompt, max_new_tokens, sequence_start,
sequence_end):
status = self.state_map.get(status)
if status < ModelStatusCode.END:
Expand All @@ -111,7 +109,6 @@ def stream_chat(self,
inputs: List[dict],
session_id: int = 2967,
request_id: str = '',
max_tokens: int = 512,
sequence_start: bool = True,
sequence_end: bool = True,
**kwargs):
Expand All @@ -122,7 +119,6 @@ def stream_chat(self,
session_id (int): the identical id of a session
inputs (List[dict]): user's inputs in this round conversation
request_id (str): the identical id of this round conversation
max_tokens (int): the expected generated token numbers
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session

Expand All @@ -134,9 +130,12 @@ def stream_chat(self,
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'

self.chatbot.cfg = self._update_gen_params(**kwargs)
max_new_tokens = self.chatbot.cfg.max_new_tokens

logger = get_logger(log_level=self.chatbot.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'max_out_len {max_tokens}')
f'max_out_len {max_new_tokens}')

if self.chatbot._session is None:
sequence_start = True
Expand All @@ -150,13 +149,10 @@ def stream_chat(self,
self.chatbot._session.request_id = request_id
self.chatbot._session.response = ''

self.chatbot.cfg = self._update_gen_params(
max_tokens=max_tokens, **kwargs)
prompt = self.template_parser(inputs)

status, res, _ = None, '', 0
for status, res, _ in self.chatbot._stream_infer(
self.chatbot._session, prompt, max_tokens, sequence_start,
self.chatbot._session, prompt, max_new_tokens, sequence_start,
sequence_end):
status = self.state_map.get(status)
# The stop symbol also appears in the output of the last STREAM_ING state.
Expand Down Expand Up @@ -246,9 +242,7 @@ def generate(self,
batched = False
prompt = inputs
gen_params = self.update_gen_params(**kwargs)
max_tokens = gen_params.pop('max_tokens')
gen_config = GenerationConfig(**gen_params)
gen_config.max_new_tokens = max_tokens
response = self.model.batch_infer(
prompt, gen_config=gen_config, do_preprocess=do_preprocess)
response = [resp.text for resp in response]
Expand Down Expand Up @@ -336,6 +330,8 @@ def generate(self,
batched = False

gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
gen_params.update(max_tokens=max_new_tokens)

resp = [''] * len(inputs)
for text in self.client.completions_v1(
Expand Down Expand Up @@ -383,6 +379,8 @@ def stream_chat(self,
generated token number
"""
gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
gen_params.update(max_tokens=max_new_tokens)
prompt = self.template_parser(inputs)

resp = ''
Expand Down
4 changes: 3 additions & 1 deletion lagent/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def chat(
assert isinstance(inputs, list)
if isinstance(inputs[0], dict):
inputs = [inputs]
if 'max_tokens' in gen_params:
raise NotImplementedError('unsupported parameter: max_tokens')
gen_params = {**self.gen_params, **gen_params}
with ThreadPoolExecutor(max_workers=20) as executor:
tasks = [
Expand All @@ -133,7 +135,7 @@ def _chat(self, messages: List[dict], **gen_params) -> str:

# Hold out 100 tokens due to potential errors in tiktoken calculation
max_tokens = min(
gen_params.pop('max_tokens'),
gen_params.pop('max_new_tokens'),
self.context_window - len(self.tokenize(str(input))) - 100)
if max_tokens <= 0:
return ''
Expand Down