-
Notifications
You must be signed in to change notification settings - Fork 162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added structured generation support to MlxLLM using Outlines #1108
base: develop
Are you sure you want to change the base?
Conversation
Awesome, @dameikle! Would you be able to forward some example code and write some tests for this too? Also, normally we work on top of |
Thanks for the reply @davidberenstein1957. Sorry I should have noticed I'd branched of main instead of develop 🙈 Sure thing, I'll add a test, and for the example, should I just add it to the docstring on the model? |
@dameikle thanks for the quick response🔥 Try to align the docstring with what we've got for other LLMs. W.r.t. the example code, it helps maintainers to quickly copy-paste it and test the integration :) |
55bc71f
to
e24470c
Compare
@davidberenstein1957 Hopefully this rebase has worked and not left too much noise. I've added an example code to the docstring as well as a test using the same model the other one uses. You should be able to do something like this with it: from pathlib import Path
from distilabel.models.llms import MlxLLM
from pydantic import BaseModel, Field
class User(BaseModel):
name: str
last_name: str
email: str
llm = MlxLLM(
path_or_hf_repo="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
structured_output={"format": "json", "schema": User},
)
llm.load()
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for John Smith"}]])
print(output)
# [{'generations': ['{ "name": "John Smith", "last_name": "Smith", "email": "john.smith@email.com" }'], 'statistics': {'input_tokens': [7], 'output_tokens': [26]}}] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great! Some minor comments :) The code snippet looks good. For some reason it does work in 0.1.13 but not in 0.1.11, however, we can redirect users to a library upgrade if errors occur.
@@ -63,6 +64,47 @@ def test_generate(self, llm: MlxLLM) -> None: | |||
assert "input_tokens" in statistics | |||
assert "output_tokens" in statistics | |||
|
|||
def test_structured_generation_json(self, llm: MlxLLM) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have structured generation tests in tests/unit/steps/tasks/structured_outputs/test_outlines.py
could you add/integrate this test there?
self.model = model | ||
self.tokenizer = tokenizer | ||
|
||
|
||
class MlxLLM(LLM, MagpieChatTemplateMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should also be able to pass structured output format during the class init.
|
||
if TYPE_CHECKING: | ||
import mlx.nn as nn | ||
from mlx_lm.tokenizer_utils import TokenizerWrapper | ||
|
||
|
||
class MlxModel: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can import this from outlines.models.mlxlm
@@ -99,6 +140,7 @@ def load(self) -> None: | |||
model_config=self.mlx_model_config, | |||
adapter_path=self.adapter_path, | |||
) | |||
self._wrapped_model = MlxModel(self._model, self._tokenizer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would create this during the load of the class.
@@ -101,6 +102,11 @@ def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]: | |||
"JSONLogitsProcessor", | |||
"RegexLogitsProcessor", | |||
), | |||
"mlx": ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is mlx not implemented for outlines below 0.1?
@@ -37,10 +37,11 @@ | |||
from llama_cpp import Llama # noqa | |||
from transformers import Pipeline # noqa | |||
from vllm import LLM as _vLLM # noqa | |||
from distilabel.models.llms.mlx import MlxModel # noqa |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would import this class from outlines.model.mlxlm
to avoid code duplication
@davidberenstein1957 thanks for the review, and sorry it's taken a while to look at them. I'll work through them and get them resolved. It looks like in outlines there was a bug stopping this working correctly in 0.1.11, I'll see if I can do something smart for highlighting this to the users. |
@dameikle I think we can keep it like this and remember because we can't do triple constraints. Something like this could be an option too.
|
Adds initial support for structured generation using Outlines to the MlxLLM.
I've tried to build on what was there and use wrappers to integrate with the outlines.py as opposed to changing the inferencing approach in the current class, so could be simpler.