Source code for digideep.utility.toolbox

import numpy as np
import torch, random
from digideep.utility.logging import logger
import inspect
import os
import json, yaml
import copy

[docs]def dump_dict_as_json(filename, dic, sort_keys=False): """ This function dumps a python dictionary in ``json`` format to a file. Args: filename (path): The address to the file. dic (dict): The dictionary to be dumped in json format. It should be json-serializable. sort_keys(bool, False): Will sort the dictionary by its keys before dumping to the file. """ f = open(filename, 'w') f.write(json.dumps(dic, indent=2, sort_keys=sort_keys)) f.close()
[docs]def load_json_as_dict(filename): f = open(filename, 'r') try: dic = json.load(f) except json.JSONDecodeError as exc: print(exc) dic = {} f.close() return dic
[docs]def dump_dict_as_yaml(filename, dic): f = open(filename, 'w') f.write(yaml.dump(dic, indent=2)) f.close()
[docs]def load_yaml_as_dict(filename): f = open(filename, 'r') try: # TODO: PyYAML can become faster by using CLoader: (can we replace UnsafeLoader safely with it?) # from yaml import CLoader as Loader, CDumper as Dumper # https://stackoverflow.com/questions/27743711/can-i-speedup-yaml # https://github.com/yaml/pyyaml/wiki/PyYAML-yaml.load(input)-Deprecation # dic = yaml.load(f, Loader=yaml.FullLoader) dic = yaml.load(f, Loader=yaml.UnsafeLoader) except yaml.YAMLError as exc: print(exc) dic = {} f.close() return dic
[docs]def seed_all(seed, cuda_deterministic = False): np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) if cuda_deterministic and torch.cuda.is_available(): torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True
[docs]def set_rng_state(states): np.random.set_state(states['numpy']) random.setstate(states['random']) torch_states = torch.tensor(states['torch']) torch_cuda_states = [torch.tensor(s) for s in states['torch_cuda']] torch.set_rng_state(torch_states) torch.cuda.set_rng_state_all(torch_cuda_states)
# logger.fatal(">> 2 rands after load:", np.random.rand(), np.random.rand())
[docs]def get_rng_state(): states = {} states['numpy'] = np.random.get_state() states['random'] = random.getstate() states['torch'] = torch.get_rng_state().numpy() states['torch_cuda'] = [s.numpy() for s in torch.cuda.get_rng_state_all()] # logger.fatal(">> 2 rands after save:", np.random.rand(), np.random.rand()) return states
[docs]def get_module(addr): """ Return a instance of a module by using only its name. Args: addr (str): The name of the module which should be in the format ``MODULENAME[.SUBMODULE1[.SUBMODULE2[...]]]`` """ parts = addr.split('.') module = ".".join(parts) m = __import__( module ) for comp in parts[1:]: m = getattr(m, comp) return m
[docs]def get_class(addr): """ Return a instance of a class by using only its name. Args: addr (str): The name of the class/function which should be in the format ``MODULENAME[.SUBMODULE1[.SUBMODULE2[...]]].CLASSNAME`` """ parts = addr.split('.') module = ".".join(parts[:-1]) m = __import__( module ) for comp in parts[1:]: m = getattr(m, comp) return m
[docs]def count_parameters(model): """ Counts the number of parameters in a PyTorch model. """ return np.sum(p.numel() for p in list(model.parameters()) if p.requires_grad)
# def match_key(dict_target, dict_source, key, default): # if key in dict_source: # dict_target[key] = dict_source[key] # del dict_source[key] # else: # dict_target[key] = default
[docs]def strict_update(dict_target, dict_source): result = copy.deepcopy(dict_target) for key in dict_source: if key not in dict_target: logger.warn("The provided parameter '{}' was not available in the source dictionary.".format(key)) # continue result[key] = dict_source[key] return result