浏览代码

target critic for ppo

/develop/bisim-sac-transfer
yanchaosun 4 年前
当前提交
36f36750
共有 5 个文件被更改,包括 67 次插入9 次删除
  1. 2
      config/ppo_transfer/3DBall.yaml
  2. 2
      config/ppo_transfer/3DBallHard.yaml
  3. 8
      config/ppo_transfer/3DBallHardTransfer.yaml
  4. 4
      ml-agents/mlagents/trainers/models.py
  5. 60
      ml-agents/mlagents/trainers/ppo_transfer/optimizer.py

2
config/ppo_transfer/3DBall.yaml


in_epoch_alter: false
in_batch_alter: true
use_op_buffer: false
use_var_predict: false
use_var_predict: true
with_prior: false
predict_return: true
use_bisim: false

2
config/ppo_transfer/3DBallHard.yaml


in_epoch_alter: false
in_batch_alter: true
use_op_buffer: false
use_var_predict: false
use_var_predict: true
with_prior: false
predict_return: true
use_bisim: false

8
config/ppo_transfer/3DBallHardTransfer.yaml


lambd: 0.95
num_epoch: 3
learning_rate_schedule: linear
model_schedule: linear
model_schedule: constant
encoder_layers: 1
policy_layers: 1
forward_layers: 1

in_epoch_alter: false
in_batch_alter: true
use_op_buffer: true
use_var_predict: false
use_op_buffer: false
use_var_predict: true
transfer_path: "results/ball-novar/3DBall"
transfer_path: "results/ball-targv/3DBall"
load_model: true
train_model: false
network_settings:

4
ml-agents/mlagents/trainers/models.py


@staticmethod
def create_value_heads(
stream_names: List[str], hidden_input: tf.Tensor
stream_names: List[str], hidden_input: tf.Tensor, reuse: bool=False
) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
"""
Creates one value estimator head for each reward signal in stream_names.

"""
value_heads = {}
for name in stream_names:
value = tf.layers.dense(hidden_input, 1, name="{}_value".format(name))
value = tf.layers.dense(hidden_input, 1, name="{}_value".format(name), reuse=reuse)
value_heads[name] = value
value = tf.reduce_mean(list(value_heads.values()), 0)
return value_heads, value

60
ml-agents/mlagents/trainers/ppo_transfer/optimizer.py


from typing import Optional, Any, Dict, cast
from typing import Optional, Any, Dict, cast, List, Tuple
import numpy as np
import os
import copy

from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.components.reward_signals.curiosity.model import CuriosityModel
from mlagents.trainers.policy.transfer_policy import TransferPolicy
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer

self.old_log_probs = tf.reduce_sum(
(tf.identity(self.all_old_log_probs)), axis=1, keepdims=True
)
target_hidden_value = ModelUtils.create_vector_observation_encoder(
self.policy.targ_encoder,
h_size,
ModelUtils.swish,
num_layers,
scope=f"main_graph",
reuse=True,
)
self.target_value_heads, self.target_value = ModelUtils.create_value_heads(
self.stream_names, target_hidden_value, reuse=True
)
def _create_dc_critic(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType

axis=1,
keepdims=True,
)
def _get_value_estimates(
self,
next_obs: List[np.ndarray],
done: bool,
policy_memory: np.ndarray = None,
value_memory: np.ndarray = None,
prev_action: np.ndarray = None,
) -> Dict[str, float]:
"""
Generates value estimates for bootstrapping.
:param experience: AgentExperience to be used for bootstrapping.
:param done: Whether or not this is the last element of the episode, in which case the value estimate will be 0.
:return: The value estimate dictionary with key being the name of the reward signal and the value the
corresponding value estimate.
"""
feed_dict: Dict[tf.Tensor, Any] = {
self.policy.batch_size_ph: 1,
self.policy.sequence_length_ph: 1,
}
vec_vis_obs = SplitObservations.from_observations(next_obs)
for i in range(len(vec_vis_obs.visual_observations)):
feed_dict[self.policy.visual_in[i]] = [vec_vis_obs.visual_observations[i]]
if self.policy.vec_obs_size > 0:
feed_dict[self.policy.vector_in] = [vec_vis_obs.vector_observations]
if policy_memory is not None:
feed_dict[self.policy.memory_in] = policy_memory
if value_memory is not None:
feed_dict[self.memory_in] = value_memory
if prev_action is not None:
feed_dict[self.policy.prev_action] = [prev_action]
value_estimates = self.sess.run(self.target_value_heads, feed_dict)
value_estimates = {k: float(v) for k, v in value_estimates.items()}
# If we're done, reassign all of the value estimates that need terminal states.
if done:
for k in value_estimates:
if self.reward_signals[k].use_terminal_states:
value_estimates[k] = 0.0
return value_estimates
正在加载...
取消
保存