|
|
|
|
|
|
class SharedActorCritic(SimpleActor, ActorCritic): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
observation_spec: List[ObservationSpec], |
|
|
|
observation_specs: List[ObservationSpec], |
|
|
|
network_settings: NetworkSettings, |
|
|
|
action_spec: ActionSpec, |
|
|
|
stream_names: List[str], |
|
|
|
|
|
|
self.use_lstm = network_settings.memory is not None |
|
|
|
super().__init__( |
|
|
|
observation_spec, |
|
|
|
observation_specs, |
|
|
|
network_settings, |
|
|
|
action_spec, |
|
|
|
conditional_sigma, |
|
|
|
|
|
|
class SeparateActorCritic(SimpleActor, ActorCritic): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
observation_spec: List[ObservationSpec], |
|
|
|
observation_specs: List[ObservationSpec], |
|
|
|
network_settings: NetworkSettings, |
|
|
|
action_spec: ActionSpec, |
|
|
|
stream_names: List[str], |
|
|
|
|
|
|
self.use_lstm = network_settings.memory is not None |
|
|
|
super().__init__( |
|
|
|
observation_spec, |
|
|
|
observation_specs, |
|
|
|
network_settings, |
|
|
|
action_spec, |
|
|
|
conditional_sigma, |
|
|
|
|
|
|
self.critic = ValueNetwork(stream_names, observation_spec, network_settings) |
|
|
|
self.critic = ValueNetwork(stream_names, observation_specs, network_settings) |
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|