From 75deccaf605c6a6a622330000373fd3b5af6d7d7 Mon Sep 17 00:00:00 2001 From: Raphael Glon Date: Thu, 26 Sep 2024 11:36:46 +0200 Subject: [PATCH] fix(diffusers): LoRA adapters, handle several base models Signed-off-by: Raphael Glon --- docker_images/diffusers/app/lora.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docker_images/diffusers/app/lora.py b/docker_images/diffusers/app/lora.py index 82308211..ba468f98 100644 --- a/docker_images/diffusers/app/lora.py +++ b/docker_images/diffusers/app/lora.py @@ -141,7 +141,8 @@ def _load_textual_embeddings(self, adapter, model_data): logger.info("Text embeddings loaded for adapter %s", adapter) else: logger.info( - "No text embeddings were loaded due to invalid embeddings or a mismatch of token sizes for adapter %s", + "No text embeddings were loaded due to invalid embeddings or a mismatch of token sizes " + "for adapter %s", adapter, ) self.current_tokens_loaded = tokens_to_add @@ -157,7 +158,8 @@ def _load_lora_adapter(self, kwargs): logger.error(msg) raise ValueError(msg) base_model = model_data.cardData["base_model"] - if self.model_id != base_model: + if (isinstance(base_model, list) and (self.model_id not in base_model)) or \ + (self.model_id != base_model): msg = f"Requested adapter {adapter:s} is not a LoRA adapter for base model {self.model_id:s}" logger.error(msg) raise ValueError(msg)