浏览代码

Fix discrete actions and GridWorld

/develop/add-fire
Ervin Teng 4 年前
当前提交
68169434
共有 2 个文件被更改,包括 19 次插入11 次删除
  1. 25
      ml-agents/mlagents/trainers/models_torch.py
  2. 5
      ml-agents/mlagents/trainers/policy/torch_policy.py

25
ml-agents/mlagents/trainers/models_torch.py


MultiCategoricalDistribution,
)
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.models import EncoderType
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
EncoderFunction = Callable[

EPSILON = 1e-7
class EncoderType(Enum):
SIMPLE = "simple"
NATURE_CNN = "nature_cnn"
RESNET = "resnet"
class ActionType(Enum):

hidden = encoder(vis_input)
vis_embeds.append(hidden)
#embedding = vec_embeds[0]
# embedding = vec_embeds[0]
if len(vec_embeds) > 0:
vec_embeds = torch.stack(vec_embeds, dim=-1).sum(dim=-1)
if len(vis_embeds) > 0:

vec_inputs, vis_inputs, masks, memories, sequence_length
)
sampled_actions = self.sample_action(dists)
return sampled_actions, dists[0].pdf(sampled_actions), self.version_number, self.memory_size, self.is_continuous_int, self.act_size_vector
return (
sampled_actions,
dists[0].pdf(sampled_actions),
self.version_number,
self.memory_size,
self.is_continuous_int,
self.act_size_vector,
)
class Critic(nn.Module):

self.layers = []
last_channel = initial_channels
for _, channel in enumerate(n_channels):
self.layers.append(nn.Conv2d(last_channel, channel, [3, 3], [1, 1], padding=1))
self.layers.append(
nn.Conv2d(last_channel, channel, [3, 3], [1, 1], padding=1)
)
self.layers.append(nn.MaxPool2d([3, 3], [2, 2]))
height, width = pool_out_shape((height, width), 3)
for _ in range(n_blocks):

def forward(self, visual_obs):
batch_size = visual_obs.shape[0]
hidden = visual_obs
for idx, layer in enumerate(self.layers):
for layer in self.layers:
if isinstance(layer, nn.Module):
hidden = layer(hidden)
elif isinstance(layer, list):

EncoderType.NATURE_CNN: NatureVisualEncoder,
EncoderType.RESNET: ResNetVisualEncoder,
}
print(encoder_type, ENCODER_FUNCTION_BY_TYPE.get(encoder_type))
return ENCODER_FUNCTION_BY_TYPE.get(encoder_type)
@staticmethod

5
ml-agents/mlagents/trainers/policy/torch_policy.py


actions = self.actor_critic.sample_action(dists)
log_probs, entropies = self.actor_critic.get_probs_and_entropy(actions, dists)
actions = torch.squeeze(actions)
if self.use_continuous_act:
actions = actions[:, :, 0]
else:
actions = actions[:, 0, :]
return actions, log_probs, entropies, value_heads, memories

正在加载...
取消
保存