浏览代码

jit for continuous control

/develop/jit
Ruo-Ping Dong 4 年前
当前提交
f5dee9d1
共有 7 个文件被更改,包括 112 次插入84 次删除
  1. 2
      ml-agents/mlagents/trainers/model_saver/torch_model_saver.py
  2. 7
      ml-agents/mlagents/trainers/policy/torch_policy.py
  3. 2
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  4. 4
      ml-agents/mlagents/trainers/ppo/trainer.py
  5. 48
      ml-agents/mlagents/trainers/torch/distributions.py
  6. 2
      ml-agents/mlagents/trainers/torch/encoders.py
  7. 131
      ml-agents/mlagents/trainers/torch/networks.py

2
ml-agents/mlagents/trainers/model_saver/torch_model_saver.py


}
torch.save(state_dict, f"{checkpoint_path}.pt")
torch.save(state_dict, os.path.join(self.model_path, "checkpoint.pt"))
self.export(checkpoint_path, behavior_name)
# self.export(checkpoint_path, behavior_name)
return checkpoint_path
def export(self, output_filepath: str, behavior_name: str) -> None:

7
ml-agents/mlagents/trainers/policy/torch_policy.py


conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)
use_jit = True
if use_jit:
self.actor_critic = torch.jit.script(self.actor_critic)
print(self.actor_critic)
# Save the m_size needed for export
self._export_m_size = self.m_size
# m_size needed for training is determined by network, not trainer settings

2
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


]
if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
else:
memories = None
if self.policy.use_vis_obs:
vis_obs = []

4
ml-agents/mlagents/trainers/ppo/trainer.py


agent_buffer_trajectory = trajectory.to_agentbuffer()
# Update the normalization
if self.is_training:
self.policy.update_normalization(agent_buffer_trajectory["vector_obs"])
# if self.is_training:
# self.policy.update_normalization(agent_buffer_trajectory["vector_obs"])
# Get all value estimates
value_estimates, value_next = self.optimizer.get_trajectory_value_estimates(

48
ml-agents/mlagents/trainers/torch/distributions.py


import abc
from typing import List
import typing
import torch
from torch import nn
import numpy as np

EPSILON = 1e-7 # Small value to avoid divide by zero
# EPSILON = 1e-7 # Small value to avoid divide by zero
class DistInstance(nn.Module, abc.ABC):

pass
class GaussianDistInstance(DistInstance):
class GaussianDistInstance:
super().__init__()
# super().__init__()
self.mean = mean
self.std = std

def log_prob(self, value):
var = self.std ** 2
log_scale = torch.log(self.std + EPSILON)
log_scale = torch.log(self.std)
-((value - self.mean) ** 2) / (2 * var + EPSILON)
-((value - self.mean) ** 2) / (2 * var)
- log_scale
- math.log(math.sqrt(2 * math.pi))
)

return torch.exp(log_prob)
def entropy(self):
return 0.5 * torch.log(2 * math.pi * math.e * self.std + EPSILON)
return 0.5 * torch.log(2 * math.pi * math.e * self.std)
class TanhGaussianDistInstance(GaussianDistInstance):

return squashed
def _inverse_tanh(self, value):
capped_value = torch.clamp(value, -1 + EPSILON, 1 - EPSILON)
return 0.5 * torch.log((1 + capped_value) / (1 - capped_value) + EPSILON)
capped_value = torch.clamp(value, -1, 1)
return 0.5 * torch.log((1 + capped_value) / (1 - capped_value))
def log_prob(self, value):
unsquashed = self.transform.inv(value)

class CategoricalDistInstance(DiscreteDistInstance):
class CategoricalDistInstance:
super().__init__()
# super().__init__()
self.logits = logits
self.probs = torch.softmax(self.logits, dim=-1)

