-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Support vector machine for regression and for classification (#236
) Closes #154. ### Summary of Changes Added a new svm class with a wrapped svc of scikit-learn to both the ml classification- and regression packages and tested the new class when added to the test_classifier.py and test_regressor.py. More than half of the time was spent at the level of testing due to error configurations with poetry and pyproject.toml, git rebase and merges. --------- Co-authored-by: Junior Atemebang Ashu <jay@juniors-mbp.localdomain> Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com> Co-authored-by: Lars Reimann <mail@larsreimann.com>
- Loading branch information
1 parent
0a9ce72
commit 7f6c3bd
Showing
6 changed files
with
196 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
90 changes: 90 additions & 0 deletions
90
src/safeds/ml/classical/classification/_support_vector_machine.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
from sklearn.svm import SVC as sk_SVC # noqa: N811 | ||
|
||
from safeds.ml.classical._util_sklearn import fit, predict | ||
|
||
from ._classifier import Classifier | ||
|
||
if TYPE_CHECKING: | ||
from safeds.data.tabular.containers import Table, TaggedTable | ||
|
||
|
||
class SupportVectorMachine(Classifier): | ||
"""Support vector machine.""" | ||
|
||
def __init__(self) -> None: | ||
self._wrapped_classifier: sk_SVC | None = None | ||
self._feature_names: list[str] | None = None | ||
self._target_name: str | None = None | ||
|
||
def fit(self, training_set: TaggedTable) -> SupportVectorMachine: | ||
""" | ||
Create a copy of this classifier and fit it with the given training data. | ||
This classifier is not modified. | ||
Parameters | ||
---------- | ||
training_set : TaggedTable | ||
The training data containing the feature and target vectors. | ||
Returns | ||
------- | ||
fitted_classifier : SupportVectorMachine | ||
The fitted classifier. | ||
Raises | ||
------ | ||
LearningError | ||
If the training data contains invalid values or if the training failed. | ||
""" | ||
wrapped_classifier = sk_SVC() | ||
fit(wrapped_classifier, training_set) | ||
|
||
result = SupportVectorMachine() | ||
result._wrapped_classifier = wrapped_classifier | ||
result._feature_names = training_set.features.column_names | ||
result._target_name = training_set.target.name | ||
|
||
return result | ||
|
||
def predict(self, dataset: Table) -> TaggedTable: | ||
""" | ||
Predict a target vector using a dataset containing feature vectors. The model has to be trained first. | ||
Parameters | ||
---------- | ||
dataset : Table | ||
The dataset containing the feature vectors. | ||
Returns | ||
------- | ||
table : TaggedTable | ||
A dataset containing the given feature vectors and the predicted target vector. | ||
Raises | ||
------ | ||
ModelNotFittedError | ||
If the model has not been fitted yet. | ||
DatasetContainsTargetError | ||
If the dataset contains the target column already. | ||
DatasetMissesFeaturesError | ||
If the dataset misses feature columns. | ||
PredictionError | ||
If predicting with the given dataset failed. | ||
""" | ||
return predict(self._wrapped_classifier, dataset, self._feature_names, self._target_name) | ||
|
||
def is_fitted(self) -> bool: | ||
""" | ||
Check if the classifier is fitted. | ||
Returns | ||
------- | ||
is_fitted : bool | ||
Whether the classifier is fitted. | ||
""" | ||
return self._wrapped_classifier is not None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
90 changes: 90 additions & 0 deletions
90
src/safeds/ml/classical/regression/_support_vector_machine.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
from sklearn.svm import SVR as sk_SVR # noqa: N811 | ||
|
||
from safeds.ml.classical._util_sklearn import fit, predict | ||
|
||
from ._regressor import Regressor | ||
|
||
if TYPE_CHECKING: | ||
from safeds.data.tabular.containers import Table, TaggedTable | ||
|
||
|
||
class SupportVectorMachine(Regressor): | ||
"""Support vector machine.""" | ||
|
||
def __init__(self) -> None: | ||
self._wrapped_regressor: sk_SVR | None = None | ||
self._feature_names: list[str] | None = None | ||
self._target_name: str | None = None | ||
|
||
def fit(self, training_set: TaggedTable) -> SupportVectorMachine: | ||
""" | ||
Create a copy of this regressor and fit it with the given training data. | ||
This regressor is not modified. | ||
Parameters | ||
---------- | ||
training_set : TaggedTable | ||
The training data containing the feature and target vectors. | ||
Returns | ||
------- | ||
fitted_regressor : SupportVectorMachine | ||
The fitted regressor. | ||
Raises | ||
------ | ||
LearningError | ||
If the training data contains invalid values or if the training failed. | ||
""" | ||
wrapped_regressor = sk_SVR() | ||
fit(wrapped_regressor, training_set) | ||
|
||
result = SupportVectorMachine() | ||
result._wrapped_regressor = wrapped_regressor | ||
result._feature_names = training_set.features.column_names | ||
result._target_name = training_set.target.name | ||
|
||
return result | ||
|
||
def predict(self, dataset: Table) -> TaggedTable: | ||
""" | ||
Predict a target vector using a dataset containing feature vectors. The model has to be trained first. | ||
Parameters | ||
---------- | ||
dataset : Table | ||
The dataset containing the feature vectors. | ||
Returns | ||
------- | ||
table : TaggedTable | ||
A dataset containing the given feature vectors and the predicted target vector. | ||
Raises | ||
------ | ||
ModelNotFittedError | ||
If the model has not been fitted yet. | ||
DatasetContainsTargetError | ||
If the dataset contains the target column already. | ||
DatasetMissesFeaturesError | ||
If the dataset misses feature columns. | ||
PredictionError | ||
If predicting with the given dataset failed. | ||
""" | ||
return predict(self._wrapped_regressor, dataset, self._feature_names, self._target_name) | ||
|
||
def is_fitted(self) -> bool: | ||
""" | ||
Check if the regressor is fitted. | ||
Returns | ||
------- | ||
is_fitted : bool | ||
Whether the regressor is fitted. | ||
""" | ||
return self._wrapped_regressor is not None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters