浏览代码

update tennis reward function

/asymm-envs
Andrew Cohen 5 年前
当前提交
5d659946
共有 3 个文件被更改,包括 36 次插入23 次删除
  1. 40
      Project/Assets/ML-Agents/Examples/Tennis/Scripts/HitWall.cs
  2. 14
      Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs
  3. 5
      Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisArea.cs

40
Project/Assets/ML-Agents/Examples/Tennis/Scripts/HitWall.cs


public class HitWall : MonoBehaviour
{
public GameObject areaObject;
public int lastAgentHit;
int m_LastAgentHit;
public enum FloorHit
{

FloorBHit
}
public FloorHit lastFloorHit;
FloorHit m_LastFloorHit;
TennisArea m_Area;
TennisAgent m_AgentA;

m_Area = areaObject.GetComponent<TennisArea>();
m_AgentA = m_Area.agentA.GetComponent<TennisAgent>();
m_AgentB = m_Area.agentB.GetComponent<TennisAgent>();
}
public void ResetPoint()
{
m_LastFloorHit = FloorHit.Service;
m_LastAgentHit = -1;
}
void Reset()

m_Area.MatchReset();
lastFloorHit = FloorHit.Service;
lastAgentHit = -1;
m_AgentA.SetReward(1);
m_AgentA.SetReward(1 + m_AgentA.timePenalty);
m_AgentB.SetReward(-1);
m_AgentA.score += 1;
Reset();

void AgentBWins()
{
m_AgentA.SetReward(-1);
m_AgentB.SetReward(1);
m_AgentB.SetReward(1 + m_AgentB.timePenalty);
m_AgentB.score += 1;
Reset();

if (collision.gameObject.name == "wallA")
{
// Agent A hits into wall or agent B hit a winner
if (lastAgentHit == 0 || lastFloorHit == FloorHit.FloorAHit)
if (m_LastAgentHit == 0 || m_LastFloorHit == FloorHit.FloorAHit)
{
AgentBWins();
}

else if (collision.gameObject.name == "wallB")
{
// Agent B hits into wall or agent A hit a winner
if (lastAgentHit == 1 || lastFloorHit == FloorHit.FloorBHit)
if (m_LastAgentHit == 1 || m_LastFloorHit == FloorHit.FloorBHit)
{
AgentAWins();
}

else if (collision.gameObject.name == "floorA")
{
// Agent A hits into floor, double bounce or service
if (lastAgentHit == 0 || lastFloorHit == FloorHit.FloorAHit || lastFloorHit == FloorHit.Service)
if (m_LastAgentHit == 0 || m_LastFloorHit == FloorHit.FloorAHit || m_LastFloorHit == FloorHit.Service)
lastFloorHit = FloorHit.FloorAHit;
m_LastFloorHit = FloorHit.FloorAHit;
if (lastAgentHit == 1 || lastFloorHit == FloorHit.FloorBHit || lastFloorHit == FloorHit.Service)
if (m_LastAgentHit == 1 || m_LastFloorHit == FloorHit.FloorBHit || m_LastFloorHit == FloorHit.Service)
lastFloorHit = FloorHit.FloorBHit;
m_LastFloorHit = FloorHit.FloorBHit;
}
}
}

if (lastAgentHit == 0)
if (m_LastAgentHit == 0)
{
AgentBWins();
}

lastAgentHit = 0;
lastFloorHit = FloorHit.FloorHitUnset;
m_LastAgentHit = 0;
m_LastFloorHit = FloorHit.FloorHitUnset;
if (lastAgentHit == 1)
if (m_LastAgentHit == 1)
lastAgentHit = 1;
lastFloorHit = FloorHit.FloorHitUnset;
m_LastAgentHit = 1;
m_LastFloorHit = FloorHit.FloorHitUnset;
}
}
}

14
Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs


public float angle;
public float scale;
[HideInInspector]
public float timePenalty = 0;
TennisArea m_Area;
float m_InvertMult;
FloatPropertiesChannel m_ResetParams;

{
m_AgentRb = GetComponent<Rigidbody>();
m_BallRb = ball.GetComponent<Rigidbody>();
m_Area = myArea.GetComponent<TennisArea>();
var canvas = GameObject.Find(k_CanvasName);
GameObject scoreBoard;
m_ResetParams = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();

transform.position.y,
transform.position.z);
}
AddReward(-1f / 3000f);
timePenalty += -1f / 3000f;
m_TextComponent.text = score.ToString();
}

public override void OnEpisodeBegin()
{
timePenalty = 0;
if (m_InvertMult == 1f)
{
m_Area.MatchReset();
}
SetResetParameters();
}

5
Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisArea.cs


public GameObject agentA;
public GameObject agentB;
Rigidbody m_BallRb;
HitWall m_BallScript;
m_BallScript = ball.GetComponent<HitWall>();
MatchReset();
}

}
m_BallRb.velocity = new Vector3(0f, 0f, 0f);
ball.transform.localScale = new Vector3(.5f, .5f, .5f);
ball.GetComponent<HitWall>().lastAgentHit = -1;
m_BallScript.ResetPoint();
}
void FixedUpdate()

正在加载...
取消
保存