浏览代码

Fix environment factory pickling on Windows (#1912)

SubprocessUnityEnvironment sends an environment factory function to
each worker which it can use to create a UnityEnvironment to interact
with. We use Python's standard multiprocessing library, which pickles
all data sent to the subprocess.  The built-in pickle library doesn't
pickle function objects on Windows machines (tested with Python 3.6 on
Windows 10 Pro).

This PR adds cloudpickle as a dependency in order to serialize the
environment factory. Other implementations of subprocess environments
do the same:
https://github.com/openai/baselines/blob/master/baselines/common/vec_env/subproc_vec_env.py
/develop-generalizationTraining-TrainerController
GitHub 6 年前
当前提交
d906273a
共有 2 个文件被更改,包括 10 次插入4 次删除
  1. 10
      ml-agents-envs/mlagents/envs/subprocess_environment.py
  2. 4
      ml-agents-envs/setup.py

10
ml-agents-envs/mlagents/envs/subprocess_environment.py


from typing import *
import copy
import numpy as np
import cloudpickle
from mlagents.envs import UnityEnvironment
from multiprocessing import Process, Pipe

self.process.join()
def worker(parent_conn: Connection, env_factory: Callable[[int], UnityEnvironment], worker_id: int):
def worker(parent_conn: Connection, pickled_env_factory: str, worker_id: int):
env_factory: Callable[[int], UnityEnvironment] = cloudpickle.loads(pickled_env_factory)
env = env_factory(worker_id)
def _send_response(cmd_name, payload):

env_factory: Callable[[int], BaseUnityEnvironment]
) -> UnityEnvWorker:
parent_conn, child_conn = Pipe()
child_process = Process(target=worker, args=(child_conn, env_factory, worker_id))
# Need to use cloudpickle for the env factory function since function objects aren't picklable
# on Windows as of Python 3.6.
pickled_env_factory = cloudpickle.dumps(env_factory)
child_process = Process(target=worker, args=(child_conn, pickled_env_factory, worker_id))
child_process.start()
return UnityEnvWorker(child_process, worker_id, parent_conn)

4
ml-agents-envs/setup.py


'numpy>=1.13.3,<=1.16.1',
'pytest>=3.2.2,<4.0.0',
'protobuf>=3.6,<3.7',
'grpcio>=1.11.0,<1.12.0'],
'grpcio>=1.11.0,<1.12.0',
'cloudpickle==0.8.1'],
python_requires=">=3.5,<3.8",
)
正在加载...
取消
保存