|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from mlagents.envs.brain import BrainInfo |
|
|
|
from mlagents.envs.action_info import ActionInfoOutputs |
|
|
|
from mlagents.trainers.agent_processor import ProcessingBuffer |
|
|
|
from mlagents.trainers.trajectory import Trajectory, trajectory_to_agentbuffer |
|
|
|
from mlagents.trainers.trainer import Trainer |
|
|
|
|
|
|
|
logger = logging.getLogger("mlagents.trainers") |
|
|
|
|
|
|
self.batches_per_epoch = trainer_parameters["batches_per_epoch"] |
|
|
|
|
|
|
|
self.demonstration_buffer = AgentBuffer() |
|
|
|
self.evaluation_buffer = ProcessingBuffer() |
|
|
|
def add_experiences( |
|
|
|
self, |
|
|
|
curr_info: BrainInfo, |
|
|
|
next_info: BrainInfo, |
|
|
|
take_action_outputs: ActionInfoOutputs, |
|
|
|
) -> None: |
|
|
|
def process_trajectory(self, trajectory: Trajectory) -> None: |
|
|
|
Adds experiences to each agent's experience history. |
|
|
|
:param curr_info: Current BrainInfo |
|
|
|
:param next_info: Next BrainInfo |
|
|
|
:param take_action_outputs: The outputs of the take action method. |
|
|
|
Takes a trajectory and processes it, putting it into the update buffer. |
|
|
|
Processing involves calculating value and advantage targets for model updating step. |
|
|
|
|
|
|
|
# Used to collect information about student performance. |
|
|
|
for agent_id in curr_info.agents: |
|
|
|
self.evaluation_buffer[agent_id].last_brain_info = curr_info |
|
|
|
agent_id = trajectory.steps[ |
|
|
|
-1 |
|
|
|
].agent_id # All the agents should have the same ID |
|
|
|
agent_buffer_trajectory = trajectory_to_agentbuffer(trajectory) |
|
|
|
for agent_id in next_info.agents: |
|
|
|
stored_next_info = self.evaluation_buffer[agent_id].last_brain_info |
|
|
|
if stored_next_info is None: |
|
|
|
continue |
|
|
|
else: |
|
|
|
next_idx = next_info.agents.index(agent_id) |
|
|
|
if agent_id not in self.cumulative_rewards: |
|
|
|
self.cumulative_rewards[agent_id] = 0 |
|
|
|
self.cumulative_rewards[agent_id] += next_info.rewards[next_idx] |
|
|
|
if not next_info.local_done[next_idx]: |
|
|
|
if agent_id not in self.episode_steps: |
|
|
|
self.episode_steps[agent_id] = 0 |
|
|
|
self.episode_steps[agent_id] += 1 |
|
|
|
# Evaluate all reward functions |
|
|
|
self.collected_rewards["environment"][agent_id] += np.sum( |
|
|
|
agent_buffer_trajectory["environment_rewards"] |
|
|
|
) |
|
|
|
def process_experiences( |
|
|
|
self, current_info: BrainInfo, next_info: BrainInfo |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Checks agent histories for processing condition, and processes them as necessary. |
|
|
|
Processing involves calculating value and advantage targets for model updating step. |
|
|
|
:param current_info: Current BrainInfo |
|
|
|
:param next_info: Next BrainInfo |
|
|
|
""" |
|
|
|
for l in range(len(next_info.agents)): |
|
|
|
if next_info.local_done[l]: |
|
|
|
agent_id = next_info.agents[l] |
|
|
|
self.stats["Environment/Cumulative Reward"].append( |
|
|
|
self.cumulative_rewards.get(agent_id, 0) |
|
|
|
) |
|
|
|
self.stats["Environment/Episode Length"].append( |
|
|
|
self.episode_steps.get(agent_id, 0) |
|
|
|
) |
|
|
|
self.reward_buffer.appendleft(self.cumulative_rewards.get(agent_id, 0)) |
|
|
|
self.cumulative_rewards[agent_id] = 0 |
|
|
|
self.episode_steps[agent_id] = 0 |
|
|
|
if trajectory.steps[-1].done: |
|
|
|
self.stats["Environment/Episode Length"].append( |
|
|
|
self.episode_steps.get(agent_id, 0) |
|
|
|
) |
|
|
|
self.episode_steps[agent_id] = 0 |
|
|
|
for name, rewards in self.collected_rewards.items(): |
|
|
|
if name == "environment": |
|
|
|
self.cumulative_returns_since_policy_update.append( |
|
|
|
rewards.get(agent_id, 0) |
|
|
|
) |
|
|
|
self.stats["Environment/Cumulative Reward"].append( |
|
|
|
rewards.get(agent_id, 0) |
|
|
|
) |
|
|
|
self.reward_buffer.appendleft(rewards.get(agent_id, 0)) |
|
|
|
rewards[agent_id] = 0 |
|
|
|
else: |
|
|
|
self.stats[self.policy.reward_signals[name].stat_name].append( |
|
|
|
rewards.get(agent_id, 0) |
|
|
|
) |
|
|
|
rewards[agent_id] = 0 |
|
|
|
|
|
|
|
def end_episode(self): |
|
|
|
""" |
|
|
|