Skip to content

Commit

Permalink
heuristics to get title from section headers (#1033)
Browse files Browse the repository at this point in the history
* heuristics to get title from section headers

* updating get_text return

* rps unit test fix

* unit test fixes

* table element fixes

* updating function, adding comments and parameters

* removing redundant checks

* more fixes for font sizes

* changing variable names, function definition  and updating doc string

* typo handling

* updating docs

* updating unit tests
  • Loading branch information
Soeb-aryn authored Nov 26, 2024
1 parent 639d007 commit a9142a6
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_partition_with_ocr_instance(self):
s = ArynPDFPartitioner("Aryn/deformable-detr-DocLayNet")
ocr = Mock(spec=OcrModel)
dummy_text = "mocked ocr text"
ocr.get_text.return_value = dummy_text
ocr.get_text.return_value = (dummy_text, 7)
d = check_partition(
s, TEST_DIR / "resources/data/pdfs/visit_aryn.pdf", use_ocr=True, use_cache=False, ocr_model=ocr
)
Expand Down
46 changes: 46 additions & 0 deletions lib/sycamore/sycamore/tests/unit/utils/test_pdf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
filter_elements_by_page,
select_pdf_pages,
select_pages,
promote_title,
)
from sycamore.tests.config import TEST_DIR

Expand Down Expand Up @@ -156,3 +157,48 @@ def test_select_pages():
assert new_doc.binary_representation is not None
assert len(new_doc.binary_representation) < len(doc.binary_representation)
assert all(e.properties["page_number"] in [1, 2, 4] for e in new_doc.elements)


def test_promote_title_with_title_element():
elements = [
Element(type="Title", properties={"page_number": 1}),
Element(type="Section-header", properties={"page_number": 2}),
Element(type="Caption", properties={"page_number": 2}),
]

result = promote_title(elements)

assert result == elements


def test_promote_title_with_section_header_as_title_candidate():
elements = [
Element(type="Section-header", properties={"page_number": 1, "font_size": 12}),
Element(type="Section-header", properties={"page_number": 2, "font_size": 14}),
Element(type="Caption", properties={"page_number": 2, "font_size": 16}),
]
gt_elements = [
Element(type="Title", properties={"page_number": 1, "font_size": 12}),
Element(type="Section-header", properties={"page_number": 2, "font_size": 14}),
Element(type="Caption", properties={"page_number": 2, "font_size": 16}),
]

result = promote_title(elements)
assert result == gt_elements


def test_promote_title_with_caption_as_title_candidate():
elements = [
Element(type="Caption", properties={"page_number": 1, "font_size": 12}),
Element(type="Section-header", properties={"page_number": 2, "font_size": 14}),
Element(type="Caption", properties={"page_number": 2, "font_size": 16}),
]
gt_elements = [
Element(type="Title", properties={"page_number": 1, "font_size": 12}),
Element(type="Section-header", properties={"page_number": 2, "font_size": 14}),
Element(type="Caption", properties={"page_number": 2, "font_size": 16}),
]

result = promote_title(elements)

