浏览代码

update policy to not use critic

/develop/action-slice
Andrew Cohen 4 年前
当前提交
f73b9dba
共有 2 个文件被更改,包括 20 次插入26 次删除
  1. 3
      ml-agents/mlagents/trainers/action_info.py
  2. 43
      ml-agents/mlagents/trainers/policy/torch_policy.py

3
ml-agents/mlagents/trainers/action_info.py


class ActionInfo(NamedTuple):
action: Any
env_action: Any
value: Any
return ActionInfo([], [], [], {}, [])
return ActionInfo([], [], {}, [])

43
ml-agents/mlagents/trainers/policy/torch_policy.py


from mlagents_envs.timers import timed
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.torch.networks import (
SharedActorCritic,
SeparateActorCritic,
GlobalSteps,
)
from mlagents.trainers.torch.networks import SimpleActor, GlobalSteps
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.buffer import AgentBuffer

) # could be much simpler if TorchPolicy is nn.Module
self.grads = None
reward_signal_configs = trainer_settings.reward_signals
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()]
# reward_signal_configs = trainer_settings.reward_signals
# reward_signal_names = [key.value for key, _ in reward_signal_configs.items()]
if separate_critic:
ac_class = SeparateActorCritic
else:
ac_class = SharedActorCritic
self.actor_critic = ac_class(
ac_class = SimpleActor
# if separate_critic:
# ac_class = SimpleActor
# else:
# ac_class = SharedActorCritic
self.actor = ac_class(
stream_names=reward_signal_names,
conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)

self.m_size = self.actor_critic.memory_size
self.m_size = self.actor.memory_size
self.actor_critic.to(default_device())
self.actor.to(default_device())
self._clip_action = not tanh_squash
@property

"""
if self.normalize:
self.actor_critic.update_normalization(buffer)
self.actor.update_normalization(buffer)
@timed
def sample_actions(

:param seq_len: Sequence length when using RNN.
:return: Tuple of AgentAction, ActionLogProbs, entropies, and output memories.
"""
actions, log_probs, entropies, memories = self.actor_critic.get_action_stats(
actions, log_probs, entropies, memories = self.actor.get_action_and_stats(
obs, masks, memories, seq_len
)
return (actions, log_probs, entropies, memories)

masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
seq_len: int = 1,
) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]:
log_probs, entropies, value_heads = self.actor_critic.get_stats_and_value(
) -> Tuple[ActionLogProbs, torch.Tensor]:
log_probs, entropies = self.actor.get_stats(
return log_probs, entropies, value_heads
return log_probs, entropies
@timed
def evaluate(

return ActionInfo(
action=run_out.get("action"),
env_action=run_out.get("env_action"),
value=run_out.get("value"),
outputs=run_out,
agent_ids=list(decision_requests.agent_id),
)

return self.get_current_step()
def load_weights(self, values: List[np.ndarray]) -> None:
self.actor_critic.load_state_dict(values)
self.actor.load_state_dict(values)
return copy.deepcopy(self.actor_critic.state_dict())
return copy.deepcopy(self.actor.state_dict())
return {"Policy": self.actor_critic, "global_step": self.global_step}
return {"Policy": self.actor, "global_step": self.global_step}
正在加载...
取消
保存