diff --git a/.github/workflows/CI.yml b/.github/workflows/publish.yml similarity index 100% rename from .github/workflows/CI.yml rename to .github/workflows/publish.yml diff --git a/Cargo.lock b/Cargo.lock index dffcdad..e3fb417 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -71,12 +71,24 @@ dependencies = [ "winapi", ] +[[package]] +name = "autocfg" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" + [[package]] name = "base64" version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + [[package]] name = "byteorder" version = "1.5.0" @@ -172,6 +184,19 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "either" version = "1.13.0" @@ -214,6 +239,12 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "heck" version = "0.5.0" @@ -261,10 +292,11 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "labelme2yolo" -version = "0.2.2" +version = "0.2.3" dependencies = [ "base64", "clap", + "dashmap", "env_logger", "glob", "indicatif", @@ -288,6 +320,16 @@ version = "0.2.157" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "374af5f94e54fa97cf75e945cce8a6b201e88a1a07e688b47dfd2a59c66dbd86" +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.22" @@ -306,6 +348,25 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + [[package]] name = "ppv-lite86" version = "0.2.20" @@ -383,6 +444,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "redox_syscall" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" +dependencies = [ + "bitflags", +] + [[package]] name = "regex" version = "1.10.6" @@ -428,6 +498,12 @@ dependencies = [ "regex", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "serde" version = "1.0.208" @@ -460,6 +536,12 @@ dependencies = [ "serde", ] +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + [[package]] name = "strsim" version = "0.11.1" diff --git a/Cargo.toml b/Cargo.toml index c5d381f..603a686 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "labelme2yolo" -version = "0.2.2" +version = "0.2.3" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -17,3 +17,4 @@ rayon = "1.5" base64 = "0.13" log = "0.4" env_logger = "0.9" +dashmap = "5.0" diff --git a/src/main.rs b/src/main.rs index 8611043..cb48825 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ use clap::{Parser, ValueEnum}; +use dashmap::DashMap; use env_logger; use glob::glob; use indicatif::{ProgressBar, ProgressStyle}; @@ -10,15 +11,16 @@ use rayon::prelude::*; use serde::{Deserialize, Serialize}; use serde_json; use std::collections::HashMap; -use std::fs; -use std::fs::copy; -use std::fs::File; +use std::fs::{self, copy, File}; use std::io::Write; use std::path::{Path, PathBuf}; use std::str::FromStr; -use std::sync::{Arc, Mutex}; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] struct Shape { label: String, points: Vec<(f64, f64)>, @@ -28,7 +30,7 @@ struct Shape { mask: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] #[serde(rename_all = "camelCase")] struct ImageAnnotation { version: String, @@ -40,23 +42,23 @@ struct ImageAnnotation { image_width: u32, } -/// A powerful tool for converting LabelMe's JSON format to YOLO dataset format. +/// Command-line arguments parser for converting LabelMe JSON to YOLO format. #[derive(Parser, Debug)] #[command(version, about, long_about = None)] struct Args { - /// The dir of the labelme json files + /// Directory containing LabelMe JSON files #[arg(short = 'd', long = "json_dir")] json_dir: String, - /// The validation dataset size + /// Proportion of the dataset to use for validation #[arg(long = "val_size", default_value_t = 0.2, value_parser = validate_size)] val_size: f32, - /// The test dataset size + /// Proportion of the dataset to use for testing #[arg(long = "test_size", default_value_t = 0.0, value_parser = validate_size)] test_size: f32, - /// The output format of yolo + /// Output format (bbox or polygon) for YOLO annotations #[arg( long = "output_format", visible_alias = "format", @@ -65,19 +67,23 @@ struct Args { )] output_format: Format, - /// The ordered label list + /// List of labels in the dataset #[arg(use_value_delimiter = true)] label_list: Vec, + + /// Seed for random shuffling + #[arg(long = "seed", default_value_t = 42)] + seed: u64, } +/// Enumeration for the YOLO output format #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug)] enum Format { - /// Output as polygon format Polygon, - /// Output as bounding-box format Bbox, } +/// Validate that the size is between 0.0 and 1.0 fn validate_size(s: &str) -> Result { match f32::from_str(s) { Ok(val) if val >= 0.0 && val <= 1.0 => Ok(val), @@ -85,36 +91,12 @@ fn validate_size(s: &str) -> Result { } } -fn create_dir(path: &Path) { - if path.exists() { - fs::remove_dir_all(path).expect("Failed to remove existing directory"); - } - fs::create_dir_all(path).expect("Failed to create directory"); -} - -fn read_and_parse_json(path: &Path) -> Option { - match fs::read_to_string(path) { - Ok(content) => match serde_json::from_str::(&content) { - Ok(annotation) => Some(annotation), - Err(e) => { - error!("Failed to parse JSON ({}): {:?}", path.display(), e); - None - } - }, - Err(e) => { - error!("Failed to read file ({}): {:?}", path.display(), e); - None - } - } -} - fn main() { - env_logger::init(); + // Initialize the logger + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); let args = Args::parse(); let dirname = PathBuf::from(&args.json_dir); - - // Check if args.json_dir exists if !dirname.exists() { error!("The specified json_dir does not exist: {}", args.json_dir); return; @@ -122,124 +104,214 @@ fn main() { info!("Starting the conversion process..."); - let pattern = dirname.join("**/*.json"); - let labels_dir = dirname.join("YOLODataset/labels"); - let images_dir = dirname.join("YOLODataset/images"); - create_dir(&labels_dir); - create_dir(&images_dir); - let train_labels_dir = labels_dir.join("train"); - let val_labels_dir = labels_dir.join("val"); - let train_images_dir = images_dir.join("train"); - let val_images_dir = images_dir.join("val"); - create_dir(&train_labels_dir); - create_dir(&val_labels_dir); - create_dir(&train_images_dir); - create_dir(&val_images_dir); + let output_dirs = match setup_output_directories(&args, &dirname) { + Ok(dirs) => dirs, + Err(e) => { + error!("Failed to set up output directories: {}", e); + return; + } + }; + + let annotations = read_and_parse_json_files(&dirname); + info!("Read and parsed {} JSON files.", annotations.len()); + + let split_data = split_annotations(annotations, args.val_size, args.test_size, args.seed); + + let label_map = Arc::new(DashMap::new()); + let next_class_id = Arc::new(AtomicUsize::new(0)); + + if !args.label_list.is_empty() { + initialize_label_map(&args.label_list, &label_map, &next_class_id); + } + + process_all_annotations( + &split_data, + &output_dirs, + &label_map, + &next_class_id, + &args, + &dirname, + ); + + info!("Creating dataset.yaml file..."); + if let Err(e) = create_dataset_yaml(&dirname, &args, &label_map) { + error!("Failed to create dataset.yaml: {}", e); + } else { + info!("Conversion process completed successfully."); + } +} + +/// Struct to hold the paths to the output directories for train/val/test splits +struct OutputDirs { + train_labels_dir: PathBuf, + val_labels_dir: PathBuf, + train_images_dir: PathBuf, + val_images_dir: PathBuf, + test_labels_dir: Option, + test_images_dir: Option, +} + +/// Safely create output directories and return their paths +fn create_output_directory(path: &Path) -> std::io::Result { + if path.exists() { + warn!( + "Directory {:?} already exists. Deleting and recreating it.", + path + ); + fs::remove_dir_all(path)?; + } + fs::create_dir_all(path)?; + Ok(path.to_path_buf()) +} + +/// Set up the directory structure for YOLO dataset output +fn setup_output_directories(args: &Args, dirname: &Path) -> std::io::Result { + let labels_dir = create_output_directory(&dirname.join("YOLODataset/labels"))?; + let images_dir = create_output_directory(&dirname.join("YOLODataset/images"))?; + + let train_labels_dir = create_output_directory(&labels_dir.join("train"))?; + let val_labels_dir = create_output_directory(&labels_dir.join("val"))?; + let train_images_dir = create_output_directory(&images_dir.join("train"))?; + let val_images_dir = create_output_directory(&images_dir.join("val"))?; + let (test_labels_dir, test_images_dir) = if args.test_size > 0.0 { - let test_labels_dir = labels_dir.join("test"); - let test_images_dir = images_dir.join("test"); - create_dir(&test_labels_dir); - create_dir(&test_images_dir); - (Some(test_labels_dir), Some(test_images_dir)) + ( + Some(create_output_directory(&labels_dir.join("test"))?), + Some(create_output_directory(&images_dir.join("test"))?), + ) } else { (None, None) }; - let label_map = Arc::new(Mutex::new(HashMap::new())); - let next_class_id = Arc::new(Mutex::new(0)); - let mut annotations = Vec::new(); - for entry in glob(pattern.to_str().expect("Failed to convert path to string")) + + Ok(OutputDirs { + train_labels_dir, + val_labels_dir, + train_images_dir, + val_images_dir, + test_labels_dir, + test_images_dir, + }) +} + +/// Read and parse JSON files from the specified directory +fn read_and_parse_json_files(dirname: &Path) -> Vec<(PathBuf, ImageAnnotation)> { + let pattern = dirname.join("**/*.json"); + let annotations: Vec<_> = glob(pattern.to_str().expect("Failed to convert path to string")) .expect("Failed to read glob pattern") - { - if let Ok(path) = entry { - if let Some(annotation) = read_and_parse_json(&path) { - annotations.push((path, annotation)); - } - } - } + .filter_map(|entry| { + entry + .ok() + .and_then(|path| read_and_parse_json(&path).map(|annotation| (path, annotation))) + }) + .collect(); + annotations +} - info!("Read and parsed {} JSON files.", annotations.len()); +/// Read and parse a single JSON file into an ImageAnnotation struct +fn read_and_parse_json(path: &Path) -> Option { + fs::read_to_string(path).ok().and_then(|content| { + serde_json::from_str::(&content) + .map_err(|e| error!("Failed to parse JSON ({}): {:?}", path.display(), e)) + .ok() + }) +} + +/// Struct to hold the split datasets for training, validation, and testing +struct SplitData { + train_annotations: Vec<(PathBuf, ImageAnnotation)>, + val_annotations: Vec<(PathBuf, ImageAnnotation)>, + test_annotations: Vec<(PathBuf, ImageAnnotation)>, +} - // Shuffle and split the annotations into train, val, and test sets - let seed: u64 = 42; // Fixed random seed +/// Split the annotations into training, validation, and testing sets +fn split_annotations( + mut annotations: Vec<(PathBuf, ImageAnnotation)>, + val_size: f32, + test_size: f32, + seed: u64, +) -> SplitData { let mut rng = StdRng::seed_from_u64(seed); annotations.shuffle(&mut rng); - let test_size = (annotations.len() as f32 * args.test_size).ceil() as usize; - let val_size = (annotations.len() as f32 * args.val_size).ceil() as usize; + + let test_size = (annotations.len() as f32 * test_size).ceil() as usize; + let val_size = (annotations.len() as f32 * val_size).ceil() as usize; + let (test_annotations, rest_annotations) = annotations.split_at(test_size); let (val_annotations, train_annotations) = rest_annotations.split_at(val_size); - info!( - "Split data into {} training, {} validation, and {} test annotations.", - train_annotations.len(), - val_annotations.len(), - test_annotations.len() - ); - - // Update label_map from label_list if not empty - if !args.label_list.is_empty() { - let mut label_map_guard = label_map.lock().unwrap(); - for (id, label) in args.label_list.iter().enumerate() { - label_map_guard.insert(label.clone(), id); - } - *next_class_id.lock().unwrap() = args.label_list.len(); + SplitData { + train_annotations: train_annotations.to_vec(), + val_annotations: val_annotations.to_vec(), + test_annotations: test_annotations.to_vec(), } +} - // Create progress bars - let train_pb = create_progress_bar(train_annotations.len() as u64, "Train"); - let val_pb = create_progress_bar(val_annotations.len() as u64, "Val"); - let test_pb = create_progress_bar(test_annotations.len() as u64, "Test"); +/// Initialize the label map with the provided label list +fn initialize_label_map( + label_list: &[String], + label_map: &Arc>, + next_class_id: &Arc, +) { + for (id, label) in label_list.iter().enumerate() { + label_map.insert(label.clone(), id); + } + next_class_id.store(label_list.len(), Ordering::Relaxed); +} - // Process train_annotations in parallel - info!("Processing training annotations..."); +/// Process all annotations in parallel for train, val, and test splits +fn process_all_annotations( + split_data: &SplitData, + output_dirs: &OutputDirs, + label_map: &Arc>, + next_class_id: &Arc, + args: &Args, + dirname: &Path, +) { + let train_pb = create_progress_bar(split_data.train_annotations.len() as u64, "Train"); process_annotations_in_parallel( - &train_annotations, - &train_labels_dir, - &train_images_dir, - &label_map, - &next_class_id, - &args, - &dirname, + &split_data.train_annotations, + &output_dirs.train_labels_dir, + &output_dirs.train_images_dir, + label_map, + next_class_id, + args, + dirname, &train_pb, ); train_pb.finish_with_message("Train processing complete"); - // Process val_annotations in parallel - info!("Processing validation annotations..."); + let val_pb = create_progress_bar(split_data.val_annotations.len() as u64, "Val"); process_annotations_in_parallel( - &val_annotations, - &val_labels_dir, - &val_images_dir, - &label_map, - &next_class_id, - &args, - &dirname, + &split_data.val_annotations, + &output_dirs.val_labels_dir, + &output_dirs.val_images_dir, + label_map, + next_class_id, + args, + dirname, &val_pb, ); val_pb.finish_with_message("Val processing complete"); - // Process test_annotations in parallel - if let (Some(test_labels_dir), Some(test_images_dir)) = (test_labels_dir, test_images_dir) { - info!("Processing test annotations..."); + if let (Some(test_labels_dir), Some(test_images_dir)) = + (&output_dirs.test_labels_dir, &output_dirs.test_images_dir) + { + let test_pb = create_progress_bar(split_data.test_annotations.len() as u64, "Test"); process_annotations_in_parallel( - &test_annotations, - &test_labels_dir, - &test_images_dir, - &label_map, - &next_class_id, - &args, - &dirname, + &split_data.test_annotations, + test_labels_dir, + test_images_dir, + label_map, + next_class_id, + args, + dirname, &test_pb, ); test_pb.finish_with_message("Test processing complete"); } - - // Create dataset.yaml file after processing annotations - info!("Creating dataset.yaml file..."); - create_dataset_yaml(&dirname, &args, &label_map); - - info!("Conversion process completed successfully."); } +/// Create a progress bar with the given length and label fn create_progress_bar(len: u64, label: &str) -> ProgressBar { let pb = ProgressBar::new(len); pb.set_style( @@ -253,76 +325,45 @@ fn create_progress_bar(len: u64, label: &str) -> ProgressBar { pb } +/// Process a batch of annotations in parallel fn process_annotations_in_parallel( annotations: &[(PathBuf, ImageAnnotation)], labels_dir: &Path, images_dir: &Path, - label_map: &Arc>>, - next_class_id: &Arc>, + label_map: &Arc>, + next_class_id: &Arc, args: &Args, base_dir: &Path, pb: &ProgressBar, ) { - annotations.par_iter().for_each(|(path, annotation)| { - let mut label_map_guard = label_map.lock().unwrap(); - let mut next_class_id_guard = next_class_id.lock().unwrap(); - process_annotation( - path, - annotation, - labels_dir, - images_dir, - &mut label_map_guard, - &mut next_class_id_guard, - args, - base_dir, - ); - pb.inc(1); + annotations.par_chunks(10).for_each(|chunk| { + chunk.iter().for_each(|(path, annotation)| { + process_annotation( + path, + annotation, + labels_dir, + images_dir, + label_map, + next_class_id, + args, + base_dir, + ); + }); + pb.inc(chunk.len() as u64); }); } -fn create_dataset_yaml( - dirname: &Path, - args: &Args, - label_map: &Arc>>, -) { - let dataset_yaml_path = dirname.join("YOLODataset/dataset.yaml"); - let mut dataset_yaml = - File::create(dataset_yaml_path).expect("Failed to create dataset.yaml file"); - let absolute_path = - fs::canonicalize(&dirname.join("YOLODataset")).expect("Failed to get absolute path"); - let mut yaml_content = format!( - "path: {}\ntrain: images/train\nval: images/val\n", - absolute_path.to_str().unwrap() - ); - if args.test_size > 0.0 { - yaml_content.push_str("test: images/test\n"); - } else { - yaml_content.push_str("test:\n"); - } - yaml_content.push_str("\nnames:\n"); - // Read names from label_map - let label_map_guard = label_map.lock().unwrap(); - let mut sorted_labels: Vec<_> = label_map_guard.iter().collect(); - sorted_labels.sort_by_key(|&(_, id)| id); - for (label, id) in sorted_labels { - yaml_content.push_str(&format!(" {}: {}\n", id, label)); - } - dataset_yaml - .write_all(yaml_content.as_bytes()) - .expect("Failed to write to dataset.yaml file"); -} - +/// Process a single annotation and convert it to YOLO format fn process_annotation( path: &Path, annotation: &ImageAnnotation, labels_dir: &Path, images_dir: &Path, - label_map: &mut HashMap, - next_class_id: &mut usize, + label_map: &Arc>, + next_class_id: &Arc, args: &Args, base_dir: &Path, ) { - // Skip processing if image_path file does not exist and image_data is empty let image_path = base_dir.join(&annotation.image_path); if !image_path.exists() && annotation @@ -337,101 +378,36 @@ fn process_annotation( return; } - let mut yolo_data = String::new(); - for shape in &annotation.shapes { - let class_id = if args.label_list.is_empty() { - // Dynamically generate class ID - *label_map.entry(shape.label.clone()).or_insert_with(|| { - let id = *next_class_id; - *next_class_id += 1; - id - }) - } else { - match args.label_list.iter().position(|r| r == &shape.label) { - Some(id) => id, - None => continue, // Ignore labels not in label_list - } - }; - - if args.output_format == Format::Polygon { - // Write polygon format - yolo_data.push_str(&format!("{}", class_id)); - if shape.shape_type == "rectangle" { - let (x1, y1) = shape.points[0]; - let (x2, y2) = shape.points[1]; - let rect_points = vec![(x1, y1), (x2, y1), (x2, y2), (x1, y2)]; - for &(x, y) in &rect_points { - let x_norm = x / annotation.image_width as f64; - let y_norm = y / annotation.image_height as f64; - yolo_data.push_str(&format!(" {:.6} {:.6}", x_norm, y_norm)); - } - } else { - for &(x, y) in &shape.points { - let x_norm = x / annotation.image_width as f64; - let y_norm = y / annotation.image_height as f64; - yolo_data.push_str(&format!(" {:.6} {:.6}", x_norm, y_norm)); - } - } - yolo_data.push_str("\n"); - } else { - // Write bbox format - let (x_min, y_min, x_max, y_max) = shape.points.iter().fold( - (f64::MAX, f64::MAX, f64::MIN, f64::MIN), - |(x_min, y_min, x_max, y_max), &(x, y)| { - (x_min.min(x), y_min.min(y), x_max.max(x), y_max.max(y)) - }, - ); - - let x_center = (x_min + x_max) / 2.0 / annotation.image_width as f64; - let y_center = (y_min + y_max) / 2.0 / annotation.image_height as f64; - let width = (x_max - x_min) / annotation.image_width as f64; - let height = (y_max - y_min) / annotation.image_height as f64; + let (yolo_data, should_skip) = + convert_to_yolo_format(annotation, args, label_map, next_class_id); - yolo_data.push_str(&format!( - "{} {:.6} {:.6} {:.6} {:.6}\n", - class_id, x_center, y_center, width, height - )); - } + if should_skip { + return; } - let output_path = labels_dir - .join(sanitize_filename::sanitize( - path.file_stem().unwrap().to_str().unwrap(), - )) - .with_extension("txt"); - let mut file = File::create(output_path).expect("Failed to create YOLO data file"); + let sanitized_name = sanitize_filename::sanitize(path.file_stem().unwrap().to_str().unwrap()); + let output_path = labels_dir.join(&sanitized_name).with_extension("txt"); + + let mut file = File::create(&output_path).expect("Failed to create YOLO data file"); file.write_all(yolo_data.as_bytes()) .expect("Failed to write YOLO data"); - // Copy image to images directory if image_path.exists() { - let image_output_path = images_dir.join(sanitize_filename::sanitize( - image_path.file_name().unwrap().to_str().unwrap(), - )); + let image_output_path = images_dir + .join(sanitized_name) + .with_extension(image_path.extension().unwrap_or_default()); copy(&image_path, &image_output_path).expect("Failed to copy image"); } else if let Some(image_data) = &annotation.image_data { if !image_data.is_empty() { - // Decode base64 image data and write to file let image_data = base64::decode(image_data).expect("Failed to decode image data"); let extension = match image_path.extension().and_then(|ext| ext.to_str()) { - Some(ext) => { - let ext_lower = ext.to_lowercase(); - match ext_lower.as_str() { - "jpg" | "jpeg" => "jpeg", - _ => "png", - } - } + Some(ext) => match ext.to_lowercase().as_str() { + "jpg" | "jpeg" => "jpeg", + _ => "png", + }, None => "png", }; - let image_output_path = images_dir - .join(sanitize_filename::sanitize( - Path::new(&annotation.image_path) - .file_stem() - .unwrap() - .to_str() - .unwrap(), - )) - .with_extension(extension); + let image_output_path = images_dir.join(sanitized_name).with_extension(extension); let mut file = File::create(&image_output_path).expect("Failed to create image file"); file.write_all(&image_data) .expect("Failed to write image data"); @@ -443,3 +419,135 @@ fn process_annotation( ); } } + +/// Convert an annotation to YOLO format (bounding box or polygon) +fn convert_to_yolo_format( + annotation: &ImageAnnotation, + args: &Args, + label_map: &Arc>, + next_class_id: &Arc, +) -> (String, bool) { + let mut yolo_data = String::new(); + let mut should_skip = false; + + for shape in &annotation.shapes { + let class_id = match label_map.get(&shape.label) { + Some(class_id) => *class_id, + None if args.label_list.is_empty() => { + let new_id = next_class_id.fetch_add(1, Ordering::Relaxed); + label_map.insert(shape.label.clone(), new_id); + new_id + } + _ => { + should_skip = true; + continue; + } + }; + + match args.output_format { + Format::Polygon => { + yolo_data.push_str(&format!("{}", class_id)); + process_polygon_shape(&mut yolo_data, annotation, shape); + yolo_data.push_str("\n"); + } + Format::Bbox => { + let (x_center, y_center, width, height) = calculate_bounding_box(annotation, shape); + yolo_data.push_str(&format!( + "{} {:.6} {:.6} {:.6} {:.6}\n", + class_id, x_center, y_center, width, height + )); + } + } + } + + (yolo_data, should_skip) +} + +/// Process polygon shape data for YOLO format +fn process_polygon_shape(yolo_data: &mut String, annotation: &ImageAnnotation, shape: &Shape) { + if shape.shape_type == "rectangle" { + let (x1, y1) = shape.points[0]; + let (x2, y2) = shape.points[1]; + let rect_points = vec![(x1, y1), (x2, y1), (x2, y2), (x1, y2)]; + for &(x, y) in &rect_points { + let x_norm = x / annotation.image_width as f64; + let y_norm = y / annotation.image_height as f64; + yolo_data.push_str(&format!(" {:.6} {:.6}", x_norm, y_norm)); + } + } else if shape.shape_type == "circle" { + let (cx, cy) = shape.points[0]; + let (px, py) = shape.points[1]; + let radius = ((cx - px).powi(2) + (cy - py).powi(2)).sqrt(); + for i in 0..12 { + let angle = 2.0 * std::f64::consts::PI * i as f64 / 12.0; + let x = cx + radius * angle.cos(); + let y = cy + radius * angle.sin(); + let x_norm = x / annotation.image_width as f64; + let y_norm = y / annotation.image_height as f64; + yolo_data.push_str(&format!(" {:.6} {:.6}", x_norm, y_norm)); + } + } else { + for &(x, y) in &shape.points { + let x_norm = x / annotation.image_width as f64; + let y_norm = y / annotation.image_height as f64; + yolo_data.push_str(&format!(" {:.6} {:.6}", x_norm, y_norm)); + } + } +} + +/// Calculate bounding box for YOLO format +fn calculate_bounding_box(annotation: &ImageAnnotation, shape: &Shape) -> (f64, f64, f64, f64) { + let (x_min, y_min, x_max, y_max) = if shape.shape_type == "circle" { + let (cx, cy) = shape.points[0]; + let (px, py) = shape.points[1]; + let radius = ((cx - px).powi(2) + (cy - py).powi(2)).sqrt(); + (cx - radius, cy - radius, cx + radius, cy + radius) + } else { + shape.points.iter().fold( + (f64::MAX, f64::MAX, f64::MIN, f64::MIN), + |(x_min, y_min, x_max, y_max), &(x, y)| { + (x_min.min(x), y_min.min(y), x_max.max(x), y_max.max(y)) + }, + ) + }; + + let x_center = (x_min + x_max) / 2.0 / annotation.image_width as f64; + let y_center = (y_min + y_max) / 2.0 / annotation.image_height as f64; + let width = (x_max - x_min) / annotation.image_width as f64; + let height = (y_max - y_min) / annotation.image_height as f64; + + (x_center, y_center, width, height) +} + +/// Create the dataset.yaml file for YOLO training +fn create_dataset_yaml( + dirname: &Path, + args: &Args, + label_map: &Arc>, +) -> std::io::Result<()> { + let dataset_yaml_path = dirname.join("YOLODataset/dataset.yaml"); + let mut dataset_yaml = File::create(&dataset_yaml_path)?; + let absolute_path = fs::canonicalize(&dirname.join("YOLODataset"))?; + let mut yaml_content = format!( + "path: {}\ntrain: images/train\nval: images/val\n", + absolute_path.to_string_lossy() + ); + if args.test_size > 0.0 { + yaml_content.push_str("test: images/test\n"); + } else { + yaml_content.push_str("test:\n"); + } + yaml_content.push_str("\nnames:\n"); + + // Extract and sort labels by their ID + let mut sorted_labels: Vec<_> = label_map + .iter() + .map(|entry| (entry.key().clone(), *entry.value())) + .collect(); + sorted_labels.sort_by_key(|&(_, id)| id); + + for (label, id) in sorted_labels { + yaml_content.push_str(&format!(" {}: {}\n", id, label)); + } + dataset_yaml.write_all(yaml_content.as_bytes()) +}