浏览代码

[MLA-825] Add default values for SideChannel IncomingMessages methods (#3751)

* C# SideChannels

* docstrings and python

* changelog

* Update com.unity.ml-agents/Runtime/SideChannels/IncomingMessage.cs

Co-Authored-By: Chris Goy <goyenator@gmail.com>

* compile fix

* Update com.unity.ml-agents/Runtime/SideChannels/IncomingMessage.cs

Co-Authored-By: Chris Goy <christopherg@unity3d.com>

Co-authored-by: Chris Goy <goyenator@gmail.com>
Co-authored-by: Chris Goy <christopherg@unity3d.com>
/develop/add-fire
GitHub 5 年前
当前提交
3249910d
共有 5 个文件被更改,包括 115 次插入14 次删除
  1. 1
      com.unity.ml-agents/CHANGELOG.md
  2. 44
      com.unity.ml-agents/Runtime/SideChannels/IncomingMessage.cs
  3. 26
      com.unity.ml-agents/Tests/Editor/SideChannelTests.cs
  4. 38
      ml-agents-envs/mlagents_envs/side_channel/incoming_message.py
  5. 20
      ml-agents-envs/mlagents_envs/tests/test_side_channel.py

1
com.unity.ml-agents/CHANGELOG.md


- Format of console output has changed slightly and now matches the name of the model/summary directory. (#3630, #3616)
- Added a feature to allow sending stats from C# environments to TensorBoard (and other python StatsWriters). To do this from your code, use `SideChannelUtils.GetSideChannel<StatsSideChannel>().AddStat(key, value)` (#3660)
- Renamed 'Generalization' feature to 'Environment Parameter Randomization'.
- SideChannel IncomingMessages methods now take an optional default argument, which is used when trying to read more data than the message contains.
- The way that UnityEnvironment decides the port was changed. If no port is specified, the behavior will depend on the `file_name` parameter. If it is `None`, 5004 (the editor port) will be used; otherwise 5005 (the base environment port) will be used.
- Fixed an issue where exceptions from environments provided a returncode of 0. (#3680)
- Running `mlagents-learn` with the same `--run-id` twice will no longer overwrite the existing files. (#3705)

44
com.unity.ml-agents/Runtime/SideChannels/IncomingMessage.cs


using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System;
using System.IO;
using System.Text;

}
/// <summary>
/// Read a boolan value from the message.
/// Read a boolean value from the message.
/// <param name="defaultValue">Default value to use if the end of the message is reached.</param>
public bool ReadBoolean()
public bool ReadBoolean(bool defaultValue = false)
return m_Reader.ReadBoolean();
return CanReadMore() ? m_Reader.ReadBoolean() : defaultValue;
/// <param name="defaultValue">Default value to use if the end of the message is reached.</param>
public int ReadInt32()
public int ReadInt32(int defaultValue = 0)
return m_Reader.ReadInt32();
return CanReadMore() ? m_Reader.ReadInt32() : defaultValue;
/// <param name="defaultValue">Default value to use if the end of the message is reached.</param>
public float ReadFloat32()
public float ReadFloat32(float defaultValue = 0.0f)
return m_Reader.ReadSingle();
return CanReadMore() ? m_Reader.ReadSingle() : defaultValue;
/// <param name="defaultValue">Default value to use if the end of the message is reached.</param>
public string ReadString()
public string ReadString(string defaultValue = default)
if (!CanReadMore())
{
return defaultValue;
}
var strLength = ReadInt32();
var str = Encoding.ASCII.GetString(m_Reader.ReadBytes(strLength));
return str;

/// Reads a list of floats from the message. The length of the list is stored in the message.
/// </summary>
/// <param name="defaultValue">Default value to use if the end of the message is reached.</param>
public IList<float> ReadFloatList()
public IList<float> ReadFloatList(IList<float> defaultValue = default)
if (!CanReadMore())
{
return defaultValue;
}
var len = ReadInt32();
var output = new float[len];
for (var i = 0; i < len; i++)

{
m_Reader?.Dispose();
m_Stream?.Dispose();
}
/// <summary>
/// Whether or not there is more data left in the stream that can be read.
/// </summary>
/// <returns></returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
bool CanReadMore()
{
return m_Stream.Position < m_Stream.Length;
}
}
}

26
com.unity.ml-agents/Tests/Editor/SideChannelTests.cs


Assert.AreEqual(stringVal, incomingMsg.ReadString());
Assert.AreEqual(floatListVal, incomingMsg.ReadFloatList());
}
[Test]
public void TestMessageReadDefaults()
{
// Make sure reading past the end of a message will apply defaults.
IncomingMessage incomingMsg;
using (var outgoingMsg = new OutgoingMessage())
{
incomingMsg = new IncomingMessage(outgoingMsg.ToByteArray());
}
Assert.AreEqual(false, incomingMsg.ReadBoolean());
Assert.AreEqual(true, incomingMsg.ReadBoolean(defaultValue: true));
Assert.AreEqual(0, incomingMsg.ReadInt32());
Assert.AreEqual(42, incomingMsg.ReadInt32(defaultValue: 42));
Assert.AreEqual(0.0f, incomingMsg.ReadFloat32());
Assert.AreEqual(1337.0f, incomingMsg.ReadFloat32(defaultValue: 1337.0f));
Assert.AreEqual(default(string), incomingMsg.ReadString());
Assert.AreEqual("foo", incomingMsg.ReadString(defaultValue: "foo"));
Assert.AreEqual(default(float[]), incomingMsg.ReadFloatList());
Assert.AreEqual(new float[] { 1001, 1002 }, incomingMsg.ReadFloatList(new float[] { 1001, 1002 }));
}
}
}

