This repository has been archived by the owner on Feb 9, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathstreamlit_app.py
185 lines (161 loc) · 6.11 KB
/
streamlit_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import time
import os
import pathlib, platform
import pprint as pp
import shutil
from io import BytesIO
from os.path import basename, join
import timm
from natsort import natsorted
import skimage
import streamlit as st
from fastai.vision.all import *
from natsort import natsorted
from skimage import io
from skimage.transform import resize
# account for posixpath
if platform.system() == "Windows":
# model originally saved on Linux, strange things happen
print("on Windows OS - adjusting PosixPath")
temp = pathlib.PosixPath
pathlib.PosixPath = pathlib.WindowsPath
def load_best_model():
try:
path_to_archive = r"model-resnetv2_50x1_bigtransfer_u.zip"
best_model_name = "model-resnetv2_50x1_bigtransfer.pkl"
shutil.unpack_archive(path_to_archive)
best_model = load_learner(join(os.getcwd(), best_model_name), cpu=True)
except:
st.write("unable to load locally. downloading model file")
model_b_best = "https://www.dropbox.com/s/9c1ovx6dclp8uve/model-resnetv2_50x1_bigtransfer.pkl?dl=1"
best_model_response = requests.get(model_b_best)
best_model = load_learner(BytesIO(best_model_response.content), cpu=True)
return best_model
def load_mixnet_model():
try:
path_to_model = r"model-mixnetXL-20epoch_u.pil"
model = load_learner(path_to_model, cpu=True)
except:
st.write("unable to load locally. downloading model file")
model_backup = (
"https://www.dropbox.com/s/bwfar78vds9ou1r/model-mixnetXL-20epoch.pkl?dl=1"
)
model_response = requests.get(model_backup)
model = load_learner(BytesIO(model_response.content), cpu=True)
return model
# App title and intro
supplemental_dir = os.path.join(os.getcwd(), "info")
fp_header = os.path.join(supplemental_dir, "climb_area_examples.png")
st.title("NatureGeoDiscoverer MVP: Detect Bouldering Areas")
st.markdown(
"by [Peter Szemraj](https://peterszemraj.ch/) | [GitHub](https://github.com/pszemraj)"
)
with st.beta_container():
st.header("Basic Instructions")
st.markdown(
"*This app assesses a satellite or arial image of land chosen by the user (scroll down) and "
"decides whether it is suitable for outdoor bouldering.*"
)
st.markdown("---")
st.markdown("**Examples of Images in the *climb area* class**")
st.image(skimage.io.imread(fp_header))
st.markdown("---")
with st.beta_container():
st.subheader("Sample Images")
st.markdown(
"If lacking satellite images, the dropbox folder ["
"here](https://www.dropbox.com/sh/0hz4lh9h8v30a8d/AACFwlIAvdnDdc6RvrcXVpnsa?dl=0) contains "
"images that were not used for model training."
)
# Fxn
@st.cache
def load_image(image_file):
# loads uploaded images
img = Image.open(image_file)
return img
# prediction function
def predict(img, img_flex, use_best_model=False):
# NOTE: it's called img_flex because it can either be an object itself, or a path to one
# Display the test image
st.image(img, caption="Chosen Image to Analyze", use_column_width=True)
if use_best_model:
model_pred = load_best_model()
else:
model_pred = load_mixnet_model()
with st.spinner("thinking..."):
time.sleep(3)
# make prediction
if not isinstance(img_flex, str):
fancy_class = PILImage(img_flex)
model_pred.precompute = False
pred_class, pred_items, pred_prob = model_pred.predict(fancy_class)
else:
# loads from a file so it's fine
pred_class, pred_items, pred_prob = model_pred.predict(img_flex)
prob_np = pred_prob.numpy()
# Display the prediction
if str(pred_class) == "climb_area":
st.balloons()
st.subheader(
"Area in test image is good for climbing! {}% confident.".format(
round(100 * prob_np[0], 2)
)
)
else:
st.subheader(
"Area in test image not great for climbing :/ - {}% confident.".format(
100 - round(100 * prob_np[0], 2)
)
)
# select model type
want_adv = st.checkbox("Use Advanced model (slower)?")
if want_adv:
st.markdown("*will analyze with advanced model*")
# Image source selection
option1_text = "Use an example image"
option2_text = "Upload a custom image for analysis"
option = st.radio("Choose a method to load an image:", [option1_text, option2_text])
# provide different options based on selection
if option == option1_text:
# Test image selection
working_dir = os.path.join(os.getcwd(), "test_images")
test_images = natsorted(
[
f
for f in os.listdir(working_dir)
if os.path.isfile(os.path.join(working_dir, f))
]
)
test_image = st.selectbox("Please select a test image:", test_images)
if st.button("Analyze!"):
# Read the image
file_path = os.path.join(working_dir, test_image)
img = skimage.io.imread(file_path)
img = resize(img, (256, 256))
# Predict and display the image
predict(img, file_path, want_adv)
else:
image_file = st.file_uploader("Upload Image", type=["png", "jpeg", "jpg"])
if st.button("Analyze!"):
if image_file is not None:
file_details = {
"Filename": image_file.name,
"FileType": image_file.type,
"FileSize": image_file.size,
}
base_img = load_image(image_file)
img = base_img.resize((256, 256))
img = img.convert("RGB")
# Predict and display the image
predict(img, img, want_adv)
st.markdown("---")
st.subheader("How it Works:")
st.markdown(
"**BoulderAreaDetector** uses Convolutional Neural Network (CNN) trained on a labeled dataset ("
"approx. 3000 satellite images, each 256x256 in two classes) with two classes. More "
"specifically, the primary model is [MixNet-XL](https://paperswithcode.com/method/mixconv)."
)
st.markdown(
"_Note that the model used on [BoulderSpot](boulderspot.io) is different and more advanced, due to git "
"file limitations, and compute resources on the app"
)