Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix handling of categorical columns #457

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
154f6a8
temporary changes
Jul 3, 2024
6edc279
resolve conflicts
Aug 13, 2024
e61159f
refactor the class Dataset
Aug 13, 2024
47c69f0
refactor the class Dataset
Aug 13, 2024
025c1f8
revert changes in the class Dataset
Aug 13, 2024
2b5c6e9
temporary changes in 'tests/unit/dataset'
Aug 14, 2024
37f0fbb
Merge branch 'main' of github.com:tdspora/syngen into EPMCTDM-7032_fi…
Aug 16, 2024
b0761b7
update unit tests, refactor the class Dataset
Aug 16, 2024
0cfce76
provide a fix for handling string columns containing values that migh…
Aug 20, 2024
238b858
resolve conflicts
Aug 20, 2024
ae769a7
minor changes in 'tests/unit/dataset'
Aug 20, 2024
4445f0d
update unit tests
Aug 20, 2024
0f50fa2
update unit tests, fix issues raised by 'flake8'
Aug 20, 2024
6c5ed86
refactor 'ml/data_loaders'
Aug 22, 2024
bfccd87
resolve conflicts
Aug 29, 2024
77358ce
refactor the code
Aug 30, 2024
7f6fb42
refactor the code
Sep 3, 2024
14329ff
resolve conflicts
Sep 3, 2024
57dda79
refactor the code
Sep 3, 2024
8ae24c0
refactor the code, update unit tests
Sep 3, 2024
ca0cd0a
resolve conflicts
Sep 4, 2024
f869cf6
update unit tests
Sep 4, 2024
6c0a78a
update the process of the identification of categorical columns
Sep 6, 2024
98b4caf
update 'VERSION'
Sep 6, 2024
40bb03a
resolve conflicts
Sep 6, 2024
ef027f5
minor changes in 'tests/unit/config'
Sep 6, 2024
57a1fda
update unit tests, fix issues raised by 'flake8'
Sep 6, 2024
43e4f16
resolve conflicts
Sep 27, 2024
c6afb41
fix the issue raised by 'flake8'
Sep 27, 2024
2294647
update 'VERSION'
Sep 27, 2024
ba1143a
resolve conflicts
Oct 2, 2024
8e53b60
refactor the class Dataset, refactor the method get_nan_labels in 'ml…
Oct 2, 2024
7c4a7c4
refactor the code in the class Dataset
Oct 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/syngen/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.33
0.9.36rc0
91 changes: 53 additions & 38 deletions src/syngen/ml/config/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ class TrainConfig:
print_report: bool
batch_size: int
loader: Optional[Callable[[str], pd.DataFrame]]
data: pd.DataFrame = field(init=False)
paths: Dict = field(init=False)
row_subset: int = field(init=False)
schema: Dict = field(init=False)
original_schema: Dict = field(init=False)
slugify_table_name: str = field(init=False)
columns: List = field(init=False)
dropped_columns: Set = field(init=False)
Expand All @@ -51,11 +53,12 @@ def __getstate__(self) -> Dict:
return instance

def preprocess_data(self):
data, self.schema = self._extract_data()
self.columns = list(data.columns)
data = self._remove_empty_columns(data)
self._mark_removed_columns(data)
self._prepare_data(data)
self._extract_data()
self._save_original_schema()
self.columns = list(self.data.columns)
self._remove_empty_columns()
self._mark_removed_columns()
self._prepare_data()

