Skip to content

Commit

Permalink
[fix] Prevents long waiting times when connecting to incorrect or unr…
Browse files Browse the repository at this point in the history
…esponsive addresses (#3185)
  • Loading branch information
zhiyxu authored Jul 12, 2024
1 parent da429d4 commit dd246b3
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

### Fixes:
- Fix SB3 callback metric tracking (mihran113)
- Prevent long waiting times when connecting to incorrect or unresponsive addresses (xuzhiy)

## 3.22.0 Jun 20, 2024

Expand Down
16 changes: 8 additions & 8 deletions aim/ext/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, remote_path: str):
def protocol_probe(self):
endpoint = f'http://{self.remote_path}/status/'
try:
response = requests.get(endpoint, headers=self.request_headers)
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
if response.status_code == 200:
if response.url.startswith('https://'):
self._http_protocol = 'https://'
Expand All @@ -76,7 +76,7 @@ def protocol_probe(self):

endpoint = f'https://{self.remote_path}/status/'
try:
response = requests.get(endpoint, headers=self.request_headers)
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
if response.status_code == 200:
self._http_protocol = 'https://'
self._ws_protocol = 'wss://'
Expand Down Expand Up @@ -132,7 +132,7 @@ def _check_remote_version_compatibility(self):

def client_heartbeat(self):
endpoint = f'{self._http_protocol}{self._client_endpoint}/heartbeat/{self.uri}/'
response = requests.get(endpoint, headers=self.request_headers)
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response_json = response.json()
if response.status_code != 200:
raise_exception(response_json.get('message'))
Expand All @@ -145,7 +145,7 @@ def client_heartbeat(self):
)
def connect(self):
endpoint = f'{self._http_protocol}{self._client_endpoint}/connect/{self.uri}/'
response = requests.get(endpoint, headers=self.request_headers)
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response_json = response.json()
if response.status_code != 200:
raise_exception(response_json.get('message'))
Expand All @@ -154,7 +154,7 @@ def connect(self):

def reconnect(self):
endpoint = f'{self._http_protocol}{self._client_endpoint}/reconnect/{self.uri}/'
response = requests.get(endpoint, headers=self.request_headers)
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response_json = response.json()
if response.status_code != 200:
raise_exception(response_json.get('message'))
Expand All @@ -170,7 +170,7 @@ def disconnect(self):
self._ws.close()

endpoint = f'{self._http_protocol}{self._client_endpoint}/disconnect/{self.uri}/'
response = requests.get(endpoint, headers=self.request_headers)
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response_json = response.json()
if response.status_code != 200:
raise_exception(response_json.get('message'))
Expand All @@ -181,7 +181,7 @@ def get_version(
self,
):
endpoint = f'{self._http_protocol}{self._client_endpoint}/get-version/'
response = requests.get(endpoint, headers=self.request_headers)
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response_json = response.json()
if response.status_code == 404:
return '<3.19.0'
Expand Down Expand Up @@ -215,7 +215,7 @@ def release_resource(self, queue_id, resource_handler):
if queue_id != -1:
self.get_queue().wait_for_finish()

response = requests.get(endpoint, headers=self.request_headers)
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response_json = response.json()
if response.status_code == 400:
raise_exception(response_json.get('exception'))
Expand Down

0 comments on commit dd246b3

Please sign in to comment.