Skip to content

Commit

Permalink
Tch backend (#23)
Browse files Browse the repository at this point in the history
* update yoloannotation

now includes confidence param

* add example

* i fucking hate this IValue bullshit

* fuck

* v9 infernece

* nms

* nms iou

* fmt

* stash

* stash2

* nms and iou

* nms/iou working BUT

somehow i can only get detections from ONE image :/

* torchscript working :)

* same device

* ayy lmao

* stash

* cleanup

* make onnx optional for kesa_al

* al backends are all optional

* yolo bbox test

* fmt

* to shape

* labelme from Vec<Shape>

* logging
  • Loading branch information
heabeounMKTO authored Apr 23, 2024
1 parent 69f6e11 commit 7e03c21
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 105 deletions.
10 changes: 6 additions & 4 deletions examples/tch_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use kesa::{
backends::tch_backend::{self, TchModel},
fileutils::get_config_from_name,
image_utils,
label::{Shape, YoloAnnotation},
label::{Shape, YoloAnnotation, LabelmeAnnotation},
output::OutputFormat,
};

Expand All @@ -29,10 +29,12 @@ fn load_tch(input: &str, device: Option<tch::Device>) -> Result<TchModel, Error>
let _ac: Vec<String> = all_classes.iter().map(|x| String::from(*x)).collect();
let preproc_img = image_utils::preprocess_imagef16(&_img2, 640)?;
let mut _pimg2 = tch::Tensor::try_from(preproc_img)?;
println!("pimg2 {:?}", _pimg2);
println!("pimg2 {:?}", _img2.dimensions());
let mut test_inf = loaded_model.run_fp16(&_pimg2, 0.7, 0.6, "yolov9")?;
println!(
"testinf[1] to yolo: {:?}", test_inf[0].to_normalized((640,640)) );
let uhhh = test_inf[0].to_normalized(&(640, 640)).to_screen(&(690, 1035)).to_shape(&_ac, &(690, 1035))?;
let _2vec: Vec<Shape> = vec![uhhh];
let _lm: LabelmeAnnotation = LabelmeAnnotation::from_shape_vec(imgpath, &_img2, &_2vec)?;
println!("labelme anno w:{:?} h:{:?}", _lm.imageWidth, _lm.imageHeight);
Ok(loaded_model)
}

