浏览代码

Merge branch 'master' into develop-windows-delay

/develop/windows-delay
Ruo-Ping Dong 4 年前
当前提交
953cb6bb
共有 12 个文件被更改,包括 45 次插入18 次删除
  1. 4
      gym-unity/gym_unity/tests/test_gym.py
  2. 8
      ml-agents/mlagents/trainers/policy/torch_policy.py
  3. 2
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  4. 2
      ml-agents/mlagents/trainers/ppo/trainer.py
  5. 2
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py
  6. 2
      ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
  7. 2
      ml-agents/mlagents/trainers/tests/torch/test_utils.py
  8. 10
      ml-agents/mlagents/trainers/torch/distributions.py
  9. 8
      ml-agents/mlagents/trainers/torch/encoders.py
  10. 12
      ml-agents/mlagents/trainers/torch/layers.py
  11. 6
      ml-agents/mlagents/trainers/torch/model_serialization.py
  12. 5
      ml-agents/mlagents/trainers/torch/networks.py

4
gym-unity/gym_unity/tests/test_gym.py


assert isinstance(done, (bool, np.bool_))
assert isinstance(info, dict)
# check behaviour for allow_multiple_obs = False
# check behavior for allow_multiple_obs = False
env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8, allow_multiple_obs=False)
assert isinstance(env, UnityToGymWrapper)
assert isinstance(env.observation_space, spaces.Box)

assert isinstance(done, (bool, np.bool_))
assert isinstance(info, dict)
# check behaviour for allow_multiple_obs = False
# check behavior for allow_multiple_obs = False
env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8, allow_multiple_obs=False)
assert isinstance(env, UnityToGymWrapper)
assert isinstance(env.observation_space, spaces.Box)

8
ml-agents/mlagents/trainers/policy/torch_policy.py


conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)
self._clip_action = not tanh_squash
# Save the m_size needed for export
self._export_m_size = self.m_size
# m_size needed for training is determined by network, not trainer settings

action, log_probs, entropy, memories = self.sample_actions(
vec_obs, vis_obs, masks=masks, memories=memories
)
run_out["action"] = ModelUtils.to_numpy(action)
if self._clip_action and self.use_continuous_act:
clipped_action = torch.clamp(action, -3, 3) / 3
else:
clipped_action = action
run_out["action"] = ModelUtils.to_numpy(clipped_action)
# Todo - make pre_action difference
run_out["log_probs"] = ModelUtils.to_numpy(log_probs)
run_out["entropy"] = ModelUtils.to_numpy(entropy)

2
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])]
act_masks = ModelUtils.list_to_tensor(batch["action_mask"])
if self.policy.use_continuous_act:
actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1)
actions = ModelUtils.list_to_tensor(batch["actions_pre"]).unsqueeze(-1)
else:
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long)

2
ml-agents/mlagents/trainers/ppo/trainer.py


behavior_spec,
self.trainer_settings,
condition_sigma_on_obs=False, # Faster training for PPO
separate_critic=behavior_spec.action_spec.is_continuous(),
separate_critic=True, # Match network architecture with TF
)
return policy

2
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py


buffer = create_agent_buffer(behavior_spec, 5)
curiosity_rp.update(buffer)
reward_old = curiosity_rp.evaluate(buffer)[0]
for _ in range(10):
for _ in range(20):
curiosity_rp.update(buffer)
reward_new = curiosity_rp.evaluate(buffer)[0]
assert reward_new < reward_old

2
ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py


PPO_TORCH_CONFIG,
hyperparameters=new_hyperparams,
network_settings=new_network_settings,
max_steps=5000,
max_steps=6000,
)
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.9)

2
ml-agents/mlagents/trainers/tests/torch/test_utils.py


action_list, dist_list
)
assert log_probs.shape == (1, 2, 2)
assert entropies.shape == (1, 2, 2)
assert entropies.shape == (1, 1, 2)
assert all_probs is None
for log_prob in log_probs.flatten():

10
ml-agents/mlagents/trainers/torch/distributions.py


return torch.exp(log_prob)
def entropy(self):
return 0.5 * torch.log(2 * math.pi * math.e * self.std + EPSILON)
return torch.mean(
0.5 * torch.log(2 * math.pi * math.e * self.std + EPSILON),
dim=1,
keepdim=True,
) # Use equivalent behavior to TF
class TanhGaussianDistInstance(GaussianDistInstance):

hidden_size,
num_outputs,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=0.1,
kernel_gain=0.2,
bias_init=Initialization.Zero,
)
self.tanh_squash = tanh_squash

num_outputs,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=0.1,
kernel_gain=0.2,
bias_init=Initialization.Zero,
)
else:

8
ml-agents/mlagents/trainers/torch/encoders.py


self.final_flat,
self.h_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
kernel_gain=1.41, # Use ReLU gain
),
nn.LeakyReLU(),
)

self.final_flat,
self.h_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
kernel_gain=1.41, # Use ReLU gain
),
nn.LeakyReLU(),
)

self.final_flat,
self.h_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
kernel_gain=1.41, # Use ReLU gain
),
nn.LeakyReLU(),
)

n_channels[-1] * height * width,
output_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
kernel_gain=1.41, # Use ReLU gain
)
self.sequential = nn.Sequential(*layers)

12
ml-agents/mlagents/trainers/torch/layers.py


:param output_size: The size of the output tensor
:param kernel_init: The Initialization to use for the weights of the layer
:param kernel_gain: The multiplier for the weights of the kernel. Note that in
TensorFlow, calling variance_scaling with scale 0.01 is equivalent to calling
KaimingHeNormal with kernel_gain of 0.1
TensorFlow, the gain is square-rooted. Therefore calling with scale 0.01 is equivalent to calling
KaimingHeNormal with kernel_gain of 0.1
_init_methods[kernel_init](layer.weight.data)
if (
kernel_init == Initialization.KaimingHeNormal
or kernel_init == Initialization.KaimingHeUniform
):
_init_methods[kernel_init](layer.weight.data, nonlinearity="linear")
else:
_init_methods[kernel_init](layer.weight.data)
layer.weight.data *= kernel_gain
_init_methods[bias_init](layer.bias.data)
return layer

6
ml-agents/mlagents/trainers/torch/model_serialization.py


This implementation is thread safe.
"""
# local is_exporting flag for each thread
# global lock shared among all threads, to make sure only one thread is exporting at a time
_lock = threading.Lock()
self._lock.acquire()
self._lock.release()
@staticmethod
def is_exporting():

5
ml-agents/mlagents/trainers/torch/networks.py


self.distribution = MultiCategoricalDistribution(
self.encoding_size, self.action_spec.discrete_branches
)
# During training, clipping is done in TorchPolicy, but we need to clip before ONNX
# export as well.
self._clip_action_on_export = not tanh_squash
@property
def memory_size(self) -> int:

if self.action_spec.is_continuous():
action_list = self.sample_action(dists)
action_out = torch.stack(action_list, dim=-1)
if self._clip_action_on_export:
action_out = torch.clamp(action_out, -3, 3) / 3
else:
action_out = torch.cat([dist.all_log_prob() for dist in dists], dim=1)
return (

正在加载...
取消
保存