浏览代码

made AgentAction take vectorAction and textAction (#397)

/develop-generalizationTraining-TrainerController
GitHub 7 年前
当前提交
1409236e
共有 16 个文件被更改,包括 62 次插入54 次删除
  1. 2
      unity-environment/Assets/ML-Agents/Editor/MLAgentsEditModeTest.cs
  2. 7
      unity-environment/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs
  3. 6
      unity-environment/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs
  4. 4
      unity-environment/Assets/ML-Agents/Examples/Area/Scripts/AreaAgent.cs
  5. 4
      unity-environment/Assets/ML-Agents/Examples/Area/Scripts/Push/PushAgent.cs
  6. 4
      unity-environment/Assets/ML-Agents/Examples/Area/Scripts/Wall/WallAgent.cs
  7. 4
      unity-environment/Assets/ML-Agents/Examples/Banana/Scripts/BananaAgent.cs
  8. 4
      unity-environment/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs
  9. 6
      unity-environment/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs
  10. 45
      unity-environment/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgentConfigurable.cs
  11. 4
      unity-environment/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
  12. 4
      unity-environment/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs
  13. 10
      unity-environment/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs
  14. 6
      unity-environment/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs
  15. 4
      unity-environment/Assets/ML-Agents/Scripts/Agent.cs
  16. 2
      unity-environment/Assets/ML-Agents/Template/Scripts/TemplateAgent.cs

2
unity-environment/Assets/ML-Agents/Editor/MLAgentsEditModeTest.cs


collectObservationsCalls += 1;
}
public override void AgentAction(float[] act)
public override void AgentAction(float[] vetorAction, string textAction)
{
agentActionCalls += 1;
AddReward(0.1f);

7
unity-environment/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DAgent.cs


}
public override void AgentAction(float[] act)
public override void AgentAction(float[] vectorAction, string textAction)
float action_z = 2f * Mathf.Clamp(act[0], -1f, 1f);
float action_z = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f);
float action_x = 2f * Mathf.Clamp(act[1], -1f, 1f);
float action_x = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f);
if ((gameObject.transform.rotation.x < 0.25f && action_x > 0f) ||
(gameObject.transform.rotation.x > -0.25f && action_x < 0f))
{

Done();
SetReward(-1f);
}
}

6
unity-environment/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DHardAgent.cs


AddVectorObs((ball.transform.position.z - gameObject.transform.position.z));
}
public override void AgentAction(float[] act)
public override void AgentAction(float[] vectorAction, string textAction)
float action_z = 2f * Mathf.Clamp(act[0], -1f, 1f);
float action_z = 2f * Mathf.Clamp(vectorAction[0], -1f, 1f);
float action_x = 2f * Mathf.Clamp(act[1], -1f, 1f);
float action_x = 2f * Mathf.Clamp(vectorAction[1], -1f, 1f);
if ((gameObject.transform.rotation.x < 0.25f && action_x > 0f) ||
(gameObject.transform.rotation.x > -0.25f && action_x < 0f))
{

4
unity-environment/Assets/ML-Agents/Examples/Area/Scripts/AreaAgent.cs


}
}
public override void AgentAction(float[] act)
public override void AgentAction(float[] vectorAction, string textAction)
MoveAgent(act);
MoveAgent(vectorAction);
if (gameObject.transform.position.y < 0.0f || Mathf.Abs(gameObject.transform.position.x - area.transform.position.x) > 8f ||
Mathf.Abs(gameObject.transform.position.z + 5 - area.transform.position.z) > 8)

4
unity-environment/Assets/ML-Agents/Examples/Area/Scripts/Push/PushAgent.cs


}
public override void AgentAction(float[] act)
public override void AgentAction(float[] vectorAction, string textAction)
MoveAgent(act);
MoveAgent(vectorAction);
if (gameObject.transform.position.y < 0.0f || Mathf.Abs(gameObject.transform.position.x - area.transform.position.x) > 8f ||
Mathf.Abs(gameObject.transform.position.z + 5 - area.transform.position.z) > 8)

4
unity-environment/Assets/ML-Agents/Examples/Area/Scripts/Wall/WallAgent.cs


