浏览代码

[rewardProviders] First stab a reward provider implementation.

/main/reward-providers
Christopher Goy 5 年前
当前提交
3a355570
共有 16 个文件被更改,包括 218 次插入29 次删除
  1. 78
      UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
  2. 6
      UnitySDK/Assets/ML-Agents/Scripts/Policy/BarracudaPolicy.cs
  3. 9
      UnitySDK/Assets/ML-Agents/Scripts/Policy/BehaviorParameters.cs
  4. 7
      UnitySDK/Assets/ML-Agents/Scripts/Policy/RemotePolicy.cs
  5. 3
      UnitySDK/Assets/ML-Agents/Scripts/RewardProvider.meta
  6. 15
      UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/BaseRewardProviderComponent.cs
  7. 3
      UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/BaseRewardProviderComponent.cs.meta
  8. 15
      UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/IRewardProvider.cs
  9. 3
      UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/IRewardProvider.cs.meta
  10. 69
      UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/LegacyRewardProvider.cs
  11. 3
      UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/LegacyRewardProvider.cs.meta
  12. 7
      UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/LegacyRewardProviderComponent.cs
  13. 3
      UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/LegacyRewardProviderComponent.cs.meta
  14. 23
      UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/TypedRewardProviderComponent.cs
  15. 3
      UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/TypedRewardProviderComponent.cs.meta

78
UnitySDK/Assets/ML-Agents/Scripts/Agent.cs


using System.Collections.Generic;
using UnityEngine;
using Barracuda;
using MLAgents.RewardProvider;
using MLAgents.Sensor;
using UnityEngine.Serialization;

/// Current Agent action (message sent from Brain).
AgentAction m_Action;
/// Represents the reward the agent accumulated during the current step.
/// It is reset to 0 at the beginning of every step.
/// Should be set to a positive value when the agent performs a "good"
/// action that we wish to reinforce/reward, and set to a negative value
/// when the agent performs a "bad" action that we wish to punish/deter.
/// Additionally, the magnitude of the reward should not exceed 1.0
float m_Reward;
/// Keeps track of the cumulative reward in this episode.
float m_CumulativeReward;

WriteAdapter m_WriteAdapter = new WriteAdapter();
/// Represents the reward the agent accumulated during the current step.
/// It is reset to 0 at the beginning of every step.
/// Should be set to a positive value when the agent performs a "good"
/// action that we wish to reinforce/reward, and set to a negative value
/// when the agent performs a "bad" action that we wish to punish/deter.
/// Additionally, the magnitude of the reward should not exceed 1.0
IRewardProvider m_RewardProvider;
/// <summary>
/// Here for ease of upgrading from the old reward system.
/// </summary>
LegacyRewardProvider m_LegacyRewardProvider;
/// MonoBehaviour function that is called when the attached GameObject
/// becomes enabled or active.
void OnEnable()

academy.DecideAction += DecideAction;
academy.AgentAct += AgentStep;
academy.AgentForceReset += _AgentReset;
InitializeRewardProvider();
m_Brain = m_PolicyFactory.GeneratePolicy(Heuristic);
m_Brain = m_PolicyFactory.GeneratePolicy(Heuristic, m_RewardProvider);
ResetData();
InitializeAgent();
InitializeSensors();

{
m_PolicyFactory.GiveModel(behaviorName, model, inferenceDevice);
m_Brain?.Dispose();
m_Brain = m_PolicyFactory.GeneratePolicy(Heuristic);
m_Brain = m_PolicyFactory.GeneratePolicy(Heuristic, m_RewardProvider);
}
/// <summary>

/// </summary>
public void ResetReward()
{
m_Reward = 0f;
if (m_Done)
{
m_CumulativeReward = 0f;
}
Debug.Assert(m_LegacyRewardProvider != null, "LegacyRewardProvider is null and " +
"legacy method 'ResetReward' was called.");
m_LegacyRewardProvider.ResetReward(m_Done);
}
/// <summary>

/// <param name="reward">The new value of the reward.</param>
public void SetReward(float reward)
{
m_CumulativeReward += (reward - m_Reward);
m_Reward = reward;
Debug.Assert(m_LegacyRewardProvider != null, "LegacyRewardProvider is null and " +
"legacy method 'SetReward' was called.");
m_LegacyRewardProvider.SetReward(reward);
}
/// <summary>

