Skip to content

Commit ee82bb7

Browse files
authored
feat: add exponential retries to parallel mode calls (#216)
Replace our parallel mode retry logic with the `backoff` library. This gives us exponential backoff, retryable error codes, etc with just a decorator, which really cleans up the code. Changes: * Refactor `partition_file_via_api` and move the request with backoff to `call_api` * Add `backoff` as a dependency and `pip compile` * Make sure we don't dump api parameters on every parallel call * Don't allow internal calls to bypass the 503 low memory gate (Should be handle in the retries like everything else) To test this, try adding an HTTPException to the code. Add a non-retryable exception in `partition_pdf_splits`: ``` # If it's small enough, just process locally # (Some kwargs need to be renamed for local partition) if len(pdf_pages) <= pages_per_pdf: raise HTTPException(status_code=400) ``` When you run this and send a file, you'll get the 400 back immediately: ``` export UNSTRUCTURED_PARALLEL_MODE_ENABLED=true export UNSTRUCTURED_PARALLEL_MODE_URL=http://localhost:8000/general/v0/general export UNSTRUCTURED_PARALLEL_NUM_THREADS=1 make run-web-app curl -X POST 'http://localhost:8000/general/v0/general' --form files=@sample-docs/layout-parser-paper.pdf {"detail":"Bad Request"} ``` Now, return a 500 error instead and run again. In this case you'll get a server error, but in the logs you should see that the retries happened: ``` Giving up call_api(...) after 3 tries (fastapi.exceptions.HTTPException) ```
1 parent 91c0617 commit ee82bb7

File tree

7 files changed

+93
-100
lines changed

7 files changed

+93
-100
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
## 0.0.44-dev0
1+
## 0.0.44-dev1
22

33
* Bump unstructured to 0.10.14
4+
* Improve parallel mode retry handling
45

56
## 0.0.43
67

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,7 @@ As mentioned above, processing a pdf using `hi_res` is currently a slow operatio
274274
* `UNSTRUCTURED_PARALLEL_MODE_URL` - the location to send pdf page asynchronously, no default setting at the moment.
275275
* `UNSTRUCTURED_PARALLEL_MODE_THREADS` - the number of threads making requests at once, default is `3`.
276276
* `UNSTRUCTURED_PARALLEL_MODE_SPLIT_SIZE` - the number of pages to be processed in one request, default is `1`.
277-
* `UNSTRUCTURED_PARALLEL_RETRY_ATTEMPTS` - the number of retry attempts, default is `1`.
278-
* `UNSTRUCTURED_PARALLEL_RETRY_BACKOFF_TIME` - the backoff time in seconds for each retry attempt, default is `1.0`.
277+
* `UNSTRUCTURED_PARALLEL_RETRY_ATTEMPTS` - the number of retry attempts on a retryable error, default is `2`. (i.e. 3 attempts are made in total)
279278

280279
### Generating Python files from the pipeline notebooks
281280

prepline_general/api/general.py

Lines changed: 73 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
from unstructured.staging.base import convert_to_isd, convert_to_dataframe, elements_from_json
2222
import psutil
2323
import requests
24-
import time
25-
from unstructured_inference.models.chipper import MODEL_TYPES as CHIPPER_MODEL_TYPES
24+
import backoff
2625
import logging
26+
from unstructured_inference.models.chipper import MODEL_TYPES as CHIPPER_MODEL_TYPES
2727

2828

2929
app = FastAPI()
@@ -39,7 +39,7 @@ def is_expected_response_type(media_type, response_type):
3939
return False
4040

4141

42-
# pipeline-api
42+
logger = logging.getLogger("unstructured_api")
4343

4444

4545
DEFAULT_MIMETYPES = (
@@ -92,6 +92,38 @@ def get_pdf_splits(pdf_pages, split_size=1):
9292
return split_pdfs
9393

9494

95+
# Do not retry with these status codes
96+
def is_non_retryable(e):
97+
return 400 <= e.status_code < 500
98+
99+
100+
@backoff.on_exception(
101+
backoff.expo,
102+
HTTPException,
103+
max_tries=int(os.environ.get("UNSTRUCTURED_PARALLEL_RETRY_ATTEMPTS", 2)) + 1,
104+
giveup=is_non_retryable,
105+
logger=logger,
106+
)
107+
def call_api(request_url, api_key, filename, file, content_type, **partition_kwargs):
108+
"""
109+
Call the api with the given request_url.
110+
"""
111+
headers = {"unstructured-api-key": api_key}
112+
113+
response = requests.post(
114+
request_url,
115+
files={"files": (filename, file, content_type)},
116+
data=partition_kwargs,
117+
headers=headers,
118+
)
119+
120+
if response.status_code != 200:
121+
detail = response.json().get("detail") or response.text
122+
raise HTTPException(status_code=response.status_code, detail=detail)
123+
124+
return response.text
125+
126+
95127
def partition_file_via_api(file_tuple, request, filename, content_type, **partition_kwargs):
96128
"""
97129
Send the given file to be partitioned remotely with retry logic,
@@ -103,40 +135,16 @@ def partition_file_via_api(file_tuple, request, filename, content_type, **partit
103135
filename and content_type are passed in the file form data
104136
partition_kwargs holds any form parameters to be sent on
105137
"""
106-
request_url = os.environ.get("UNSTRUCTURED_PARALLEL_MODE_URL")
138+
file, page_offset = file_tuple
107139

140+
request_url = os.environ.get("UNSTRUCTURED_PARALLEL_MODE_URL")
108141
if not request_url:
109142
raise HTTPException(status_code=500, detail="Parallel mode enabled but no url set!")
110143

111-
file, page_offset = file_tuple
112-
113-
headers = {"unstructured-api-key": request.headers.get("unstructured-api-key")}
144+
api_key = request.headers.get("unstructured-api-key")
114145

115-
# Retry parameters
116-
try_attempts = int(os.environ.get("UNSTRUCTURED_PARALLEL_RETRY_ATTEMPTS", 1)) + 1
117-
retry_backoff_time = float(os.environ.get("UNSTRUCTURED_PARALLEL_RETRY_BACKOFF_TIME", 1.0))
118-
119-
while try_attempts >= 0:
120-
response = requests.post(
121-
request_url,
122-
files={"files": (filename, file, content_type)},
123-
data=partition_kwargs,
124-
headers=headers,
125-
)
126-
try_attempts -= 1
127-
non_retryable_error_codes = [400, 401, 402, 403]
128-
status_code = response.status_code
129-
if status_code != 200:
130-
if try_attempts == 0 or status_code in non_retryable_error_codes:
131-
detail = response.json().get("detail") or response.text
132-
raise HTTPException(status_code=response.status_code, detail=detail)
133-
else:
134-
# Retry after backoff
135-
time.sleep(retry_backoff_time)
136-
else:
137-
break
138-
139-
elements = elements_from_json(text=response.text)
146+
result = call_api(request_url, api_key, filename, file, content_type, **partition_kwargs)
147+
elements = elements_from_json(text=result)
140148

141149
# We need to account for the original page numbers
142150
for element in elements:
@@ -196,9 +204,6 @@ def partition_pdf_splits(
196204
return results
197205

198206

199-
logger = logging.getLogger("unstructured_api")
200-
201-
202207
def pipeline_api(
203208
file,
204209
request=None,
@@ -215,47 +220,47 @@ def pipeline_api(
215220
m_strategy=[],
216221
m_xml_keep_tags=[],
217222
):
218-
logger.debug(
219-
"pipeline_api input params: {}".format(
220-
json.dumps(
221-
{
222-
"filename": filename,
223-
"file_content_type": file_content_type,
224-
"response_type": response_type,
225-
"m_coordinates": m_coordinates,
226-
"m_encoding": m_encoding,
227-
"m_hi_res_model_name": m_hi_res_model_name,
228-
"m_include_page_breaks": m_include_page_breaks,
229-
"m_ocr_languages": m_ocr_languages,
230-
"m_pdf_infer_table_structure": m_pdf_infer_table_structure,
231-
"m_skip_infer_table_types": m_skip_infer_table_types,
232-
"m_strategy": m_strategy,
233-
"m_xml_keep_tags": m_xml_keep_tags,
234-
},
235-
default=str,
223+
if filename.endswith(".msg"):
224+
# Note(yuming): convert file type for msg files
225+
# since fast api might sent the wrong one.
226+
file_content_type = "application/x-ole-storage"
227+
228+
# We don't want to keep logging the same params for every parallel call
229+
origin_ip = request.headers.get("X-Forwarded-For") or request.client.host
230+
is_internal_request = origin_ip.startswith("10.")
231+
232+
if not is_internal_request:
233+
logger.debug(
234+
"pipeline_api input params: {}".format(
235+
json.dumps(
236+
{
237+
"filename": filename,
238+
"response_type": response_type,
239+
"m_coordinates": m_coordinates,
240+
"m_encoding": m_encoding,
241+
"m_hi_res_model_name": m_hi_res_model_name,
242+
"m_include_page_breaks": m_include_page_breaks,
243+
"m_ocr_languages": m_ocr_languages,
244+
"m_pdf_infer_table_structure": m_pdf_infer_table_structure,
245+
"m_skip_infer_table_types": m_skip_infer_table_types,
246+
"m_strategy": m_strategy,
247+
"m_xml_keep_tags": m_xml_keep_tags,
248+
},
249+
default=str,
250+
)
236251
)
237252
)
238-
)
253+
254+
logger.debug(f"filetype: {file_content_type}")
239255

240256
# If this var is set, reject traffic when free memory is below minimum
241-
# Allow internal requests - these are parallel calls already in progress
242257
mem = psutil.virtual_memory()
243258
memory_free_minimum = int(os.environ.get("UNSTRUCTURED_MEMORY_FREE_MINIMUM_MB", 0))
244259

245260
if memory_free_minimum > 0 and mem.available <= memory_free_minimum * 1024 * 1024:
246-
# Note(yuming): Use X-Forwarded-For header to find the orginal IP for external API
247-
# requests,since LB forwards requests in AWS
248-
origin_ip = request.headers.get("X-Forwarded-For") or request.client.host
249-
250-
if not origin_ip.startswith("10."):
251-
raise HTTPException(
252-
status_code=503, detail="Server is under heavy load. Please try again later."
253-
)
254-
255-
if filename.endswith(".msg"):
256-
# Note(yuming): convert file type for msg files
257-
# since fast api might sent the wrong one.
258-
file_content_type = "application/x-ole-storage"
261+
raise HTTPException(
262+
status_code=503, detail="Server is under heavy load. Please try again later."
263+
)
259264

260265
if file_content_type == "application/pdf":
261266
try:

requirements/base.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ fastapi
99
uvicorn
1010
ratelimit
1111
requests
12+
backoff
1213
pypdf
1314
pycryptodome
1415
psutil

requirements/base.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ anyio==3.7.1
1010
# via
1111
# fastapi
1212
# starlette
13+
backoff==2.2.1
14+
# via -r requirements/base.in
1315
beautifulsoup4==4.12.2
1416
# via unstructured
1517
certifi==2023.7.22

requirements/test.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ babel==2.12.1
4141
# via jupyterlab-server
4242
backcall==0.2.0
4343
# via ipython
44+
backoff==2.2.1
45+
# via -r requirements/base.txt
4446
beautifulsoup4==4.12.2
4547
# via
4648
# -r requirements/base.txt

test_general/api/test_app.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import requests
66
import pandas as pd
77
from fastapi.testclient import TestClient
8+
from fastapi import HTTPException
89
from unittest.mock import Mock, ANY
910

1011
from prepline_general.api.app import app
@@ -384,23 +385,6 @@ def test_general_api_returns_503(monkeypatch, mocker):
384385

385386
assert response.status_code == 503
386387

387-
mock_client = mocker.patch("fastapi.Request.client")
388-
mock_client.host = "10.5.0.0"
389-
response = client.post(
390-
MAIN_API_ROUTE,
391-
files=[("files", (str(test_file), open(test_file, "rb")))],
392-
)
393-
394-
assert response.status_code == 200
395-
396-
mock_client.host = "10.4.0.0"
397-
response = client.post(
398-
MAIN_API_ROUTE,
399-
files=[("files", (str(test_file), open(test_file, "rb")))],
400-
)
401-
402-
assert response.status_code == 200
403-
404388

405389
class MockResponse:
406390
def __init__(self, status_code):
@@ -514,17 +498,14 @@ def test_partition_file_via_api_will_retry(monkeypatch, mocker):
514498
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_URL", "unused")
515499
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_THREADS", "1")
516500

517-
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_RETRY_ATTEMPTS", "2")
518-
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_RETRY_BACKOFF_TIME", "0.1")
519-
520501
num_calls = 0
521502

522-
# Return a transient error the first time
503+
# Validate the retry count by returning an error the first 2 times
523504
def mock_response(*args, **kwargs):
524505
nonlocal num_calls
525506
num_calls += 1
526507

527-
if num_calls == 1:
508+
if num_calls <= 2:
528509
return MockResponse(status_code=500)
529510

530511
return MockResponse(status_code=200)
@@ -549,34 +530,36 @@ def mock_response(*args, **kwargs):
549530
assert response.status_code == 200
550531

551532

552-
def test_partition_file_via_api_no_retryable_error_code(monkeypatch, mocker):
533+
def test_partition_file_via_api_not_retryable_error_code(monkeypatch, mocker):
553534
"""
554535
Verify we didn't retry if the error code is not retryable
555536
"""
556537
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_ENABLED", "true")
557538
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_URL", "unused")
558539
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_THREADS", "1")
540+
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_MODE_RETRY_ATTEMPTS", "3")
559541

560-
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_RETRY_ATTEMPTS", "2")
561-
monkeypatch.setenv("UNSTRUCTURED_PARALLEL_RETRY_BACKOFF_TIME", "0.1")
542+
remote_partition = Mock(side_effect=HTTPException(status_code=401))
562543

563544
monkeypatch.setattr(
564545
requests,
565546
"post",
566-
lambda *args, **kwargs: MockResponse(status_code=401),
547+
remote_partition,
567548
)
568-
mock_sleep = mocker.patch("time.sleep")
569549
client = TestClient(app)
570550
test_file = Path("sample-docs") / "layout-parser-paper.pdf"
571551

572552
response = client.post(
573553
MAIN_API_ROUTE,
574554
files=[("files", (str(test_file), open(test_file, "rb"), "application/pdf"))],
575-
data={"pdf_processing_mode": "parallel"},
576555
)
577556

578557
assert response.status_code == 401
579-
assert mock_sleep.call_count == 0
558+
559+
# Often page 2 will start processing before the page 1 exception is raised.
560+
# So we can't assert called_once, but we can assert the count is less than it
561+
# would have been if we used all retries.
562+
assert remote_partition.call_count < 4
580563

581564

582565
def test_password_protected_pdf():

0 commit comments

Comments
 (0)