Ervin Teng
4 年前
当前提交
c6904f86
共有 5 个文件被更改,包括 52 次插入 和 4 次删除
-
24ml-agents/mlagents/trainers/coma/trainer.py
-
1ml-agents/mlagents/trainers/settings.py
-
3ml-agents/mlagents/trainers/torch/components/reward_providers/__init__.py
-
4ml-agents/mlagents/trainers/torch/components/reward_providers/reward_provider_factory.py
-
24ml-agents/mlagents/trainers/torch/components/reward_providers/group_extrinsic_reward_provider.py
|
|||
import numpy as np |
|||
from typing import Dict |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer, BufferKey |
|||
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( |
|||
BaseRewardProvider, |
|||
) |
|||
|
|||
|
|||
class GroupExtrinsicRewardProvider(BaseRewardProvider): |
|||
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray: |
|||
indiv_rewards = np.array( |
|||
mini_batch[BufferKey.ENVIRONMENT_REWARDS], dtype=np.float32 |
|||
) |
|||
groupmate_rewards_list = mini_batch[BufferKey.GROUPMATE_REWARDS] |
|||
groupmate_rewards_sum = np.array( |
|||
[sum(_rew) for _rew in groupmate_rewards_list], dtype=np.ndarray |
|||
) |
|||
group_rewards = np.array(mini_batch[BufferKey.GROUP_REWARD], dtype=np.float32) |
|||
# Add all the group rewards to the individual rewards |
|||
return indiv_rewards + groupmate_rewards_sum + group_rewards |
|||
|
|||
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]: |
|||
return {} |
撰写
预览
正在加载...
取消
保存
Reference in new issue