浏览代码

ignoring Instance of 'AbstractContextManager' has no 'enter_context' member (no-member)

/develop/action-spec-gym
Andrew Cohen 4 年前
当前提交
8013e544
共有 15 个文件被更改,包括 101 次插入85 次删除
  1. 24
      ml-agents/mlagents/trainers/demo_loader.py
  2. 9
      ml-agents/mlagents/trainers/policy/policy.py
  3. 3
      ml-agents/mlagents/trainers/policy/torch_policy.py
  4. 2
      ml-agents/mlagents/trainers/ppo/trainer.py
  5. 20
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  6. 36
      ml-agents/mlagents/trainers/tests/mock_brain.py
  7. 13
      ml-agents/mlagents/trainers/tests/simple_test_envs.py
  8. 13
      ml-agents/mlagents/trainers/tests/torch/test_policy.py
  9. 3
      ml-agents/mlagents/trainers/torch/components/bc/module.py
  10. 16
      ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
  11. 3
      ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
  12. 1
      ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py
  13. 4
      ml-agents/mlagents/trainers/torch/model_serialization.py
  14. 33
      ml-agents/mlagents/trainers/torch/networks.py
  15. 6
      ml-agents/mlagents/trainers/torch/utils.py

24
ml-agents/mlagents/trainers/demo_loader.py


