|
|
|
|
|
|
self.critic.network_body.update_normalization(vector_obs) |
|
|
|
self.actor.network_body.update_normalization(vector_obs) |
|
|
|
|
|
|
|
def execute_model(self, vec_obs, vis_obs, masks=None): |
|
|
|
def execute_model(self, vec_obs, vis_obs, masks=None, actions=None): |
|
|
|
actions = [] |
|
|
|
if actions is None: |
|
|
|
generate_actions = True |
|
|
|
actions = [] |
|
|
|
else: |
|
|
|
generate_actions = False |
|
|
|
for action_dist in action_dists: |
|
|
|
action = action_dist.sample() |
|
|
|
actions.append(action) |
|
|
|
for idx, action_dist in enumerate(action_dists): |
|
|
|
if generate_actions: |
|
|
|
action = action_dist.sample() |
|
|
|
actions.append(action) |
|
|
|
else: |
|
|
|
action = actions[idx] |
|
|
|
actions = torch.stack(actions) |
|
|
|
log_probs = torch.stack(log_probs).squeeze(0) |
|
|
|
entropies = torch.stack(entropies).squeeze(0) |
|
|
|
|
|
|
|
if generate_actions: |
|
|
|
actions = torch.stack(actions, dim=-1) |
|
|
|
log_probs = torch.stack(log_probs, dim=-1) |
|
|
|
entropies = torch.stack(entropies, dim=-1) |
|
|
|
if self.act_type == "continuous": |
|
|
|
if generate_actions: |
|
|
|
actions = actions.squeeze(-1) |
|
|
|
log_probs = log_probs.squeeze(-1) |
|
|
|
entropies = entropies.squeeze(-1) |
|
|
|
value_heads, mean_value = self.critic(vec_obs, vis_obs) |
|
|
|
return actions, log_probs, entropies, value_heads |
|
|
|
|
|
|
|
|
|
|
vec_obs, vis_obs, masks = self.split_decision_step(decision_requests) |
|
|
|
vec_obs = [torch.Tensor(vec_obs)] |
|
|
|
vis_obs = [torch.Tensor(vis_ob) for vis_ob in vis_obs] |
|
|
|
masks = torch.Tensor(masks) |
|
|
|
if masks is not None: |
|
|
|
masks = torch.Tensor(masks) |
|
|
|
run_out = {} |
|
|
|
action, log_probs, entropy, value_heads = self.execute_model( |
|
|
|
vec_obs, vis_obs, masks |
|
|
|