浏览代码

fixing the tests

/develop/rm-rf-new-models
vincentpierre 4 年前
当前提交
93ca1409
共有 17 个文件被更改,包括 327 次插入191 次删除
  1. 2
      ml-agents/mlagents/trainers/demo_loader.py
  2. 4
      ml-agents/mlagents/trainers/policy/torch_policy.py
  3. 1
      ml-agents/mlagents/trainers/ppo/trainer.py
  4. 31
      ml-agents/mlagents/trainers/tests/test_trajectory.py
  5. 10
      ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py
  6. 2
      ml-agents/mlagents/trainers/tests/torch/test_encoders.py
  7. 15
      ml-agents/mlagents/trainers/tests/torch/test_networks.py
  8. 21
      ml-agents/mlagents/trainers/tests/torch/test_policy.py
  9. 18
      ml-agents/mlagents/trainers/tests/torch/test_utils.py
  10. 27
      ml-agents/mlagents/trainers/torch/components/bc/module.py
  11. 39
      ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
  12. 72
      ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
  13. 21
      ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py
  14. 18
      ml-agents/mlagents/trainers/torch/networks.py
  15. 17
      ml-agents/mlagents/trainers/torch/utils.py
  16. 5
      ml-agents/mlagents/trainers/trajectory.py
  17. 215
      ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider copy.py

2
ml-agents/mlagents/trainers/demo_loader.py


demo_raw_buffer["done"].append(next_done)
demo_raw_buffer["rewards"].append(next_reward)
for i, obs in enumerate(current_obs):
demo_raw_buffer[ObsUtil.get_name_at(i)].append(obs[i])
demo_raw_buffer[ObsUtil.get_name_at(i)].append(obs)
demo_raw_buffer["actions"].append(current_pair_info.action_info.vector_actions)
demo_raw_buffer["prev_action"].append(previous_action)
if next_done:

4
ml-agents/mlagents/trainers/policy/torch_policy.py