torch.zeros(1, num_outputs, requires_grad=True)
)
def forward(self, inputs: torch.Tensor) -> List[DistInstance]:
def forward(self, inputs: torch.Tensor):
if self.conditional_sigma:
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)
else:
log_sigma = self.log_sigma
if self.tanh_squash:
return [TanhGaussianDistInstance(mu, torch.exp(log_sigma))]
else:
return [GaussianDistInstance(mu, torch.exp(log_sigma))]
# if self.conditional_sigma:
# log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)
# else:
log_sigma = self.log_sigma
# if self.tanh_squash:
# return [TanhGaussianDistInstance(mu, torch.exp(log_sigma))]
# else:
return [GaussianDistInstance(mu, torch.exp(log_sigma))]
class MultiCategoricalDistribution(nn.Module):

def _mask_branch(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
raw_probs = torch.nn.functional.softmax(logits, dim=-1) * mask
normalized_probs = raw_probs / torch.sum(raw_probs, dim=-1).unsqueeze(-1)
normalized_logits = torch.log(normalized_probs + EPSILON)
normalized_logits = torch.log(normalized_probs)
start = int(np.sum(self.act_sizes[:idx]))
end = int(np.sum(self.act_sizes[: idx + 1]))
start = torch.sum(torch.tensor(self.act_sizes[:idx]))
end = torch.sum(torch.tensor(self.act_sizes[: idx + 1]))
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> List[DistInstance]:
def forward(self, inputs: torch.Tensor, masks: torch.Tensor):
# Todo - Support multiple branches in mask code
branch_distributions = []
masks = self._split_masks(masks)

2
ml-agents/mlagents/trainers/torch/encoders.py


nn.LeakyReLU(),
)
def forward(self, visual_obs: torch.Tensor) -> None:
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
hidden = self.conv_layers(visual_obs)
hidden = torch.reshape(hidden, (-1, self.final_flat))
hidden = self.dense(hidden)

131
ml-agents/mlagents/trainers/torch/networks.py


from typing import Callable, List, Dict, Tuple, Optional
import abc
import typing
import torch
from torch import nn

GaussianDistribution,
MultiCategoricalDistribution,
DistInstance,
GaussianDistInstance,
CategoricalDistInstance
)
from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils

else:
self.lstm = None # type: ignore
self.memory_size = 0
def update_normalization(self, vec_inputs: List[torch.Tensor]) -> None:
for vec_input, vec_enc in zip(vec_inputs, self.vector_encoders):
vec_enc.update_normalization(vec_input)

for n1, n2 in zip(self.vector_encoders, other_network.vector_encoders):
n1.copy_normalization(n2)
@property
def memory_size(self) -> int:
return self.lstm.memory_size if self.use_lstm else 0
# @property
# def memory_size(self) -> int:
# return self.lstm.memory_size if self.use_lstm else 0
def forward(
self,

memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if actions is not None:
hidden = encoder(vec_input, actions)
else:
hidden = encoder(vec_input)
# if actions is not None:
# hidden = encoder(vec_input, actions)
# else:
# hidden = encoder(vec_input)
hidden = encoder(vec_input)
if not torch.onnx.is_in_onnx_export():
vis_input = vis_input.permute([0, 3, 1, 2])
# if not torch.onnx.is_in_onnx_export():
# vis_input = vis_input.permute([0, 3, 1, 2])
hidden = encoder(vis_input)
encodes.append(hidden)

for _enc in encodes[1:]:
encoding += _enc
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)
encoding = encoding.reshape([-1, sequence_length, self.h_size])
encoding, memories = self.lstm(encoding, memories)
encoding = encoding.reshape([-1, self.m_size // 2])
# if self.use_lstm:
# # Resize to (batch, sequence length, encoding size)
# encoding = encoding.reshape([-1, sequence_length, self.h_size])
# encoding, memories = self.lstm(encoding, memories)
# encoding = encoding.reshape([-1, self.m_size // 2])
return encoding, memories

encoding_size = network_settings.hidden_units
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream)
@property
def memory_size(self) -> int:
return self.network_body.memory_size
self.memory_size = self.network_body.memory_size
# @property
# def memory_size(self) -> int:
# return self.network_body.memory_size
def forward(
self,

memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
) -> Tuple[Dict[str, torch.Tensor], Optional[torch.Tensor]]:
encoding, memories = self.network_body(
vec_inputs, vis_inputs, actions, memories, sequence_length
)

"""
pass
@abc.abstractproperty
def memory_size(self):
"""
Returns the size of the memory (same size used as input and output in the other
methods) used by this Actor.
"""
pass
# @abc.abstractproperty
# def memory_size(self):
# """
# Returns the size of the memory (same size used as input and output in the other
# methods) used by this Actor.
# """
# pass
class SimpleActor(nn.Module, Actor):

self.distribution = MultiCategoricalDistribution(
self.encoding_size, act_size
)
self.memory_size = 0
@property
def memory_size(self) -> int:
return self.network_body.memory_size
# @property
# def memory_size(self) -> int:
# return self.network_body.memory_size
def sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]:
@torch.jit.export
def sample_action(self, dists:List[GaussianDistInstance]) -> List[torch.Tensor]:
actions = []
for action_dist in dists:
action = action_dist.sample()

masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Optional[torch.Tensor]]:
):
if self.act_type == ActionType.CONTINUOUS:
dists = self.distribution(encoding)
else:
dists = self.distribution(encoding, masks)
##################
# if self.act_type == ActionType.CONTINUOUS:
dists = self.distribution(encoding)
# else:
# dists = self.distribution(encoding, masks)
@torch.jit.ignore
def forward(
self,
vec_inputs: List[torch.Tensor],

self.stream_names = stream_names
self.value_heads = ValueHeads(stream_names, self.encoding_size)
@torch.jit.export
def critic_pass(
self,
vec_inputs: List[torch.Tensor],

encoding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
if self.act_type == ActionType.CONTINUOUS:
dists = self.distribution(encoding)
else:
dists = self.distribution(encoding, masks=masks)
# if self.act_type == ActionType.CONTINUOUS:
# dists = self.distribution(encoding)
# else:
dists = self.distribution(encoding, masks=masks)
value_outputs = self.value_heads(encoding)
return dists, value_outputs, memories

)
self.stream_names = stream_names
self.critic = ValueNetwork(stream_names, observation_shapes, network_settings)
# self.critic = torch.jit.script(ValueNetwork(stream_names, observation_shapes, network_settings))
self.memory_size = 0
@property
def memory_size(self) -> int:
return self.network_body.memory_size + self.critic.memory_size
# @property
# def memory_size(self) -> int:
# return self.network_body.memory_size + self.critic.memory_size
@torch.jit.export
def critic_pass(
self,
vec_inputs: List[torch.Tensor],

) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
) -> Tuple[Dict[str, torch.Tensor], Optional[torch.Tensor]]:
if self.use_lstm:
# Use only the back half of memories for critic
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1)
# if self.use_lstm:
# # Use only the back half of memories for critic
# actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1)
value_outputs, critic_mem_out = self.critic(
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
)

memories_out = None
return value_outputs, memories_out
@torch.jit.export
def get_dist_and_value(
self,
vec_inputs: List[torch.Tensor],

sequence_length: int = 1,
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
if self.use_lstm:
# Use only the back half of memories for critic and actor
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1)
else:
critic_mem = None
actor_mem = None
):
# if self.use_lstm:
# # Use only the back half of memories for critic and actor
# actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1)
# else:
critic_mem = None
actor_mem = None
dists, actor_mem_outs = self.get_dists(
vec_inputs,
vis_inputs,

value_outputs, critic_mem_outs = self.critic(
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
)
if self.use_lstm:
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1)
else:
mem_out = None
# if self.use_lstm:
# mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1)
# else:
mem_out = None
return dists, value_outputs, mem_out

正在加载...
取消
保存