浏览代码

add ActionBuffers and utils

/develop/action-spec-gym
Andrew Cohen 4 年前
当前提交
3f771e61
共有 10 个文件被更改,包括 126 次插入81 次删除
  1. 77
      ml-agents-envs/mlagents_envs/base_env.py
  2. 10
      ml-agents-envs/mlagents_envs/environment.py
  3. 2
      ml-agents/mlagents/trainers/agent_processor.py
  4. 2
      ml-agents/mlagents/trainers/buffer.py
  5. 30
      ml-agents/mlagents/trainers/policy/policy.py
  6. 25
      ml-agents/mlagents/trainers/policy/torch_policy.py
  7. 5
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  8. 11
      ml-agents/mlagents/trainers/tests/simple_test_envs.py
  9. 39
      ml-agents/mlagents/trainers/torch/utils.py
  10. 6
      ml-agents/mlagents/trainers/trajectory.py

77
ml-agents-envs/mlagents_envs/base_env.py


)
class ActionBuffers(NamedTuple):
"""
A NamedTuple whose fields correspond to actions of different types.
Continuous and discrete actions are numpy arrays.
"""
continuous: np.ndarray
discrete: np.ndarray
class ActionSpec(NamedTuple):
"""
A NamedTuple containing utility functions and information about the action spaces

"""
return len(self.discrete_branches)
def create_empty(self, n_agents: int) -> np.ndarray:
def create_empty(self, n_agents: int) -> ActionBuffers:
Generates a numpy array corresponding to an empty action (all zeros)
Generates ActionBuffers corresponding to an empty action (all zeros)
if self.is_continuous():
return np.zeros((n_agents, self.continuous_size), dtype=np.float32)
return np.zeros((n_agents, self.discrete_size), dtype=np.int32)
return ActionBuffers(np.zeros((n_agents, self.continuous_size), dtype=np.float32),
np.zeros((n_agents, self.discrete_size), dtype=np.int32))
def create_random(self, n_agents: int) -> np.ndarray:
def create_random(self, n_agents: int) -> ActionBuffers:
Generates a numpy array corresponding to a random action (either discrete
Generates ActionBuffers corresponding to a random action (either discrete
if self.is_continuous():
action = np.random.uniform(
low=-1.0, high=1.0, size=(n_agents, self.continuous_size)
).astype(np.float32)
else:
branch_size = self.discrete_branches
action = np.column_stack(
continuous_action = np.random.uniform(
low=-1.0, high=1.0, size=(n_agents, self.continuous_size)
).astype(np.float32)
discrete_action = np.column_stack(
branch_size[i], # type: ignore
self.discrete_branches[i], # type: ignore
size=(n_agents),
dtype=np.int32,
)

return action
return ActionBuffers(continuous_action, discrete_action)
self, actions: np.ndarray, n_agents: int, name: str
) -> np.ndarray:
self, actions: ActionBuffers, n_agents: int, name: str
) -> ActionBuffers:
if self.continuous_size > 0:
_size = self.continuous_size
else:
_size = self.discrete_size
_expected_shape = (n_agents, _size)
if actions.shape != _expected_shape:
_expected_shape = (n_agents, self.continuous_size)
if actions.continuous.shape != _expected_shape:
f"The behavior {name} needs an input of dimension "
f"The behavior {name} needs a continuous input of dimension "
_expected_type = np.float32 if self.is_continuous() else np.int32
if actions.dtype != _expected_type:
actions = actions.astype(_expected_type)
_expected_shape = (n_agents, self.discrete_size)
if actions.discrete.shape != _expected_shape:
raise UnityActionException(
f"The behavior {name} needs a discrete input of dimension "
f"{_expected_shape} for (<number of agents>, <action size>) but "
f"received input of dimension {actions.shape}"
)
if actions.continuous.dtype != np.float32:
actions.continuous = actions.continuous.astype(np.float32)
if actions.discrete.dtype != np.int32:
actions.discrete = actions.discrete.astype(np.int32)
return actions
@staticmethod

"""
@abstractmethod
def set_actions(self, behavior_name: BehaviorName, action: np.ndarray) -> None:
def set_actions(self, behavior_name: BehaviorName, action: ActionBuffers) -> None:
:param action: A two dimensional np.ndarray corresponding to the action
(either int or float)
:param action: ActionBuffers tuple of continuous and/or discrete action
self, behavior_name: BehaviorName, agent_id: AgentId, action: np.ndarray
self, behavior_name: BehaviorName, agent_id: AgentId, action: ActionBuffers
) -> None:
"""
Sets the action for one of the agents in the simulation for the next

:param action: A one dimensional np.ndarray corresponding to the action
(either int or float)
:param action: ActionBuffers tuple of continuous and/or discrete action
"""
@abstractmethod

10
ml-agents-envs/mlagents_envs/environment.py


DecisionSteps,
TerminalSteps,
BehaviorSpec,
ActionBuffers,
BehaviorName,
AgentId,
BehaviorMapping,

self._env_state: Dict[str, Tuple[DecisionSteps, TerminalSteps]] = {}
self._env_specs: Dict[str, BehaviorSpec] = {}
self._env_actions: Dict[str, np.ndarray] = {}
self._env_actions: Dict[str, ActionBuffers] = {}
self._is_first_message = True
self._update_behavior_specs(aca_output)

f"agent group in the environment"
)
def set_actions(self, behavior_name: BehaviorName, action: np.ndarray) -> None:
def set_actions(self, behavior_name: BehaviorName, action: ActionBuffers) -> None:
self._assert_behavior_exists(behavior_name)
if behavior_name not in self._env_state:
return

self._env_actions[behavior_name] = action
def set_action_for_agent(
self, behavior_name: BehaviorName, agent_id: AgentId, action: np.ndarray
self, behavior_name: BehaviorName, agent_id: AgentId, action: ActionBuffers
) -> None:
self._assert_behavior_exists(behavior_name)
if behavior_name not in self._env_state:

@timed
def _generate_step_input(
self, vector_action: Dict[str, np.ndarray]
self, vector_action: Dict[str, ActionBuffers]
) -> UnityInputProto:
rl_in = UnityRLInputProto()
for b in vector_action:

for i in range(n_agents):
#TODO: extend to AgentBuffers
action = AgentActionProto(vector_actions=vector_action[b][i])
rl_in.agent_actions[b].value.extend([action])
rl_in.command = STEP

2
ml-agents/mlagents/trainers/agent_processor.py


action_pre = None
action_probs = stored_take_action_outputs["log_probs"][idx]
action_mask = stored_decision_step.action_mask
prev_action = self.policy.retrieve_previous_action([global_id])[0, :]
prev_action = self.policy.retrieve_previous_action([global_id])#[0, :]
experience = AgentExperience(
obs=obs,
reward=step.reward,

2
ml-agents/mlagents/trainers/buffer.py


class AgentBufferField(list):
"""
AgentBufferField is a list of numpy arrays. When an agent collects a field, you can add it to his
AgentBufferField is a list of numpy arrays. When an agent collects a field, you can add it to its
AgentBufferField with the append method.
"""

30
ml-agents/mlagents/trainers/policy/policy.py


from mlagents_envs.exception import UnityException
from mlagents.trainers.action_info import ActionInfo
from mlagents_envs.base_env import BehaviorSpec
from mlagents_envs.base_env import BehaviorSpec, ActionBuffers
from mlagents.trainers.settings import TrainerSettings, NetworkSettings

1 for shape in behavior_spec.observation_shapes if len(shape) == 3
)
self.use_continuous_act = self.behavior_spec.action_spec.is_continuous()
# This line will be removed in the ActionBuffer change
self.num_branches = (
self.behavior_spec.action_spec.continuous_size
+ self.behavior_spec.action_spec.discrete_size
)
self.previous_action_dict: Dict[str, np.array] = {}
self.previous_action_dict: Dict[str, ActionBuffers] = {}
self.memory_dict: Dict[str, np.ndarray] = {}
self.normalize = trainer_settings.network_settings.normalize
self.use_recurrent = self.network_settings.memory is not None

if agent_id in self.memory_dict:
self.memory_dict.pop(agent_id)
def make_empty_previous_action(self, num_agents):
def make_empty_previous_action(self, num_agents) -> ActionBuffers:
:return: Numpy array of zeros.
:return: ActionBuffers .
return np.zeros((num_agents, self.num_branches), dtype=np.int)
return self.behavior_spec.action_spec.create_empty(num_agents)
self, agent_ids: List[str], action_matrix: Optional[np.ndarray]
self, agent_ids: List[str], action_buffers: Optional[ActionBuffers]
if action_matrix is None:
if action_buffers is None:
self.previous_action_dict[agent_id] = action_matrix[index, :]
self.previous_action_dict[agent_id] = action_buffers
def retrieve_previous_action(self, agent_ids: List[str]) -> np.ndarray:
action_matrix = np.zeros((len(agent_ids), self.num_branches), dtype=np.int)
def retrieve_previous_action(self, agent_ids: List[str]) -> ActionBuffers:
action_buffers = self.behavior_spec.action_spec.create_empty(len(agent_ids))
action_matrix[index, :] = self.previous_action_dict[agent_id]
return action_matrix
for action, previous_action in zip(action_buffers, self.previous_action_dict[agent_id]):
action[index, :] = previous_action
return action_buffers
def remove_previous_action(self, agent_ids):
for agent_id in agent_ids:

25
ml-agents/mlagents/trainers/policy/torch_policy.py


memories: Optional[torch.Tensor] = None,
seq_len: int = 1,
all_log_probs: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
"""
:param vec_obs: List of vector observations.
:param vis_obs: List of visual observations.

log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy(
action_list, dists
)
actions = torch.stack(action_list, dim=-1)
if self.use_continuous_act:
actions = actions[:, :, 0]
else:
actions = actions[:, 0, :]
# actions = torch.stack(action_list, dim=-1)
# if self.use_continuous_act:
# actions = actions[:, :, 0]
# else:
# actions = actions[:, 0, :]
actions,
action_list,
all_logs if all_log_probs else log_probs,
entropy_sum,
memories,

self,
vec_obs: torch.Tensor,
vis_obs: torch.Tensor,
actions: torch.Tensor,
actions: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
seq_len: int = 1,

)
action_list = [actions[..., i] for i in range(actions.shape[-1])]
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_list, dists)
#action_list = [actions[..., i] for i in range(actions.shape[-1])]
#log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_list, dists)
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(actions, dists)
# Use the sum of entropy across actions, not the mean
entropy_sum = torch.sum(entropies, dim=1)
return log_probs, entropy_sum, value_heads

