from __future__ import annotations
import time
from pathlib import Path
from typing import Any, Generator
import pandas as pd
from tqdm import autonotebook as tqdm
from sdgx.data_connectors.base import DataConnector
from sdgx.data_connectors.generator_connector import GeneratorConnector
from sdgx.data_connectors.manager import DataConnectorManager
from sdgx.data_loader import DataLoader
from sdgx.data_models.metadata import Metadata
from sdgx.data_processors.base import DataProcessor
from sdgx.data_processors.manager import DataProcessorManager
from sdgx.exceptions import SynthesizerInitError, SynthesizerSampleError
from sdgx.models.base import SynthesizerModel
from sdgx.models.components.sdv_ctgan.synthesizers.base import BatchedSynthesizer
from sdgx.models.manager import ModelManager
from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel
from sdgx.models.statistics.single_table.base import StatisticSynthesizerModel
from sdgx.utils import logger
[docs]
class Synthesizer:
"""
Synthesizer is the high level interface for synthesizing data.
We provided several example usage in our `Github repository <https://github.com/hitsz-ids/synthetic-data-generator/tree/main/example>`_.
Args:
model (str | SynthesizerModel | type[SynthesizerModel]): The name of the model or the model itself. Type of model must be :class:`~sdgx.models.base.SynthesizerModel`.
When model is a string, it must be registered in :class:`~sdgx.models.manager.ModelManager`.
model_path (str | Path, optional): The path to the model file. Defaults to None. Used to load the model if ``model`` is a string or type of :class:`~sdgx.models.base.SynthesizerModel`.
model_kwargs (dict[str, Any], optional): The keyword arguments for model. Defaults to None.
metadata (Metadata, optional): The metadata to use. Defaults to None.
metadata_path (str | Path, optional): The path to the metadata file. Defaults to None. Used to load the metadata if ``metadata`` is None.
data_connector (DataConnector | type[DataConnector] | str, optional): The data connector to use. Defaults to None.
When data_connector is a string, it must be registered in :class:`~sdgx.data_connectors.manager.DataConnectorManager`.
data_connector_kwargs (dict[str, Any], optional): The keyword arguments for data connectors. Defaults to None.
raw_data_loaders_kwargs (dict[str, Any], optional): The keyword arguments for raw data loaders. Defaults to None.
processed_data_loaders_kwargs (dict[str, Any], optional): The keyword arguments for processed data loaders. Defaults to None.
data_processors (list[str | DataProcessor | type[DataProcessor]], optional): The data processors to use. Defaults to None.
When data_processor is a string, it must be registered in :class:`~sdgx.data_processors.manager.DataProcessorManager`.
data_processors_kwargs (dict[str, dict[str, Any]], optional): The keyword arguments for data processors. Defaults to None.
Example:
.. code-block:: python
from sdgx.data_connectors.csv_connector import CsvConnector
from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel
from sdgx.synthesizer import Synthesizer
from sdgx.utils import download_demo_data
dataset_csv = download_demo_data()
data_connector = CsvConnector(path=dataset_csv)
synthesizer = Synthesizer(
model=CTGANSynthesizerModel(epochs=1), # For quick demo
data_connector=data_connector,
)
synthesizer.fit()
sampled_data = synthesizer.sample(1000)
"""
METADATA_SAVE_NAME = "metadata.json"
"""
Default name for metadata file
"""
MODEL_SAVE_DIR = "model"
"""
Default name for model directory
"""
def __init__(
self,
model: str | SynthesizerModel | type[SynthesizerModel],
model_path: None | str | Path = None,
model_kwargs: None | dict[str, Any] = None,
metadata: None | Metadata = None,
metadata_path: None | str | Path = None,
data_connector: None | str | DataConnector | type[DataConnector] = None,
data_connector_kwargs: None | dict[str, Any] = None,
raw_data_loaders_kwargs: None | dict[str, Any] = None,
processed_data_loaders_kwargs: None | dict[str, Any] = None,
data_processors: None | list[str | DataProcessor | type[DataProcessor]] = None,
data_processors_kwargs: None | dict[str, Any] = None,
):
# Init data connectors
if isinstance(data_connector, str) or isinstance(data_connector, type):
data_connector = DataConnectorManager().init_data_connector(
data_connector, **(data_connector_kwargs or {})
)
if data_connector:
self.dataloader = DataLoader(
data_connector,
**(raw_data_loaders_kwargs or {}),
)
else:
logger.warning("No data_connector provided, will not support `fit`")
self.dataloader = None
# Init data processors
self.data_processors_manager = DataProcessorManager()
if not data_processors:
data_processors = self.data_processors_manager.registed_default_processor_list
logger.info(f"Using data processors: {data_processors}")
self.data_processors = [
(
d
if isinstance(d, DataProcessor)
else self.data_processors_manager.init_data_processor(
d, **(data_processors_kwargs or {})
)
)
for d in data_processors
]
if metadata and metadata_path:
raise SynthesizerInitError(
"metadata and metadata_path cannot be specified at the same time"
)
# Load metadata
# metadata also can be changed in ``fit`` or ``sample``
# Always use the latest metadata configured.
if metadata:
self.metadata = metadata
elif metadata_path:
self.metadata = Metadata.load(metadata_path)
else:
self.metadata = None
# Init model
self.model_manager = ModelManager()
if isinstance(model, SynthesizerModel) and model_path:
# Initialized model cannot load from model_path
raise SynthesizerInitError(
"model as instance and model_path cannot be specified at the same time"
)
if (isinstance(model, str) or isinstance(model, type)) and model_path:
# Load model by cls or str
self.model = self.model_manager.load(model, model_path, **(model_kwargs or {}))
if model_kwargs:
logger.warning("model_kwargs will be ignored when loading model from model_path")
elif isinstance(model, str) or isinstance(model, type):
# Init model by cls or str
self.model = self.model_manager.init_model(model, **(model_kwargs or {}))
elif isinstance(model, SynthesizerModel) or isinstance(model, StatisticSynthesizerModel):
# Already initialized model
self.model = model
if model_kwargs:
logger.warning("model_kwargs will be ignored when using already initialized model")
else:
raise SynthesizerInitError("model or model_path must be specified")
# Other arguments
self.processed_data_loaders_kwargs = processed_data_loaders_kwargs or {}
[docs]
def save(self, save_dir: str | Path) -> Path:
"""
Dump metadata and model to file
Args:
save_dir (str | Path): The directory to save the model.
Returns:
Path: The directory to save the synthesizer.
"""
save_dir = Path(save_dir).expanduser().resolve()
save_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving synthesizer to {save_dir}")
if self.metadata:
self.metadata.save(save_dir / self.METADATA_SAVE_NAME)
model_save_dir = save_dir / self.MODEL_SAVE_DIR
model_save_dir.mkdir(parents=True, exist_ok=True)
self.model.save(model_save_dir)
return save_dir
[docs]
@classmethod
def load(
cls,
load_dir: str | Path,
model: str | type[SynthesizerModel],
metadata: None | Metadata = None,
data_connector: None | str | DataConnector | type[DataConnector] = None,
data_connector_kwargs: None | dict[str, Any] = None,
raw_data_loaders_kwargs: None | dict[str, Any] = None,
processed_data_loaders_kwargs: None | dict[str, Any] = None,
data_processors: None | list[str | DataProcessor | type[DataProcessor]] = None,
data_processors_kwargs: None | dict[str, dict[str, Any]] = None,
model_kwargs=None,
) -> "Synthesizer":
"""
Load metadata and model, allow rebuilding Synthesizer for finetuning or other use cases.
We need ``model`` as not every model support *pickle* way to save and load.
Args:
load_dir (str | Path): The directory to load the model.
model (str | type[SynthesizerModel]): The name of the model or the model itself. Type of model must be :class:`~sdgx.models.base.SynthesizerModel`.
When model is a string, it must be registered in :class:`~sdgx.models.manager.ModelManager`.
metadata (Metadata, optional): The metadata to use. Defaults to None.
data_connector (DataConnector | type[DataConnector] | str, optional): The data connector to use. Defaults to None.
When data_connector is a string, it must be registered in :class:`~sdgx.data_connectors.manager.DataConnectorManager`.
data_connector_kwargs (dict[str, Any], optional): The keyword arguments for data connectors. Defaults to None.
raw_data_loaders_kwargs (dict[str, Any], optional): The keyword arguments for raw data loaders. Defaults to None.
processed_data_loaders_kwargs (dict[str, Any], optional): The keyword arguments for processed data loaders. Defaults to None.
data_processors (list[str | DataProcessor | type[DataProcessor]], optional): The data processors to use. Defaults to None.
When data_processor is a string, it must be registered in :class:`~sdgx.data_processors.manager.DataProcessorManager`.
data_processors_kwargs (dict[str, dict[str, Any]], optional): The keyword arguments for data processors. Defaults to None.
Returns:
Synthesizer: The synthesizer instance.
"""
load_dir = Path(load_dir).expanduser().resolve()
logger.info(f"Loading synthesizer from {load_dir}")
if not load_dir.exists():
raise SynthesizerInitError(f"{load_dir.as_posix()} does not exist")
model_path = load_dir / cls.MODEL_SAVE_DIR
if not model_path.exists():
raise SynthesizerInitError(
f"{model_path.as_posix()} does not exist, cannot load model."
)
metadata_path = load_dir / cls.METADATA_SAVE_NAME
if not metadata_path.exists():
metadata_path = None
return Synthesizer(
model=model,
model_path=model_path,
metadata=metadata,
metadata_path=metadata_path,
model_kwargs=model_kwargs,
data_connector=data_connector,
data_connector_kwargs=data_connector_kwargs,
raw_data_loaders_kwargs=raw_data_loaders_kwargs,
processed_data_loaders_kwargs=processed_data_loaders_kwargs,
data_processors=data_processors,
data_processors_kwargs=data_processors_kwargs,
)
[docs]
def fit(
self,
metadata: None | Metadata = None,
inspector_max_chunk: int = 10,
metadata_include_inspectors: None | list[str] = None,
metadata_exclude_inspectors: None | list[str] = None,
inspector_init_kwargs: None | dict[str, Any] = None,
model_fit_kwargs: None | dict[str, Any] = None,
):
"""
Fit the synthesizer with metadata and data processors.
Raw data will be loaded from the dataloader and processed by the data processors in a Generator.
The Generator, which prevents the processed data, will be wrapped into a DataLoader, aka ProcessedDataLoader.
The ProcessedDataLoader will be used to fit the model.
For more information about DataLoaders, please refer to the :class:`~sdgx.data_loaders.base.DataLoader`.
For more information about DataProcessors, please refer to the :class:`~sdgx.data_processors.base.DataProcessor`.
For more information about DataConnectors, please refer to the :class:`~sdgx.data_connectors.base.DataConnector`. Especially, the :class:`~sdgx.data_connectors.generator_connector.GeneratorConnector`.
Args:
metadata (Metadata, optional): The metadata to use. Defaults to None. If None, it will be inferred from the dataloader with the :func:`~sdgx.data_models.metadata.Metadata.from_dataloader` method.
inspector_max_chunk (int, optional): The maximum number of chunks to inspect. Defaults to 10.
metadata_include_inspectors (list[str], optional): The list of metadata inspectors to include. Defaults to None.
metadata_exclude_inspectors (list[str], optional): The list of metadata inspectors to exclude. Defaults to None.
inspector_init_kwargs (dict[str, Any], optional): The keyword arguments for metadata inspectors. Defaults to None.
model_fit_kwargs (dict[str, Any], optional): The keyword arguments for model.fit. Defaults to None.
"""
if self.dataloader is None:
raise SynthesizerInitError(
"Cannot fit without dataloader, check `data_connector` parameter when initializing Synthesizer"
)
metadata = (
metadata
or self.metadata
or Metadata.from_dataloader(
self.dataloader,
max_chunk=inspector_max_chunk,
include_inspectors=metadata_include_inspectors,
exclude_inspectors=metadata_exclude_inspectors,
inspector_init_kwargs=inspector_init_kwargs,
)
)
# Some processors may cause metadata update before model fitting, we need to make a copy.
self.metadata = metadata.model_copy() # Ensure update metadata
logger.info("Fitting data processors...")
if not self.dataloader:
logger.info("Fitting without dataloader.")
start_time = time.time()
for d in self.data_processors:
if self.dataloader:
d.fit(metadata=metadata, tabular_data=self.dataloader)
else:
d.fit(metadata=metadata)
logger.info(
f"Fitted {len(self.data_processors)} data processors in {time.time() - start_time}s."
)
def chunk_generator() -> Generator[pd.DataFrame, None, None]:
for chunk in self.dataloader.iter():
for d in self.data_processors:
chunk = d.convert(chunk)
yield chunk
logger.info("Initializing processed data loader...")
start_time = time.time()
processed_dataloader = DataLoader(
GeneratorConnector(chunk_generator),
identity=self.dataloader.identity,
**self.processed_data_loaders_kwargs,
)
logger.info(f"Initialized processed data loader in {time.time() - start_time}s")
try:
logger.info("Model fit Started...")
self.model.fit(metadata, processed_dataloader, **(model_fit_kwargs or {}))
logger.info("Model fit... Finished")
finally:
processed_dataloader.finalize(clear_cache=True)
[docs]
def sample(
self,
count: int,
chunksize: None | int = None,
metadata: None | Metadata = None,
model_sample_args: None | dict[str, Any] = None,
) -> pd.DataFrame | Generator[pd.DataFrame, None, None]:
"""
Sample data from the synthesizer.
Args:
count (int): The number of samples to generate.
chunksize (int, optional): The chunksize to use. Defaults to None. If is not None, the data will be sampled in chunks.
And will return a generator that yields chunks of samples.
metadata (Metadata, optional): The metadata to use. Defaults to None. If None, will use the metadata in fit first.
model_sample_args (dict[str, Any], optional): The keyword arguments for model.sample. Defaults to None.
Returns:
pd.DataFrame | typing.Generator[pd.DataFrame, None, None]: The sampled data. When chunksize is not None, it will be a generator.
"""
logger.info("Sampling...")
metadata = metadata or self.metadata
self.metadata = metadata # Ensure update metadata
# data_processors do not need to be fit again in the sampling stage
if not model_sample_args:
model_sample_args = {}
if chunksize is None:
return self._sample_once(count, model_sample_args)
if chunksize > count:
raise SynthesizerSampleError("chunksize must be less than or equal to count")
def generator_sample_caller():
sample_times = count // chunksize
for _ in range(sample_times):
sample_data = self._sample_once(chunksize, model_sample_args)
for d in self.data_processors:
sample_data = d.reverse_convert(sample_data)
yield sample_data
if count % chunksize > 0:
sample_data = self._sample_once(count % chunksize, model_sample_args)
for d in self.data_processors:
sample_data = d.reverse_convert(sample_data)
yield sample_data
return generator_sample_caller()
[docs]
def _sample_once(
self, count: int, model_sample_args: None | dict[str, Any] = None
) -> pd.DataFrame:
"""
Sample data once.
DataProcessors may drop some broken data after reverse_convert.
So we oversample first and then take the first `count` samples.
TODO:
- Use an adaptive scale for oversampling will be better for performance.
"""
missing_count = count
max_trails = 50
sample_data_list = []
psb = tqdm.tqdm(total=count, desc="Sampling")
# To improve batched model performance, such as tvae or ctgan.
batch_size: int = 0
multiply_factor: float = 4.0
if isinstance(self.model, BatchedSynthesizer):
batch_size = self.model.get_batch_size()
multiply_factor = 1.2
if isinstance(self.model, CTGANSynthesizerModel):
model_sample_args = {"drop_more": False}
while missing_count > 0 and max_trails > 0:
sample_data = self.model.sample(
max(int(missing_count * multiply_factor), batch_size), **model_sample_args
)
# TODO table separated parallel process
for d in self.data_processors:
sample_data = d.reverse_convert(sample_data)
sample_data = sample_data.dropna(how="all")
sample_data_list.append(sample_data)
missing_count = missing_count - len(sample_data)
psb.update(len(sample_data))
max_trails -= 1
return pd.concat(sample_data_list)[:count]
[docs]
def cleanup(self):
"""
Cleanup resources. This will cause model unavailable and clear the cache.
It useful when Synthesizer object is no longer needed and may hold large resources like GPUs.
"""
if self.dataloader:
self.dataloader.finalize(clear_cache=True)
# Release resources
if hasattr(self, "model"):
del self.model
def __del__(self):
self.cleanup()