浏览代码

Experiment with JIT compiler

/develop/add-fire/jit
Ervin Teng 4 年前
当前提交
72180f9b
共有 5 个文件被更改,包括 104 次插入62 次删除
  1. 13
      ml-agents/mlagents/trainers/distributions_torch.py
  2. 102
      ml-agents/mlagents/trainers/models_torch.py
  3. 6
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  4. 36
      ml-agents/mlagents/trainers/policy/torch_policy.py
  5. 9
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py

13
ml-agents/mlagents/trainers/distributions_torch.py


class GaussianDistribution(nn.Module):
def __init__(self, hidden_size, num_outputs, conditional_sigma=False, **kwargs):
super(GaussianDistribution, self).__init__(**kwargs)
def __init__(self, hidden_size, num_outputs, conditional_sigma=False):
super(GaussianDistribution, self).__init__()
self.conditional_sigma = conditional_sigma
self.mu = nn.Linear(hidden_size, num_outputs)
nn.init.xavier_uniform_(self.mu.weight, gain=0.01)

torch.zeros(1, num_outputs, requires_grad=True)
)
@torch.jit.ignore
if self.conditional_sigma:
log_sigma = self.log_sigma(inputs)
else:
log_sigma = self.log_sigma
# if self.conditional_sigma:
# log_sigma = self.log_sigma(inputs)
# else:
log_sigma = self.log_sigma
return [distributions.normal.Normal(loc=mu, scale=torch.exp(log_sigma))]

102
ml-agents/mlagents/trainers/models_torch.py


h_size,
)
)
self.vector_normalizers = nn.ModuleList(self.vector_normalizers)
self.vector_encoders = nn.ModuleList(self.vector_encoders)
self.visual_encoders = nn.ModuleList(self.visual_encoders)
if use_lstm:

for idx, vec_input in enumerate(vec_inputs):
self.vector_normalizers[idx].update(vec_input)
def forward(self, vec_inputs, vis_inputs, memories=None, sequence_length=1):
def forward(
self,
vec_inputs,
vis_inputs,
memories=torch.tensor(1),
sequence_length=torch.tensor(1),
):
if self.normalize:
vec_input = self.vector_normalizers[idx](vec_input)
# if self.normalize:
# vec_input = self.vector_normalizers[idx](vec_input)
hidden = encoder(vec_input)
vec_embeds.append(hidden)

