diff --git a/OTAnalytics/adapter_ui/view_model.py b/OTAnalytics/adapter_ui/view_model.py
index 4f3d19550..b4fe4b8dc 100644
--- a/OTAnalytics/adapter_ui/view_model.py
+++ b/OTAnalytics/adapter_ui/view_model.py
@@ -417,3 +417,7 @@ def get_weather_types(self) -> ColumnResources:
@abstractmethod
def set_svz_metadata_frame(self, frame: AbstractFrameSvzMetadata) -> None:
raise NotImplementedError
+
+ @abstractmethod
+ def get_save_path_suggestion(self, file_type: str, context_file_type: str) -> Path:
+ raise NotImplementedError
diff --git a/OTAnalytics/application/application.py b/OTAnalytics/application/application.py
index 502627881..22f48adb7 100644
--- a/OTAnalytics/application/application.py
+++ b/OTAnalytics/application/application.py
@@ -54,6 +54,7 @@
GetSectionsById,
)
from OTAnalytics.application.use_cases.start_new_project import StartNewProject
+from OTAnalytics.application.use_cases.suggest_save_path import SavePathSuggester
from OTAnalytics.application.use_cases.track_repository import (
GetAllTrackFiles,
TrackRepositorySize,
@@ -129,6 +130,7 @@ def __init__(
load_otconfig: LoadOtconfig,
config_has_changed: ConfigHasChanged,
export_road_user_assignments: ExportRoadUserAssignments,
+ file_name_suggester: SavePathSuggester,
) -> None:
self._datastore: Datastore = datastore
self.track_state: TrackState = track_state
@@ -168,6 +170,7 @@ def __init__(
self._load_otconfig = load_otconfig
self._config_has_changed = config_has_changed
self._export_road_user_assignments = export_road_user_assignments
+ self._file_name_suggester = file_name_suggester
def connect_observers(self) -> None:
"""
@@ -640,6 +643,30 @@ def get_road_user_export_formats(
) -> Iterable[ExportFormat]:
return self._export_road_user_assignments.get_supported_formats()
+ def suggest_save_path(self, file_type: str, context_file_type: str = "") -> Path:
+ """Suggests a save path based on the given file type and an optional
+ related file type.
+
+ The suggested path is in the following format:
+ /..
+
+ The base folder will be determined in the following precedence:
+ 1. First loaded config file (otconfig or otflow)
+ 2. First loaded track file (ottrk)
+ 3. First loaded video file
+ 4. Default: Current working directory
+
+ The file stem suggestion will be determined in the following precedence:
+ 1. The file stem of the loaded config file (otconfig or otflow)
+ 2. _
+ 3. Default:
+
+ Args:
+ file_type (str): the file type.
+ context_file_type (str): the context file type.
+ """
+ return self._file_name_suggester.suggest(file_type, context_file_type)
+
class MissingTracksError(Exception):
pass
diff --git a/OTAnalytics/application/config.py b/OTAnalytics/application/config.py
index 4de2ee055..62a8833c7 100644
--- a/OTAnalytics/application/config.py
+++ b/OTAnalytics/application/config.py
@@ -13,7 +13,6 @@
CLI_CUTTING_SECTION_MARKER: str = "#clicut"
DEFAULT_EVENTLIST_FILE_STEM: str = "events"
DEFAULT_EVENTLIST_FILE_TYPE: str = "otevents"
-DEFAULT_COUNTS_FILE_STEM: str = "counts"
DEFAULT_COUNTS_FILE_TYPE: str = "csv"
DEFAULT_COUNT_INTERVAL_TIME_UNIT: str = "min"
DEFAULT_TRACK_FILE_TYPE: str = "ottrk"
@@ -23,6 +22,15 @@
DEFAULT_PROGRESSBAR_STEP_PERCENTAGE: int = 5
DEFAULT_NUM_PROCESSES = 4
+
+# File Types
+CONTEXT_FILE_TYPE_ROAD_USER_ASSIGNMENTS = "road_user_assignments"
+CONTEXT_FILE_TYPE_EVENTS = "events"
+CONTEXT_FILE_TYPE_COUNTS = "counts"
+OTCONFIG_FILE_TYPE = "otconfig"
+OTFLOW_FILE_TYPE = "otflow"
+
+
OS: str = platform.system()
"""OS OTAnalytics is currently running on"""
diff --git a/OTAnalytics/application/use_cases/suggest_save_path.py b/OTAnalytics/application/use_cases/suggest_save_path.py
new file mode 100644
index 000000000..20ae98c51
--- /dev/null
+++ b/OTAnalytics/application/use_cases/suggest_save_path.py
@@ -0,0 +1,113 @@
+from datetime import datetime
+from pathlib import Path
+from typing import Callable
+
+from OTAnalytics.application.state import FileState
+from OTAnalytics.application.use_cases.get_current_project import GetCurrentProject
+from OTAnalytics.application.use_cases.track_repository import GetAllTrackFiles
+from OTAnalytics.application.use_cases.video_repository import GetAllVideos
+
+DATETIME_FORMAT = "%Y-%m-%d_%H-%M-%S"
+
+
+class SavePathSuggester:
+ """
+ Class for suggesting save paths based on the config file, otflow file,
+ the first track file, and video file.
+
+ Args:
+ file_state (FileState): Holds information on files loaded in application.
+ get_all_track_files (GetAllTrackFiles): A use case that retrieves
+ all track files.
+ get_all_videos (GetAllVideos): A use case that retrieves all
+ video files.
+ get_project (GetCurrentProject): A use case that retrieves
+ the current project.
+ """
+
+ @property
+ def __config_file(self) -> Path | None:
+ """The path to the last loaded or saved configuration file."""
+ if config_file := self._file_state.last_saved_config.get():
+ return config_file.file
+ return None
+
+ @property
+ def __first_track_file(self) -> Path | None:
+ """The path to the first track file."""
+
+ if track_files := self._get_all_track_files():
+ return next(iter(track_files))
+ return None
+
+ @property
+ def __first_video_file(self) -> Path | None:
+ """The path to the first video file."""
+
+ if video_files := self._get_all_videos.get():
+ return video_files[0].get_path()
+ return None
+
+ def __init__(
+ self,
+ file_state: FileState,
+ get_all_track_files: GetAllTrackFiles,
+ get_all_videos: GetAllVideos,
+ get_project: GetCurrentProject,
+ provide_datetime: Callable[[], datetime] = datetime.now,
+ ) -> None:
+ self._file_state = file_state
+ self._get_all_track_files = get_all_track_files
+ self._get_all_videos = get_all_videos
+ self._get_project = get_project
+ self._provide_datetime = provide_datetime
+
+ def suggest(self, file_type: str, context_file_type: str = "") -> Path:
+ """Suggests a save path based on the given file type and an optional
+ related file type.
+
+ The suggested path is in the following format:
+ /..
+
+ The base folder will be determined in the following precedence:
+ 1. First loaded config file (otconfig or otflow)
+ 2. First loaded track file (ottrk)
+ 3. First loaded video file
+ 4. Default: Current working directory
+
+ The file stem suggestion will be determined in the following precedence:
+ 1. The file stem of the loaded config file (otconfig or otflow)
+ 2. _
+ 3. Default:
+
+ Args:
+ file_type (str): the file type.
+ context_file_type (str): the context file type.
+ """
+
+ base_folder = self._retrieve_base_folder()
+ file_stem = self._suggest_file_stem()
+ if context_file_type:
+ return base_folder / f"{file_stem}.{context_file_type}.{file_type}"
+ return base_folder / f"{file_stem}.{file_type}"
+
+ def _retrieve_base_folder(self) -> Path:
+ """Returns the base folder for suggesting a new file name."""
+ if self.__config_file:
+ return self.__config_file.parent
+ if self.__first_track_file:
+ return self.__first_track_file.parent
+ if self.__first_video_file:
+ return self.__first_video_file.parent
+ return Path.cwd()
+
+ def _suggest_file_stem(self) -> str:
+ """Generates a suggestion for the file stem."""
+
+ if self.__config_file:
+ return f"{self.__config_file.stem}"
+
+ current_time = self._provide_datetime().strftime(DATETIME_FORMAT)
+ if project_name := self._get_project.get().name:
+ return f"{project_name}_{current_time}"
+ return current_time
diff --git a/OTAnalytics/plugin_ui/cli.py b/OTAnalytics/plugin_ui/cli.py
index 091d6f082..f67a0d5a9 100644
--- a/OTAnalytics/plugin_ui/cli.py
+++ b/OTAnalytics/plugin_ui/cli.py
@@ -6,8 +6,9 @@
CountingSpecificationDto,
)
from OTAnalytics.application.config import (
+ CONTEXT_FILE_TYPE_COUNTS,
+ CONTEXT_FILE_TYPE_ROAD_USER_ASSIGNMENTS,
DEFAULT_COUNT_INTERVAL_TIME_UNIT,
- DEFAULT_COUNTS_FILE_STEM,
DEFAULT_COUNTS_FILE_TYPE,
DEFAULT_SECTIONS_FILE_TYPE,
DEFAULT_TRACK_FILE_TYPE,
@@ -244,7 +245,9 @@ def _export_events(self, sections: Iterable[Section], save_path: Path) -> None:
event_list_exporter.export(events, sections, actual_save_path)
logger().info(f"Event list saved at '{actual_save_path}'")
- assignment_path = save_path.with_suffix(".road_user_assignment.csv")
+ assignment_path = save_path.with_suffix(
+ f".{CONTEXT_FILE_TYPE_ROAD_USER_ASSIGNMENTS}.csv"
+ )
specification = ExportSpecification(
save_path=assignment_path, format=CSV_FORMAT.name
)
@@ -267,7 +270,7 @@ def _do_export_counts(self, save_path: Path) -> None:
raise ValueError("modes is None but has to be defined for exporting counts")
for count_interval in self._run_config.count_intervals:
output_file = save_path.with_suffix(
- f".{DEFAULT_COUNTS_FILE_STEM}_{count_interval}"
+ f".{CONTEXT_FILE_TYPE_COUNTS}_{count_interval}"
f"{DEFAULT_COUNT_INTERVAL_TIME_UNIT}."
f"{DEFAULT_COUNTS_FILE_TYPE}"
)
diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py b/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py
index b653cc689..44e20f7e6 100644
--- a/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py
+++ b/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py
@@ -61,6 +61,8 @@
from OTAnalytics.application.config import (
CUTTING_SECTION_MARKER,
DEFAULT_COUNTING_INTERVAL_IN_MINUTES,
+ OTCONFIG_FILE_TYPE,
+ OTFLOW_FILE_TYPE,
)
from OTAnalytics.application.logger import logger
from OTAnalytics.application.parser.flow_parser import FlowParser
@@ -173,14 +175,12 @@
LINE_SECTION: str = "line_section"
TO_SECTION = "to_section"
FROM_SECTION = "from_section"
-OTFLOW = "otflow"
MISSING_TRACK_FRAME_MESSAGE = "tracks frame"
MISSING_VIDEO_FRAME_MESSAGE = "videos frame"
MISSING_VIDEO_CONTROL_FRAME_MESSAGE = "video control frame"
MISSING_SECTION_FRAME_MESSAGE = "sections frame"
MISSING_FLOW_FRAME_MESSAGE = "flows frame"
MISSING_ANALYSIS_FRAME_MESSAGE = "analysis frame"
-OTCONFIG = "otconfig"
class MissingInjectedInstanceError(Exception):
@@ -514,16 +514,17 @@ def _show_current_project(self) -> None:
self._frame_project.update(name=project.name, start_date=project.start_date)
def save_otconfig(self) -> None:
- title = "Save configuration as"
- file_types = [(f"{OTCONFIG} file", f"*.{OTCONFIG}")]
- defaultextension = f".{OTCONFIG}"
- initialfile = f"config.{OTCONFIG}"
- otconfig_file: Path = ask_for_save_file_path(
- title, file_types, defaultextension, initialfile=initialfile
+ suggested_save_path = self._application.suggest_save_path(OTCONFIG_FILE_TYPE)
+ configuration_file = ask_for_save_file_path(
+ title="Save configuration as",
+ filetypes=[(f"{OTCONFIG_FILE_TYPE} file", f"*.{OTCONFIG_FILE_TYPE}")],
+ defaultextension=f".{OTCONFIG_FILE_TYPE}",
+ initialfile=suggested_save_path.name,
+ initialdir=suggested_save_path.parent,
)
- if not otconfig_file:
+ if not configuration_file:
return
- self._save_otconfig(otconfig_file)
+ self._save_otconfig(configuration_file)
def _save_otconfig(self, otconfig_file: Path) -> None:
logger().info(f"Config file to save: {otconfig_file}")
@@ -573,10 +574,10 @@ def load_otconfig(self) -> None:
askopenfilename(
title="Load configuration file",
filetypes=[
- (f"{OTFLOW} file", f"*.{OTFLOW}"),
- (f"{OTCONFIG} file", f"*.{OTCONFIG}"),
+ (f"{OTFLOW_FILE_TYPE} file", f"*.{OTFLOW_FILE_TYPE}"),
+ (f"{OTCONFIG_FILE_TYPE} file", f"*.{OTCONFIG_FILE_TYPE}"),
],
- defaultextension=f".{OTFLOW}",
+ defaultextension=f".{OTFLOW_FILE_TYPE}",
)
)
if not otconfig_file:
@@ -595,7 +596,7 @@ def _load_otconfig(self, otconfig_file: Path) -> None:
)
if proceed.canceled:
return
- logger().info(f"{OTCONFIG} file to load: {otconfig_file}")
+ logger().info(f"{OTCONFIG_FILE_TYPE} file to load: {otconfig_file}")
self._application.load_otconfig(file=Path(otconfig_file))
self._show_current_project()
self._show_current_svz_metadata()
@@ -726,17 +727,17 @@ def load_configuration(self) -> None: # sourcery skip: avoid-builtin-shadow
askopenfilename(
title="Load sections file",
filetypes=[
- (f"{OTFLOW} file", f"*.{OTFLOW}"),
- (f"{OTCONFIG} file", f"*.{OTCONFIG}"),
+ (f"{OTFLOW_FILE_TYPE} file", f"*.{OTFLOW_FILE_TYPE}"),
+ (f"{OTCONFIG_FILE_TYPE} file", f"*.{OTCONFIG_FILE_TYPE}"),
],
- defaultextension=f".{OTFLOW}",
+ defaultextension=f".{OTFLOW_FILE_TYPE}",
)
)
if not configuration_file.stem:
return
- elif configuration_file.suffix == f".{OTFLOW}":
+ elif configuration_file.suffix == f".{OTFLOW_FILE_TYPE}":
self._load_otflow(configuration_file)
- elif configuration_file.suffix == f".{OTCONFIG}":
+ elif configuration_file.suffix == f".{OTCONFIG_FILE_TYPE}":
self._load_otconfig(configuration_file)
else:
raise ValueError("Configuration file to load has unknown file extension")
@@ -763,25 +764,22 @@ def _load_otflow(self, otflow_file: Path) -> None:
self.refresh_items_on_canvas()
def save_configuration(self) -> None:
- initial_dir = Path.cwd()
- if config_file := self._application.file_state.last_saved_config.get():
- initial_dir = config_file.file.parent
-
+ suggested_save_path = self._application.suggest_save_path(OTFLOW_FILE_TYPE)
configuration_file = ask_for_save_file_path(
title="Save configuration as",
filetypes=[
- (f"{OTFLOW} file", f"*.{OTFLOW}"),
- (f"{OTCONFIG} file", f"*.{OTCONFIG}"),
+ (f"{OTFLOW_FILE_TYPE} file", f"*.{OTFLOW_FILE_TYPE}"),
+ (f"{OTCONFIG_FILE_TYPE} file", f"*.{OTCONFIG_FILE_TYPE}"),
],
- defaultextension=f".{OTFLOW}",
- initialfile=f"flows.{OTFLOW}",
- initialdir=initial_dir,
+ defaultextension=f".{OTFLOW_FILE_TYPE}",
+ initialfile=suggested_save_path.name,
+ initialdir=suggested_save_path.parent,
)
if not configuration_file.stem:
return
- elif configuration_file.suffix == f".{OTFLOW}":
+ elif configuration_file.suffix == f".{OTFLOW_FILE_TYPE}":
self._save_otflow(configuration_file)
- elif configuration_file.suffix == f".{OTCONFIG}":
+ elif configuration_file.suffix == f".{OTCONFIG_FILE_TYPE}":
self._save_otconfig(configuration_file)
else:
raise ValueError("Configuration file to save has unknown file extension")
@@ -1397,6 +1395,7 @@ def _configure_event_exporter(
initial_position=(50, 50),
input_values=default_values,
export_format_extensions=export_format_extensions,
+ viewmodel=self,
).get_data()
file = input_values[toplevel_export_events.EXPORT_FILE]
export_format = input_values[toplevel_export_events.EXPORT_FORMAT]
@@ -1759,6 +1758,7 @@ def export_road_user_assignments(self) -> None:
input_values=default_values,
export_format_extensions=export_formats,
initial_file_stem="road_user_assignments",
+ viewmodel=self,
).get_data()
logger().debug(export_values)
save_path = export_values[toplevel_export_events.EXPORT_FILE]
@@ -1832,3 +1832,6 @@ def _show_current_svz_metadata(self) -> None:
self._frame_svz_metadata.update(metadata=metadata.to_dict())
else:
self._frame_svz_metadata.update({})
+
+ def get_save_path_suggestion(self, file_type: str, context_file_type: str) -> Path:
+ return self._application.suggest_save_path(file_type, context_file_type)
diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_counts.py b/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_counts.py
index 0858beffa..18d43436f 100644
--- a/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_counts.py
+++ b/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_counts.py
@@ -5,6 +5,10 @@
from customtkinter import CTkEntry, CTkLabel, CTkOptionMenu
from OTAnalytics.adapter_ui.view_model import ViewModel
+from OTAnalytics.application.config import (
+ CONTEXT_FILE_TYPE_COUNTS,
+ DEFAULT_COUNT_INTERVAL_TIME_UNIT,
+)
from OTAnalytics.plugin_ui.customtkinter_gui.constants import PADX, PADY, STICKY
from OTAnalytics.plugin_ui.customtkinter_gui.frame_filter import DateRow
from OTAnalytics.plugin_ui.customtkinter_gui.helpers import ask_for_save_file_name
@@ -18,7 +22,6 @@
END = "end"
EXPORT_FORMAT = "export_format"
EXPORT_FILE = "export_file"
-INITIAL_FILE_STEM = "counts"
class CancelExportCounts(Exception):
@@ -130,11 +133,17 @@ def _create_frame_content(self, master: Any) -> FrameContent:
def _choose_file(self) -> None:
export_format = self._input_values[EXPORT_FORMAT] #
export_extension = self._export_formats[export_format]
+ suggested_save_path = self._viewmodel.get_save_path_suggestion(
+ export_extension[1:],
+ f"{CONTEXT_FILE_TYPE_COUNTS}"
+ f"_{self._input_values[INTERVAL]}{DEFAULT_COUNT_INTERVAL_TIME_UNIT}",
+ )
export_file = ask_for_save_file_name(
title="Save counts as",
filetypes=[(export_format, export_extension)],
defaultextension=export_extension,
- initialfile=INITIAL_FILE_STEM,
+ initialfile=suggested_save_path.name,
+ initialdir=suggested_save_path.parent,
)
self._input_values[EXPORT_FILE] = export_file
if export_file == "":
diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_events.py b/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_events.py
index c6f58295e..50ff7391f 100644
--- a/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_events.py
+++ b/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_events.py
@@ -3,6 +3,8 @@
from customtkinter import CTkLabel, CTkOptionMenu
+from OTAnalytics.adapter_ui.view_model import ViewModel
+from OTAnalytics.application.config import CONTEXT_FILE_TYPE_EVENTS
from OTAnalytics.plugin_ui.customtkinter_gui.constants import PADX, PADY, STICKY
from OTAnalytics.plugin_ui.customtkinter_gui.helpers import ask_for_save_file_name
from OTAnalytics.plugin_ui.customtkinter_gui.toplevel_template import (
@@ -12,7 +14,6 @@
EXPORT_FORMAT = "export_format"
EXPORT_FILE = "export_file"
-INITIAL_FILE_STEM = "events"
class CancelExportEvents(Exception):
@@ -70,11 +71,13 @@ def _is_int_above_zero(self, entry_value: Any) -> bool:
class ToplevelExportEvents(ToplevelTemplate):
def __init__(
self,
+ viewmodel: ViewModel,
export_format_extensions: dict[str, str],
input_values: dict,
- initial_file_stem: str = INITIAL_FILE_STEM,
+ initial_file_stem: str = CONTEXT_FILE_TYPE_EVENTS,
**kwargs: Any,
) -> None:
+ self._viewmodel = viewmodel
self._input_values = input_values
self._export_format_extensions = export_format_extensions
self._initial_file_stem = initial_file_stem
@@ -89,12 +92,17 @@ def _create_frame_content(self, master: Any) -> FrameContent:
def _choose_file(self) -> None:
export_format = self._input_values[EXPORT_FORMAT] #
- export_extension = f"*{self._export_format_extensions[export_format]}"
+ export_file_type = self._export_format_extensions[export_format][1:]
+ export_extension = f"*.{export_file_type}"
+ suggested_save_path = self._viewmodel.get_save_path_suggestion(
+ export_file_type, context_file_type=self._initial_file_stem
+ )
export_file = ask_for_save_file_name(
title="Save counts as",
filetypes=[(export_format, export_extension)],
defaultextension=export_extension,
- initialfile=self._initial_file_stem,
+ initialfile=suggested_save_path.name,
+ initialdir=suggested_save_path.parent,
)
self._input_values[EXPORT_FILE] = export_file
if export_file == "":
diff --git a/OTAnalytics/plugin_ui/main_application.py b/OTAnalytics/plugin_ui/main_application.py
index d451d3132..0c0bfb0bb 100644
--- a/OTAnalytics/plugin_ui/main_application.py
+++ b/OTAnalytics/plugin_ui/main_application.py
@@ -127,6 +127,7 @@
RemoveSection,
)
from OTAnalytics.application.use_cases.start_new_project import StartNewProject
+from OTAnalytics.application.use_cases.suggest_save_path import SavePathSuggester
from OTAnalytics.application.use_cases.track_repository import (
AddAllTracks,
ClearAllTracks,
@@ -473,13 +474,15 @@ def start_gui(self, run_config: RunConfiguration) -> None:
AddAllFlows(add_flow),
parse_json,
)
+ get_all_videos = GetAllVideos(video_repository)
+ get_current_project = GetCurrentProject(datastore)
config_has_changed = ConfigHasChanged(
OtconfigHasChanged(
config_parser,
get_sections,
get_flows,
- GetCurrentProject(datastore),
- GetAllVideos(video_repository),
+ get_current_project,
+ get_all_videos,
),
OtflowHasChanged(flow_parser, get_sections, get_flows),
file_state,
@@ -491,6 +494,9 @@ def start_gui(self, run_config: RunConfiguration) -> None:
flow_repository,
create_events,
)
+ save_path_suggester = SavePathSuggester(
+ file_state, get_all_track_files, get_all_videos, get_current_project
+ )
application = OTAnalyticsApplication(
datastore,
track_state,
@@ -524,6 +530,7 @@ def start_gui(self, run_config: RunConfiguration) -> None:
load_otconfig,
config_has_changed,
export_road_user_assignments,
+ save_path_suggester,
)
section_repository.register_sections_observer(cut_tracks_intersecting_section)
section_repository.register_section_changed_observer(
diff --git a/tests/OTAnalytics/application/use_cases/test_suggest_save_path.py b/tests/OTAnalytics/application/use_cases/test_suggest_save_path.py
new file mode 100644
index 000000000..6864b822d
--- /dev/null
+++ b/tests/OTAnalytics/application/use_cases/test_suggest_save_path.py
@@ -0,0 +1,166 @@
+from datetime import datetime
+from pathlib import Path
+from unittest.mock import Mock
+
+import pytest
+
+from OTAnalytics.application.state import ConfigurationFile, FileState
+from OTAnalytics.application.use_cases.get_current_project import GetCurrentProject
+from OTAnalytics.application.use_cases.suggest_save_path import (
+ DATETIME_FORMAT,
+ SavePathSuggester,
+)
+from OTAnalytics.application.use_cases.track_repository import GetAllTrackFiles
+from OTAnalytics.application.use_cases.video_repository import GetAllVideos
+
+FIRST_TRACK_FILE = Path("path/to/tracks/first.ottrk")
+SECOND_TRACK_FILE = Path("path/to/tracks/second.ottrk")
+FIRST_VIDEO_FILE = Path("path/to/videos/first.mp4")
+SECOND_VIDEO_FILE = Path("path/to/videos/second.mp4")
+PROJECT_NAME = "My Project Name"
+DATETIME_NOW = datetime(2024, 1, 2, 3, 4, 5)
+DATETIME_NOW_FORMATTED = DATETIME_NOW.strftime(DATETIME_FORMAT)
+LAST_SAVED_OTCONFIG = Path("path/to/config/last.otconfig")
+LAST_SAVED_OTFLOW = Path("path/to/otflow/last.otflow")
+
+
+def create_file_state(last_saved_config_file: Path | None = None) -> FileState:
+ state = FileState()
+ if last_saved_config_file:
+ state.last_saved_config.set(ConfigurationFile(last_saved_config_file, {}))
+ return state
+
+
+def create_track_file_provider(
+ track_files: set[Path] | None = None,
+) -> GetAllTrackFiles:
+ if track_files:
+ return Mock(return_value=track_files)
+ return Mock(return_value=set())
+
+
+def create_video_provider(video_files: list[Path] | None = None) -> GetAllVideos:
+ videos = []
+ if video_files:
+ for video_file in video_files:
+ video = Mock()
+ video.get_path.return_value = video_file
+ videos.append(video)
+ get_videos = Mock()
+ get_videos.get.return_value = videos
+ return get_videos
+
+
+def create_project_provider(project_name: str = "") -> GetCurrentProject:
+ project = Mock()
+ project.name = project_name
+ get_project = Mock()
+ get_project.get.return_value = project
+ return get_project
+
+
+def create_suggestor(
+ project_name: str = "",
+ last_saved_config: Path | None = None,
+ track_files: set[Path] | None = None,
+ video_files: list[Path] | None = None,
+) -> SavePathSuggester:
+ get_project = create_project_provider(project_name)
+ file_state = create_file_state(last_saved_config)
+ get_track_files = create_track_file_provider(track_files)
+ get_videos = create_video_provider(video_files)
+ return SavePathSuggester(
+ file_state,
+ get_track_files,
+ get_videos,
+ get_project,
+ provide_datetime,
+ )
+
+
+def provide_datetime() -> datetime:
+ return DATETIME_NOW
+
+
+class TestSavePathSuggester:
+ @pytest.mark.parametrize(
+ (
+ "project_name,last_saved_config,track_files,video_files,"
+ "context_file_type,file_type,expected"
+ ),
+ [
+ (
+ "",
+ None,
+ None,
+ None,
+ "",
+ "otconfig",
+ Path.cwd() / f"{DATETIME_NOW_FORMATTED}.otconfig",
+ ),
+ (
+ PROJECT_NAME,
+ LAST_SAVED_OTCONFIG,
+ {FIRST_TRACK_FILE, SECOND_TRACK_FILE},
+ [FIRST_VIDEO_FILE, SECOND_VIDEO_FILE],
+ "",
+ "otconfig",
+ LAST_SAVED_OTCONFIG.with_name(f"{LAST_SAVED_OTCONFIG.stem}.otconfig"),
+ ),
+ (
+ PROJECT_NAME,
+ LAST_SAVED_OTCONFIG,
+ {FIRST_TRACK_FILE, SECOND_TRACK_FILE},
+ [FIRST_VIDEO_FILE, SECOND_VIDEO_FILE],
+ "events",
+ "csv",
+ LAST_SAVED_OTCONFIG.with_name(f"{LAST_SAVED_OTCONFIG.stem}.events.csv"),
+ ),
+ (
+ PROJECT_NAME,
+ None,
+ {FIRST_TRACK_FILE, SECOND_TRACK_FILE},
+ [FIRST_VIDEO_FILE, SECOND_VIDEO_FILE],
+ "events",
+ "csv",
+ FIRST_TRACK_FILE.with_name(
+ f"{PROJECT_NAME}_{DATETIME_NOW_FORMATTED}.events.csv"
+ ),
+ ),
+ (
+ PROJECT_NAME,
+ None,
+ None,
+ [FIRST_VIDEO_FILE, SECOND_VIDEO_FILE],
+ "events",
+ "csv",
+ FIRST_VIDEO_FILE.with_name(
+ f"{PROJECT_NAME}_{DATETIME_NOW_FORMATTED}.events.csv"
+ ),
+ ),
+ (
+ PROJECT_NAME,
+ LAST_SAVED_OTCONFIG,
+ None,
+ [FIRST_VIDEO_FILE, SECOND_VIDEO_FILE],
+ "events",
+ "csv",
+ LAST_SAVED_OTCONFIG.with_name(f"{LAST_SAVED_OTCONFIG.stem}.events.csv"),
+ ),
+ ],
+ )
+ def test_suggest_default(
+ self,
+ project_name: str,
+ last_saved_config: Path | None,
+ track_files: set[Path] | None,
+ video_files: list[Path] | None,
+ context_file_type: str,
+ file_type: str,
+ expected: Path,
+ ) -> None:
+ suggestor = create_suggestor(
+ project_name, last_saved_config, track_files, video_files
+ )
+ suggestion = suggestor.suggest(file_type, context_file_type)
+ assert suggestion == expected
diff --git a/tests/OTAnalytics/plugin_ui/test_cli.py b/tests/OTAnalytics/plugin_ui/test_cli.py
index d22e1de0b..d151a1358 100644
--- a/tests/OTAnalytics/plugin_ui/test_cli.py
+++ b/tests/OTAnalytics/plugin_ui/test_cli.py
@@ -17,8 +17,8 @@
CountingSpecificationDto,
)
from OTAnalytics.application.config import (
+ CONTEXT_FILE_TYPE_COUNTS,
DEFAULT_COUNT_INTERVAL_TIME_UNIT,
- DEFAULT_COUNTS_FILE_STEM,
DEFAULT_COUNTS_FILE_TYPE,
DEFAULT_EVENTLIST_FILE_TYPE,
DEFAULT_NUM_PROCESSES,
@@ -564,7 +564,7 @@ def test_use_video_start_and_end_for_counting(
interval = 15
filename = "filename"
expected_output_file = (
- test_data_tmp_dir / f"{filename}.{DEFAULT_COUNTS_FILE_STEM}_{interval}"
+ test_data_tmp_dir / f"{filename}.{CONTEXT_FILE_TYPE_COUNTS}_{interval}"
f"{DEFAULT_COUNT_INTERVAL_TIME_UNIT}."
f"{DEFAULT_COUNTS_FILE_TYPE}"
)