|
|
|
|
|
|
WriteAdapter m_WriteAdapter = new WriteAdapter(); |
|
|
|
|
|
|
|
|
|
|
|
RewardProviderComponent m_RewardProviderComponent; |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Represents the reward the agent accumulated during the current step.
|
|
|
|
/// It is reset at the beginning of every step.
|
|
|
|
|
|
|
/// Additionally, the magnitude of the reward should not exceed 1.0
|
|
|
|
/// </summary>
|
|
|
|
public IRewardProvider rewardProvider |
|
|
|
{ |
|
|
|
get |
|
|
|
{ |
|
|
|
Debug.Assert(m_RewardProviderComponent != null, |
|
|
|
nameof(m_RewardProviderComponent) + " != null"); |
|
|
|
return m_RewardProviderComponent.GetRewardProvider(); |
|
|
|
} |
|
|
|
} |
|
|
|
public IRewardProvider rewardProvider; |
|
|
|
CumulativeRewardProvider DefaultRewardProvider |
|
|
|
CumulativeRewardProvider CumulativeRewardProvider |
|
|
|
{ |
|
|
|
get { return rewardProvider as CumulativeRewardProvider; } |
|
|
|
} |
|
|
|
|
|
|
{ |
|
|
|
return m_StepCount; |
|
|
|
} |
|
|
|
void WarnDefaultRewardProvider(string callee) { |
|
|
|
if (CumulativeRewardProvider == null) |
|
|
|
{ |
|
|
|
Debug.LogWarningFormat("the CumulativeRewardProvider is null and " + |
|
|
|
"method '{0}' was called. If your agent doesn't have the CumulativeRewardProvider," + |
|
|
|
"remove the call to '{0}'.", callee); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Resets the step reward and possibly the episode reward for the agent.
|
|
|
|
|
|
|
Debug.Assert(DefaultRewardProvider != null, "the DefaultRewardProvider is null and " + |
|
|
|
"method 'ResetReward' was called. If your agent doesn't have the CumulativeRewardProvider," + |
|
|
|
"remove the call from ResetReward."); |
|
|
|
DefaultRewardProvider.ResetReward(m_Done); |
|
|
|
WarnDefaultRewardProvider("ResetReward"); |
|
|
|
InternalResetReward(); |
|
|
|
} |
|
|
|
|
|
|
|
void InternalResetReward() |
|
|
|
{ |
|
|
|
CumulativeRewardProvider?.ResetReward(m_Done); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
|
|
/// <param name="reward">The new value of the reward.</param>
|
|
|
|
public void SetReward(float reward) |
|
|
|
{ |
|
|
|
Debug.Assert(DefaultRewardProvider != null, "the DefaultRewardProvider is null and " + |
|
|
|
"method 'SetReward' was called. If your agent doesn't have the CumulativeRewardProvider," + |
|
|
|
"remove the call from 'SetReward'."); |
|
|
|
DefaultRewardProvider.SetReward(reward); |
|
|
|
WarnDefaultRewardProvider("SetReward"); |
|
|
|
CumulativeRewardProvider?.SetReward(reward); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
|
|
public void AddReward(float increment) |
|
|
|
{ |
|
|
|
Debug.Assert(DefaultRewardProvider != null, "the DefaultRewardProvider is null and " + |
|
|
|
"method 'AddReward' was called. If your agent doesn't have the CumulativeRewardProvider," + |
|
|
|
"remove the call from 'AddReward'."); |
|
|
|
DefaultRewardProvider.AddReward(increment); |
|
|
|
WarnDefaultRewardProvider("AddReward"); |
|
|
|
CumulativeRewardProvider?.AddReward(increment); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
|
|
public float GetIncrementalReward() |
|
|
|
{ |
|
|
|
Debug.Assert(rewardProvider != null, "m_RewardProviderComponent is null and " + |
|
|
|
"method 'GetIncrementalReward' was called."); |
|
|
|
return rewardProvider.GetIncrementalReward(); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
void InitializeRewardProvider() |
|
|
|
{ |
|
|
|
// Look for a legacy reward provider.
|
|
|
|
m_RewardProviderComponent = GetComponent<RewardProviderComponent>(); |
|
|
|
if (m_RewardProviderComponent == null) |
|
|
|
var rewardProviderComponent = GetComponent<RewardProviderComponent>(); |
|
|
|
if (rewardProviderComponent == null) |
|
|
|
m_RewardProviderComponent = gameObject.AddComponent<CumulativeRewardProviderComponent>(); |
|
|
|
rewardProviderComponent = gameObject.AddComponent<CumulativeRewardProviderComponent>(); |
|
|
|
rewardProvider = rewardProviderComponent.GetRewardProvider(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
m_Info.observations.Clear(); |
|
|
|
m_ActionMasker.ResetMask(); |
|
|
|
UpdateSensors(); |
|
|
|
rewardProvider.RewardStep(); |
|
|
|
using (TimerStack.Instance.Scoped("CollectObservations")) |
|
|
|
{ |
|
|
|
CollectObservations(); |
|
|
|
|
|
|
m_Info.id = m_Id; |
|
|
|
|
|
|
|
m_Brain.RequestDecision(this); |
|
|
|
|
|
|
|
if (m_Recorder != null && m_Recorder.record && Application.isEditor) |
|
|
|
{ |
|
|
|
// This is a bit of a hack - if we're in inference mode, observations won't be generated
|
|
|
|
|
|
|
if (m_RequestDecision) |
|
|
|
{ |
|
|
|
SendInfoToBrain(); |
|
|
|
ResetReward(); |
|
|
|
InternalResetReward(); |
|
|
|
|
|
|
|
m_HasAlreadyReset = false; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
if (m_Terminate) |
|
|
|
{ |
|
|
|
m_Terminate = false; |
|
|
|
ResetReward(); |
|
|
|
InternalResetReward(); |
|
|
|
m_Done = false; |
|
|
|
m_MaxStepReached = false; |
|
|
|
m_RequestDecision = false; |
|
|
|