|
|
|
|
|
|
trainer_params.get("vis_encode_type", "simple") |
|
|
|
) |
|
|
|
self.tau = trainer_params.get("tau", 0.005) |
|
|
|
m_size = self.policy.m_size |
|
|
|
|
|
|
|
self.init_entcoef = trainer_params.get("init_entcoef", 1.0) |
|
|
|
stream_names = self.reward_signals.keys() |
|
|
|
# Use to reduce "survivor bonus" when using Curiosity or GAIL. |
|
|
|
|
|
|
|
|
|
|
self.policy_network = SACPolicyNetwork( |
|
|
|
policy=self.policy, |
|
|
|
m_size=m_size, |
|
|
|
m_size=self.policy.m_size, # 3x policy.m_size |
|
|
|
h_size=h_size, |
|
|
|
normalize=self.policy.normalize, |
|
|
|
use_recurrent=self.policy.use_recurrent, |
|
|
|
|
|
|
) |
|
|
|
self.target_network = SACTargetNetwork( |
|
|
|
policy=self.policy, |
|
|
|
m_size=m_size // 4 if m_size else None, |
|
|
|
m_size=self.policy.m_size, # 1x policy.m_size |
|
|
|
h_size=h_size, |
|
|
|
normalize=self.policy.normalize, |
|
|
|
use_recurrent=self.policy.use_recurrent, |
|
|
|
|
|
|
) |
|
|
|
# The optimizer's m_size is 3 times the policy (Q1, Q2, and Value) |
|
|
|
self.m_size = 3 * self.policy.m_size |
|
|
|
self.create_inputs_and_outputs() |
|
|
|
self.learning_rate = LearningModel.create_learning_rate( |
|
|
|
lr_schedule, lr, self.policy.global_step, int(max_step) |
|
|
|
|
|
|
|
|
|
|
# Add some stuff to inference dict from optimizer |
|
|
|
self.policy.inference_dict["learning_rate"] = self.learning_rate |
|
|
|
if self.policy.use_recurrent: |
|
|
|
self.policy.inference_dict["optimizer_memory_out"] = self.memory_out |
|
|
|
|
|
|
|
def create_inputs_and_outputs(self) -> None: |
|
|
|
""" |
|
|
|
|
|
|
if self.policy.use_recurrent: |
|
|
|
self.memory_in = self.policy_network.memory_in |
|
|
|
self.memory_out = self.policy_network.memory_out |
|
|
|
|
|
|
|
# For Barracuda |
|
|
|
self.inference_memory_out = tf.identity( |
|
|
|
self.policy_network.policy_memory_out, name="recurrent_out" |
|
|
|
) |
|
|
|
|
|
|
|
if not self.policy.use_continuous_act: |
|
|
|
self.prev_action = self.policy_network.prev_action |
|
|
|
self.next_memory_in = self.target_network.memory_in |
|
|
|
|
|
|
stats_needed.update(self.reward_signals[name].stats_name_to_update_name) |
|
|
|
|
|
|
|
def construct_feed_dict( |
|
|
|
self, policy: TFPolicy, batch: Dict[str, Any], num_sequences: int |
|
|
|
self, policy: TFPolicy, batch: AgentBuffer, num_sequences: int |
|
|
|
) -> Dict[tf.Tensor, Any]: |
|
|
|
""" |
|
|
|
Builds the feed dict for updating the SAC model. |
|
|
|
|
|
|
_obs = batch["next_visual_obs%d" % i] |
|
|
|
feed_dict[self.next_visual_in[i]] = _obs |
|
|
|
if self.policy.use_recurrent: |
|
|
|
mem_in = [ |
|
|
|
batch["memory"][i] |
|
|
|
for i in range(0, len(batch["memory"]), self.policy.sequence_length) |
|
|
|
] |
|
|
|
# LSTM shouldn't have sequence length <1, but stop it from going out of the index if true. |
|
|
|
offset = 1 if self.policy.sequence_length > 1 else 0 |
|
|
|
next_mem_in = [ |
|
|
|
batch["memory"][i][ |
|
|
|
: self.policy.m_size // 4 |
|
|
|
] # only pass value part of memory to target network |
|
|
|
for i in range( |
|
|
|
offset, len(batch["memory"]), self.policy.sequence_length |
|
|
|
) |
|
|
|
] |
|
|
|
feed_dict[policy.memory_in] = mem_in |
|
|
|
feed_dict[self.next_memory_in] = next_mem_in |
|
|
|
feed_dict[policy.memory_in] = self._make_zero_mem( |
|
|
|
self.policy.m_size, batch.num_experiences |
|
|
|
) |
|
|
|
feed_dict[self.policy_network.memory_in] = self._make_zero_mem( |
|
|
|
self.m_size, batch.num_experiences |
|
|
|
) |
|
|
|
feed_dict[self.target_network.memory_in] = self._make_zero_mem( |
|
|
|
self.policy.m_size, batch.num_experiences |
|
|
|
) |
|
|
|
feed_dict[self.dones_holder] = batch["done"] |
|
|
|
return feed_dict |