Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with nan value in median of correlations is fixed #439

Merged
merged 11 commits into from
Aug 22, 2024
20 changes: 18 additions & 2 deletions src/syngen/ml/metrics/metrics_classes/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,23 @@ def calculate_all(self, categ_columns: List[str], cont_columns: List[str]):
)
)
self.corr_score = self.original_heatmap - self.synthetic_heatmap
self.corr_score = self.corr_score.dropna(how="all").dropna(how="all", axis=1)
self.corr_score = (
self.corr_score
.dropna(how="all")
.dropna(how="all", axis=1)
)

# check if there are any nans left in corr_score
if self.corr_score.isna().values.any():
# mask for NaNs in both original_heatmap and synthetic_heatmap
nan_mask = (
np.isnan(self.original_heatmap) &
np.isnan(self.synthetic_heatmap)
)

# Set the NaN values in corr_score to 0 where both
# original_heatmap and synthetic_heatmap have NaNs
self.corr_score[nan_mask] = 0

if self.plot:
plt.clf()
Expand All @@ -294,7 +310,7 @@ def calculate_all(self, categ_columns: List[str], cont_columns: List[str]):

@staticmethod
def __calculate_correlations(data):
return abs(data.corr())
return abs(data.corr(method="spearman"))


class BivariateMetric(BaseMetric):
Expand Down