|
|
|
|
|
|
) |
|
|
|
|
|
|
|
def create_tf_policy( |
|
|
|
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec |
|
|
|
self, |
|
|
|
parsed_behavior_id: BehaviorIdentifiers, |
|
|
|
behavior_spec: BehaviorSpec, |
|
|
|
create_graph: bool = False, |
|
|
|
) -> TFPolicy: |
|
|
|
policy = TFPolicy( |
|
|
|
self.seed, |
|
|
|
|
|
|
reparameterize=True, |
|
|
|
create_tf_graph=False, |
|
|
|
create_tf_graph=create_graph, |
|
|
|
) |
|
|
|
self.maybe_load_replay_buffer() |
|
|
|
return policy |
|
|
|