浏览代码
MultiAgentGroup Interface (#4923)
MultiAgentGroup Interface (#4923)
* add SimpleMultiAgentGroup * add group reward field to agent and proto/develop/gail-srl-hack
GitHub
4 年前
当前提交
ddb01eb2
共有 20 个文件被更改,包括 618 次插入 和 28 次删除
-
65com.unity.ml-agents/Runtime/Agent.cs
-
2com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
-
67com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/AgentInfo.cs
-
7gym-unity/gym_unity/tests/test_gym.py
-
22ml-agents-envs/mlagents_envs/base_env.py
-
18ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.py
-
8ml-agents-envs/mlagents_envs/communicator_objects/agent_info_pb2.pyi
-
30ml-agents-envs/mlagents_envs/rpc_utils.py
-
4ml-agents-envs/mlagents_envs/tests/test_steps.py
-
10ml-agents/mlagents/trainers/tests/mock_brain.py
-
65ml-agents/mlagents/trainers/tests/simple_test_envs.py
-
2protobuf-definitions/proto/mlagents_envs/communicator_objects/agent_info.proto
-
26com.unity.ml-agents/Runtime/IMultiAgentGroup.cs
-
11com.unity.ml-agents/Runtime/IMultiAgentGroup.cs.meta
-
13com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs
-
11com.unity.ml-agents/Runtime/MultiAgentGroupIdCounter.cs.meta
-
143com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs
-
11com.unity.ml-agents/Runtime/SimpleMultiAgentGroup.cs.meta
-
120com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs
-
11com.unity.ml-agents/Tests/Editor/MultiAgentGroupTests.cs.meta
|
|||
namespace Unity.MLAgents |
|||
{ |
|||
/// <summary>
|
|||
/// MultiAgentGroup interface for grouping agents to support multi-agent training.
|
|||
/// </summary>
|
|||
public interface IMultiAgentGroup |
|||
{ |
|||
/// <summary>
|
|||
/// Get the ID of MultiAgentGroup.
|
|||
/// </summary>
|
|||
/// <returns>
|
|||
/// MultiAgentGroup ID.
|
|||
/// </returns>
|
|||
int GetId(); |
|||
|
|||
/// <summary>
|
|||
/// Register agent to the MultiAgentGroup.
|
|||
/// </summary>
|
|||
void RegisterAgent(Agent agent); |
|||
|
|||
/// <summary>
|
|||
/// Unregister agent from the MultiAgentGroup.
|
|||
/// </summary>
|
|||
void UnregisterAgent(Agent agent); |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 3744ac27d956e43e1a39c7ba2550ab82 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System.Threading; |
|||
|
|||
namespace Unity.MLAgents |
|||
{ |
|||
internal static class MultiAgentGroupIdCounter |
|||
{ |
|||
static int s_Counter; |
|||
public static int GetGroupId() |
|||
{ |
|||
return Interlocked.Increment(ref s_Counter); ; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 5661ffdb6c7704e84bc785572dcd5bd1 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System; |
|||
using System.Linq; |
|||
using System.Collections.Generic; |
|||
|
|||
namespace Unity.MLAgents |
|||
{ |
|||
/// <summary>
|
|||
/// A basic class implementation of MultiAgentGroup.
|
|||
/// </summary>
|
|||
internal class SimpleMultiAgentGroup : IMultiAgentGroup, IDisposable |
|||
{ |
|||
readonly int m_Id = MultiAgentGroupIdCounter.GetGroupId(); |
|||
HashSet<Agent> m_Agents = new HashSet<Agent>(); |
|||
|
|||
|
|||
public virtual void Dispose() |
|||
{ |
|||
while (m_Agents.Count > 0) |
|||
{ |
|||
UnregisterAgent(m_Agents.First()); |
|||
} |
|||
} |
|||
|
|||
/// <inheritdoc />
|
|||
public virtual void RegisterAgent(Agent agent) |
|||
{ |
|||
if (!m_Agents.Contains(agent)) |
|||
{ |
|||
agent.SetMultiAgentGroup(this); |
|||
m_Agents.Add(agent); |
|||
agent.OnAgentDisabled += UnregisterAgent; |
|||
} |
|||
} |
|||
|
|||
/// <inheritdoc />
|
|||
public virtual void UnregisterAgent(Agent agent) |
|||
{ |
|||
if (m_Agents.Contains(agent)) |
|||
{ |
|||
agent.SetMultiAgentGroup(null); |
|||
m_Agents.Remove(agent); |
|||
agent.OnAgentDisabled -= UnregisterAgent; |
|||
} |
|||
} |
|||
|
|||
/// <inheritdoc />
|
|||
public int GetId() |
|||
{ |
|||
return m_Id; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Get list of all agents currently registered to this MultiAgentGroup.
|
|||
/// </summary>
|
|||
/// <returns>
|
|||
/// List of agents registered to the MultiAgentGroup.
|
|||
/// </returns>
|
|||
public IReadOnlyCollection<Agent> GetRegisteredAgents() |
|||
{ |
|||
return (IReadOnlyCollection<Agent>)m_Agents; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Increments the group rewards for all agents in this MultiAgentGroup.
|
|||
/// </summary>
|
|||
/// <remarks>
|
|||
/// This function increases or decreases the group rewards by a given amount for all agents
|
|||
/// in the group. Use <see cref="SetGroupReward(float)"/> to set the group reward assigned
|
|||
/// to the current step with a specific value rather than increasing or decreasing it.
|
|||
///
|
|||
/// A positive group reward indicates the whole group's accomplishments or desired behaviors.
|
|||
/// Every agent in the group will receive the same group reward no matter whether the
|
|||
/// agent's act directly leads to the reward. Group rewards are meant to reinforce agents
|
|||
/// to act in the group's best interest instead of individual ones.
|
|||
/// Group rewards are treated differently than individual agent rewards during training, so
|
|||
/// calling AddGroupReward() is not equivalent to calling agent.AddReward() on each agent in the group.
|
|||
/// </remarks>
|
|||
/// <param name="reward">Incremental group reward value.</param>
|
|||
public void AddGroupReward(float reward) |
|||
{ |
|||
foreach (var agent in m_Agents) |
|||
{ |
|||
agent.AddGroupReward(reward); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Set the group rewards for all agents in this MultiAgentGroup.
|
|||
/// </summary>
|
|||
/// <remarks>
|
|||
/// This function replaces any group rewards given during the current step for all agents in the group.
|
|||
/// Use <see cref="AddGroupReward(float)"/> to incrementally change the group reward rather than
|
|||
/// overriding it.
|
|||
///
|
|||
/// A positive group reward indicates the whole group's accomplishments or desired behaviors.
|
|||
/// Every agent in the group will receive the same group reward no matter whether the
|
|||
/// agent's act directly leads to the reward. Group rewards are meant to reinforce agents
|
|||
/// to act in the group's best interest instead of indivisual ones.
|
|||
/// Group rewards are treated differently than individual agent rewards during training, so
|
|||
/// calling SetGroupReward() is not equivalent to calling agent.SetReward() on each agent in the group.
|
|||
/// </remarks>
|
|||
/// <param name="reward">The new value of the group reward.</param>
|
|||
public void SetGroupReward(float reward) |
|||
{ |
|||
foreach (var agent in m_Agents) |
|||
{ |
|||
agent.SetGroupReward(reward); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// End episodes for all agents in this MultiAgentGroup.
|
|||
/// </summary>
|
|||
/// <remarks>
|
|||
/// This should be used when the episode can no longer continue, such as when the group
|
|||
/// reaches the goal or fails at the task.
|
|||
/// </remarks>
|
|||
public void EndGroupEpisode() |
|||
{ |
|||
foreach (var agent in m_Agents) |
|||
{ |
|||
agent.EndEpisode(); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Indicate that the episode is over but not due to the "fault" of the group.
|
|||
/// This has the same end result as calling <see cref="EndGroupEpisode"/>, but has a
|
|||
/// slightly different effect on training.
|
|||
/// </summary>
|
|||
/// <remarks>
|
|||
/// This should be used when the episode could continue, but has gone on for
|
|||
/// a sufficient number of steps, such as if the environment hits some maximum number of steps.
|
|||
/// </remarks>
|
|||
public void GroupEpisodeInterrupted() |
|||
{ |
|||
foreach (var agent in m_Agents) |
|||
{ |
|||
agent.EpisodeInterrupted(); |
|||
} |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 3454e3c3c70964dca93b63ee4b650095 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using Unity.MLAgents; |
|||
using System; |
|||
using System.Reflection; |
|||
using NUnit.Framework; |
|||
using UnityEngine; |
|||
using Unity; |
|||
|
|||
namespace Unity.MLAgents.Tests |
|||
{ |
|||
public class MultiAgentGroupTests |
|||
{ |
|||
class TestAgent : Agent |
|||
{ |
|||
internal int _GroupId |
|||
{ |
|||
get |
|||
{ |
|||
return (int)typeof(Agent).GetField("m_GroupId", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this); |
|||
} |
|||
} |
|||
|
|||
internal float _GroupReward |
|||
{ |
|||
get |
|||
{ |
|||
return (float)typeof(Agent).GetField("m_GroupReward", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this); |
|||
} |
|||
} |
|||
|
|||
internal Action<Agent> _OnAgentDisabledActions |
|||
{ |
|||
get |
|||
{ |
|||
return (Action<Agent>)typeof(Agent).GetField("OnAgentDisabled", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this); |
|||
} |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestRegisteredAgentGroupId() |
|||
{ |
|||
var agentGo = new GameObject("TestAgent"); |
|||
agentGo.AddComponent<TestAgent>(); |
|||
var agent = agentGo.GetComponent<TestAgent>(); |
|||
|
|||
// test register
|
|||
SimpleMultiAgentGroup agentGroup1 = new SimpleMultiAgentGroup(); |
|||
agentGroup1.RegisterAgent(agent); |
|||
Assert.AreEqual(agentGroup1.GetId(), agent._GroupId); |
|||
Assert.IsNotNull(agent._OnAgentDisabledActions); |
|||
|
|||
// should not be able to registered to multiple groups
|
|||
SimpleMultiAgentGroup agentGroup2 = new SimpleMultiAgentGroup(); |
|||
Assert.Throws<UnityAgentsException>( |
|||
() => agentGroup2.RegisterAgent(agent)); |
|||
Assert.AreEqual(agentGroup1.GetId(), agent._GroupId); |
|||
|
|||
// test unregister
|
|||
agentGroup1.UnregisterAgent(agent); |
|||
Assert.AreEqual(0, agent._GroupId); |
|||
Assert.IsNull(agent._OnAgentDisabledActions); |
|||
|
|||
// test register to another group after unregister
|
|||
agentGroup2.RegisterAgent(agent); |
|||
Assert.AreEqual(agentGroup2.GetId(), agent._GroupId); |
|||
Assert.IsNotNull(agent._OnAgentDisabledActions); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestRegisterMultipleAgent() |
|||
{ |
|||
var agentGo1 = new GameObject("TestAgent"); |
|||
agentGo1.AddComponent<TestAgent>(); |
|||
var agent1 = agentGo1.GetComponent<TestAgent>(); |
|||
var agentGo2 = new GameObject("TestAgent"); |
|||
agentGo2.AddComponent<TestAgent>(); |
|||
var agent2 = agentGo2.GetComponent<TestAgent>(); |
|||
|
|||
SimpleMultiAgentGroup agentGroup = new SimpleMultiAgentGroup(); |
|||
agentGroup.RegisterAgent(agent1); // register
|
|||
Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 1); |
|||
agentGroup.UnregisterAgent(agent2); // unregister non-member agent
|
|||
Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 1); |
|||
agentGroup.UnregisterAgent(agent1); // unregister
|
|||
Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 0); |
|||
agentGroup.RegisterAgent(agent1); |
|||
agentGroup.RegisterAgent(agent1); // duplicated register
|
|||
Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 1); |
|||
agentGroup.RegisterAgent(agent2); // register another
|
|||
Assert.AreEqual(agentGroup.GetRegisteredAgents().Count, 2); |
|||
|
|||
// test add/set group rewards
|
|||
agentGroup.AddGroupReward(0.1f); |
|||
Assert.AreEqual(0.1f, agent1._GroupReward); |
|||
agentGroup.AddGroupReward(0.5f); |
|||
Assert.AreEqual(0.6f, agent1._GroupReward); |
|||
agentGroup.SetGroupReward(0.3f); |
|||
Assert.AreEqual(0.3f, agent1._GroupReward); |
|||
// unregistered agent should not receive group reward
|
|||
agentGroup.UnregisterAgent(agent1); |
|||
agentGroup.AddGroupReward(0.2f); |
|||
Assert.AreEqual(0.3f, agent1._GroupReward); |
|||
Assert.AreEqual(0.5f, agent2._GroupReward); |
|||
|
|||
// dispose group should automatically unregister all
|
|||
agentGroup.Dispose(); |
|||
Assert.AreEqual(0, agent1._GroupId); |
|||
Assert.AreEqual(0, agent2._GroupId); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestGroupIdCounter() |
|||
{ |
|||
SimpleMultiAgentGroup group1 = new SimpleMultiAgentGroup(); |
|||
SimpleMultiAgentGroup group2 = new SimpleMultiAgentGroup(); |
|||
// id should be unique
|
|||
Assert.AreNotEqual(group1.GetId(), group2.GetId()); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: ef0158fde748d478ca5ee3bbe22a4c9e |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
撰写
预览
正在加载...
取消
保存
Reference in new issue