entropies, and output memories, all as Torch Tensors.
"""
if memories is None:
dists, memories = self.actor_critic.get_dists(
obs, masks, memories, seq_len
)
dists, memories = self.actor_critic.get_dists(obs, masks, memories, seq_len)
else:
# If we're using LSTM. we need to execute the values to get the critic memories
dists, _, memories = self.actor_critic.get_dist_and_value(

1
ml-agents/mlagents/trainers/ppo/trainer.py


from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.settings import TrainerSettings, PPOSettings
from mlagents.trainers.buffer import AgentBuffer
logger = get_logger(__name__)

31
ml-agents/mlagents/trainers/tests/test_trajectory.py


import numpy as np
import pytest
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.tests.mock_brain import make_fake_trajectory
from mlagents_envs.base_env import ActionSpec

@pytest.mark.parametrize("num_visual_obs", [0, 1, 2])
@pytest.mark.parametrize("num_vec_obs", [0, 1])
def test_split_obs(num_visual_obs, num_vec_obs):
obs = []
for _ in range(num_visual_obs):
obs.append(np.ones((84, 84, 3), dtype=np.float32))
for _ in range(num_vec_obs):
obs.append(np.ones(VEC_OBS_SIZE, dtype=np.float32))
split_observations = SplitObservations.from_observations(obs)
if num_vec_obs == 1:
assert len(split_observations.vector_observations) == VEC_OBS_SIZE
else:
assert len(split_observations.vector_observations) == 0
# Assert the number of vector observations.
assert len(split_observations.visual_observations) == num_visual_obs
"next_visual_obs0",
"visual_obs0",
"vector_obs",
"next_vector_in",
"next_obs_0",
"next_obs_1",
"obs_0",
"obs_1",
"memory",
"masks",
"done",

10
ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py


from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.tests.torch.test_policy import create_policy_mock
from mlagents.trainers.torch.utils import ModelUtils
def test_register(tmp_path):

decision_step, _ = mb.create_steps_from_behavior_spec(
policy1.behavior_spec, num_agents=1
)
vec_vis_obs, masks = policy1._split_decision_step(decision_step)
vec_obs = [torch.as_tensor(vec_vis_obs.vector_observations)]
vis_obs = [torch.as_tensor(vis_ob) for vis_ob in vec_vis_obs.visual_observations]
obs, masks = policy1._split_decision_step(decision_step)
obs = [ModelUtils.list_to_tensor(obs) for obs in obs]
vec_obs, vis_obs, masks=masks, memories=memories, all_log_probs=True
obs, masks=masks, memories=memories, all_log_probs=True
vec_obs, vis_obs, masks=masks, memories=memories, all_log_probs=True
obs, masks=masks, memories=memories, all_log_probs=True
)
np.testing.assert_array_equal(log_probs1, log_probs2)

2
ml-agents/mlagents/trainers/tests/torch/test_encoders.py


num_outputs = 128
enc = vis_class(image_size[0], image_size[1], image_size[2], num_outputs)
# Note: NCHW not NHWC
sample_input = torch.ones((1, image_size[2], image_size[0], image_size[1]))
sample_input = torch.ones((1, image_size[0], image_size[1], image_size[2]))
encoding = enc(sample_input)
assert encoding.shape == (1, num_outputs)

15
ml-agents/mlagents/trainers/tests/torch/test_networks.py


sample_act = 0.1 * torch.ones((1, 2))
for _ in range(300):
encoded, _ = networkbody([sample_obs], [], sample_act)
encoded, _ = networkbody([sample_obs], sample_act)
assert encoded.shape == (1, network_settings.hidden_units)
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))

sample_obs = torch.ones((1, seq_len, obs_size))
for _ in range(200):
encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 12))
encoded, _ = networkbody([sample_obs], memories=torch.ones(1, seq_len, 12))
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()

optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
sample_obs = 0.1 * torch.ones((1, 84, 84, 3))
sample_vec_obs = torch.ones((1, vec_obs_size))
obs = [sample_vec_obs] + [sample_obs]
encoded, _ = networkbody([sample_vec_obs], [sample_obs])
encoded, _ = networkbody(obs)
assert encoded.shape == (1, network_settings.hidden_units)
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))

for _ in range(50):
sample_obs = torch.ones((1, obs_size))
values, _ = value_net([sample_obs], [])
values, _ = value_net([sample_obs])
loss = 0
for s_name in stream_names:
assert values[s_name].shape == (1, num_outputs)

actor = SimpleActor(obs_shapes, network_settings, action_spec)
# Test get_dist
sample_obs = torch.ones((1, obs_size))
dists, _ = actor.get_dists([sample_obs], [], masks=masks)
dists, _ = actor.get_dists([sample_obs], masks=masks)
for dist in dists:
if use_discrete:
assert isinstance(dist, CategoricalDistInstance)

# memories isn't always set to None, the network should be able to
# deal with that.
# Test critic pass
value_out, memories_out = actor.critic_pass([sample_obs], [], memories=memories)
value_out, memories_out = actor.critic_pass([sample_obs], memories=memories)
for stream in stream_names:
if lstm:
assert value_out[stream].shape == (network_settings.memory.sequence_length,)

# Test get_dist_and_value
dists, value_out, mem_out = actor.get_dist_and_value(
[sample_obs], [], memories=memories
[sample_obs], memories=memories
)
if mem_out is not None:
assert mem_out.shape == memories.shape

21
ml-agents/mlagents/trainers/tests/torch/test_policy.py


from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.settings import TrainerSettings, NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.trajectory import ObsUtil
VECTOR_ACTION_SPACE = 2
VECTOR_OBS_SPACE = 8

TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size)
vec_obs = [ModelUtils.list_to_tensor(buffer["vector_obs"])]
vis_obs = []
for idx, _ in enumerate(policy.actor_critic.network_body.visual_processors):
vis_ob = ModelUtils.list_to_tensor(buffer["visual_obs%d" % idx])
vis_obs.append(vis_ob)
obs = ObsUtil.from_buffer(buffer, len(policy.behavior_spec.observation_shapes))
obs = [ModelUtils.list_to_tensor(obs) for obs in obs]
memories = [
ModelUtils.list_to_tensor(buffer["memory"][i])

memories = torch.stack(memories).unsqueeze(0)
log_probs, entropy, values = policy.evaluate_actions(
vec_obs,
vis_obs,
obs,
masks=act_masks,
actions=actions,
memories=memories,

TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size)
vec_obs = [ModelUtils.list_to_tensor(buffer["vector_obs"])]
vis_obs = []
for idx, _ in enumerate(policy.actor_critic.network_body.visual_processors):
vis_ob = ModelUtils.list_to_tensor(buffer["visual_obs%d" % idx])
vis_obs.append(vis_ob)
obs = ObsUtil.from_buffer(buffer, len(policy.behavior_spec.observation_shapes))
obs = [ModelUtils.list_to_tensor(obs) for obs in obs]
memories = [
ModelUtils.list_to_tensor(buffer["memory"][i])

entropies,
memories,
) = policy.sample_actions(
vec_obs,
vis_obs,
obs,
masks=act_masks,
memories=memories,
seq_len=policy.sequence_length,

18
ml-agents/mlagents/trainers/tests/torch/test_utils.py


for encoder_type in EncoderType:
good_size = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type]
vis_input = torch.ones((1, 3, good_size, good_size))
vis_input = torch.ones((1, good_size, good_size, 3))
ModelUtils._check_resolution_for_encoder(good_size, good_size, encoder_type)
enc_func = ModelUtils.get_encoder_for_type(encoder_type)
enc = enc_func(good_size, good_size, 3, 1)

with pytest.raises(Exception):
bad_size = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type] - 1
vis_input = torch.ones((1, 3, bad_size, bad_size))
vis_input = torch.ones((1, bad_size, bad_size, 3))
with pytest.raises(UnityTrainerException):
# Make sure we'd hit a friendly error during model setup time.

for _ in range(num_visual):
obs_shapes.append(vis_obs_shape)
h_size = 128
vis_enc, vec_enc, total_output = ModelUtils.create_input_processors(
encoders, embedding_sizes = ModelUtils.create_input_processors(
vec_enc = list(vec_enc)
vis_enc = list(vis_enc)
assert len(vec_enc) == (1 if num_vector >= 1 else 0)
total_output = sum(embedding_sizes)
vec_enc = []
vis_enc = []
for i, enc in enumerate(encoders):
if len(obs_shapes[i]) == 1:
vec_enc.append(enc)
else:
vis_enc.append(enc)
assert len(vec_enc) == num_vector
assert len(vis_enc) == num_visual
assert total_output == int(num_visual * h_size + vec_obs_shape[0] * num_vector)
if num_vector > 0:

27
ml-agents/mlagents/trainers/torch/components/bc/module.py


from mlagents.trainers.demo_loader import demo_to_buffer
from mlagents.trainers.settings import BehavioralCloningSettings, ScheduleType
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.buffer import AgentBuffer
class BCModule:

_, self.demonstration_buffer = demo_to_buffer(
settings.demo_path, policy.sequence_length, policy.behavior_spec
)
self.batch_size = (
settings.batch_size if settings.batch_size else default_batch_size
)

return bc_loss
def _update_batch(
self, mini_batch_demo: Dict[str, np.ndarray], n_sequences: int
self, mini_batch_demo: AgentBuffer, n_sequences: int
vec_obs = [ModelUtils.list_to_tensor(mini_batch_demo["vector_obs"])]
obs = ObsUtil.from_buffer(
mini_batch_demo, len(self.policy.behavior_spec.observation_shapes)
)
# Convert to tensors
obs = [ModelUtils.list_to_tensor(obs) for obs in obs]
print("\n\n\n\n", obs, obs[0].shape)
act_masks = None
if self.policy.use_continuous_act:
expert_actions = ModelUtils.list_to_tensor(mini_batch_demo["actions"])

if self.policy.use_recurrent:
memories = torch.zeros(1, self.n_sequences, self.policy.m_size)
if self.policy.use_vis_obs:
vis_obs = []
for idx, _ in enumerate(
self.policy.actor_critic.network_body.visual_processors
):
vis_ob = ModelUtils.list_to_tensor(
mini_batch_demo["visual_obs%d" % idx]
)
vis_obs.append(vis_ob)
else:
vis_obs = []
(
selected_actions,
clipped_actions,

) = self.policy.sample_actions(
vec_obs,
vis_obs,
obs,
masks=act_masks,
memories=memories,
seq_len=self.policy.sequence_length,

39
ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py


from mlagents.trainers.torch.networks import NetworkBody
from mlagents.trainers.torch.layers import LinearEncoder, linear_layer
from mlagents.trainers.settings import NetworkSettings, EncoderType
from mlagents.trainers.trajectory import ObsUtil
class CuriosityRewardProvider(BaseRewardProvider):

"""
Extracts the current state embedding from a mini_batch.
"""
n_vis = len(self._state_encoder.visual_processors)
hidden, _ = self._state_encoder.forward(
vec_inputs=[
ModelUtils.list_to_tensor(mini_batch["vector_obs"], dtype=torch.float)
],
vis_inputs=[
ModelUtils.list_to_tensor(
mini_batch["visual_obs%d" % i], dtype=torch.float
)
for i in range(n_vis)
],
)
n_obs = len(self._state_encoder.encoders)
obs = ObsUtil.from_buffer(mini_batch, n_obs)
# Convert to tensors
obs = [ModelUtils.list_to_tensor(obs) for obs in obs]
hidden, _ = self._state_encoder.forward(obs)
return hidden
def get_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor:

n_vis = len(self._state_encoder.visual_processors)
hidden, _ = self._state_encoder.forward(
vec_inputs=[
ModelUtils.list_to_tensor(
mini_batch["next_vector_in"], dtype=torch.float
)
],
vis_inputs=[
ModelUtils.list_to_tensor(
mini_batch["next_visual_obs%d" % i], dtype=torch.float
)
for i in range(n_vis)
],
)
n_obs = len(self._state_encoder.encoders)
obs = ObsUtil.from_buffer_next(mini_batch, n_obs)
# Convert to tensors
obs = [ModelUtils.list_to_tensor(obs) for obs in obs]
hidden, _ = self._state_encoder.forward(obs)
return hidden
def predict_action(self, mini_batch: AgentBuffer) -> torch.Tensor:

72
ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py


from typing import Optional, Dict, List, Tuple
from typing import Optional, Dict, List
import numpy as np
from mlagents.torch_utils import torch, default_device

from mlagents.trainers.torch.layers import linear_layer, Initialization
from mlagents.trainers.settings import NetworkSettings, EncoderType
from mlagents.trainers.demo_loader import demo_to_buffer
from mlagents.trainers.trajectory import ObsUtil
class GAILRewardProvider(BaseRewardProvider):

torch.as_tensor(mini_batch["actions"], dtype=torch.float)
)
def get_state_inputs(
self, mini_batch: AgentBuffer
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
def get_state_inputs(self, mini_batch: AgentBuffer) -> List[torch.Tensor]:
n_vis = len(self.encoder.visual_processors)
n_vec = len(self.encoder.vector_processors)
vec_inputs = (
[ModelUtils.list_to_tensor(mini_batch["vector_obs"], dtype=torch.float)]
if n_vec > 0
else []
)
vis_inputs = [
ModelUtils.list_to_tensor(mini_batch["visual_obs%d" % i], dtype=torch.float)
for i in range(n_vis)
]
return vec_inputs, vis_inputs
n_obs = len(self.encoder.encoders)
obs = ObsUtil.from_buffer(mini_batch, n_obs)
# Convert to tensors
obs = [ModelUtils.list_to_tensor(obs) for obs in obs]
return obs
def compute_estimate(
self, mini_batch: AgentBuffer, use_vail_noise: bool = False

:param use_vail_noise: Only when using VAIL : If true, will sample the code, if
false, will return the mean of the code.
"""
vec_inputs, vis_inputs = self.get_state_inputs(mini_batch)
inputs = self.get_state_inputs(mini_batch)
hidden, _ = self.encoder(vec_inputs, vis_inputs, action_inputs)
hidden, _ = self.encoder(inputs, action_inputs)
hidden, _ = self.encoder(vec_inputs, vis_inputs)
hidden, _ = self.encoder(inputs)
z_mu: Optional[torch.Tensor] = None
if self._settings.use_vail:
z_mu = self._z_mu_layer(hidden)

