Source code for sdgx.models.ml.single_table.ctgan

from __future__ import annotations

import math
import time
from pathlib import Path
from typing import List, Optional

import numpy as np
import pandas as pd
import torch
from packaging import version
from torch import optim
from torch.nn import (
    BatchNorm1d,
    Dropout,
    LeakyReLU,
    Linear,
    Module,
    ReLU,
    Sequential,
    functional,
)
from tqdm import autonotebook as tqdm

from sdgx.data_loader import DataLoader
from sdgx.data_models.metadata import Metadata
from sdgx.models.components.optimize.ndarray_loader import NDArrayLoader
from sdgx.models.components.optimize.sdv_ctgan.data_sampler import DataSampler
from sdgx.models.components.optimize.sdv_ctgan.data_transformer import DataTransformer
from sdgx.models.components.sdv_ctgan.synthesizers.base import (
    BaseSynthesizer as SDVBaseSynthesizer,
)
from sdgx.models.components.sdv_ctgan.synthesizers.base import (
    BatchedSynthesizer,
    random_state,
)
from sdgx.models.extension import hookimpl
from sdgx.models.ml.single_table.base import MLSynthesizerModel
from sdgx.utils import logger


[docs] class Discriminator(Module): """Discriminator for the CTGAN.""" def __init__(self, input_dim, discriminator_dim, pac=10): super(Discriminator, self).__init__() dim = input_dim * pac self.pac = pac self.pacdim = dim seq = [] for item in list(discriminator_dim): seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)] dim = item seq += [Linear(dim, 1)] self.seq = Sequential(*seq)
[docs] def calc_gradient_penalty(self, real_data, fake_data, device="cpu", pac=10, lambda_=10): """Compute the gradient penalty.""" alpha = torch.rand(real_data.size(0) // pac, 1, 1, device=device) alpha = alpha.repeat(1, pac, real_data.size(1)) alpha = alpha.view(-1, real_data.size(1)) interpolates = alpha * real_data + ((1 - alpha) * fake_data) disc_interpolates = self(interpolates) gradients = torch.autograd.grad( outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones(disc_interpolates.size(), device=device), create_graph=True, retain_graph=True, only_inputs=True, )[0] gradients_view = gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1 gradient_penalty = ((gradients_view) ** 2).mean() * lambda_ return gradient_penalty
[docs] def forward(self, input_): """Apply the Discriminator to the `input_`.""" assert input_.size()[0] % self.pac == 0 return self.seq(input_.view(-1, self.pacdim))
[docs] class Residual(Module): """Residual layer for the CTGAN.""" def __init__(self, i, o): super(Residual, self).__init__() self.fc = Linear(i, o) self.bn = BatchNorm1d(o) self.relu = ReLU()
[docs] def forward(self, input_): """Apply the Residual layer to the `input_`.""" out = self.fc(input_) out = self.bn(out) out = self.relu(out) return torch.cat([out, input_], dim=1)
[docs] class Generator(Module): """Generator for the CTGAN.""" def __init__(self, embedding_dim, generator_dim, data_dim): super(Generator, self).__init__() dim = embedding_dim seq = [] for item in list(generator_dim): seq += [Residual(dim, item)] dim += item seq.append(Linear(dim, data_dim)) self.seq = Sequential(*seq)
[docs] def forward(self, input_): """Apply the Generator to the `input_`.""" data = self.seq(input_) return data
[docs] class CTGANSynthesizerModel(MLSynthesizerModel, BatchedSynthesizer): """ Modified from ``sdgx.models.components.sdv_ctgan.synthesizers.ctgan.CTGANSynthesizer``. A CTGANSynthesizer but provided :ref:`SynthesizerModel` interface with chunked fit. This is the core class of the CTGAN project, where the different components are orchestrated together. For more details about the process, please check the [Modeling Tabular data using Conditional GAN](https://arxiv.org/abs/1907.00503) paper. Args: embedding_dim (int): Size of the random sample passed to the Generator. Defaults to 128. generator_dim (tuple or list of ints): Size of the output samples for each one of the Residuals. A Residual Layer will be created for each one of the values provided. Defaults to (256, 256). discriminator_dim (tuple or list of ints): Size of the output samples for each one of the Discriminator Layers. A Linear Layer will be created for each one of the values provided. Defaults to (256, 256). generator_lr (float): Learning rate for the generator. Defaults to 2e-4. generator_decay (float): Generator weight decay for the Adam Optimizer. Defaults to 1e-6. discriminator_lr (float): Learning rate for the discriminator. Defaults to 2e-4. discriminator_decay (float): Discriminator weight decay for the Adam Optimizer. Defaults to 1e-6. batch_size (int): Number of data samples to process in each step. discriminator_steps (int): Number of discriminator updates to do for each generator update. From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper default is 5. Default used is 1 to match original CTGAN implementation. log_frequency (boolean): Whether to use log frequency of categorical levels in conditional sampling. Defaults to ``True``. epochs (int): Number of training epochs. Defaults to 300. pac (int): Number of samples to group together when applying the discriminator. Defaults to 10. device (str): Device to run the training on. Preferred to be 'cuda' for GPU if available. """ MODEL_SAVE_NAME = "ctgan.pkl" def __init__( self, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500, discriminator_steps=1, log_frequency=True, epochs=300, pac=10, device="cuda" if torch.cuda.is_available() else "cpu", ): assert batch_size % 2 == 0 BatchedSynthesizer.__init__(self, batch_size=batch_size) self._embedding_dim = embedding_dim self._generator_dim = generator_dim self._discriminator_dim = discriminator_dim self._generator_lr = generator_lr self._generator_decay = generator_decay self._discriminator_lr = discriminator_lr self._discriminator_decay = discriminator_decay self._discriminator_steps = discriminator_steps self._log_frequency = log_frequency self._epochs = epochs self.pac = pac self._device = torch.device(device) # Following components are initialized in `_pre_fit` self._transformer: Optional[DataTransformer] = None self._data_sampler: Optional[DataSampler] = None self._generator = None self._ndarry_loader: Optional[NDArrayLoader] = None self.data_dim: Optional[int] = None
[docs] def fit(self, metadata: Metadata, dataloader: DataLoader, epochs=None, *args, **kwargs): # In the future, sdgx use `sdgx.data_processor.transformers.discrete` to handle discrete_columns # the original sdv transformer will be removed in version 0.3.0 # This will be done in another PR. discrete_columns = list(metadata.get("discrete_columns")) if epochs is not None: self._epochs = epochs self._pre_fit(dataloader, discrete_columns, metadata) if self.fit_data_empty: logger.info("CTGAN fit finished because of empty df detected.") return logger.info("CTGAN prefit finished, start CTGAN training.") self._fit(len(self._ndarry_loader)) logger.info("CTGAN training finished.")
[docs] def _pre_fit( self, dataloader: DataLoader, discrete_columns: list[str] = None, metadata: Metadata = None ): if not discrete_columns: discrete_columns = [] # self._validate_discrete_columns(dataloader.columns(), discrete_columns) discrete_columns = self._filter_discrete_columns(dataloader.columns(), discrete_columns) # if the df is empty, we don't need to do anything if self.fit_data_empty: return # Fit Transformer and DataSampler self._transformer = DataTransformer(metadata=metadata) logger.info("Fitting model's transformer...") self._transformer.fit(dataloader, discrete_columns) logger.info("Transforming data...") self._ndarry_loader = self._transformer.transform(dataloader) logger.info("Sampling data.") self._data_sampler = DataSampler( self._ndarry_loader, self._transformer.output_info_list, self._log_frequency ) logger.info("Initialize Generator.") # Initialize Generator self.data_dim = self._transformer.output_dimensions self._generator = Generator( self._embedding_dim + self._data_sampler.dim_cond_vec(), self._generator_dim, self.data_dim, ).to(self._device)
@random_state def _fit(self, data_size: int): """Fit the CTGAN Synthesizer models to the training data.""" logger.info(f"Fit using data_size:{data_size}, data_dim: {self.data_dim}.") epochs = self._epochs # data_dim = self._transformer.output_dimensions discriminator = Discriminator( self.data_dim + self._data_sampler.dim_cond_vec(), self._discriminator_dim, pac=self.pac, ).to(self._device) optimizerG = optim.Adam( self._generator.parameters(), lr=self._generator_lr, betas=(0.5, 0.9), weight_decay=self._generator_decay, ) optimizerD = optim.Adam( discriminator.parameters(), lr=self._discriminator_lr, betas=(0.5, 0.9), weight_decay=self._discriminator_decay, ) mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device) std = mean + 1 logger.info("Starting model training, epochs: {}".format(epochs)) steps_per_epoch = max(data_size // self._batch_size, 1) for i in range(epochs): start_time = time.time() for id_ in tqdm.tqdm(range(steps_per_epoch), desc="Fitting batches", delay=3): for n in range(self._discriminator_steps): fakez = torch.normal(mean=mean, std=std) condvec = self._data_sampler.sample_condvec(self._batch_size) if condvec is None: c1, m1, col, opt = None, None, None, None real = self._data_sampler.sample_data(self._batch_size, col, opt) else: c1, m1, col, opt = condvec c1 = torch.from_numpy(c1).to(self._device) m1 = torch.from_numpy(m1).to(self._device) fakez = torch.cat([fakez, c1], dim=1) perm = np.arange(self._batch_size) np.random.shuffle(perm) real = self._data_sampler.sample_data( self._batch_size, col[perm], opt[perm] ) c2 = c1[perm] fake = self._generator(fakez) fakeact = self._apply_activate(fake) real = torch.from_numpy(real.astype("float32")).to(self._device) if c1 is not None: fake_cat = torch.cat([fakeact, c1], dim=1) real_cat = torch.cat([real, c2], dim=1) else: real_cat = real fake_cat = fakeact y_fake = discriminator(fake_cat) y_real = discriminator(real_cat) pen = discriminator.calc_gradient_penalty( real_cat, fake_cat, self._device, self.pac ) loss_d = -(torch.mean(y_real) - torch.mean(y_fake)) optimizerD.zero_grad() pen.backward(retain_graph=True) loss_d.backward() optimizerD.step() fakez = torch.normal(mean=mean, std=std) condvec = self._data_sampler.sample_condvec(self._batch_size) if condvec is None: c1, m1, col, opt = None, None, None, None else: c1, m1, col, opt = condvec c1 = torch.from_numpy(c1).to(self._device) m1 = torch.from_numpy(m1).to(self._device) fakez = torch.cat([fakez, c1], dim=1) fake = self._generator(fakez) fakeact = self._apply_activate(fake) if c1 is not None: y_fake = discriminator(torch.cat([fakeact, c1], dim=1)) else: y_fake = discriminator(fakeact) if condvec is None: cross_entropy = 0 else: cross_entropy = self._cond_loss(fake, c1, m1) loss_g = -torch.mean(y_fake) + cross_entropy optimizerG.zero_grad() loss_g.backward() optimizerG.step() logger.info( f"Epoch {i+1}, Loss G: {loss_g.detach().cpu(): .4f}," # noqa: T001 f" Loss D: {loss_d.detach().cpu(): .4f}," f" Time: {time.time() - start_time: .4f}", )
[docs] def sample(self, count: int, *args, **kwargs) -> pd.DataFrame: if self.fit_data_empty: return pd.DataFrame(index=range(count)) return self._sample(count, *args, **kwargs)
@random_state def _sample(self, n, condition_column=None, condition_value=None, drop_more=True): """Sample data similar to the training data. Choosing a condition_column and condition_value will increase the probability of the discrete condition_value happening in the condition_column. Args: n (int): Number of rows to sample. condition_column (string): Name of a discrete column. condition_value (string): Name of the category in the condition_column which we wish to increase the probability of happening. Returns: numpy.ndarray or pandas.DataFrame """ if condition_column is not None and condition_value is not None: condition_info = self._transformer.convert_column_name_value_to_id( condition_column, condition_value ) global_condition_vec = self._data_sampler.generate_cond_from_condition_column_info( condition_info, self._batch_size ) else: global_condition_vec = None steps = math.ceil(n / self._batch_size) data = [] for _ in tqdm.tqdm(range(steps), desc="Sampling batches", delay=3): mean = torch.zeros(self._batch_size, self._embedding_dim) std = mean + 1 fakez = torch.normal(mean=mean, std=std).to(self._device) if global_condition_vec is not None: condvec = global_condition_vec.copy() else: condvec = self._data_sampler.sample_original_condvec(self._batch_size) if condvec is None: pass else: c1 = condvec c1 = torch.from_numpy(c1).to(self._device) fakez = torch.cat([fakez, c1], dim=1) fake = self._generator(fakez) fakeact = self._apply_activate(fake) data.append(fakeact.detach().cpu().numpy()) data = np.concatenate(data, axis=0) logger.info("CTGAN Generated {} raw samples.".format(data.shape[0])) if drop_more: data = data[:n] return self._transformer.inverse_transform(data)
[docs] def save(self, save_dir: str | Path): save_dir.mkdir(parents=True, exist_ok=True) return SDVBaseSynthesizer.save(self, save_dir / self.MODEL_SAVE_NAME)
[docs] @classmethod def load(cls, save_dir: str | Path, device: str = None) -> "CTGANSynthesizerModel": return SDVBaseSynthesizer.load(save_dir / cls.MODEL_SAVE_NAME, device)
[docs] @staticmethod def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): """Deals with the instability of the gumbel_softmax for older versions of torch. For more details about the issue: https://drive.google.com/file/d/1AA5wPfZ1kquaRtVruCd6BiYZGcDeNxyP/view?usp=sharing Args: logits […, num_features]: Unnormalized log probabilities tau: Non-negative scalar temperature hard (bool): If True, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd dim (int): A dimension along which softmax will be computed. Default: -1. Returns: Sampled tensor of same shape as logits from the Gumbel-Softmax distribution. """ if version.parse(torch.__version__) < version.parse("1.2.0"): for i in range(10): transformed = functional.gumbel_softmax( logits, tau=tau, hard=hard, eps=eps, dim=dim ) if not torch.isnan(transformed).any(): return transformed raise ValueError("gumbel_softmax returning NaN.") return functional.gumbel_softmax(logits, tau=tau, hard=hard, eps=eps, dim=dim)
[docs] def _apply_activate(self, data): """Apply proper activation function to the output of the generator.""" data_t = [] st = 0 for column_info in self._transformer.output_info_list: for span_info in column_info: if span_info.activation_fn == "tanh": ed = st + span_info.dim data_t.append(torch.tanh(data[:, st:ed])) st = ed elif span_info.activation_fn == "softmax": ed = st + span_info.dim transformed = self._gumbel_softmax(data[:, st:ed], tau=0.2) data_t.append(transformed) st = ed elif span_info.activation_fn == "linear": # for label encoder ed = st + span_info.dim transformed = data[:, st:ed].clone() data_t.append(transformed) st = ed else: raise ValueError(f"Unexpected activation function {span_info.activation_fn}.") return torch.cat(data_t, dim=1)
[docs] def _cond_loss(self, data, c, m): """Compute the cross entropy loss on the fixed discrete column.""" loss = [] st = 0 st_c = 0 for column_info in self._transformer.output_info_list: for span_info in column_info: if len(column_info) != 1 or span_info.activation_fn != "softmax": # not discrete column st += span_info.dim else: ed = st + span_info.dim ed_c = st_c + span_info.dim tmp = functional.cross_entropy( data[:, st:ed], torch.argmax(c[:, st_c:ed_c], dim=1), reduction="none", ) loss.append(tmp) st = ed st_c = ed_c loss = torch.stack(loss, dim=1) # noqa: PD013 return (loss * m).sum() / data.size()[0]
[docs] def _filter_discrete_columns(self, train_data: List[str], discrete_columns: List[str]): """ We filter PII Column here, which PII would only be discrete for now. As PII would be generating from PII Generator which not synthetic from model. Besides we need to figure it out when to stop model fitting: The original data consists entirely of discrete column data, and all of this discrete column data is PII. For `train_data`, there are three possibilities for the columns type. - train_data = valid_discrete + valid_continue - train_data = valid_continue - train_data = valid_discrete For `discrete_columns`, discrete_columns = invalid_discrete(PII) + valid_discrete Thus, valid_discrete = discrete_columns - invalid_discrete = discrete_columns - Set.intersection(train_data, discrete_columns) Thus, original_data_is_all_PII: discrete_columns is not empty & train_data is empty """ # Discrete_columns is empty - simple an empty list, but we need to continue fitting continue columns if len(discrete_columns) == 0: return discrete_columns # Discrete_columns is not empty - check if train_data is empty for stop model fitting if len(train_data) == 0: self.fit_data_empty = True return discrete_columns # Filter valid discrete columns invalid_columns = set(discrete_columns) - set(train_data) return set(discrete_columns) - set(invalid_columns)
[docs] def _validate_discrete_columns(self, train_data, discrete_columns): """Check whether ``discrete_columns`` exists in ``train_data``. Args: train_data (numpy.ndarray or pandas.DataFrame or list): Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame. discrete_columns (list-like): List of discrete columns to be used to generate the Conditional Vector. If ``train_data`` is a Numpy array, this list should contain the integer indices of the columns. Otherwise, if it is a ``pandas.DataFrame``, this list should contain the column names. """ if isinstance(train_data, pd.DataFrame): invalid_columns = set(discrete_columns) - set(train_data.columns) elif isinstance(train_data, np.ndarray): invalid_columns = [] for column in discrete_columns: if column < 0 or column >= train_data.shape[1]: invalid_columns.append(column) elif isinstance(train_data, list): invalid_columns = set(discrete_columns) - set(train_data) else: raise TypeError("``train_data`` should be either pd.DataFrame or np.array.") if invalid_columns: raise ValueError(f"Invalid columns found: {invalid_columns}")
[docs] def set_device(self, device): """Set the `device` to be used ('GPU' or 'CPU).""" self._device = device if self._generator is not None: self._generator.to(self._device)
[docs] @hookimpl def register(manager): manager.register("CTGAN", CTGANSynthesizerModel)