浏览代码

hybrid bheavior spec

/develop/actionmodel-csharp
Andrew Cohen 4 年前
当前提交
6587c911
共有 2 个文件被更改,包括 269 次插入0 次删除
  1. 40
      ml-agents-envs/mlagents_envs/base_env.py
  2. 229
      ml-agents/mlagents/trainers/torch/networks.py

40
ml-agents-envs/mlagents_envs/base_env.py


DISCRETE = 0
CONTINUOUS = 1
class HybridBehaviorSpec(NamedTuple):
observation_shapes: List[Tuple]
continuous_action_shape: int
discrete_action_shape: Tuple[int]
def discrete_action_size(self) -> int:
return len(self.discrete_action_shape)
def continuous_action_size(self) -> int:
return self.continuous_action_shape
@property
def action_size(self) -> int:
return self.discrete_action_size() + self.continuous_action_size()
@property
def discrete_action_branches(self) -> Optional[Tuple[int, ...]]:
return self.discrete_action_shape # type: ignore
def create_empty_action(self, n_agents: int) -> np.ndarray:
return np.zeros((n_agents, self.discrete_action_size + self.continuous_action_size), dtype=np.float32)
def create_random_action(self, n_agents: int) -> np.ndarray:
continuous_action = np.random.uniform(
low=-1.0, high=1.0, size=(n_agents, self.continuous_action_size)
).astype(np.float32)
branch_size = self.discrete_action_branches
discrete_action = np.column_stack(
[
np.random.randint(
0,
branch_size[i], # type: ignore
size=(n_agents),
dtype=np.int32,
)
for i in range(self.action_size)
]
)
return np.concatenate(discrete_action, continuous_action)
class BehaviorSpec(NamedTuple):
"""

229
ml-agents/mlagents/trainers/torch/networks.py


"""
pass
class HybridSimpleActor(nn.Module, Actor):
def __init__(
self,
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
continuous_act_size: List[int],
discrete_act_size: List[int],
conditional_sigma: bool = False,
tanh_squash: bool = False,
):
super().__init__()
self.discrete_act_size = discrete_act_size
self.continuous_act_size = continuous_act_size
self.version_number = torch.nn.Parameter(torch.Tensor([2.0]))
#self.is_continuous_int = torch.nn.Parameter(
# torch.Tensor([int(act_type == ActionType.CONTINUOUS)])
#)
self.continuous_act_size_vector = torch.nn.Parameter(torch.Tensor(continuous_act_size))
self.discrete_act_size_vector = torch.nn.Parameter(torch.Tensor(discrete_act_size))
self.network_body = NetworkBody(observation_shapes, network_settings)
if network_settings.memory is not None:
self.encoding_size = network_settings.memory.memory_size // 2
else:
self.encoding_size = network_settings.hidden_units
self.continuous_distribution = GaussianDistribution(
self.encoding_size,
continuous_act_size[0],
conditional_sigma=conditional_sigma,
tanh_squash=tanh_squash,
)
self.discrete_distribution = MultiCategoricalDistribution(
self.encoding_size, discrete_act_size
)
@property
def memory_size(self) -> int:
return self.network_body.memory_size
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None:
self.network_body.update_normalization(vector_obs)
def sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]:
actions = []
for action_dist in dists:
action = action_dist.sample()
actions.append(action)
return actions
def get_dists(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Optional[torch.Tensor]]:
encoding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
discrete_dists = self.discrete_distribution(encoding, masks)
continuous_dists = self.continuous_distribution(encoding)
return discrete_dists + continuous_dists, memories
def forward(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, int, int, int, int]:
"""
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
"""
# TODO: This is bad right now
dists _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1)
discrete_dists = dists[0]
continuous_dists = dists[1]
discrete_action_out = discrete_dists[0].all_log_prob()
continuous_action_list = self.sample_action(continuous_dists)
continuous_action_out = torch.stack(continuous_action_list, dim=-1)
action_out = torch.cat(continuous, discrete_action_out, dim=-1)
return (
action_out,
self.version_number,
torch.Tensor([self.network_body.memory_size]),
self.is_continuous_int,
self.act_size_vector,
)
class HybridSharedActorCritic(HybridSimpleActor, ActorCritic):
def __init__(
self,
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
continuous_act_size: List[int],
discrete_act_size: List[int],
stream_names: List[str],
conditional_sigma: bool = False,
tanh_squash: bool = False,
):
super().__init__(
self,
observation_shapes,
network_settings,
act_type,
act_size,
conditional_sigma,
tanh_squash,
)
self.stream_names = stream_names
self.value_heads = ValueHeads(stream_names, self.encoding_size)
def critic_pass(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
encoding, memories_out = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
return self.value_heads(encoding), memories_out
def get_dist_and_value(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
encoding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
if self.act_type == ActionType.CONTINUOUS:
dists = self.distribution(encoding)
else:
dists = self.distribution(encoding, masks=masks)
value_outputs = self.value_heads(encoding)
return dists, value_outputs, memories
class HybridSeparateActorCritic(HybridSimpleActor, ActorCritic):
def __init__(
self,
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
continuous_act_size: List[int],
discrete_act_size: List[int],
stream_names: List[str],
conditional_sigma: bool = False,
tanh_squash: bool = False,
):
super().__init__(
self,
observation_shapes,
network_settings,
act_type,
act_size,
conditional_sigma,
tanh_squash,
)
self.stream_names = stream_names
self.value_heads = ValueHeads(stream_names, self.encoding_size)
@property
def memory_size(self) -> int:
return self.network_body.memory_size + self.critic.memory_size
def critic_pass(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
actor_mem, critic_mem = None, None
if self.use_lstm:
# Use only the back half of memories for critic
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1)
value_outputs, critic_mem_out = self.critic(
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
)
if actor_mem is not None:
# Make memories with the actor mem unchanged
memories_out = torch.cat([actor_mem, critic_mem_out], dim=-1)
else:
memories_out = None
return value_outputs, memories_out
def get_dist_and_value(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
if self.use_lstm:
# Use only the back half of memories for critic and actor
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1)
else:
critic_mem = None
actor_mem = None
dists, actor_mem_outs = self.get_dists(
vec_inputs,
vis_inputs,
memories=actor_mem,
sequence_length=sequence_length,
masks=masks,
)
value_outputs, critic_mem_outs = self.critic(
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
)
if self.use_lstm:
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1)
else:
mem_out = None
return dists, value_outputs, mem_out
################################################################################
######### Continuous xor Discrete cases ##########
################################################################################
class SimpleActor(nn.Module, Actor):
def __init__(
self,

正在加载...
取消
保存