-
Notifications
You must be signed in to change notification settings - Fork 0
/
yolov8-classify.py
130 lines (106 loc) ยท 3.11 KB
/
yolov8-classify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import streamlit as st
from ultralytics import YOLO
from PIL import Image
import pandas as pd
import plotly.graph_objects as go
# Local Modules
import content
import utils
# Sidebar
model_path = utils.yolo_classify_sidebar_options()
# Load YOLO model
with st.spinner("Model is downloading..."):
model = YOLO(model_path)
class_names = list(
model.names.values()
) # Convert dictionary to list of class names
st.success("Model loaded successfully!", icon="โ
")
# Content
content.content_yolov8_classify()
# Image selection
uploaded_file = utils.image_selector()
col1, col2 = st.columns(2)
with col1:
# Display the uploaded image
st.markdown("## ๐ผ๏ธ Input Image")
st.image(uploaded_file, caption="Uploaded Image | Input", use_column_width="auto")
# Open the image using PIL
image = Image.open(uploaded_file)
# Perform prediction
results = model.predict(image)
# Get classification results
summary = results[0].summary()
probs = results[0].probs.data.tolist()
names = results[0].names
speeds = results[0].speed # Speed object for inference times
original_shape = results[0].orig_shape # Original image shape
with col2:
# Display results
st.markdown("## ๐ Classification Results")
annotated_img = results[0].plot() # plot a BGR numpy array of predictions
st.image(
annotated_img,
channels="BGR",
caption="Classification Results | Classification Model Output",
use_column_width="auto",
)
# Extract class names and probabilities
class_names = list(names.values())
# Create a DataFrame for Plotly
df = pd.DataFrame({"Class": class_names, "Probability": probs})
# Create Plotly figure
fig = go.Figure(
data=[
go.Bar(
x=df["Class"],
y=df["Probability"],
text=df["Probability"].apply(lambda x: f"{x:.4f}"),
textposition="outside",
textfont=dict(size=18),
hoverinfo="text",
hovertext=[
f"Class: {cls}<br>Probability: {prob:.4f}"
for cls, prob in zip(df["Class"], df["Probability"])
],
)
]
)
# Update layout with larger font sizes
fig.update_layout(
title={
"text": "Class Probabilities",
"font": {"size": 20},
},
xaxis_title="Class",
yaxis_title="Probability",
xaxis={
"title": {"font": {"size": 24}},
"tickfont": {"size": 18},
},
yaxis={
"title": {"font": {"size": 24}},
"tickfont": {"size": 18},
"range": [0, 1],
},
width=800,
height=500,
font=dict(size=14),
)
# Display the plot in Streamlit
st.markdown("## ๐ Bar chart of classification results")
st.plotly_chart(fig, use_container_width=True)
st.divider()
md_prediciton = f"""
## ๐๏ธ Model Prediction
## Model predicts the class **:red[{summary[0]['name']}]** with a probability of **:red[{summary[0]['confidence']:.2f}]**
"""
st.markdown(md_prediciton)
st.divider()
utils.show_inference(speeds=speeds)
st.divider()
utils.show_data_objects(
speeds=speeds,
original_shape=original_shape,
probabilities=probs,
class_names=class_names,
)