浏览代码

action model and network tests

/develop/actionmodel-csharp
Andrew Cohen 4 年前
当前提交
d984af1f
共有 2 个文件被更改,包括 100 次插入13 次删除
  1. 32
      ml-agents/mlagents/trainers/tests/torch/test_networks.py
  2. 81
      ml-agents/mlagents/trainers/tests/torch/test_action_model.py

32
ml-agents/mlagents/trainers/tests/torch/test_networks.py


from mlagents.trainers.torch.networks import (
NetworkBody,
ValueNetwork,
SimpleActor,
from mlagents.trainers.torch.distributions import (
GaussianDistInstance,
CategoricalDistInstance,
)
from mlagents_envs.base_env import ActionSpec

memory=NetworkSettings.MemorySettings() if lstm else None
)
obs_shapes = [(obs_size,)]
act_size = [2]
act_size = 2
mask = torch.ones([1, act_size * 2])
action_spec = ActionSpec.create_continuous(act_size[0])
# action_spec = ActionSpec.create_continuous(act_size[0])
action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size)))
actor = ac_type(obs_shapes, network_settings, action_spec, stream_names)
if lstm:
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size))

else:
assert value_out[stream].shape == (1,)
# Test get_dist_and_value
dists, value_out, mem_out = actor.get_dist_and_value(
[sample_obs], [], memories=memories
# Test get action stats and_value
action, log_probs, entropies, value_out, mem_out = actor.get_action_stats_and_value(
[sample_obs], [], memories=memories, masks=mask
if lstm:
assert action.continuous_tensor.shape == (64, 2)
else:
assert action.continuous_tensor.shape == (1, 2)
assert len(action.discrete_list) == 2
for _disc in action.discrete_list:
if lstm:
assert _disc.shape == (64, 1)
else:
assert _disc.shape == (1, 1)
for dist in dists:
assert isinstance(dist, GaussianDistInstance)
for stream in stream_names:
if lstm:
assert value_out[stream].shape == (network_settings.memory.sequence_length,)

81
ml-agents/mlagents/trainers/tests/torch/test_action_model.py


import pytest
from mlagents.torch_utils import torch
from mlagents.trainers.torch.action_model import ActionModel, DistInstances
from mlagents.trainers.torch.agent_action import AgentAction
from mlagents.trainers.torch.distributions import (
GaussianDistInstance,
CategoricalDistInstance,
)
from mlagents_envs.base_env import ActionSpec
def create_action_model(inp_size, act_size):
mask = torch.ones([1, act_size * 2])
action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size)))
action_model = ActionModel(inp_size, action_spec)
return action_model, mask
def test_get_dists():
inp_size = 4
act_size = 2
action_model, masks = create_action_model(inp_size, act_size)
sample_inp = torch.ones((1, inp_size))
dists = action_model._get_dists(sample_inp, masks=masks)
assert isinstance(dists.continuous, GaussianDistInstance)
assert len(dists.discrete) == 2
for _dist in dists.discrete:
assert isinstance(_dist, CategoricalDistInstance)
def test_sample_action():
inp_size = 4
act_size = 2
action_model, masks = create_action_model(inp_size, act_size)
sample_inp = torch.ones((1, inp_size))
dists = action_model._get_dists(sample_inp, masks=masks)
agent_action = action_model._sample_action(dists)
assert agent_action.continuous_tensor.shape == (1, 2)
assert len(agent_action.discrete_list) == 2
for _disc in agent_action.discrete_list:
assert _disc.shape == (1, 1)
def test_get_probs_and_entropy():
inp_size = 4
act_size = 2
action_model, masks = create_action_model(inp_size, act_size)
_continuous_dist = GaussianDistInstance(torch.zeros((1, 2)), torch.ones((1, 2)))
act_size = 2
test_prob = torch.tensor([[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)])
_discrete_dist_list = [
CategoricalDistInstance(test_prob),
CategoricalDistInstance(test_prob),
]
dist_tuple = DistInstances(_continuous_dist, _discrete_dist_list)
agent_action = AgentAction(
torch.zeros((1, 2)), [torch.tensor([0]), torch.tensor([1])]
)
log_probs, entropies = action_model._get_probs_and_entropy(agent_action, dist_tuple)
assert log_probs.continuous_tensor.shape == (1, 2)
assert len(log_probs.discrete_list) == 2
for _disc in log_probs.discrete_list:
assert _disc.shape == (1,)
assert len(log_probs.all_discrete_list) == 2
for _disc in log_probs.all_discrete_list:
assert _disc.shape == (1, 2)
for clp in log_probs.continuous_tensor[0]:
# Log prob of standard normal at 0
assert clp == pytest.approx(-0.919, abs=0.01)
assert log_probs.discrete_list[0] > log_probs.discrete_list[1]
for ent, val in zip(entropies[0], [1.4189, 1.4189, 0.6191, 0.6191]):
assert ent == pytest.approx(val, abs=0.01)
正在加载...
取消
保存