Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added structured generation support to MlxLLM using Outlines #1108

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from

Conversation

dameikle
Copy link
Contributor

@dameikle dameikle commented Jan 22, 2025

Adds initial support for structured generation using Outlines to the MlxLLM.

I've tried to build on what was there and use wrappers to integrate with the outlines.py as opposed to changing the inferencing approach in the current class, so could be simpler.

@davidberenstein1957
Copy link
Member

davidberenstein1957 commented Jan 22, 2025

Awesome, @dameikle! Would you be able to forward some example code and write some tests for this too? Also, normally we work on top of develop would you be able to cherry-pick the commits and change the PR to develop?

@davidberenstein1957 davidberenstein1957 changed the base branch from main to develop January 22, 2025 13:27
@davidberenstein1957 davidberenstein1957 changed the base branch from develop to main January 22, 2025 13:27
@dameikle
Copy link
Contributor Author

Thanks for the reply @davidberenstein1957. Sorry I should have noticed I'd branched of main instead of develop 🙈 Sure thing, I'll add a test, and for the example, should I just add it to the docstring on the model?

@davidberenstein1957
Copy link
Member

@dameikle thanks for the quick response🔥 Try to align the docstring with what we've got for other LLMs. W.r.t. the example code, it helps maintainers to quickly copy-paste it and test the integration :)

@dameikle dameikle force-pushed the mlx_structured_generation branch from 55bc71f to e24470c Compare January 22, 2025 17:18
@dameikle
Copy link
Contributor Author

@davidberenstein1957 Hopefully this rebase has worked and not left too much noise.

I've added an example code to the docstring as well as a test using the same model the other one uses. You should be able to do something like this with it:

from pathlib import Path
from distilabel.models.llms import MlxLLM
from pydantic import BaseModel, Field

class User(BaseModel):
    name: str
    last_name: str
    email: str


llm = MlxLLM(
    path_or_hf_repo="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
    structured_output={"format": "json", "schema": User},
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for John Smith"}]])
print(output)
# [{'generations': ['{ "name": "John Smith", "last_name": "Smith", "email": "john.smith@email.com" }'], 'statistics': {'input_tokens': [7], 'output_tokens': [26]}}]

@dameikle dameikle changed the base branch from main to develop January 22, 2025 18:11
Copy link
Member

@davidberenstein1957 davidberenstein1957 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great! Some minor comments :) The code snippet looks good. For some reason it does work in 0.1.13 but not in 0.1.11, however, we can redirect users to a library upgrade if errors occur.

@@ -63,6 +64,47 @@ def test_generate(self, llm: MlxLLM) -> None:
assert "input_tokens" in statistics
assert "output_tokens" in statistics

def test_structured_generation_json(self, llm: MlxLLM) -> None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have structured generation tests in tests/unit/steps/tasks/structured_outputs/test_outlines.py could you add/integrate this test there?

self.model = model
self.tokenizer = tokenizer


class MlxLLM(LLM, MagpieChatTemplateMixin):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also be able to pass structured output format during the class init.


if TYPE_CHECKING:
import mlx.nn as nn
from mlx_lm.tokenizer_utils import TokenizerWrapper


class MlxModel:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can import this from outlines.models.mlxlm

@@ -99,6 +140,7 @@ def load(self) -> None:
model_config=self.mlx_model_config,
adapter_path=self.adapter_path,
)
self._wrapped_model = MlxModel(self._model, self._tokenizer)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would create this during the load of the class.

@@ -101,6 +102,11 @@ def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]:
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
"mlx": (

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is mlx not implemented for outlines below 0.1?

@@ -37,10 +37,11 @@
from llama_cpp import Llama # noqa
from transformers import Pipeline # noqa
from vllm import LLM as _vLLM # noqa
from distilabel.models.llms.mlx import MlxModel # noqa

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would import this class from outlines.model.mlxlm to avoid code duplication

@dameikle
Copy link
Contributor Author

@davidberenstein1957 thanks for the review, and sorry it's taken a while to look at them. I'll work through them and get them resolved. It looks like in outlines there was a bug stopping this working correctly in 0.1.11, I'll see if I can do something smart for highlighting this to the users.

@davidberenstein1957
Copy link
Member

@dameikle I think we can keep it like this and remember because we can't do triple constraints.

Something like this could be an option too.

def is_valid_version(version):
    return (version > "1.2" and version < "2") or (version > "2.1")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants