浏览代码

Separate Actor/Critic, remove ActorCritics

/develop/action-slice
Andrew Cohen 4 年前
当前提交
eeabb974
共有 1 个文件被更改,包括 59 次插入234 次删除
  1. 293
      ml-agents/mlagents/trainers/torch/networks.py

293
ml-agents/mlagents/trainers/torch/networks.py


encoding_size = network_settings.hidden_units
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream)
def update_normalization(self, buffer: AgentBuffer) -> None:
self.network_body.update_normalization(buffer)
def critic_pass(
self,
inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
"""
Get value outputs for the given obs.
:param inputs: List of inputs as tensors.
:param memories: Tensor of memories, if using memory. Otherwise, None.
:returns: Dict of reward stream to output tensor for values.
"""
value_outputs, critic_mem_out = self.forward(
inputs, memories=memories, sequence_length=sequence_length
)
return value_outputs, critic_mem_out
def forward(
self,
inputs: List[torch.Tensor],

"""
pass
def get_action_stats(
def get_action_and_stats(
self,
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,

"""
Returns sampled actions.
If memory is enabled, return the memories as well.
:param vec_inputs: A List of vector inputs as tensors.
:param vis_inputs: A List of visual inputs as tensors.
:param inputs: A List of inputs as tensors.
:param masks: If using discrete actions, a Tensor of action masks.
:param memories: If using memory, a Tensor of initial memories.
:param sequence_length: If using memory, the sequence length.

pass
@abc.abstractmethod
def forward(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
var_len_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
) -> Tuple[Union[int, torch.Tensor], ...]:
"""
Forward pass of the Actor for inference. This is required for export to ONNX, and
the inputs and outputs of this method should not be changed without a respective change
in the ONNX export code.
"""
pass
class ActorCritic(Actor):
@abc.abstractmethod
def critic_pass(
def get_stats(
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
"""
Get value outputs for the given obs.
:param inputs: List of inputs as tensors.
:param memories: Tensor of memories, if using memory. Otherwise, None.
:returns: Dict of reward stream to output tensor for values.
"""
pass
@abc.abstractmethod
def get_action_stats_and_value(
self,
inputs: List[torch.Tensor],
actions: AgentAction,
) -> Tuple[
AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
]:
) -> Tuple[ActionLogProbs, torch.Tensor]:
Returns sampled actions and value estimates.
Returns log_probs for actions and entropies.
:param inputs: A List of vector inputs as tensors.
:param inputs: A List of inputs as tensors.
:param actions: AgentAction of actions.
:return: A Tuple of AgentAction, ActionLogProbs, entropies, Dict of reward signal
name to value estimate, and memories. Memories will be None if not using memory.
:return: A Tuple of AgentAction, ActionLogProbs, entropies, and memories.
Memories will be None if not using memory.
@abc.abstractproperty
def memory_size(self):
@abc.abstractmethod
def forward(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
var_len_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
) -> Tuple[Union[int, torch.Tensor], ...]:
Returns the size of the memory (same size used as input and output in the other
methods) used by this Actor.
Forward pass of the Actor for inference. This is required for export to ONNX, and
the inputs and outputs of this method should not be changed without a respective change
in the ONNX export code.
"""
pass

def update_normalization(self, buffer: AgentBuffer) -> None:
self.network_body.update_normalization(buffer)
def get_action_stats(
def get_action_and_stats(
self,
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,

)
action, log_probs, entropies = self.action_model(encoding, masks)
return action, log_probs, entropies, memories
def get_stats(
self,
inputs: List[torch.Tensor],
actions: AgentAction,
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[ActionLogProbs, torch.Tensor]:
encoding, actor_mem_outs = self.network_body(
inputs, memories=memories, sequence_length=sequence_length
)
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions)
return log_probs, entropies
def forward(
self,

self.act_size_vector_deprecated,
]
return tuple(export_out)
class SharedActorCritic(SimpleActor, ActorCritic):
def __init__(
self,
observation_specs: List[ObservationSpec],
network_settings: NetworkSettings,
action_spec: ActionSpec,
stream_names: List[str],
conditional_sigma: bool = False,
tanh_squash: bool = False,
):
self.use_lstm = network_settings.memory is not None
super().__init__(
observation_specs,
network_settings,
action_spec,
conditional_sigma,
tanh_squash,
)
self.stream_names = stream_names
self.value_heads = ValueHeads(stream_names, self.encoding_size)
def critic_pass(
self,
inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
encoding, memories_out = self.network_body(
inputs, memories=memories, sequence_length=sequence_length
)
return self.value_heads(encoding), memories_out
def get_stats_and_value(
self,
inputs: List[torch.Tensor],
actions: AgentAction,
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]:
encoding, memories = self.network_body(
inputs, memories=memories, sequence_length=sequence_length
)
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions)
value_outputs = self.value_heads(encoding)
return log_probs, entropies, value_outputs
def get_action_stats_and_value(
self,
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[
AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
]:
encoding, memories = self.network_body(
inputs, memories=memories, sequence_length=sequence_length
)
action, log_probs, entropies = self.action_model(encoding, masks)
value_outputs = self.value_heads(encoding)
return action, log_probs, entropies, value_outputs, memories
class SeparateActorCritic(SimpleActor, ActorCritic):
def __init__(
self,
observation_specs: List[ObservationSpec],
network_settings: NetworkSettings,
action_spec: ActionSpec,
stream_names: List[str],
conditional_sigma: bool = False,
tanh_squash: bool = False,
):
self.use_lstm = network_settings.memory is not None
super().__init__(
observation_specs,
network_settings,
action_spec,
conditional_sigma,
tanh_squash,
)
self.stream_names = stream_names
self.critic = ValueNetwork(stream_names, observation_specs, network_settings)
@property
def memory_size(self) -> int:
return self.network_body.memory_size + self.critic.memory_size
def _get_actor_critic_mem(
self, memories: Optional[torch.Tensor] = None
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
if self.use_lstm and memories is not None:
# Use only the back half of memories for critic and actor
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1)
actor_mem, critic_mem = actor_mem.contiguous(), critic_mem.contiguous()
else:
critic_mem = None
actor_mem = None
return actor_mem, critic_mem
def critic_pass(
self,
inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
actor_mem, critic_mem = self._get_actor_critic_mem(memories)
value_outputs, critic_mem_out = self.critic(
inputs, memories=critic_mem, sequence_length=sequence_length
)
if actor_mem is not None:
# Make memories with the actor mem unchanged
memories_out = torch.cat([actor_mem, critic_mem_out], dim=-1)
else:
memories_out = None
return value_outputs, memories_out
def get_stats_and_value(
self,
inputs: List[torch.Tensor],
actions: AgentAction,
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]:
actor_mem, critic_mem = self._get_actor_critic_mem(memories)
encoding, actor_mem_outs = self.network_body(
inputs, memories=actor_mem, sequence_length=sequence_length
)
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions)
value_outputs, critic_mem_outs = self.critic(
inputs, memories=critic_mem, sequence_length=sequence_length
)
return log_probs, entropies, value_outputs
def get_action_stats(
self,
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor, torch.Tensor]:
actor_mem, critic_mem = self._get_actor_critic_mem(memories)
action, log_probs, entropies, actor_mem_out = super().get_action_stats(
inputs, masks=masks, memories=actor_mem, sequence_length=sequence_length
)
if critic_mem is not None:
# Make memories with the actor mem unchanged
memories_out = torch.cat([actor_mem_out, critic_mem], dim=-1)
else:
memories_out = None
return action, log_probs, entropies, memories_out
def get_action_stats_and_value(
self,
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[
AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
]:
actor_mem, critic_mem = self._get_actor_critic_mem(memories)
encoding, actor_mem_outs = self.network_body(
inputs, memories=actor_mem, sequence_length=sequence_length
)
action, log_probs, entropies = self.action_model(encoding, masks)
value_outputs, critic_mem_outs = self.critic(
inputs, memories=critic_mem, sequence_length=sequence_length
)
if self.use_lstm:
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1)
else:
mem_out = None
return action, log_probs, entropies, value_outputs, mem_out
def update_normalization(self, buffer: AgentBuffer) -> None:
super().update_normalization(buffer)
self.critic.network_body.update_normalization(buffer)
class GlobalSteps(nn.Module):

正在加载...
取消
保存