|
|
|
|
|
|
from mlagents.trainers.action_info import ActionInfo |
|
|
|
from mlagents.trainers.behavior_id_utils import get_global_agent_id |
|
|
|
from mlagents.trainers.policy import Policy |
|
|
|
from mlagents_envs.base_env import DecisionSteps, BehaviorSpec |
|
|
|
from mlagents_envs.base_env import DecisionSteps, BehaviorSpec, ActionBuffers |
|
|
|
from mlagents_envs.timers import timed |
|
|
|
|
|
|
|
from mlagents.trainers.settings import TrainerSettings |
|
|
|
|
|
|
GlobalSteps() |
|
|
|
) # could be much simpler if TorchPolicy is nn.Module |
|
|
|
self.grads = None |
|
|
|
|
|
|
|
self.previous_action_dict: Dict[str, ActionBuffers] = {} |
|
|
|
|
|
|
|
reward_signal_configs = trainer_settings.reward_signals |
|
|
|
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] |
|
|
|
|
|
|
self, |
|
|
|
vec_obs: torch.Tensor, |
|
|
|
vis_obs: torch.Tensor, |
|
|
|
actions: torch.Tensor, |
|
|
|
actions: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
seq_len: int = 1, |
|
|
|
|
|
|
vec_obs, vis_obs, masks=masks, memories=memories |
|
|
|
) |
|
|
|
# Todo - make pre_action difference |
|
|
|
run_out["pre_action"] = ModelUtils.to_numpy(action) |
|
|
|
run_out["action"] = ModelUtils.to_numpy(action) |
|
|
|
run_out["action"] = ModelUtils.to_action_buffers(action, self.action_spec) |
|
|
|
run_out["pre_action"] = ModelUtils.to_action_buffers(action, self.action_spec) |
|
|
|
run_out["log_probs"] = ModelUtils.to_numpy(log_probs) |
|
|
|
run_out["entropy"] = ModelUtils.to_numpy(entropy) |
|
|
|
run_out["value_heads"] = { |
|
|
|
|
|
|
|
|
|
|
def get_modules(self): |
|
|
|
return {"Policy": self.actor_critic, "global_step": self.global_step} |
|
|
|
|
|
|
|
# Overriding for use of ActionBuffers in torch |
|
|
|
def make_empty_previous_action(self, num_agents): |
|
|
|
""" |
|
|
|
Creates empty previous action for use with RNNs and discrete control |
|
|
|
:param num_agents: Number of agents. |
|
|
|
:return: Numpy array of zeros. |
|
|
|
""" |
|
|
|
return self.action_spec.create_empty_action(num_agents) |
|
|
|
|
|
|
|
def save_previous_action( |
|
|
|
self, agent_ids: List[str], action_matrix: ActionBuffers) -> None: |
|
|
|
if action_matrix is None: |
|
|
|
return |
|
|
|
for index, agent_id in enumerate(agent_ids): |
|
|
|
self.previous_action_dict[agent_id] = action_matrix |
|
|
|
|
|
|
|
def retrieve_previous_action(self, agent_ids: List[str]) -> np.ndarray: |
|
|
|
action_matrix = self.action_spec.create_empty_action(len(agent_ids)) |
|
|
|
for index, agent_id in enumerate(agent_ids): |
|
|
|
if agent_id in self.previous_action_dict: |
|
|
|
action_matrix = self.previous_action_dict[agent_id] |
|
|
|
return action_matrix |
|
|
|
|
|
|
|
def remove_previous_action(self, agent_ids): |
|
|
|
for agent_id in agent_ids: |
|
|
|
if agent_id in self.previous_action_dict: |
|
|
|
self.previous_action_dict.pop(agent_id) |
|
|
|
|