import sys
import os
import shutil
from itertools import chain, combinations
if sys.version_info[0] == 2:
import ConfigParser as configparser
else:
import configparser
from .formula import Formula
from .kwargs import MODEL_INITIALIZATION_KWARGS, BAYES_KWARGS, NN_BAYES_KWARGS, \
PLOT_KWARGS_CORE, PLOT_KWARGS_OTHER
PLOT_KEYS_CORE = [x.key for x in PLOT_KWARGS_CORE]
PLOT_KEYS_OTHER = [x.key for x in PLOT_KWARGS_OTHER]
# Thanks to Brice (https://stackoverflow.com/users/140264/brice) at Stack Overflow for this
def powerset(iterable):
xs = list(iterable)
return chain.from_iterable(combinations(xs,n) for n in range(1, len(xs)+1))
[docs]
class Config(object):
"""
Parses an \*.ini file and stores settings needed to define a set of CDR experiments.
:param path: Path to \*.ini file
"""
def __init__(self, path):
self.current_model = None
config = configparser.ConfigParser()
config.optionxform = str
assert os.path.exists(path), 'Config file %s does not exist' %path
config.read(path)
data = config['data']
global_settings = config['global_settings']
if 'cdr_settings' in config:
cdr_settings = config['cdr_settings']
elif 'dtsr_settings' in config:
cdr_settings = config['dtsr_settings']
else:
config['cdr_settings'] = {}
cdr_settings = config['cdr_settings']
########
# Data #
########
self.X_train = data.get('X_train').split()
self.X_dev = data.get('X_dev', None)
if self.X_dev:
self.X_dev = self.X_dev.split()
self.X_test = data.get('X_test', None)
if self.X_test:
self.X_test = self.X_test.split()
self.Y_train = data.get('Y_train', data.get('y_train', None))
assert self.Y_train, 'Y_train must be provided'
self.Y_train = self.Y_train.split()
self.Y_dev = data.get('Y_dev', data.get('y_dev', None))
if self.Y_dev:
self.Y_dev = self.Y_dev.split()
self.Y_test = data.get('Y_test', data.get('y_test', None))
if self.Y_test:
self.Y_test = self.Y_test.split()
sep = data.get('sep', ',')
if sep.lower() in ['', "' '", '" "', 's', 'space']:
sep = ' '
self.sep = sep
series_ids = data.get('series_ids')
if series_ids:
self.series_ids = series_ids.strip().split()
else:
self.series_ids = []
self.modulus = data.getint('modulus', 4)
split_ids = data.get('split_ids', '')
self.split_ids = split_ids.strip().split()
filters = data.get('filters', None)
if filters is None:
filters = []
else:
filters = filters.split(';')
for i in range(len(filters)):
f = filters[i].strip().split()
k = f[0]
v = ' '.join(f[1:])
filters[i] = (k, v)
self.filters = filters
self.history_length = data.getint('history_length', 128)
self.future_length = data.getint('future_length', 0)
self.t_delta_cutoff = data.getfloat('t_delta_cutoff', None)
self.merge_cols = data.get('merge_cols', None)
if self.merge_cols is not None:
self.merge_cols = self.merge_cols.split()
###################
# Global Settings #
###################
self.outdir = global_settings.get('outdir', None)
if self.outdir is None:
self.outdir = global_settings.get('logdir', None)
if self.outdir is None:
self.outdir = './cdr_model/'
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
if not os.path.normpath(os.path.realpath(path)) == os.path.normpath(os.path.realpath(self.outdir + '/config.ini')):
shutil.copy2(path, self.outdir + '/config.ini')
self.use_gpu_if_available = global_settings.getboolean('use_gpu_if_available', True)
#################
# CDR Settings #
#################
self.global_cdr_settings = self.build_cdr_settings(cdr_settings, add_defaults=False)
############
# Model(s) #
############
# Add ablations and crossval folds
self.models = {}
model_fields = [m for m in config.keys() if m.startswith('model_')]
for model_field in model_fields:
model_name = model_field[6:]
formula = Formula(config[model_field]['formula'])
is_cdrnn = len(formula.nns_by_id) > 0
if not (model_name.startswith('LM') or model_name.startswith('GAM')):
reg_type = 'cdr'
else:
reg_type = model_name.split('_')[0]
model_config = config[model_field]
model_settings = self.build_cdr_settings(
model_config,
global_settings=self.global_cdr_settings,
is_cdr=reg_type=='cdr',
is_cdrnn=is_cdrnn
)
self.models[model_name] = model_settings
if reg_type == 'lme':
self.models[model_name]['correlated'] = config[model_field].getboolean('correlated', True)
if 'ablate' in config[model_field]:
for ablated in powerset(config[model_field]['ablate'].strip().split()):
ablated = list(ablated)
new_name = model_name + '!' + '!'.join(ablated)
formula = Formula(config[model_field]['formula'])
formula.ablate_impulses(ablated)
new_model = self.models[model_name].copy()
if reg_type == 'cdr':
new_model['formula'] = str(formula)
elif reg_type == 'lme':
new_model['formula'] = formula.to_lmer_formula_string(
z=False,
correlated=self.models[model_name]['correlated'],
transform_dirac=False
)
else:
raise ValueError('Ablation with reg_type "%s" not currently supported.' % reg_type)
new_model['ablated'] = set(ablated)
self.models[new_name] = new_model
self.ensembles = self.models.copy()
self.crossval_families = {}
self.expand_submodels()
self.irf_name_map = {
't_delta': 'Delay (s)',
'time_X': 'Timestamp (s)',
'X_time': 'Timestamp (s)',
'rate': 'Rate'
}
if 'irf_name_map' in config:
for x in config['irf_name_map']:
self.irf_name_map[x] = config['irf_name_map'][x]
def __getitem__(self, item):
if self.current_model is None:
return self.global_cdr_settings[item]
if self.current_model in self.models:
return self.models[self.current_model][item]
if self.current_model in self.ensembles:
return self.ensembles[self.current_model][item]
if self.current_model in self.crossval_families:
return self.crossval_families[self.current_model][item]
raise ValueError('There is no model named "%s" defined in the config file.' % self.current_model)
@property
def model_names(self):
return list(self.models.keys())
@property
def ensemble_names(self):
return list(self.ensembles.keys())
@property
def crossval_family_names(self):
return list(self.crossval_families.keys())
def get(self, item, default=None):
if (self.current_model is None and item in self.global_cdr_settings) or \
(self.current_model in self.models and item in self.models[self.current_model]) or \
(self.current_model in self.ensembles and item in self.ensembles[self.current_model]) or \
(self.current_model in self.crossval_families and item in self.crossval_families[self.current_model]):
return self[item]
return default
def __str__(self):
out = ''
V = vars(self)
for x in V:
out += '%s: %s\n' %(x, V[x])
return out
[docs]
def set_model(self, model_name=None):
"""
Change internal state to that of model named **model_name**.
``Config`` instances can store settings for multiple models.
``set_model()`` determines which model's settings are returned by ``Config`` getter methods.
:param model_name: ``str``; name of target model
:return: ``None``
"""
if model_name is None or \
model_name in self.models or \
model_name in self.ensembles or\
model_name in self.crossval_families:
self.current_model = model_name
else:
raise ValueError('There is no model named "%s" defined in the config file.' %model_name)
[docs]
def build_cdr_settings(self, settings, add_defaults=True, global_settings=None, is_cdr=True, is_cdrnn=False):
"""
Given a settings object parsed from a config file, compute CDR parameter dictionary.
:param settings: settings from a ``ConfigParser`` object.
:param add_defaults: ``bool``; whether to add default settings not explicitly specified in the config.
:param global_settings: ``dict`` or ``None``; dictionary of global defaults for parameters missing from **settings**.
:param is_cdr: ``bool``; whether this is a CDR(NN) model.
:param is_cdrnn: ``bool``; whether this is a CDRNN model.
:return: ``dict``; dictionary of settings key-value pairs.
"""
if global_settings is None:
global_settings = {}
out = {}
# Core fields
out['formula'] = settings.get('formula', None)
if is_cdr and out['formula']:
# Standardize the model string
out['formula'] = str(Formula(out['formula']))
# Model initialization keyword arguments
if is_cdr:
# Allowing settings to propagate for the wrong model type if specified
# allows the global config to specify defaults for multiple model types.
# Cross-type settings will only propagate if they are explicitly defined
# in the config (defaults are ignored).
# General initialization keyword arguments
for kwarg in MODEL_INITIALIZATION_KWARGS:
if add_defaults:
if kwarg.in_settings(settings) or kwarg.key not in global_settings:
out[kwarg.key] = kwarg.kwarg_from_config(settings, is_cdrnn=is_cdrnn)
else:
out[kwarg.key] = global_settings[kwarg.key]
elif kwarg.in_settings(settings):
out[kwarg.key] = kwarg.kwarg_from_config(settings, is_cdrnn=is_cdrnn)
if kwarg.key == 'plot_interactions' and kwarg.key in out and isinstance(out[kwarg.key], str):
out[kwarg.key] = out[kwarg.key].split()
out['ablated'] = set()
# Ensembling settings
n_ensemble = settings.get('n_ensemble', global_settings.get('n_ensemble', None))
if isinstance(n_ensemble, str):
if n_ensemble.lower() == 'none':
n_ensemble = None
else:
n_ensemble = int(n_ensemble)
out['n_ensemble'] = n_ensemble
# Cross validation settings
out['crossval_factor'] = settings.get('crossval_factor', global_settings.get('crossval_factor', None))
crossval_folds = settings.get('crossval_folds', global_settings.get('crossval_folds', None))
if crossval_folds is None:
crossval_folds = []
elif isinstance(crossval_folds, str):
crossval_folds = crossval_folds.split()
assert not out['crossval_factor'] or crossval_folds, 'crossval_folds must also be provided when crossval_factor is used.'
if out['crossval_factor']:
assert len(crossval_folds) > 1, 'Must have at least 2 folds for crossval'
out['crossval_folds'] = crossval_folds
out['crossval_use_dev_fold'] = settings.get(
'crossval_use_dev_fold',
global_settings.get('crossval_use_dev_fold', False)
)
if out['crossval_use_dev_fold']:
assert (len(crossval_folds)) > 2, 'Must have at least 3 folds when ``crossval_use_dev_fold`` is ``True``.'
out['crossval_fold'] = None
out['crossval_dev_fold'] = None
return out
[docs]
def expand_submodels(self):
"""
Expand models into cross-validation folds and/or ensembles.
return: ``None``
"""
model_names = self.model_names.copy()
for model_name in model_names:
self.set_model(model_name)
model_config = self.models[model_name]
crossval_factor = self['crossval_factor']
crossval_folds = self['crossval_folds']
crossval_use_dev_fold = self['crossval_use_dev_fold']
n_ensemble = self['n_ensemble']
if crossval_factor:
crossval_dev_folds = crossval_folds[-1:] + crossval_folds[:-1]
del self.models[model_name]
del self.ensembles[model_name]
self.crossval_families[model_name] = model_config
for fold, dev_fold in zip(crossval_folds, crossval_dev_folds):
_model_name = model_name + '.CV%s~%s' % (crossval_factor, fold)
_model_config = model_config.copy()
_model_config['crossval_fold'] = fold
if crossval_use_dev_fold:
_model_config['crossval_dev_fold'] = dev_fold
self.ensembles[_model_name] = _model_config
if n_ensemble is None:
self.models[_model_name] = _model_config
else:
for i in range(n_ensemble):
__model_name = _model_name + '.m%d' % i
self.models[__model_name] = _model_config
else:
if n_ensemble is not None:
del self.models[model_name]
for i in range(n_ensemble):
_model_name = model_name + '.m%d' % i
self.models[_model_name] = model_config
[docs]
class PlotConfig(object):
"""
Parses an \*.ini file and stores settings needed to define CDR plots
:param path: Path to \*.ini file
"""
def __init__(self, path=None):
if path is None:
self.settings_core, self.settings_other = self.build_plot_settings({})
else:
config = configparser.ConfigParser()
config.optionxform = str
assert os.path.exists(path), 'Config file %s does not exist' %path
config.read(path)
plot_settings = config['plot']
self.settings_core, self.settings_other = self.build_plot_settings(plot_settings)
def __getitem__(self, item):
if item in self.settings_core:
return self.settings_core[item]
return self.settings_other[item]
def __setitem__(self, key, value):
if key in PLOT_KEYS_CORE:
self.settings_core[key] = value
elif key in PLOT_KEYS_OTHER:
self.settings_other[key] = value
else:
raise ValueError('Attempted to set value for unrecognized plot kwarg %s' % key)
def get(self, item, default=None):
if item in self.settings_core:
return self.settings_core.get(item, default)
return self.settings_other.get(item, default)
[docs]
def build_plot_settings(self, settings):
"""
Given a settings object parsed from a config file, compute plot parameters.
:param settings: settings from a ``ConfigParser`` object.
:return: ``dict``; dictionary of settings key-value pairs.
"""
out_core = {}
out_other = {}
for kwarg in PLOT_KWARGS_CORE:
if kwarg.in_settings(settings):
val = kwarg.kwarg_from_config(settings)
if kwarg.key in ['responses', 'response_params', 'pred_names'] and val is not None:
val = val.split()
elif kwarg.key == 'prop_cycle_map' and val is not None:
val = val.split()
is_dict = len(val[-1].split(';')) == 2
if is_dict:
val_tmp = val
val = {}
for x in val_tmp:
k, v = x.split(';')
val[k] = int(v)
else:
val = [int(x) for x in val]
elif kwarg.key == 'ylim' and val is not None:
val = tuple(float(x) for x in val.split())
out_core[kwarg.key] = val
else:
out_core[kwarg.key] = kwarg.default_value
for kwarg in PLOT_KWARGS_OTHER:
if kwarg.in_settings(settings):
val = kwarg.kwarg_from_config(settings)
out_other[kwarg.key] = val
else:
out_other[kwarg.key] = kwarg.default_value
return out_core, out_other