38
ml-agents-envs/mlagents_envs/side_channel/incoming_message.py


self.buffer = buffer
self.offset = offset
def read_bool(self) -> bool:
def read_bool(self, default_value: bool = False) -> bool:
:param default_value: Default value to use if the end of the message is reached.
:return: The value read from the message, or the default value if the end was reached.
if self._at_end_of_buffer():
return default_value
def read_int32(self) -> int:
def read_int32(self, default_value: int = 0) -> int:
:param default_value: Default value to use if the end of the message is reached.
:return: The value read from the message, or the default value if the end was reached.
if self._at_end_of_buffer():
return default_value
def read_float32(self) -> float:
def read_float32(self, default_value: float = 0.0) -> float:
:param default_value: Default value to use if the end of the message is reached.
:return: The value read from the message, or the default value if the end was reached.
if self._at_end_of_buffer():
return default_value
def read_float32_list(self) -> List[float]:
def read_float32_list(self, default_value: List[float] = None) -> List[float]:
:param default_value: Default value to use if the end of the message is reached.
:return: The value read from the message, or the default value if the end was reached.
if self._at_end_of_buffer():
return [] if default_value is None else default_value
list_len = self.read_int32()
output = []
for _ in range(list_len):

def read_string(self) -> str:
def read_string(self, default_value: str = "") -> str:
:param default_value: Default value to use if the end of the message is reached.
:return: The value read from the message, or the default value if the end was reached.
if self._at_end_of_buffer():
return default_value
encoded_str_len = self.read_int32()
val = self.buffer[self.offset : self.offset + encoded_str_len].decode("ascii")
self.offset += encoded_str_len

Get a copy of the internal bytes used by the message.
"""
return bytearray(self.buffer)
def _at_end_of_buffer(self) -> bool:
return self.offset >= len(self.buffer)

20
ml-agents-envs/mlagents_envs/tests/test_side_channel.py


read_vals.append(msg_in.read_bool())
assert vals == read_vals
# Test reading with defaults
assert msg_in.read_bool() is False
assert msg_in.read_bool(default_value=True) is True
def test_message_int32():
val = 1337

read_val = msg_in.read_int32()
assert val == read_val
# Test reading with defaults
assert 0 == msg_in.read_int32()
assert val == msg_in.read_int32(default_value=val)
def test_message_float32():
val = 42.0

# These won't be exactly equal in general, since python floats are 64-bit.
assert val == read_val
# Test reading with defaults
assert 0.0 == msg_in.read_float32()
assert val == msg_in.read_float32(default_value=val)
def test_message_string():
val = "mlagents!"

read_val = msg_in.read_string()
assert val == read_val
# Test reading with defaults
assert "" == msg_in.read_string()
assert val == msg_in.read_string(default_value=val)
def test_message_float_list():
val = [1.0, 3.0, 9.0]

read_val = msg_in.read_float32_list()
# These won't be exactly equal in general, since python floats are 64-bit.
assert val == read_val
# Test reading with defaults
assert [] == msg_in.read_float32_list()
assert val == msg_in.read_float32_list(default_value=val)
正在加载...
取消
保存