浏览代码

use LinearEncoder in curiosity and clean up (#4444)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
6986fb10
共有 1 个文件被更改,包括 6 次插入8 次删除
  1. 14
      ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py

14
ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py


from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.networks import NetworkBody
from mlagents.trainers.torch.layers import linear_layer, Swish
from mlagents.trainers.torch.layers import LinearEncoder, linear_layer
from mlagents.trainers.settings import NetworkSettings, EncoderType

self._action_flattener = ModelUtils.ActionFlattener(specs)
self.inverse_model_action_predition = torch.nn.Sequential(
linear_layer(2 * settings.encoding_size, 256),
Swish(),
self.inverse_model_action_prediction = torch.nn.Sequential(
LinearEncoder(2 * settings.encoding_size, 1, 256),
linear_layer(
settings.encoding_size + self._action_flattener.flattened_size, 256
LinearEncoder(
settings.encoding_size + self._action_flattener.flattened_size, 1, 256
Swish(),
linear_layer(256, settings.encoding_size),
)

inverse_model_input = torch.cat(
(self.get_current_state(mini_batch), self.get_next_state(mini_batch)), dim=1
)
hidden = self.inverse_model_action_predition(inverse_model_input)
hidden = self.inverse_model_action_prediction(inverse_model_input)
if self._policy_specs.is_action_continuous():
return hidden
else:

正在加载...
取消
保存