Skip to content

Commit

Permalink
feat: add support and tests for nested data
Browse files Browse the repository at this point in the history
  • Loading branch information
ireneisdoomed committed Apr 12, 2023
1 parent 5414538 commit 51ad0aa
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 31 deletions.
10 changes: 6 additions & 4 deletions src/otg/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING

from otg.common.schemas import flatten_schema

if TYPE_CHECKING:
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType
Expand Down Expand Up @@ -64,10 +66,10 @@ def validate_schema(self: Dataset) -> None:
Raises:
ValueError: DataFrame schema is not valid
"""
expected_schema = self._schema # type: ignore[attr-defined]
expected_fields = [(field.name, field.dataType) for field in expected_schema]
observed_schema = self._df.schema # type: ignore[attr-defined]
observed_fields = [(field.name, field.dataType) for field in observed_schema]
expected_schema = self._schema
expected_fields = flatten_schema(expected_schema)
observed_schema = self._df.schema
observed_fields = flatten_schema(observed_schema)

# Unexpected fields in dataset
if unexpected_struct_fields := [
Expand Down
102 changes: 75 additions & 27 deletions tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from pathlib import Path
from typing import TYPE_CHECKING

import pyspark.sql.functions as f
import pytest
from pyspark.sql.types import StringType, StructField, StructType
from pyspark.sql.types import ArrayType, StringType, StructField, StructType

from otg.dataset.dataset import Dataset

Expand Down Expand Up @@ -44,42 +45,89 @@ def test_schema(schema_json: str) -> None:
class TestValidateSchema:
"""Test validate_schema method."""

@pytest.fixture(scope="class")
def mock_expected_schema(self: TestValidateSchema) -> StructType:
"""Mock expected schema."""
return StructType(
[
StructField("studyLocusId", StringType(), nullable=False),
StructField("geneId", StringType(), nullable=False),
]
)

mock_expected_schema = StructType(
[
StructField("studyLocusId", StringType(), nullable=False),
StructField("geneId", StringType(), nullable=False),
]
)
mock_expected_nested_schema = StructType(
[
StructField("studyLocusId", StringType(), nullable=False),
StructField(
"credibleSet",
ArrayType(
StructType([StructField("tagVariantId", StringType(), False)])
),
False,
),
]
)
mock_observed_data = [("A", "ENSG0001"), ("B", "ENSG0002")]
mock_observed_nested_data = [
("A", [{"tagVariantId": "varA"}]),
("B", [{"tagVariantId": "varB"}]),
]

@pytest.mark.parametrize(
("observed_data", "expected_schema"),
[
(mock_observed_data, mock_expected_schema),
(mock_observed_nested_data, mock_expected_nested_schema),
],
)
def test_validate_schema_extra_field(
self: TestValidateSchema, spark: SparkSession, mock_expected_schema: StructType
self: TestValidateSchema,
spark: SparkSession,
observed_data: list,
expected_schema: StructType,
) -> None:
"""Test that validate_schema raises an error if the observed schema has an extra field."""
df = spark.createDataFrame(
[("A", "ENSG0001", "extra1"), ("B", "ENSG0002", "extra2")],
schema=["studyLocusId", "geneId", "extraField"],
)
observed_data,
schema=expected_schema,
).withColumn("extraField", f.lit("extra"))
with pytest.raises(ValueError, match="extraField"):
Dataset(df, mock_expected_schema)

Dataset(df, expected_schema)

@pytest.mark.parametrize(
("observed_data", "expected_schema"),
[
(mock_observed_data, mock_expected_schema),
(mock_observed_nested_data, mock_expected_nested_schema),
],
)
def test_validate_schema_missing_field(
self: TestValidateSchema, spark: SparkSession, mock_expected_schema: StructType
self: TestValidateSchema,
spark: SparkSession,
observed_data: list,
expected_schema: StructType,
) -> None:
"""Test that validate_schema raises an error if the observed schema is missing a required field."""
df = spark.createDataFrame([("A",), ("B",)], schema=["geneId"])
with pytest.raises(ValueError, match="geneId"):
Dataset(df, mock_expected_schema)

df = spark.createDataFrame(
observed_data,
schema=expected_schema,
).drop("studyLocusId")
with pytest.raises(ValueError, match="studyLocusId"):
Dataset(df, expected_schema)

@pytest.mark.parametrize(
("observed_data", "expected_schema"),
[
(mock_observed_data, mock_expected_schema),
(mock_observed_nested_data, mock_expected_nested_schema),
],
)
def test_validate_schema_different_datatype(
self: TestValidateSchema, spark: SparkSession, mock_expected_schema: StructType
self: TestValidateSchema,
spark: SparkSession,
observed_data: list,
expected_schema: StructType,
) -> None:
"""Test that validate_schema raises an error if any field in the observed schema has a different type than expected."""
df = spark.createDataFrame(
[(2, "ENSG0001"), (1, "ENSG0002")],
schema=["studyLocusId", "geneId"],
)
observed_data,
schema=expected_schema,
).withColumn("studyLocusId", f.lit(1))
with pytest.raises(ValueError, match="studyLocusId"):
Dataset(df, mock_expected_schema)
Dataset(df, expected_schema)

0 comments on commit 51ad0aa

Please sign in to comment.