Expand Down
10 changes: 5 additions & 5 deletions src/backends/onnx_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl InferenceModel for OnnxModel {
}
// returns a vec with a single numer on error :|
Err(e) => {
println!("cannot find any detections! {}", e);
println!("[error]::onnx_backend: cannot find any detections! {}", e);
let mut rnd = rand::thread_rng();
let random_num = vec![rnd.gen::<u8>() as usize];
Embeddings::new(Array::zeros(IxDyn(&random_num)))
Expand Down Expand Up @@ -115,7 +115,7 @@ impl<'a> OnnxInference<'a> {
}
// returns a vec with a single numer on error :|
Err(e) => {
println!("cannot find any detections! {}", e);
println!("[error]::onnx_backend: cannot find any detections! {}", e);
let mut rnd = rand::thread_rng();
let random_num = vec![rnd.gen::<u8>() as usize];
Embeddings::new(Array::zeros(IxDyn(&random_num)))
Expand Down Expand Up @@ -163,7 +163,7 @@ pub fn load_onnx_model(
.commit_from_file(&model_path)
.unwrap();
let model_yaml_config_path = get_config_from_name(&config_path, &model_path)
.expect("Cannot Find model Configuration file");
.expect("[error]::onnx_backend: cannot find model configuration file");

let model_details =
serde_yaml::from_reader(std::fs::File::open(&model_yaml_config_path)?).unwrap();
Expand All @@ -172,13 +172,13 @@ pub fn load_onnx_model(
// nhom sok chet ort mean phob lok
let mut spinna = Spinner::new(
Spinners::Dots12,
format!("Loading Model {:?}", &model_path).into(),
format!("[info]::onnx_backend: loading model {:?}", &model_path).into(),
);
let loaded_model: OnnxModel = OnnxModel::new(model_details, model, false).unwrap();
let mut _dummy_input: ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>> =
Array::ones((1, 3, 640, 640));
let original_img = image::open(Path::new(image_path)).unwrap();
println!("\nRunning Warmup");
println!("\n[info]::onnx_backend: running Warmup");
// runs a forward pass on a random image from the folder
let _ = &loaded_model.run(original_img);
spinna.stop_with_symbol("✅");
Expand Down
25 changes: 13 additions & 12 deletions src/backends/tch_backend.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::label::Xyxy;
use crate::label::{YoloBbox, CoordinateType};
use crate::label::{CoordinateType, YoloBbox};
use anyhow::bail;
use anyhow::{Error, Result};
use conv::TryInto;
Expand All @@ -12,11 +12,13 @@ use tch::Kind;
use tch::Tensor;
use tch::{self, vision::image};


#[derive(Debug)]
pub struct TchModel {
model: tch::CModule,
device: tch::Device,
w: i64,
h: i64,
pub model: tch::CModule,
pub device: tch::Device,
pub w: i64,
pub h: i64,
}

impl TchModel {
Expand Down Expand Up @@ -92,9 +94,9 @@ impl TchModel {
.unwrap()
.to_device(self.device);
let _transposed_o = pred.transpose(2, 1);
let t1 = std::time::Instant::now();
// let t1 = std::time::Instant::now();
let results = self.nms_yolov9(&_transposed_o.get(0), conf_thresh, iou_thresh);
println!("inference time: {:?}", t1.elapsed());
// println!("inference time: {:?}", t1.elapsed());
results
}
_ => {
Expand Down Expand Up @@ -174,18 +176,17 @@ impl TchModel {
// you can normalize these coordinates [x,y]/640 then multiply
// it by its dimension i.e [x, y]*[imagewidth, imageheight]
if pred[4 + class_index] > 0. {

let xyxy: Xyxy = Xyxy {
coordinate_type: CoordinateType::Screen,
x1: (pred[0] - pred[2] / 2.0),
y1: (pred[1] - pred[3] / 2.0),
x2: (pred[0] + pred[2] / 2.0),
y2: (pred[0] + pred[3] / 2.0)
y2: (pred[0] + pred[3] / 2.0),
};
let bbox: YoloBbox = YoloBbox {
class: class_index as i64,
xyxy: xyxy,
confidence: confidence
confidence: confidence,
};
bboxes[class_index].push(bbox);
}
Expand Down Expand Up @@ -249,12 +250,12 @@ impl TchModel {
x1: (pred[0] - pred[2] / 2.0),
y1: (pred[1] - pred[3] / 2.0),
x2: (pred[0] + pred[2] / 2.0),
y2: (pred[0] + pred[3] / 2.0)
y2: (pred[0] + pred[3] / 2.0),
};
let bbox: YoloBbox = YoloBbox {
class: class_index as i64,
xyxy: xyxy,
confidence: confidence
confidence: confidence,
};
bboxes[class_index].push(bbox);
}
Expand Down
17 changes: 14 additions & 3 deletions src/kesa_al.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ use backends::onnx_backend::{init_onnx_backend, load_onnx_model};

#[cfg(feature = "torch")]
use backends::tch_backend::TchModel;
#[cfg(feature = "torch")]
use tch::Device;

use clap::{ArgAction, Parser};
use fileutils::{open_image, write_yolo_to_txt};
Expand Down Expand Up @@ -111,8 +113,9 @@ fn main() -> Result<(), Error> {
.unwrap();
let all_imgs = get_all_images(&args.folder);
let model_type: ComputeBackendType = get_backend(&args.weights)?;
println!("Detected model format : {:#?}", &model_type);
println!("[info]::kesa_al: detected model format : {:#?}", &model_type);
match model_type {

#[cfg(feature = "onnxruntime")]
ComputeBackendType::OnnxModel => {
let init_onnx = init_onnx_backend()?;
Expand All @@ -122,7 +125,7 @@ fn main() -> Result<(), Error> {
false,
None,
)?;
println!("leme get a uhh : {:?}", &load_model.model);
println!("[info]::kesa_al: onnx_model {:#?}", &load_model.model);
let prog = ProgressBar::new(all_imgs.to_owned().len() as u64);
all_imgs.par_iter().for_each(|image_path| {
let orig_img = open_image(&image_path);
Expand Down Expand Up @@ -173,9 +176,17 @@ fn main() -> Result<(), Error> {

#[cfg(feature = "torch")]
ComputeBackendType::TchModel => {
let cuda = tch::Device::cuda_if_available();
let torch_model = TchModel::new(
&args.weights,
args.imgsize.to_owned() as i64,
args.imgsize.to_owned() as i64,
cuda,
);
println!("[info]::kesa_al: torch_model {:#?}", &torch_model);
todo!()
}
_ => panic!("cannot infer model type!"),
_ => panic!("[error]::kesa_al: cannot infer model type!"),
};
// draw_dummy_graph();
Ok(())
Expand Down
8 changes: 4 additions & 4 deletions src/kesa_aug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ fn main() -> Result<(), Error> {
.num_threads(workers.unwrap().try_into().unwrap())
.build_global()
.unwrap();
let mut spinner0 = Spinner::new(spinners::Hearts, "collecting jsons..", Color::White);
let mut spinner0 = Spinner::new(spinners::Hearts, "[info]::kesa_aug: collecting jsons..", Color::White);
let all_json = get_all_jsons(&args.folder)?;
let all_classes = get_all_classes(&all_json)?;
let classes_hash = get_all_classes_hash(&all_classes)?;
spinner0.success(format!("found {:?} json files", &all_json.len()).as_str());
spinner0.success(format!("[info]::kesa_aug: found {:?} json files", &all_json.len()).as_str());
let prog = ProgressBar::new(all_json.len().to_owned() as u64);

all_json.par_iter().for_each(|file| {
Expand All @@ -85,7 +85,7 @@ fn main() -> Result<(), Error> {
.unwrap();
}
});
prog.finish_with_message("created augmentations!\n");
prog.finish_with_message("[info]::kesa_aug: created augmentations!\n");
Ok(())
}

Expand Down Expand Up @@ -171,7 +171,7 @@ fn get_random_aug() -> Result<AugmentationType, Error> {
10 => AugmentationType::HueRotate270,
11 => AugmentationType::Grayscale,
12 => AugmentationType::Rotate90,
_ => panic!("unknown augmentation type!"),
_ => panic!("[error]::kesa_aug: unknown augmentation type!"),
};
Ok(do_aug)
}
20 changes: 10 additions & 10 deletions src/kesa_l2y.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,19 @@ fn main() -> Result<(), Error> {
// TODO: put split portions
let export_options = ExportFolderOptions::new(export.unwrap().as_str(), 0.7)?;
println!("Export Options: {:#?}", &export_options);
let mut spinner0 = Spinner::new(spinners::Hearts, "Creating export paths", Color::White);
let mut spinner0 = Spinner::new(spinners::Hearts, "[info]::kesa_l2y: creating export paths", Color::White);
export_options.create_folders()?;
spinner0.success("Created export paths");
spinner0.success("[info]::kesa_l2y: created export paths");

let mut spinner = Spinner::new(
spinners::Hearts,
format!("Searching for .json files in {:?}", &args.folder),
format!("[info]::kesa_l2y: searching for .json files in {:?}", &args.folder),
Color::White,
);
let all_json = get_all_jsons(&args.folder)?;
let all_classes = get_all_classes(&all_json)?;

spinner.success(format!("found {:?} json files", &all_json.len()).as_str());
spinner.success(format!("[info]::kesa_l2y: found {:?} json files", &all_json.len()).as_str());

let prog = ProgressBar::new(all_json.len().to_owned() as u64);
let class_hash = get_all_classes_hash(&all_classes)?;
Expand All @@ -72,7 +72,7 @@ fn main() -> Result<(), Error> {
prog.inc(1);
convert_labelme2yolo(file, &class_hash)
});
prog.finish_with_message("Conversion done !\n");
prog.finish_with_message("[info]::kesa_l2y: conversion done !\n");

// split array into 3
let train_split = all_json.len().to_owned() as f32 * export_options.train_ratio;
Expand All @@ -96,7 +96,7 @@ fn convert_labelme2yolo(json: &PathBuf, class_hash: &HashMap<String, i64>) -> ()
// convert to yolo txt format
let all_yolo = all_shapes
.to_yolo(&class_hash)
.expect("cannot convert yolo");
.expect("[error]::kesa_l2y: cannot convert yolo");
let _write = write_yolo_to_txt(all_yolo, &json);
}

Expand All @@ -106,7 +106,7 @@ fn move_files(
export_options: &ExportFolderOptions,
batch: &str,
) -> Result<(), Error> {
println!("moving files to `{}` batch", &batch);
println!("[info]::kesa_l2y: moving files to `{}` batch", &batch);
let prog = ProgressBar::new(input_array.len().to_owned() as u64);
for orig_json_file in input_array.iter() {
prog.inc(1);
Expand Down Expand Up @@ -144,11 +144,11 @@ fn move_files(
fs::rename(orig_txt_file, dest_label)?;
}
_ => {
bail!("unrecognized batch name {:?}", batch)
bail!("[error]::kesa_l2y: unrecognized batch name {:?}", batch)
}
}
}
prog.finish_with_message("files moved !\n");
println!("files moving done!");
prog.finish_with_message("[info]::kesa_l2y: files moved !\n");
println!("[info]::kesa_l2y: files moving done!");
Ok(())
}
4 changes: 2 additions & 2 deletions src/kesa_split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ fn move_txt_files(
}
}
}
prog.finish_with_message("files moved !\n");
println!("files moving done!");
prog.finish_with_message("[info]::kesa_split: files moved !\n");
println!("[info]::kesa_split: files moving done!");
Ok(())
}
Loading

0 comments on commit 7e03c21

Please sign in to comment.