demo_buffer = make_demo_buffer(info_action_pair, behavior_spec, sequence_length)
if expected_behavior_spec:
# check action dimensions in demonstration match
if behavior_spec.action_shape != expected_behavior_spec.action_shape:
if (
behavior_spec.action_spec.continuous_action_size
!= expected_behavior_spec.action_spec.continuous_action_size
):
"The action dimensions {} in demonstration do not match the policy's {}.".format(
behavior_spec.action_shape, expected_behavior_spec.action_shape
"The continuous action dimensions {} in demonstration do not match the policy's {}.".format(
behavior_spec.action_spec.continuous_action_size,
expected_behavior_spec.action_spec.continuous_action_size,
# check the action types in demonstration match
if behavior_spec.action_type != expected_behavior_spec.action_type:
if (
behavior_spec.action_spec.discrete_action_branches
!= expected_behavior_spec.action_spec.discrete_action_branches
):
"The action type of {} in demonstration do not match the policy's {}.".format(
behavior_spec.action_type, expected_behavior_spec.action_type
"The continuous action dimensions {} in demonstration do not match the policy's {}.".format(
behavior_spec.action_spec.discrete_action_branches,
expected_behavior_spec.action_spec.discrete_action_branches,
# check observations match
# check observations match
if len(behavior_spec.observation_shapes) != len(
expected_behavior_spec.observation_shapes
):

9
ml-agents/mlagents/trainers/policy/policy.py


self.trainer_settings = trainer_settings
self.network_settings: NetworkSettings = trainer_settings.network_settings
self.seed = seed
# For mixed action spaces
self.continuous_act_size = self.action_spec.continuous_action_size
self.discrete_act_size = self.action_spec.discrete_action_size
self.discrete_act_branches = self.action_spec.discrete_action_branches
if (
self.action_spec.continuous_action_size > 0
and self.action_spec.discrete_action_size > 0
):
raise UnityPolicyException("Trainers do not support mixed action spaces.")
self.act_size = (
list(self.action_spec.discrete_action_branches)
if self.action_spec.is_action_discrete()

3
ml-agents/mlagents/trainers/policy/torch_policy.py


self.actor_critic = ac_class(
observation_shapes=self.behavior_spec.observation_shapes,
network_settings=trainer_settings.network_settings,
act_type=behavior_spec.action_type,
act_size=self.act_size,
action_spec=behavior_spec.action_spec,
stream_names=reward_signal_names,
conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,

2
ml-agents/mlagents/trainers/ppo/trainer.py


behavior_spec,
self.trainer_settings,
condition_sigma_on_obs=False, # Faster training for PPO
separate_critic=behavior_spec.is_action_continuous(),
separate_critic=behavior_spec.action_spec.is_action_continuous(),
)
return policy

20
ml-agents/mlagents/trainers/sac/optimizer_torch.py


from mlagents.torch_utils import torch, nn, default_device
from mlagents_envs.logging_util import get_logger
from mlagents_envs.base_env import ActionType
from mlagents_envs.base_env import ActionType, ActionSpec
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.settings import NetworkSettings

stream_names: List[str],
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
act_type: ActionType,
act_size: List[int],
action_spec: ActionSpec,
if act_type == ActionType.CONTINUOUS:
self.action_spec = action_spec
if self.action_spec.is_action_continuous():
self.act_type = ActionType.CONTINUOUS
self.act_size = self.action_spec.continuous_action_size
num_action_ins = sum(act_size)
num_action_ins = self.act_size
num_value_outs = sum(act_size)
self.act_type = ActionType.DISCRETE
self.act_size = self.action_spec.discrete_action_branches
num_value_outs = sum(self.act_size)
num_action_ins = 0
self.q1_network = ValueNetwork(
stream_names,

self.stream_names,
self.policy.behavior_spec.observation_shapes,
policy_network_settings,
self.policy.behavior_spec.action_type,
self.act_size,
self.policy.action_spec,
)
self.target_network = ValueNetwork(

36
ml-agents/mlagents/trainers/tests/mock_brain.py


DecisionSteps,
TerminalSteps,
BehaviorSpec,
ActionType,
ActionSpec,
)

reward = np.array(num_agents * [1.0], dtype=np.float32)
interrupted = np.array(num_agents * [False], dtype=np.bool)
agent_id = np.arange(num_agents, dtype=np.int32)
behavior_spec = BehaviorSpec(
observation_shapes,
ActionType.DISCRETE if discrete else ActionType.CONTINUOUS,
action_shape,
)
if discrete:
action_spec = ActionSpec(0, action_shape)
else:
action_spec = ActionSpec(action_shape, ())
behavior_spec = BehaviorSpec(observation_shapes, action_spec)
if done:
return (
DecisionSteps.empty(behavior_spec),

def create_steps_from_behavior_spec(
behavior_spec: BehaviorSpec, num_agents: int = 1
) -> Tuple[DecisionSteps, TerminalSteps]:
action_spec = behavior_spec.action_spec
is_discrete = action_spec.is_action_discrete()
action_shape=behavior_spec.action_shape,
discrete=behavior_spec.is_action_discrete(),
action_shape=action_spec.discrete_action_branches
if is_discrete
else action_spec.continuous_action_size,
discrete=is_discrete,
)

memory_size: int = 10,
exclude_key_list: List[str] = None,
) -> AgentBuffer:
action_space = behavior_spec.action_shape
is_discrete = behavior_spec.is_action_discrete()
is_discrete = behavior_spec.action_spec.is_action_discrete()
if is_discrete:
action_space = behavior_spec.action_spec.discrete_action_branches
else:
action_space = behavior_spec.action_spec.continuous_action_size
trajectory = make_fake_trajectory(
length,
behavior_spec.observation_shapes,

def setup_test_behavior_specs(
use_discrete=True, use_visual=False, vector_action_space=2, vector_obs_space=8
):
if use_discrete:
action_spec = ActionSpec(0, tuple(vector_action_space))
else:
action_spec = ActionSpec(vector_action_space, ())
[(84, 84, 3)] * int(use_visual) + [(vector_obs_space,)],
ActionType.DISCRETE if use_discrete else ActionType.CONTINUOUS,
tuple(vector_action_space) if use_discrete else vector_action_space,
[(84, 84, 3)] * int(use_visual) + [(vector_obs_space,)], action_spec
)
return behavior_spec

13
ml-agents/mlagents/trainers/tests/simple_test_envs.py


import numpy as np
from mlagents_envs.base_env import (
ActionSpec,
ActionType,
BehaviorMapping,
)
from mlagents_envs.tests.test_rpc_utils import proto_from_steps_and_action

self.num_vector = num_vector
self.vis_obs_size = vis_obs_size
self.vec_obs_size = vec_obs_size
action_type = ActionType.DISCRETE if use_discrete else ActionType.CONTINUOUS
self.behavior_spec = BehaviorSpec(
self._make_obs_spec(),
action_type,
tuple(2 for _ in range(action_size)) if use_discrete else action_size,
)
if use_discrete:
action_spec = ActionSpec(0, tuple(2 for _ in range(action_size)))
else:
action_spec = ActionSpec(action_size, ())
self.behavior_spec = BehaviorSpec(self._make_obs_spec(), action_spec)
self.action_size = action_size
self.names = brain_names
self.positions: Dict[str, List[float]] = {}

13
ml-agents/mlagents/trainers/tests/torch/test_policy.py


memories=memories,
seq_len=policy.sequence_length,
)
assert log_probs.shape == (64, policy.behavior_spec.action_size)
assert entropy.shape == (64, policy.behavior_spec.action_size)
assert log_probs.shape == (64, policy.action_spec.action_size)
assert entropy.shape == (64, policy.action_spec.action_size)
for val in values.values():
assert val.shape == (64,)

all_log_probs=not policy.use_continuous_act,
)
if discrete:
assert log_probs.shape == (
64,
sum(policy.behavior_spec.discrete_action_branches),
)
assert log_probs.shape == (64, sum(policy.action_spec.discrete_action_branches))
assert log_probs.shape == (64, policy.behavior_spec.action_shape)
assert entropies.shape == (64, policy.behavior_spec.action_size)
assert log_probs.shape == (64, policy.action_spec.continuous_action_size)
assert entropies.shape == (64, policy.action_spec.action_size)
if rnn:
assert memories.shape == (1, 1, policy.m_size)

3
ml-agents/mlagents/trainers/torch/components/bc/module.py


for the pretrainer.
"""
self.policy = policy
self.action_spec = policy.action_spec
self._anneal_steps = settings.steps
self.current_lr = policy_learning_rate * settings.strength

np.ones(
(
self.n_sequences * self.policy.sequence_length,
sum(self.policy.behavior_spec.discrete_action_branches),
sum(self.action_spec.discrete_action_branches),
),
dtype=np.float32,
)

16
ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py


def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None:
super().__init__()
self._policy_specs = specs
self._action_spec = specs.action_spec
state_encoder_settings = NetworkSettings(
normalize=False,
hidden_units=settings.encoding_size,

specs.observation_shapes, state_encoder_settings
)
self._action_flattener = ModelUtils.ActionFlattener(specs)
self._action_flattener = ModelUtils.ActionFlattener(self._action_spec)
self.inverse_model_action_prediction = torch.nn.Sequential(
LinearEncoder(2 * settings.encoding_size, 1, 256),

(self.get_current_state(mini_batch), self.get_next_state(mini_batch)), dim=1
)
hidden = self.inverse_model_action_prediction(inverse_model_input)
if self._policy_specs.is_action_continuous():
if self._action_spec.is_action_continuous():
hidden, self._policy_specs.discrete_action_branches
hidden, self._action_spec.discrete_action_branches
)
branches = [torch.softmax(b, dim=1) for b in branches]
return torch.cat(branches, dim=1)

Uses the current state embedding and the action of the mini_batch to predict
the next state embedding.
"""
if self._policy_specs.is_action_continuous():
if self._action_spec.is_action_continuous():
self._policy_specs.discrete_action_branches,
self._action_spec.discrete_action_branches,
),
dim=1,
)

