Function overwrite

Synopsis

def overwrite(main_config_obj, args)

Description

Overwrites parameters with input flags

Args: main_config_obj (ConfigClass): config instance args (dict): arguments used to overwrite

Returns: ConfigClass: config instance

Source

Lines 96-175 in anyfig/anyfig_setup.py.

def overwrite(main_config_obj, args):
  """
  Overwrites parameters with input flags

  Args:
      main_config_obj (ConfigClass): config instance
      args (dict): arguments used to overwrite

  Returns:
      ConfigClass: config instance
  """

  # Sort on nested level to override shallow items first
  args = dict(sorted(args.items(), key=lambda item: item[0].count('.')))
  for argument_key, val in args.items():
    # Seperate nested keys into outer and inner
    outer_keys = argument_key.split('.')
    inner_key = outer_keys.pop(-1)
    base_err_msg = f"Can't set '{argument_key} = {val}'"

    # Check that the nested config has the attribute and is a config class
    config_obj = main_config_obj
    config_class = type(config_obj).__name__

    for key_idx, key_part in enumerate(argument_key.split('.')):
      err_msg = f"{base_err_msg}. '{key_part}' isn't an attribute in '{config_class}'"
      assert hasattr(config_obj, key_part), err_msg

      # Check if the config allows the argument
      figutils.check_allowed_input_argument(config_obj, key_part, argument_key)

      # Check if the outer attributes are config classes
      if key_idx < len(outer_keys):
        config_obj = getattr(config_obj, key_part)
        config_class = type(config_obj).__name__
        err_msg = f"{base_err_msg}. '{'.'.join(outer_keys)}' isn't a registered Anyfig config class"
        assert figutils.is_config_class(config_obj), err_msg

    value_class = type(getattr(config_obj, inner_key))
    base_err_msg = f"Input argument '{argument_key}' with value {val} can't create an object of the expected type"

    # Create new anyfig class object
    if figutils.is_config_class(value_class):
      value_obj = create_config(val)

    # Create new object that follows the InterfaceField's rules
    elif issubclass(value_class, fields.InterfaceField):
      field = getattr(config_obj, inner_key)

      if isinstance(value_class, fields.InputField):
        value_class = field.type_pattern
      else:
        value_class = type(field.value)

      try:
        val = value_class(val)
      except Exception as e:
        err_msg = f"{base_err_msg} {field.type_pattern}. {e}"
        raise RuntimeError(err_msg) from None
      field = field.update_value(inner_key, val, config_class)
      value_obj = field.finish_wrapping_phase(inner_key, config_class)

    # Create new object of previous value type with new value
    else:
      try:
        if isinstance(val, dict):  # Keyword specified cli-arguments
          value_obj = value_class(**val)
        else:
          value_obj = value_class(val)

      except Exception as e:
        err_msg = f"{base_err_msg} {value_class}. {e}"
        raise RuntimeError(err_msg) from None

    # Overwrite old value
    setattr(config_obj, inner_key, value_obj)

  return main_config_obj







Add Discussion as Guest

Log in