|
|
|
|
|
|
using System.Collections.Generic; |
|
|
|
using UnityEngine; |
|
|
|
using Barracuda; |
|
|
|
using MLAgents.RewardProvider; |
|
|
|
using MLAgents.Sensor; |
|
|
|
using UnityEngine.Serialization; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/// Current Agent action (message sent from Brain).
|
|
|
|
AgentAction m_Action; |
|
|
|
|
|
|
|
/// Represents the reward the agent accumulated during the current step.
|
|
|
|
/// It is reset to 0 at the beginning of every step.
|
|
|
|
/// Should be set to a positive value when the agent performs a "good"
|
|
|
|
/// action that we wish to reinforce/reward, and set to a negative value
|
|
|
|
/// when the agent performs a "bad" action that we wish to punish/deter.
|
|
|
|
/// Additionally, the magnitude of the reward should not exceed 1.0
|
|
|
|
float m_Reward; |
|
|
|
|
|
|
|
|
|
|
|
/// Keeps track of the cumulative reward in this episode.
|
|
|
|
float m_CumulativeReward; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WriteAdapter m_WriteAdapter = new WriteAdapter(); |
|
|
|
|
|
|
|
|
|
|
|
/// Represents the reward the agent accumulated during the current step.
|
|
|
|
/// It is reset to 0 at the beginning of every step.
|
|
|
|
/// Should be set to a positive value when the agent performs a "good"
|
|
|
|
/// action that we wish to reinforce/reward, and set to a negative value
|
|
|
|
/// when the agent performs a "bad" action that we wish to punish/deter.
|
|
|
|
/// Additionally, the magnitude of the reward should not exceed 1.0
|
|
|
|
IRewardProvider m_RewardProvider; |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Here for ease of upgrading from the old reward system.
|
|
|
|
/// </summary>
|
|
|
|
LegacyRewardProvider m_LegacyRewardProvider; |
|
|
|
|
|
|
|
|
|
|
|
/// MonoBehaviour function that is called when the attached GameObject
|
|
|
|
/// becomes enabled or active.
|
|
|
|
void OnEnable() |
|
|
|
|
|
|
academy.DecideAction += DecideAction; |
|
|
|
academy.AgentAct += AgentStep; |
|
|
|
academy.AgentForceReset += _AgentReset; |
|
|
|
InitializeRewardProvider(); |
|
|
|
m_Brain = m_PolicyFactory.GeneratePolicy(Heuristic); |
|
|
|
m_Brain = m_PolicyFactory.GeneratePolicy(Heuristic, m_RewardProvider); |
|
|
|
ResetData(); |
|
|
|
InitializeAgent(); |
|
|
|
InitializeSensors(); |
|
|
|
|
|
|
{ |
|
|
|
m_PolicyFactory.GiveModel(behaviorName, model, inferenceDevice); |
|
|
|
m_Brain?.Dispose(); |
|
|
|
m_Brain = m_PolicyFactory.GeneratePolicy(Heuristic); |
|
|
|
m_Brain = m_PolicyFactory.GeneratePolicy(Heuristic, m_RewardProvider); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
|
|
/// </summary>
|
|
|
|
public void ResetReward() |
|
|
|
{ |
|
|
|
m_Reward = 0f; |
|
|
|
if (m_Done) |
|
|
|
{ |
|
|
|
m_CumulativeReward = 0f; |
|
|
|
} |
|
|
|
Debug.Assert(m_LegacyRewardProvider != null, "LegacyRewardProvider is null and " + |
|
|
|
"legacy method 'ResetReward' was called."); |
|
|
|
m_LegacyRewardProvider.ResetReward(m_Done); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
|
|
/// <param name="reward">The new value of the reward.</param>
|
|
|
|
public void SetReward(float reward) |
|
|
|
{ |
|
|
|
m_CumulativeReward += (reward - m_Reward); |
|
|
|
m_Reward = reward; |
|
|
|
Debug.Assert(m_LegacyRewardProvider != null, "LegacyRewardProvider is null and " + |
|
|
|
"legacy method 'SetReward' was called."); |
|
|
|
m_LegacyRewardProvider.SetReward(reward); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
|
|
public void AddReward(float increment) |
|
|
|
{ |
|
|
|
m_Reward += increment; |
|
|
|
m_CumulativeReward += increment; |
|
|
|
Debug.Assert(m_LegacyRewardProvider != null, "LegacyRewardProvider is null and " + |
|
|
|
"legacy method 'AddReward' was called."); |
|
|
|
m_LegacyRewardProvider.AddReward(increment); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
|
|
public float GetReward() |
|
|
|
{ |
|
|
|
return m_Reward; |
|
|
|
Debug.Assert(m_LegacyRewardProvider != null, "LegacyRewardProvider is null and " + |
|
|
|
"legacy method 'GetReward' was called."); |
|
|
|
return m_LegacyRewardProvider.GetIncrementalReward(); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
|
|
public float GetCumulativeReward() |
|
|
|
{ |
|
|
|
return m_CumulativeReward; |
|
|
|
Debug.Assert(m_LegacyRewardProvider != null, "LegacyRewardProvider is null and " + |
|
|
|
"legacy method 'GetCumulativeReward' was called."); |
|
|
|
return m_LegacyRewardProvider.GetCumulativeReward(); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
|
|
m_VectorSensorBuffer = new float[numFloatObservations]; |
|
|
|
} |
|
|
|
|
|
|
|
void InitializeRewardProvider() |
|
|
|
{ |
|
|
|
// Look for a legacy reward provider.
|
|
|
|
var rewardProviderComponent = GetComponent<BaseRewardProviderComponent>(); |
|
|
|
if (rewardProviderComponent != null) |
|
|
|
{ |
|
|
|
m_RewardProvider = rewardProviderComponent.GetRewardProvider(); |
|
|
|
} |
|
|
|
|
|
|
|
if (m_RewardProvider == null) |
|
|
|
{ |
|
|
|
var legacyRewardProviderComponent = gameObject.AddComponent<LegacyRewardProviderComponent>(); |
|
|
|
m_RewardProvider = legacyRewardProviderComponent.GetTypedRewardProvider(); |
|
|
|
} |
|
|
|
|
|
|
|
m_LegacyRewardProvider = m_RewardProvider as LegacyRewardProvider; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Sends the Agent info to the linked Brain.
|
|
|
|
/// </summary>
|
|
|
|
|
|
|
|
|
|
|
// var param = m_PolicyFactory.brainParameters; // look, no brain params!
|
|
|
|
|
|
|
|
m_Info.reward = m_Reward; |
|
|
|
m_Info.reward = m_RewardProvider.GetIncrementalReward(); |
|
|
|
// TODO(cgoy): Decouple Agent/Policy.
|
|
|
|
m_Brain.RequestDecision(this); |
|
|
|
|
|
|
|
if (m_Recorder != null && m_Recorder.record && Application.isEditor) |
|
|
|