AddVectorObs(blockVelocity.z);
}
public override void AgentAction(float[] act)
public override void AgentAction(float[] vectorAction, string textAction)
MoveAgent(act);
MoveAgent(vectorAction);
if (gameObject.transform.position.y < 0.0f ||
Mathf.Abs(gameObject.transform.position.x - area.transform.position.x) > 8f ||

4
unity-environment/Assets/ML-Agents/Examples/Banana/Scripts/BananaAgent.cs


public override void AgentAction(float[] act)
public override void AgentAction(float[] vectorAction, string textAction)
MoveAgent(act);
MoveAgent(vectorAction);
}
public override void AgentReset()

4
unity-environment/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs


AddVectorObs(position);
}
public override void AgentAction(float[] act)
public override void AgentAction(float[] vectorAction, string textAction)
float movement = act[0];
float movement = vectorAction[0];
int direction = 0;
if (movement == 0) { direction = -1; }
if (movement == 1) { direction = 1; }

6
unity-environment/Assets/ML-Agents/Examples/Bouncer/Scripts/BouncerAgent.cs


AddVectorObs(banana.transform.position.z / 25f);
}
public override void AgentAction(float[] act)
public override void AgentAction(float[] vectorAction, string textAction)
float x = Mathf.Clamp(act[0], -1, 1);
float z = Mathf.Clamp(act[1], -1, 1);
float x = Mathf.Clamp(vectorAction[0], -1, 1);
float z = Mathf.Clamp(vectorAction[1], -1, 1);
rb.velocity = new Vector3(x, 0, z) ;
if (rb.velocity.magnitude < 0.01f){
AddReward(-1);

45
unity-environment/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgentConfigurable.cs


}
}
public override void AgentAction(float[] act)
public override void AgentAction(float[] vectorAction, string textAction)
for (int k = 0; k < act.Length; k++)
for (int k = 0; k < vectorAction.Length; k++)
act[k] = Mathf.Clamp(act[k], -1f, 1f);
vectorAction[k] = Mathf.Clamp(vectorAction[k], -1f, 1f);
limbRBs[0].AddTorque(-limbs[0].transform.right * strength * act[0]);
limbRBs[1].AddTorque(-limbs[1].transform.right * strength * act[1]);
limbRBs[2].AddTorque(-limbs[2].transform.right * strength * act[2]);
limbRBs[3].AddTorque(-limbs[3].transform.right * strength * act[3]);
limbRBs[0].AddTorque(-body.transform.up * strength * act[4]);
limbRBs[1].AddTorque(-body.transform.up * strength * act[5]);
limbRBs[2].AddTorque(-body.transform.up * strength * act[6]);
limbRBs[3].AddTorque(-body.transform.up * strength * act[7]);
limbRBs[4].AddTorque(-limbs[4].transform.right * strength * act[8]);
limbRBs[5].AddTorque(-limbs[5].transform.right * strength * act[9]);
limbRBs[6].AddTorque(-limbs[6].transform.right * strength * act[10]);
limbRBs[7].AddTorque(-limbs[7].transform.right * strength * act[11]);
limbRBs[0].AddTorque(-limbs[0].transform.right * strength * vectorAction[0]);
limbRBs[1].AddTorque(-limbs[1].transform.right * strength * vectorAction[1]);
limbRBs[2].AddTorque(-limbs[2].transform.right * strength * vectorAction[2]);
limbRBs[3].AddTorque(-limbs[3].transform.right * strength * vectorAction[3]);
limbRBs[0].AddTorque(-body.transform.up * strength * vectorAction[4]);
limbRBs[1].AddTorque(-body.transform.up * strength * vectorAction[5]);
limbRBs[2].AddTorque(-body.transform.up * strength * vectorAction[6]);
limbRBs[3].AddTorque(-body.transform.up * strength * vectorAction[7]);
limbRBs[4].AddTorque(-limbs[4].transform.right * strength * vectorAction[8]);
limbRBs[5].AddTorque(-limbs[5].transform.right * strength * vectorAction[9]);
limbRBs[6].AddTorque(-limbs[6].transform.right * strength * vectorAction[10]);
limbRBs[7].AddTorque(-limbs[7].transform.right * strength * vectorAction[11]);
float torque_penalty = act[0] * act[0] + act[1] * act[1] + act[2] * act[2] + act[3] * act[3]
+ act[4] * act[4] + act[5] * act[5] + act[6] * act[6] + act[7] * act[7]
+ act[8] * act[8] + act[9] * act[9] + act[10] * act[10] + act[11] * act[11];
float torque_penalty = vectorAction[0] * vectorAction[0] +
vectorAction[1] * vectorAction[1] +
vectorAction[2] * vectorAction[2] +
vectorAction[3] * vectorAction[3] +
vectorAction[4] * vectorAction[4] +
vectorAction[5] * vectorAction[5] +
vectorAction[6] * vectorAction[6] +
vectorAction[7] * vectorAction[7] +
vectorAction[8] * vectorAction[8] +
vectorAction[9] * vectorAction[9] +
vectorAction[10] * vectorAction[10] +
vectorAction[11] * vectorAction[11];
if (!IsDone())
{

4
unity-environment/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


}
// to be implemented by the developer
public override void AgentAction(float[] act)
public override void AgentAction(float[] vectorAction, string textAction)
int action = Mathf.FloorToInt(act[0]);
int action = Mathf.FloorToInt(vectorAction[0]);
// 0 - Forward, 1 - Backward, 2 - Left, 3 - Right
Vector3 targetPos = transform.position;

