Skip to content

Commit

Permalink
updates to callback
Browse files Browse the repository at this point in the history
  • Loading branch information
Jordan-Pierce committed Jan 14, 2025
1 parent 81206ea commit 355ee7b
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 43 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,5 @@ ENV/
.vscode/
.idea/

tests/detection_tiled
tests/segmentation_tiled
tests/detection_*
tests/segmentation_*
24 changes: 21 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,32 @@ tiler = YoloTiler(
tiler.run()
```

An example of an (optional) `progress_callback` function can be seen below:
```python
@dataclass
class TileProgress:
"""Data class to track tiling progress"""
current_set_name: str = ""
current_image_name: str = ""
current_image_idx: int = 0
total_images: int = 0
current_tile_idx: int = 0
total_tiles: int = 0
```

Using `TileProgress` custom callback functions can be created. An example of an (optional) `progress_callback` function
can be seen below:

```python
from yolo_tiler import TilerProgress

def progress_callback(progress: TileProgress):
print(f"Processing {progress.current_image} in {progress.current_set} set: "
f"tile {progress.current_tile}/{progress.total_tiles}")
# Determine whether to show tile or image progress
if progress.total_tiles > 0:
print(f"Processing {progress.current_image_name} in {progress.current_set_name} set: "
f"Tile {progress.current_tile_idx}/{progress.total_tiles}")
else:
print(f"Processing {progress.current_image_name} in {progress.current_set_name} set: "
f"Image {progress.current_image_idx}/{progress.total_images}")

```

Expand Down
21 changes: 12 additions & 9 deletions tests/test_yolo_tiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,23 @@


def progress_callback(progress: TileProgress):
print(f"Processing {progress.current_image_name} in {progress.current_set_name} set: "
f"tile {progress.current_tile_idx}/{progress.total_tiles}, "
f"image {progress.current_image_idx}/{progress.total_images}")


src = "./tests/segmentation"
dst = "./tests/segmentation_tiled"
# Determine whether to show tile or image progress
if progress.total_tiles > 0:
print(f"Processing {progress.current_image_name} in {progress.current_set_name} set: "
f"Tile {progress.current_tile_idx}/{progress.total_tiles}")
else:
print(f"Processing {progress.current_image_name} in {progress.current_set_name} set: "
f"Image {progress.current_image_idx}/{progress.total_images}")

src = "./tests/detection_tiled"
dst = "./tests/detection_tiled_tiled"

config = TileConfig(
slice_wh=(320, 240), # Slice width and height
overlap_wh=(0.0, 0.0), # Overlap width and height (10% overlap in this example, or 64x48 pixels)
input_ext=".png",
output_ext=None,
annotation_type="instance_segmentation",
annotation_type="object_detection",
train_ratio=0.7,
valid_ratio=0.2,
test_ratio=0.1,
Expand All @@ -33,7 +36,7 @@ def progress_callback(progress: TileProgress):
source=src,
target=dst,
config=config,
num_viz_samples=100,
num_viz_samples=25,
progress_callback=progress_callback
)

Expand Down
81 changes: 52 additions & 29 deletions yolo_tiler/yolo_tiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ def get_effective_area(self, image_width: int, image_height: int) -> Tuple[int,
@dataclass
class TileProgress:
"""Data class to track tiling progress"""
current_set_name: str
current_image_name: str
current_image_idx: int
total_images: int
current_tile_idx: int
total_tiles: int
current_set_name: str = ""
current_image_name: str = ""
current_image_idx: int = 0
total_images: int = 0
current_tile_idx: int = 0
total_tiles: int = 0


class YoloTiler:
Expand Down Expand Up @@ -166,19 +166,37 @@ def _tqdm_callback(self, progress: TileProgress):
progress: TileProgress object containing current progress
"""
# Initialize or get progress bar for current set
if progress.current_set_name not in self._progress_bars:
# Determine if we're tracking tiles or images
if progress.total_tiles > 0:
total = progress.total_tiles
desc = f"{progress.current_set_name}: Tile"
unit = 'tiles'
else:
total = progress.total_images
desc = f"{progress.current_set_name}: Image"
unit = 'images'

self._progress_bars[progress.current_set_name] = tqdm(
total=progress.total_tiles,
desc=progress.current_set_name,
unit='items'
total=total,
desc=desc,
unit=unit
)

# Update progress
self._progress_bars[progress.current_set_name].n = progress.current_tile_idx
# Update progress based on available information
if progress.total_tiles > 0:
self._progress_bars[progress.current_set_name].n = progress.current_tile_idx
else:
self._progress_bars[progress.current_set_name].n = progress.current_image_idx

self._progress_bars[progress.current_set_name].refresh()

# Close and cleanup if task is complete
if progress.current_tile_idx >= progress.total_tiles:
is_complete = (progress.total_tiles > 0 and progress.current_tile_idx >= progress.total_tiles) or \
(progress.total_tiles == 0 and progress.current_image_idx >= progress.total_images)

if is_complete:
self._progress_bars[progress.current_set_name].close()
del self._progress_bars[progress.current_set_name]

Expand Down Expand Up @@ -414,7 +432,12 @@ def _save_labels(self, labels: List, path: Path, is_segmentation: bool) -> None:
df = pd.DataFrame(labels, columns=['class', 'x1', 'y1', 'w', 'h'])
df.to_csv(path, sep=' ', index=False, header=False, float_format='%.6f')

def tile_image(self, image_path: Path, label_path: Path, folder: str, current_image_idx: int, total_images: int) -> None:
def tile_image(self,
image_path: Path,
label_path: Path,
folder: str,
current_image_idx: int,
total_images: int) -> None:
"""
Tile an image and its corresponding labels, properly handling margins.
"""
Expand Down Expand Up @@ -667,31 +690,31 @@ def split_data(self) -> None:
num_test = len(test_set)

# Move files to valid folder
for tile_idx, (image_path, label_path) in enumerate(valid_set):
for image_idx, (image_path, label_path) in enumerate(valid_set):
self._move_split_data(image_path, label_path, 'valid')

if self.progress_callback:
progress = TileProgress(
current_tile_idx=tile_idx + 1,
total_tiles=num_valid,
current_tile_idx=0,
total_tiles=0,
current_set_name='valid',
current_image_name=image_path.name,
current_image_idx=0, # Placeholder, update as needed
total_images=0 # Placeholder, update as needed
current_image_idx=image_idx + 1,
total_images=num_valid
)
self.progress_callback(progress)

# Move files to test folder
for tile_idx, (image_path, label_path) in enumerate(test_set):
for image_idx, (image_path, label_path) in enumerate(test_set):
self._move_split_data(image_path, label_path, 'test')
if self.progress_callback:
progress = TileProgress(
current_tile_idx=tile_idx + 1,
total_tiles=num_test,
current_tile_idx=0,
total_tiles=0,
current_set_name='test',
current_image_name=image_path.name,
current_image_idx=0, # Placeholder, update as needed
total_images=0 # Placeholder, update as needed
current_image_idx=image_idx + 1,
total_images=num_test
)
self.progress_callback(progress)

Expand Down Expand Up @@ -779,7 +802,7 @@ def visualize_random_samples(self) -> None:
selected_images = random.sample(image_paths, num_samples)

# Process each selected image
for tile_idx, image_path in enumerate(selected_images):
for image_idx, image_path in enumerate(selected_images):
label_path = train_label_dir / f"{image_path.stem}.txt"

if not label_path.exists():
Expand All @@ -788,16 +811,16 @@ def visualize_random_samples(self) -> None:

if self.progress_callback:
progress = TileProgress(
current_tile_idx=tile_idx + 1,
total_tiles=num_samples,
current_tile_idx=0,
total_tiles=0,
current_set_name='rendered',
current_image_name=image_path.name,
current_image_idx=0, # Placeholder, update as needed
total_images=0 # Placeholder, update as needed
current_image_idx=image_idx + 1,
total_images=len(selected_images)
)
self.progress_callback(progress)

self._render_single_sample(image_path, label_path, tile_idx + 1)
self._render_single_sample(image_path, label_path, image_idx + 1)

def _render_single_sample(self, image_path: Path, label_path: Path, idx: int) -> None:
"""
Expand Down

0 comments on commit 355ee7b

Please sign in to comment.