Skip to content

Commit

Permalink
Fix graph and fit labels shape mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
KulikDM committed Jan 11, 2025
1 parent 80db464 commit 044f064
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 6 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ repos:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-toml
- id: double-quote-string-fixer
- id: requirements-txt-fixer
- id: name-tests-test
always_run: true
args: [--pytest-test-first]
Expand All @@ -21,7 +21,7 @@ repos:
name: Format docstrings

- repo: https://github.com/asottile/pyupgrade
rev: v3.18.0
rev: v3.19.1
hooks:
- id: pyupgrade
args: [--py38-plus]
Expand All @@ -42,7 +42,7 @@ repos:
name: Sort imports

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.8.3
rev: v0.8.6
hooks:
- id: ruff
args: [--exit-non-zero-on-fix, --fix, --line-length=180, --extend-ignore=F401]
Expand Down
1 change: 1 addition & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ v<0.0.5>, <12/16/2024> -- chore: format and linted code
v<0.0.6>, <12/20/2024> -- chore: unskip huggingface tests
v<0.0.6>, <12/21/2024> -- chore: switch from pydantic.v1 to pydantic
v<0.0.6>, <12/21/2024> -- chore: added init tests
v<0.0.6>, <01/11/2025> -- fix: mismatch of input graph to output fit labels shape for GraphOutlierDetector
18 changes: 15 additions & 3 deletions muzlin/anomaly/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class GraphOutlierDetector(BaseEstimator, OutlierMixin, BaseModel):
threshold_ (float): The percentile used to threshold inliers from outliers in the model.
labels_ (array-like): The array of the fitted binary labels for the training data.
reg_R2_ (float): The R2 score of the fitted regression model on the training data.
rm_indices_ (list): Removed node indices from the networkx graph prior to fitting. Unconnected nodes cannot be used during fitting and will be assigned outlier labels in the fitted output.
"""

Expand All @@ -66,6 +67,7 @@ class GraphOutlierDetector(BaseEstimator, OutlierMixin, BaseModel):
threshold_: float = Field(default=None, exclude=True)
labels_: Type[np.ndarray] = Field(default=None, exclude=True)
reg_R2_: float = Field(default=None, exclude=True)
rm_indices_: list = Field(default=None, exclude=True)

class Config:
arbitrary_types_allowed = True
Expand Down Expand Up @@ -110,6 +112,8 @@ def fit(self, graph: Gtype, y=None):
y (array-like, or None, optional): Not required.
"""

len_g = len(graph)

# Prepare graph and vector data before fitting
graph_torch, vectors = self._preprocess_graph(graph)

Expand Down Expand Up @@ -147,10 +151,16 @@ def fit(self, graph: Gtype, y=None):
self.threshold_ = (np.percentile(scores, contam*100) if
contam <= 1.0 else contam * np.max(scores))

labels = (scores > self.threshold_).astype('int').ravel()
fitted_labels = (scores > self.threshold_).astype('int').ravel()

# Assure that the lengths of the output labels and input graph nodes match
full_labels = np.ones(len_g)
full_labels[np.setdiff1d(
np.arange(len_g), self.rm_indices_)] = fitted_labels

setattr(self.pipeline, 'threshold_', self.threshold_)
setattr(self.pipeline, 'labels_', labels)
setattr(self.pipeline, 'labels_', full_labels)
setattr(self.pipeline, 'rm_indices_', self.rm_indices_)
setattr(self.regressor, 'reg_R2_', reg_R2_)

# Relog model to save attr
Expand Down Expand Up @@ -209,7 +219,7 @@ def _preprocess_graph(self, graph: Gtype) -> Tuple[Ttype, XType]:
nodes_and_indices = [(index, node) for index, (node, degree) in
enumerate(dict(graph.degree()).items()) if degree == 0]

indices, nodes_to_remove = zip(
self.rm_indices_, nodes_to_remove = zip(
*nodes_and_indices) if nodes_and_indices else ([], [])

graph.remove_nodes_from(nodes_to_remove)
Expand Down Expand Up @@ -255,6 +265,8 @@ def _check_is_initalized(self):

self.threshold_ = self.pipeline.threshold_
self.labels_ = self.pipeline.labels_
self.rm_indices_ = self.pipeline.rm_indices_ if hasattr(
self.pipeline, 'rm_indices_') else []
self.decision_scores_ = self.pipeline.named_steps['detector'].decision_score_.numpy(
)
self.reg_R2_ = self.regressor.reg_R2_
3 changes: 3 additions & 0 deletions tests/anomaly/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def pipeline_checks(detector, X, y):
"""Standard checks for a fitted GraphOutlierDetector."""
assert hasattr(detector, 'threshold_')
assert hasattr(detector, 'labels_')
assert hasattr(detector, 'rm_indices_')
assert hasattr(
detector.pipeline.named_steps['detector'], 'decision_score_')
assert hasattr(detector.regressor, 'reg_R2_')
Expand Down Expand Up @@ -148,6 +149,8 @@ def test_low_degree_extra_attr(self, outlier_detector):
decision_scores = detector.pipeline.named_steps['detector'].decision_score_.numpy(
)
assert decision_scores.shape[0] < X.shape[0]
assert len(detector.rm_indices_) > 0
assert len(detector.labels_) == len(X)

def test_missing_x_attr(self, outlier_detector):
"""Test fit failure due to missing x node attribute."""
Expand Down

0 comments on commit 044f064

Please sign in to comment.