Skip to content

Commit

Permalink
Merge pull request #10 from shamspias/feat/rt-detr
Browse files Browse the repository at this point in the history
Fixing label index
  • Loading branch information
shamspias authored Sep 11, 2024
2 parents 1f3c24e + 22a6ed4 commit df50307
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def continue_ui(self):
"mmdet",
]
)
self.config.sahi_device = st.sidebar.selectbox("Device:", ["cpu"])
self.config.sahi_device = st.sidebar.selectbox("Device:", ["cpu", "cuda:0"])
self.config.sahi_slice_size = st.sidebar.slider("SAHI slice size:", 128, 512, (256, 256))
self.config.sahi_overlap_ratio = st.sidebar.slider("SAHI overlap ratio:", 0.1, 0.5, (0.2, 0.2))
self.sahi_config = {
Expand Down
9 changes: 7 additions & 2 deletions formats/base_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,29 @@ def process_results(self, results: Dict, img_dimensions, supported_classes) -> L
for box in results['boxes']: # Assuming SAHI results are formatted similarly
class_id = int(box['cls'][0])
if class_id in supported_classes: # Check if class_id is in the list of supported classes
class_id_index = supported_classes.index(
class_id) # Get index of class_id in supported_classes list
xmin, ymin, xmax, ymax = box['xyxy'][0]
x_center = ((xmin + xmax) / 2) / img_width
y_center = ((ymin + ymax) / 2) / img_height
width = (xmax - xmin) / img_width
height = (ymax - ymin) / img_height
annotations.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
annotations.append(f"{class_id_index} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
else:
for result in results:
if hasattr(result, 'boxes') and result.boxes is not None:
for box in result.boxes:
class_id = int(box.cls[0])
if class_id in supported_classes: # Check if class_id is in the list of supported classes
class_id_index = supported_classes.index(
class_id) # Get index of class_id in supported_classes list
xmin, ymin, xmax, ymax = box.xyxy[0]
x_center = ((xmin + xmax) / 2) / img_width
y_center = ((ymin + ymax) / 2) / img_height
width = (xmax - xmin) / img_width
height = (ymax - ymin) / img_height
annotations.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
annotations.append(
f"{class_id_index} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")

return annotations

Expand Down

0 comments on commit df50307

Please sign in to comment.