Source code for sdgx.models.manager
from __future__ import annotations
from typing import Any
from sdgx.exceptions import ManagerLoadModelError
from sdgx.manager import Manager
from sdgx.models import extension, ml, statistics
from sdgx.models.base import SynthesizerModel
from sdgx.models.extension import project_name as PROJECT_NAME
[docs]
class ModelManager(Manager):
register_type = SynthesizerModel
project_name = PROJECT_NAME
hookspecs_model = extension
@property
def registed_models(self):
"""
redirect to registed_cls
"""
return self.registed_cls
[docs]
def load_all_local_model(self):
self._load_dir(ml.single_table)
self._load_dir(ml.multi_tables)
self._load_dir(statistics.single_table)
self._load_dir(statistics.multi_tables)
[docs]
def init_model(self, model_name, **kwargs: dict[str, Any]) -> SynthesizerModel:
"""
redirect to init
"""
return self.init(model_name, **kwargs)
[docs]
def load(self, model: type[SynthesizerModel] | str, model_path, **kwargs) -> SynthesizerModel:
if not (isinstance(model, type) or isinstance(model, str)):
raise ManagerLoadModelError(
"model must be type of SynthesizerModel or str for model_name"
)
if isinstance(model, str):
model = self._normalize_name(model)
if isinstance(model, str) and model not in self.registed_models:
raise ManagerLoadModelError(f"{model} is not registered.")
model = model if isinstance(model, type) else self.registed_models[model]
try:
return model.load(model_path, **kwargs)
except Exception as e:
raise ManagerLoadModelError(e)