浏览代码

Remove another if statement

/develop/add-fire/exp
Ervin Teng 5 年前
当前提交
2fae31e6
共有 3 个文件被更改,包括 16 次插入35 次删除
  1. 2
      ml-agents/mlagents/trainers/distributions_torch.py
  2. 30
      ml-agents/mlagents/trainers/models_torch.py
  3. 19
      ml-agents/mlagents/trainers/policy/torch_policy.py

2
ml-agents/mlagents/trainers/distributions_torch.py


torch.zeros(1, num_outputs, requires_grad=True)
)
def forward(self, inputs):
def forward(self, inputs, masks):
mu = self.mu(inputs)
if self.conditional_sigma:
log_sigma = self.log_sigma(inputs)

30
ml-agents/mlagents/trainers/models_torch.py


hidden = encoder(vis_input)
vis_embeds.append(hidden)
if len(vec_embeds) > 0:
vec_embeds = torch.cat(vec_embeds)
if len(vis_embeds) > 0:
vis_embeds = torch.cat(vis_embeds)
if len(vec_embeds) > 0 and len(vis_embeds) > 0:
embedding = torch.cat([vec_embeds, vis_embeds])
elif len(vec_embeds) > 0:
embedding = vec_embeds
else:
embedding = vis_embeds
embedding = torch.cat(vec_embeds + vis_embeds)
if self.use_lstm:
embedding = embedding.reshape([sequence_length, -1, self.h_size])

entropies = entropies.squeeze(-1)
return log_probs, entropies
def evaluate(
self, vec_inputs, vis_inputs, masks=None, memories=None, sequence_length=1
):
embedding, memories = self.network_body(
vec_inputs, vis_inputs, memories, sequence_length
)
dists = self.distribution(embedding, masks=masks)
return dists, memories
def forward(
self, vec_inputs, vis_inputs, masks=None, memories=None, sequence_length=1
):

value_outputs = self.critic(vec_inputs, vis_inputs)
if self.act_type == ActionType.CONTINUOUS:
dists = self.distribution(embedding)
else:
dists = self.distribution(embedding, masks=masks)
# if self.separate_critic:
# value_outputs = self.critic(vec_inputs, vis_inputs)
# else:
# value_outputs = self.value_heads(embedding)
dists = self.distribution(embedding, masks=masks)
return dists, value_outputs, memories

19
ml-agents/mlagents/trainers/policy/torch_policy.py


EPSILON = 1e-7 # Small value to avoid divide by zero
print("Torch threads", torch.get_num_threads())
print("Torch intra-op threads", torch.get_num_interop_threads())
# torch.set_num_interop_threads(8)
# torch.set_num_threads(6)
class TorchPolicy(Policy):
def __init__(

self.inference_dict: Dict[str, tf.Tensor] = {}
self.update_dict: Dict[str, tf.Tensor] = {}
# TF defaults to 32-bit, so we use the same here.
torch.set_default_tensor_type(torch.FloatTensor)
reward_signal_configs = trainer_params["reward_signals"]
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",

@timed
def sample_actions(self, vec_obs, vis_obs, masks=None, memories=None, seq_len=1):
dists, (value_heads, mean_value), memories = self.actor_critic(
dists, memories = self.actor_critic.evaluate(
vec_obs, vis_obs, masks, memories, seq_len
)

actions.squeeze_(-1)
return actions, log_probs, entropies, value_heads, memories
return actions, log_probs, entropies, memories
def evaluate_actions(
self, vec_obs, vis_obs, actions, masks=None, memories=None, seq_len=1

run_out = {}
with torch.no_grad():
action, log_probs, entropy, value_heads, memories = self.sample_actions(
action, log_probs, entropy, memories = self.sample_actions(
vec_obs, vis_obs, masks=masks, memories=memories
)
run_out["action"] = action.detach().numpy()

run_out["entropy"] = entropy.detach().numpy()
run_out["value_heads"] = {
name: t.detach().numpy() for name, t in value_heads.items()
}
run_out["value"] = np.mean(list(run_out["value_heads"].values()), 0)
run_out["learning_rate"] = 0.0
if self.use_recurrent:
run_out["memories"] = memories.detach().numpy()

正在加载...
取消
保存