action prediction (given the current and next state).
"""
predicted_action = self.predict_action(mini_batch)
if self._policy_specs.is_action_continuous():
if self._action_spec.is_action_continuous():
sq_difference = (
ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.float)
- predicted_action

true_action = torch.cat(
ModelUtils.actions_to_onehot(
ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.long),
self._policy_specs.discrete_action_branches,
self._action_spec.discrete_action_branches,
),
dim=1,
)

3
ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py


def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None:
super().__init__()
self._policy_specs = specs
self._use_vail = settings.use_vail
self._settings = settings

vis_encode_type=EncoderType.SIMPLE,
memory=None,
)
self._action_flattener = ModelUtils.ActionFlattener(specs)
self._action_flattener = ModelUtils.ActionFlattener(specs.action_spec)
unencoded_size = (
self._action_flattener.flattened_size + 1 if settings.use_actions else 0
) # +1 is for dones

1
ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py


def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None:
super().__init__()
self._policy_specs = specs
state_encoder_settings = NetworkSettings(
normalize=True,
hidden_units=settings.encoding_size,

4
ml-agents/mlagents/trainers/torch/model_serialization.py


for shape in self.policy.behavior_spec.observation_shapes
if len(shape) == 3
]
dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.act_size)])
dummy_masks = torch.ones(
batch_dim + [sum(self.policy.action_spec.discrete_action_branches)]
)
dummy_memories = torch.zeros(
batch_dim + seq_len_dim + [self.policy.export_memory_size]
)

33
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.torch_utils import torch, nn
from mlagents_envs.base_env import ActionType
from mlagents_envs.base_env import ActionType, ActionSpec
from mlagents.trainers.torch.distributions import (
GaussianDistribution,
MultiCategoricalDistribution,

self,
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
act_type: ActionType,
act_size: List[int],
action_spec: ActionSpec,
self.act_type = act_type
self.act_size = act_size
self.action_spec = action_spec
if self.action_spec.is_action_continuous():
self.act_type = ActionType.CONTINUOUS
else:
self.act_type = ActionType.DISCRETE
self.act_size = self.action_spec.action_size
torch.Tensor([int(act_type == ActionType.CONTINUOUS)])
torch.Tensor([int(self.act_type == ActionType.CONTINUOUS)])
torch.Tensor([sum(act_size)]), requires_grad=False
torch.Tensor([self.action_spec.action_size]), requires_grad=False
)
self.network_body = NetworkBody(observation_shapes, network_settings)
if network_settings.memory is not None:

if self.act_type == ActionType.CONTINUOUS:
self.distribution = GaussianDistribution(
self.encoding_size,
act_size[0],
self.action_spec.continuous_action_size,
self.encoding_size, act_size
self.encoding_size, self.action_spec.discrete_action_branches
)
@property

self,
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
act_type: ActionType,
act_size: List[int],
action_spec: ActionSpec,
stream_names: List[str],
conditional_sigma: bool = False,
tanh_squash: bool = False,

network_settings,
act_type,
act_size,
action_spec,
conditional_sigma,
tanh_squash,
)

self,
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
act_type: ActionType,
act_size: List[int],
action_spec: ActionSpec,
stream_names: List[str],
conditional_sigma: bool = False,
tanh_squash: bool = False,

super().__init__(
observation_shapes,
network_settings,
act_type,
act_size,
action_spec,
conditional_sigma,
tanh_squash,
)

6
ml-agents/mlagents/trainers/torch/utils.py


)
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.base_env import BehaviorSpec
from mlagents_envs.base_env import ActionSpec
from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance

}
class ActionFlattener:
def __init__(self, behavior_spec: BehaviorSpec):
self._specs = behavior_spec
def __init__(self, action_spec: ActionSpec):
self._specs = action_spec
@property
def flattened_size(self) -> int:

正在加载...
取消
保存