public void AddReward(float increment)
{
m_Reward += increment;
m_CumulativeReward += increment;
Debug.Assert(m_LegacyRewardProvider != null, "LegacyRewardProvider is null and " +
"legacy method 'AddReward' was called.");
m_LegacyRewardProvider.AddReward(increment);
}
/// <summary>

public float GetReward()
{
return m_Reward;
Debug.Assert(m_LegacyRewardProvider != null, "LegacyRewardProvider is null and " +
"legacy method 'GetReward' was called.");
return m_LegacyRewardProvider.GetIncrementalReward();
}
/// <summary>

public float GetCumulativeReward()
{
return m_CumulativeReward;
Debug.Assert(m_LegacyRewardProvider != null, "LegacyRewardProvider is null and " +
"legacy method 'GetCumulativeReward' was called.");
return m_LegacyRewardProvider.GetCumulativeReward();
}
/// <summary>

m_VectorSensorBuffer = new float[numFloatObservations];
}
void InitializeRewardProvider()
{
// Look for a legacy reward provider.
var rewardProviderComponent = GetComponent<BaseRewardProviderComponent>();
if (rewardProviderComponent != null)
{
m_RewardProvider = rewardProviderComponent.GetRewardProvider();
}
if (m_RewardProvider == null)
{
var legacyRewardProviderComponent = gameObject.AddComponent<LegacyRewardProviderComponent>();
m_RewardProvider = legacyRewardProviderComponent.GetTypedRewardProvider();
}
m_LegacyRewardProvider = m_RewardProvider as LegacyRewardProvider;
}
/// <summary>
/// Sends the Agent info to the linked Brain.
/// </summary>

// var param = m_PolicyFactory.brainParameters; // look, no brain params!
m_Info.reward = m_Reward;
m_Info.reward = m_RewardProvider.GetIncrementalReward();
// TODO(cgoy): Decouple Agent/Policy.
m_Brain.RequestDecision(this);
if (m_Recorder != null && m_Recorder.record && Application.isEditor)

6
UnitySDK/Assets/ML-Agents/Scripts/Policy/BarracudaPolicy.cs


using Barracuda;
using System.Collections.Generic;
using MLAgents.InferenceBrain;
using MLAgents.RewardProvider;
namespace MLAgents
{

/// Sensor shapes for the associated Agents. All Agents must have the same shapes for their Sensors.
/// </summary>
List<int[]> m_SensorShapes;
IRewardProvider m_RewardProvider;
InferenceDevice inferenceDevice)
InferenceDevice inferenceDevice,
IRewardProvider rewardProvider)
m_RewardProvider = rewardProvider;
}
/// <inheritdoc />

9
UnitySDK/Assets/ML-Agents/Scripts/Policy/BehaviorParameters.cs


