|
|
|
|
|
|
import uuid |
|
|
|
import pytest |
|
|
|
from mlagents_envs.side_channel.engine_configuration_channel import ( |
|
|
|
EngineConfigurationChannel, |
|
|
|
EngineConfig, |
|
|
|
) |
|
|
|
from mlagents_envs.side_channel.environment_parameters_channel import ( |
|
|
|
EnvironmentParametersChannel, |
|
|
|
) |
|
|
|
from mlagents_envs.side_channel.stats_side_channel import ( |
|
|
|
StatsSideChannel, |
|
|
|
StatsAggregationMethod, |
|
|
|
) |
|
|
|
from mlagents_envs.exception import ( |
|
|
|
UnitySideChannelException, |
|
|
|
UnityCommunicationException, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class IntChannel(SideChannel): |
|
|
|
|
|
|
# Test reading with defaults |
|
|
|
assert [] == msg_in.read_float32_list() |
|
|
|
assert val == msg_in.read_float32_list(default_value=val) |
|
|
|
|
|
|
|
|
|
|
|
def test_engine_configuration(): |
|
|
|
sender = EngineConfigurationChannel() |
|
|
|
# We use a raw bytes channel to interpred the data |
|
|
|
receiver = RawBytesChannel(sender.channel_id) |
|
|
|
|
|
|
|
config = EngineConfig.default_config() |
|
|
|
sender.set_configuration(config) |
|
|
|
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender}) |
|
|
|
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data) |
|
|
|
|
|
|
|
received_data = receiver.get_and_clear_received_messages() |
|
|
|
assert len(received_data) == 5 # 5 different messages one for each setting |
|
|
|
|
|
|
|
sent_time_scale = 4.5 |
|
|
|
sender.set_configuration_parameters(time_scale=sent_time_scale) |
|
|
|
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender}) |
|
|
|
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data) |
|
|
|
|
|
|
|
message = IncomingMessage(receiver.get_and_clear_received_messages()[0]) |
|
|
|
message.read_int32() |
|
|
|
time_scale = message.read_float32() |
|
|
|
assert time_scale == sent_time_scale |
|
|
|
|
|
|
|
with pytest.raises(UnitySideChannelException): |
|
|
|
sender.set_configuration_parameters(width=None, height=42) |
|
|
|
|
|
|
|
with pytest.raises(UnityCommunicationException): |
|
|
|
# try to send data to the EngineConfigurationChannel |
|
|
|
sender.set_configuration_parameters(time_scale=sent_time_scale) |
|
|
|
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender}) |
|
|
|
UnityEnvironment._parse_side_channel_message( |
|
|
|
{receiver.channel_id: sender}, data |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def test_environment_parameters(): |
|
|
|
sender = EnvironmentParametersChannel() |
|
|
|
# We use a raw bytes channel to interpred the data |
|
|
|
receiver = RawBytesChannel(sender.channel_id) |
|
|
|
|
|
|
|
sender.set_float_parameter("param-1", 0.1) |
|
|
|
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender}) |
|
|
|
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data) |
|
|
|
|
|
|
|
message = IncomingMessage(receiver.get_and_clear_received_messages()[0]) |
|
|
|
key = message.read_string() |
|
|
|
dtype = message.read_int32() |
|
|
|
value = message.read_float32() |
|
|
|
assert key == "param-1" |
|
|
|
assert dtype == EnvironmentParametersChannel.EnvironmentDataTypes.FLOAT |
|
|
|
assert value - 0.1 < 1e-8 |
|
|
|
|
|
|
|
sender.set_float_parameter("param-1", 0.1) |
|
|
|
sender.set_float_parameter("param-2", 0.1) |
|
|
|
sender.set_float_parameter("param-3", 0.1) |
|
|
|
|
|
|
|
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender}) |
|
|
|
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data) |
|
|
|
|
|
|
|
assert len(receiver.get_and_clear_received_messages()) == 3 |
|
|
|
|
|
|
|
with pytest.raises(UnityCommunicationException): |
|
|
|
# try to send data to the EngineConfigurationChannel |
|
|
|
sender.set_float_parameter("param-1", 0.1) |
|
|
|
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender}) |
|
|
|
UnityEnvironment._parse_side_channel_message( |
|
|
|
{receiver.channel_id: sender}, data |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def test_stats_channel(): |
|
|
|
receiver = StatsSideChannel() |
|
|
|
message = OutgoingMessage() |
|
|
|
message.write_string("stats-1") |
|
|
|
message.write_float32(42.0) |
|
|
|
message.write_int32(1) # corresponds to StatsAggregationMethod.MOST_RECENT |
|
|
|
|
|
|
|
receiver.on_message_received(IncomingMessage(message.buffer)) |
|
|
|
|
|
|
|
stats = receiver.get_and_reset_stats() |
|
|
|
|
|
|
|
assert len(stats) == 1 |
|
|
|
val, method = stats["stats-1"] |
|
|
|
assert val - 42.0 < 1e-8 |
|
|
|
assert method == StatsAggregationMethod.MOST_RECENT |