4
unity-environment/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs


agentRB.AddForce(dirToGo * academy.agentRunSpeed, ForceMode.VelocityChange); //GO
}
public override void AgentAction(float[] act)
public override void AgentAction(float[] vectorAction, string textAction)
MoveAgent(act); //perform agent actions
MoveAgent(vectorAction); //perform agent actions
bool fail = false; // did the agent or block get pushed off the edge?
if (!Physics.Raycast(agentRB.position, Vector3.down, 20)) //if the agent has gone over the edge, we done.

10
unity-environment/Assets/ML-Agents/Examples/Reacher/Scripts/ReacherAgent.cs


/// <summary>
/// The agent's four actions correspond to torques on each of the two joints.
/// </summary>
public override void AgentAction(float[] act)
public override void AgentAction(float[] vectorAction, string textAction)
float torque_x = Mathf.Clamp(act[0], -1, 1) * 100f;
float torque_z = Mathf.Clamp(act[1], -1, 1) * 100f;
float torque_x = Mathf.Clamp(vectorAction[0], -1, 1) * 100f;
float torque_z = Mathf.Clamp(vectorAction[1], -1, 1) * 100f;
torque_x = Mathf.Clamp(act[2], -1, 1) * 100f;
torque_z = Mathf.Clamp(act[3], -1, 1) * 100f;
torque_x = Mathf.Clamp(vectorAction[2], -1, 1) * 100f;
torque_z = Mathf.Clamp(vectorAction[3], -1, 1) * 100f;
rbB.AddTorque(new Vector3(torque_x, 0f, torque_z));
}

6
unity-environment/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs


}
// to be implemented by the developer
public override void AgentAction(float[] act)
public override void AgentAction(float[] vectorAction, string textAction)
moveX = 0.25f * Mathf.Clamp(act[0], -1f, 1f) * invertMult;
if (Mathf.Clamp(act[1], -1f, 1f) > 0f && gameObject.transform.position.y - transform.parent.transform.position.y < -1.5f)
moveX = 0.25f * Mathf.Clamp(vectorAction[0], -1f, 1f) * invertMult;
if (Mathf.Clamp(vectorAction[1], -1f, 1f) > 0f && gameObject.transform.position.y - transform.parent.transform.position.y < -1.5f)
{
moveY = 0.5f;
gameObject.GetComponent<Rigidbody>().velocity = new Vector3(GetComponent<Rigidbody>().velocity.x, moveY * 12f, 0f);

4
unity-environment/Assets/ML-Agents/Scripts/Agent.cs


/// </summary>
/// <param name="action">The action the agent receives
/// from the brain.</param>
public virtual void AgentAction(float[] action)
public virtual void AgentAction(float[] vectorAction, string textAction)
{
}

if ((requestAction) && (brain != null))
{
requestAction = false;
AgentAction(_action.vectorActions);
AgentAction(_action.vectorActions, _action.textActions);
}
if ((stepCounter >= agentParameters.maxStep)

2
unity-environment/Assets/ML-Agents/Template/Scripts/TemplateAgent.cs


}
public override void AgentAction(float[] act)
public override void AgentAction(float[] vectorAction, string textAction)
{
}

正在加载...
取消
保存