您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
151 行
5.8 KiB
151 行
5.8 KiB
from typing import List, NamedTuple
|
|
import attr
|
|
import numpy as np
|
|
|
|
from mlagents.trainers.buffer import AgentBuffer
|
|
from mlagents_envs.base_env import ActionTuple
|
|
from mlagents.trainers.torch.action_log_probs import LogProbsTuple
|
|
|
|
|
|
@attr.s(auto_attribs=True)
|
|
class AgentExperience:
|
|
obs: List[np.ndarray]
|
|
collab_obs: List[List[np.ndarray]]
|
|
reward: float
|
|
done: bool
|
|
action: ActionTuple
|
|
action_probs: LogProbsTuple
|
|
action_mask: np.ndarray
|
|
prev_action: np.ndarray
|
|
interrupted: bool
|
|
memory: np.ndarray
|
|
|
|
|
|
class SplitObservations(NamedTuple):
|
|
vector_observations: np.ndarray
|
|
visual_observations: List[np.ndarray]
|
|
|
|
@staticmethod
|
|
def from_observations(obs: List[np.ndarray]) -> "SplitObservations":
|
|
"""
|
|
Divides a List of numpy arrays into a SplitObservations NamedTuple.
|
|
This allows you to access the vector and visual observations directly,
|
|
without enumerating the list over and over.
|
|
:param obs: List of numpy arrays (observation)
|
|
:returns: A SplitObservations object.
|
|
"""
|
|
vis_obs_list: List[np.ndarray] = []
|
|
vec_obs_list: List[np.ndarray] = []
|
|
last_obs = None
|
|
for observation in obs:
|
|
# Obs could be batched or single
|
|
if len(observation.shape) == 1 or len(observation.shape) == 2:
|
|
vec_obs_list.append(observation)
|
|
if len(observation.shape) == 3 or len(observation.shape) == 4:
|
|
vis_obs_list.append(observation)
|
|
last_obs = observation
|
|
if last_obs is not None:
|
|
is_batched = len(last_obs.shape) == 2 or len(last_obs.shape) == 4
|
|
if is_batched:
|
|
vec_obs = (
|
|
np.concatenate(vec_obs_list, axis=1)
|
|
if len(vec_obs_list) > 0
|
|
else np.zeros((last_obs.shape[0], 0), dtype=np.float32)
|
|
)
|
|
else:
|
|
vec_obs = (
|
|
np.concatenate(vec_obs_list, axis=0)
|
|
if len(vec_obs_list) > 0
|
|
else np.array([], dtype=np.float32)
|
|
)
|
|
else:
|
|
vec_obs = []
|
|
return SplitObservations(
|
|
vector_observations=vec_obs, visual_observations=vis_obs_list
|
|
)
|
|
|
|
|
|
class Trajectory(NamedTuple):
|
|
steps: List[AgentExperience]
|
|
next_obs: List[
|
|
np.ndarray
|
|
] # Observation following the trajectory, for bootstrapping
|
|
next_collab_obs: List[List[np.ndarray]]
|
|
agent_id: str
|
|
behavior_id: str
|
|
|
|
def to_agentbuffer(self) -> AgentBuffer:
|
|
"""
|
|
Converts a Trajectory to an AgentBuffer
|
|
:param trajectory: A Trajectory
|
|
:returns: AgentBuffer. Note that the length of the AgentBuffer will be one
|
|
less than the trajectory, as the next observation need to be populated from the last
|
|
step of the trajectory.
|
|
"""
|
|
agent_buffer_trajectory = AgentBuffer()
|
|
curr_obs = self.steps[0].obs
|
|
for step, exp in enumerate(self.steps):
|
|
if step == 0:
|
|
# this initial all zeros creates the offset for comms
|
|
|
|
dummy = [[np.zeros_like(col_ob) for col_ob in exp.collab_obs[_ag]] for _ag in range(len(exp.collab_obs))]
|
|
agent_buffer_trajectory["comm_obs"].append(dummy)
|
|
if step < len(self.steps) - 1:
|
|
next_obs = self.steps[step + 1].obs
|
|
else:
|
|
next_obs = self.next_obs
|
|
agent_buffer_trajectory["obs"].append(curr_obs)
|
|
agent_buffer_trajectory["next_obs"].append(next_obs)
|
|
agent_buffer_trajectory["critic_obs"].append(exp.collab_obs)
|
|
# to avoid error of different sized bufferfields
|
|
if step < len(self.steps) - 1:
|
|
agent_buffer_trajectory["comm_obs"].append(exp.collab_obs)
|
|
if exp.memory is not None:
|
|
agent_buffer_trajectory["memory"].append(exp.memory)
|
|
|
|
agent_buffer_trajectory["masks"].append(1.0)
|
|
agent_buffer_trajectory["done"].append(exp.done)
|
|
|
|
# Adds the log prob and action of continuous/discrete separately
|
|
agent_buffer_trajectory["continuous_action"].append(exp.action.continuous)
|
|
agent_buffer_trajectory["discrete_action"].append(exp.action.discrete)
|
|
agent_buffer_trajectory["continuous_log_probs"].append(
|
|
exp.action_probs.continuous
|
|
)
|
|
agent_buffer_trajectory["discrete_log_probs"].append(
|
|
exp.action_probs.discrete
|
|
)
|
|
|
|
# Store action masks if necessary. Note that 1 means active, while
|
|
# in AgentExperience False means active.
|
|
if exp.action_mask is not None:
|
|
mask = 1 - np.concatenate(exp.action_mask)
|
|
agent_buffer_trajectory["action_mask"].append(mask, padding_value=1)
|
|
else:
|
|
# This should never be needed unless the environment somehow doesn't supply the
|
|
# action mask in a discrete space.
|
|
|
|
action_shape = exp.action.discrete.shape
|
|
agent_buffer_trajectory["action_mask"].append(
|
|
np.ones(action_shape, dtype=np.float32), padding_value=1
|
|
)
|
|
agent_buffer_trajectory["prev_action"].append(exp.prev_action)
|
|
agent_buffer_trajectory["environment_rewards"].append(exp.reward)
|
|
|
|
# Store the next obs as the current
|
|
curr_obs = next_obs
|
|
return agent_buffer_trajectory
|
|
|
|
@property
|
|
def done_reached(self) -> bool:
|
|
"""
|
|
Returns true if trajectory is terminated with a Done.
|
|
"""
|
|
return self.steps[-1].done
|
|
|
|
@property
|
|
def interrupted(self) -> bool:
|
|
"""
|
|
Returns true if trajectory was terminated because max steps was reached.
|
|
"""
|
|
return self.steps[-1].interrupted
|