Function overwrite
Synopsis
def overwrite(main_config_obj, args)
Description
Overwrites parameters with input flags
Args: main_config_obj (ConfigClass): config instance args (dict): arguments used to overwrite
Returns: ConfigClass: config instance
Source
Lines 96-175 in anyfig/anyfig_setup.py.
def overwrite(main_config_obj, args):
"""
Overwrites parameters with input flags
Args:
main_config_obj (ConfigClass): config instance
args (dict): arguments used to overwrite
Returns:
ConfigClass: config instance
"""
# Sort on nested level to override shallow items first
args = dict(sorted(args.items(), key=lambda item: item[0].count('.')))
for argument_key, val in args.items():
# Seperate nested keys into outer and inner
outer_keys = argument_key.split('.')
inner_key = outer_keys.pop(-1)
base_err_msg = f"Can't set '{argument_key} = {val}'"
# Check that the nested config has the attribute and is a config class
config_obj = main_config_obj
config_class = type(config_obj).__name__
for key_idx, key_part in enumerate(argument_key.split('.')):
err_msg = f"{base_err_msg}. '{key_part}' isn't an attribute in '{config_class}'"
assert hasattr(config_obj, key_part), err_msg
# Check if the config allows the argument
figutils.check_allowed_input_argument(config_obj, key_part, argument_key)
# Check if the outer attributes are config classes
if key_idx < len(outer_keys):
config_obj = getattr(config_obj, key_part)
config_class = type(config_obj).__name__
err_msg = f"{base_err_msg}. '{'.'.join(outer_keys)}' isn't a registered Anyfig config class"
assert figutils.is_config_class(config_obj), err_msg
value_class = type(getattr(config_obj, inner_key))
base_err_msg = f"Input argument '{argument_key}' with value {val} can't create an object of the expected type"
# Create new anyfig class object
if figutils.is_config_class(value_class):
value_obj = create_config(val)
# Create new object that follows the InterfaceField's rules
elif issubclass(value_class, fields.InterfaceField):
field = getattr(config_obj, inner_key)
if isinstance(value_class, fields.InputField):
value_class = field.type_pattern
else:
value_class = type(field.value)
try:
val = value_class(val)
except Exception as e:
err_msg = f"{base_err_msg} {field.type_pattern}. {e}"
raise RuntimeError(err_msg) from None
field = field.update_value(inner_key, val, config_class)
value_obj = field.finish_wrapping_phase(inner_key, config_class)
# Create new object of previous value type with new value
else:
try:
if isinstance(val, dict): # Keyword specified cli-arguments
value_obj = value_class(**val)
else:
value_obj = value_class(val)
except Exception as e:
err_msg = f"{base_err_msg} {value_class}. {e}"
raise RuntimeError(err_msg) from None
# Overwrite old value
setattr(config_obj, inner_key, value_obj)
return main_config_obj