From 86dba5ff20ca1bc64a6cc02e8ff4b2d8249f1019 Mon Sep 17 00:00:00 2001 From: Lalit Pant Date: Mon, 14 Oct 2024 09:01:14 +0530 Subject: [PATCH] Add support for yolov8. --- .../scala/net/kogics/kojo/aiapp/ObjectDetector.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/main/scala/net/kogics/kojo/aiapp/ObjectDetector.scala b/src/main/scala/net/kogics/kojo/aiapp/ObjectDetector.scala index e9fc068..c7b2a88 100644 --- a/src/main/scala/net/kogics/kojo/aiapp/ObjectDetector.scala +++ b/src/main/scala/net/kogics/kojo/aiapp/ObjectDetector.scala @@ -8,12 +8,16 @@ import scala.util.Using import ai.djl.modality.cv.output.DetectedObjects import ai.djl.modality.cv.Image import ai.djl.repository.zoo.Criteria -import org.bytedeco.javacv.Java2DFrameUtils import org.bytedeco.opencv.opencv_core.Mat class ObjectDetector(modelDir: String) { val (model, predictor) = { println("Loading 'object detection' model...") + val translatorFactory = modelDir match { + case s if s.contains("yolov5") => new ai.djl.modality.cv.translator.YoloV5TranslatorFactory() + case s if s.contains("yolov8") => new ai.djl.modality.cv.translator.YoloV8TranslatorFactory() + case _ => throw new RuntimeException("Unknown object detection model") + } import scala.jdk.CollectionConverters._ val args: Map[String, AnyRef] = Map( @@ -22,7 +26,7 @@ class ObjectDetector(modelDir: String) { "resize" -> Boolean.box(true), "rescale" -> Boolean.box(true), "optApplyRatio" -> Boolean.box(true), - "threshold" -> Double.box(0.4), + "threshold" -> Double.box(0.6), ) val criteria = @@ -30,7 +34,7 @@ class ObjectDetector(modelDir: String) { .builder() .setTypes(classOf[Image], classOf[DetectedObjects]) .optModelPath(Paths.get(modelDir)) - .optTranslatorFactory(new ai.djl.modality.cv.translator.YoloV5TranslatorFactory()) + .optTranslatorFactory(translatorFactory) .optArguments(args.asJava) .build()