浏览代码

Removing some vis and vec fields from policy.py

/MLA-1734-demo-provider
vincentpierre 4 年前
当前提交
36cc4665
共有 3 个文件被更改,包括 20 次插入24 次删除
  1. 21
      ml-agents/mlagents/trainers/policy/policy.py
  2. 14
      ml-agents/mlagents/trainers/policy/torch_policy.py
  3. 9
      ml-agents/mlagents/trainers/torch/model_serialization.py

21
ml-agents/mlagents/trainers/policy/policy.py


self.trainer_settings = trainer_settings
self.network_settings: NetworkSettings = trainer_settings.network_settings
self.seed = seed
self.act_size = (
list(self.behavior_spec.action_spec.discrete_branches)
if self.behavior_spec.action_spec.is_discrete()
else [self.behavior_spec.action_spec.continuous_size]
)
self.vec_obs_size = sum(
shape[0] for shape in behavior_spec.observation_shapes if len(shape) == 1
)
self.vis_obs_size = sum(
1 for shape in behavior_spec.observation_shapes if len(shape) == 3
)
self.use_continuous_act = self.behavior_spec.action_spec.is_continuous()
if self.normalize:
has_vec_obs = False
# Make sure there is at least one vector observation for normalization
for shape in behavior_spec.observation_shapes:
if len(shape) == 1:
has_vec_obs = True
break
if not has_vec_obs:
self.normalize = False
self.use_recurrent = self.network_settings.memory is not None
self.h_size = self.network_settings.hidden_units
num_layers = self.network_settings.num_layers

14
ml-agents/mlagents/trainers/policy/torch_policy.py


def _extract_masks(self, decision_requests: DecisionSteps) -> np.ndarray:
mask = None
if self.behavior_spec.action_spec.discrete_size > 0:
mask = torch.ones([len(decision_requests), np.sum(self.act_size)])
num_discrete_flat = np.sum(self.behavior_spec.action_spec.discrete_branches)
mask = torch.ones([len(decision_requests), num_discrete_flat])
if decision_requests.action_mask is not None:
mask = torch.as_tensor(
1 - np.concatenate(decision_requests.action_mask, axis=1)

:param buffer: The buffer with the observations to add to the running estimate
of the distribution.
"""
if self.use_vec_obs and self.normalize:
if self.normalize:
self.actor_critic.update_normalization(buffer)
@timed

outputs=run_out,
agent_ids=list(decision_requests.agent_id),
)
@property
def use_vis_obs(self):
return self.vis_obs_size > 0
@property
def use_vec_obs(self):
return self.vec_obs_size > 0
def get_current_step(self):
"""

9
ml-agents/mlagents/trainers/torch/model_serialization.py


self.policy = policy
batch_dim = [1]
seq_len_dim = [1]
dummy_vec_obs = [torch.zeros(batch_dim + [self.policy.vec_obs_size])]
vec_obs_size = 0
for shape in self.policy.behavior_spec.observation_shapes:
if len(shape) == 1:
vec_obs_size += shape[0]
num_vis_obs = sum(1 for shape in self.policy.behavior_spec.observation_shapes if len(shape) == 3)
dummy_vec_obs = [torch.zeros(batch_dim + [vec_obs_size])]
# create input shape of NCHW
# (It's NHWC in self.policy.behavior_spec.observation_shapes)
dummy_vis_obs = [

self.input_names = (
["vector_observation"]
+ [f"visual_observation_{i}" for i in range(self.policy.vis_obs_size)]
+ [f"visual_observation_{i}" for i in range(num_vis_obs)]
+ ["action_masks", "memories"]
)
self.dynamic_axes = {name: {0: "batch"} for name in self.input_names}

正在加载...
取消
保存