浏览代码

Fix SAC interrupted condition and typing

/develop/unified-obs
Ervin Teng 4 年前
当前提交
1db21cbb
共有 2 个文件被更改,包括 5 次插入10 次删除
  1. 10
      ml-agents/mlagents/trainers/sac/trainer.py
  2. 5
      ml-agents/mlagents/trainers/torch/encoders.py

10
ml-agents/mlagents/trainers/sac/trainer.py


from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer
from mlagents.trainers.trajectory import Trajectory, SplitObservations
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.settings import TrainerSettings, SACSettings, FrameworkType
from mlagents.trainers.torch.components.reward_providers import BaseRewardProvider

# Bootstrap using the last step rather than the bootstrap step if max step is reached.
# Set last element to duplicate obs and remove dones.
if last_step.interrupted:
vec_vis_obs = SplitObservations.from_observations(last_step.obs)
for i, obs in enumerate(vec_vis_obs.visual_observations):
agent_buffer_trajectory["next_visual_obs%d" % i][-1] = obs
if vec_vis_obs.vector_observations.size > 1:
agent_buffer_trajectory["next_vector_in"][
-1
] = vec_vis_obs.vector_observations
agent_buffer_trajectory["next_obs"] = last_step.obs
agent_buffer_trajectory["done"][-1] = False
# Append to update buffer

5
ml-agents/mlagents/trainers/torch/encoders.py


return inputs
def copy_normalization(self, other_input: "InputProcessor") -> None:
if self.normalizer is not None and other_input.normalizer is not None:
self.normalizer.copy_from(other_input.normalizer)
if isinstance(other_input, VectorInput):
if self.normalizer is not None and other_input.normalizer is not None:
self.normalizer.copy_from(other_input.normalizer)
def update_normalization(self, inputs: torch.Tensor) -> None:
if self.normalizer is not None:

正在加载...
取消
保存