浏览代码

added some side channel tests (#3937)

/docs-update
GitHub 4 年前
当前提交
21f0029f
共有 1 个文件被更改,包括 103 次插入0 次删除
  1. 103
      ml-agents-envs/mlagents_envs/tests/test_side_channel.py

103
ml-agents-envs/mlagents_envs/tests/test_side_channel.py


import uuid
import pytest
from mlagents_envs.side_channel.engine_configuration_channel import (
EngineConfigurationChannel,
EngineConfig,
)
from mlagents_envs.side_channel.environment_parameters_channel import (
EnvironmentParametersChannel,
)
from mlagents_envs.side_channel.stats_side_channel import (
StatsSideChannel,
StatsAggregationMethod,
)
from mlagents_envs.exception import (
UnitySideChannelException,
UnityCommunicationException,
)
class IntChannel(SideChannel):

# Test reading with defaults
assert [] == msg_in.read_float32_list()
assert val == msg_in.read_float32_list(default_value=val)
def test_engine_configuration():
sender = EngineConfigurationChannel()
# We use a raw bytes channel to interpred the data
receiver = RawBytesChannel(sender.channel_id)
config = EngineConfig.default_config()
sender.set_configuration(config)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
received_data = receiver.get_and_clear_received_messages()
assert len(received_data) == 5 # 5 different messages one for each setting
sent_time_scale = 4.5
sender.set_configuration_parameters(time_scale=sent_time_scale)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
message = IncomingMessage(receiver.get_and_clear_received_messages()[0])
message.read_int32()
time_scale = message.read_float32()
assert time_scale == sent_time_scale
with pytest.raises(UnitySideChannelException):
sender.set_configuration_parameters(width=None, height=42)
with pytest.raises(UnityCommunicationException):
# try to send data to the EngineConfigurationChannel
sender.set_configuration_parameters(time_scale=sent_time_scale)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message(
{receiver.channel_id: sender}, data
)
def test_environment_parameters():
sender = EnvironmentParametersChannel()
# We use a raw bytes channel to interpred the data
receiver = RawBytesChannel(sender.channel_id)
sender.set_float_parameter("param-1", 0.1)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
message = IncomingMessage(receiver.get_and_clear_received_messages()[0])
key = message.read_string()
dtype = message.read_int32()
value = message.read_float32()
assert key == "param-1"
assert dtype == EnvironmentParametersChannel.EnvironmentDataTypes.FLOAT
assert value - 0.1 < 1e-8
sender.set_float_parameter("param-1", 0.1)
sender.set_float_parameter("param-2", 0.1)
sender.set_float_parameter("param-3", 0.1)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
assert len(receiver.get_and_clear_received_messages()) == 3
with pytest.raises(UnityCommunicationException):
# try to send data to the EngineConfigurationChannel
sender.set_float_parameter("param-1", 0.1)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message(
{receiver.channel_id: sender}, data
)
def test_stats_channel():
receiver = StatsSideChannel()
message = OutgoingMessage()
message.write_string("stats-1")
message.write_float32(42.0)
message.write_int32(1) # corresponds to StatsAggregationMethod.MOST_RECENT
receiver.on_message_received(IncomingMessage(message.buffer))
stats = receiver.get_and_reset_stats()
assert len(stats) == 1
val, method = stats["stats-1"]
assert val - 42.0 < 1e-8
assert method == StatsAggregationMethod.MOST_RECENT
正在加载...
取消
保存