diff --git a/src/marqo/_httprequests.py b/src/marqo/_httprequests.py index f04a4ff2..7ed89d48 100644 --- a/src/marqo/_httprequests.py +++ b/src/marqo/_httprequests.py @@ -21,6 +21,7 @@ 'put': session.put } + class HttpRequests: def __init__(self, config: Config) -> None: self.config = config @@ -32,12 +33,13 @@ def _operation(self, method: HTTP_OPERATIONS) -> Callable: return OPERATION_MAPPING[method] - def _construct_path(self, path: str) -> str: + def _construct_path(self, path: str, index_name="") -> str: """Augment the URL request path based if telemetry is required.""" + url = f"{self.config.get_url(index_name=index_name)}/{path}" if self.config.use_telemetry: delimeter= "?" if "?" not in f"{self.config.url}/{path}" else "&" - return f"{self.config.url}/{path}{delimeter}telemetry=True" - return f"{self.config.url}/{path}" + return url + f"{delimeter}telemetry=True" + return url def send_request( self, @@ -45,6 +47,7 @@ def send_request( path: str, body: Optional[Union[Dict[str, Any], List[Dict[str, Any]], List[str], str]] = None, content_type: Optional[str] = None, + index_name: str = "" ) -> Any: req_headers = copy.deepcopy(self.headers) @@ -56,7 +59,7 @@ def send_request( try: response = self._operation(http_operation)( - url=self._construct_path(path), + url=self._construct_path(path, index_name), timeout=self.config.timeout, headers=req_headers, data=body, @@ -68,40 +71,43 @@ def send_request( except requests.exceptions.ConnectionError as err: raise BackendCommunicationError(str(err)) from err - def get( self, path: str, body: Optional[Union[Dict[str, Any], List[Dict[str, Any]], List[str], str]] = None, + index_name: str = "" ) -> Any: content_type = None if body is not None: content_type = 'application/json' - return self.send_request('get', path=path, body=body, content_type=content_type) + return self.send_request('get', path=path, body=body, content_type=content_type,index_name=index_name) def post( self, path: str, body: Optional[Union[Dict[str, Any], List[Dict[str, Any]], List[str], str]] = None, content_type: Optional[str] = 'application/json', + index_name: str = "" ) -> Any: - return self.send_request('post', path, body, content_type) + return self.send_request('post', path, body, content_type, index_name=index_name) def put( self, path: str, body: Optional[Union[Dict[str, Any], List[Dict[str, Any]], List[str], str]] = None, content_type: Optional[str] = None, + index_name: str = "" ) -> Any: if body is not None: content_type = 'application/json' - return self.send_request('put', path, body, content_type) + return self.send_request('put', path, body, content_type, index_name=index_name) def delete( self, path: str, body: Optional[Union[Dict[str, Any], List[Dict[str, Any]], List[str]]] = None, + index_name: str = "" ) -> Any: - return self.send_request('delete', path, body) + return self.send_request('delete', path, body, index_name=index_name) @staticmethod def __to_json( diff --git a/src/marqo/client.py b/src/marqo/client.py index c8e2dc27..d16f6fe1 100644 --- a/src/marqo/client.py +++ b/src/marqo/client.py @@ -46,9 +46,12 @@ def create_index( sentences_per_chunk=2, sentence_overlap=0, image_preprocessing_method=None, - settings_dict=None + settings_dict=None, + inference_node_type=None, + storage_node_type=None, + inference_node_count=1, ) -> Dict[str, Any]: - """Create the index. + """Create the index. Please refer to the marqo cloud to see options for inference and storage node types. Args: index_name: name of the index. @@ -61,6 +64,9 @@ def create_index( settings_dict: if specified, overwrites all other setting parameters, and is passed directly as the index's index_settings + inference_node_type: + storage_node_type: + inference_node_count; Returns: Response body, containing information about index creation result """ @@ -70,7 +76,8 @@ def create_index( model=model, normalize_embeddings=normalize_embeddings, sentences_per_chunk=sentences_per_chunk, sentence_overlap=sentence_overlap, image_preprocessing_method=image_preprocessing_method, - settings_dict=settings_dict + settings_dict=settings_dict, inference_node_type=inference_node_type, storage_node_type=storage_node_type, + inference_node_count=inference_node_count ) def delete_index(self, index_name: str) -> Dict[str, Any]: @@ -101,7 +108,7 @@ def get_index(self, index_name: str) -> Index: """ ix = Index(self.config, index_name) # verify it exists: - self.http.get(path=f"indexes/{index_name}/stats") + self.http.get(path=f"indexes/{index_name}/stats", index_name=index_name) return ix def index(self, index_name: str) -> Index: @@ -170,19 +177,15 @@ def get_marqo(self): def health(self): return self.http.get(path="health") - def eject_model(self, model_name:str, model_device:str): return self.http.delete(path=f"models?model_name={model_name}&model_device={model_device}") - def get_loaded_models(self): return self.http.get(path="models") - def get_cuda_info(self): return self.http.get(path="device/cuda") - def get_cpu_info(self): return self.http.get(path="device/cpu") diff --git a/src/marqo/config.py b/src/marqo/config.py index 81c17ec1..16693b68 100644 --- a/src/marqo/config.py +++ b/src/marqo/config.py @@ -2,6 +2,7 @@ from marqo import enums, utils import urllib3 import warnings +from marqo.marqo_url_resolver import MarqoUrlResolver class Config: @@ -24,9 +25,11 @@ def __init__( """ self.cluster_is_remote = False self.cluster_is_s2search = False + self.cluster_is_marqo = False + self.marqo_url_resolver = None + self.api_key = api_key self.url = self.set_url(url) self.timeout = timeout - self.api_key = api_key # suppress warnings until we figure out the dependency issues: # warnings.filterwarnings("ignore") self.use_telemetry = use_telemetry @@ -43,5 +46,18 @@ def set_url(self, url): self.cluster_is_remote = True if "s2search.io" in lowered_url: self.cluster_is_s2search = True + if "api.marqo.ai" in lowered_url: + self.cluster_is_marqo = True + self.marqo_url_resolver = MarqoUrlResolver(api_key=self.api_key, expiration_time=15) self.url = url return self.url + + def get_url(self, index_name=None,): + """Get the URL, and infers whether that url is marqo cloud, + and if it is targeting a specific index resolves the index-specific url""" + if not self.cluster_is_marqo: + return self.url + if self.cluster_is_marqo and not index_name: + return self.url + "/api" + # calls resolver to get index-specific url for when cluster is marqo and index_name is not None + return self.marqo_url_resolver[index_name] diff --git a/src/marqo/defaults.py b/src/marqo/defaults.py index 640d6cd8..df180f89 100644 --- a/src/marqo/defaults.py +++ b/src/marqo/defaults.py @@ -18,6 +18,6 @@ def get_cloud_default_index_settings(): "patch_method": None } }, - "number_of_shards": 2, - "number_of_replicas": 1, + "number_of_shards": 1, + "number_of_replicas": 0, } diff --git a/src/marqo/errors.py b/src/marqo/errors.py index ee7e90ae..d9bfce1b 100644 --- a/src/marqo/errors.py +++ b/src/marqo/errors.py @@ -165,3 +165,27 @@ class BackendTimeoutError(InternalError): def __init__(self, message: str,) -> None: self.message = f"Timeout error communicating with Marqo: {message}" + +class MarqoCloudIndexNotReadyError(MarqoError): + """Error when Marqo index is not ready""" + code = "index_not_ready_cloud" + status_code = HTTPStatus.NOT_FOUND + + def __init__(self, index_name: str,) -> None: + self.message = f"The Python client could not resolve the endpoint for the index name {index_name}. " \ + f"This could be due to the index is still in the process of being created," \ + f" or that the client's cache has not yet been updated.\n" \ + f"- Please try again in a couple of minutes, or you can query the index status" \ + f" with mq.index({index_name}).get_status function to see when index is ready.\n" \ + f"- If the problem persists, please contact marqo support at support@marqo.ai" + + +class MarqoCloudIndexNotFoundError(MarqoError): + """Error when Marqo index is not ready""" + code = "index_not_found_cloud" + status_code = HTTPStatus.NOT_FOUND + + def __init__(self, index_name: str,) -> None: + self.message = f"The index name {index_name} does not exist in the Marqo cloud or client's cache" \ + f" has not yet been updated. Please check the index name and try again.\n" \ + f"- If the problem persists, please contact marqo support at support@marqo.ai" diff --git a/src/marqo/index.py b/src/marqo/index.py index 05409442..d7b8a562 100644 --- a/src/marqo/index.py +++ b/src/marqo/index.py @@ -2,6 +2,8 @@ import json import logging import pprint +import time + from marqo import defaults import typing from urllib import parse @@ -54,7 +56,10 @@ def create(config: Config, index_name: str, sentences_per_chunk=2, sentence_overlap=0, image_preprocessing_method=None, - settings_dict: dict = None + settings_dict: dict = None, + inference_node_type: str = None, + storage_node_type: str = None, + inference_node_count: int = 1, ) -> Dict[str, Any]: """Create the index. @@ -70,6 +75,9 @@ def create(config: Config, index_name: str, settings_dict: if specified, overwrites all other setting parameters, and is passed directly as the index's index_settings + inference_node_type: inference type for the index + storage_node_type: storage type for the index + inference_node_count: number of inference nodes for the index Returns: Response body, containing information about index creation result """ @@ -91,7 +99,19 @@ def create(config: Config, index_name: str, cl_text_preprocessing['split_length'] = sentences_per_chunk cl_img_preprocessing = cl_ix_defaults['image_preprocessing'] cl_img_preprocessing['patch_method'] = image_preprocessing_method - return req.post(f"indexes/{index_name}", body=cl_settings) + if not config.cluster_is_marqo: + return req.post(f"indexes/{index_name}", body=cl_settings) + cl_settings['inference_type'] = inference_node_type + cl_settings['storage_class'] = storage_node_type + cl_settings['inference_node_count'] = inference_node_count + response = req.post(f"indexes/{index_name}", body=cl_settings) + index = Index(config, index_name) + creation = index.get_status() + while creation['index_status'] != 'READY': + time.sleep(10) + creation = index.get_status() + mq_logger.info(f"Index creation status: {creation['index_status']}") + return response return req.post(f"indexes/{index_name}", body={ "index_defaults": { @@ -111,7 +131,11 @@ def create(config: Config, index_name: str, def refresh(self): """refreshes the index""" - return self.http.post(path=F"indexes/{self.index_name}/refresh") + return self.http.post(path=F"indexes/{self.index_name}/refresh", index_name=self.index_name,) + + def get_status(self): + """gets the status of the index""" + return self.http.get(path=F"indexes/{self.index_name}/status") def search(self, q: Union[str, dict], searchable_attributes: Optional[List[str]] = None, limit: int = 10, offset: int = 0, search_method: Union[SearchMethods.TENSOR, str] = SearchMethods.TENSOR, @@ -184,7 +208,8 @@ def search(self, q: Union[str, dict], searchable_attributes: Optional[List[str]] body["modelAuth"] = model_auth res = self.http.post( path=path_with_query_str, - body=body + body=body, + index_name=self.index_name, ) num_results = len(res["hits"]) @@ -214,7 +239,7 @@ def get_document(self, document_id: str, expose_facets=None) -> Dict[str, Any]: url_string = f"indexes/{self.index_name}/documents/{document_id}" if expose_facets is not None: url_string += f"?expose_facets={expose_facets}" - return self.http.get(url_string) + return self.http.get(url_string, index_name=self.index_name,) def get_documents(self, document_ids: List[str], expose_facets=None) -> Dict[str, Any]: """Gets a selection of documents based on their IDs. @@ -233,7 +258,8 @@ def get_documents(self, document_ids: List[str], expose_facets=None) -> Dict[str url_string += f"?expose_facets={expose_facets}" return self.http.get( url_string, - body=document_ids + body=document_ids, + index_name=self.index_name, ) def add_documents( @@ -340,7 +366,9 @@ def _add_docs_organiser( # ADD DOCS TIMER-LOGGER (2) start_time_client_request = timer() - res = self.http.post(path=path_with_query_str, body=documents) + res = self.http.post( + path=path_with_query_str, body=documents, index_name=self.index_name, + ) end_time_client_request = timer() total_client_request_time = end_time_client_request - start_time_client_request @@ -374,13 +402,11 @@ def delete_documents(self, ids: List[str], auto_refresh: bool = None) -> Dict[st base_path = f"indexes/{self.index_name}/documents/delete-batch" path_with_refresh = base_path if auto_refresh is None else base_path + f"?refresh={str(auto_refresh).lower()}" - return self.http.post( - path=path_with_refresh, body=ids - ) + return self.http.post(path=path_with_refresh, body=ids, index_name=self.index_name,) def get_stats(self) -> Dict[str, Any]: """Get stats about the index""" - return self.http.get(path=f"indexes/{self.index_name}/stats") + return self.http.get(path=f"indexes/{self.index_name}/stats", index_name=self.index_name,) @staticmethod def _maybe_datetime(the_date: Optional[Union[datetime, str]]) -> Optional[datetime]: @@ -434,7 +460,7 @@ def verbosely_add_docs(i, docs): errors_detected = False t0 = timer() - res = self.http.post(path=path_with_query_str, body=docs) + res = self.http.post(path=path_with_query_str, body=docs, index_name=self.index_name,) total_batch_time = timer() - t0 num_docs = len(docs) @@ -491,4 +517,4 @@ def verbosely_add_docs(i, docs): def get_settings(self) -> dict: """Get all settings of the index""" - return self.http.get(path=f"indexes/{self.index_name}/settings") + return self.http.get(path=f"indexes/{self.index_name}/settings", index_name=self.index_name,) diff --git a/src/marqo/marqo_url_resolver.py b/src/marqo/marqo_url_resolver.py new file mode 100644 index 00000000..07d6627d --- /dev/null +++ b/src/marqo/marqo_url_resolver.py @@ -0,0 +1,42 @@ +import logging +import time +import requests + +from marqo.errors import MarqoCloudIndexNotFoundError, MarqoCloudIndexNotReadyError + + +class MarqoUrlResolver: + def __init__(self, api_key=None, expiration_time: int = 15): + """ URL Resolver is a cache for urls that are resolved to their respective indices only for marqo cloud. """ + self.timestamp = time.time() - expiration_time + self._urls_mapping = {"READY": {}, "CREATING": {}} + self.api_key = api_key + self.expiration_time = expiration_time + + def refresh_urls_if_needed(self, index_name): + if index_name not in self._urls_mapping['READY'] and time.time() - self.timestamp > self.expiration_time: + # fast refresh to catch if index was created + self._refresh_urls() + if index_name in self._urls_mapping['READY'] and time.time() - self.timestamp > 360: + # slow refresh in case index was deleted + self._refresh_urls(timeout=3) + + def __getitem__(self, item): + self.refresh_urls_if_needed(item) + if item in self._urls_mapping['READY']: + return self._urls_mapping['READY'][item] + if item in self._urls_mapping['CREATING']: + raise MarqoCloudIndexNotReadyError(item) + raise MarqoCloudIndexNotFoundError(item) + + def _refresh_urls(self, timeout=None): + response = requests.get('https://api.marqo.ai/api/indexes', + headers={"x-api-key": self.api_key}, timeout=timeout) + if not response.ok: + logging.warning(response.text) + response_json = response.json() + for index in response_json['indices']: + if index.get('index_status') in ["READY", "CREATING"]: + self._urls_mapping[index['index_status']][index['index_name']] = index.get('load_balancer_dns_name') + if self._urls_mapping: + self.timestamp = time.time() diff --git a/tests/marqo_test.py b/tests/marqo_test.py index 7461aa12..2a2e3a57 100644 --- a/tests/marqo_test.py +++ b/tests/marqo_test.py @@ -48,7 +48,7 @@ def wrapper(self, *args, **kwargs): call_count = defaultdict(int) # Used to ensure expected_calls for each MockHTTPTraffic with mock.patch("marqo._httprequests.HttpRequests.send_request") as mock_send_request: - def side_effect(http_operation, path, body=None, content_type=None): + def side_effect(http_operation, path, body=None, content_type=None, index_name=""): if isinstance(body, str): body = json.loads(body) for i, config in enumerate(mock_config): diff --git a/tests/v0_tests/test_config.py b/tests/v0_tests/test_config.py index 7f5d7a9e..dd050c31 100644 --- a/tests/v0_tests/test_config.py +++ b/tests/v0_tests/test_config.py @@ -19,3 +19,25 @@ def test_url_is_s2search(self): def test_url_is_not_s2search(self): c = config.Config(url="https://som_random_cluster/abdcde:8882") assert not c.cluster_is_s2search + + def test_url_is_marqo(self): + c = config.Config(url="https://api.marqo.ai") + assert c.cluster_is_marqo + + def test_get_url_when_cluster_is_marqo_and_no_index_name_specified(self): + c = config.Config(url="https://api.marqo.ai") + assert c.get_url() == "https://api.marqo.ai/api" + + @mock.patch("requests.get") + def test_get_url_when_cluster_is_marqo_and_index_name_specified(self, mock_get): + mock_get.return_value.json.return_value = {"indices": [ + {"index_name": "index1", "load_balancer_dns_name": "example.com", "index_status": "READY"}, + {"index_name": "index2", "load_balancer_dns_name": "example2.com", "index_status": "READY"} + ]} + c = config.Config(url="https://api.marqo.ai") + print(c.marqo_url_resolver._urls_mapping) + assert c.get_url(index_name="index1") == "example.com" + + def test_get_url_when_cluster_is_not_marqo_and_index_name_specified(self): + c = config.Config(url="https://s2search.io/abdcde:8882") + assert c.get_url(index_name="index1") == "https://s2search.io/abdcde:8882" diff --git a/tests/v0_tests/test_index.py b/tests/v0_tests/test_index.py index 76e2e969..e00c7820 100644 --- a/tests/v0_tests/test_index.py +++ b/tests/v0_tests/test_index.py @@ -149,8 +149,8 @@ def run(): test_client.create_index(index_name=self.index_name_1) args, kwargs = mock__post.call_args # this is specific to cloud - assert kwargs['body']['number_of_shards'] == 2 - assert kwargs['body']['number_of_replicas'] == 1 + assert kwargs['body']['number_of_shards'] == 1 + assert kwargs['body']['number_of_replicas'] == 0 assert kwargs['body']['index_defaults']['treat_urls_and_pointers_as_images'] is False return True assert run() @@ -166,8 +166,8 @@ def run(): index_name=self.index_name_1, model='sentence-transformers/stsb-xlm-r-multilingual') args, kwargs = mock__post.call_args assert kwargs['body']['index_defaults']['model'] == 'sentence-transformers/stsb-xlm-r-multilingual' - assert kwargs['body']['number_of_shards'] == 2 - assert kwargs['body']['number_of_replicas'] == 1 + assert kwargs['body']['number_of_shards'] == 1 + assert kwargs['body']['number_of_replicas'] == 0 assert kwargs['body']['index_defaults']['treat_urls_and_pointers_as_images'] is False return True assert run() @@ -200,4 +200,98 @@ def test_create_custom_number_of_replicas(self): self.client.create_index(index_name=self.index_name_1, settings_dict = settings) index_setting = self.client.index(self.index_name_1).get_settings() print(index_setting) - assert intended_replicas == index_setting['number_of_replicas'] \ No newline at end of file + assert intended_replicas == index_setting['number_of_replicas'] + + @mock.patch("marqo._httprequests.HttpRequests.post", return_value={"acknowledged": True}) + @mock.patch("marqo._httprequests.HttpRequests.get", return_value={"index_status": "READY"}) + def test_create_marqo_cloud_index(self, mock_get, mock_post): + self.client.config.url = "https://api.marqo.ai" + self.client.config.api_key = 'some-super-secret-API-key' + self.client.config.cluster_is_marqo = True + + result = self.client.create_index( + index_name=self.index_name_1, inference_node_type="marqo.CPU", inference_node_count=1, + storage_node_type="marqo.basic" + ) + + mock_post.assert_called_with('indexes/my-test-index-1', body={ + 'index_defaults': { + 'treat_urls_and_pointers_as_images': False, 'model': None, 'normalize_embeddings': True, + 'text_preprocessing': {'split_length': 2, 'split_overlap': 0, 'split_method': 'sentence'}, + 'image_preprocessing': {'patch_method': None} + }, + 'number_of_shards': 1, 'number_of_replicas': 0, + 'inference_type': "marqo.CPU", 'storage_class': "marqo.basic", 'inference_node_count': 1}) + mock_get.assert_called_with(path="indexes/my-test-index-1/status") + assert result == {"acknowledged": True} + + @mock.patch("marqo._httprequests.HttpRequests.post", return_value={"error": "inference_type is required"}) + @mock.patch("marqo._httprequests.HttpRequests.get", return_value={"index_status": "READY"}) + def test_create_marqo_cloud_index_wrong_inference_settings(self, mock_get, mock_post): + self.client.config.url = "https://api.marqo.ai" + self.client.config.api_key = 'some-super-secret-API-key' + self.client.config.cluster_is_marqo = True + + result = self.client.create_index( + index_name=self.index_name_1, inference_node_type=None, inference_node_count=1, + storage_node_type="marqo.basic" + ) + + mock_post.assert_called_with('indexes/my-test-index-1', body={ + 'index_defaults': { + 'treat_urls_and_pointers_as_images': False, 'model': None, 'normalize_embeddings': True, + 'text_preprocessing': {'split_length': 2, 'split_overlap': 0, 'split_method': 'sentence'}, + 'image_preprocessing': {'patch_method': None} + }, + 'number_of_shards': 1, 'number_of_replicas': 0, + 'inference_type': None, 'storage_class': "marqo.basic", 'inference_node_count': 1}) + mock_get.assert_called_with(path="indexes/my-test-index-1/status") + assert result == {"error": "inference_type is required"} + + @mock.patch("marqo._httprequests.HttpRequests.post", return_value={"error": "storage_class is required"}) + @mock.patch("marqo._httprequests.HttpRequests.get", return_value={"index_status": "READY"}) + def test_create_marqo_cloud_index_wrong_storage_settings(self, mock_get, mock_post): + self.client.config.url = "https://api.marqo.ai" + self.client.config.api_key = 'some-super-secret-API-key' + self.client.config.cluster_is_marqo = True + + result = self.client.create_index( + index_name=self.index_name_1, inference_node_type="marqo.CPU", inference_node_count=1, + storage_node_type=None + ) + + mock_post.assert_called_with('indexes/my-test-index-1', body={ + 'index_defaults': { + 'treat_urls_and_pointers_as_images': False, 'model': None, 'normalize_embeddings': True, + 'text_preprocessing': {'split_length': 2, 'split_overlap': 0, 'split_method': 'sentence'}, + 'image_preprocessing': {'patch_method': None} + }, + 'number_of_shards': 1, 'number_of_replicas': 0, + 'inference_type': "marqo.CPU", 'storage_class': None, 'inference_node_count': 1}) + mock_get.assert_called_with(path="indexes/my-test-index-1/status") + assert result == {"error": "storage_class is required"} + + @mock.patch("marqo._httprequests.HttpRequests.post", + return_value={"error": "inference_node_count must be greater than 0"}) + @mock.patch("marqo._httprequests.HttpRequests.get", return_value={"index_status": "READY"}) + def test_create_marqo_cloud_index_wrong_inference_node_count(self, mock_get, mock_post): + self.client.config.url = "https://api.marqo.ai" + self.client.config.api_key = 'some-super-secret-API-key' + self.client.config.cluster_is_marqo = True + + result = self.client.create_index( + index_name=self.index_name_1, inference_node_type="marqo.CPU", inference_node_count=-1, + storage_node_type="marqo.basic" + ) + + mock_post.assert_called_with('indexes/my-test-index-1', body={ + 'index_defaults': { + 'treat_urls_and_pointers_as_images': False, 'model': None, 'normalize_embeddings': True, + 'text_preprocessing': {'split_length': 2, 'split_overlap': 0, 'split_method': 'sentence'}, + 'image_preprocessing': {'patch_method': None} + }, + 'number_of_shards': 1, 'number_of_replicas': 0, + 'inference_type': "marqo.CPU", 'storage_class': "marqo.basic", 'inference_node_count': -1}) + mock_get.assert_called_with(path="indexes/my-test-index-1/status") + assert result == {"error": "inference_node_count must be greater than 0"} + diff --git a/tests/v0_tests/test_marqo_url_resolver.py b/tests/v0_tests/test_marqo_url_resolver.py new file mode 100644 index 00000000..494b5666 --- /dev/null +++ b/tests/v0_tests/test_marqo_url_resolver.py @@ -0,0 +1,71 @@ +import time +from unittest.mock import patch + +from marqo.marqo_url_resolver import MarqoUrlResolver +from tests.marqo_test import MarqoTestCase + + +class TestMarqoUrlResolver(MarqoTestCase): + @patch("requests.get") + def test_refresh_urls_if_needed(self, mock_get): + mock_get.return_value.json.return_value = {"indices": [ + {"index_name": "index1", "load_balancer_dns_name": "example.com", "index_status": "READY"}, + {"index_name": "index2", "load_balancer_dns_name": "example2.com", "index_status": "READY"} + ]} + resolver = MarqoUrlResolver(api_key="your-api-key", expiration_time=60) + initial_timestamp = resolver.timestamp + + # Wait for more than the expiration time + time.sleep(0.1) + + resolver.refresh_urls_if_needed("index1") + + # Check that the timestamp has been updated + print(resolver.timestamp, initial_timestamp) + assert resolver.timestamp > initial_timestamp + + # Check that the URLs mapping has been refreshed + assert resolver._urls_mapping["READY"] == { + "index1": "example.com", + "index2": "example2.com", + } + + @patch("requests.get") + def test_refresh_urls_if_not_needed(self, mock_get): + mock_get.return_value.json.return_value = {"indices": [ + {"index_name": "index1", "load_balancer_dns_name": "example.com", "index_status": "READY"}, + {"index_name": "index2", "load_balancer_dns_name": "example2.com", "index_status": "READY"} + ]} + resolver = MarqoUrlResolver(api_key="your-api-key", expiration_time=60) + + # Call refresh_urls_if_needed without waiting + resolver.refresh_urls_if_needed("index1") + initial_timestamp = resolver.timestamp + time.sleep(0.1) + resolver.refresh_urls_if_needed("index2") + + # Check that the timestamp has not been updated + assert resolver.timestamp == initial_timestamp + + # Check that the URLs mapping has been initially populated + assert resolver._urls_mapping["READY"] == { + "index1": "example.com", + "index2": "example2.com", + } + + @patch("requests.get") + def test_refresh_includes_only_ready(self, mock_get): + mock_get.return_value.json.return_value = {"indices": [ + {"index_name": "index1", "load_balancer_dns_name": "example.com", "index_status": "READY"}, + {"index_name": "index2", "load_balancer_dns_name": "example2.com", "index_status": "NOT READY"} + ]} + resolver = MarqoUrlResolver(api_key="your-api-key", expiration_time=60) + + # Access the urls_mapping property + resolver.refresh_urls_if_needed("index1") + urls_mapping = resolver._urls_mapping + + # Check that the URLs mapping has been initially populated + assert urls_mapping["READY"] == { + "index1": "example.com", + }