diff --git a/lollms/services/midjourney/lollms_midjourney.py b/lollms/services/midjourney/lollms_midjourney.py index 2bc3a6a..201dc36 100644 --- a/lollms/services/midjourney/lollms_midjourney.py +++ b/lollms/services/midjourney/lollms_midjourney.py @@ -180,7 +180,31 @@ def download_image(self, uri, folder_path): else: print(f"Failed to download image. Status code: {response.status_code}") return None - + def get_nearest_aspect_ratio(self, width: int, height: int) -> str: + # Define the available aspect ratios + aspect_ratios = { + "1:2": 0.5, + "2:3": 0.6667, + "3:4": 0.75, + "4:5": 0.8, + "1:1": 1, + "5:4": 1.25, + "4:3": 1.3333, + "3:2": 1.5, + "16:9": 1.7778, + "7:4": 1.75, + "2:1": 2 + } + + # Calculate the input aspect ratio + input_ratio = width / height + + # Find the nearest aspect ratio + nearest_ratio = min(aspect_ratios.items(), key=lambda x: abs(x[1] - input_ratio)) + + # Return the formatted string + return f"--ar {nearest_ratio[0]}" + def paint( self, positive_prompt, @@ -200,6 +224,7 @@ def paint( try: # Send prompt and get initial response + positive_prompt += self.get_nearest_aspect_ratio(width, height) initial_response = self.send_prompt_with_retry(positive_prompt, self.retries) message_id = initial_response.get("messageId") if not message_id: @@ -210,7 +235,7 @@ def paint( if "error" in progress_response: raise ValueError(progress_response["error"]) - if width<1024: + if width<=1024: file_name = self.download_image(progress_response["uri"], output_path) return file_name, {"prompt":positive_prompt, "negative_prompt":negative_prompt}