浏览代码

Merge branch 'develop-add-fire' into develop-add-fire-halfentropy

/develop/add-fire/halfentropy
Ervin Teng 4 年前
当前提交
b2872adf
共有 3 个文件被更改,包括 7 次插入5 次删除
  1. 4
      ml-agents/mlagents/trainers/tests/torch/test_decoders.py
  2. 2
      ml-agents/mlagents/trainers/tests/torch/test_distributions.py
  3. 6
      ml-agents/mlagents/trainers/tests/torch/test_utils.py

4
ml-agents/mlagents/trainers/tests/torch/test_decoders.py


# Test default 1 value per head
value_heads = ValueHeads(stream_names, input_size)
input_data = torch.ones((batch_size, input_size))
value_out, _ = value_heads(input_data) # Note: mean value will be removed shortly
value_out = value_heads(input_data) # Note: mean value will be removed shortly
for stream_name in stream_names:
assert value_out[stream_name].shape == (batch_size,)

output_size = 4
value_heads = ValueHeads(stream_names, input_size, output_size)
input_data = torch.ones((batch_size, input_size))
value_out, _ = value_heads(input_data)
value_out = value_heads(input_data)
for stream_name in stream_names:
assert value_out[stream_name].shape == (batch_size, output_size)

2
ml-agents/mlagents/trainers/tests/torch/test_distributions.py


def test_tanh_gaussian_dist_instance():
torch.manual_seed(0)
act_size = 4
dist_instance = GaussianDistInstance(
dist_instance = TanhGaussianDistInstance(
torch.zeros(1, act_size), torch.ones(1, act_size)
)
for _ in range(10):

6
ml-agents/mlagents/trainers/tests/torch/test_utils.py


for encoder_type in EncoderType:
good_size = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type]
vis_input = torch.ones((1, 3, good_size, good_size))
ModelUtils._check_resolution_for_encoder(vis_input, encoder_type)
ModelUtils._check_resolution_for_encoder(good_size, good_size, encoder_type)
enc_func = ModelUtils.get_encoder_for_type(encoder_type)
enc = enc_func(good_size, good_size, 3, 1)
enc.forward(vis_input)

with pytest.raises(UnityTrainerException):
# Make sure we'd hit a friendly error during model setup time.
ModelUtils._check_resolution_for_encoder(vis_input, encoder_type)
ModelUtils._check_resolution_for_encoder(
bad_size, bad_size, encoder_type
)
enc = enc_func(bad_size, bad_size, 3, 1)
enc.forward(vis_input)

正在加载...
取消
保存