Source code for sdgx.data_processors.base
from __future__ import annotations
from typing import Any, Dict
import pandas as pd
from sdgx.data_models.metadata import Metadata
from sdgx.exceptions import SynthesizerProcessorError
from sdgx.utils import logger
[docs]
class DataProcessor:
"""
Base class for data processors.
"""
fitted = False
[docs]
def check_fitted(self):
"""Check if the processor is fitted.
Raises:
SynthesizerProcessorError: If the processor is not fitted.
"""
if not self.fitted:
raise SynthesizerProcessorError("Processor NOT fitted.")
[docs]
def fit(self, metadata: Metadata | None = None, **kwargs: Dict[str, Any]):
self._fit(metadata, **kwargs)
self.fitted = True
[docs]
def _fit(self, metadata: Metadata | None = None, **kwargs: Dict[str, Any]):
"""Fit the data processor.
Called before ``convert`` and ``reverse_convert``.
Args:
metadata (Metadata, optional): Metadata. Defaults to None.
"""
return
[docs]
def convert(self, raw_data: pd.DataFrame) -> pd.DataFrame:
"""Convert raw data into processed data.
Args:
raw_data (pd.DataFrame): Raw data
Returns:
pd.DataFrame: Processed data
"""
return raw_data
[docs]
def reverse_convert(self, processed_data: pd.DataFrame) -> pd.DataFrame:
"""Convert processed data into raw data.
Args:
processed_data (pd.DataFrame): Processed data
Returns:
pd.DataFrame: Raw data
"""
return processed_data
[docs]
@staticmethod
def remove_columns(tabular_data: pd.DataFrame, column_name_to_remove: list) -> pd.DataFrame:
"""
Remove specified columns from the input tabular data.
Args:
- tabular_data (pd.DataFrame): Processed tabular data
- column_name_to_remove (list): List of column names to be removed
Returns:
- result_data (pd.DataFrame): Tabular data with specified columns removed
"""
# Make a copy of the input data to avoid modifying the original data
result_data = tabular_data.copy()
# Remove specified columns
try:
result_data = result_data.drop(columns=column_name_to_remove)
except KeyError:
logger.warning(
"Duplicate column removal occurred, which might lead to unintended consequences."
)
return result_data
[docs]
@staticmethod
def attach_columns(tabular_data: pd.DataFrame, new_columns: pd.DataFrame) -> pd.DataFrame:
"""
Attach additional columns to an existing DataFrame.
Args:
- tabular_data (pd.DataFrame): The original DataFrame.
- new_columns (pd.DataFrame): The DataFrame containing additional columns to be attached.
Returns:
- result_data (pd.DataFrame): The DataFrame with new_columns attached.
Raises:
- ValueError: If the number of rows in tabular_data and new_columns are not the same.
"""
# Check if the number of rows in tabular_data and new_columns are the same
if tabular_data.shape[0] != new_columns.shape[0]:
raise ValueError("Number of rows in tabular_data and new_columns must be the same.")
# Concatenate tabular_data and new_columns along axis 1 (columns)
result_data = pd.concat([tabular_data, new_columns], axis=1)
return result_data