|
|
|
|
|
|
|
|
|
|
class TFOptimizer(Optimizer): # pylint: disable=W0223 |
|
|
|
def __init__(self, policy: TFPolicy, trainer_params: Dict[str, Any]): |
|
|
|
super().__init__(policy) |
|
|
|
self.sess = policy.sess |
|
|
|
self.policy = policy |
|
|
|
self.update_dict: Dict[str, tf.Tensor] = {} |
|
|
|
|
|
|
def get_trajectory_value_estimates( |
|
|
|
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool |
|
|
|
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]: |
|
|
|
""" |
|
|
|
Gets the value estimates for an entire trajectory of samples, except for the final one. |
|
|
|
:param batch: An AgentBuffer that represents the trajectory. Assumed to be created from |
|
|
|
Trajectory.to_agentbuffer()/ |
|
|
|
:param next_obs: The next observation after the trajectory is Done, for use with bootstrapping. |
|
|
|
:param done: Whether or not this trajectory is terminal, in which case the value estimate will be 0. |
|
|
|
:returns: Two Dicts that represent the trajectory value estimates for each reward signal (str to np.ndarray) |
|
|
|
and the final value estimate of the next_obs for each reward signal (str to float). |
|
|
|
""" |
|
|
|
feed_dict: Dict[tf.Tensor, Any] = { |
|
|
|
self.policy.batch_size_ph: batch.num_experiences, |
|
|
|
self.policy.sequence_length_ph: batch.num_experiences, # We want to feed data in batch-wise, not time-wise. |
|
|
|
|
|
|
|
|
|
|
# We do this in a separate step to feed the memory outs - a further optimization would |
|
|
|
# be to append to the obs before running sess.run. |
|
|
|
final_value_estimates = self.get_value_estimates( |
|
|
|
final_value_estimates = self._get_value_estimates( |
|
|
|
def get_value_estimates( |
|
|
|
def _get_value_estimates( |
|
|
|
self, |
|
|
|
next_obs: List[np.ndarray], |
|
|
|
done: bool, |
|
|
|
|
|
|
) -> Dict[str, float]: |
|
|
|
""" |
|
|
|
Generates value estimates for bootstrapping. |
|
|
|
Generates value estimates for bootstrapping. Called by get_trajectory_value_extimates |
|
|
|
:param policy_memory: Memory output of the policy at the prior timestep. |
|
|
|
:param value_memory: Memory output of the value network at the prior timestep. |
|
|
|
:prev_action: The last action before this observation. |
|
|
|
:return: The value estimate dictionary with key being the name of the reward signal and the value the |
|
|
|
corresponding value estimate. |
|
|
|
""" |
|
|
|