Skip to content

Commit

Permalink
max_tokens to max_new_tokens (#149)
Browse files Browse the repository at this point in the history
* Fix: max_new_tokens to max_tokens

* change `max_tokens` to `max_new_tokens` in API models

* max_tokens to max_new_tokens

* inject parameter 'max_new_tokens' for examples

---------

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>
  • Loading branch information
liujiangning30 and wangzy authored Feb 6, 2024
1 parent 90ef521 commit 7b71988
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 72 deletions.
3 changes: 2 additions & 1 deletion examples/internlm2_agent_cli_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def main():
model = HFTransformer(
path=args.path,
meta_template=META,
max_new_tokens=1024,
top_p=0.8,
top_k=None,
temperature=0.1,
Expand Down Expand Up @@ -69,7 +70,7 @@ def input_prompt():
print('\nInternLm2:', end='')
current_length = 0
last_status = None
for agent_return in chatbot.stream_chat(history, max_new_tokens=512):
for agent_return in chatbot.stream_chat(history):
status = agent_return.state
if status not in [
AgentStatusCode.STREAM_ING, AgentStatusCode.CODING,
Expand Down
1 change: 1 addition & 0 deletions examples/internlm2_agent_web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def init_model(self, option, ip=None):
model_name='internlm2-chat-20b',
url=model_url,
meta_template=META,
max_new_tokens=1024,
top_p=0.8,
top_k=100,
temperature=0,
Expand Down
4 changes: 2 additions & 2 deletions examples/model_cli_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def main():
model = HFTransformer(
path=args.path,
meta_template=META,
max_new_tokens=1024,
top_p=0.8,
top_k=None,
temperature=0.1,
Expand All @@ -51,8 +52,7 @@ def input_prompt():
history = [dict(role='user', content=prompt)]
print('\nInternLm2:', end='')
current_length = 0
for status, response, _ in model.stream_chat(
history, max_new_tokens=512):
for status, response, _ in model.stream_chat(history):
print(response[current_length:], end='', flush=True)
current_length = len(response)
history.append(dict(role='assistant', content=response))
Expand Down
4 changes: 2 additions & 2 deletions lagent/llms/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(self,
template_parser: 'APITemplateParser' = APITemplateParser,
meta_template: Optional[Dict] = None,
*,
max_tokens: int = 512,
max_new_tokens: int = 512,
top_p: float = 0.8,
top_k: float = None,
temperature: float = 0.8,
Expand All @@ -169,7 +169,7 @@ def __init__(self,
if isinstance(stop_words, str):
stop_words = [stop_words]
self.gen_params = dict(
max_tokens=max_tokens,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
temperature=temperature,
Expand Down
8 changes: 4 additions & 4 deletions lagent/llms/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class BaseModel:
Args:
path (str): The path to the model.
max_seq_len (int): The maximum sequence length of the model. Defaults
to 2048.
max_new_tokens (int): Maximum length of output expected to be generated by the model. Defaults
to 512.
tokenizer_only (bool): If True, only the tokenizer will be initialized.
Defaults to False.
meta_template (list of dict, optional): The model's meta prompt
Expand All @@ -116,7 +116,7 @@ def __init__(self,
template_parser: 'LMTemplateParser' = LMTemplateParser,
meta_template: Optional[List[Dict]] = None,
*,
max_tokens: int = 512,
max_new_tokens: int = 512,
top_p: float = 0.8,
top_k: float = None,
temperature: float = 0.8,
Expand All @@ -133,7 +133,7 @@ def __init__(self,
if isinstance(stop_words, str):
stop_words = [stop_words]
self.gen_params = dict(
max_tokens=max_tokens,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
temperature=temperature,
Expand Down
49 changes: 3 additions & 46 deletions lagent/llms/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
import logging
import warnings
from typing import Dict, List, Optional, Union

from lagent.schema import ModelStatusCode
Expand All @@ -19,8 +18,6 @@ class HFTransformer(BaseModel):
Args:
path (str): The name or path to HuggingFace's model.
max_seq_len (int): The maximum length of the input sequence. Defaults
to 2048.
tokenizer_path (str): The path to the tokenizer. Defaults to None.
tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
Defaults to {}.
Expand Down Expand Up @@ -157,49 +154,9 @@ def stream_generate(
eos_token_id.extend(self.additional_eos_token_id)
eos_token_id_tensor = torch.tensor(eos_token_id).to(
input_ids.device) if eos_token_id is not None else None
has_default_max_length = (
kwargs.get('max_length') is None
and generation_config.max_length is not None)
if (has_default_max_length
and generation_config.max_new_tokens is None):
warnings.warn(
"Using `max_length`'s default"
f'({generation_config.max_length})'
'to control the generation length. '
'This behaviour is deprecated and will be removed'
' from the config in v5 of Transformers -- we'
' recommend using `max_new_tokens` to control the'
' maximum length of the generation.',
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length)
if not has_default_max_length:
logger.warn( # pylint: disable=W4902
'Both `max_new_tokens`'
f'(={generation_config.max_new_tokens})'
'and `max_length`'
f'(={generation_config.max_length})'
' seem to have been set.`max_new_tokens`'
' will take precedence. Please refer to'
' the documentation for more information. '
'(https://huggingface.co/docs/transformers/main/en'
'/main_classes/text_generation)',
UserWarning,
)

if input_ids_seq_length >= generation_config.max_length:
input_ids_string = 'input_ids'
logger.warning(
f'Input length of {input_ids_string}'
f' is {input_ids_seq_length},'
' but `max_length` is set to'
f' {generation_config.max_length}.'
'This can lead to unexpected behavior.'
' You should consider increasing `max_new_tokens`.')

# 2. Set generation parameters if not already defined
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length)
# Set generation parameters if not already defined
logits_processor = self.logits_processor
stopping_criteria = self.stopping_criteria

Expand Down
30 changes: 14 additions & 16 deletions lagent/llms/lmdepoly_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def generate(self,
inputs: Union[str, List[str]],
session_id: int = 2967,
request_id: str = '',
max_tokens: int = 512,
sequence_start: bool = True,
sequence_end: bool = True,
**kwargs):
Expand All @@ -57,7 +56,6 @@ def generate(self,
inputs (str, List[str]): user's prompt(s) in this round
session_id (int): the identical id of a session
request_id (str): the identical id of this round conversation
max_tokens (int): the expected generated token numbers
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
Expand All @@ -72,9 +70,12 @@ def generate(self,
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'

self.chatbot.cfg = self._update_gen_params(**kwargs)
max_new_tokens = self.chatbot.cfg.max_new_tokens

logger = get_logger(log_level=self.chatbot.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'max_out_len {max_tokens}')
f'max_out_len {max_new_tokens}')

if self.chatbot._session is None:
sequence_start = True
Expand All @@ -88,12 +89,9 @@ def generate(self,
self.chatbot._session.request_id = request_id
self.chatbot._session.response = ''

self.chatbot.cfg = self._update_gen_params(
max_tokens=max_tokens, **kwargs)

status, res, _ = None, '', 0
for status, res, _ in self.chatbot._stream_infer(
self.chatbot._session, prompt, max_tokens, sequence_start,
self.chatbot._session, prompt, max_new_tokens, sequence_start,
sequence_end):
status = self.state_map.get(status)
if status < ModelStatusCode.END:
Expand All @@ -111,7 +109,6 @@ def stream_chat(self,
inputs: List[dict],
session_id: int = 2967,
request_id: str = '',
max_tokens: int = 512,
sequence_start: bool = True,
sequence_end: bool = True,
**kwargs):
Expand All @@ -122,7 +119,6 @@ def stream_chat(self,
session_id (int): the identical id of a session
inputs (List[dict]): user's inputs in this round conversation
request_id (str): the identical id of this round conversation
max_tokens (int): the expected generated token numbers
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
Expand All @@ -134,9 +130,12 @@ def stream_chat(self,
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'

self.chatbot.cfg = self._update_gen_params(**kwargs)
max_new_tokens = self.chatbot.cfg.max_new_tokens

logger = get_logger(log_level=self.chatbot.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'max_out_len {max_tokens}')
f'max_out_len {max_new_tokens}')

if self.chatbot._session is None:
sequence_start = True
Expand All @@ -150,13 +149,10 @@ def stream_chat(self,
self.chatbot._session.request_id = request_id
self.chatbot._session.response = ''

self.chatbot.cfg = self._update_gen_params(
max_tokens=max_tokens, **kwargs)
prompt = self.template_parser(inputs)

status, res, _ = None, '', 0
for status, res, _ in self.chatbot._stream_infer(
self.chatbot._session, prompt, max_tokens, sequence_start,
self.chatbot._session, prompt, max_new_tokens, sequence_start,
sequence_end):
status = self.state_map.get(status)
# The stop symbol also appears in the output of the last STREAM_ING state.
Expand Down Expand Up @@ -246,9 +242,7 @@ def generate(self,
batched = False
prompt = inputs
gen_params = self.update_gen_params(**kwargs)
max_tokens = gen_params.pop('max_tokens')
gen_config = GenerationConfig(**gen_params)
gen_config.max_new_tokens = max_tokens
response = self.model.batch_infer(
prompt, gen_config=gen_config, do_preprocess=do_preprocess)
response = [resp.text for resp in response]
Expand Down Expand Up @@ -336,6 +330,8 @@ def generate(self,
batched = False

gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
gen_params.update(max_tokens=max_new_tokens)

resp = [''] * len(inputs)
for text in self.client.completions_v1(
Expand Down Expand Up @@ -383,6 +379,8 @@ def stream_chat(self,
generated token number
"""
gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
gen_params.update(max_tokens=max_new_tokens)
prompt = self.template_parser(inputs)

resp = ''
Expand Down
4 changes: 3 additions & 1 deletion lagent/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def chat(
assert isinstance(inputs, list)
if isinstance(inputs[0], dict):
inputs = [inputs]
if 'max_tokens' in gen_params:
raise NotImplementedError('unsupported parameter: max_tokens')
gen_params = {**self.gen_params, **gen_params}
with ThreadPoolExecutor(max_workers=20) as executor:
tasks = [
Expand All @@ -133,7 +135,7 @@ def _chat(self, messages: List[dict], **gen_params) -> str:

# Hold out 100 tokens due to potential errors in tiktoken calculation
max_tokens = min(
gen_params.pop('max_tokens'),
gen_params.pop('max_new_tokens'),
self.context_window - len(self.tokenize(str(input))) - 100)
if max_tokens <= 0:
return ''
Expand Down

0 comments on commit 7b71988

Please sign in to comment.