Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatible with lmdeploy #258

Merged
merged 5 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions lagent/llms/lmdeploy_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy
import logging
from typing import List, Optional, Union

from lagent.llms.base_llm import BaseModel
Expand All @@ -23,7 +25,12 @@ def __init__(self,
log_level: str = 'WARNING',
**kwargs):
super().__init__(path=None, **kwargs)
from lmdeploy.serve.turbomind.chatbot import Chatbot, StatusCode
try:
from lmdeploy.serve.turbomind.chatbot import Chatbot, StatusCode
except Exception as e:
logging.error(f'{e}')
raise RuntimeError('DO NOT use turbomind.chatbot since it has '
'been removed by lmdeploy since v0.5.2')
self.state_map = {
StatusCode.TRITON_STREAM_END: ModelStatusCode.END,
StatusCode.TRITON_SERVER_ERR: ModelStatusCode.SERVER_ERR,
Expand Down Expand Up @@ -226,11 +233,32 @@ def __init__(self,
tp: int = 1,
pipeline_cfg=dict(),
**kwargs):

import lmdeploy
from lmdeploy import ChatTemplateConfig, TurbomindEngineConfig, pipeline, version_info

self.str_version = lmdeploy.__version__
self.version = version_info
self.do_sample = kwargs.pop('do_sample', None)
if self.do_sample is not None and self.version < (0, 6, 0):
raise RuntimeError(
'`do_sample` parameter is not supported by lmdeploy until '
f'v0.6.0, but currently using lmdeloy {self.str_version}')
super().__init__(path=path, **kwargs)
from lmdeploy import pipeline
backend_config = copy.deepcopy(pipeline_cfg)
backend_config.update(tp=tp)
backend_config = {
k: v
for k, v in backend_config.items()
if hasattr(TurbomindEngineConfig, k)
}
backend_config = TurbomindEngineConfig(**backend_config)
chat_template_config = ChatTemplateConfig(
model_name=model_name) if model_name else None
self.model = pipeline(
model_path=self.path, model_name=model_name, tp=tp, **pipeline_cfg)
model_path=self.path,
backend_config=backend_config,
chat_template_config=chat_template_config,
log_level='WARNING')

def generate(self,
inputs: Union[str, List[str]],
Expand All @@ -249,13 +277,26 @@ def generate(self,
(a list of/batched) text/chat completion
"""
from lmdeploy.messages import GenerationConfig

batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
prompt = inputs
do_sample = kwargs.pop('do_sample', None)
gen_params = self.update_gen_params(**kwargs)

if do_sample is None:
do_sample = self.do_sample
if do_sample is not None and self.version < (0, 6, 0):
raise RuntimeError(
'`do_sample` parameter is not supported by lmdeploy until '
f'v0.6.0, but currently using lmdeloy {self.str_version}')
if self.version >= (0, 6, 0):
if do_sample is None:
do_sample = gen_params['top_k'] > 1 or gen_params[
'temperature'] > 0
gen_params.update(do_sample=do_sample)

gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens, **gen_params)
response = self.model.batch_infer(
Expand Down
2 changes: 1 addition & 1 deletion requirements/optional.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
google-search-results
lmdeploy<=0.5.3
lmdeploy>=0.2.5
pillow
python-pptx
timeout_decorator
Expand Down