浏览代码

reset these to master

/active-variablespeed
HH 4 年前
当前提交
c72553c8
共有 2 个文件被更改,包括 19 次插入19 次删除
  1. 14
      ml-agents/mlagents/trainers/learn.py
  2. 24
      ml-agents/mlagents/trainers/ppo/trainer.py

14
ml-agents/mlagents/trainers/learn.py


def create_environment_factory(
env_path: Optional[str],
no_graphics: bool,
seed: int,
start_port: Optional[int],
env_args: Optional[List[str]],
log_folder: str,
env_path: Optional[str],
no_graphics: bool,
seed: int,
start_port: Optional[int],
env_args: Optional[List[str]],
log_folder: str,
worker_id: int, side_channels: List[SideChannel]
worker_id: int, side_channels: List[SideChannel]
) -> UnityEnvironment:
# Make sure that each environment gets a different seed
env_seed = seed + worker_id

24
ml-agents/mlagents/trainers/ppo/trainer.py


"""The PPOTrainer is an implementation of the PPO algorithm."""
def __init__(
self,
brain_name: str,
reward_buff_cap: int,
trainer_settings: TrainerSettings,
training: bool,
load: bool,
seed: int,
artifact_path: str,
self,
brain_name: str,
reward_buff_cap: int,
trainer_settings: TrainerSettings,
training: bool,
load: bool,
seed: int,
artifact_path: str,
):
"""
Responsible for collecting experiences and training PPO model.

agent_buffer_trajectory,
trajectory.next_obs,
trajectory.done_reached and not trajectory.interrupted,
)
)
for name, v in value_estimates.items():
agent_buffer_trajectory[f"{name}_value_estimates"].extend(v)
self._stats_reporter.add_stat(

# Make sure batch_size is a multiple of sequence length. During training, we
# will need to reshape the data into a batch_size x sequence_length tensor.
batch_size = (
self.hyperparameters.batch_size
- self.hyperparameters.batch_size % self.policy.sequence_length
self.hyperparameters.batch_size
- self.hyperparameters.batch_size % self.policy.sequence_length
)
# Make sure there is at least one sequence
batch_size = max(batch_size, self.policy.sequence_length)

return True
def create_policy(
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec
) -> TFPolicy:
"""
Creates a PPO policy to trainers list of policies.

正在加载...
取消
保存