浏览代码

Move agent_id to Trajectory

/develop-newnormalization
Ervin Teng 5 年前
当前提交
324d217b
共有 7 个文件被更改,包括 10 次插入24 次删除
  1. 5
      ml-agents/mlagents/trainers/agent_processor.py
  2. 4
      ml-agents/mlagents/trainers/bc/trainer.py
  3. 3
      ml-agents/mlagents/trainers/ppo/policy.py
  4. 5
      ml-agents/mlagents/trainers/ppo/trainer.py
  5. 2
      ml-agents/mlagents/trainers/sac/trainer.py
  6. 4
      ml-agents/mlagents/trainers/tests/test_trajectory.py
  7. 11
      ml-agents/mlagents/trainers/trajectory.py

5
ml-agents/mlagents/trainers/agent_processor.py


action_mask=action_masks,
prev_action=prev_action,
max_step=max_step,
agent_id=agent_id,
memory=memory,
)
# Add the value outputs if needed

next_obs.append(next_info.visual_observations[i][next_idx])
if self.policy.use_vec_obs:
next_obs.append(next_info.vector_observations[next_idx])
trajectory = Trajectory(steps=self.experience_buffers[agent_id])
trajectory = Trajectory(
steps=self.experience_buffers[agent_id], agent_id=agent_id
)
# This will eventually be replaced with a queue
self.trainer.process_trajectory(trajectory)
self.experience_buffers[agent_id] = []

4
ml-agents/mlagents/trainers/bc/trainer.py


Takes a trajectory and processes it, putting it into the update buffer.
Processing involves calculating value and advantage targets for model updating step.
"""
agent_id = trajectory.steps[
-1
].agent_id # All the agents should have the same ID
agent_id = trajectory.agent_id # All the experiences should have the same ID
agent_buffer_trajectory = trajectory_to_agentbuffer(trajectory)
# Evaluate all reward functions

3
ml-agents/mlagents/trainers/ppo/policy.py


return value_estimates
def get_value_estimates(
self, experience: AgentExperience, done: bool
self, experience: AgentExperience, agent_id: str, done: bool
) -> Dict[str, float]:
"""
Generates value estimates for bootstrapping.

self.model.batch_size: 1,
self.model.sequence_length: 1,
}
agent_id = experience.agent_id
vec_vis_obs = split_obs(experience.obs)
for i in range(len(vec_vis_obs.visual_observations)):
feed_dict[self.model.visual_in[i]] = [vec_vis_obs.visual_observations[i]]

5
ml-agents/mlagents/trainers/ppo/trainer.py


Takes a trajectory and processes it, putting it into the update buffer.
Processing involves calculating value and advantage targets for model updating step.
"""
agent_id = trajectory.steps[
-1
].agent_id # All the agents should have the same ID
agent_id = trajectory.agent_id # All the agents should have the same ID
# Note that this agent buffer version of the traj. is one less than the len of the raw trajectory
# for bootstrapping purposes.

value_next = self.policy.get_value_estimates(
trajectory.steps[-1],
trajectory.steps[-1].done and not trajectory.steps[-1].max_step,
agent_id,
)
# Evaluate all reward functions

2
ml-agents/mlagents/trainers/sac/trainer.py


Takes a trajectory and processes it, putting it into the replay buffer.
"""
last_step = trajectory.steps[-1]
agent_id = last_step.agent_id # All the agents should have the same ID
agent_id = trajectory.agent_id # All the agents should have the same ID
# Note that this agent buffer version of the traj. is one less than the len of the raw trajectory
# for bootstrapping purposes.

4
ml-agents/mlagents/trainers/tests/test_trajectory.py


prev_action=prev_action,
max_step=max_step,
memory=memory,
agent_id=agent_id,
)
steps_list.append(experience)
last_experience = AgentExperience(

prev_action=prev_action,
max_step=max_step_complete,
memory=memory,
agent_id=agent_id,
return Trajectory(steps=steps_list)
return Trajectory(steps=steps_list, agent_id=agent_id)
@pytest.mark.parametrize("num_visual_obs", [0, 1, 2])

11
ml-agents/mlagents/trainers/trajectory.py


import numpy as np
from mlagents.trainers.buffer import AgentBuffer
from mlagents.envs.exception import UnityException
class AgentExperience(NamedTuple):

prev_action: np.ndarray
max_step: bool
memory: np.ndarray
agent_id: str
class SplitObservations(NamedTuple):

class Trajectory(NamedTuple):
steps: List[AgentExperience]
class AgentProcessorException(UnityException):
"""
Related to errors with the AgentProcessor.
"""
pass
agent_id: str
def split_obs(obs: List[np.ndarray]) -> SplitObservations:

正在加载...
取消
保存