Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deformable table extractor #1053

Merged
merged 2 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 51 additions & 5 deletions lib/sycamore/sycamore/transforms/table_structure/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,16 @@ def _prepare_tokens(self, tokens: list[dict[str, Any]], crop_box, width, height)
t["block_num"] = 0
return tokens

def _init_structure_model(self):
from transformers import TableTransformerForObjectDetection

self.structure_model = TableTransformerForObjectDetection.from_pretrained(self.model).to(self._get_device())

@timetrace("tblExtr")
@requires_modules(["torch", "torchvision"], extra="local-inference")
def extract(self, element: TableElement, doc_image: Image.Image, union_tokens=False) -> TableElement:
def extract(
self, element: TableElement, doc_image: Image.Image, union_tokens=False, apply_thresholds=False
) -> TableElement:
"""Extracts the table structure from the specified element using a TableTransformer model.

Takes a TableElement containing a bounding box, for example from the SycamorePartitioner,
Expand All @@ -112,6 +119,8 @@ def extract(self, element: TableElement, doc_image: Image.Image, union_tokens=Fa
element: A TableElement. The bounding box must be non-null.
doc_image: A PIL object containing an image of the Document page containing the element.
Used for bounding box calculations.
union_tokens: Make sure that ocr/pdfminer tokens are _all_ included in the table.
apply_thresholds: Apply class thresholds to the objects output by the model.
"""

# We need a bounding box to be able to do anything.
Expand All @@ -123,9 +132,7 @@ def extract(self, element: TableElement, doc_image: Image.Image, union_tokens=Fa
width, height = doc_image.size

if self.structure_model is None:
from transformers import TableTransformerForObjectDetection

self.structure_model = TableTransformerForObjectDetection.from_pretrained(self.model).to(self._get_device())
self._init_structure_model()
assert self.structure_model is not None # For typechecking

# Crop the image to encompass just the table + some padding.
Expand Down Expand Up @@ -161,7 +168,9 @@ def extract(self, element: TableElement, doc_image: Image.Image, union_tokens=Fa
structure_id2label = self.structure_model.config.id2label
structure_id2label[len(structure_id2label)] = "no object"

objects = table_transformers.outputs_to_objects(outputs, cropped_image.size, structure_id2label)
objects = table_transformers.outputs_to_objects(
outputs, cropped_image.size, structure_id2label, apply_thresholds=apply_thresholds
)

# Convert the raw objects to our internal table representation. This involves multiple
# phases of postprocessing.
Expand All @@ -182,6 +191,43 @@ def extract(self, element: TableElement, doc_image: Image.Image, union_tokens=Fa
return element


class DeformableTableStructureExtractor(TableTransformerStructureExtractor):
"""A TableStructureExtractor implementation that uses the Deformable DETR model."""

def __init__(self, model: str, device=None):
"""
Creates a TableTransformerStructureExtractor

Args:
model: The HuggingFace URL or local path for the DeformableDETR model to use.
"""

super().__init__(model, device)

def _init_structure_model(self):
from transformers import DeformableDetrForObjectDetection

self.structure_model = DeformableDetrForObjectDetection.from_pretrained(self.model).to(self._get_device())

def extract(
self, element: TableElement, doc_image: Image.Image, union_tokens=False, apply_thresholds=True
) -> TableElement:
"""Extracts the table structure from the specified element using a DeformableDETR model.

Takes a TableElement containing a bounding box, for example from the SycamorePartitioner,
and populates the table property with information about the cells.

Args:
element: A TableElement. The bounding box must be non-null.
doc_image: A PIL object containing an image of the Document page containing the element.
Used for bounding box calculations.
union_tokens: Make sure that ocr/pdfminer tokens are _all_ included in the table.
apply_thresholds: Apply class thresholds to the objects output by the model.
"""
# Literally just call the super but change the default for apply_thresholds
return super().extract(element, doc_image, union_tokens, apply_thresholds)


DEFAULT_TABLE_STRUCTURE_EXTRACTOR = TableTransformerStructureExtractor


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,18 @@ def rescale_bboxes(out_bbox, size):
return b


def outputs_to_objects(outputs, img_size, id2label):
def outputs_to_objects(outputs, img_size, id2label, apply_thresholds: bool = False):
m = outputs.logits.softmax(-1).max(-1)
pred_labels = list(m.indices.detach().cpu().numpy())[0]
pred_scores = list(m.values.detach().cpu().numpy())[0]
pred_bboxes = outputs["pred_boxes"].detach().cpu()[0]
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]

if apply_thresholds:
pred_bboxes, pred_scores, pred_labels = apply_class_thresholds(
pred_bboxes, pred_labels, pred_scores, id2label, DEFAULT_STRUCTURE_CLASS_THRESHOLDS
)

objects = []
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
if float(bbox[0]) > float(bbox[2]) or float(bbox[1]) > float(bbox[3]):
Expand Down Expand Up @@ -906,9 +911,10 @@ def objects_to_structures(objects, tokens, class_thresholds):
if len(tables) == 0:
return {}
if len(tables) > 1:
tables.sort(key=lambda x: x["score"], reverse=True)
import logging

logging.warning("Got multiple tables in document. Using only the first one")
logging.warning("Got multiple tables in document. Using only the highest-scoring one")

table = tables[0]
structure = {}
Expand Down
Loading