def to_dict(self) -> Dict:
"""
Expand Down Expand Up @@ -110,21 +113,22 @@ def _load_source(self) -> Tuple[pd.DataFrame, Dict]:
table_name=self.table_name
).fetch_data()
else:
return DataLoader(self.source).load_data()
data_loader = DataLoader(self.source)
self.original_schema = data_loader.original_schema
return data_loader.load_data()

def _remove_empty_columns(self, data: pd.DataFrame) -> pd.DataFrame:
def _remove_empty_columns(self):
"""
Remove completely empty columns from dataframe
"""
data_columns = set(data.columns)
data = data.dropna(how="all", axis=1)
data_columns = set(self.data.columns)
self.data = self.data.dropna(how="all", axis=1)

self.dropped_columns = data_columns - set(data.columns)
self.dropped_columns = data_columns - set(self.data.columns)
if len(self.dropped_columns) > 0:
logger.info(f"Empty columns - {', '.join(self.dropped_columns)} were removed")
return data

def _mark_removed_columns(self, data: pd.DataFrame):
def _mark_removed_columns(self):
"""
Mark removed columns in the schema
"""
Expand All @@ -133,27 +137,26 @@ def _mark_removed_columns(self, data: pd.DataFrame):
self.schema["fields"] = {column: "removed" for column in self.dropped_columns}
else:
for column, data_type in self.schema.get("fields", {}).items():
if column not in data.columns:
if column not in self.data.columns:
self.schema["fields"][column] = "removed"

def _check_if_data_is_empty(self, data: pd.DataFrame):
def _check_if_data_is_empty(self):
"""
Check if the provided data is empty
"""
if data.shape[0] < 1:
if self.data.shape[0] < 1:
raise ValueError(
f"The empty table was provided. Unable to train the table - '{self.table_name}'"
)

def _extract_data(self) -> Tuple[pd.DataFrame, Dict]:
def _extract_data(self):
"""
Extract data and schema necessary for training process
"""
data, schema = self._load_source()
self._check_if_data_is_empty(data)
return data, schema
self.data, self.schema = self._load_source()
self._check_if_data_is_empty()

def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
def _preprocess_data(self):
"""
Preprocess data and set the parameter "row_subset" for training process
"""
Expand All @@ -168,15 +171,15 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
logger.warning(f"The 'row_limit' {warning_message}")
else:
if self.drop_null:
if not data.dropna().empty:
initial_data = data
data = data.dropna()
if count_of_dropped_rows := initial_data.shape[0] - data.shape[0]:
if not self.data.dropna().empty:
initial_data = self.data
self.data = self.data.dropna()
if count_of_dropped_rows := initial_data.shape[0] - self.data.shape[0]:
logger.info(
f"As the parameter 'drop_null' set to 'True', "
f"{count_of_dropped_rows} rows of the table - '{self.table_name}' "
f"that have empty values have been dropped. "
f"The count of remained rows is {data.shape[0]}."
f"The count of remained rows is {self.data.shape[0]}."
)
else:
logger.warning(
Expand All @@ -185,39 +188,47 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
)

if self.row_limit:
self.row_subset = min(self.row_limit, len(data))
self.row_subset = min(self.row_limit, len(self.data))

data = data.sample(n=self.row_subset)
self.data = self.data.sample(n=self.row_subset)

if len(data) < 100:
if len(self.data) < 100:
logger.warning(
"The input table is too small to provide any meaningful results. "
"Please consider 1) disable drop_null argument, 2) provide bigger table"
)
elif len(data) < 500:
elif len(self.data) < 500:
logger.warning(
f"The amount of data is {len(data)} rows. It seems that it isn't enough "
f"The amount of data is {len(self.data)} rows. It seems that it isn't enough "
f"to supply high-quality results. To improve the quality of generated data "
f"please consider any of the steps: 1) provide a bigger table, "
f"2) disable drop_null argument"
)

logger.info(f"The subset of rows was set to {len(data)}")
logger.info(f"The subset of rows was set to {len(self.data)}")

self.row_subset = len(data)
self.row_subset = len(self.data)
self._set_batch_size()
return data

def _save_input_data(self, data: pd.DataFrame):
DataLoader(self.paths["input_data_path"]).save_data(self.paths["input_data_path"], data)
def _save_input_data(self):
"""
Save the subset of the original data
"""
DataLoader(self.paths["input_data_path"]).save_data(self.data)

def _save_original_schema(self):
"""
Save the schema of the original data
"""
DataLoader(self.paths["original_schema_path"]).save_data(self.original_schema)

def _prepare_data(self, data: pd.DataFrame):
def _prepare_data(self):
"""
Preprocess and save the data necessary for the training process
"""
data = self._preprocess_data(data)
data = self._preprocess_data()
if not self.loader:
self._save_input_data(data)
self._save_input_data()

@slugify_attribute(table_name="slugify_table_name")
def _get_paths(self) -> Dict:
Expand Down Expand Up @@ -245,6 +256,8 @@ def _get_paths(self) -> Dict:
f"checkpoints/stat_keys/",
"original_data_path": f"model_artifacts/tmp_store/{self.slugify_table_name}/"
f"input_data_{self.slugify_table_name}.pkl",
"original_schema_path": f"model_artifacts/tmp_store/{self.slugify_table_name}/"
f"original_schema_{self.slugify_table_name}.pkl",
"path_to_merged_infer": f"model_artifacts/tmp_store/{self.slugify_table_name}/"
f"merged_infer_{self.slugify_table_name}.csv",
"no_ml_state_path":
Expand Down Expand Up @@ -369,6 +382,8 @@ def _get_paths(self) -> Dict:
"state_path": f"model_artifacts/resources/{dynamic_name}/vae/checkpoints",
"train_config_pickle_path":
f"model_artifacts/resources/{dynamic_name}/vae/checkpoints/train_config.pkl",
"original_schema_path": f"model_artifacts/tmp_store/{self.slugify_table_name}/"
f"original_schema_{self.slugify_table_name}.pkl",
"tmp_store_path": f"model_artifacts/tmp_store/{dynamic_name}",
"vae_resources_path": f"model_artifacts/resources/{dynamic_name}/vae/checkpoints/",
"dataset_pickle_path":
Expand Down
39 changes: 15 additions & 24 deletions src/syngen/ml/convertor/convertor.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,18 @@
from typing import Dict, Tuple
from abc import ABC, abstractmethod
from typing import Dict
from dataclasses import dataclass

import pandas as pd
import numpy as np
from loguru import logger


class Convertor(ABC):
@dataclass
class Convertor:
"""
Abstract class for converting fetched schema in Avro, Parquet or Delta formats
"""

def __init__(self, schema, df):
self.converted_schema, self.preprocessed_df = self._convert_schema_and_df(schema, df)

@abstractmethod
def _convert_schema_and_df(self, schema: Dict, df: pd.DataFrame) -> Tuple[Dict, pd.DataFrame]:
"""
Convert the schema of file to unified format, preprocess dataframe
"""
pass
schema: Dict
df: pd.DataFrame

@staticmethod
def _update_data_types(schema: Dict, df: pd.DataFrame):
Expand Down Expand Up @@ -78,7 +71,7 @@ def _preprocess_df(self, schema: Dict, df: pd.DataFrame) -> pd.DataFrame:
"""
Preprocess data frame, update data types of columns
"""
if not df.empty and schema["format"] != "CSV":
if not df.empty:
try:
self._update_data_types(schema, df)
except Exception as e:
Expand All @@ -96,16 +89,12 @@ class CSVConvertor(Convertor):
"""
Class for supporting custom schema for csv files
"""

df: pd.DataFrame()
schema = {"fields": {}, "format": "CSV"}

def __init__(self, schema, df):
def __init__(self, df):
schema = {"fields": {}, "format": "CSV"}
super().__init__(schema, df)

def _convert_schema_and_df(self, schema, df) -> Tuple[Dict, pd.DataFrame]:
preprocessed_df = self._preprocess_df(schema, df)
return schema, preprocessed_df
self.preprocessed_df = self._preprocess_df(schema, df)


class AvroConvertor(Convertor):
Expand All @@ -115,8 +104,11 @@ class AvroConvertor(Convertor):

def __init__(self, schema, df):
super().__init__(schema, df)
self.converted_schema = self._convert_schema(schema)
self.preprocessed_df = self._preprocess_df(self.converted_schema, df)

def _convert_schema_and_df(self, schema, df) -> Tuple[Dict, pd.DataFrame]:
@staticmethod
def _convert_schema(schema) -> Dict:
"""
Convert the schema of Avro file to unified format, preprocess dataframe
"""
Expand All @@ -142,5 +134,4 @@ def _convert_schema_and_df(self, schema, df) -> Tuple[Dict, pd.DataFrame]:
logger.error(message)
raise ValueError(message)
converted_schema["format"] = "Avro"
preprocessed_df = self._preprocess_df(converted_schema, df)
return converted_schema, preprocessed_df
return converted_schema
Loading
Loading