embedding = torch.cat(vec_embeds + vis_embeds)
if self.use_lstm:
embedding = embedding.reshape([sequence_length, -1, self.h_size])
memories = torch.split(memories, self.m_size // 2, dim=-1)
embedding, memories = self.lstm(embedding, memories)
embedding = embedding.reshape([-1, self.m_size // 2])
memories = torch.cat(memories, dim=-1)
return embedding, memories
# if self.use_lstm:
# embedding = embedding.reshape([sequence_length, -1, self.h_size])
# memories = torch.split(memories, self.m_size // 2, dim=-1)
# embedding, memories = self.lstm(embedding, memories)
# embedding = embedding.reshape([-1, self.m_size // 2])
# memories = torch.cat(memories, dim=-1)
return embedding, embedding
class ActorCritic(nn.Module):

self.stream_names = stream_names
self.value_heads = ValueHeads(stream_names, embedding_size)
@torch.jit.ignore
def critic_pass(self, vec_inputs, vis_inputs, memories=None):
if self.separate_critic:
return self.critic(vec_inputs, vis_inputs)
else:
embedding, _ = self.network_body(vec_inputs, vis_inputs, memories=memories)
return self.value_heads(embedding)
@torch.jit.export
def critic_pass(self, vec_inputs, vis_inputs, memories=torch.tensor(1)):
# if self.separate_critic:
value, mean_value = self.critic(vec_inputs, vis_inputs)
return {"extrinsic": value}, mean_value
# else:
# embedding, _ = self.network_body(vec_inputs, vis_inputs, memories=memories)
# return {"extrinsic" : self.value_heads(embedding)}
@torch.jit.ignore
def sample_action(self, dists):
actions = []
for action_dist in dists:

return actions
@torch.jit.ignore
def get_probs_and_entropy(self, actions, dists):
log_probs = []
entropies = []

entropies = entropies.squeeze(-1)
return log_probs, entropies
@torch.jit.ignore
def evaluate(
self, vec_inputs, vis_inputs, masks=None, memories=None, sequence_length=1
):

return dists, memories
@torch.jit.export
def jit_forward(
self,
vec_inputs,
vis_inputs,
masks=torch.tensor(1),
memories=torch.tensor(1),
sequence_length=torch.tensor(1),
):
fut = torch.jit._fork(
self.network_body, vec_inputs, vis_inputs, memories, sequence_length
)
embedding, memories = torch.jit._wait(fut)
value_outputs = self.critic_pass(vec_inputs, vis_inputs, memories)
return embedding, value_outputs, memories
@torch.jit.ignore
self, vec_inputs, vis_inputs, masks=None, memories=None, sequence_length=1
self,
vec_inputs,
vis_inputs,
masks=torch.tensor(1),
memories=torch.tensor(1),
sequence_length=torch.tensor(1),
embedding, memories = self.network_body(
vec_inputs, vis_inputs, memories, sequence_length
embedding, value_outputs, memories = self.jit_forward(
vec_inputs, vis_inputs, masks, memories, sequence_length
value_outputs = self.critic(vec_inputs, vis_inputs)
dists = self.distribution(embedding, masks=masks)
dists = self.get_dist(embedding, masks)
@torch.jit.ignore
def get_dist(self, embedding, masks):
return self.distribution(embedding, masks=masks)
class Critic(nn.Module):
def __init__(

class Normalizer(nn.Module):
def __init__(self, vec_obs_size, **kwargs):
super(Normalizer, self).__init__(**kwargs)
def __init__(self, vec_obs_size):
super(Normalizer, self).__init__()
self.normalization_steps = torch.tensor(1)
self.running_mean = torch.zeros(vec_obs_size)
self.running_variance = torch.ones(vec_obs_size)

for name in stream_names:
value = nn.Linear(input_size, 1)
self.value_heads[name] = value
self.value = value
self.value_outputs = nn.ModuleDict({})
value_outputs = {}
for stream_name, _ in self.value_heads.items():
value_outputs[stream_name] = self.value_heads[stream_name](hidden).squeeze(
-1
)
return (
value_outputs,
torch.mean(torch.stack(list(value_outputs.values())), dim=0),
)
# self.__delattr__
# for stream_name, head in self.value_heads.items():
# self.value_outputs[stream_name] = head(hidden).squeeze(
# -1
# )
return (self.value(hidden).squeeze(-1), self.value(hidden).squeeze(-1))
class VectorEncoder(nn.Module):

6
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


def get_trajectory_value_estimates(
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]:
vector_obs = [torch.as_tensor(batch["vector_obs"])]
vector_obs = torch.as_tensor([batch["vector_obs"]])
if self.policy.use_vis_obs:
visual_obs = []
for idx, _ in enumerate(

visual_obs.append(visual_ob)
else:
visual_obs = []
visual_obs = torch.as_tensor([])
next_obs = [torch.as_tensor(next_obs).unsqueeze(0)]
next_obs = torch.as_tensor([np.expand_dims(next_obs, 0)])
next_memory = torch.zeros([1, 1, self.policy.m_size])
value_estimates, mean_value = self.policy.actor_critic.critic_pass(

36
ml-agents/mlagents/trainers/policy/torch_policy.py


from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.models_torch import EncoderType, ActorCritic
torch.set_num_interop_threads(6)
EPSILON = 1e-7 # Small value to avoid divide by zero

"Losses/Policy Loss": "policy_loss",
}
self.actor_critic = ActorCritic(
h_size=int(trainer_params["hidden_units"]),
act_type=self.act_type,
vector_sizes=[brain.vector_observation_space_size],
act_size=brain.vector_action_space_size,
normalize=trainer_params["normalize"],
num_layers=int(trainer_params["num_layers"]),
m_size=trainer_params["memory_size"],
use_lstm=self.use_recurrent,
visual_sizes=brain.camera_resolutions,
vis_encode_type=EncoderType(
trainer_params.get("vis_encode_type", "simple")
),
stream_names=list(reward_signal_configs.keys()),
separate_critic=self.use_continuous_act,
self.actor_critic = torch.jit.script(
ActorCritic(
h_size=int(trainer_params["hidden_units"]),
act_type=self.act_type,
vector_sizes=[brain.vector_observation_space_size],
act_size=brain.vector_action_space_size,
normalize=trainer_params["normalize"],
num_layers=int(trainer_params["num_layers"]),
m_size=trainer_params["memory_size"],
use_lstm=self.use_recurrent,
visual_sizes=brain.camera_resolutions,
vis_encode_type=EncoderType(
trainer_params.get("vis_encode_type", "simple")
),
stream_names=list(reward_signal_configs.keys()),
separate_critic=self.use_continuous_act,
)
print(self.actor_critic)
def split_decision_step(self, decision_requests):
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs)

self, vec_obs, vis_obs, actions, masks=None, memories=None, seq_len=1
):
dists, (value_heads, mean_value), _ = self.actor_critic(
vec_obs, vis_obs, masks, memories, seq_len
vec_obs, vis_obs, masks, memories, torch.as_tensor(seq_len)
)
log_probs, entropies = self.actor_critic.get_probs_and_entropy(actions, dists)

9
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


old_values[name] = torch.as_tensor(batch["{}_value_estimates".format(name)])
returns[name] = torch.as_tensor(batch["{}_returns".format(name)])
vec_obs = [torch.as_tensor(batch["vector_obs"])]
vec_obs = torch.as_tensor([batch["vector_obs"]])
act_masks = torch.as_tensor(batch["action_mask"])
if self.policy.use_continuous_act:
actions = torch.as_tensor(batch["actions"]).unsqueeze(-1)

torch.as_tensor(batch["memory"][i])
for i in range(0, len(batch["memory"]), self.policy.sequence_length)
]
if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
memories = torch.as_tensor([1])
# if len(memories) > 0:
# memories = torch.stack(memories).unsqueeze(0)
if self.policy.use_vis_obs:
vis_obs = []

vis_ob = torch.as_tensor(batch["visual_obs%d" % idx])
vis_obs.append(vis_ob)
else:
vis_obs = []
vis_obs = torch.as_tensor([])
log_probs, entropy, values = self.policy.evaluate_actions(
vec_obs,
vis_obs,

正在加载...
取消
保存