浏览代码

Some more progress - still broken

/develop/centralizedcritic/counterfact
Ervin Teng 4 年前
当前提交
092ea232
共有 3 个文件被更改,包括 25 次插入22 次删除
  1. 6
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 4
      ml-agents/mlagents/trainers/ppo/trainer.py
  3. 37
      ml-agents/mlagents/trainers/torch/networks.py

6
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


from typing import Dict, Optional, Tuple, List
from mlagents.torch_utils import torch
from mlagents.trainers.torch.agent_action import AgentAction
import numpy as np
from mlagents.trainers.buffer import AgentBuffer

memory = torch.zeros([1, 1, self.policy.m_size])
actions = AgentAction.from_dict(batch)
next_obs = [obs.unsqueeze(0) for obs in next_obs]
critic_obs = TeamObsUtil.from_buffer(batch, n_obs)

value_estimates, marg_val_estimates, next_memory = self.policy.actor_critic.critic_pass(
current_obs,
actions,
# Actions is a hack here, we need the next actions
next_obs, next_memory, sequence_length=1, critic_obs=next_critic_obs
next_obs, actions, next_memory, sequence_length=1, critic_obs=next_critic_obs
)
for name, estimate in value_estimates.items():

4
ml-agents/mlagents/trainers/ppo/trainer.py


:param lambd: GAE weighing factor.
:return: list of advantage estimates for time-steps t to T.
"""
value_estimates = np.append(value_estimates, value_next)
delta_t = rewards + gamma * value_estimates[1:] - marginalized_value_estimates
advantage = discount_rewards(r=delta_t, gamma=gamma * lambd)
advantage = value_estimates - marginalized_value_estimates
return advantage

37
ml-agents/mlagents/trainers/torch/networks.py


sensor_specs: List[SensorSpec],
network_settings: NetworkSettings,
action_spec: ActionSpec,
num_obs_heads: int = 1,
):
super().__init__()
self.normalize = network_settings.normalize

def __init__(
self,
stream_names: List[str],
observation_shapes: List[Tuple[int, ...]],
observation_shapes: List[SensorSpec],
encoded_act_size: int = 0,
action_spec: ActionSpec,
observation_shapes, network_settings, encoded_act_size=encoded_act_size
observation_shapes, network_settings, action_spec=action_spec
)
if network_settings.memory is not None:
encoding_size = network_settings.memory.memory_size // 2

def forward(
self,
inputs: List[List[torch.Tensor]],
actions: Optional[torch.Tensor] = None,
value_inputs: List[List[torch.Tensor]],
q_inputs: List[List[torch.Tensor]],
q_actions: List[AgentAction] = None,
inputs, actions, memories, sequence_length
value_inputs, q_inputs, q_actions, memories, sequence_length
)
output = self.value_heads(encoding)
return output, memories

)
self.stream_names = stream_names
self.critic = CentralizedValueNetwork(
stream_names, sensor_specs, network_settings
stream_names, sensor_specs, network_settings, action_spec=action_spec
)
@property

def critic_pass(
self,
inputs: List[torch.Tensor],
actions: AgentAction,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
critic_obs: List[List[torch.Tensor]] = None,

if critic_obs is not None and critic_obs:
all_net_inputs.extend(critic_obs)
mar_value_outputs, _ = self.critic(
critic_obs, memories=critic_mem, sequence_length=sequence_length
)
else:
mar_value_outputs = None
mar_value_outputs, _ = self.critic(
all_net_inputs, [], [], memories=critic_mem, sequence_length=sequence_length
)
all_net_inputs, memories=critic_mem, sequence_length=sequence_length
critic_obs, [inputs], [actions], memories=critic_mem, sequence_length=sequence_length
)
if mar_value_outputs is None:
mar_value_outputs = value_outputs

all_net_inputs = [inputs]
if critic_obs is not None and critic_obs:
all_net_inputs.extend(critic_obs)
mar_value_outputs, _ = self.critic(
critic_obs, memories=critic_mem, sequence_length=sequence_length
)
critic_obs = []
mar_value_outputs, _ = self.critic(
all_net_inputs, [], [], memories=critic_mem, sequence_length=sequence_length
)
all_net_inputs, memories=critic_mem, sequence_length=sequence_length
[inputs], critic_obs, actions, memories=critic_mem, sequence_length=sequence_length
)
return log_probs, entropies, value_outputs, mar_value_outputs

正在加载...
取消
保存