This repository has been archived by the owner on May 22, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathsection.py
365 lines (307 loc) · 12.8 KB
/
section.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
#####################################################################
# THIS FILE IS AUTOMATICALLY GENERATED BY UNSTRUCTURED API TOOLS.
# DO NOT MODIFY DIRECTLY
#####################################################################
import io
import os
import gzip
import mimetypes
from typing import List, Union
from fastapi import status, FastAPI, File, Form, Request, UploadFile, APIRouter, HTTPException
from fastapi.responses import PlainTextResponse
import json
from fastapi.responses import StreamingResponse
from starlette.datastructures import Headers
from starlette.types import Send
from base64 import b64encode
from typing import Optional, Mapping, Iterator, Tuple
import secrets
from prepline_sec_filings.sections import section_string_to_enum, validate_section_names, SECSection
from prepline_sec_filings.sec_document import SECDocument, REPORT_TYPES, VALID_FILING_TYPES
from enum import Enum
import re
import signal
from unstructured.staging.base import convert_to_isd
from prepline_sec_filings.sections import (
ALL_SECTIONS,
SECTIONS_10K,
SECTIONS_10Q,
SECTIONS_S1,
)
import csv
from typing import Dict
from unstructured.documents.elements import Text, NarrativeText, Title, ListItem
from unstructured.staging.label_studio import stage_for_label_studio
app = FastAPI()
router = APIRouter()
def is_expected_response_type(media_type, response_type):
if media_type == "application/json" and response_type not in [dict, list]:
return True
elif media_type == "text/csv" and response_type != str:
return True
else:
return False
# pipeline-api
class timeout:
def __init__(self, seconds=1, error_message="Timeout"):
self.seconds = seconds
self.error_message = error_message
def handle_timeout(self, signum, frame):
raise TimeoutError(self.error_message)
def __enter__(self):
try:
signal.signal(signal.SIGALRM, self.handle_timeout)
signal.alarm(self.seconds)
except ValueError:
pass
def __exit__(self, type, value, traceback):
try:
signal.alarm(0)
except ValueError:
pass
def get_regex_enum(section_regex):
class CustomSECSection(Enum):
CUSTOM = re.compile(section_regex)
@property
def pattern(self):
return self.value
return CustomSECSection.CUSTOM
def convert_to_isd_csv(results: dict) -> str:
"""
Returns the representation of document elements as an Initial Structured Document (ISD)
in CSV Format.
"""
csv_fieldnames: List[str] = ["section", "element_type", "text"]
new_rows = []
for section, section_narrative in results.items():
rows: List[Dict[str, str]] = convert_to_isd(section_narrative)
for row in rows:
new_row_item = dict()
new_row_item["section"] = section
new_row_item["element_type"] = row["type"]
new_row_item["text"] = row["text"]
new_rows.append(new_row_item)
with io.StringIO() as buffer:
csv_writer = csv.DictWriter(buffer, fieldnames=csv_fieldnames)
csv_writer.writeheader()
csv_writer.writerows(new_rows)
return buffer.getvalue()
# List of valid response schemas
LABELSTUDIO = "labelstudio"
ISD = "isd"
def pipeline_api(
text, response_type="application/json", response_schema="isd", m_section=[], m_section_regex=[]
):
"""Many supported sections including: RISK_FACTORS, MANAGEMENT_DISCUSSION, and many more"""
validate_section_names(m_section)
sec_document = SECDocument.from_string(text)
if sec_document.filing_type not in VALID_FILING_TYPES:
raise ValueError(
f"SEC document filing type {sec_document.filing_type} is not supported, "
f"must be one of {','.join(VALID_FILING_TYPES)}"
)
results = {}
if m_section == [ALL_SECTIONS]:
filing_type = sec_document.filing_type
if filing_type in REPORT_TYPES:
if filing_type.startswith("10-K"):
m_section = [enum.name for enum in SECTIONS_10K]
elif filing_type.startswith("10-Q"):
m_section = [enum.name for enum in SECTIONS_10Q]
else:
raise ValueError(f"Invalid report type: {filing_type}")
else:
m_section = [enum.name for enum in SECTIONS_S1]
for section in m_section:
results[section] = sec_document.get_section_narrative(section_string_to_enum[section])
for i, section_regex in enumerate(m_section_regex):
regex_enum = get_regex_enum(section_regex)
with timeout(seconds=5):
section_elements = sec_document.get_section_narrative(regex_enum)
results[f"REGEX_{i}"] = section_elements
if response_type == "application/json":
if response_schema == LABELSTUDIO:
return {
section: stage_for_label_studio(section_narrative)
for section, section_narrative in results.items()
}
elif response_schema == ISD:
return {
section: convert_to_isd(section_narrative)
for section, section_narrative in results.items()
}
else:
raise ValueError(
f"output_schema '{response_schema}' is not supported for {response_type}"
)
elif response_type == "text/csv":
if response_schema != ISD:
raise ValueError(
f"output_schema '{response_schema}' is not supported for {response_type}"
)
return convert_to_isd_csv(results)
else:
raise ValueError(f"response_type '{response_type}' is not supported")
def get_validated_mimetype(file):
"""
Return a file's mimetype, either via the file.content_type or the mimetypes lib if that's too
generic. If the user has set UNSTRUCTURED_ALLOWED_MIMETYPES, validate against this list and
return HTTP 400 for an invalid type.
"""
content_type = file.content_type
if not content_type or content_type == "application/octet-stream":
content_type = mimetypes.guess_type(str(file.filename))[0]
# Some filetypes missing for this library, just hardcode them for now
if not content_type:
if file.filename.endswith(".md"):
content_type = "text/markdown"
elif file.filename.endswith(".msg"):
content_type = "message/rfc822"
allowed_mimetypes_str = os.environ.get("UNSTRUCTURED_ALLOWED_MIMETYPES")
if allowed_mimetypes_str is not None:
allowed_mimetypes = allowed_mimetypes_str.split(",")
if content_type not in allowed_mimetypes:
raise HTTPException(
status_code=400,
detail=(
f"Unable to process {file.filename}: "
f"File type {content_type} is not supported."
),
)
return content_type
class MultipartMixedResponse(StreamingResponse):
CRLF = b"\r\n"
def __init__(self, *args, content_type: str = None, **kwargs):
super().__init__(*args, **kwargs)
self.content_type = content_type
def init_headers(self, headers: Optional[Mapping[str, str]] = None) -> None:
super().init_headers(headers)
self.boundary_value = secrets.token_hex(16)
content_type = f'multipart/mixed; boundary="{self.boundary_value}"'
self.raw_headers.append((b"content-type", content_type.encode("latin-1")))
@property
def boundary(self):
return b"--" + self.boundary_value.encode()
def _build_part_headers(self, headers: dict) -> bytes:
header_bytes = b""
for header, value in headers.items():
header_bytes += f"{header}: {value}".encode() + self.CRLF
return header_bytes
def build_part(self, chunk: bytes) -> bytes:
part = self.boundary + self.CRLF
part_headers = {"Content-Length": len(chunk), "Content-Transfer-Encoding": "base64"}
if self.content_type is not None:
part_headers["Content-Type"] = self.content_type
part += self._build_part_headers(part_headers)
part += self.CRLF + chunk + self.CRLF
return part
async def stream_response(self, send: Send) -> None:
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
async for chunk in self.body_iterator:
if not isinstance(chunk, bytes):
chunk = chunk.encode(self.charset)
chunk = b64encode(chunk)
await send(
{"type": "http.response.body", "body": self.build_part(chunk), "more_body": True}
)
await send({"type": "http.response.body", "body": b"", "more_body": False})
def ungz_file(file: UploadFile, gz_uncompressed_content_type=None) -> UploadFile:
def return_content_type(filename):
if gz_uncompressed_content_type:
return gz_uncompressed_content_type
else:
return str(mimetypes.guess_type(filename)[0])
filename = str(file.filename) if file.filename else ""
if filename.endswith(".gz"):
filename = filename[:-3]
gzip_file = gzip.open(file.file).read()
return UploadFile(
file=io.BytesIO(gzip_file),
size=len(gzip_file),
filename=filename,
headers=Headers({"content-type": return_content_type(filename)}),
)
@router.post("/sec-filings/v0/section")
@router.post("/sec-filings/v0.2.1/section")
def pipeline_1(
request: Request,
gz_uncompressed_content_type: Optional[str] = Form(default=None),
text_files: Union[List[UploadFile], None] = File(default=None),
output_format: Union[str, None] = Form(default=None),
output_schema: str = Form(default=None),
section: List[str] = Form(default=[]),
section_regex: List[str] = Form(default=[]),
):
if text_files:
for file_index in range(len(text_files)):
if text_files[file_index].content_type == "application/gzip":
text_files[file_index] = ungz_file(text_files[file_index])
content_type = request.headers.get("Accept")
default_response_type = output_format or "application/json"
if not content_type or content_type == "*/*" or content_type == "multipart/mixed":
media_type = default_response_type
else:
media_type = content_type
default_response_schema = output_schema or "isd"
if isinstance(text_files, list) and len(text_files):
if len(text_files) > 1:
if content_type and content_type not in ["*/*", "multipart/mixed", "application/json"]:
raise HTTPException(
detail=(
f"Conflict in media type {content_type}"
' with response type "multipart/mixed".\n'
),
status_code=status.HTTP_406_NOT_ACCEPTABLE,
)
def response_generator(is_multipart):
for file in text_files:
get_validated_mimetype(file)
text = file.file.read().decode("utf-8")
response = pipeline_api(
text,
m_section=section,
m_section_regex=section_regex,
response_type=media_type,
response_schema=default_response_schema,
)
if is_expected_response_type(media_type, type(response)):
raise HTTPException(
detail=(
f"Conflict in media type {media_type}"
f" with response type {type(response)}.\n"
),
status_code=status.HTTP_406_NOT_ACCEPTABLE,
)
valid_response_types = ["application/json", "text/csv", "*/*", "multipart/mixed"]
if media_type in valid_response_types:
if is_multipart:
if type(response) not in [str, bytes]:
response = json.dumps(response)
yield response
else:
raise HTTPException(
detail=f"Unsupported media type {media_type}.\n",
status_code=status.HTTP_406_NOT_ACCEPTABLE,
)
if content_type == "multipart/mixed":
return MultipartMixedResponse(
response_generator(is_multipart=True), content_type=media_type
)
else:
return (
list(response_generator(is_multipart=False))[0]
if len(text_files) == 1
else response_generator(is_multipart=False)
)
else:
raise HTTPException(
detail='Request parameter "text_files" is required.\n',
status_code=status.HTTP_400_BAD_REQUEST,
)
app.include_router(router)