浏览代码

Merge pull request #4803 from Unity-Technologies/develop-remove-vec-vis-fields

Remove some vis and vec fields from Policy
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
a272bff0
共有 3 个文件被更改,包括 15 次插入26 次删除
  1. 14
      ml-agents/mlagents/trainers/policy/policy.py
  2. 14
      ml-agents/mlagents/trainers/policy/torch_policy.py
  3. 13
      ml-agents/mlagents/trainers/torch/model_serialization.py

14
ml-agents/mlagents/trainers/policy/policy.py


self.trainer_settings = trainer_settings
self.network_settings: NetworkSettings = trainer_settings.network_settings
self.seed = seed
self.act_size = (
list(self.behavior_spec.action_spec.discrete_branches)
if self.behavior_spec.action_spec.is_discrete()
else [self.behavior_spec.action_spec.continuous_size]
)
self.vec_obs_size = sum(
sen_spec.shape[0]
for sen_spec in behavior_spec.sensor_specs
if len(sen_spec.shape) == 1
)
self.vis_obs_size = sum(
1 for sen_spec in behavior_spec.sensor_specs if len(sen_spec.shape) == 3
)
self.use_continuous_act = self.behavior_spec.action_spec.is_continuous()
self.previous_action_dict: Dict[str, np.ndarray] = {}
self.memory_dict: Dict[str, np.ndarray] = {}
self.normalize = trainer_settings.network_settings.normalize

14
ml-agents/mlagents/trainers/policy/torch_policy.py


def _extract_masks(self, decision_requests: DecisionSteps) -> np.ndarray:
mask = None
if self.behavior_spec.action_spec.discrete_size > 0:
mask = torch.ones([len(decision_requests), np.sum(self.act_size)])
num_discrete_flat = np.sum(self.behavior_spec.action_spec.discrete_branches)
mask = torch.ones([len(decision_requests), num_discrete_flat])
if decision_requests.action_mask is not None:
mask = torch.as_tensor(
1 - np.concatenate(decision_requests.action_mask, axis=1)

:param buffer: The buffer with the observations to add to the running estimate
of the distribution.
"""
if self.use_vec_obs and self.normalize:
if self.normalize:
self.actor_critic.update_normalization(buffer)
@timed

outputs=run_out,
agent_ids=list(decision_requests.agent_id),
)
@property
def use_vis_obs(self):
return self.vis_obs_size > 0
@property
def use_vec_obs(self):
return self.vec_obs_size > 0
def get_current_step(self):
"""

13
ml-agents/mlagents/trainers/torch/model_serialization.py


self.policy = policy
batch_dim = [1]
seq_len_dim = [1]
dummy_vec_obs = [torch.zeros(batch_dim + [self.policy.vec_obs_size])]
vec_obs_size = 0
for sens_spec in self.policy.behavior_spec.sensor_specs:
if len(sens_spec.shape) == 1:
vec_obs_size += sens_spec.shape[0]
num_vis_obs = sum(
1
for sens_spec in self.policy.behavior_spec.sensor_specs
if len(sens_spec.shape) == 3
)
dummy_vec_obs = [torch.zeros(batch_dim + [vec_obs_size])]
# create input shape of NCHW
# (It's NHWC in self.policy.behavior_spec.sensor_specs.shape)
dummy_vis_obs = [

self.input_names = (
["vector_observation"]
+ [f"visual_observation_{i}" for i in range(self.policy.vis_obs_size)]
+ [f"visual_observation_{i}" for i in range(num_vis_obs)]
+ ["action_masks", "memories"]
)
self.dynamic_axes = {name: {0: "batch"} for name in self.input_names}

正在加载...
取消
保存