Source code for digideep.pipeline.runner

import gc
import sys
import time
import signal

from digideep.environment import Explorer
from digideep.utility.logging import logger
from digideep.utility.toolbox import seed_all, get_class, get_module, set_rng_state, get_rng_state
from digideep.utility.profiling import profiler, KeepTime
from digideep.utility.monitoring import monitor
from collections import OrderedDict as odict

# Runner should be irrelevant of torch, gym, dm_control, etc.

[docs]class Runner: """ This class controls the main flow of the program. The main components of the class are: * explorer: A dictionary containing :class:`~digideep.environment.explorer.Explorer` for the three modes of ``train``, ``test``, and ``eval``. An :class:`~digideep.environment.explorer.Explorer` is a class which handles running simulations concurrently in several environments. * memory: The component responsible for storing the trajectories generated by the explorer. * agents: A dictionary containing all agents in the environment. This class also prints the :class:`~digideep.utility.profiling.Profiler` and :class:`~digideep.utility.monitoring.Monitor` information. Also the main serialization burden is on this class. The rest of classes only need to implement the ``state_dict`` and ``load_state_dict`` functions for serialization. Caution: The lines of code for testing while training are commented out. """ def __init__(self, params): self.params = params self.state = {} self.state["i_frame"] = 0 # self.state["i_rolls"] = 0 self.state["i_cycle"] = 0 self.state["i_epoch"] = 0 self.state["loading"] = False
[docs] def lazy_connect_signal(self): # Connect shell signals signal.signal(signal.SIGUSR1, self.on_sigusr1_received) signal.signal(signal.SIGINT, self.on_sigint_received)
[docs] def on_sigint_received(self, signalNumber, frame): print("") # To print on the next line where ^C is printed. self.ctrl_c_count += 1 if self.ctrl_c_count == 1: logger.fatal("Received CTRL+C. Will terminate process after cycle is over.") logger.fatal("Press CTRL+C one more time to exit without saving.") self.ready_for_termination = True self.save_major_checkpoint = True elif self.ctrl_c_count == 2: # NOTE: Kill all subprocesses logger.fatal("Received CTRL+C for the second time. Will terminate immediately.") self.ready_for_termination = True self.save_major_checkpoint = False sys.exit(1)
[docs] def on_sigusr1_received(self, signalNumber, frame): logger.fatal("Received SIGUSR1 signal. Will terminate process after cycle is over.") self.ready_for_termination = True self.save_major_checkpoint = True
[docs] def lazy_init(self): """ Initialization of attributes which are not part of the object state. These need lazy initialization due to proper initialization when loading from a checkpoint. """ self.time_start = time.time() logger.fatal("Execution (max) timer started ...") self.save_major_checkpoint = False self.ready_for_termination = False self.iterations = 0 profiler.reset() monitor.reset() self.monitor_epoch() # Ignore interrupt signals for_subprocesses signal.signal(signal.SIGINT, signal.SIG_IGN) self.ctrl_c_count = 0
[docs] def monitor_epoch(self): monitor.set_meta_key("epoch", self.state["i_epoch"])
[docs] def start(self, session): """A function to initialize the objects and load their states (if loading from a checkpoint). This function must be called before using the :func:`train` and :func:`enjoy` functions. If we are starting from scrarch, we will: * Instantiate all internal components using parameters. If we are loading from a saved checkpoint, we will: * Instantiate all internal components using old parameters. * Load all state dicts. * (OPTIONAL) Override parameters. """ # Up to now, states of the runner are already loaded. Objects' states not, however. self.lazy_init() self.session = session # TODO: Load random states to continue random number generation from pause. seed_all(**self.params["runner"]["randargs"]) # The order is as it is: self.instantiate() self.load() self.override() # NOTE: We lazily connect signals so it is not spawned in the child processes. self.lazy_connect_signal() # NOTE: We set this state for the future. # Because all future loading would # involve actual loading of states. self.state["loading"] = True
[docs] def instantiate(self): """ This function will instantiate the memory, the explorers, and the agents with their specific parameters. """ ## Instantiate Memory self.memory = {} for memory_name in self.params["memory"]: memory_class = get_class(self.params["memory"][memory_name]["type"]) self.memory[memory_name] = memory_class(self.session, mode=memory_name, **self.params["memory"][memory_name]["args"]) ## Instantiate Agents self.agents = {} action_generator = {} for agent_name in self.params["agents"]: agent_class = get_class(self.params["agents"][agent_name]["type"]) self.agents[agent_name] = agent_class(self.session, self.memory, **self.params["agents"][agent_name]) ## Instantiate Explorers # All explorers: train/test/eval explorer_list = list(self.params["explorer"].keys()) assert "train" in explorer_list, "'train' mode explorer is not defined in the explorer parameters." assert "test" in explorer_list, "'test' mode explorer is not defined in the explorer parameters." assert "eval" in explorer_list, "'eval' mode explorer is not defined in the explorer parameters." self.explorer = {} explorer_list.remove("eval") for e in explorer_list: # if e == "eval": # continue self.explorer[e] = Explorer(self.session, agents=self.agents, **self.params["explorer"][e]) # "eval" must be created as the last explorer to avoid GLFW connection to X11 issues. # if "eval" in self.explorer: # NOTE: We have made creation of "eval" explorer conditioned on the session being playing. # This is to make sure that no connections to X11 exist in the main thread. if self.session.is_playing: self.explorer["eval"] = Explorer(self.session, agents=self.agents, **self.params["explorer"]["eval"])
############################################################### ### SERIALIZATION ### #####################
[docs] def state_dict(self): """ This function will return the states of all internal objects: * Agents * Explorer (only the ``train`` mode) * Memory Todo: Memory should be dumped in a separate file, since it can get really large. Moreover, it should be optional. """ random_state = get_rng_state() agents_state = {} for agent_name in self.agents: agents_state[agent_name] = self.agents[agent_name].state_dict() ## The state of explorer["test"] and explorer["eval"] is not important for us. explorer_state = {} for explorer_name in self.explorer: ## We used to save only states of the train explorer: # if not explorer_name in ["train"]: # continue # OLD LOGIC # if explorer_name in ["test", "eval"]: # continue explorer_state[explorer_name] = self.explorer[explorer_name].state_dict() memory_state = {} for memory_name in self.memory: memory_state[memory_name] = self.memory[memory_name].state_dict() return {'random_state':random_state, 'agents':agents_state, 'explorer':explorer_state, 'memory':memory_state}
[docs] def load_state_dict(self, state_dict): """ This function will load the states of the internal objects: * Agents * Explorers (state of ``train`` mode would be loaded for ``test`` and ``eval`` as well) * Memory """ random_state = state_dict['random_state'] set_rng_state(random_state) agents_state = state_dict['agents'] for agent_name in agents_state: self.agents[agent_name].load_state_dict(agents_state[agent_name]) explorer_state = state_dict['explorer'] for explorer_name in explorer_state: if explorer_name == "eval": continue self.explorer[explorer_name].load_state_dict(explorer_state[explorer_name]) memory_state = state_dict['memory'] for memory_name in memory_state: self.memory[memory_name].load_state_dict(memory_state[memory_name])
## NEWER LOGIC # We do intentionally update the state of test/eval explorers with the state of "train" explorer. # We are only interested in states of the reward/observation normalizers. # for explorer_name in self.explorer: # ## NOTE: Which environments must we reset? # self.explorer[explorer_name].reset() # # if explorer_name in explorer_state: # # continue # logger.warn("Loading explorer '{}' states from 'train'.".format(explorer_name)) # self._sync_normalizations(source_explorer="train", target_explorer=explorer_name) ## OLD LOGIC # # We do intentionally update the state of test/eval explorers with the state of "train" explorer. # # We are only interested in states of the reward/observation normalizers. # self._sync_normalizations(source_explorer="train", target_explorer="test") # self._sync_normalizations(source_explorer="train", target_explorer="eval") # # # self.explorer["test"].load_state_dict(self.explorer["train"].state_dict()) # # self.explorer["eval"].load_state_dict(self.explorer["train"].state_dict()) # # self.explorer["test"].reset() # self.explorer["eval"].reset() ###
[docs] def override(self): pass
##################### ### SAVE RUNNER ### ##################### # UPON SAVING/LOADING THE RUNNER WITH THE SELF.SAVE FUNCTION: # * save --> self.state_dict --> session.save_states --> torch.save --> states.pt # |-> session.save_runner --> self.__getstate__ --> pickle.dump --> runner.pt # * pickle.load --> __setstate__ # ... Later on ... # --> self.start --> self.instantiate --> self.load --> session.load_states --> self.load_state_dict --> self.override # The __setstate__ and __getstate__ functions are for loading/saving the "runner" through pickle.dump / pickle.load # def __getstate__(self): """ This function is used by ``pickle.dump`` when we save the :class:`Runner`. This saves the ``params`` and ``state`` of the runner. """ # This is at the time of pickling state = {'params':self.params, 'state':self.state} return state def __setstate__(self, state): """ This function is used by ``pickle.load`` when we load the :class:`Runner`. """ # state['state']['loading'] = True self.__dict__.update(state) ###
[docs] def save_final_checkpoint(self): self.save(forced=True) # Store snapshots for all memories only if simulation ended gracefully. for memory_name in self.memory: if hasattr(self.memory[memory_name], "save_snapshot"): self.memory[memory_name].save_snapshot(self.state["i_epoch"])
[docs] def save(self, forced=False): """ This is a high-level function for saving both the state of objects and the runner object. It will use helper functions from :class:`~digideep.pipeline.session.Session`. """ if forced or (self.state["i_epoch"] % self.params["runner"]["save_int"] == 0): ## 1. state_dict: Saved with torch.save self.session.save_states(self.state_dict(), self.state["i_epoch"]) ## 2. runner: Saved with pickle.dump self.session.save_runner(self, self.state["i_epoch"])
[docs] def load(self): # This function does not directly work with files. Instead, it """ This is a function used by the :func:`start` function to load the states of internal objects from the checkpoint and update the objects state dicts. """ if self.state["loading"]: state_dict = self.session.load_states() self.load_state_dict(state_dict) self.load_memory()
# We leave loading = True. All future loadings would be either resume or play.
[docs] def load_memory(self): if self.session.is_resumed: for memory_name in self.memory: if hasattr(self.memory[memory_name], "load_snapshot"): self.memory[memory_name].load_snapshot()
###############################################################
[docs] def train_cycle(self): # 1. Do Experiment with KeepTime("train"): chunk = self.explorer["train"].update() # 2. Store Result with KeepTime("store"): self.memory["train"].store(chunk) # 3. Update Agent with KeepTime("update"): for agent_name in self.agents: with KeepTime(agent_name): self.agents[agent_name].update()
[docs] def train(self): """ The function that runs the training loop. See Also: :ref:`ref-how-runner-works` """ try: # while self.state["i_epoch"] < self.state["n_epochs"]: while (self.state["i_epoch"] < self.params["runner"]["n_epochs"]) and not self.termination_check(): self.state["i_cycle"] = 0 while self.state["i_cycle"] < self.params["runner"]["n_cycles"]: with KeepTime("/"): self.train_cycle() self.state["i_cycle"] += 1 # End of Cycle self.state["i_epoch"] += 1 self.monitor_epoch() self.iterations += 1 # NOTE: We may save/test after each cycle or at intervals. # 1. Perform the test self.test() # 2. Log self.log() # 3. Save self.save() # Free up memory from garbage. gc.collect() # Garbage Collection except (KeyboardInterrupt, SystemExit): logger.fatal('Operation stopped by the user ...') finally: self.finalize()
[docs] def termination_check(self): termination = self.ready_for_termination if self.params["runner"]["max_time"]: if time.time() - self.time_start >= self.params["runner"]["max_time"] * 3600: self.save_major_checkpoint = True termination = True logger.fatal('Simulation maximum allowed execution time exceeded ...') if self.params["runner"]["max_iter"]: # TODO: Should be current_epoch - initial_epoch >= max_iter: ... if self.iterations >= self.params["runner"]["max_iter"]: self.save_major_checkpoint = True termination = True logger.fatal('Simulation maximum allowed execution iterations exceeded ...') return termination
[docs] def finalize(self, save=True): logger.fatal('End of operation ...') # Mark session as done if we have went through all epochs. # if self.state["i_epoch"] == self.state["n_epochs"]: if self.state["i_epoch"] == self.params["runner"]["n_epochs"]: self.session.mark_as_done() self.save_major_checkpoint = True if save and self.save_major_checkpoint: self.save_final_checkpoint() # self.save_major_checkpoint = False # Close all explorers benignly: for key in self.explorer: self.explorer[key].close()
[docs] def test(self): # Make the states of the two explorers train/test exactly the same, for the states of the environments. if self.params["runner"]["test_act"]: if self.state["i_epoch"] % self.params["runner"]["test_int"] == 0: with KeepTime("/"): with KeepTime("test"): self._sync_normalizations(source_explorer="train", target_explorer="test") # self.explorer["test"].load_state_dict(self.explorer["train"].state_dict()) self.explorer["test"].reset() # TODO: Do update until "win_size" episodes get executed. # That is in: self.explorer["test"].state["n_episode"] # Make sure that n_steps is 1. # If num_worker>1 it is possible that we get more than required test episodes. # The rest will be reported with the next test run. self.explorer["test"].update()
[docs] def enjoy(self): #i.e. eval """This function evaluates the current policy in the environment. It only runs the explorer in a loop. .. code-block:: python # Do a cycle while not done: # Explore explorer["eval"].update() log() """ # TODO: We need more elegant mechanisms to handle this import. import glfw glfw.init() try: self._sync_normalizations(source_explorer="train", target_explorer="eval") self.explorer["eval"].reset() while True: # Cycles self.state["i_cycle"] = 0 while self.state["i_cycle"] < self.params["runner"]["n_cycles"]: with KeepTime("/"): # 1. Do Experiment with KeepTime("eval"): self.explorer["eval"].update() self.log() self.state["i_cycle"] += 1 # Log except (KeyboardInterrupt, SystemExit): logger.fatal('Operation stopped by the user ...') finally: self.finalize(save=False)
[docs] def custom(self): raise NotImplementedError()
##################### def _sync_normalizations(self, source_explorer, target_explorer): state_dict = self.explorer[source_explorer].state_dict() # digideep.environment.wrappers.normalizers:VecNormalizeObsDict Observation Normalization States # digideep.environment.wrappers.normalizers:VecNormalizeRew Reward Normalizing States # digideep.environment.wrappers.random_state:VecRandomState Random Generator States # digideep.environment.common.vec_env.subproc_vec_env:SubprocVecEnv Physical states keys = ["digideep.environment.wrappers.normalizers:VecNormalizeObsDict", "digideep.environment.wrappers.normalizers:VecNormalizeRew"] state_dict_mod = {} for k in keys: if k in state_dict["envs"]: state_dict_mod[k] = state_dict["envs"][k] self.explorer[target_explorer].envs.load_state_dict(state_dict_mod) ##################### ## Logging Summary ## #####################
[docs] def log(self): """ The log function prints a summary of: * Frame rate and simulated frames. * Variables sent to the :class:`~digideep.utility.monitoring.Monitor`. * Profiling information, i.e. registered timing information in the :class:`~digideep.utility.profiling.Profiler`. """ # monitor.get_meta_key("frame") # monitor.get_meta_key("episode") # monitor.get_meta_key("epoch") frame = monitor.get_meta_key("frame") episode = monitor.get_meta_key("episode") n_frame = frame - self.state["i_frame"] self.state["i_frame"] = frame elapsed = profiler.get_time_overall("/") overall = int(n_frame / elapsed) logger("---------------------------------------------------------") logger("Epoch({cycle:3d}cy)={epoch:4d} | Frame={frame:4.1e} | Episodes={episode:4.1e} | Overall({n_frame:4.1e}F/{e_time:4.1f}s)={freq:4d}Hz".format( cycle=self.params["runner"]["n_cycles"], epoch=self.state["i_epoch"], frame=frame, episode=episode, n_frame=n_frame, e_time=elapsed, freq=overall ) ) # Printing monitoring information: logger("MONITORING:\n"+str(monitor)) monitor.dump() monitor.reset() # Printing profiling information: logger("PROFILING:\n"+str(profiler)) meta = odict({"epoch":self.state["i_epoch"], "frame":frame, "episode":episode}) profiler.dump(meta) profiler.reset() print("")