浏览代码

Merge remote-tracking branch 'origin/develop-critic-optimizer' into develop-critic-optimizer

/develop/action-slice
Ervin Teng 4 年前
当前提交
24ee4bd5
共有 2 个文件被更改,包括 3 次插入11 次删除
  1. 10
      ml-agents/mlagents/trainers/policy/torch_policy.py
  2. 4
      ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py

10
ml-agents/mlagents/trainers/policy/torch_policy.py


) # could be much simpler if TorchPolicy is nn.Module
self.grads = None
# reward_signal_configs = trainer_settings.reward_signals
# reward_signal_names = [key.value for key, _ in reward_signal_configs.items()]
ac_class = SimpleActor
# if separate_critic:
# ac_class = SimpleActor
# else:
# ac_class = SharedActorCritic
self.actor = ac_class(
self.actor = SimpleActor(
observation_specs=self.behavior_spec.observation_specs,
network_settings=trainer_settings.network_settings,
action_spec=behavior_spec.action_spec,

4
ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py


"""
Make sure two policies have the same output for the same input.
"""
policy1.actor_critic = policy1.actor_critic.to(default_device())
policy2.actor_critic = policy2.actor_critic.to(default_device())
policy1.actor = policy1.actor.to(default_device())
policy2.actor = policy2.actor.to(default_device())
decision_step, _ = mb.create_steps_from_behavior_spec(
policy1.behavior_spec, num_agents=1

正在加载...
取消
保存