Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 27, 2024
1 parent 7ab27b7 commit 6e05343
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
1 change: 1 addition & 0 deletions tests/data_connector/test_dataframe_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from sdgx.data_connectors.dataframe_connector import DataFrameConnector


@pytest.fixture
def data_for_test():
return pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
Expand Down
30 changes: 18 additions & 12 deletions tests/models/test_nlabelencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,37 @@
import pandas as pd
import pytest

from sdgx.models.components.sdv_rdt.transformers.categorical import NormalizedLabelEncoder
from sdgx.models.components.sdv_rdt.transformers.categorical import (
NormalizedLabelEncoder,
)

@pytest.fixture(scope='module')

@pytest.fixture(scope="module")
def data_test():
return pd.DataFrame({
'x': [str(i) for i in range(100)],
'y': [str(-i) for i in range(50)]*2,
'z': [str(i) for i in range(25)]*4
}, columns=['x', 'y', 'z'])
return pd.DataFrame(
{
"x": [str(i) for i in range(100)],
"y": [str(-i) for i in range(50)] * 2,
"z": [str(i) for i in range(25)] * 4,
},
columns=["x", "y", "z"],
)


def test_encoder(
data_test: pd.DataFrame
):
def test_encoder(data_test: pd.DataFrame):

for col in ['x', 'y', 'z']:
for col in ["x", "y", "z"]:
nlabel_encoder = NormalizedLabelEncoder()
nlabel_encoder.fit(data_test, col)
td = nlabel_encoder.transform(data_test.copy())
rd = nlabel_encoder.reverse_transform(td.copy())
td.rename(columns={f'{col}.value': f'{col}'}, inplace=True)
td.rename(columns={f"{col}.value": f"{col}"}, inplace=True)
assert (rd[col].sort_values().values == data_test[col].sort_values().values).any()
assert (td[col] >= 0).any()
assert (td[col] <= 1).any()
assert td[col].shape == data_test[col].shape
assert len(td[col].unique()) == len(data_test[col].unique())


if __name__ == "__main__":
pytest.main(["-vv", "-s", __file__])

0 comments on commit 6e05343

Please sign in to comment.