|
|
|
|
|
|
from mlagents.trainers.trainer.rl_trainer import RLTrainer |
|
|
|
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|
|
|
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer |
|
|
|
from mlagents.trainers.trajectory import Trajectory, SplitObservations |
|
|
|
from mlagents.trainers.trajectory import Trajectory |
|
|
|
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers |
|
|
|
from mlagents.trainers.settings import TrainerSettings, SACSettings, FrameworkType |
|
|
|
from mlagents.trainers.torch.components.reward_providers import BaseRewardProvider |
|
|
|
|
|
|
# Bootstrap using the last step rather than the bootstrap step if max step is reached. |
|
|
|
# Set last element to duplicate obs and remove dones. |
|
|
|
if last_step.interrupted: |
|
|
|
vec_vis_obs = SplitObservations.from_observations(last_step.obs) |
|
|
|
for i, obs in enumerate(vec_vis_obs.visual_observations): |
|
|
|
agent_buffer_trajectory["next_visual_obs%d" % i][-1] = obs |
|
|
|
if vec_vis_obs.vector_observations.size > 1: |
|
|
|
agent_buffer_trajectory["next_vector_in"][ |
|
|
|
-1 |
|
|
|
] = vec_vis_obs.vector_observations |
|
|
|
agent_buffer_trajectory["next_obs"] = last_step.obs |
|
|
|
agent_buffer_trajectory["done"][-1] = False |
|
|
|
|
|
|
|
# Append to update buffer |
|
|
|