浏览代码

Fix SeparateActorCritic export

/develop/add-fire/memoryclass
Ervin Teng 4 年前
当前提交
9ae22c61
共有 3 个文件被更改,包括 25 次插入19 次删除
  1. 11
      ml-agents/mlagents/trainers/policy/torch_policy.py
  2. 2
      ml-agents/mlagents/trainers/torch/model_serialization.py
  3. 31
      ml-agents/mlagents/trainers/torch/networks.py

11
ml-agents/mlagents/trainers/policy/torch_policy.py


conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)
# Save the m_size needed for export
self._export_m_size = self.m_size
# m_size needed for training is determined by network, not trainer settings
@property
def export_memory_size(self) -> int:
"""
Returns the memory size of the exported ONNX policy. This only includes the memory
of the Actor and not any auxillary networks.
"""
return self._export_m_size
def split_decision_step(self, decision_requests):
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs)

2
ml-agents/mlagents/trainers/torch/model_serialization.py


else []
)
dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.act_size)])
dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.m_size])
dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.export_memory_size])
# Need to pass all posslible inputs since currently keyword arguments is not
# supported by torch.nn.export()

31
ml-agents/mlagents/trainers/torch/networks.py


from typing import Callable, List, Dict, Tuple, Optional
import attr
import abc
import torch

if self.use_lstm:
# Resize to (batch, sequence length, encoding size)
encoding = encoding.reshape([-1, sequence_length, self.h_size])
# memories = torch.split(memories, self.m_size // 2, dim=-1)
# memories = torch.cat(memories, dim=-1)
return encoding, memories

# Give the Actor only half the memories. Note we previously validate
# that memory_size must be a multiple of 4.
self.use_lstm = network_settings.memory is not None
if network_settings.memory is not None:
self.half_mem_size = network_settings.memory.memory_size // 2
new_memory_settings = attr.evolve(
network_settings.memory, memory_size=self.half_mem_size
)
use_network_settings = attr.evolve(
network_settings, memory=new_memory_settings
)
else:
use_network_settings = network_settings
self.half_mem_size = 0
# if network_settings.memory is not None:
# self.half_mem_size = network_settings.memory.memory_size // 2
# new_memory_settings = attr.evolve(
# network_settings.memory, memory_size=self.half_mem_size
# )
# use_network_settings = attr.evolve(
# network_settings, memory=new_memory_settings
# )
# else:
# use_network_settings = network_settings
# self.half_mem_size = 0
use_network_settings,
network_settings,
act_type,
act_size,
conditional_sigma,

self.critic = ValueNetwork(
stream_names, observation_shapes, use_network_settings
)
self.critic = ValueNetwork(stream_names, observation_shapes, network_settings)
@property
def memory_size(self) -> int:

正在加载...
取消
保存