using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.UI;
using System.Linq;
// Class contains all necessary environment parameters
// to be defined and sent to external agent
#if ENABLE_TENSORFLOW
public enum BrainType
{
Player,
Heuristic,
External,
Internal
}
#else
public enum BrainType
{
Player,
Heuristic,
External,
}
#endif
public enum StateType
{
discrete,
continuous
}
;
/** Only need to be modified in the brain's inpector.
* Defines what is the resolution of the camera
*/
[System.Serializable]
public struct resolution
{
public int width;
/**< \brief The width of the observation in pixels */
public int height;
/**< \brief The height of the observation in pixels */
public bool blackAndWhite;
/**< \brief If true, the image will be in black and white.
* If false, it will be in colors RGB */
}
/** Should be modified via the Editor Inspector.
* Defines brain-specific parameters
*/
[System.Serializable]
public class BrainParameters
{
[Tooltip("Length of state vector for brain (In Continuous state space)." +
"Or number of possible values (in Discrete state space).")]
public int stateSize = 1;
/**< \brief If continuous : The length of the float vector that represents
* the state
*
If discrete : The number of possible values the state can take*/
[Tooltip("Number of states that will be staked before beeing fed to the neural network.")]
[Range(1, 10)]
public int stackedStates = 1;
[Tooltip("Length of action vector for brain (In Continuous state space)." +
"Or number of possible values (in Discrete action space).")]
public int actionSize = 1;
/**< \brief If continuous : The length of the float vector that represents the action
*
If discrete : The number of possible values the action can take*/
[Tooltip("Length of memory vector for brain. Used with Recurrent networks.")]
public int memorySize = 0;
/**< \brief The length of the float vector that holds the memory for the agent */
[Tooltip("Describes height, width, and whether to greyscale visual observations for the Brain.")]
public resolution[] cameraResolutions;
/**<\brief The list of observation resolutions for the brain */
[Tooltip("A list of strings used to name the available actions for the Brain.")]
public string[] actionDescriptions;
/**< \brief The list of strings describing what the actions correpond to */
[Tooltip("Corresponds to whether state vector contains a single integer (Discrete) " +
"or a series of real-valued floats (Continuous).")]
public StateType actionSpaceType = StateType.discrete;
/**< \brief Defines if the action is discrete or continuous */
[Tooltip("Corresponds to whether action vector contains a single integer (Discrete)" +
" or a series of real-valued floats (Continuous).")]
public StateType stateSpaceType = StateType.continuous;
/**< \brief Defines if the state is discrete or continuous */
}
[HelpURL("https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Agents-Editor-Interface.md#brain")]
/**
* Contains all high-level Brain logic.
* Add this component to an empty GameObject in your scene and drag this
* GameObject into your Academy to make it a child in the hierarchy.
* Contains a set of CoreBrains, which each correspond to a different method
* for deciding actions.
*/
public class Brain : MonoBehaviour
{
// Current agent info
public Dictionary> currentStates = new Dictionary>(32);
public Dictionary> currentCameras = new Dictionary>(32);
public Dictionary currentRewards = new Dictionary(32);
public Dictionary currentDones = new Dictionary(32);
public Dictionary currentMaxes = new Dictionary(32);
public Dictionary currentActions = new Dictionary(32);
public Dictionary currentMemories = new Dictionary(32);
[Tooltip("Define state, observation, and action spaces for the Brain.")]
/**< \brief Defines brain specific parameters such as the state size*/
public BrainParameters brainParameters = new BrainParameters();
/**< \brief Defines what is the type of the brain :
* External / Internal / Player / Heuristic*/
[Tooltip("Describes how the Brain will decide actions.")]
public BrainType brainType;
[HideInInspector]
/**< \brief Keeps track of the agents which subscribe to this brain*/
public Dictionary agents = new Dictionary();
[SerializeField]
ScriptableObject[] CoreBrains;
/**< \brief Reference to the current CoreBrain used by the brain*/
public CoreBrain coreBrain;
//Ensures the coreBrains are not dupplicated with the brains
[SerializeField]
private int instanceID;
/// Ensures the brain has an up to date array of coreBrains
/** Is called when the inspector is modified and into InitializeBrain.
* If the brain gameObject was just created, it generates a list of
* coreBrains (one for each brainType) */
public void UpdateCoreBrains()
{
// If CoreBrains is null, this means the Brain object was just
// instanciated and we create instances of each CoreBrain
if (CoreBrains == null)
{
CoreBrains = new ScriptableObject[System.Enum.GetValues(typeof(BrainType)).Length];
foreach (BrainType bt in System.Enum.GetValues(typeof(BrainType)))
{
CoreBrains[(int)bt] = ScriptableObject.CreateInstance("CoreBrain" + bt.ToString());
}
}
else
{
foreach (BrainType bt in System.Enum.GetValues(typeof(BrainType)))
{
if ((int)bt >= CoreBrains.Length)
break;
if (CoreBrains[(int)bt] == null)
{
CoreBrains[(int)bt] = ScriptableObject.CreateInstance("CoreBrain" + bt.ToString());
}
}
}
// If the length of CoreBrains does not match the number of BrainTypes,
// we increase the length of CoreBrains
if (CoreBrains.Length < System.Enum.GetValues(typeof(BrainType)).Length)
{
ScriptableObject[] new_CoreBrains = new ScriptableObject[System.Enum.GetValues(typeof(BrainType)).Length];
foreach (BrainType bt in System.Enum.GetValues(typeof(BrainType)))
{
if ((int)bt < CoreBrains.Length)
{
new_CoreBrains[(int)bt] = CoreBrains[(int)bt];
}
else
{
new_CoreBrains[(int)bt] = ScriptableObject.CreateInstance("CoreBrain" + bt.ToString());
}
}
CoreBrains = new_CoreBrains;
}
// If the stored instanceID does not match the current instanceID,
// this means that the Brain GameObject was duplicated, and
// we need to make a new copy of each CoreBrain
if (instanceID != gameObject.GetInstanceID())
{
foreach (BrainType bt in System.Enum.GetValues(typeof(BrainType)))
{
if (CoreBrains[(int)bt] == null)
{
CoreBrains[(int)bt] = ScriptableObject.CreateInstance("CoreBrain" + bt.ToString());
}
else
{
CoreBrains[(int)bt] = ScriptableObject.Instantiate(CoreBrains[(int)bt]);
}
}
instanceID = gameObject.GetInstanceID();
}
// The coreBrain to display is the one defined in brainType
coreBrain = (CoreBrain)CoreBrains[(int)brainType];
coreBrain.SetBrain(this);
}
/// This is called by the Academy at the start of the environemnt.
public void InitializeBrain()
{
UpdateCoreBrains();
coreBrain.InitializeCoreBrain();
}
public void CollectEverything()
{
currentStates.Clear();
currentCameras.Clear();
currentRewards.Clear();
currentDones.Clear();
currentMaxes.Clear();
currentActions.Clear();
currentMemories.Clear();
foreach (KeyValuePair idAgent in agents)
{
idAgent.Value.SetCumulativeReward();
List states = idAgent.Value.ClearAndCollectState();
if ((states.Count != brainParameters.stateSize * brainParameters.stackedStates) && (brainParameters.stateSpaceType == StateType.continuous))
{
throw new UnityAgentsException(string.Format(@"The number of states does not match for agent {0}:
Was expecting {1} continuous states but received {2}.", idAgent.Value.gameObject.name, brainParameters.stateSize, states.Count));
}
if ((states.Count != brainParameters.stackedStates) && (brainParameters.stateSpaceType == StateType.discrete))
{
throw new UnityAgentsException(string.Format(@"The number of states does not match for agent {0}:
Was expecting 1 discrete states but received {1}.", idAgent.Value.gameObject.name, states.Count));
}
List observations = idAgent.Value.observations;
if (observations.Count < brainParameters.cameraResolutions.Count())
{
throw new UnityAgentsException(string.Format(@"The number of observations does not match for agent {0}:
Was expecting at least {1} observation but received {2}.", idAgent.Value.gameObject.name, brainParameters.cameraResolutions.Count(), observations.Count));
}
currentStates.Add(idAgent.Key, states);
currentCameras.Add(idAgent.Key, observations);
currentRewards.Add(idAgent.Key, idAgent.Value.reward);
currentDones.Add(idAgent.Key, idAgent.Value.done);
currentMaxes.Add(idAgent.Key, idAgent.Value.maxStepReached);
currentActions.Add(idAgent.Key, idAgent.Value.agentStoredAction);
currentMemories.Add(idAgent.Key, idAgent.Value.memory);
}
}
/// Collects the states of all the agents which subscribe to this brain
/// and returns a dictionary {id -> state}
public Dictionary> CollectStates()
{
currentStates.Clear();
foreach (KeyValuePair idAgent in agents)
{
idAgent.Value.SetCumulativeReward();
List states = idAgent.Value.ClearAndCollectState();
if ((states.Count != brainParameters.stateSize * brainParameters.stackedStates) && (brainParameters.stateSpaceType == StateType.continuous))
{
throw new UnityAgentsException(string.Format(@"The number of states does not match for agent {0}:
Was expecting {1} continuous states but received {2}.", idAgent.Value.gameObject.name, brainParameters.stateSize, states.Count));
}
if ((states.Count != brainParameters.stackedStates) && (brainParameters.stateSpaceType == StateType.discrete))
{
throw new UnityAgentsException(string.Format(@"The number of states does not match for agent {0}:
Was expecting 1 discrete states but received {1}.", idAgent.Value.gameObject.name, states.Count));
}
currentStates.Add(idAgent.Key, states);
}
return currentStates;
}
/// Collects the observations of all the agents which subscribe to this
/// brain and returns a dictionary {id -> Camera}
public Dictionary> CollectObservations()
{
currentCameras.Clear();
foreach (KeyValuePair idAgent in agents)
{
List observations = idAgent.Value.observations;
if (observations.Count < brainParameters.cameraResolutions.Count())
{
throw new UnityAgentsException(string.Format(@"The number of observations does not match for agent {0}:
Was expecting at least {1} observation but received {2}.", idAgent.Value.gameObject.name, brainParameters.cameraResolutions.Count(), observations.Count));
}
currentCameras.Add(idAgent.Key, observations);
}
return currentCameras;
}
/// Collects the rewards of all the agents which subscribe to this brain
/// and returns a dictionary {id -> reward}
public Dictionary CollectRewards()
{
currentRewards.Clear();
foreach (KeyValuePair idAgent in agents)
{
currentRewards.Add(idAgent.Key, idAgent.Value.reward);
}
return currentRewards;
}
/// Collects the done flag of all the agents which subscribe to this brain
/// and returns a dictionary {id -> done}
public Dictionary CollectDones()
{
currentDones.Clear();
foreach (KeyValuePair idAgent in agents)
{
currentDones.Add(idAgent.Key, idAgent.Value.done);
}
return currentDones;
}
/// Collects the done flag of all the agents which subscribe to this brain
/// and returns a dictionary {id -> done}
public Dictionary CollectMaxes()
{
currentMaxes.Clear();
foreach (KeyValuePair idAgent in agents)
{
currentMaxes.Add(idAgent.Key, idAgent.Value.maxStepReached);
}
return currentMaxes;
}
/// Collects the actions of all the agents which subscribe to this brain
/// and returns a dictionary {id -> action}
public Dictionary CollectActions()
{
currentActions.Clear();
foreach (KeyValuePair idAgent in agents)
{
currentActions.Add(idAgent.Key, idAgent.Value.agentStoredAction);
}
return currentActions;
}
/// Collects the memories of all the agents which subscribe to this brain
/// and returns a dictionary {id -> memories}
public Dictionary CollectMemories()
{
currentMemories.Clear();
foreach (KeyValuePair idAgent in agents)
{
currentMemories.Add(idAgent.Key, idAgent.Value.memory);
}
return currentMemories;
}
/// Takes a dictionary {id -> memories} and sends the memories to the
/// corresponding agents
public void SendMemories(Dictionary memories)
{
foreach (KeyValuePair idAgent in agents)
{
idAgent.Value.memory = memories[idAgent.Key];
}
}
/// Takes a dictionary {id -> actions} and sends the actions to the
/// corresponding agents
public void SendActions(Dictionary actions)
{
foreach (KeyValuePair idAgent in agents)
{
//Add a check here to see if the component was destroyed ?
idAgent.Value.UpdateAction(actions[idAgent.Key]);
}
}
/// Takes a dictionary {id -> values} and sends the values to the
/// corresponding agents
public void SendValues(Dictionary values)
{
foreach (KeyValuePair idAgent in agents)
{
//Add a check here to see if the component was destroyed ?
idAgent.Value.value = values[idAgent.Key];
}
}
///Sets all the agents which subscribe to the brain to done
public void SendDone()
{
foreach (KeyValuePair idAgent in agents)
{
idAgent.Value.done = true;
}
}
///Sets all the agents which subscribe to the brain to maxStepReached
public void SendMaxReached()
{
foreach (KeyValuePair idAgent in agents)
{
idAgent.Value.maxStepReached = true;
}
}
/// Uses coreBrain to call SendState on the CoreBrain
public void SendState()
{
coreBrain.SendState();
}
/// Uses coreBrain to call decideAction on the CoreBrain
public void DecideAction()
{
coreBrain.DecideAction();
}
/// \brief Is used by the Academy to send a step message to all the agents
/// which are not done
public void Step()
{
var agentsToIterate = agents.Values.ToList();
foreach (Agent agent in agentsToIterate)
{
if (!agent.done)
{
agent.Step();
}
}
}
/// Is used by the Academy to reset the agents if they are done
public void ResetIfDone()
{
var agentsToIterate = agents.Values.ToList();
foreach (Agent agent in agentsToIterate)
{
if (agent.done)
{
if (!agent.resetOnDone)
{
agent.AgentOnDone();
}
else
{
agent.Reset();
}
}
}
}
/// Is used by the Academy to reset all agents
public void Reset()
{
foreach (Agent agent in agents.Values)
{
agent.Reset();
agent.done = false;
agent.maxStepReached = false;
}
}
/// \brief Is used by the Academy reset the done flag and the rewards of the
/// agents that subscribe to the brain
public void ResetDoneAndReward()
{
foreach (Agent agent in agents.Values)
{
if (!agent.done || agent.resetOnDone)
{
agent.ResetReward();
agent.done = false;
agent.maxStepReached = false;
}
}
}
/** Contains logic for coverting a camera component into a Texture2D. */
public Texture2D ObservationToTex(Camera camera, int width, int height)
{
Camera cam = camera;
Rect oldRec = camera.rect;
cam.rect = new Rect(0f, 0f, 1f, 1f);
bool supportsAntialiasing = false;
bool needsRescale = false;
var depth = 24;
var format = RenderTextureFormat.Default;
var readWrite = RenderTextureReadWrite.Default;
var antiAliasing = (supportsAntialiasing) ? Mathf.Max(1, QualitySettings.antiAliasing) : 1;
var finalRT =
RenderTexture.GetTemporary(width, height, depth, format, readWrite, antiAliasing);
var renderRT = (!needsRescale) ? finalRT :
RenderTexture.GetTemporary(width, height, depth, format, readWrite, antiAliasing);
var tex = new Texture2D(width, height, TextureFormat.RGB24, false);
var prevActiveRT = RenderTexture.active;
var prevCameraRT = cam.targetTexture;
// render to offscreen texture (readonly from CPU side)
RenderTexture.active = renderRT;
cam.targetTexture = renderRT;
cam.Render();
if (needsRescale)
{
RenderTexture.active = finalRT;
Graphics.Blit(renderRT, finalRT);
RenderTexture.ReleaseTemporary(renderRT);
}
tex.ReadPixels(new Rect(0, 0, tex.width, tex.height), 0, 0);
tex.Apply();
cam.targetTexture = prevCameraRT;
cam.rect = oldRec;
RenderTexture.active = prevActiveRT;
RenderTexture.ReleaseTemporary(finalRT);
return tex;
}
/// Contains logic to convert the agent's cameras into observation list
/// (as list of float arrays)
public List GetObservationMatrixList(List agent_keys)
{
var observation_matrix_list = new List();
Dictionary> observations = CollectObservations();
for (int obs_number = 0; obs_number < brainParameters.cameraResolutions.Length; obs_number++)
{
var width = brainParameters.cameraResolutions[obs_number].width;
var height = brainParameters.cameraResolutions[obs_number].height;
var bw = brainParameters.cameraResolutions[obs_number].blackAndWhite;
var pixels = 0;
if (bw)
pixels = 1;
else
pixels = 3;
float[,,,] observation_matrix = new float[agent_keys.Count
, height
, width
, pixels];
var i = 0;
foreach (int k in agent_keys)
{
Camera agent_obs = observations[k][obs_number];
Texture2D tex = ObservationToTex(agent_obs, width, height);
for (int w = 0; w < width; w++)
{
for (int h = 0; h < height; h++)
{
Color c = tex.GetPixel(w, h);
if (!bw)
{
observation_matrix[i, tex.height - h - 1, w, 0] = c.r;
observation_matrix[i, tex.height - h - 1, w, 1] = c.g;
observation_matrix[i, tex.height - h - 1, w, 2] = c.b;
}
else
{
observation_matrix[i, tex.height - h - 1, w, 0] = (c.r + c.g + c.b) / 3;
}
}
}
UnityEngine.Object.DestroyImmediate(tex);
Resources.UnloadUnusedAssets();
i++;
}
observation_matrix_list.Add(observation_matrix);
}
return observation_matrix_list;
}
}