浏览代码

Reset StackingSensor when the Agent resets (#3727)

* sensors.Reset() WIP

* fix test implementations

* call reset from Agent
/develop/add-fire
GitHub 5 年前
当前提交
89237f96
共有 13 个文件被更改,包括 76 次插入0 次删除
  1. 3
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs
  2. 9
      com.unity.ml-agents/Runtime/Agent.cs
  3. 3
      com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
  4. 6
      com.unity.ml-agents/Runtime/Sensors/ISensor.cs
  5. 3
      com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
  6. 3
      com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs
  7. 13
      com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
  8. 6
      com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
  9. 9
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  10. 1
      com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs
  11. 1
      com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs
  12. 1
      com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs
  13. 18
      com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs

3
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/SensorBase.cs


public void Update() {}
/// <inheritdoc/>
public void Reset() { }
/// <inheritdoc/>
public virtual byte[] GetCompressedObservation()
{
return null;

9
com.unity.ml-agents/Runtime/Agent.cs


// Request the last decision with no callbacks
// We request a decision so Python knows the Agent is done immediately
m_Brain?.RequestDecision(m_Info, sensors);
ResetSensors();
// We also have to write any to any DemonstationStores so that they get the "done" flag.
foreach (var demoWriter in DemonstrationWriters)

foreach (var sensor in sensors)
{
sensor.Update();
}
}
void ResetSensors()
{
foreach (var sensor in sensors)
{
sensor.Reset();
}
}

3
com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs


public void Update() {}
/// <inheritdoc/>
public void Reset() { }
/// <inheritdoc/>
public SensorCompressionType GetCompressionType()
{
return m_CompressionType;

6
com.unity.ml-agents/Runtime/Sensors/ISensor.cs


void Update();
/// <summary>
/// Resets the internal states of the sensor. This is called at the end of an Agent's episode.
/// Most implementations can leave this empty.
/// </summary>
void Reset();
/// <summary>
/// Return the compression type being used. If no compression is used, return
/// <see cref="SensorCompressionType.None"/>.
/// </summary>

3
com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs


}
/// <inheritdoc/>
public void Reset() { }
/// <inheritdoc/>
public int[] GetObservationShape()
{
return m_Shape;

3
com.unity.ml-agents/Runtime/Sensors/RenderTextureSensor.cs


public void Update() {}
/// <inheritdoc/>
public void Reset() { }
/// <inheritdoc/>
public SensorCompressionType GetCompressionType()
{
return m_CompressionType;

13
com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs


using System;
namespace MLAgents.Sensors
{
/// <summary>

{
m_WrappedSensor.Update();
m_CurrentIndex = (m_CurrentIndex + 1) % m_NumStackedObservations;
}
/// <inheritdoc/>
public void Reset()
{
m_WrappedSensor.Reset();
// Zero out the buffer.
for (var i = 0; i < m_NumStackedObservations; i++)
{
Array.Clear(m_StackedObservations[i], 0, m_StackedObservations[i].Length);
}
}
/// <inheritdoc/>

6
com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs


}
/// <inheritdoc/>
public void Reset()
{
Clear();
}
/// <inheritdoc/>
public int[] GetObservationShape()
{
return m_Shape;

9
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs


public string sensorName;
public int numWriteCalls;
public int numCompressedCalls;
public int numResetCalls;
public SensorCompressionType compressionType = SensorCompressionType.None;
public TestSensor(string n)

}
public void Update() {}
public void Reset()
{
numResetCalls++;
}
}
[TestFixture]

var expectedAgentActionForEpisode = 0;
var expectedCollectObsCalls = 0;
var expectedCollectObsCallsForEpisode = 0;
var expectedSensorResetCalls = 0;
for (var i = 0; i < 15; i++)
{

expectedAgentActionForEpisode = 0;
expectedCollectObsCallsForEpisode = 0;
expectedAgentStepCount = 0;
expectedSensorResetCalls++;
}
aca.EnvironmentStep();

Assert.AreEqual(expectedAgentActionForEpisode, agent1.agentActionCallsForEpisode);
Assert.AreEqual(expectedCollectObsCalls, agent1.collectObservationsCalls);
Assert.AreEqual(expectedCollectObsCallsForEpisode, agent1.collectObservationsCallsForEpisode);
Assert.AreEqual(expectedSensorResetCalls, agent1.sensor1.numResetCalls);
}
}

1
com.unity.ml-agents/Tests/Editor/ParameterLoaderTest.cs


}
public void Update() {}
public void Reset() { }
public SensorCompressionType GetCompressionType()
{

1
com.unity.ml-agents/Tests/Editor/Sensor/FloatVisualSensorTests.cs


}
public void Update() {}
public void Reset() { }
public SensorCompressionType GetCompressionType()
{

1
com.unity.ml-agents/Tests/Editor/Sensor/SensorShapeValidatorTests.cs


}
public void Update() { }
public void Reset() { }
public SensorCompressionType GetCompressionType()
{

18
com.unity.ml-agents/Tests/Editor/Sensor/StackingSensorTests.cs


// Check that if we don't call Update(), the same observations are produced
SensorTestHelper.CompareObservation(sensor, new[] {5f, 6f, 7f, 8f, 9f, 10f});
}
[Test]
public void TestStackingReset()
{
VectorSensor wrapped = new VectorSensor(2);
ISensor sensor = new StackingSensor(wrapped, 3);
wrapped.AddObservation(new[] {1f, 2f});
SensorTestHelper.CompareObservation(sensor, new[] {0f, 0f, 0f, 0f, 1f, 2f});
sensor.Update();
wrapped.AddObservation(new[] {3f, 4f});
SensorTestHelper.CompareObservation(sensor, new[] {0f, 0f, 1f, 2f, 3f, 4f});
sensor.Reset();
wrapped.AddObservation(new[] {5f, 6f});
SensorTestHelper.CompareObservation(sensor, new[] {0f, 0f, 0f, 0f, 5f, 6f});
}
}
}
正在加载...
取消
保存