Skip to content

Commit

Permalink
Merge pull request #749 from theislab/fix/cat_dtype
Browse files Browse the repository at this point in the history
update check for categorical dtype
  • Loading branch information
ArinaDanilina authored Oct 7, 2024
2 parents b516bd9 + d9f65e6 commit 578e3eb
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/notebooks
3 changes: 1 addition & 2 deletions src/moscot/plotting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import numpy as np
import pandas as pd
from pandas.api.types import is_categorical_dtype
from sklearn.preprocessing import MinMaxScaler

import matplotlib as mpl
Expand Down Expand Up @@ -523,7 +522,7 @@ def _color_transition(c1: str, c2: str, num: int, alpha: float) -> List[str]:
def _create_col_colors(adata: AnnData, obs_col: str, subset: Union[str, List[str]]) -> Optional[mpl.colors.Colormap]:
if isinstance(subset, list):
subset = subset[0]
if not is_categorical_dtype(adata.obs[obs_col]):
if not isinstance(adata.obs[obs_col].dtype, pd.CategoricalDtype):
raise TypeError(f"`adata.obs[{obs_col!r}] must be of categorical type.")

for i, cat in enumerate(adata.obs[obs_col].cat.categories):
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/time/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import numpy as np
import pandas as pd
from pandas.api.types import infer_dtype, is_categorical_dtype, is_numeric_dtype
from pandas.api.types import infer_dtype, is_numeric_dtype

from anndata import AnnData

Expand Down Expand Up @@ -1034,7 +1034,7 @@ def temporal_key(self: TemporalMixinProtocol[K, B], key: Optional[str]) -> None:
raise KeyError(f"Unable to find temporal key in `adata.obs[{key!r}]`.")
self.adata.obs[key] = self.adata.obs[key].astype("category")
col = self.adata.obs[key]
if not (is_categorical_dtype(col) and is_numeric_dtype(col.cat.categories)):
if not (isinstance(col.dtype, pd.CategoricalDtype) and is_numeric_dtype(col.cat.categories)):
raise TypeError(
f"Expected `adata.obs[{key!r}]` to be categorical with numeric categories, "
f"found `{infer_dtype(col)}`."
Expand Down

0 comments on commit 578e3eb

Please sign in to comment.