Gradient penalty from https://arxiv.org/pdf/1704.00028. Adds stability esp.
for off-policy. Compute gradients w.r.t randomly interpolated input.
"""
policy_vec_inputs, policy_vis_inputs = self.get_state_inputs(policy_batch)
expert_vec_inputs, expert_vis_inputs = self.get_state_inputs(expert_batch)
interp_vec_inputs = []
for policy_vec_input, expert_vec_input in zip(
policy_vec_inputs, expert_vec_inputs
):
obs_epsilon = torch.rand(policy_vec_input.shape)
interp_vec_input = (
obs_epsilon * policy_vec_input + (1 - obs_epsilon) * expert_vec_input
)
interp_vec_input.requires_grad = True # For gradient calculation
interp_vec_inputs.append(interp_vec_input)
interp_vis_inputs = []
for policy_vis_input, expert_vis_input in zip(
policy_vis_inputs, expert_vis_inputs
):
obs_epsilon = torch.rand(policy_vis_input.shape)
interp_vis_input = (
obs_epsilon * policy_vis_input + (1 - obs_epsilon) * expert_vis_input
)
interp_vis_input.requires_grad = True # For gradient calculation
interp_vis_inputs.append(interp_vis_input)
policy_inputs = self.get_state_inputs(policy_batch)
expert_inputs = self.get_state_inputs(expert_batch)
interp_inputs = []
for policy_input, expert_input in zip(policy_inputs, expert_inputs):
obs_epsilon = torch.rand(policy_input.shape)
interp_input = obs_epsilon * policy_input + (1 - obs_epsilon) * expert_input
interp_input.requires_grad = True # For gradient calculation
interp_inputs.append(interp_input)
if self._settings.use_actions:
policy_action = self.get_action_input(policy_batch)
expert_action = self.get_action_input(expert_batch)

dim=1,
)
action_inputs.requires_grad = True
hidden, _ = self.encoder(
interp_vec_inputs, interp_vis_inputs, action_inputs
)
encoder_input = tuple(
interp_vec_inputs + interp_vis_inputs + [action_inputs]
)
hidden, _ = self.encoder(interp_inputs, action_inputs)
encoder_input = tuple(interp_inputs + [action_inputs])
hidden, _ = self.encoder(interp_vec_inputs, interp_vis_inputs)
encoder_input = tuple(interp_vec_inputs + interp_vis_inputs)
hidden, _ = self.encoder(interp_inputs)
encoder_input = tuple(interp_inputs)
if self._settings.use_vail:
use_vail_noise = True
z_mu = self._z_mu_layer(hidden)

21
ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py


from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.networks import NetworkBody
from mlagents.trainers.settings import NetworkSettings, EncoderType
from mlagents.trainers.trajectory import ObsUtil
class RNDRewardProvider(BaseRewardProvider):

self._encoder = NetworkBody(specs.observation_shapes, state_encoder_settings)
def forward(self, mini_batch: AgentBuffer) -> torch.Tensor:
n_vis = len(self._encoder.visual_processors)
hidden, _ = self._encoder.forward(
vec_inputs=[
ModelUtils.list_to_tensor(mini_batch["vector_obs"], dtype=torch.float)
],
vis_inputs=[
ModelUtils.list_to_tensor(
mini_batch["visual_obs%d" % i], dtype=torch.float
)
for i in range(n_vis)
],
)
self._encoder.update_normalization(torch.tensor(mini_batch["vector_obs"]))
n_obs = len(self._encoder.encoders)
obs = ObsUtil.from_buffer(mini_batch, n_obs)
# Convert to tensors
obs = [ModelUtils.list_to_tensor(obs) for obs in obs]
hidden, _ = self._encoder.forward(obs)
self._encoder.update_normalization(mini_batch)
return hidden

18
ml-agents/mlagents/trainers/torch/networks.py


@abc.abstractmethod
def forward(
self,
inputs: List[torch.Tensor],
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, int, int, int, int]:

def memory_size(self) -> int:
return self.network_body.memory_size
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None:
self.network_body.update_normalization(vector_obs)
def update_normalization(self, buffer: AgentBuffer) -> None:
self.network_body.update_normalization(buffer)
def sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]:
actions = []

critic_mem = None
actor_mem = None
dists, actor_mem_outs = self.get_dists(
inputs,
memories=actor_mem,
sequence_length=sequence_length,
masks=masks,
inputs, memories=actor_mem, sequence_length=sequence_length, masks=masks
)
value_outputs, critic_mem_outs = self.critic(
inputs, memories=critic_mem, sequence_length=sequence_length

mem_out = None
return dists, value_outputs, mem_out
def update_normalization(self, vector_obs: AgentBuffer) -> None:
super().update_normalization(vector_obs)
self.critic.network_body.update_normalization(vector_obs)
def update_normalization(self, buffer: AgentBuffer) -> None:
super().update_normalization(buffer)
self.critic.network_body.update_normalization(buffer)
class GlobalSteps(nn.Module):

17
ml-agents/mlagents/trainers/torch/utils.py


shape: Tuple[int, ...],
normalize: bool,
h_size: int,
vis_encode_type: EncoderType
vis_encode_type: EncoderType,
) -> Tuple[nn.Module, int]:
"""
Returns the encoder and the size of the generated embedding

