浏览代码

Seems to speed it up

/develop/add-fire/exp
Ervin Teng 4 年前
当前提交
565f92ef
共有 2 个文件被更改,包括 15 次插入21 次删除
  1. 20
      ml-agents/mlagents/trainers/models_torch.py
  2. 16
      ml-agents/mlagents/trainers/policy/torch_policy.py

20
ml-agents/mlagents/trainers/models_torch.py


entropies = entropies.squeeze(-1)
return log_probs, entropies
def get_dist_and_value(
def forward(
value_outputs = self.critic(vec_inputs, vis_inputs)
if self.separate_critic:
value_outputs = self.critic(vec_inputs, vis_inputs)
else:
value_outputs = self.value_heads(embedding)
# if self.separate_critic:
# value_outputs = self.critic(vec_inputs, vis_inputs)
# else:
# value_outputs = self.value_heads(embedding)
def forward(
self, vec_inputs, vis_inputs, masks=None, memories=None, sequence_length=1
):
dists, value_outputs, memories = self.get_dist_and_value(
vec_inputs, vis_inputs, masks, memories, sequence_length
)
sampled_actions = self.sample_action(dists)
return sampled_actions, memories
class Critic(nn.Module):

16
ml-agents/mlagents/trainers/policy/torch_policy.py


EPSILON = 1e-7 # Small value to avoid divide by zero
print("Torch threads", torch.get_num_threads())
print("Torch intra-op threads", torch.get_num_interop_threads())
# torch.set_num_interop_threads(8)
# torch.set_num_threads(6)
class TorchPolicy(Policy):
def __init__(

self.inference_dict: Dict[str, tf.Tensor] = {}
self.update_dict: Dict[str, tf.Tensor] = {}
# TF defaults to 32-bit, so we use the same here.
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_tensor_type(torch.FloatTensor)
reward_signal_configs = trainer_params["reward_signals"]
self.stats_name_to_update_name = {

@timed
def sample_actions(self, vec_obs, vis_obs, masks=None, memories=None, seq_len=1):
dists, (
value_heads,
mean_value,
), memories = self.actor_critic.get_dist_and_value(
dists, (value_heads, mean_value), memories = self.actor_critic(
vec_obs, vis_obs, masks, memories, seq_len
)

def evaluate_actions(
self, vec_obs, vis_obs, actions, masks=None, memories=None, seq_len=1
):
dists, (value_heads, mean_value), _ = self.actor_critic.get_dist_and_value(
dists, (value_heads, mean_value), _ = self.actor_critic(
vec_obs, vis_obs, masks, memories, seq_len
)

run_out["learning_rate"] = 0.0
if self.use_recurrent:
run_out["memories"] = memories.detach().numpy()
self.actor_critic.update_normalization(vec_obs)
return run_out
def get_action(

正在加载...
取消
保存