Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update GPT4ALL integration #4567

Merged
merged 4 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions docs/modules/models/llms/integrations/gpt4all.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
}
],
"source": [
"%pip install pygpt4all > /dev/null"
"%pip install gpt4all > /dev/null"
]
},
{
Expand Down Expand Up @@ -64,7 +64,7 @@
"source": [
"### Specify Model\n",
"\n",
"To run locally, download a compatible ggml-formatted model. For more info, visit https://github.com/nomic-ai/pygpt4all\n",
"To run locally, download a compatible ggml-formatted model. For more info, visit https://github.com/nomic-ai/gpt4all\n",
"\n",
"For full installation instructions go [here](https://gpt4all.io/index.html).\n",
"\n",
Expand Down Expand Up @@ -102,7 +102,7 @@
"\n",
"# Path(local_path).parent.mkdir(parents=True, exist_ok=True)\n",
"\n",
"# # Example model. Check https://github.com/nomic-ai/pygpt4all for the latest models.\n",
"# # Example model. Check https://github.com/nomic-ai/gpt4all for the latest models.\n",
"# url = 'http://gpt4all.io/models/ggml-gpt4all-l13b-snoozy.bin'\n",
"\n",
"# # send a GET request to the URL to download the file. Stream since it's large\n",
Expand All @@ -126,7 +126,8 @@
"callbacks = [StreamingStdOutCallbackHandler()]\n",
"# Verbose is required to pass to the callback manager\n",
"llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)\n",
"# If you want to use GPT4ALL_J model add the backend parameter\n",
"# If you want to use a custom model add the backend parameter\n",
"# Check https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/python for supported backends\n",
Chae4ek marked this conversation as resolved.
Show resolved Hide resolved
"llm = GPT4All(model=local_path, backend='gptj', callbacks=callbacks, verbose=True)"
]
},
Expand Down
103 changes: 37 additions & 66 deletions langchain/llms/gpt4all.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class GPT4All(LLM):
r"""Wrapper around GPT4All language models.

To use, you should have the ``pygpt4all`` python package installed, the
To use, you should have the ``gpt4all`` python package installed, the
pre-trained model file, and the model's config information.

Example:
Expand All @@ -28,7 +28,7 @@ class GPT4All(LLM):
model: str
"""Path to the pre-trained GPT4All model file."""

backend: str = Field("llama", alias="backend")
backend: Optional[str] = Field(None, alias="backend")

n_ctx: int = Field(512, alias="n_ctx")
"""Token context window."""
Expand Down Expand Up @@ -88,93 +88,66 @@ class GPT4All(LLM):
streaming: bool = False
"""Whether to stream the results or not."""

context_erase: float = 0.5
"""Leave (n_ctx * context_erase) tokens
starting from beginning if the context has run out."""

client: Any = None #: :meta private:

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid

def _llama_default_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
@staticmethod
def _model_param_names() -> Set[str]:
return {
"n_predict": self.n_predict,
"n_threads": self.n_threads,
"repeat_last_n": self.repeat_last_n,
"repeat_penalty": self.repeat_penalty,
"top_k": self.top_k,
"top_p": self.top_p,
"temp": self.temp,
"n_ctx",
"n_predict",
"top_k",
"top_p",
"temp",
"n_batch",
"repeat_penalty",
"repeat_last_n",
"context_erase",
}

def _gptj_default_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
def _default_params(self) -> Dict[str, Any]:
return {
"n_ctx": self.n_ctx,
"n_predict": self.n_predict,
"n_threads": self.n_threads,
"top_k": self.top_k,
"top_p": self.top_p,
"temp": self.temp,
"n_batch": self.n_batch,
"repeat_penalty": self.repeat_penalty,
"repeat_last_n": self.repeat_last_n,
"context_erase": self.context_erase,
}

@staticmethod
def _llama_param_names() -> Set[str]:
"""Get the identifying parameters."""
return {
"seed",
"n_ctx",
"n_parts",
"f16_kv",
"logits_all",
"vocab_only",
"use_mlock",
"embedding",
}

@staticmethod
def _gptj_param_names() -> Set[str]:
"""Get the identifying parameters."""
return set()

@staticmethod
def _model_param_names(backend: str) -> Set[str]:
if backend == "llama":
return GPT4All._llama_param_names()
else:
return GPT4All._gptj_param_names()

def _default_params(self) -> Dict[str, Any]:
if self.backend == "llama":
return self._llama_default_params()
else:
return self._gptj_default_params()

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in the environment."""
try:
backend = values["backend"]
if backend == "llama":
from pygpt4all import GPT4All as GPT4AllModel
elif backend == "gptj":
from pygpt4all import GPT4All_J as GPT4AllModel
else:
raise ValueError(f"Incorrect gpt4all backend {cls.backend}")

model_kwargs = {
k: v
for k, v in values.items()
if k in GPT4All._model_param_names(backend)
}
from gpt4all import GPT4All as GPT4AllModel

full_path = values["model"]
model_path, delimiter, model_name = full_path.rpartition("/")
model_path += delimiter

values["client"] = GPT4AllModel(
model_path=values["model"],
**model_kwargs,
model_name=model_name,
model_path=model_path or None,
model_type=values["backend"],
allow_download=False,
)
Chae4ek marked this conversation as resolved.
Show resolved Hide resolved
values["backend"] = values["client"].model.model_type

except ImportError:
raise ValueError(
"Could not import pygpt4all python package. "
"Please install it with `pip install pygpt4all`."
"Could not import gpt4all python package. "
"Please install it with `pip install gpt4all`."
)
return values

Expand All @@ -185,9 +158,7 @@ def _identifying_params(self) -> Mapping[str, Any]:
"model": self.model,
**self._default_params(),
**{
k: v
for k, v in self.__dict__.items()
if k in self._model_param_names(self.backend)
k: v for k, v in self.__dict__.items() if k in self._model_param_names()
},
}

Expand Down