return (VectorInput(shape[0], normalize), shape[0])
if len(shape) == 2:
raise UnityTrainerException(f"Unsupported shape of {shape} for observation")
if len(shape) == 3:
ModelUtils._check_resolution_for_encoder(
shape[0], shape[1], vis_encode_type

raise UnityTrainerException(f"Unsupported shape of {shape} for observation")
@staticmethod
def create_input_processors(

normalize: bool = False,
) -> Tuple[nn.ModuleList, nn.ModuleList, int]:
) -> Tuple[nn.ModuleList, List[int]]:
conditioining network on other values (e.g. actions for a Q function)
conditioning network on other values (e.g. actions for a Q function)
:param h_size: Number of hidden units per layer.
:param vis_encode_type: Type of visual encoder to use.
:param unnormalized_inputs: Vector inputs that should not be normalized, and added to the vector

encoders: List[nn.Module] = []
embedding_sizes: List[int] = []
for i, dimension in enumerate(observation_shapes):
encoder, embedding_size = ModelUtils.get_encoder_for_obs(dimension, normalize, h_size, vis_encode_type)
for dimension in observation_shapes:
encoder, embedding_size = ModelUtils.get_encoder_for_obs(
dimension, normalize, h_size, vis_encode_type
)
return (encoders, embedding_sizes)
return (nn.ModuleList(encoders), embedding_sizes)
@staticmethod
def list_to_tensor(

5
ml-agents/mlagents/trainers/trajectory.py


class ObsUtil:
result = []
result: List[np.array] = []
for obs in observations:
if len(obs.shape) == rank:
result += [obs]

@staticmethod
def from_buffer(batch: AgentBuffer, num_obs: int) -> List[np.array]:
result = []
result: List[np.array] = []
for i in range(num_obs):
result.append(batch[ObsUtil.get_name_at(i)])
return result

215
ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider copy.py


import numpy as np
from typing import Dict
from mlagents.torch_utils import torch, default_device
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import (
BaseRewardProvider,
)
from mlagents.trainers.settings import CuriositySettings
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.networks import NetworkBody
from mlagents.trainers.torch.layers import LinearEncoder, linear_layer
from mlagents.trainers.settings import NetworkSettings, EncoderType
from mlagents.trainers.trajectory import ObsUtil
class CuriosityRewardProvider(BaseRewardProvider):
beta = 0.2 # Forward vs Inverse loss weight
loss_multiplier = 10.0 # Loss multiplier
def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None:
super().__init__(specs, settings)
self._ignore_done = True
self._network = CuriosityNetwork(specs, settings)
self._network.to(default_device())
self.optimizer = torch.optim.Adam(
self._network.parameters(), lr=settings.learning_rate
)
self._has_updated_once = False
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray:
with torch.no_grad():
rewards = ModelUtils.to_numpy(self._network.compute_reward(mini_batch))
rewards = np.minimum(rewards, 1.0 / self.strength)
return rewards * self._has_updated_once
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
self._has_updated_once = True
forward_loss = self._network.compute_forward_loss(mini_batch)
inverse_loss = self._network.compute_inverse_loss(mini_batch)
loss = self.loss_multiplier * (
self.beta * forward_loss + (1.0 - self.beta) * inverse_loss
)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return {
"Losses/Curiosity Forward Loss": forward_loss.item(),
"Losses/Curiosity Inverse Loss": inverse_loss.item(),
}
def get_modules(self):
return {f"Module:{self.name}": self._network}
class CuriosityNetwork(torch.nn.Module):
EPSILON = 1e-10
def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None:
super().__init__()
self._action_spec = specs.action_spec
state_encoder_settings = NetworkSettings(
normalize=False,
hidden_units=settings.encoding_size,
num_layers=2,
vis_encode_type=EncoderType.SIMPLE,
memory=None,
)
self._state_encoder = NetworkBody(
specs.observation_shapes, state_encoder_settings
)
self._action_flattener = ModelUtils.ActionFlattener(self._action_spec)
self.inverse_model_action_prediction = torch.nn.Sequential(
LinearEncoder(2 * settings.encoding_size, 1, 256),
linear_layer(256, self._action_flattener.flattened_size),
)
self.forward_model_next_state_prediction = torch.nn.Sequential(
LinearEncoder(
settings.encoding_size + self._action_flattener.flattened_size, 1, 256
),
linear_layer(256, settings.encoding_size),
)
def get_current_state(self, mini_batch: AgentBuffer) -> torch.Tensor:
"""
Extracts the current state embedding from a mini_batch.
"""
n_obs = len(self._state_encoder.encoders)
obs = ObsUtil.from_buffer(mini_batch, n_obs)
# Convert to tensors
obs = [ModelUtils.list_to_tensor(obs) for obs in obs]
hidden, _ = self._state_encoder.forward(obs)
return hidden
def get_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor:
"""
Extracts the next state embedding from a mini_batch.
"""
n_obs = len(self._state_encoder.encoders)
obs = ObsUtil.from_buffer_next(mini_batch, n_obs)
# Convert to tensors
obs = [ModelUtils.list_to_tensor(obs) for obs in obs]
hidden, _ = self._state_encoder.forward(obs)
return hidden
def predict_action(self, mini_batch: AgentBuffer) -> torch.Tensor:
"""
In the continuous case, returns the predicted action.
In the discrete case, returns the logits.
"""
inverse_model_input = torch.cat(
(self.get_current_state(mini_batch), self.get_next_state(mini_batch)), dim=1
)
hidden = self.inverse_model_action_prediction(inverse_model_input)
if self._action_spec.is_continuous():
return hidden
else:
branches = ModelUtils.break_into_branches(
hidden, self._action_spec.discrete_branches
)
branches = [torch.softmax(b, dim=1) for b in branches]
return torch.cat(branches, dim=1)
def predict_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor:
"""
Uses the current state embedding and the action of the mini_batch to predict
the next state embedding.
"""
if self._action_spec.is_continuous():
action = ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.float)
else:
action = torch.cat(
ModelUtils.actions_to_onehot(
ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.long),
self._action_spec.discrete_branches,
),
dim=1,
)
forward_model_input = torch.cat(
(self.get_current_state(mini_batch), action), dim=1
)
return self.forward_model_next_state_prediction(forward_model_input)
def compute_inverse_loss(self, mini_batch: AgentBuffer) -> torch.Tensor:
"""
Computes the inverse loss for a mini_batch. Corresponds to the error on the
action prediction (given the current and next state).
"""
predicted_action = self.predict_action(mini_batch)
if self._action_spec.is_continuous():
sq_difference = (
ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.float)
- predicted_action
) ** 2
sq_difference = torch.sum(sq_difference, dim=1)
return torch.mean(
ModelUtils.dynamic_partition(
sq_difference,
ModelUtils.list_to_tensor(mini_batch["masks"], dtype=torch.float),
2,
)[1]
)
else:
true_action = torch.cat(
ModelUtils.actions_to_onehot(
ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.long),
self._action_spec.discrete_branches,
),
dim=1,
)
cross_entropy = torch.sum(
-torch.log(predicted_action + self.EPSILON) * true_action, dim=1
)
return torch.mean(
ModelUtils.dynamic_partition(
cross_entropy,
ModelUtils.list_to_tensor(
mini_batch["masks"], dtype=torch.float
), # use masks not action_masks
2,
)[1]
)
def compute_reward(self, mini_batch: AgentBuffer) -> torch.Tensor:
"""
Calculates the curiosity reward for the mini_batch. Corresponds to the error
between the predicted and actual next state.
"""
predicted_next_state = self.predict_next_state(mini_batch)
target = self.get_next_state(mini_batch)
sq_difference = 0.5 * (target - predicted_next_state) ** 2
sq_difference = torch.sum(sq_difference, dim=1)
return sq_difference
def compute_forward_loss(self, mini_batch: AgentBuffer) -> torch.Tensor:
"""
Computes the loss for the next state prediction
"""
return torch.mean(
ModelUtils.dynamic_partition(
self.compute_reward(mini_batch),
ModelUtils.list_to_tensor(mini_batch["masks"], dtype=torch.float),
2,
)[1]
)
正在加载...
取消
保存