Source code for digideep.environment.make_environment

"""
This module is inspired by `pytorch-a2c-ppo-acktr <https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/>`_.
"""

import numpy as np
import os
from collections import OrderedDict

from digideep.utility.toolbox import get_module, get_class
from digideep.utility.logging import logger

################################################
## Viutal wrappers ##
#####################
from gym.wrappers.monitor import Monitor as MonitorVideoRecorder

from .common.atari_wrappers import make_atari, wrap_deepmind

# from .common.vec_env.vec_monitor import VecMonitor
from .common.monitor import Monitor

from .common.vec_env.subproc_vec_env import SubprocVecEnv
from .common.vec_env.dummy_vec_env import DummyVecEnv

## Our essential wrappers
from .wrappers.save_state import VecSaveState
from .wrappers.random_state import VecRandomState

from .wrappers.adapter import WrapperDummyMultiAgent
from .wrappers.adapter import WrapperDummyDictObs
from .wrappers.adapter import WrapperFlattenObsDict
from .wrappers.adapter import VecObsRewInfoActWrapper
################################################

from gym import spaces

################################################
###      Importing Environment Packages      ###
################################################
import gym
# Even though we don't need dm_control to be loaded here, it helps in initializing glfw when using MuJoCo 1.5.
# import digideep.environment.dmc2gym

from gym.envs.registration import registry
################################################



