浏览代码

MultiInputNetBody

/develop/action-slice
Andrew Cohen 4 年前
当前提交
98d647de
共有 3 个文件被更改,包括 185 次插入187 次删除
  1. 194
      ml-agents/mlagents/trainers/coma/optimizer_torch.py
  2. 38
      ml-agents/mlagents/trainers/ppo/trainer.py
  3. 140
      ml-agents/mlagents/trainers/torch/networks.py

194
ml-agents/mlagents/trainers/coma/optimizer_torch.py


from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
from mlagents.trainers.settings import TrainerSettings, PPOSettings
from mlagents.trainers.torch.networks import Critic
from mlagents.trainers.torch.layers import EntityEmbedding, ResidualSelfAttention, LinearEncoder
from mlagents.trainers.torch.networks import Critic, MultiInputNetworkBody
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.agent_action import AgentAction
from mlagents.trainers.torch.action_log_probs import ActionLogProbs
from mlagents.trainers.torch.utils import ModelUtils

network_settings: NetworkSettings,
action_spec: ActionSpec,
):
super().__init__()
self.normalize = network_settings.normalize
self.use_lstm = network_settings.memory is not None
# Scale network depending on num agents
self.h_size = network_settings.hidden_units
self.m_size = (
network_settings.memory.memory_size
if network_settings.memory is not None
else 0
torch.nn.Module.__init__(self)
self.network_body = MultiInputNetworkBody(
observation_specs, network_settings, action_spec
self.processors, _input_size = ModelUtils.create_input_processors(
sensor_specs,
self.h_size,
network_settings.vis_encode_type,
normalize=self.normalize,
)
self.action_spec = action_spec
# Modules for self-attention
obs_only_ent_size = sum(_input_size)
q_ent_size = (
sum(_input_size)
+ sum(self.action_spec.discrete_branches)
+ self.action_spec.continuous_size
)
self.obs_encoder = EntityEmbedding(
0, obs_only_ent_size, None, self.h_size, concat_self=False
)
self.obs_action_encoder = EntityEmbedding(
0, q_ent_size, None, self.h_size, concat_self=False
)
self.self_attn = ResidualSelfAttention(self.h_size)
self.linear_encoder = LinearEncoder(
self.h_size,
network_settings.num_layers,
self.h_size,
kernel_gain=(0.125 / self.h_size) ** 0.5,
)
if self.use_lstm:
self.lstm = LSTM(self.h_size, self.m_size)
if network_settings.memory is not None:
encoding_size = network_settings.memory.memory_size // 2
self.lstm = None # type: ignorek
encoding_size = network_settings.hidden_units
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream)
return self.lstm.memory_size if self.use_lstm else 0
return self.network_body.memory_size
obs = ObsUtil.from_buffer(buffer, len(self.processors))
for vec_input, enc in zip(obs, self.processors):
if isinstance(enc, VectorInput):
enc.update_normalization(torch.as_tensor(vec_input))
def copy_normalization(self, other_network: "NetworkBody") -> None:
if self.normalize:
for n1, n2 in zip(self.processors, other_network.processors):
if isinstance(n1, VectorInput) and isinstance(n2, VectorInput):
n1.copy_normalization(n2)
def _get_masks_from_nans(self, obs_tensors: List[torch.Tensor]) -> torch.Tensor:
"""
Get attention masks by grabbing an arbitrary obs across all the agents
Since these are raw obs, the padded values are still NaN
"""
only_first_obs = [_all_obs[0] for _all_obs in obs_tensors]
obs_for_mask = torch.stack(only_first_obs, dim=1)
# Get the mask from nans
attn_mask = torch.any(obs_for_mask.isnan(), dim=2).type(torch.FloatTensor)
return attn_mask
self.network_body.update_normalization(buffer)
def baseline(
self,
self_obs: List[List[torch.Tensor]],

sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
self_attn_masks = []
f_inp = None
concat_f_inp = []
for inputs, action in zip(obs, actions):
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs
processed_obs = processor(obs_input)
encodes.append(processed_obs)
cat_encodes = [
torch.cat(encodes, dim=-1),
action.to_flat(self.action_spec.discrete_branches),
]
concat_f_inp.append(torch.cat(cat_encodes, dim=1))
if concat_f_inp:
f_inp = torch.stack(concat_f_inp, dim=1)
self_attn_masks.append(self._get_masks_from_nans(obs))
encoding, memories = self.network_body(obs_only=self_obs, obs=obs, actions=actions, memories, sequence_length)
value_outputs, critic_mem_out = self.forward(encoding, memories, sequence_length)
return value_outputs, critic_mem_out
concat_encoded_obs = []
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = self_obs[idx]
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs
processed_obs = processor(obs_input)
encodes.append(processed_obs)
concat_encoded_obs.append(torch.cat(encodes, dim=-1))
g_inp = torch.stack(concat_encoded_obs, dim=1)
# Get the mask from nans
self_attn_masks.append(self._get_masks_from_nans([self_obs]))
encoding, memories = self.forward(
f_inp,
g_inp,
self_attn_masks,
memories=memories,
sequence_length=sequence_length,
)
return encoding, memories
def critic_pass(
self,
obs: List[List[torch.Tensor]],

self_attn_masks = []
concat_encoded_obs = []
for inputs in obs:
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs
processed_obs = processor(obs_input)
encodes.append(processed_obs)
concat_encoded_obs.append(torch.cat(encodes, dim=-1))
g_inp = torch.stack(concat_encoded_obs, dim=1)
# Get the mask from nans
self_attn_masks.append(self._get_masks_from_nans(obs))
encoding, memories = self.forward(
None,
g_inp,
self_attn_masks,
memories=memories,
sequence_length=sequence_length,
)
return encoding, memories
encoding, memories = self.network_body(obs_only=obs, obs=None, actions=None, memories, sequence_length)
value_outputs, critic_mem_out = self.forward(encoding, memories, sequence_length)
return value_outputs, critic_mem_out
f_enc: torch.Tensor,
g_enc: torch.Tensor,
self_attn_masks: List[torch.Tensor],
encoding: torch.Tensor,
self_attn_inputs = []
if f_enc is not None:
self_attn_inputs.append(self.obs_action_encoder(None, f_enc))
if g_enc is not None:
self_attn_inputs.append(self.obs_encoder(None, g_enc))
encoded_entity = torch.cat(self_attn_inputs, dim=1)
encoded_state = self.self_attn(encoded_entity, self_attn_masks)
inputs = encoded_state
encoding = self.linear_encoder(inputs)
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)
encoding = encoding.reshape([-1, sequence_length, self.h_size])
encoding, memories = self.lstm(encoding, memories)
encoding = encoding.reshape([-1, self.m_size // 2])
return encoding, memories
output = self.value_heads(encoding)
return output, memories
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
"""

reward_signal_configs = trainer_settings.reward_signals
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()]
if policy.shared_critic:
self.value_net = policy.actor
else:
self.value_net = ValueNetwork(
reward_signal_names,
policy.behavior_spec.observation_specs,
network_settings=trainer_settings.network_settings,
)
self._critic = COMAValueNetwork(
reward_signal_names,
policy.behavior_spec.observation_specs,
network_settings=trainer_settings.network_settings,
policy.behavior_spec.action_spec,
)
params = list(self.policy.actor.parameters()) + list(
self.value_net.parameters()

@property
def critic(self):
return self.value_net
return self._critic
def ppo_value_loss(
def coma_value_loss(
self,
values: Dict[str, torch.Tensor],
old_values: Dict[str, torch.Tensor],

value_loss = torch.mean(torch.stack(value_losses))
return value_loss
def ppo_policy_loss(
def policy_policy_loss(
self,
advantages: torch.Tensor,
log_probs: torch.Tensor,

38
ml-agents/mlagents/trainers/ppo/trainer.py


return self.policy
def discount_rewards(r, gamma=0.99, value_next=0.0):
def lambd_return(r, value_estimates, gamma=0.99, lambd=0.8, value_next=0.0):
Computes discounted sum of future rewards for use in updating value estimate.
Computes lambda return.
:param value_estimates: List of value estimates.
:param lambd: n_step return weighting factor.
:return: discounted sum of future rewards as list.
:return: lambda return as a list
discounted_r = np.zeros_like(r)
running_add = value_next
for t in reversed(range(0, r.size)):
running_add = running_add * gamma + r[t]
discounted_r[t] = running_add
return discounted_r
returns = np.zeros_like(r)
returns[-1] = r[-1] + gamma * value_next
for t in reversed(range(0, r.size - 1)):
returns[t] = (
gamma * lambd * returns[t + 1]
+ r[t]
+ (1 - lambd) * gamma * value_estimates[t + 1]
)
def get_gae(rewards, value_estimates, value_next=0.0, gamma=0.99, lambd=0.95):
"""
Computes generalized advantage estimate for use in updating policy.
:param rewards: list of rewards for time-steps t to T.
:param value_next: Value estimate for time-step T+1.
:param value_estimates: list of value estimates for time-steps t to T.
:param gamma: Discount factor.
:param lambd: GAE weighing factor.
:return: list of advantage estimates for time-steps t to T.
"""
value_estimates = np.append(value_estimates, value_next)
delta_t = rewards + gamma * value_estimates[1:] - value_estimates[:-1]
advantage = discount_rewards(r=delta_t, gamma=gamma * lambd)
return advantage
return returns

140
ml-agents/mlagents/trainers/torch/networks.py


return encoding, memories
class MultiInputNetworkBody(torch.nn.Module, Critic):
def __init__(
self,
observation_specs: List[ObservationSpec],
network_settings: NetworkSettings,
action_spec: ActionSpec,
):
super().__init__()
self.normalize = network_settings.normalize
self.use_lstm = network_settings.memory is not None
# Scale network depending on num agents
self.h_size = network_settings.hidden_units
self.m_size = (
network_settings.memory.memory_size
if network_settings.memory is not None
else 0
)
self.processors, _input_size = ModelUtils.create_input_processors(
sensor_specs,
self.h_size,
network_settings.vis_encode_type,
normalize=self.normalize,
)
self.action_spec = action_spec
# Modules for self-attention
obs_only_ent_size = sum(_input_size)
q_ent_size = (
sum(_input_size)
+ sum(self.action_spec.discrete_branches)
+ self.action_spec.continuous_size
)
self.obs_encoder = EntityEmbedding(
0, obs_only_ent_size, None, self.h_size, concat_self=False
)
self.obs_action_encoder = EntityEmbedding(
0, q_ent_size, None, self.h_size, concat_self=False
)
self.self_attn = ResidualSelfAttention(self.h_size)
self.linear_encoder = LinearEncoder(
self.h_size,
network_settings.num_layers,
self.h_size,
kernel_gain=(0.125 / self.h_size) ** 0.5,
)
if self.use_lstm:
self.lstm = LSTM(self.h_size, self.m_size)
else:
self.lstm = None # type: ignorek
@property
def memory_size(self) -> int:
return self.lstm.memory_size if self.use_lstm else 0
def update_normalization(self, buffer: AgentBuffer) -> None:
obs = ObsUtil.from_buffer(buffer, len(self.processors))
for vec_input, enc in zip(obs, self.processors):
if isinstance(enc, VectorInput):
enc.update_normalization(torch.as_tensor(vec_input))
def copy_normalization(self, other_network: "MultiInputNetworkBody") -> None:
if self.normalize:
for n1, n2 in zip(self.processors, other_network.processors):
if isinstance(n1, VectorInput) and isinstance(n2, VectorInput):
n1.copy_normalization(n2)
def _get_masks_from_nans(self, obs_tensors: List[torch.Tensor]) -> torch.Tensor:
"""
Get attention masks by grabbing an arbitrary obs across all the agents
Since these are raw obs, the padded values are still NaN
"""
only_first_obs = [_all_obs[0] for _all_obs in obs_tensors]
obs_for_mask = torch.stack(only_first_obs, dim=1)
# Get the mask from nans
attn_mask = torch.any(obs_for_mask.isnan(), dim=2).type(torch.FloatTensor)
return attn_mask
def forward(
self,
obs_only: List[List[torch.Tensor]],
obs: List[List[torch.Tensor]],
actions: List[AgentAction],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
self_attn_masks = []
self_attn_inputs = []
concat_f_inp = []
for inputs, action in zip(obs, actions):
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs
processed_obs = processor(obs_input)
encodes.append(processed_obs)
cat_encodes = [
torch.cat(encodes, dim=-1),
action.to_flat(self.action_spec.discrete_branches),
]
concat_f_inp.append(torch.cat(cat_encodes, dim=1))
if concat_f_inp:
f_inp = torch.stack(concat_f_inp, dim=1)
self_attn_masks.append(self._get_masks_from_nans(obs))
self_attn_inputs.append(self.obs_action_encoder(None, f_inp))
concat_encoded_obs = []
for inputs in obs_only:
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs
processed_obs = processor(obs_input)
encodes.append(processed_obs)
concat_encoded_obs.append(torch.cat(encodes, dim=-1))
g_inp = torch.stack(concat_encoded_obs, dim=1)
self_attn_masks.append(self._get_masks_from_nans())
self_attn_inputs.append(self.obs_encoder(None, g_inp))
encoded_entity = torch.cat(self_attn_inputs, dim=1)
encoded_state = self.self_attn(encoded_entity, self_attn_masks)
encoding = self.linear_encoder(encoded_state)
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)
encoding = encoding.reshape([-1, sequence_length, self.h_size])
encoding, memories = self.lstm(encoding, memories)
encoding = encoding.reshape([-1, self.m_size // 2])
return encoding, memories
class Critic(abc.ABC):
@abc.abstractmethod
def update_normalization(self, buffer: AgentBuffer) -> None:

正在加载...
取消
保存