Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

264 improve normality test #267

Merged
merged 12 commits into from
Jan 2, 2024
4 changes: 4 additions & 0 deletions eis_toolkit/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,7 @@ class NonNumericDataException(Exception):

class InvalidCompositionException(Exception):
"""Exception error class for when the data is not in suitable form for compositional data transforms."""


class SampleSizeExceededException(Exception):
"""Exception error class for when the data exceeds maximum sample size."""
67 changes: 53 additions & 14 deletions eis_toolkit/exploratory_analyses/statistical_tests.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np
import pandas as pd
from beartype import beartype
from beartype.typing import Literal, Optional, Sequence
from beartype.typing import Dict, Literal, Optional, Sequence, Tuple, Union
from scipy.stats import chi2_contingency, shapiro

from eis_toolkit import exceptions
from eis_toolkit.utilities.checks.dataframe import check_columns_valid, check_empty_dataframe
from eis_toolkit.utilities.checks.dataframe import check_columns_numeric, check_columns_valid, check_empty_dataframe


@beartype
Expand Down Expand Up @@ -52,27 +53,65 @@ def chi_square_test(data: pd.DataFrame, target_column: str, columns: Optional[Se


@beartype
def normality_test(data: pd.DataFrame) -> dict:
def normality_test(
data: Union[pd.DataFrame, np.ndarray], columns: Optional[Sequence[str]] = None
) -> Union[Dict[str, Tuple[float, float]], Tuple[float, float]]:
"""Compute Shapiro-Wilk test for normality on the input data.

It is assumed that the input data is normally distributed and numeric, i.e. integers or floats.

Args:
data: Dataframe containing the input data.
data: Dataframe or Numpy array containing the input data.
columns: Optional columns to be used for testing.

Returns:
Test statistics for each variable.
Test statistics for each variable, output differs based on input data type.
Numpy array input returns a Tuple of statistic and p_value.
Dataframe input returns a dictionary where keys are column names
and values are tuples containing the statistic and p-value.

Raises:
EmptyDataFrameException: The input Dataframe is empty.
EmptyDataException: The input data is empty.
InvalidColumnException: All selected columns were not found in the input data.
NonNumericDataException: Selected data or columns contains non-numeric data.
SampleSizeExceededException: Input data exceeds the maximum of 5000 samples.
"""
if check_empty_dataframe(data):
raise exceptions.EmptyDataFrameException("The input Dataframe is empty.")

statistics = {}
for column in data.columns:
statistic, p_value = shapiro(data[column])
statistics[column] = (statistic, p_value)
if isinstance(data, pd.DataFrame):
if check_empty_dataframe(data):
raise exceptions.EmptyDataException("The input Dataframe is empty.")

if columns is not None:
if not check_columns_valid(data, columns):
raise exceptions.InvalidColumnException("All selected columns were not found in the input DataFrame.")
if not check_columns_numeric(data, columns):
raise exceptions.NonNumericDataException("The selected columns contain non-numeric data.")

data = data[columns].dropna()

else:
if not check_columns_numeric(data, data.columns):
raise exceptions.NonNumericDataException("The input data contain non-numeric data.")
columns = data.columns

for column in columns:
if len(data[column]) > 5000:
raise exceptions.SampleSizeExceededException(
f"Sample size for '{column}' exceeds the limit of 5000 samples."
)
statistic, p_value = shapiro(data[column])
statistics[column] = (statistic, p_value)

else:
if data.size == 0:
raise exceptions.EmptyDataException("The input numpy array is empty.")
if len(data) > 5000:
raise exceptions.SampleSizeExceededException("Sample size exceeds the limit of 5000 samples.")

nan_mask = np.isnan(data)
data = data[~nan_mask]

flattened_data = data.flatten()
statistic, p_value = shapiro(flattened_data)
statistics = (statistic, p_value)

return statistics

Expand Down
2 changes: 1 addition & 1 deletion eis_toolkit/utilities/checks/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def check_columns_numeric(df: pd.DataFrame, columns: Sequence[str]) -> bool:
Returns:
True if all columns are numeric, otherwise False.
"""
columns_numeric = df.columns.select_dtypes(include="number").columns.to_list()
columns_numeric = df[columns].select_dtypes(include="number").columns.to_list()
return all(column in columns_numeric for column in columns)


Expand Down
41 changes: 38 additions & 3 deletions tests/exploratory_analyses/statistical_tests_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,41 @@
)

data = np.array([[0, 1, 2, 1], [2, 0, 1, 2], [2, 1, 0, 2], [0, 1, 2, 1]])
missing_data = np.array([[0, 1, 2, 1, np.nan], [2, 0, 1, 2, np.nan], [2, 1, 0, 2, np.nan], [0, 1, 2, 1, np.nan]])
non_numeric_data = np.array([[0, 1, 2, 1], ["a", "b", "c", "d"], [3, 2, 1, 0], ["c", "d", "b", "a"]])
numeric_data = pd.DataFrame(data, columns=["a", "b", "c", "d"])
non_numeric_df = pd.DataFrame(non_numeric_data, columns=["a", "b", "c", "d"])
missing_values_df = pd.DataFrame(missing_data, columns=["a", "b", "c", "d", "na"])
categorical_data = pd.DataFrame({"e": [0, 0, 1, 1], "f": [True, False, True, True]})
target_column = "e"
np.random.seed(42)
large_data = np.random.normal(size=5001)
large_df = pd.DataFrame(large_data, columns=["a"])


def test_chi_square_test():
"""Test that returned statistics for independence are correct."""
output_statistics = chi_square_test(data=categorical_data, target_column=target_column, columns=("f"))
output_statistics = chi_square_test(data=categorical_data, target_column=target_column, columns=["f"])
np.testing.assert_array_equal((output_statistics["f"]), (0.0, 1.0, 1))


def test_normality_test():
"""Test that returned statistics for normality are correct."""
output_statistics = normality_test(data=numeric_data)
output_statistics = normality_test(data=numeric_data, columns=["a"])
np.testing.assert_array_almost_equal(output_statistics["a"], (0.72863, 0.02386), decimal=5)
output_statistics = normality_test(data=data)
np.testing.assert_array_almost_equal(output_statistics, (0.8077, 0.00345), decimal=5)
output_statistics = normality_test(data=np.array([0, 2, 2, 0]))
np.testing.assert_array_almost_equal(output_statistics, (0.72863, 0.02386), decimal=5)


def test_normality_test_missing_data():
"""Test that input with missing data returns statistics correctly."""
output_statistics = normality_test(data=missing_data)
np.testing.assert_array_almost_equal(output_statistics, (0.8077, 0.00345), decimal=5)
output_statistics = normality_test(data=np.array([0, 2, 2, 0, np.nan]))
np.testing.assert_array_almost_equal(output_statistics, (0.72863, 0.02386), decimal=5)
output_statistics = normality_test(data=missing_values_df, columns=["a", "b"])
np.testing.assert_array_almost_equal(output_statistics["a"], (0.72863, 0.02386), decimal=5)


Expand Down Expand Up @@ -60,14 +81,28 @@ def test_covariance_matrix():
def test_empty_df():
"""Test that empty DataFrame raises the correct exception."""
empty_df = pd.DataFrame()
with pytest.raises(exceptions.EmptyDataFrameException):
with pytest.raises(exceptions.EmptyDataException):
normality_test(data=empty_df)


def test_max_samples():
"""Test that sample count > 5000 raises the correct exception."""
with pytest.raises(exceptions.SampleSizeExceededException):
normality_test(data=large_data)
normality_test(data=large_df, columns=["a"])


def test_invalid_columns():
"""Test that invalid column name in raises the correct exception."""
with pytest.raises(exceptions.InvalidParameterValueException):
chi_square_test(data=categorical_data, target_column=target_column, columns=["f", "x"])
normality_test(data=numeric_data, columns=["e", "f"])


def test_non_numeric_data():
"""Test that non-numeric data raises the correct exception."""
with pytest.raises(exceptions.NonNumericDataException):
normality_test(data=non_numeric_df, columns=["a"])


def test_invalid_target_column():
Expand Down
Loading