Source code for sdgx.models.base

from __future__ import annotations

from pathlib import Path

import pandas as pd

from sdgx.data_loader import DataLoader
from sdgx.data_models.metadata import Metadata
from sdgx.exceptions import SynthesizerInitError


[docs] class SynthesizerModel: use_dataloader: bool = False use_raw_data: bool = False def __init__(self, *args, **kwargs) -> None: # specify data access type if "use_dataloader" in kwargs.keys(): self.use_dataloader = kwargs["use_dataloader"] if "use_raw_data" in kwargs.keys(): self.use_raw_data = kwargs["use_raw_data"]
[docs] def _check_access_type(self): if self.use_dataloader == self.use_raw_data == False: raise SynthesizerInitError( "Data access type not specified, please use `use_raw_data: bool` or `use_dataloader: bool` to specify data access type." ) elif self.use_dataloader == self.use_raw_data == True: raise SynthesizerInitError("Duplicate data access type found.")
[docs] def fit(self, metadata: Metadata, dataloader: DataLoader, *args, **kwargs): """ Fit the model using the given metadata and dataloader. Args: metadata (Metadata): The metadata to use. dataloader (DataLoader): The dataloader to use. """ raise NotImplementedError
[docs] def sample(self, count: int, *args, **kwargs) -> pd.DataFrame: """ Sample data from the model. Args: count (int): The number of samples to generate. Returns: pd.DataFrame: The generated data. """ raise NotImplementedError
[docs] def save(self, save_dir: str | Path): """ Dump model to file. Args: save_dir (str | Path): The directory to save the model. """ raise NotImplementedError
[docs] @classmethod def load(cls, save_dir: str | Path, **kwargs) -> "SynthesizerModel": """ Load model from file. Args: save_dir (str | Path): The directory to load the model from. """ raise NotImplementedError