#############################
##### Utility Functions #####
#############################
[docs]def space2config(S): """Function to convert space's characteristics into a config-space dict. """ # S.__class__.__name__: "Discrete" / "Box" if isinstance(S, spaces.Discrete): typ = S.__class__.__name__ dim = np.int32(S.n) lim = (np.nan, np.nan) # Discrete Spaces do not have high/low config = {"typ":typ, "dim":dim, "lim":lim} elif isinstance(S, spaces.Box): typ = S.__class__.__name__ dim = S.shape # S.shape[0]: This "[0]" only supports 1d arrays. lim = (S.low.tolist(), S.high.tolist()) config = {"typ":typ, "dim":dim, "lim":lim} elif isinstance(S, spaces.Dict): config = OrderedDict() for k in S.spaces: config[k] = space2config(S.spaces[k]) else: logger.fatal("Unknown type for space:", type(S)) raise NotImplementedError return config
################################# ##### MakeEnvironment Class ##### #################################
[docs]class MakeEnvironment: """This class will make the environment. It will apply the wrappers to the environments as well. Tip: Except :class:`~digideep.environment.common.monitor.Monitor` environment, no environment will be applied on the environment unless explicitly specified. """ registered = False def __init__(self, session, mode, seed, **params): self.mode = mode # train/test/eval self.seed = seed self.session = session self.params = params # Won't we have several environment registrations by this? if params["from_module"]: try: get_module(params["from_module"]) except Exception as ex: logger.fatal("While importing user module:", ex) exit() # elif (params["from_params"]) and (not MakeEnvironment.registered): elif (params["from_params"]) and (not params["name"] in registry.env_specs): try: registry.register(**params["register_args"]) MakeEnvironment.registered = True except Exception as ex: logger.fatal("While registering from parameters:", ex) exit() # After all of these, check if environment is registered in the gym or not. if not params["name"] in registry.env_specs: logger.fatal("Environment '" + params["name"] + "' is not registered in the gym registry.") exit()
[docs] def make_env(self, rank, force_no_monitor=False, extra_env_kwargs={}): # import sys # For debugging def _f(): # The header of gym.make(.): `def make(id, **kwargs)` env = gym.make(self.params["name"], **extra_env_kwargs) # TODO: Use gym.seeding to generate good and independent random seeds env.seed(self.seed + rank) ## Atari environment wrappers is_atari = hasattr(gym.envs, 'atari') and isinstance(env.unwrapped, gym.envs.atari.atari_env.AtariEnv) if is_atari: env = make_atari(self.params["name"]) if is_atari and len(env.observation_space.shape) == 3: env = wrap_deepmind(env) ## Add monitoring wrappers (not optional). if not force_no_monitor: log_dir = os.path.join(self.session["path_monitor"], str(rank)) if not self.session.dry_run: env = Monitor(env, log_dir, **self.params["main_wrappers"]["Monitor"]) else: env = Monitor(env, "/tmp", **self.params["main_wrappers"]["Monitor"]) ## Add a video recorder if mode == "eval". if self.mode == "eval" and not self.session.dry_run: videos_dir = os.path.join(self.session["path_videos"], str(rank)) # force will be true when resuming training from a checkpoint. env = MonitorVideoRecorder(env, videos_dir, video_callable=lambda id:True, force=self.session.is_resumed) ## Dummy Dict Action and Observation if not isinstance(env.action_space, spaces.Dict): env = WrapperDummyMultiAgent(env, **self.params["main_wrappers"]["WrapperDummyMultiAgent"]) if not isinstance(env.observation_space, spaces.Dict): env = WrapperDummyDictObs(env, **self.params["main_wrappers"]["WrapperDummyDictObs"]) ## Flatten the observation_space (which is by now of spaces.Dict type.) # spaces.Dict({"obs1":spaces.Box, "obs2": spaces.Dict({"image1":spaces.Box, "sensor2":spaces.Discrete})}) # Will be: # spaces.Dict({"/obs1":spaces.Box, "/obs2/image1":spaces.Box, "/obs2/sensor2":spaces.Discrete}) env = WrapperFlattenObsDict(env) ## NOTE: We do not flatten the action_space, since we usually deal with 1-level dicts for actions. ## If nested actions are to be considered we can upgrade action_spaces to flattened spaces. ## Adding arbitrary wrapper stack env = self.run_wrapper_stack(env, self.params["norm_wrappers"]) return env return _f
[docs] def create_envs(self, num_workers=1, force_no_monitor=False, extra_env_kwargs={}): envs = [self.make_env(rank=idx, force_no_monitor=force_no_monitor, extra_env_kwargs=extra_env_kwargs) for idx in range(num_workers)] ## NOTE: We do not use DummyVecEnvs when num_workers==1 to avoid running glfw.init() on the Main process. if self.mode == "eval": envs = DummyVecEnv(envs) else: envs = SubprocVecEnv(envs) ## Handling random states envs = VecRandomState(envs) ## Converting data structure of obs/rew/infos/actions: envs = VecObsRewInfoActWrapper(envs) ## Monitor seems to have more interesting features than VecMonitor. So we may not use VecMonitor. # envs = VecMonitor(envs, 'test.log') ## Adding arbitrary wrapper stack envs = self.run_wrapper_stack(envs, self.params["vect_wrappers"]) ## We must add VecSaveState as the last wrapper to save the state of stateful wrappers recursively. envs = VecSaveState(envs) return envs
[docs] def run_wrapper_stack(self, env, stack): """ Apply a series of wrappers. """ for index in range(len(stack)): if stack[index]["enabled"]: wrapper_class = get_class(stack[index]["name"]) # We pass mode to the wrapper as well, so the wrapper can adjust itself. if "request_for_args" in stack[index]: for rfa in stack[index]["request_for_args"]: logger(" Adding argument {} to the wrapper {}".format(rfa, stack[index]["name"])) if rfa == "session_state": if self.session: stack[index]["args"]["session_state"] = self.session.state # TODO: Move the "mode" to optional parameter that can be requested! # elif rfa == "mode": # stack[index]["args"]["mode"] = self.mode else: logger.fatal(" Argument {} not found!".format(rfa)) exit() env = wrapper_class(env, mode=self.mode, **stack[index]["args"]) return env
[docs] def get_config(self): """This function will generate a dict of interesting specifications of the environment. Note: Observation and action can be nested spaces.Dict. """ # _f = self.make_env(rank=0, force_no_monitor=True) venv = self.create_envs(num_workers=1, force_no_monitor=True) venv.reset() config = { 'env_type': venv.env_type, 'action_space' : space2config(venv.action_space), # This is type of action space: Discrete, Box, ... 'observation_space': space2config(venv.observation_space), # Observation space # 'reward_range': env.reward_range, 'max_episode_steps': venv.spec.max_episode_steps, # Maximum allowable steps # 'dt': env.dt() if hasattr(env, "dt") else None # The delta t as the timestep of the environment. } venv.close() return config