Skip to content

Commit

Permalink
implement the mechanism of avoiding the process of saving 'input_data…
Browse files Browse the repository at this point in the history
….pkl' in case the callback function is used for fetching the data (#450)

* implement the mechanism of avoiding the process of saving 'input_data.pkl' in case the callback function is used for fetching the data

* minor changes in the class VaeInferHandler

* minor changes in 'ml/config'

* refactor the code

* refactor the handling of the log message in the method '_preprocess_data' of the class TrainConfig

* update 'VERSION'

* refactor the code

* remove the log message in the method 'fetch_data' of the class DataFrameFetcher, update 'VERSION'

---------

Co-authored-by: Hanna Imshenetska <Hanna_Imshenetska@epam.com@EVZZAMZSA0021.epam.com>
  • Loading branch information
Anna050689 and Hanna Imshenetska authored Sep 3, 2024
1 parent b7c0966 commit f9bd9d5
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 84 deletions.
2 changes: 1 addition & 1 deletion src/syngen/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.31
0.9.32
66 changes: 44 additions & 22 deletions src/syngen/ml/config/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,27 +157,37 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
"""
Preprocess data and set the parameter "row_subset" for training process
"""
if self.drop_null:
if not data.dropna().empty:
initial_data = data
data = data.dropna()
if count_of_dropped_rows := initial_data.shape[0] - data.shape[0]:
logger.info(
f"As the parameter 'drop_null' set to 'True', "
f"{count_of_dropped_rows} rows of the table - '{self.table_name}' "
f"that have empty values have been dropped. "
f"The count of remained rows is {data.shape[0]}."
if self.loader:
warning_message = (
"parameter will be ignored because the retrieval of the data "
"is handled by a callback function"
)
if self.drop_null:
logger.warning(f"The 'drop_null' {warning_message}")
if self.row_limit is not None:
logger.warning(f"The 'row_limit' {warning_message}")
else:
if self.drop_null:
if not data.dropna().empty:
initial_data = data
data = data.dropna()
if count_of_dropped_rows := initial_data.shape[0] - data.shape[0]:
logger.info(
f"As the parameter 'drop_null' set to 'True', "
f"{count_of_dropped_rows} rows of the table - '{self.table_name}' "
f"that have empty values have been dropped. "
f"The count of remained rows is {data.shape[0]}."
)
else:
logger.warning(
"The specified 'drop_null' argument results in the empty dataframe, "
"so it will be ignored"
)
else:
logger.warning(
"The specified 'drop_null' argument results in the empty dataframe, "
"so it will be ignored"
)

if self.row_limit:
self.row_subset = min(self.row_limit, len(data))
if self.row_limit:
self.row_subset = min(self.row_limit, len(data))

data = data.sample(n=self.row_subset)
data = data.sample(n=self.row_subset)

if len(data) < 100:
logger.warning(
Expand Down Expand Up @@ -206,7 +216,8 @@ def _prepare_data(self, data: pd.DataFrame):
Preprocess and save the data necessary for the training process
"""
data = self._preprocess_data(data)
self._save_input_data(data)
if not self.loader:
self._save_input_data(data)

@slugify_attribute(table_name="slugify_table_name")
def _get_paths(self) -> Dict:
Expand Down Expand Up @@ -259,6 +270,7 @@ class InferConfig:
get_infer_metrics: bool
both_keys: bool
log_level: str
loader: Optional[Callable[[str], pd.DataFrame]]
slugify_table_name: str = field(init=False)

def __post_init__(self):
Expand Down Expand Up @@ -287,7 +299,10 @@ def _set_up_reporting(self):
"""
if (
(self.print_report or self.get_infer_metrics)
and not DataLoader(self.paths["input_data_path"]).has_existed_path
and (
not DataLoader(self.paths["input_data_path"]).has_existed_path
and not self.loader
)
):
message = (
f"It seems that the path to original data "
Expand Down Expand Up @@ -317,8 +332,15 @@ def _set_up_size(self):
"""
Set up "size" of generated data
"""
if self.size is None and DataLoader(self.paths["input_data_path"]).has_existed_path:
data, schema = DataLoader(self.paths["input_data_path"]).load_data()
if self.size is None:
data_loader = DataLoader(self.paths["input_data_path"])
if data_loader.has_existed_path:
data, schema = data_loader.load_data()
elif self.loader:
data, schema = DataFrameFetcher(
loader=self.loader,
table_name=self.table_name
).fetch_data()
self.size = len(data)

def _set_up_batch_size(self):
Expand Down
5 changes: 0 additions & 5 deletions src/syngen/ml/data_loaders/dataframe_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@ class DataFrameFetcher:
table_name: str

def fetch_data(self) -> Tuple[pd.DataFrame, Dict]:
logger.info(
"Attempting to fetch the dataframe due "
"to the absence of the information about the path to the source."
)

try:
df = self.loader(self.table_name)
default_schema = {"fields": {}, "format": "CSV"}
Expand Down
102 changes: 52 additions & 50 deletions src/syngen/ml/handlers/handlers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Tuple, Optional, Dict, List
from typing import Tuple, Optional, Dict, List, Callable
from abc import ABC, abstractmethod
import os
import math
from ulid import ULID
from uuid import UUID
from dataclasses import dataclass, field

import pandas as pd
import numpy as np
Expand All @@ -16,9 +15,10 @@
from tensorflow.keras.preprocessing.text import Tokenizer
from slugify import slugify
from loguru import logger
from attrs import define, field

from syngen.ml.vae import * # noqa: F403
from syngen.ml.data_loaders import DataLoader
from syngen.ml.data_loaders import DataLoader, DataFrameFetcher
from syngen.ml.utils import (
fetch_config,
check_if_features_assigned,
Expand All @@ -38,15 +38,13 @@ def handle(self, data: pd.DataFrame, **kwargs):
pass


@dataclass
@define
class BaseHandler(AbstractHandler):
metadata: Dict
paths: Dict
table_name: str
_next_handler: Optional[AbstractHandler] = field(init=False)

def __post_init__(self):
self._next_handler = None
metadata: Dict = field(kw_only=True)
paths: Dict = field(kw_only=True)
table_name: str = field(kw_only=True)
loader: Optional[Callable[[str], pd.DataFrame]] = None
_next_handler: Optional[AbstractHandler] = None

def set_next(self, handler: AbstractHandler) -> AbstractHandler:
self._next_handler = handler
Expand All @@ -72,17 +70,34 @@ def create_wrapper(cls_name, data: pd.DataFrame, schema: Optional[Dict], **kwarg
process=kwargs["process"],
)

def fetch_data(self) -> Tuple[pd.DataFrame, Optional[Dict]]:
"""
Fetch the data
"""
data_loader = DataLoader(self.paths["input_data_path"])
data = pd.DataFrame()
schema = None
if data_loader.has_existed_path:
data, schema = data_loader.load_data()
elif self.loader:
data, schema = DataFrameFetcher(
loader=self.loader,
table_name=self.table_name
).fetch_data()
return data, schema


@dataclass
@define
class RootHandler(BaseHandler):

def handle(self, **kwargs):
data, schema = DataLoader(self.paths["input_data_path"]).load_data()
data, schema = super().fetch_data()
return super().handle(data, **kwargs)


@dataclass
@define
class LongTextsHandler(BaseHandler):
schema: Optional[Dict]
schema: Optional[Dict] = field(kw_only=True)

@staticmethod
def series_count_words(x):
Expand Down Expand Up @@ -141,16 +156,16 @@ def handle(self, data: pd.DataFrame, **kwargs):
return super().handle(data, **kwargs)


@dataclass
@define
class VaeTrainHandler(BaseHandler):
wrapper_name: str
schema: Dict
epochs: int
row_subset: int
drop_null: bool
batch_size: int
type_of_process: str
print_report: bool
wrapper_name: str = field(kw_only=True)
schema: Dict = field(kw_only=True)
epochs: int = field(kw_only=True)
row_subset: int = field(kw_only=True)
drop_null: bool = field(kw_only=True)
batch_size: int = field(kw_only=True)
type_of_process: str = field(kw_only=True)
print_report: bool = field(kw_only=True)

def __fit_model(self, data: pd.DataFrame):
logger.info("Start VAE training")
Expand Down Expand Up @@ -196,25 +211,25 @@ def handle(self, data: pd.DataFrame, **kwargs):
return super().handle(data, **kwargs)


@dataclass
@define
class VaeInferHandler(BaseHandler):
metadata_path: str
random_seed: Optional[int]
size: int
batch_size: int
run_parallel: bool
print_report: bool
get_infer_metrics: bool
wrapper_name: str
log_level: str
type_of_process: str
metadata_path: str = field(kw_only=True)
random_seed: Optional[int] = field(kw_only=True)
size: int = field(kw_only=True)
batch_size: int = field(kw_only=True)
run_parallel: bool = field(kw_only=True)
print_report: bool = field(kw_only=True)
get_infer_metrics: bool = field(kw_only=True)
wrapper_name: str = field(kw_only=True)
log_level: str = field(kw_only=True)
type_of_process: str = field(kw_only=True)
random_seed_list: List = field(init=False)
vae: Optional[VAEWrapper] = field(init=False) # noqa: F405
has_vae: bool = field(init=False)
has_no_ml: bool = field(init=False)
batch_num: int = field(init=False)

def __post_init__(self):
def __attrs_post_init__(self):
if self.random_seed:
seed(self.random_seed)
self.batch_num = math.ceil(self.size / self.batch_size)
Expand All @@ -223,7 +238,7 @@ def __post_init__(self):
self.dataset = fetch_config(self.paths["dataset_pickle_path"])
self.has_vae = len(self.dataset.features) > 0

data, schema = self._get_data()
data, schema = self.fetch_data()

if self.has_vae:
self.vae = self._get_wrapper(data, schema)
Expand All @@ -240,19 +255,6 @@ def synth_word(size, indexes, counts):
)
)

def _get_data(self) -> Tuple[pd.DataFrame, Dict]:
"""
Load the data from the input data path
"""
input_data_existed = DataLoader(self.paths["input_data_path"]).has_existed_path

if input_data_existed:
data, schema = DataLoader(self.paths["input_data_path"]).load_data()
else:
data = pd.DataFrame()
schema = None
return data, schema

def _get_wrapper(self, data: pd.DataFrame, schema: Dict):
"""
Create and get the wrapper for the VAE model
Expand Down
22 changes: 18 additions & 4 deletions src/syngen/ml/reporters/reporters.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import abstractmethod
from typing import Dict
from typing import Dict, Optional, Callable
import itertools
from collections import defaultdict

import pandas as pd
import numpy as np
from loguru import logger

Expand All @@ -12,7 +13,7 @@
datetime_to_timestamp,
)
from syngen.ml.metrics import AccuracyTest, SampleAccuracyTest
from syngen.ml.data_loaders import DataLoader
from syngen.ml.data_loaders import DataLoader, DataFrameFetcher
from syngen.ml.metrics.utils import text_to_continuous
from syngen.ml.mlflow_tracker import MlflowTracker
from syngen.ml.utils import ProgressBarHandler
Expand All @@ -23,15 +24,28 @@ class Reporter:
Abstract class for reporters
"""

def __init__(self, table_name: str, paths: Dict[str, str], config: Dict[str, str]):
def __init__(
self,
table_name: str,
paths: Dict[str, str],
config: Dict[str, str],
loader: Optional[Callable[[str], pd.DataFrame]] = None
):
self.table_name = table_name
self.paths = paths
self.config = config
self.loader = loader
self.dataset = None
self.columns_nan_labels = dict()

def _extract_report_data(self):
original, schema = DataLoader(self.paths["original_data_path"]).load_data()
if self.loader:
original, schema = DataFrameFetcher(
loader=self.loader,
table_name=self.table_name
).fetch_data()
else:
original, schema = DataLoader(self.paths["original_data_path"]).load_data()
synthetic, schema = DataLoader(self.paths["path_to_merged_infer"]).load_data()
return original, synthetic

Expand Down
4 changes: 4 additions & 0 deletions src/syngen/ml/strategies/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def add_handler(self):
metadata=self.metadata,
table_name=self.config.table_name,
paths=self.config.paths,
loader=self.config.loader
)

vae_handler = VaeTrainHandler(
Expand Down Expand Up @@ -203,6 +204,7 @@ def add_handler(self, type_of_process: str):
get_infer_metrics=self.config.get_infer_metrics,
log_level=self.config.log_level,
type_of_process=type_of_process,
loader=self.config.loader
)
return self

Expand All @@ -216,6 +218,7 @@ def add_reporters(self):
table_name=get_initial_table_name(table_name),
paths=self.config.paths,
config=self.config.to_dict(),
loader=self.config.loader
)
Report().register_reporter(table=table_name, reporter=accuracy_reporter)

Expand All @@ -239,6 +242,7 @@ def run(self, **kwargs):
get_infer_metrics=kwargs["get_infer_metrics"],
log_level=kwargs["log_level"],
both_keys=kwargs["both_keys"],
loader=kwargs["loader"]
)

MlflowTracker().log_params(self.config.to_dict())
Expand Down
1 change: 1 addition & 0 deletions src/syngen/ml/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def _infer_table(self, table, metadata, type_of_process, delta, is_nested=False)
log_level=self.log_level,
both_keys=both_keys,
type_of_process=self.type_of_process,
loader=self.loader
)
ProgressBarHandler().set_progress(
delta=delta,
Expand Down
Loading

0 comments on commit f9bd9d5

Please sign in to comment.