action, log_probs, entropy, memories = self.sample_actions(
vec_obs, vis_obs, masks=masks, memories=memories
)
run_out["action"] = ModelUtils.to_numpy(action)
run_out["pre_action"] = ModelUtils.to_numpy(action)
run_out["action"] = ModelUtils.to_action_buffers(action, self.behavior_spec.action_spec)
run_out["pre_action"] = ModelUtils.to_action_buffers(action, self.behavior_spec.action_spec)
# Todo - make pre_action difference
run_out["log_probs"] = ModelUtils.to_numpy(log_probs)
run_out["entropy"] = ModelUtils.to_numpy(entropy)

5
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])]
act_masks = ModelUtils.list_to_tensor(batch["action_mask"])
if self.policy.use_continuous_act:
actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1)
else:
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long)
actions = ModelUtils.action_buffers_to_tensor_list(batch["actions"], self.policy.behavior_spec.action_spec)
memories = [
ModelUtils.list_to_tensor(batch["memory"][i])

11
ml-agents/mlagents/trainers/tests/simple_test_envs.py


from mlagents_envs.base_env import (
ActionSpec,
ActionBuffers,
BaseEnv,
BehaviorSpec,
DecisionSteps,

def _take_action(self, name: str) -> bool:
deltas = []
for _act in self.action[name][0]:
if self.discrete:
deltas.append(1 if _act else -1)
else:
deltas.append(_act)
_act = self.action[name]
for _disc in _act.discrete:
deltas.append(1 if _disc else -1)
for _cont in _act.continuous:
deltas.append(_cont)
for i, _delta in enumerate(deltas):
_delta = clamp(_delta, -self.step_size, self.step_size)
self.positions[name][i] += _delta

39
ml-agents/mlagents/trainers/torch/utils.py


)
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.base_env import ActionSpec
from mlagents_envs.base_env import ActionSpec, ActionBuffers
from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance

nn.ModuleList(vector_encoders),
total_processed_size,
)
@staticmethod
def to_action_buffers(actions: List[torch.Tensor], action_spec: ActionSpec) -> ActionBuffers:
"""
Converts a list of action Tensors to an ActionBuffers tuple. Implicitly
assumes order of actions in 'actions' is continuous, discrete
"""
continuous_action: np.ndarray = np.array([])
discrete_action_list: List[np.ndarray] = []
discrete_action: np.ndarray = np.array([])
# offset to index discrete actions depending on presence of continuous actions
_offset = 0
if action_spec.continuous_size > 0:
continuous_action = actions[0].detach().cpu().numpy()
_offset = 1
if action_spec.discrete_size > 0:
for _disc in range(action_spec.discrete_size):
discrete_action_list.append(actions[_disc + _offset].detach().cpu().numpy())
#print(discrete_action_list)
discrete_action = np.array(discrete_action_list)
return ActionBuffers(continuous_action, discrete_action)
@staticmethod
def action_buffers_to_tensor_list(
action_buffers: ActionBuffers, action_spec: ActionSpec, dtype: Optional[torch.dtype] = None
) -> List[torch.Tensor]:
"""
Converts ActionBuffers fields into a List of tensors.
"""
#print(action_buffers)
action_tensors: List[torch.Tensor] = []
if action_spec.continuous_size > 0:
action_tensors.append(torch.as_tensor(np.asanyarray(action_buffers.continuous), dtype=dtype))
if action_spec.discrete_size > 0:
for _disc in range(action_buffers.discrete):
action_tensors.append(torch.as_tensor(np.asanyarray(_disc), dtype=dtype))
return actiion_tensors
@staticmethod
def list_to_tensor(

6
ml-agents/mlagents/trainers/trajectory.py


from mlagents.trainers.buffer import AgentBuffer
from mlagents_envs.base_env import ActionBuffers
action: np.ndarray
action: ActionBuffers
action_pre: np.ndarray # TODO: Remove this
action_pre: ActionBuffers # TODO: Remove this
action_mask: np.ndarray
prev_action: np.ndarray
interrupted: bool

正在加载...
取消
保存