浏览代码
Move 'take_action' into Policy class (#1669)
Move 'take_action' into Policy class (#1669)
* Move 'take_action' into Policy class This refactor is part of Actor-Trainer separation. Since policies will be distributed across actors in separate processes which share a single trainer, taking an action should be the responsibility of the policy. This change makes a few smaller changes: * Combines `take_action` logic between trainers, making it more generic * Adds an `ActionInfo` data class to be more explicit about the data returned by the policy, only used by TrainerController and policy for now. * Moves trainer stats logic out of `take_action` and into `add_experiences` * Renames 'take_action' to 'get_action'/develop-generalizationTraining-TrainerController
GitHub
6 年前
当前提交
c258b1c3
共有 9 个文件被更改,包括 146 次插入 和 87 次删除
-
6ml-agents/mlagents/trainers/__init__.py
-
17ml-agents/mlagents/trainers/bc/trainer.py
-
24ml-agents/mlagents/trainers/policy.py
-
29ml-agents/mlagents/trainers/ppo/trainer.py
-
14ml-agents/mlagents/trainers/trainer.py
-
47ml-agents/mlagents/trainers/trainer_controller.py
-
36ml-agents/tests/trainers/test_trainer_controller.py
-
9ml-agents/mlagents/trainers/action_info.py
-
51ml-agents/tests/trainers/test_policy.py
|
|||
from typing import NamedTuple, Any, Dict, Optional |
|||
|
|||
|
|||
class ActionInfo(NamedTuple): |
|||
action: Any |
|||
memory: Any |
|||
text: Any |
|||
value: Any |
|||
outputs: Optional[Dict[str, Any]] |
|
|||
from mlagents.trainers.policy import * |
|||
from unittest.mock import MagicMock |
|||
|
|||
def basic_mock_brain(): |
|||
mock_brain = MagicMock() |
|||
mock_brain.vector_action_space_type = "continuous" |
|||
return mock_brain |
|||
|
|||
def basic_params(): |
|||
return { |
|||
"use_recurrent": False, |
|||
"model_path": "my/path" |
|||
} |
|||
|
|||
|
|||
def test_take_action_returns_empty_with_no_agents(): |
|||
test_seed = 3 |
|||
policy = Policy(test_seed, basic_mock_brain(), basic_params()) |
|||
no_agent_brain_info = BrainInfo([], [], [], agents=[]) |
|||
result = policy.get_action(no_agent_brain_info) |
|||
assert(result == ActionInfo([], [], [], None, None)) |
|||
|
|||
|
|||
def test_take_action_returns_nones_on_missing_values(): |
|||
test_seed = 3 |
|||
policy = Policy(test_seed, basic_mock_brain(), basic_params()) |
|||
policy.evaluate = MagicMock(return_value={}) |
|||
brain_info_with_agents = BrainInfo([], [], [], agents=['an-agent-id']) |
|||
result = policy.get_action(brain_info_with_agents) |
|||
assert(result == ActionInfo(None, None, None, None, {})) |
|||
|
|||
|
|||
def test_take_action_returns_action_info_when_available(): |
|||
test_seed = 3 |
|||
policy = Policy(test_seed, basic_mock_brain(), basic_params()) |
|||
policy_eval_out = { |
|||
'action': np.array([1.0]), |
|||
'memory_out': np.array([2.5]), |
|||
'value': np.array([1.1]) |
|||
} |
|||
policy.evaluate = MagicMock(return_value=policy_eval_out) |
|||
brain_info_with_agents = BrainInfo([], [], [], agents=['an-agent-id']) |
|||
result = policy.get_action(brain_info_with_agents) |
|||
expected = ActionInfo( |
|||
policy_eval_out['action'], |
|||
policy_eval_out['memory_out'], |
|||
None, |
|||
policy_eval_out['value'], |
|||
policy_eval_out |
|||
) |
|||
assert (result == expected) |
撰写
预览
正在加载...
取消
保存
Reference in new issue