浏览代码

Fix some bugs for visual obs

/develop-newnormalization
Ervin Teng 5 年前
当前提交
c7632aa7
共有 3 个文件被更改,包括 18 次插入15 次删除
  1. 14
      ml-agents/mlagents/trainers/agent_processor.py
  2. 2
      ml-agents/mlagents/trainers/ppo/policy.py
  3. 17
      ml-agents/mlagents/trainers/trajectory.py

14
ml-agents/mlagents/trainers/agent_processor.py


def __init__(self, trainer: Trainer):
self.experience_buffers: Dict[str, List] = defaultdict(list)
self.last_brain_info: Dict[str, BrainInfo] = defaultdict(BrainInfo)
self.last_brain_info: Dict[str, BrainInfo] = {}
self.last_take_action_outputs: Dict[str, ActionInfoOutputs] = defaultdict(
ActionInfoOutputs
)

tmp_environment = np.array(next_info.rewards)
for agent_id in next_info.agents:
stored_info = self.last_brain_info[agent_id]
stored_take_action_outputs = self.last_take_action_outputs[agent_id]
stored_info = self.last_brain_info.get(agent_id, None)
stored_take_action_outputs = self.last_take_action_outputs.get(
agent_id, None
)
if stored_info is not None:
idx = stored_info.agents.index(agent_id)
next_idx = next_info.agents.index(agent_id)

max_step = next_info.max_reached[next_idx]
# Add the outputs of the last eval
action = take_action_outputs["action"][idx]
action = stored_take_action_outputs["action"][idx]
action_pre = take_action_outputs["pre_action"][idx]
action_pre = stored_take_action_outputs["pre_action"][idx]
action_probs = take_action_outputs["log_probs"][idx]
action_probs = stored_take_action_outputs["log_probs"][idx]
action_masks = stored_info.action_masks[idx]
prev_action = self.policy.retrieve_previous_action([agent_id])[0, :]

2
ml-agents/mlagents/trainers/ppo/policy.py


if not self.use_continuous_act and self.use_recurrent:
feed_dict[self.model.prev_action] = batch["prev_action"]
value_estimates = self.sess.run(self.model.value_heads, feed_dict)
value_estimates = {k: np.squeeze(v) for k, v in value_estimates.items()}
value_estimates = {k: np.squeeze(v, axis=1) for k, v in value_estimates.items()}
return value_estimates

17
ml-agents/mlagents/trainers/trajectory.py


vec_obs_indices.append(index)
if len(observation.shape) == 3:
vis_obs_indices.append(index)
vec_obs = np.concatenate([obs[i] for i in vec_obs_indices], axis=0)
vec_obs = (
np.concatenate([obs[i] for i in vec_obs_indices], axis=0)
if len(vec_obs_indices) > 0
else np.array([], dtype=np.float32)
)
vis_obs = [obs[i] for i in vis_obs_indices]
return SplitObservations(vector_observations=vec_obs, visual_observations=vis_obs)

agent_buffer_trajectory["next_visual_obs%d" % i].append(
next_vec_vis_obs.visual_observations[i]
)
if vec_vis_obs.vector_observations.size > 0:
agent_buffer_trajectory["vector_obs"].append(
vec_vis_obs.vector_observations
)
agent_buffer_trajectory["next_vector_in"].append(
next_vec_vis_obs.vector_observations
)
agent_buffer_trajectory["vector_obs"].append(vec_vis_obs.vector_observations)
agent_buffer_trajectory["next_vector_in"].append(
next_vec_vis_obs.vector_observations
)
if exp.memory:
agent_buffer_trajectory["memory"].append(exp.memory)

正在加载...
取消
保存