Source code for digideep.environment.common.vec_env.subproc_vec_env

"""
The MIT License

Copyright (c) 2017 OpenAI (http://openai.com)

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

import numpy as np
from multiprocessing import Process, Pipe
from . import VecEnv, CloudpickleWrapper

[docs]def worker(remote, parent_remote, env_fn_wrapper): parent_remote.close() env = env_fn_wrapper.x() try: while True: cmd, data = remote.recv() if cmd == 'step': ob, reward, done, info = env.step(data) if done: ob = env.reset() remote.send((ob, reward, done, info)) elif cmd == 'reset': ob = env.reset() remote.send(ob) elif cmd == 'render': remote.send(env.render(mode='rgb_array')) elif cmd == 'close': remote.close() break elif cmd == 'get_spaces': remote.send((env.observation_space, env.action_space)) elif cmd == 'get_specs': remote.send(env.spec) elif cmd == 'get_type': remote.send(env.unwrapped.__module__) elif cmd == 'get_env_state': if hasattr(env.unwrapped, "get_env_state"): # print("We really got the env state!") remote.send(env.unwrapped.get_env_state()) else: remote.send(None) elif cmd == 'set_env_state': if hasattr(env.unwrapped, "set_env_state"): # print("We really set the env state!") remote.send(env.unwrapped.set_env_state(data)) else: remote.send(None) elif cmd == 'get_rng_state': remote.send(env.unwrapped.np_random.get_state()) elif cmd == 'set_rng_state': remote.send(env.unwrapped.np_random.set_state(data)) else: raise NotImplementedError except KeyboardInterrupt: print('SubprocVecEnv worker: got KeyboardInterrupt') finally: env.close()
[docs]class SubprocVecEnv(VecEnv): """ VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes. Recommended to use when num_envs > 1 and step() can be a bottleneck. """ def __init__(self, env_fns, spaces=None): """ Arguments: env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable """ self.waiting = False self.closed = False nenvs = len(env_fns) self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] for p in self.ps: p.daemon = True # if the main process crashes, we should not cause things to hang p.start() for remote in self.work_remotes: remote.close() self.remotes[0].send(('get_spaces', None)) observation_space, action_space = self.remotes[0].recv() self.viewer = None self._delayed_init_flag = False # Here we get the current spec of the env. But later we will update it in '_delayed_init' self.remotes[0].send(('get_specs', None)) spec = self.remotes[0].recv() # Get the type of the environment, which is the main class that has created the environment self.remotes[0].send(('get_type', None)) env_type = self.remotes[0].recv() VecEnv.__init__(self, len(env_fns), observation_space, action_space, spec, env_type) def _delayed_init(self): """We get the spec later, because we know that some environments are late in creating their spec's. """ if self._delayed_init_flag: return self._delayed_init_flag = True # It will be a new spec, BUT it's too late! self.remotes[0].send(('get_specs', None)) spec = self.remotes[0].recv() # TODO: Update the self.spec attributes with all attributes from the new spec! self.spec.__dict__.update(spec.__dict__)
[docs] def step_async(self, actions): self._assert_not_closed() for remote, action in zip(self.remotes, actions): remote.send(('step', action)) self.waiting = True
[docs] def step_wait(self): self._assert_not_closed() results = [remote.recv() for remote in self.remotes] self.waiting = False obs, rews, dones, infos = zip(*results) return np.stack(obs), np.stack(rews), np.stack(dones), infos
[docs] def reset(self): self._assert_not_closed() for remote in self.remotes: remote.send(('reset', None)) result = np.stack([remote.recv() for remote in self.remotes]) self._delayed_init() return result
[docs] def close_extras(self): self.closed = True if self.waiting: for remote in self.remotes: remote.recv() for remote in self.remotes: remote.send(('close', None)) for p in self.ps: p.join()
[docs] def get_images(self): self._assert_not_closed() for pipe in self.remotes: pipe.send(('render', None)) imgs = [pipe.recv() for pipe in self.remotes] return imgs
def _assert_not_closed(self): assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()"
[docs] def set_rng_state(self, states): for remote, state in zip(self.remotes, states): remote.send(('set_rng_state', state)) results = [remote.recv() for remote in self.remotes] return results
[docs] def get_rng_state(self): for remote in self.remotes: remote.send(('get_rng_state', None)) states = [remote.recv() for remote in self.remotes] return states
[docs] def state_dict(self): for remote in self.remotes: remote.send(('get_env_state', None)) states = [remote.recv() for remote in self.remotes] return states
[docs] def load_state_dict(self, state_dicts): for remote, state_dict in zip(self.remotes, state_dicts): remote.send(('set_env_state', state_dict)) results = [remote.recv() for remote in self.remotes] return results