diff --git a/src/sparseml/yolov8/trainers.py b/src/sparseml/yolov8/trainers.py index 76a55fb190a..ba817dab660 100644 --- a/src/sparseml/yolov8/trainers.py +++ b/src/sparseml/yolov8/trainers.py @@ -26,6 +26,7 @@ from typing import List, Optional import torch +import yaml from sparseml.optim.helpers import load_recipe_yaml_str from sparseml.pytorch.optim.manager import ScheduledModifierManager @@ -46,7 +47,7 @@ SparseSegmentationValidator, ) from sparsezoo import Model -from sparsezoo.utils import validate_onnx +from sparsezoo.utils import validate_onnx, load_model, save_onnx from ultralytics import __version__ from ultralytics.nn.modules import Detect, Segment from ultralytics.nn.tasks import SegmentationModel, attempt_load_one_weight @@ -784,6 +785,28 @@ def export(self, **kwargs): except Exception: pass + # Add model metadata + data_yaml = Path(self.ckpt.get("train_args", {}).get("data", "")).read_text("utf-8") + names = yaml.safe_load(data_yaml).get("names", {}) + metadata = { + "description": f"Ultralytics YOLOv8 via {source}", + "author": "Ultralytics", + "date": self.ckpt.get("date"), + "version": self.ckpt.get("version"), + "license": "AGPL-3.0 License (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + "stride": int(max(model.model.stride)), + "task": self.task, + "batch": self.overrides.get("batch"), + "imgsz": self.overrides.get("imgsz"), + "names": names.copy(), + } + _onnx_m = load_model(complete_path) + for k,v in metadata.items(): + meta = _onnx_m.metadata_props.add() + meta.key, meta.value = k, str(v) + _ = save_onnx(model=_onnx_m, model_path=complete_path) + validate_onnx(complete_path) deployment_folder = exporter.create_deployment_folder(onnx_model_name=name) if args["export_samples"]: