Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolaiPetukhov committed Mar 5, 2024
1 parent 00d5936 commit 233d871
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 7 deletions.
2 changes: 1 addition & 1 deletion supervisely/train/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
supervisely==6.73.4
supervisely==6.73.41
74 changes: 74 additions & 0 deletions supervisely/train/src/sly_project_cached.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import os

import supervisely as sly
from supervisely.project.download import (
download_to_cache,
copy_from_cache,
is_cached,
get_cache_size,
)
from sly_utils import get_progress_cb
import sly_train_globals as g


def download_project(
api: sly.Api,
project_info: sly.ProjectInfo,
project_dir: str,
use_cache: bool,
):
if os.path.exists(project_dir):
sly.fs.clean_dir(project_dir)
if not use_cache:
total = project_info.items_count
download_progress = get_progress_cb("Downloading input data...", total * 2)
sly.download(
api=api,
project_id=project_info.id,
dest_dir=project_dir,
dataset_ids=None,
log_progress=True,
progress_cb=download_progress,
cache=g.my_app.cache
)
return

# get datasets to download and cached
dataset_infos = api.dataset.get_list(project_info.id)
to_download = [info for info in dataset_infos if not is_cached(project_info.id, info.name)]
cached = [info for info in dataset_infos if is_cached(project_info.id, info.name)]
if len(cached) == 0:
log_msg = "No cached datasets found"
else:
log_msg = "Using cached datasets: " + ", ".join(
f"{ds_info.name} ({ds_info.id})" for ds_info in cached
)
sly.logger.info(log_msg)
if len(to_download) == 0:
log_msg = "All datasets are cached. No datasets to download"
else:
log_msg = "Downloading datasets: " + ", ".join(
f"{ds_info.name} ({ds_info.id})" for ds_info in to_download
)
sly.logger.info(log_msg)
# get images count
total = sum([ds_info.images_count for ds_info in to_download])
# download
download_progress = get_progress_cb("Downloading input data...", total * 2)
download_to_cache(
api=api,
project_id=project_info.id,
dataset_infos=to_download,
log_progress=True,
progress_cb=download_progress,
)
# copy datasets from cache
total = sum([get_cache_size(project_info.id, ds.name) for ds in dataset_infos])
dataset_names = [ds_info.name for ds_info in dataset_infos]
download_progress = get_progress_cb("Retreiving data from cache...", total, is_size=True)
copy_from_cache(
project_id=project_info.id,
dest_dir=project_dir,
dataset_names=dataset_names,
progress_cb=download_progress,
)
9 changes: 7 additions & 2 deletions supervisely/train/src/sly_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
root_source_dir, scratch_str, finetune_str

import ui as ui
from sly_project_cached import download_project
from sly_train_utils import init_script_arguments
from sly_utils import get_progress_cb, upload_artifacts
from splits import get_train_val_sets, verify_train_val_sets
Expand Down Expand Up @@ -39,9 +40,13 @@ def train(api: sly.Api, task_id, context, state, app_logger):
sly.fs.mkdir(project_dir, remove_content_if_exists=True) # clean content for debug, has no effect in prod

# download and preprocess Sypervisely project (using cache)
download_progress = get_progress_cb("Download data (using cache)", g.project_info.items_count * 2)
try:
sly.download_project(api, project_id, project_dir, cache=my_app.cache, progress_cb=download_progress)
download_project(
api=api,
project_info=g.project_info,
project_dir=project_dir,
use_cache=state.get("useCache", True),
)
except Exception as e:
sly.logger.warn("Can not download project")
raise Exception(
Expand Down
10 changes: 7 additions & 3 deletions supervisely/train/src/ui/input_project.html
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
<sly-card title="1. Input Project" subtitle="This project will be used for training">
<sly-field title="" description="Project">
<a slot="title" target="_blank"
:href="`/projects/${data.projectId}/datasets`">{{data.projectName}} ({{data.projectImagesCount}}
<a slot="title" target="_blank" :href="`/projects/${data.projectId}/datasets`">{{data.projectName}}
({{data.projectImagesCount}}
images)</a>
<sly-icon slot="icon" :options="{ imageUrl: `${data.projectPreviewUrl}` }"/>
<sly-icon slot="icon" :options="{ imageUrl: `${data.projectPreviewUrl}` }" />
</sly-field>
<el-checkbox v-model="state.useCache">
<span v-if="data.isCached">Use cached data stored on the agent to optimize project downlaod</span>
<span v-else>Cache data on the agent to optimize project download for future trainings</span>
</el-checkbox>
</sly-card>
5 changes: 4 additions & 1 deletion supervisely/train/src/ui/input_project.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from supervisely.project.download import is_cached
import sly_train_globals as g


def init(data):
def init(data, state):
data["projectId"] = g.project_info.id
data["projectName"] = g.project_info.name
data["projectImagesCount"] = g.project_info.items_count
data["projectPreviewUrl"] = g.api.image.preview_url(g.project_info.reference_image_url, 100, 100)
data["isCached"] = is_cached(g.project_info.id)
state["useCache"] = True

0 comments on commit 233d871

Please sign in to comment.