Skip to content

Commit

Permalink
issue #750 regularization strength for logistic classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
grefrathc committed Jun 21, 2024
1 parent b81bcd6 commit 1b37cd6
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/safeds/ml/classical/classification/_logistic_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ class LogisticClassifier(Classifier):
# Dunder methods
# ------------------------------------------------------------------------------------------------------------------

def __init__(self) -> None:
def __init__(self, c: float=1.0) -> None:
super().__init__()

self.c = c
def __hash__(self) -> int:
return _structural_hash(
super().__hash__(),
Expand All @@ -30,12 +30,13 @@ def __hash__(self) -> int:
# ------------------------------------------------------------------------------------------------------------------

def _clone(self) -> LogisticClassifier:
return LogisticClassifier()

return LogisticClassifier(c=self.c)
def _get_sklearn_model(self) -> ClassifierMixin:
from sklearn.linear_model import LogisticRegression as SklearnLogisticRegression

return SklearnLogisticRegression(
random_state=_get_random_seed(),
n_jobs=-1,
)
C=self.c,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest
from safeds.data.labeled.containers import TabularDataset
from safeds.data.tabular.containers import Table
from safeds.ml.classical.classification import LogisticClassifier


@pytest.fixture()
def training_set() -> TabularDataset:
table = Table({"col1": [1, 2, 3, 4], "col2": [1, 2, 3, 4]})
return table.to_tabular_dataset(target_name="col1")

class TestC:
def test_should_be_passed_to_fitted_model(self, training_set: TabularDataset) -> None:
fitted_model = LogisticClassifier(c=2).fit(training_set)
assert fitted_model.c == 2

def test_should_be_passed_to_sklearn(self, training_set: TabularDataset) -> None:
fitted_model = LogisticClassifier(c=2).fit(training_set)
assert fitted_model._wrapped_model is not None
assert fitted_model._wrapped_model.C == 2

def test_clone(self, training_set: TabularDataset) -> None:
fitted_model = LogisticClassifier(c=2).fit(training_set)
cloned_classifier = fitted_model._clone()
assert isinstance(cloned_classifier, LogisticClassifier)
assert cloned_classifier.c == fitted_model.c

0 comments on commit 1b37cd6

Please sign in to comment.