using System.IO.Abstractions; using System.Text.RegularExpressions; using UnityEngine; using System.Collections.Generic; using MLAgents.Sensor; namespace MLAgents { /// /// Demonstration Recorder Component. /// [RequireComponent(typeof(Agent))] [AddComponentMenu("ML Agents/Demonstration Recorder", (int)MenuGroup.Default)] public class DemonstrationRecorder : MonoBehaviour { public bool record; public string demonstrationName; Agent m_RecordingAgent; string m_FilePath; DemonstrationStore m_DemoStore; public const int MaxNameLength = 16; void Start() { if (Application.isEditor && record) { InitializeDemoStore(); } } void Update() { if (Application.isEditor && record && m_DemoStore == null) { InitializeDemoStore(); } } /// /// Creates demonstration store for use in recording. /// public void InitializeDemoStore(IFileSystem fileSystem = null) { m_RecordingAgent = GetComponent(); m_DemoStore = new DemonstrationStore(fileSystem); var behaviorParams = GetComponent(); demonstrationName = SanitizeName(demonstrationName, MaxNameLength); m_DemoStore.Initialize( demonstrationName, behaviorParams.brainParameters, behaviorParams.fullyQualifiedBehaviorName); Monitor.Log("Recording Demonstration of Agent: ", m_RecordingAgent.name); } /// /// Removes all characters except alphanumerics from demonstration name. /// Shorten name if it is longer than the maxNameLength. /// public static string SanitizeName(string demoName, int maxNameLength) { var rgx = new Regex("[^a-zA-Z0-9 -]"); demoName = rgx.Replace(demoName, ""); // If the string is too long, it will overflow the metadata. if (demoName.Length > maxNameLength) { demoName = demoName.Substring(0, maxNameLength); } return demoName; } /// /// Forwards AgentInfo to Demonstration Store. /// public void WriteExperience(AgentInfo info, List sensors) { m_DemoStore?.Record(info, sensors); } public void Close() { if (m_DemoStore != null) { m_DemoStore.Close(); m_DemoStore = null; } } /// /// Closes Demonstration store. /// void OnApplicationQuit() { if (Application.isEditor && record) { Close(); } } } }