浏览代码

Additional changes

/develop/centralizedcritic
Ervin Teng 4 年前
当前提交
ad439fb6
共有 2 个文件被更改,包括 7 次插入21 次删除
  1. 5
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 23
      ml-agents/mlagents/trainers/torch/networks.py

5
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


memory = torch.zeros([1, 1, self.policy.m_size])
value_estimates, next_memory = self.policy.actor_critic.critic_pass(
current_obs, memory, sequence_length=batch.num_experiences, critic_obs=critic_obs
current_obs,
memory,
sequence_length=batch.num_experiences,
critic_obs=critic_obs,
)
next_value_estimate, _ = self.policy.actor_critic.critic_pass(

23
ml-agents/mlagents/trainers/torch/networks.py


def forward(
self,
net_inputs: List[torch.Tensor],
inputs: List[torch.Tensor],
net_inputs, actions, memories, sequence_length
inputs, actions, memories, sequence_length
)
output = self.value_heads(encoding)
return output, memories

def forward(
self,
net_inputs: List[List[torch.Tensor]],
inputs: List[torch.Tensor],
inputs: List[List[torch.Tensor]],
actions: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,

self.act_size_vector_deprecated,
]
return tuple(export_out)
def get_action_stats(
self,
net_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[
AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
]:
encoding, memories = self.network_body(
net_inputs, memories=memories, sequence_length=sequence_length
)
action, log_probs, entropies = self.action_model(encoding, masks)
return action, log_probs, entropies, memories
class SharedActorCritic(SimpleActor, ActorCritic):

正在加载...
取消
保存