|
|
|
|
|
|
) # could be much simpler if TorchPolicy is nn.Module |
|
|
|
self.grads = None |
|
|
|
|
|
|
|
# reward_signal_configs = trainer_settings.reward_signals |
|
|
|
# reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] |
|
|
|
|
|
|
|
ac_class = SimpleActor |
|
|
|
# if separate_critic: |
|
|
|
# ac_class = SimpleActor |
|
|
|
# else: |
|
|
|
# ac_class = SharedActorCritic |
|
|
|
self.actor = ac_class( |
|
|
|
self.actor = SimpleActor( |
|
|
|
observation_specs=self.behavior_spec.observation_specs, |
|
|
|
network_settings=trainer_settings.network_settings, |
|
|
|
action_spec=behavior_spec.action_spec, |
|
|
|