浏览代码

Running LSTM for SAC

/develop/add-fire/memoryclass
Ervin Teng 4 年前
当前提交
37f986c8
共有 1 个文件被更改,包括 85 次插入15 次删除
  1. 100
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

100
ml-agents/mlagents/trainers/sac/optimizer_torch.py


import numpy as np
from typing import Dict, List, Mapping, cast, Tuple
from typing import Dict, List, Mapping, cast, Tuple, Optional
import attr
from mlagents_envs.logging_util import get_logger
from mlagents_envs.base_env import ActionType

self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
actions: torch.Tensor = None,
actions: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
q1_out, _ = self.q1_network(vec_inputs, vis_inputs, actions=actions)
q2_out, _ = self.q2_network(vec_inputs, vis_inputs, actions=actions)
q1_out, _ = self.q1_network(
vec_inputs,
vis_inputs,
actions=actions,
memories=memories,
sequence_length=sequence_length,
)
q2_out, _ = self.q2_network(
vec_inputs,
vis_inputs,
actions=actions,
memories=memories,
sequence_length=sequence_length,
)
return q1_out, q2_out
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings):

for name in self.stream_names
}
# Critics should have 1/2 of the memory of the policy
critic_memory = policy_network_settings.memory
if critic_memory is not None:
critic_memory = attr.evolve(
critic_memory, memory_size=critic_memory.memory_size // 2
)
value_network_settings = attr.evolve(
policy_network_settings, memory=critic_memory
)
policy_network_settings,
value_network_settings,
policy_network_settings,
value_network_settings,
)
self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0)

v_backup = min_policy_qs[name] - torch.sum(
_ent_coef * log_probs, dim=1
)
# print(log_probs, v_backup, _ent_coef, loss_masks)
value_loss = 0.5 * torch.mean(
loss_masks * torch.nn.functional.mse_loss(values[name], v_backup)
)

else:
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long)
memories = [
memories_list = [
if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
# LSTM shouldn't have sequence length <1, but stop it from going out of the index if true.
offset = 1 if self.policy.sequence_length > 1 else 0
next_memories_list = [
ModelUtils.list_to_tensor(
batch["memory"][i][: self.policy.m_size // 2]
) # only pass value part of memory to target network
for i in range(offset, len(batch["memory"]), self.policy.sequence_length)
]
if len(memories_list) > 0:
memories = torch.stack(memories_list).unsqueeze(0)
next_memories = torch.stack(next_memories_list).unsqueeze(0)
else:
memories = None
next_memories = None
# Q network memories are 0'ed out, since we don't have them during inference.
q_memories = torch.zeros(
(memories.shape[0], memories.shape[1], memories.shape[2] // 2)
)
vis_obs: List[torch.Tensor] = []
next_vis_obs: List[torch.Tensor] = []
if self.policy.use_vis_obs:

)
if self.policy.use_continuous_act:
squeezed_actions = actions.squeeze(-1)
q1p_out, q2p_out = self.value_network(vec_obs, vis_obs, sampled_actions)
q1_out, q2_out = self.value_network(vec_obs, vis_obs, squeezed_actions)
q1p_out, q2p_out = self.value_network(
vec_obs,
vis_obs,
sampled_actions,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
q1_out, q2_out = self.value_network(
vec_obs,
vis_obs,
squeezed_actions,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
q1p_out, q2p_out = self.value_network(vec_obs, vis_obs)
q1_out, q2_out = self.value_network(vec_obs, vis_obs)
q1p_out, q2p_out = self.value_network(
vec_obs,
vis_obs,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
q1_out, q2_out = self.value_network(
vec_obs,
vis_obs,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
target_values, _ = self.target_network(next_vec_obs, next_vis_obs)
target_values, _ = self.target_network(
next_vec_obs,
next_vis_obs,
memories=next_memories,
sequence_length=self.policy.sequence_length,
)
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32)
use_discrete = not self.policy.use_continuous_act
dones = ModelUtils.list_to_tensor(batch["done"])

正在加载...
取消
保存