|
|
|
|
|
|
|
|
|
|
EPSILON = 1e-7 # Small value to avoid divide by zero |
|
|
|
|
|
|
|
print("Torch threads", torch.get_num_threads()) |
|
|
|
print("Torch intra-op threads", torch.get_num_interop_threads()) |
|
|
|
|
|
|
|
# torch.set_num_interop_threads(8) |
|
|
|
# torch.set_num_threads(6) |
|
|
|
|
|
|
|
|
|
|
|
class TorchPolicy(Policy): |
|
|
|
def __init__( |
|
|
|
|
|
|
|
|
|
|
self.inference_dict: Dict[str, tf.Tensor] = {} |
|
|
|
self.update_dict: Dict[str, tf.Tensor] = {} |
|
|
|
# TF defaults to 32-bit, so we use the same here. |
|
|
|
torch.set_default_tensor_type(torch.FloatTensor) |
|
|
|
|
|
|
|
reward_signal_configs = trainer_params["reward_signals"] |
|
|
|
self.stats_name_to_update_name = { |
|
|
|
"Losses/Value Loss": "value_loss", |
|
|
|
|
|
|
|
|
|
|
@timed |
|
|
|
def sample_actions(self, vec_obs, vis_obs, masks=None, memories=None, seq_len=1): |
|
|
|
dists, (value_heads, mean_value), memories = self.actor_critic( |
|
|
|
dists, memories = self.actor_critic.evaluate( |
|
|
|
vec_obs, vis_obs, masks, memories, seq_len |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
actions.squeeze_(-1) |
|
|
|
|
|
|
|
return actions, log_probs, entropies, value_heads, memories |
|
|
|
return actions, log_probs, entropies, memories |
|
|
|
|
|
|
|
def evaluate_actions( |
|
|
|
self, vec_obs, vis_obs, actions, masks=None, memories=None, seq_len=1 |
|
|
|
|
|
|
|
|
|
|
run_out = {} |
|
|
|
with torch.no_grad(): |
|
|
|
action, log_probs, entropy, value_heads, memories = self.sample_actions( |
|
|
|
action, log_probs, entropy, memories = self.sample_actions( |
|
|
|
vec_obs, vis_obs, masks=masks, memories=memories |
|
|
|
) |
|
|
|
run_out["action"] = action.detach().numpy() |
|
|
|
|
|
|
run_out["entropy"] = entropy.detach().numpy() |
|
|
|
run_out["value_heads"] = { |
|
|
|
name: t.detach().numpy() for name, t in value_heads.items() |
|
|
|
} |
|
|
|
run_out["value"] = np.mean(list(run_out["value_heads"].values()), 0) |
|
|
|
run_out["learning_rate"] = 0.0 |
|
|
|
if self.use_recurrent: |
|
|
|
run_out["memories"] = memories.detach().numpy() |
|
|
|