using Barracuda;
using System;
using System.Collections.Generic;
using MLAgents.RewardProvider;
using UnityEngine;
namespace MLAgents

}
public IPolicy GeneratePolicy(Func<float[]> heuristic)
public IPolicy GeneratePolicy(Func<float[]> heuristic, IRewardProvider rewardProvider)
{
switch (m_BehaviorType)
{

return new BarracudaPolicy(m_BrainParameters, m_Model, m_InferenceDevice);
return new BarracudaPolicy(m_BrainParameters, m_Model, m_InferenceDevice, rewardProvider);
return new RemotePolicy(m_BrainParameters, behaviorName);
return new RemotePolicy(m_BrainParameters, m_BehaviorName, rewardProvider);
return new BarracudaPolicy(m_BrainParameters, m_Model, m_InferenceDevice);
return new BarracudaPolicy(m_BrainParameters, m_Model, m_InferenceDevice, rewardProvider);
}
else
{

7
UnitySDK/Assets/ML-Agents/Scripts/Policy/RemotePolicy.cs


using UnityEngine;
using System.Collections.Generic;
using MLAgents.RewardProvider;
namespace MLAgents
{

/// </summary>
List<int[]> m_SensorShapes;
IRewardProvider m_RewardProvider;
string behaviorName)
string behaviorName,
IRewardProvider rewardProvider)
m_RewardProvider = rewardProvider;
var aca = Object.FindObjectOfType<Academy>();
aca.LazyInitialization();
m_Communicator = aca.Communicator;

3
UnitySDK/Assets/ML-Agents/Scripts/RewardProvider.meta


fileFormatVersion: 2
guid: 332fe3ab963e4b33bc528e8f5b2c82a7
timeCreated: 1575329166

15
UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/BaseRewardProviderComponent.cs


using System;
using UnityEngine;
namespace MLAgents.RewardProvider
{
public class BaseRewardProviderComponent: MonoBehaviour
{
IRewardProvider m_RewardProvider;
public virtual IRewardProvider GetRewardProvider()
{
return m_RewardProvider;
}
}
}

3
UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/BaseRewardProviderComponent.cs.meta


fileFormatVersion: 2
guid: a4db05acdb9445f6843616362564f7c2
timeCreated: 1576018618

15
UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/IRewardProvider.cs


namespace MLAgents.RewardProvider
{
/// <summary>
/// Reward providers allow users to provide rewards for Agent behavior during training in order to
/// give hints on what types of actions are "better" than others based on an Agent's previous observation.
/// </summary>
public interface IRewardProvider
{
/// <summary>
/// Get an incremental reward to pass along to a trainer.
/// </summary>
/// <returns></returns>
float GetIncrementalReward();
}
}

3
UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/IRewardProvider.cs.meta


fileFormatVersion: 2
guid: bed12564f4c74e3e964fdb763ce73213
timeCreated: 1575329472

69
UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/LegacyRewardProvider.cs


namespace MLAgents.RewardProvider
{
/// <summary>
/// A legacy reward provider that can be used in an Agent as a way to easily upgrade
/// from the old reward system.
/// </summary>
public class LegacyRewardProvider : IRewardProvider
{
float m_IncrementalReward;
float m_CumulativeReward;
public float GetIncrementalReward()
{
return m_IncrementalReward;
}
/// <summary>
/// Resets the step reward and possibly the episode reward for the agent.
/// </summary>
public void ResetReward(bool done = false)
{
m_IncrementalReward = 0f;
if (done)
{
m_CumulativeReward = 0f;
}
}
/// <summary>
/// Overrides the current step reward of the agent and updates the episode
/// reward accordingly.
/// </summary>
/// <param name="reward">The new value of the reward.</param>
public void SetReward(float reward)
{
m_CumulativeReward += (reward - m_IncrementalReward);
m_IncrementalReward = reward;
}
/// <summary>
/// Increments the step and episode rewards by the provided value.
/// </summary>
/// <param name="increment">Incremental reward value.</param>
public void AddReward(float increment)
{
m_IncrementalReward += increment;
m_CumulativeReward += increment;
}
/// <summary>
/// Retrieves the step reward for the Agent.
/// </summary>
/// <returns>The step reward.</returns>
public float GetReward()
{
return m_IncrementalReward;
}
/// <summary>
/// Retrieves the episode reward for the Agent.
/// </summary>
/// <returns>The episode reward.</returns>
public float GetCumulativeReward()
{
return m_CumulativeReward;
}
}
}

3
UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/LegacyRewardProvider.cs.meta


fileFormatVersion: 2
guid: 46aa889302734e5ca844235a4f69ff29
timeCreated: 1576015982

7
UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/LegacyRewardProviderComponent.cs


namespace MLAgents.RewardProvider
{
public class LegacyRewardProviderComponent : TypedRewardProviderComponent<LegacyRewardProvider>
{
}
}

3
UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/LegacyRewardProviderComponent.cs.meta


fileFormatVersion: 2
guid: 9ff175c2b68f41e5b5aa045010677f61
timeCreated: 1576019305

23
UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/TypedRewardProviderComponent.cs


namespace MLAgents.RewardProvider
{
/// <summary>
/// A typed reward provider that provides easy, typed access to RewardProvider implementations.
/// Subclasses should
/// </summary>
/// <typeparam name="T"></typeparam>
public class TypedRewardProviderComponent<T> : BaseRewardProviderComponent
where T : IRewardProvider, new()
{
T m_TypedRewardProvider = new T();
public T GetTypedRewardProvider()
{
return m_TypedRewardProvider;
}
public override IRewardProvider GetRewardProvider()
{
return m_TypedRewardProvider;
}
}
}

3
UnitySDK/Assets/ML-Agents/Scripts/RewardProvider/TypedRewardProviderComponent.cs.meta


fileFormatVersion: 2
guid: d5d464dcbd314f68a335a1b9b37c6c6e
timeCreated: 1576019945
正在加载...
取消
保存