Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

153 行
5.1 KiB

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using MLAgents;
public class HallwayAgent : Agent
{
public GameObject ground;
public GameObject area;
public GameObject orangeGoal;
public GameObject redGoal;
public GameObject orangeBlock;
public GameObject redBlock;
public bool useVectorObs;
RayPerception rayPer;
Rigidbody shortBlockRB;
Rigidbody agentRB;
Material groundMaterial;
Renderer groundRenderer;
HallwayAcademy academy;
int selection;
public override void InitializeAgent()
{
base.InitializeAgent();
academy = FindObjectOfType<HallwayAcademy>();
rayPer = GetComponent<RayPerception>();
agentRB = GetComponent<Rigidbody>();
groundRenderer = ground.GetComponent<Renderer>();
groundMaterial = groundRenderer.material;
}
public override void CollectObservations()
{
if (useVectorObs)
{
float rayDistance = 12f;
float[] rayAngles = { 20f, 60f, 90f, 120f, 160f };
string[] detectableObjects = { "orangeGoal", "redGoal", "orangeBlock", "redBlock", "wall" };
AddVectorObs(GetStepCount() / (float)agentParameters.maxStep);
AddVectorObs(rayPer.Perceive(rayDistance, rayAngles, detectableObjects, 0f, 0f));
}
}
IEnumerator GoalScoredSwapGroundMaterial(Material mat, float time)
{
groundRenderer.material = mat;
yield return new WaitForSeconds(time);
groundRenderer.material = groundMaterial;
}
public void MoveAgent(float[] act)
{
Vector3 dirToGo = Vector3.zero;
Vector3 rotateDir = Vector3.zero;
if (brain.brainParameters.vectorActionSpaceType == SpaceType.continuous)
{
dirToGo = transform.forward * Mathf.Clamp(act[0], -1f, 1f);
rotateDir = transform.up * Mathf.Clamp(act[1], -1f, 1f);
}
else
{
int action = Mathf.FloorToInt(act[0]);
switch (action)
{
case 0:
dirToGo = transform.forward * 1f;
break;
case 1:
dirToGo = transform.forward * -1f;
break;
case 2:
rotateDir = transform.up * 1f;
break;
case 3:
rotateDir = transform.up * -1f;
break;
}
}
transform.Rotate(rotateDir, Time.deltaTime * 150f);
agentRB.AddForce(dirToGo * academy.agentRunSpeed, ForceMode.VelocityChange);
}
public override void AgentAction(float[] vectorAction, string textAction)
{
AddReward(-1f / agentParameters.maxStep);
MoveAgent(vectorAction);
}
void OnCollisionEnter(Collision col)
{
if (col.gameObject.CompareTag("orangeGoal") || col.gameObject.CompareTag("redGoal"))
{
if ((selection == 0 && col.gameObject.CompareTag("orangeGoal")) ||
(selection == 1 && col.gameObject.CompareTag("redGoal")))
{
SetReward(1f);
StartCoroutine(GoalScoredSwapGroundMaterial(academy.goalScoredMaterial, 0.5f));
}
else
{
SetReward(-0.1f);
StartCoroutine(GoalScoredSwapGroundMaterial(academy.failMaterial, 0.5f));
}
Done();
}
}
public override void AgentReset()
{
float agentOffset = -15f;
float blockOffset = 0f;
selection = Random.Range(0, 2);
if (selection == 0)
{
orangeBlock.transform.position =
new Vector3(0f + Random.Range(-3f, 3f), 2f, blockOffset + Random.Range(-5f, 5f))
+ ground.transform.position;
redBlock.transform.position =
new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f))
+ ground.transform.position;
}
else
{
orangeBlock.transform.position =
new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f))
+ ground.transform.position;
redBlock.transform.position =
new Vector3(0f, 2f, blockOffset + Random.Range(-5f, 5f))
+ ground.transform.position;
}
transform.position = new Vector3(0f + Random.Range(-3f, 3f),
1f, agentOffset + Random.Range(-5f, 5f))
+ ground.transform.position;
transform.rotation = Quaternion.Euler(0f, Random.Range(0f, 360f), 0f);
agentRB.velocity *= 0f;
int goalPos = Random.Range(0, 2);
if (goalPos == 0)
{
orangeGoal.transform.position = new Vector3(7f, 0.5f, 9f) + area.transform.position;
redGoal.transform.position = new Vector3(-7f, 0.5f, 9f) + area.transform.position;
}
else
{
redGoal.transform.position = new Vector3(7f, 0.5f, 9f) + area.transform.position;
orangeGoal.transform.position = new Vector3(-7f, 0.5f, 9f) + area.transform.position;
}
}
}