assert result == gt_elements
16 changes: 13 additions & 3 deletions lib/sycamore/sycamore/transforms/detr_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,18 @@ def _supplement_text(inferred: list[Element], text: list[Element], threshold: fl
if matched:
matches = []
full_text = []
font_sizes = []
for m in matched:
matches.append(m)
if m.text_representation:
full_text.append(m.text_representation)

if font_size := m.properties.get("font_size"):
font_sizes.append(font_size)
if isinstance(i, TableElement):
i.tokens = [{"text": elem.text_representation, "bbox": elem.bbox} for elem in matches]

i.data["text_representation"] = " ".join(full_text)

i.properties["font_size"] = sum(font_sizes) / len(font_sizes) if font_sizes else None
return inferred + unmatched

def partition_pdf(
Expand All @@ -157,6 +159,7 @@ def partition_pdf(
output_format: Optional[str] = None,
text_extraction_options: dict[str, Any] = {},
source: str = "",
output_label_options: dict[str, Any] = {},
) -> list[Element]:
if use_partitioning_service:
assert aryn_api_key != ""
Expand Down Expand Up @@ -199,6 +202,13 @@ def partition_pdf(
for ele in r:
ele.properties[DocumentPropertyTypes.PAGE_NUMBER] = i + 1
page.append(ele)
if output_label_options.get("promote_title", False):
from sycamore.utils.pdf_utils import promote_title

if title_candidate_elements := output_label_options.get("title_candidate_elements"):
promote_title(page, title_candidate_elements)
else:
promote_title(page)
bbox_sort_page(page)
elements.extend(page)
if output_format == "markdown":
Expand Down Expand Up @@ -811,6 +821,6 @@ def extract_ocr(
tokens.append(token)
elem.tokens = tokens
else:
elem.text_representation = ocr_model_obj.get_text(cropped_image)
elem.text_representation, elem.properties["font_size"] = ocr_model_obj.get_text(cropped_image)

return elements
7 changes: 7 additions & 0 deletions lib/sycamore/sycamore/transforms/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,10 @@ class ArynPartitioner(Partitioner):
either pdfminer or OCR. Currently supports the 'object_type' property for pdfminer,
which can be set to 'boxes' or 'lines' to control the granularity of output.
source: The application that is using the partitioner. This is used for logging purposes.
output_label_options: A dictionary for configuring output label behavior. It supports two options:
promote_title, a boolean that specifies whether to add a title to partitioned elements if one is missing, and
title_candidate_elements, a list of strings representing labels for potential titles.
default: {"promote_title": True , "title_candidate_elements":["Section-header", "Caption"]}
Example:
The following shows an example of using the ArynPartitioner to partition a PDF and extract
both table structure and image
Expand Down Expand Up @@ -454,6 +458,7 @@ def __init__(
output_format: Optional[str] = None,
text_extraction_options: dict[str, Any] = {},
source: str = "",
output_label_options: dict[str, Any] = {},
):
if use_partitioning_service:
device = "cpu"
Expand Down Expand Up @@ -494,6 +499,7 @@ def __init__(
self._pages_per_call = pages_per_call
self._text_extraction_options = text_extraction_options
self._source = source
self.output_label_options = output_label_options

@timetrace("SycamorePdf")
def partition(self, document: Document) -> Document:
Expand Down Expand Up @@ -523,6 +529,7 @@ def partition(self, document: Document) -> Document:
output_format=self._output_format,
text_extraction_options=self._text_extraction_options,
source=self._source,
output_label_options=self.output_label_options,
)
except Exception as e:
path = document.properties["path"]
Expand Down
31 changes: 20 additions & 11 deletions lib/sycamore/sycamore/transforms/text_extraction/ocr_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
class OcrModel(TextExtractor):

@abstractmethod
def get_text(self, image: Image.Image) -> str:
def get_text(self, image: Image.Image) -> tuple[str, Optional[float]]:
pass

@abstractmethod
Expand Down Expand Up @@ -72,22 +72,24 @@ def __init__(self, lang_list=["en"], **kwargs):

self.reader = easyocr.Reader(lang_list=lang_list, **kwargs)

def get_text(self, image: Image.Image) -> str:
def get_text(self, image: Image.Image) -> tuple[str, Optional[float]]:
image_bytes = BytesIO()
image.save(image_bytes, format="BMP")
raw_results = self.reader.readtext(image_bytes.getvalue())
out_list = []
font_sizes = []
for res in raw_results:
text = res[1]
out_list.append(text)
font_sizes.append(res[0][2][1] - res[0][0][1])
val = " ".join(out_list)
return val
avg_font_size = sum(font_sizes) / len(font_sizes) if font_sizes else None
return val, avg_font_size

def get_boxes_and_text(self, image: Image.Image) -> list[dict[str, Any]]:
image_bytes = BytesIO()
image.save(image_bytes, format="BMP")
raw_results = self.reader.readtext(image_bytes.getvalue())

out: list[dict[str, Any]] = []
for res in raw_results:
raw_bbox = res[0]
Expand All @@ -109,9 +111,10 @@ def __init__(self):

self.pytesseract = pytesseract

def get_text(self, image: Image.Image) -> str:
def get_text(self, image: Image.Image) -> tuple[str, Optional[float]]:
val = self.pytesseract.image_to_string(image)
return val
# font size calculation is not supported for tesseract
return val, None

def get_boxes_and_text(self, image: Image.Image) -> list[dict[str, Any]]:
output_list = []
Expand Down Expand Up @@ -140,7 +143,8 @@ def __init__(self):
self.tesseract = Tesseract()
self.easy_ocr = EasyOcr()

def get_text(self, image: Image.Image) -> str:
def get_text(self, image: Image.Image) -> tuple[str, Optional[float]]:
# font size calculation is not supported for tesseract
return self.tesseract.get_text(image)

def get_boxes_and_text(self, image: Image.Image) -> list[dict[str, Any]]:
Expand All @@ -165,14 +169,19 @@ def __init__(self, language="en", slice_kwargs={}):
self.reader = PaddleOCR(lang=self.language, use_gpu=self.use_gpu)
self.slice_kwargs = slice_kwargs

def get_text(self, image: Image.Image) -> str:
def get_text(self, image: Image.Image) -> tuple[str, Optional[float]]:
bytearray = BytesIO()
image.save(bytearray, format="BMP")
result = self.reader.ocr(bytearray.getvalue(), rec=True, det=True, cls=False)
if result and result[0]:
text_values = [value[1][0] for value in result[0]]
return " ".join(text_values)
return ""
text_values = []
font_sizes = []
for value in result[0]:
text_values.append(value[1][0])
font_sizes.append(value[0][3][1] - value[0][0][1])
avg_font_size = sum(font_sizes) / len(font_sizes) if font_sizes else None
return " ".join(text_values), avg_font_size
return "", None

def set_slicing_parameters(self, image_width, image_height) -> dict[str, Any]:
slicing_params = {}
Expand Down
14 changes: 14 additions & 0 deletions lib/sycamore/sycamore/transforms/text_extraction/pdf_miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ def extract_document(self, filename: str, hash_key: str, use_cache=False, **kwar
pdf_miner_cache.set(hash_key, pages)
return pages

def _get_font_size(self, objs) -> float:
font_size_list = []

def traverse(objs):
for obj in objs:
if isinstance(obj, Iterable):
traverse(obj)
elif hasattr(obj, "fontname"):
font_size_list.append(obj.size)

traverse(objs)
return sum(font_size_list) / len(font_size_list)

@timetrace("PdfMinerPageEx")
def extract_page(self, page: Optional[Union["PDFPage", "Image"]]) -> list[Element]:
from pdfminer.pdfpage import PDFPage
Expand All @@ -103,6 +116,7 @@ def extract_page(self, page: Optional[Union["PDFPage", "Image"]]) -> list[Elemen
{
"bbox": BoundingBox(x1, y1, x2, y2),
"text": obj.get_text(),
"font_size": self._get_font_size(obj),
}
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def parse_output(self, output: list[dict[str, Any]], width, height) -> list[Elem
obj_bbox.y2 / height,
)
text.text_representation = obj_text
if "font_size" in obj:
text.properties["font_size"] = obj["font_size"]
texts.append(text)
return texts

Expand Down
18 changes: 18 additions & 0 deletions lib/sycamore/sycamore/utils/pdf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,21 @@ def display_page_and_table_properties(some_pages: list[Document]):
print("Element Type: ", e.type)
print("Element Properties: ", json.dumps(e.properties, indent=2, default=str))
display(HTML(e.text_representation))


def promote_title(elements: list[Element], title_candidate_elements=["Section-header", "Caption"]) -> list[Element]:
section_header_big_font = 0
section_header = None
for ele in elements:
if ele.properties["page_number"] != 1:
continue
if ele.type == "Title":
return elements
else:
font_size = ele.properties.get("font_size", None)
if ele.type in title_candidate_elements and font_size and font_size > section_header_big_font:
section_header_big_font = font_size
section_header = ele
if section_header:
section_header.type = "Title"
return elements

0 comments on commit a9142a6

Please sign in to comment.