浏览代码

write to proto

/test-recurrent-gail
Andrew Cohen 4 年前
当前提交
0cc2956d
共有 3 个文件被更改,包括 11 次插入0 次删除
  1. 9
      ml-agents-envs/mlagents_envs/rpc_utils.py
  2. 1
      ml-agents/mlagents/trainers/tests/simple_test_envs.py
  3. 1
      ml-agents/mlagents/trainers/tests/test_simple_rl.py

9
ml-agents-envs/mlagents_envs/rpc_utils.py


action_mask = np.split(action_mask, indices, axis=1)
return BatchedStepResult(obs_list, rewards, done, max_step, agent_id, action_mask)
@timed
def proto_from_batched_step_result(batched_step_result: BatchedStepResult) -> AgentInfoProto:
reward = batched_step_result.reward
done = batched_step_result.done
max_step_reached = batched_step_result.max_step
agent_id = batched_step_result.agent_id
action_mask = batched_step_result.action_mask
observations = batched_step_result.obs
return AgentInfoProto(reward=reward, done=done, id=agent_id, max_step_reached=max_step_reached, action_mask=action_mask, observations=observations)
def _generate_split_indices(dims):
if len(dims) <= 1:

1
ml-agents/mlagents/trainers/tests/simple_test_envs.py


BatchedStepResult,
ActionType,
)
from mlagents_envs.rpc_utils import proto_from_batched_step_result
OBS_SIZE = 1
STEP_SIZE = 0.1

1
ml-agents/mlagents/trainers/tests/test_simple_rl.py


_check_environment_trains(env, PPO_CONFIG)
<<<<<<< Updated upstream
@pytest.mark.parametrize("use_discrete", [True, False])
def test_recurrent_ppo(use_discrete):
env = Memory1DEnvironment([BRAIN_NAME], use_discrete=use_discrete)

正在加载...
取消
保存