浏览代码
[coma2] Make group extrinsic reward part of extrinsic (#5033)
[coma2] Make group extrinsic reward part of extrinsic (#5033)
* Make group extrinsic part of extrinsic * Fix test and init * Fix tests and bug * Add baseline loss to TensorBoard/develop/action-slice
GitHub
4 年前
当前提交
ba2af269
共有 10 个文件被更改,包括 122 次插入 和 50 次删除
-
4config/ppo/PushBlockCollab.yaml
-
31ml-agents/mlagents/trainers/coma/optimizer_torch.py
-
10ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
-
12ml-agents/mlagents/trainers/settings.py
-
24ml-agents/mlagents/trainers/tests/torch/test_coma.py
-
27ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py
-
3ml-agents/mlagents/trainers/torch/components/reward_providers/__init__.py
-
33ml-agents/mlagents/trainers/torch/components/reward_providers/extrinsic_reward_provider.py
-
4ml-agents/mlagents/trainers/torch/components/reward_providers/reward_provider_factory.py
-
24ml-agents/mlagents/trainers/torch/components/reward_providers/group_extrinsic_reward_provider.py
|
|||
import numpy as np |
|||
from typing import Dict |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer, BufferKey |
|||
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( |
|||
BaseRewardProvider, |
|||
) |
|||
|
|||
|
|||
class GroupExtrinsicRewardProvider(BaseRewardProvider): |
|||
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray: |
|||
indiv_rewards = np.array( |
|||
mini_batch[BufferKey.ENVIRONMENT_REWARDS], dtype=np.float32 |
|||
) |
|||
groupmate_rewards_list = mini_batch[BufferKey.GROUPMATE_REWARDS] |
|||
groupmate_rewards_sum = np.array( |
|||
[sum(_rew) for _rew in groupmate_rewards_list], dtype=np.ndarray |
|||
) |
|||
group_rewards = np.array(mini_batch[BufferKey.GROUP_REWARD], dtype=np.float32) |
|||
# Add all the group rewards to the individual rewards |
|||
return indiv_rewards + groupmate_rewards_sum + group_rewards |
|||
|
|||
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]: |
|||
return {} |
撰写
预览
正在加载...
取消
保存
Reference in new issue