您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
94 行
2.5 KiB
94 行
2.5 KiB
using System.Collections;
|
|
using System.Collections.Generic;
|
|
using UnityEngine;
|
|
|
|
public class BouncerAgent : Agent {
|
|
|
|
[Header("Bouncer Specific")]
|
|
public GameObject banana;
|
|
Rigidbody rb;
|
|
public float strength = 10f;
|
|
float jumpCooldown = 0f;
|
|
int numberJumps = 20;
|
|
int jumpLeft = 20;
|
|
|
|
public override void InitializeAgent()
|
|
{
|
|
rb = gameObject.GetComponent<Rigidbody>();
|
|
}
|
|
|
|
public override void CollectObservations()
|
|
{
|
|
AddVectorObs(gameObject.transform.localPosition);
|
|
AddVectorObs(banana.transform.localPosition);
|
|
}
|
|
|
|
public override void AgentAction(float[] vectorAction, string textAction)
|
|
{
|
|
float x = Mathf.Clamp(vectorAction[0], -1, 1);
|
|
float y = Mathf.Clamp(vectorAction[1], 0, 1);
|
|
float z = Mathf.Clamp(vectorAction[2], -1, 1);
|
|
rb.AddForce( new Vector3(x, y+1, z) *strength);
|
|
|
|
AddReward(-0.05f * (
|
|
vectorAction[0] * vectorAction[0] +
|
|
vectorAction[1] * vectorAction[1] +
|
|
vectorAction[2] * vectorAction[2]) / 3f);
|
|
|
|
}
|
|
|
|
public override void AgentReset()
|
|
{
|
|
|
|
gameObject.transform.localPosition = new Vector3(
|
|
(1 - 2 * Random.value) *5, 2, (1 - 2 * Random.value)*5);
|
|
rb.velocity = default(Vector3);
|
|
GameObject environment = gameObject.transform.parent.gameObject;
|
|
BouncerBanana[] bananas =
|
|
environment.GetComponentsInChildren<BouncerBanana>();
|
|
foreach (BouncerBanana bb in bananas)
|
|
{
|
|
bb.Respawn();
|
|
}
|
|
jumpLeft = numberJumps;
|
|
}
|
|
|
|
public override void AgentOnDone()
|
|
{
|
|
|
|
}
|
|
|
|
private void FixedUpdate()
|
|
{
|
|
if ((Physics.Raycast(transform.position, new Vector3(0f,-1f,0f), 0.51f))
|
|
&& jumpCooldown <= 0f)
|
|
{
|
|
RequestDecision();
|
|
jumpLeft -= 1;
|
|
jumpCooldown = 0.1f;
|
|
rb.velocity = default(Vector3);
|
|
}
|
|
jumpCooldown -= Time.fixedDeltaTime;
|
|
if (gameObject.transform.position.y < -1)
|
|
{
|
|
AddReward(-1);
|
|
Done();
|
|
return;
|
|
}
|
|
if ((gameObject.transform.localPosition.x < -19)
|
|
||(gameObject.transform.localPosition.x >19)
|
|
|| (gameObject.transform.localPosition.z < -19)
|
|
|| (gameObject.transform.localPosition.z > 19)
|
|
)
|
|
{
|
|
AddReward(-1);
|
|
Done();
|
|
return;
|
|
}
|
|
if (jumpLeft == 0f)
|
|
{
|
|
Done();
|
|
}
|
|
|
|
}
|
|
}
|