Skip to content

Commit

Permalink
train mdl mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Aug 7, 2024
1 parent f88452a commit 539a00d
Show file tree
Hide file tree
Showing 5 changed files with 469 additions and 318 deletions.
2 changes: 1 addition & 1 deletion simba/SimBA.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ def __init__(self):
video_process_menu.add_command(label="Convert ROI definitions", compound="left", image=self.menu_icons["roi"]["img"], command=lambda: ConvertROIDefinitionsPopUp(), font=Formats.FONT_REGULAR.value)
convert_data_menu = Menu(video_process_menu)
convert_data_menu.add_command(label="Convert CSV to parquet", compound="left", image=self.menu_icons["parquet"]["img"], command=Csv2ParquetPopUp, font=Formats.FONT_REGULAR.value)
convert_data_menu.add_command(label="Convert parquet o CSV", compound="left", image=self.menu_icons["csv"]["img"], command=Parquet2CsvPopUp, font=Formats.FONT_REGULAR.value)
convert_data_menu.add_command(label="Convert parquet o CSV", compound="left", image=self.menu_icons["csv_grey"]["img"], command=Parquet2CsvPopUp, font=Formats.FONT_REGULAR.value)

video_process_menu.add_cascade(label="Convert working file type...", compound="left", image=self.menu_icons["change"]["img"], menu=convert_data_menu, font=Formats.FONT_REGULAR.value)

Expand Down
2 changes: 1 addition & 1 deletion simba/data_processors/timebins_movement_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def save(self):
self.__create_plots()


# test = TimeBinsMovementCalculator(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini',
# test = TimeBinsMovementCalculator(config_path=r"C:\troubleshooting\two_black_animals_14bp\project_folder\project_config.ini",
# bin_length=0.1,
# plots=True,
# body_parts=['Nose_1'])
Expand Down
94 changes: 53 additions & 41 deletions simba/mixins/plotting_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,58 +1052,70 @@ def joint_plot(
return plot

@staticmethod
def line_plot(
df: pd.DataFrame,
x: str,
y: str,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
title: Optional[str] = None,
save_path: Optional[Union[str, os.PathLike]] = None,
):

check_instance(
source=f"{PlottingMixin.line_plot.__name__} df",
instance=df,
accepted_types=(pd.DataFrame),
)
check_str(
name=f"{PlottingMixin.line_plot.__name__} x",
value=x,
options=tuple(df.columns),
)
check_str(
name=f"{PlottingMixin.line_plot.__name__} y",
value=y,
options=tuple(df.columns),
)

check_valid_lst(
data=list(df[y]),
source=f"{PlottingMixin.line_plot.__name__} y",
valid_dtypes=(np.float32, np.float64, np.int32, np.int64, int, float),
)
sns.set_style("whitegrid", {"grid.linestyle": "--"})
plot = sns.lineplot(data=df, x=x, y=y)
def line_plot(df: pd.DataFrame,
x: str,
y: Union[str, List[str]],
error: Optional[Union[str, List[str]]] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
title: Optional[str] = None,
fig_size: Optional[Tuple[int]] = (10, 6),
error_opacity: Optional[float] = 0.2,
palette: Optional[str] = 'Set1',
save_path: Optional[Union[str, os.PathLike]] = None, ):

check_instance(source=f"{PlottingMixin.line_plot.__name__} df", instance=df, accepted_types=(pd.DataFrame))
check_str(name=f"{PlottingMixin.line_plot.__name__} x", value=x, options=tuple(df.columns))
check_instance(source=f"{PlottingMixin.line_plot.__name__} y", instance=y, accepted_types=(str, list))
sns.set_style(style="whitegrid", rc={"grid.linestyle": "--"})

if isinstance(y, str):
check_str(name=f"{PlottingMixin.line_plot.__name__} y", value=y, options=tuple(df.columns))
check_valid_lst(data=list(df[y]), source=f"{PlottingMixin.line_plot.__name__} y",
valid_dtypes=Formats.NUMERIC_DTYPES.value)
y = [y]
if error is not None:
check_instance(source=f"{PlottingMixin.line_plot.__name__} error", instance=error,
accepted_types=(str,))
check_str(name=f"{PlottingMixin.line_plot.__name__} error", value=error, options=tuple(df.columns))
check_valid_lst(data=list(df[error]), source=f"{PlottingMixin.line_plot.__name__} error",
valid_dtypes=Formats.NUMERIC_DTYPES.value)
error = [error]
else:
for i in y:
check_str(name=f"{PlottingMixin.line_plot.__name__} y", value=i, options=tuple(df.columns))
check_valid_lst(data=list(df[i]), source=f"{PlottingMixin.line_plot.__name__} error",
valid_dtypes=Formats.NUMERIC_DTYPES.value)
if error is not None:
check_instance(source=f"{PlottingMixin.line_plot.__name__} error", instance=error,
accepted_types=(list,))
for i in error:
check_str(name=f"{PlottingMixin.line_plot.__name__} error", value=i, options=tuple(df.columns))
check_valid_lst(data=list(df[i]), source=f"{PlottingMixin.line_plot.__name__} error",
valid_dtypes=Formats.NUMERIC_DTYPES.value)

fig, ax = plt.subplots(figsize=fig_size)
for i in range(len(y)):
sns.lineplot(data=df, x=x, y=y[i], label=y[i], palette=palette)
if error is not None:
ax.fill_between(df[x], df[y[i]] - df[error[i]], df[y[i]] + df[error[i]], alpha=error_opacity)

if x_label is not None:
check_str(name=f"{PlottingMixin.line_plot.__name__} x_label", value=x_label)
plt.xlabel(x_label)
ax.set_xlabel(x_label)
if y_label is not None:
check_str(name=f"{PlottingMixin.line_plot.__name__} y_label", value=y_label)
plt.ylabel(y_label)
ax.set_ylabel(y_label)
if title is not None:
check_str(name=f"{PlottingMixin.line_plot.__name__} title", value=title)
plt.title(title, ha="center", fontsize=15)
ax.set_title(title, ha="center", fontsize=15)
if save_path is not None:
check_str(
name=f"{PlottingMixin.line_plot.__name__} save_path", value=save_path
)
check_str(name=f"{PlottingMixin.line_plot.__name__} save_path", value=save_path)
check_if_dir_exists(in_dir=os.path.dirname(save_path))
plot.figure.savefig(save_path)
plt.savefig(save_path)
plt.close("all")
else:
return plot
return fig

@staticmethod
def make_line_plot(
Expand Down
Loading

0 comments on commit 539a00d

Please sign in to comment.