比较提交

...
此合并请求有变更与目标分支冲突。
/ml-agents-envs/mlagents_envs/rpc_utils.py
/ml-agents-envs/mlagents_envs/base_env.py
/ml-agents/mlagents/trainers/policy/torch_policy.py
/ml-agents/mlagents/trainers/policy/policy.py
/ml-agents/mlagents/trainers/ppo/optimizer_torch.py
/ml-agents/mlagents/trainers/ppo/trainer.py
/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py
/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py
/ml-agents/mlagents/trainers/tests/torch/test_networks.py
/ml-agents/mlagents/trainers/tests/simple_test_envs.py
/ml-agents/mlagents/trainers/torch/distributions.py
/ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
/ml-agents/mlagents/trainers/torch/model_serialization.py
/ml-agents/mlagents/trainers/torch/utils.py
/ml-agents/mlagents/trainers/torch/networks.py
/ml-agents/mlagents/trainers/tests/torch/test_hybrid.py
/ml-agents/mlagents/trainers/torch/action_model.py
/ml-agents/mlagents/trainers/policy/tf_policy.py

1 次代码提交

作者 SHA1 备注 提交日期
Andrew Cohen 6e23bafd ActionFlattener Refactor 4 年前
共有 19 个文件被更改,包括 619 次插入377 次删除
  1. 143
      ml-agents-envs/mlagents_envs/base_env.py
  2. 19
      ml-agents-envs/mlagents_envs/rpc_utils.py
  3. 7
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  4. 1
      ml-agents/mlagents/trainers/ppo/trainer.py
  5. 4
      ml-agents/mlagents/trainers/policy/tf_policy.py
  6. 72
      ml-agents/mlagents/trainers/policy/torch_policy.py
  7. 15
      ml-agents/mlagents/trainers/policy/policy.py
  8. 106
      ml-agents/mlagents/trainers/tests/simple_test_envs.py
  9. 6
      ml-agents/mlagents/trainers/tests/torch/test_networks.py
  10. 124
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py
  11. 5
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py
  12. 2
      ml-agents/mlagents/trainers/torch/model_serialization.py
  13. 37
      ml-agents/mlagents/trainers/torch/distributions.py
  14. 231
      ml-agents/mlagents/trainers/torch/networks.py
  15. 2
      ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
  16. 2
      ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
  17. 40
      ml-agents/mlagents/trainers/torch/utils.py
  18. 93
      ml-agents/mlagents/trainers/tests/torch/test_hybrid.py
  19. 87
      ml-agents/mlagents/trainers/torch/action_model.py

143
ml-agents-envs/mlagents_envs/base_env.py


BehaviorName = str
class ActionBuffer(NamedTuple):
"""
Contains continuous and discrete actions as numpy arrays.
"""
continuous: np.ndarray
discrete: np.ndarray
class DecisionStep(NamedTuple):
"""
Contains the data a single Agent collected since the last

class ActionType(Enum):
DISCRETE = 0
CONTINUOUS = 1
class BehaviorSpec(NamedTuple):
"""
A NamedTuple to containing information about the observations and actions
spaces for a group of Agents under the same behavior.
- observation_shapes is a List of Tuples of int : Each Tuple corresponds
to an observation's dimensions. The shape tuples have the same ordering as
the ordering of the DecisionSteps and TerminalSteps.
- action_type is the type of data of the action. it can be discrete or
continuous. If discrete, the action tensors are expected to be int32. If
continuous, the actions are expected to be float32.
- action_shape is:
- An int in continuous action space corresponding to the number of
floats that constitute the action.
- A Tuple of int in discrete action space where each int corresponds to
the number of discrete actions available to the agent.
"""
HYBRID = 2
observation_shapes: List[Tuple]
action_type: ActionType
action_shape: Union[int, Tuple[int, ...]]
class ActionSpec(NamedTuple):
num_continuous_actions: int
discrete_branch_sizes: Tuple[int]
# For backwards compatibility
return self.action_type == ActionType.DISCRETE
return self.discrete_action_size > 0
# For backwards compatibility
return self.action_type == ActionType.CONTINUOUS
return self.continuous_action_size > 0
def action_size(self) -> int:
"""
Returns the dimension of the action.
- In the continuous case, will return the number of continuous actions.
- In the (multi-)discrete case, will return the number of action.
branches.
"""
if self.action_type == ActionType.DISCRETE:
return len(self.action_shape) # type: ignore
else:
return self.action_shape # type: ignore
def discrete_action_branches(self) -> Optional[Tuple[int, ...]]:
return self.discrete_branch_sizes # type: ignore
def discrete_action_branches(self) -> Optional[Tuple[int, ...]]:
"""
Returns a Tuple of int corresponding to the number of possible actions
for each branch (only for discrete actions). Will return None in
for continuous actions.
"""
if self.action_type == ActionType.DISCRETE:
return self.action_shape # type: ignore
else:
return None
def discrete_action_size(self) -> int:
return len(self.discrete_branch_sizes)
def create_empty_action(self, n_agents: int) -> np.ndarray:
"""
Generates a numpy array corresponding to an empty action (all zeros)
for a number of agents.
:param n_agents: The number of agents that will have actions generated
"""
if self.action_type == ActionType.DISCRETE:
return np.zeros((n_agents, self.action_size), dtype=np.int32)
else:
return np.zeros((n_agents, self.action_size), dtype=np.float32)
@property
def continuous_action_size(self) -> int:
return self.num_continuous_actions
@property
def action_size(self) -> int:
return self.discrete_action_size + self.continuous_action_size
def create_empty_action(self, n_agents: int) -> Tuple[np.ndarray, np.ndarray]:
return ActionBuffer(
np.zeros((n_agents, self.continuous_action_size), dtype=np.float32),
np.zeros((n_agents, self.discrete_action_size), dtype=np.int32),
)
"""
Generates a numpy array corresponding to a random action (either discrete
or continuous) for a number of agents.
:param n_agents: The number of agents that will have actions generated
:param generator: The random number generator used for creating random action
"""
if self.is_action_continuous():
action = np.random.uniform(
low=-1.0, high=1.0, size=(n_agents, self.action_size)
).astype(np.float32)
return action
elif self.is_action_discrete():
branch_size = self.discrete_action_branches
action = np.column_stack(
[
np.random.randint(
0,
branch_size[i], # type: ignore
size=(n_agents),
dtype=np.int32,
)
for i in range(self.action_size)
]
)
return action
continuous_action = np.random.uniform(
low=-1.0, high=1.0, size=(n_agents, self.continuous_action_size)
).astype(np.float32)
branch_size = self.discrete_action_branches
discrete_action = np.column_stack(
[
np.random.randint(
0,
branch_size[i], # type: ignore
size=(n_agents),
dtype=np.int32,
)
for i in range(self.discrete_action_size)
]
)
return ActionBuffer(continuous_action, discrete_action)
class BehaviorSpec(NamedTuple):
observation_shapes: List[Tuple]
action_spec: ActionSpec
class BehaviorMapping(Mapping):
def __init__(self, specs: Dict[BehaviorName, BehaviorSpec]):
self._dict = specs

"""
@abstractmethod
def set_actions(self, behavior_name: BehaviorName, action: np.ndarray) -> None:
def set_actions(
self, behavior_name: BehaviorName, action: Union[ActionBuffer, np.ndarray]
) -> None:
"""
Sets the action for all of the agents in the simulation for the next
step. The Actions must be in the same order as the order received in

@abstractmethod
def set_action_for_agent(
self, behavior_name: BehaviorName, agent_id: AgentId, action: np.ndarray
self,
behavior_name: BehaviorName,
agent_id: AgentId,
action: Union[ActionBuffer, np.ndarray],
) -> None:
"""
Sets the action for one of the agents in the simulation for the next

19
ml-agents-envs/mlagents_envs/rpc_utils.py


from mlagents_envs.base_env import (
ActionSpec,
ActionType,
DecisionSteps,
TerminalSteps,
)

:return: BehaviorSpec object.
"""
observation_shape = [tuple(obs.shape) for obs in agent_info.observations]
action_type = (
ActionType.DISCRETE
if brain_param_proto.vector_action_space_type == 0
else ActionType.CONTINUOUS
)
if action_type == ActionType.CONTINUOUS:
action_shape: Union[
int, Tuple[int, ...]
] = brain_param_proto.vector_action_size[0]
else:
action_shape = tuple(brain_param_proto.vector_action_size)
return BehaviorSpec(observation_shape, action_type, action_shape)
action_spec = brain_param_proto.action_spec
action_spec = ActionSpec(action_spec.num_continuous_actions,
tuple(branch for branch in action_spec.discrete_branch_sizes)
)
return BehaviorSpec(observation_shape, action_spec)
class OffsetBytesIO:

7
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])]
act_masks = ModelUtils.list_to_tensor(batch["action_mask"])
if self.policy.use_continuous_act:
actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1)
else:
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long)
actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1)
#discrete_actions = ModelUtils.list_to_tensor(batch["actions"][self.policy.continuous_act_size:], dtype=torch.long)
memories = [
ModelUtils.list_to_tensor(batch["memory"][i])

vis_obs.append(vis_ob)
else:
vis_obs = []
log_probs, entropy, values = self.policy.evaluate_actions(
vec_obs,
vis_obs,

1
ml-agents/mlagents/trainers/ppo/trainer.py


behavior_spec,
self.trainer_settings,
condition_sigma_on_obs=False, # Faster training for PPO
separate_critic=behavior_spec.is_action_continuous(),
)
return policy

4
ml-agents/mlagents/trainers/policy/tf_policy.py


reparameterize,
condition_sigma_on_obs,
)
if self.continuous_act_size > 0 and len(self.discrete_act_size) > 0:
raise UnityPolicyException(
"Tensorflow does not support mixed action spaces. Please run with --torch."
)
# for ghost trainer save/load snapshots
self.assign_phs: List[tf.Tensor] = []
self.assign_ops: List[tf.Operation] = []

72
ml-agents/mlagents/trainers/policy/torch_policy.py


self.actor_critic = ac_class(
observation_shapes=self.behavior_spec.observation_shapes,
network_settings=trainer_settings.network_settings,
act_type=behavior_spec.action_type,
act_size=self.act_size,
action_spec=self.behavior_spec.action_spec,
stream_names=reward_signal_names,
conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,

) -> Tuple[SplitObservations, np.ndarray]:
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs)
mask = None
if not self.use_continuous_act:
mask = torch.ones([len(decision_requests), np.sum(self.act_size)])
if decision_requests.action_mask is not None:
mask = torch.as_tensor(
1 - np.concatenate(decision_requests.action_mask, axis=1)
)
if self.discrete_act_size > 0:
mask = torch.ones([len(decision_requests), np.sum(self.discrete_act_branches)])
if decision_requests.action_mask is not None:
mask = torch.as_tensor(
1 - np.concatenate(decision_requests.action_mask, axis=1)
)
return vec_vis_obs, mask
def update_normalization(self, vector_obs: np.ndarray) -> None:

memories: Optional[torch.Tensor] = None,
seq_len: int = 1,
all_log_probs: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
]:
:param vec_obs: List of vector observations.
:param vis_obs: List of visual observations.
:param masks: Loss masks for RNN, else None.
:param memories: Input memories when using RNN, else None.
:param seq_len: Sequence length when using RNN.
:return: Tuple of actions, log probabilities (dependent on all_log_probs), entropies, and
output memories, all as Torch Tensors.
if memories is None:
dists, memories = self.actor_critic.get_dists(
vec_obs, vis_obs, masks, memories, seq_len
)
else:
# If we're using LSTM. we need to execute the values to get the critic memories
dists, _, memories = self.actor_critic.get_dist_and_value(
vec_obs, vis_obs, masks, memories, seq_len
)
action_list = self.actor_critic.sample_action(dists)
log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy(
action_list, dists
actions, log_probs, entropies, value_heads, memories = self.actor_critic.get_action_stats_and_value(
vec_obs, vis_obs, masks, memories, seq_len
actions = torch.stack(action_list, dim=-1)
if self.use_continuous_act:
actions = actions[:, :, 0]
else:
actions = actions[:, 0, :]
return (actions, all_logs if all_log_probs else log_probs, entropies, memories)
return (
actions,
log_probs,
entropies,
value_heads,
memories,
)
def evaluate_actions(
self,

memories: Optional[torch.Tensor] = None,
seq_len: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
dists, value_heads, _ = self.actor_critic.get_dist_and_value(
vec_obs, vis_obs, masks, memories, seq_len
)
action_list = [actions[..., i] for i in range(actions.shape[-1])]
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_list, dists)
log_probs, entropies, value_heads = self.actor_critic.get_stats_and_value(
vec_obs, vis_obs, actions, masks, memories, seq_len
)
return log_probs, entropies, value_heads
@timed

run_out = {}
with torch.no_grad():
action, log_probs, entropy, memories = self.sample_actions(
action, log_probs, entropy, value_heads, memories = self.sample_actions(
run_out["action"] = ModelUtils.to_numpy(action)
# Todo - make pre_action difference
# Todo - make pre_action difference
run_out["action"] = ModelUtils.to_numpy(action)
run_out["value_heads"] = {
name: ModelUtils.to_numpy(t) for name, t in value_heads.items()
}
run_out["value"] = np.mean(list(run_out["value_heads"].values()), 0)
run_out["learning_rate"] = 0.0
if self.use_recurrent:
run_out["memory_out"] = ModelUtils.to_numpy(memories).squeeze(0)

15
ml-agents/mlagents/trainers/policy/policy.py


condition_sigma_on_obs: bool = True,
):
self.behavior_spec = behavior_spec
self.action_spec = behavior_spec.action_spec
# For mixed action spaces
self.continuous_act_size = self.action_spec.continuous_action_size
self.discrete_act_size = self.action_spec.discrete_action_size
self.discrete_act_branches = self.action_spec.discrete_action_branches
list(behavior_spec.discrete_action_branches)
if behavior_spec.is_action_discrete()
else [behavior_spec.action_size]
list(self.action_spec.discrete_action_branches)
if self.action_spec.is_action_discrete()
else [self.action_spec.action_size]
)
self.vec_obs_size = sum(
shape[0] for shape in behavior_spec.observation_shapes if len(shape) == 1

)
self.use_continuous_act = behavior_spec.is_action_continuous()
self.num_branches = self.behavior_spec.action_size
self.use_continuous_act = self.action_spec.is_action_continuous()
self.num_branches = self.action_spec.action_size
self.previous_action_dict: Dict[str, np.array] = {}
self.memory_dict: Dict[str, np.ndarray] = {}
self.normalize = trainer_settings.network_settings.normalize

106
ml-agents/mlagents/trainers/tests/simple_test_envs.py


import numpy as np
from mlagents_envs.base_env import (
ActionSpec,
BaseEnv,
BehaviorSpec,
DecisionSteps,

BehaviorName,
ActionBuffer,
)
from mlagents_envs.tests.test_rpc_utils import proto_from_steps_and_action
from mlagents_envs.communicator_objects.agent_info_action_pair_pb2 import (

self.vis_obs_size = vis_obs_size
self.vec_obs_size = vec_obs_size
action_type = ActionType.DISCRETE if use_discrete else ActionType.CONTINUOUS
self.behavior_spec = BehaviorSpec(
self._make_obs_spec(),
action_type,
tuple(2 for _ in range(action_size)) if use_discrete else action_size,
)
if use_discrete:
self.behavior_spec = BehaviorSpec(
self._make_obs_spec(), ActionSpec(0, tuple(2 for _ in range(action_size)))
)
else:
self.behavior_spec = BehaviorSpec(
self._make_obs_spec(), ActionSpec(action_size, tuple())
)
self.action_size = action_size
self.names = brain_names
self.positions: Dict[str, List[float]] = {}

def close(self):
pass
class HybridEnvironment(SimpleEnvironment):
def __init__(
self,
brain_names,
step_size=STEP_SIZE,
num_visual=0,
num_vector=1,
vis_obs_size=VIS_OBS_SIZE,
vec_obs_size=OBS_SIZE,
continuous_action_size=1,
discrete_action_size=1,
):
super().__init__(brain_names, False)
self.continuous_env = SimpleEnvironment(
brain_names,
False,
step_size,
num_visual,
num_vector,
vis_obs_size,
vec_obs_size,
continuous_action_size,
)
self.discrete_env = SimpleEnvironment(
brain_names,
True,
step_size,
num_visual,
num_vector,
vis_obs_size,
vec_obs_size,
discrete_action_size,
)
super().__init__(
brain_names,
True, # This is needed for env to generate masks correctly
step_size=step_size,
num_visual=num_visual,
num_vector=num_vector,
action_size=discrete_action_size, # This is needed for env to generate masks correctly
)
# Number of steps to reveal the goal for. Lower is harder. Should be
# less than 1/step_size to force agent to use memory
self.behavior_spec = BehaviorSpec(
self._make_obs_spec(),
ActionSpec(continuous_action_size, tuple(2 for _ in range(discrete_action_size))),
)
self.continuous_action_size = continuous_action_size
self.discrete_action_size = discrete_action_size
self.continuous_action = {}
self.discrete_action = {}
def step(self) -> None:
assert all(action is not None for action in self.continuous_env.action.values())
assert all(action is not None for action in self.discrete_env.action.values())
for name in self.names:
cont_done = self.continuous_env._take_action(name)
disc_done = self.discrete_env._take_action(name)
all_done = cont_done and disc_done
if all_done:
reward = 0
for _pos in (
self.continuous_env.positions[name]
+ self.discrete_env.positions[name]
):
reward += (SUCCESS_REWARD * _pos * self.goal[name]) / len(
self.continuous_env.positions[name]
+ self.discrete_env.positions[name]
)
else:
reward = -TIME_PENALTY
self.rewards[name] += reward
self.step_result[name] = self._make_batched_step(name, all_done, reward)
def reset(self) -> None: # type: ignore
super().reset()
self.continuous_env.reset()
self.discrete_env.reset()
self.continuous_env.goal = self.goal
self.discrete_env.goal = self.goal
def set_actions(self, behavior_name: BehaviorName, action) -> None:
# print(action, self.goal[behavior_name])
continuous_action = action[:, : self.continuous_action_size]
discrete_action = action[:, self.continuous_action_size :]
self.continuous_env.set_actions(behavior_name, continuous_action)
self.discrete_env.set_actions(behavior_name, discrete_action)
class MemoryEnvironment(SimpleEnvironment):

6
ml-agents/mlagents/trainers/tests/torch/test_networks.py


[sample_obs], [], masks=masks
)
for act in actions:
# This is different from above for ONNX export
if action_type == ActionType.CONTINUOUS:
assert act.shape == (act_size[0], 1)
else:
assert act.shape == tuple(act_size)
assert act.shape == tuple(act_size)
assert mem_size == 0
assert is_cont == int(action_type == ActionType.CONTINUOUS)

124
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py


GAILRewardProvider,
create_reward_provider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionType
from mlagents_envs.base_env import BehaviorSpec, ActionSpec
from mlagents.trainers.settings import GAILSettings, RewardSignalType
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,

SEED = [42]
@pytest.mark.parametrize(
"behavior_spec", [BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2)]
)
def test_construction(behavior_spec: BehaviorSpec) -> None:
gail_settings = GAILSettings(demo_path=CONTINUOUS_PATH)
gail_rp = GAILRewardProvider(behavior_spec, gail_settings)
assert gail_rp.name == "GAIL"
#@pytest.mark.parametrize(
# "behavior_spec", [BehaviorSpec([(8,)], ActionSpec(2, tuple()))]
#)
#def test_construction(behavior_spec: BehaviorSpec) -> None:
# gail_settings = GAILSettings(demo_path=CONTINUOUS_PATH)
# gail_rp = GAILRewardProvider(behavior_spec, gail_settings)
# assert gail_rp.name == "GAIL"
@pytest.mark.parametrize(
"behavior_spec", [BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2)]
)
def test_factory(behavior_spec: BehaviorSpec) -> None:
gail_settings = GAILSettings(demo_path=CONTINUOUS_PATH)
gail_rp = create_reward_provider(
RewardSignalType.GAIL, behavior_spec, gail_settings
)
assert gail_rp.name == "GAIL"
#@pytest.mark.parametrize(
# "behavior_spec", [BehaviorSpec([(8,)], ActionSpec(2, tuple()))]
#)
#def test_factory(behavior_spec: BehaviorSpec) -> None:
# gail_settings = GAILSettings(demo_path=CONTINUOUS_PATH)
# gail_rp = create_reward_provider(
# RewardSignalType.GAIL, behavior_spec, gail_settings
# )
# assert gail_rp.name == "GAIL"
@pytest.mark.parametrize("seed", SEED)

BehaviorSpec([(8,), (24, 26, 1)], ActionType.CONTINUOUS, 2),
BehaviorSpec([(50,)], ActionType.DISCRETE, (2, 3, 3, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (20,)),
BehaviorSpec([(8,), (24, 26, 1)], ActionSpec(2, tuple())),
BehaviorSpec([(50,)], ActionSpec(0, (2, 3, 3, 3))),
BehaviorSpec([(10,)], ActionSpec(0, (20,))),
],
)
@pytest.mark.parametrize("use_actions", [False, True])

assert (
reward_policy < init_reward_policy
) # Non-expert reward getting worse as network trains
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3, 3, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (20,)),
],
)
@pytest.mark.parametrize("use_actions", [False, True])
@patch(
"mlagents.trainers.torch.components.reward_providers.gail_reward_provider.demo_to_buffer"
)
def test_reward_decreases_vail(
demo_to_buffer: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int
) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
buffer_expert = create_agent_buffer(behavior_spec, 1000)
buffer_policy = create_agent_buffer(behavior_spec, 1000)
demo_to_buffer.return_value = None, buffer_expert
gail_settings = GAILSettings(
demo_path="", learning_rate=0.005, use_vail=True, use_actions=use_actions
)
DiscriminatorNetwork.initial_beta = 0.0
# we must set the initial value of beta to 0 for testing
# If we do not, the kl-loss will dominate early and will block the estimator
gail_rp = create_reward_provider(
RewardSignalType.GAIL, behavior_spec, gail_settings
)
for _ in range(200):
gail_rp.update(buffer_policy)
reward_expert = gail_rp.evaluate(buffer_expert)[0]
reward_policy = gail_rp.evaluate(buffer_policy)[0]
assert reward_expert >= 0 # GAIL / VAIL reward always positive
assert reward_policy >= 0
reward_expert = gail_rp.evaluate(buffer_expert)[0]
reward_policy = gail_rp.evaluate(buffer_policy)[0]
assert reward_expert > reward_policy # Expert reward greater than non-expert reward
#
#
#@pytest.mark.parametrize("seed", SEED)
#@pytest.mark.parametrize(
# "behavior_spec",
# [
# BehaviorSpec([(8,)], ActionSpec(2, tuple())),
# BehaviorSpec([(10,)], ActionSpec(0, (2, 3, 3, 3))),
# BehaviorSpec([(10,)], ActionSpec(0, (20,))),
# ],
#)
#@pytest.mark.parametrize("use_actions", [False, True])
#@patch(
# "mlagents.trainers.torch.components.reward_providers.gail_reward_provider.demo_to_buffer"
#)
#def test_reward_decreases_vail(
# demo_to_buffer: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int
#) -> None:
# np.random.seed(seed)
# torch.manual_seed(seed)
# buffer_expert = create_agent_buffer(behavior_spec, 1000)
# buffer_policy = create_agent_buffer(behavior_spec, 1000)
# demo_to_buffer.return_value = None, buffer_expert
# gail_settings = GAILSettings(
# demo_path="", learning_rate=0.005, use_vail=True, use_actions=use_actions
# )
# DiscriminatorNetwork.initial_beta = 0.0
# # we must set the initial value of beta to 0 for testing
# # If we do not, the kl-loss will dominate early and will block the estimator
# gail_rp = create_reward_provider(
# RewardSignalType.GAIL, behavior_spec, gail_settings
# )
#
# for _ in range(200):
# gail_rp.update(buffer_policy)
# reward_expert = gail_rp.evaluate(buffer_expert)[0]
# reward_policy = gail_rp.evaluate(buffer_policy)[0]
# assert reward_expert >= 0 # GAIL / VAIL reward always positive
# assert reward_policy >= 0
# reward_expert = gail_rp.evaluate(buffer_expert)[0]
# reward_policy = gail_rp.evaluate(buffer_policy)[0]
# assert reward_expert > reward_policy # Expert reward greater than non-expert reward

5
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py


next_observations = [
np.random.normal(size=shape) for shape in behavior_spec.observation_shapes
]
action = behavior_spec.create_random_action(1)[0, :]
action_buffer = behavior_spec.action_spec.create_random_action(1)
#action = behavior_spec.action_spec.create_random_action(1)[0, :]
action = np.concatenate([action_buffer.continuous, action_buffer.discrete], axis=1)
print(action)
for _ in range(number):
curr_split_obs = SplitObservations.from_observations(curr_observations)
next_split_obs = SplitObservations.from_observations(next_observations)

2
ml-agents/mlagents/trainers/torch/model_serialization.py


for shape in self.policy.behavior_spec.observation_shapes
if len(shape) == 3
]
dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.act_size)])
dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.discrete_act_branches)])
dummy_memories = torch.zeros(
batch_dim + seq_len_dim + [self.policy.export_memory_size]
)

37
ml-agents/mlagents/trainers/torch/distributions.py


import abc
from typing import List
from typing import List, Tuple
from mlagents.torch_utils import torch, nn
import numpy as np
import math

"""
pass
@abc.abstractmethod
def exported_model_output(self) -> torch.Tensor:
"""
Returns the tensor to be exported to ONNX for the distribution
"""
pass
@abc.abstractmethod
def structure_action(self, action: torch.Tensor) -> torch.Tensor:
"""
Return the structured action to be passed to the trainer
"""
pass
class DiscreteDistInstance(DistInstance):
@abc.abstractmethod

def entropy(self):
return 0.5 * torch.log(2 * math.pi * math.e * self.std + EPSILON)
def exported_model_output(self):
return self.sample()
def structure_action(self, action):
return action[:, :, 0]
class TanhGaussianDistInstance(GaussianDistInstance):
def __init__(self, mean, std):

).squeeze(-1)
def log_prob(self, value):
return torch.log(self.pdf(value))
return torch.log(self.pdf(value)).unsqueeze(-1)
return -torch.sum(self.probs * torch.log(self.probs), dim=-1)
return -torch.sum(self.probs * torch.log(self.probs), dim=-1).unsqueeze(-1)
def exported_model_output(self):
return self.all_log_prob()
def structure_action(self, action):
return action[:, 0, :].type(torch.float)
class GaussianDistribution(nn.Module):

torch.zeros(1, num_outputs, requires_grad=True)
)
def forward(self, inputs: torch.Tensor) -> List[DistInstance]:
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> List[DistInstance]:
log_sigma = self.log_sigma
log_sigma = self.log_sigma.expand(inputs.shape[0], -1)
if self.tanh_squash:
return [TanhGaussianDistInstance(mu, torch.exp(log_sigma))]
else:

231
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.torch_utils import torch, nn
from mlagents_envs.base_env import ActionType
from mlagents.trainers.torch.distributions import (
GaussianDistribution,
MultiCategoricalDistribution,
DistInstance,
)
from mlagents_envs.base_env import ActionSpec
from mlagents.trainers.torch.distributions import DistInstance
from mlagents.trainers.torch.action_model import ActionModel
from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads

pass
@abc.abstractmethod
def sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]:
"""
Takes a List of Distribution iinstances and samples an action from each.
"""
pass
@abc.abstractmethod
def get_dists(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Optional[torch.Tensor]]:
"""
Returns distributions from this Actor, from which actions can be sampled.
If memory is enabled, return the memories as well.
:param vec_inputs: A List of vector inputs as tensors.
:param vis_inputs: A List of visual inputs as tensors.
:param masks: If using discrete actions, a Tensor of action masks.
:param memories: If using memory, a Tensor of initial memories.
:param sequence_length: If using memory, the sequence length.
:return: A Tuple of a List of action distribution instances, and memories.
Memories will be None if not using memory.
"""
pass
@abc.abstractmethod
def forward(
self,
vec_inputs: List[torch.Tensor],

pass
@abc.abstractmethod
def get_dist_and_value(
def get_action_stats_and_value(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],

) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
) -> Tuple[List[DistInstance], List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
"""
Returns distributions, from which actions can be sampled, and value estimates.
If memory is enabled, return the memories as well.

self,
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
act_type: ActionType,
act_size: List[int],
action_spec: ActionSpec,
self.act_type = act_type
self.act_size = act_size
self.discrete_act_size = action_spec.discrete_action_size
self.discrete_act_branches = action_spec.discrete_action_branches
self.continuous_act_size = action_spec.continuous_action_size
self.is_continuous_int = torch.nn.Parameter(
torch.Tensor([int(act_type == ActionType.CONTINUOUS)])
)
torch.Tensor([sum(act_size)]), requires_grad=False
torch.Tensor(action_spec.action_size)
)
self.is_continuous_int = torch.nn.Parameter(
torch.Tensor([int(self.continuous_act_size > 0)])
)
self.network_body = NetworkBody(observation_shapes, network_settings)
if network_settings.memory is not None:

if self.act_type == ActionType.CONTINUOUS:
self.distribution = GaussianDistribution(
self.encoding_size,
act_size[0],
conditional_sigma=conditional_sigma,
tanh_squash=tanh_squash,
)
else:
self.distribution = MultiCategoricalDistribution(
self.encoding_size, act_size
)
self.action_model = ActionModel(
self.encoding_size,
action_spec,
conditional_sigma=conditional_sigma,
tanh_squash=tanh_squash,
)
@property
def memory_size(self) -> int:

self.network_body.update_normalization(vector_obs)
def sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]:
actions = []
for action_dist in dists:
action = action_dist.sample()
actions.append(action)
return actions
def get_dists(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[List[DistInstance], Optional[torch.Tensor]]:
encoding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
if self.act_type == ActionType.CONTINUOUS:
dists = self.distribution(encoding)
else:
dists = self.distribution(encoding, masks)
return dists, memories
def forward(
self,
vec_inputs: List[torch.Tensor],

"""
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
"""
dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1)
if self.act_type == ActionType.CONTINUOUS:
action_list = self.sample_action(dists)
action_out = torch.stack(action_list, dim=-1)
else:
action_out = torch.cat([dist.all_log_prob() for dist in dists], dim=1)
encoding, memories_out = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=1
)
# TODO: How this is written depends on how the inference model is structured
action_out = self.action_model.get_action_out(encoding, masks)
return (
action_out,
self.version_number,

self,
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
act_type: ActionType,
act_size: List[int],
action_spec: ActionSpec,
self.use_lstm = network_settings.memory is not None
act_type,
act_size,
action_spec,
conditional_sigma,
tanh_squash,
)

)
return self.value_heads(encoding), memories_out
def get_dist_and_value(
def get_stats_and_value(
actions: torch.Tensor,
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
if self.act_type == ActionType.CONTINUOUS:
dists = self.distribution(encoding)
else:
dists = self.distribution(encoding, masks=masks)
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions)
value_outputs = self.value_heads(encoding)
return log_probs, entropies, value_outputs
def get_action_stats_and_value(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
encoding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
action, log_probs, entropies = self.action_model(encoding, masks)
return dists, value_outputs, memories
return action, log_probs, entropies, value_outputs, memories
class SeparateActorCritic(SimpleActor, ActorCritic):

network_settings: NetworkSettings,
act_type: ActionType,
act_size: List[int],
action_spec: ActionSpec,
# Give the Actor only half the memories. Note we previously validate
# that memory_size must be a multiple of 4.
act_type,
act_size,
action_spec,
conditional_sigma,
tanh_squash,
)

memories_out = None
return value_outputs, memories_out
def get_dist_and_value(
def get_stats_and_value(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
actions: torch.Tensor,
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
if self.use_lstm:
# Use only the back half of memories for critic and actor
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1)
else:
critic_mem = None
actor_mem = None
encoding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions)
value_outputs, critic_mem_outs = self.critic(
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
)
return log_probs, entropies, value_outputs
def get_action_stats_and_value(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],

) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
if self.use_lstm:
# Use only the back half of memories for critic and actor
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1)

dists, actor_mem_outs = self.get_dists(
vec_inputs,
vis_inputs,
memories=actor_mem,
sequence_length=sequence_length,
masks=masks,
encoding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
action, log_probs, entropies = self.action_model(encoding, masks)
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
return dists, value_outputs, mem_out
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None:
super().update_normalization(vector_obs)
self.critic.network_body.update_normalization(vector_obs)
return action, log_probs, entropies, value_outputs, mem_out
def __init__(self):
super().__init__()
self.__global_step = nn.Parameter(torch.Tensor([0]), requires_grad=False)
def __init__(self):
super().__init__()
self.__global_step = nn.Parameter(torch.Tensor([0]), requires_grad=False)
@property
def current_step(self):
return int(self.__global_step.item())
@property
def current_step(self):
return int(self.__global_step.item())
@current_step.setter
def current_step(self, value):
self.__global_step[:] = value
@current_step.setter
def current_step(self, value):
self.__global_step[:] = value
def increment(self, value):
self.__global_step += value
def increment(self, value):
self.__global_step += value
def __init__(self, lr):
# Todo: add learning rate decay
super().__init__()
self.learning_rate = torch.Tensor([lr])
def __init__(self, lr):
# Todo: add learning rate decay
super().__init__()
self.learning_rate = torch.Tensor([lr])

2
ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py


specs.observation_shapes, state_encoder_settings
)
self._action_flattener = ModelUtils.ActionFlattener(specs)
self._action_flattener = ModelUtils.ActionFlattener(specs.action_spec)
self.inverse_model_action_prediction = torch.nn.Sequential(
LinearEncoder(2 * settings.encoding_size, 1, 256),

2
ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py


vis_encode_type=EncoderType.SIMPLE,
memory=None,
)
self._action_flattener = ModelUtils.ActionFlattener(specs)
self._action_flattener = ModelUtils.ActionFlattener(specs.action_spec)
unencoded_size = (
self._action_flattener.flattened_size + 1 if settings.use_actions else 0
) # +1 is for dones

40
ml-agents/mlagents/trainers/torch/utils.py


)
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.base_env import BehaviorSpec
from mlagents_envs.base_env import ActionSpec
from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance

}
class ActionFlattener:
def __init__(self, behavior_spec: BehaviorSpec):
self._specs = behavior_spec
def __init__(self, action_spec: ActionSpec):
self._specs = action_spec
if self._specs.is_action_continuous():
return self._specs.action_size
else:
return sum(self._specs.discrete_action_branches)
return self._specs.continuous_action_size + sum(self._specs.discrete_action_branches)
if self._specs.is_action_continuous():
return action
else:
return torch.cat(
_cont = action[: self._specs.continuous_action_size]
_disc = action[self._specs.continuous_action_size :]
_disc = torch.cat(
torch.as_tensor(action, dtype=torch.long),
torch.as_tensor(_disc, dtype=torch.long),
return torch.cat([_cont, _disc], dim=1)
#if self._specs.is_action_continuous():
# return action
#else:
# return torch.cat(
# ModelUtils.actions_to_onehot(
# torch.as_tensor(action, dtype=torch.long),
# self._specs.discrete_action_branches,
# ),
# dim=1,
# )
@staticmethod
def update_learning_rate(optim: torch.optim.Optimizer, lr: float) -> None:

for action, action_dist in zip(action_list, dists):
log_prob = action_dist.log_prob(action)
log_probs_list.append(log_prob)
entropies_list.append(action_dist.entropy())
entropy = action_dist.entropy()
entropies_list.append(entropy)
log_probs = torch.stack(log_probs_list, dim=-1)
entropies = torch.stack(entropies_list, dim=-1)
log_probs = torch.cat(log_probs_list, dim=1)
entropies = torch.cat(entropies_list, dim=1)
if not all_probs_list:
log_probs = log_probs.squeeze(-1)
entropies = entropies.squeeze(-1)

93
ml-agents/mlagents/trainers/tests/torch/test_hybrid.py


import pytest
import attr
from mlagents.trainers.tests.simple_test_envs import (
SimpleEnvironment,
HybridEnvironment,
MemoryEnvironment,
RecordEnvironment,
)
from mlagents.trainers.demo_loader import write_demo
from mlagents.trainers.settings import (
NetworkSettings,
SelfPlaySettings,
BehavioralCloningSettings,
GAILSettings,
RewardSignalType,
EncoderType,
FrameworkType,
)
from mlagents_envs.communicator_objects.demonstration_meta_pb2 import (
DemonstrationMetaProto,
)
from mlagents_envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
from mlagents_envs.communicator_objects.space_type_pb2 import discrete, continuous
from mlagents.trainers.tests.dummy_config import ppo_dummy_config, sac_dummy_config
from mlagents.trainers.tests.check_env_trains import (
check_environment_trains,
default_reward_processor,
)
BRAIN_NAME = "1D"
PPO_TORCH_CONFIG = attr.evolve(ppo_dummy_config(), framework=FrameworkType.PYTORCH)
SAC_TORCH_CONFIG = attr.evolve(sac_dummy_config(), framework=FrameworkType.PYTORCH)
# @pytest.mark.parametrize("use_discrete", [True, False])
# def test_simple_ppo(use_discrete):
# env = SimpleEnvironment([BRAIN_NAME], use_discrete=use_discrete)
# config = attr.evolve(PPO_TORCH_CONFIG)
# _check_environment_trains(env, {BRAIN_NAME: config})
def test_hybrid_ppo():
env = HybridEnvironment(
[BRAIN_NAME], continuous_action_size=1, discrete_action_size=1, step_size=0.8
)
new_hyperparams = attr.evolve(
PPO_TORCH_CONFIG.hyperparameters, batch_size=32, buffer_size=1280
)
config = attr.evolve(PPO_TORCH_CONFIG, hyperparameters=new_hyperparams, max_steps=10000)
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=1.0)
def test_conthybrid_ppo():
env = HybridEnvironment(
[BRAIN_NAME], continuous_action_size=1, discrete_action_size=0, step_size=0.8
)
config = attr.evolve(PPO_TORCH_CONFIG)
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=1.0)
def test_dischybrid_ppo():
env = HybridEnvironment(
[BRAIN_NAME], continuous_action_size=0, discrete_action_size=1, step_size=0.8
)
config = attr.evolve(PPO_TORCH_CONFIG)
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=1.0)
def test_3cdhybrid_ppo():
env = HybridEnvironment([BRAIN_NAME], continuous_action_size=2, discrete_action_size=1, step_size=0.8)
new_hyperparams = attr.evolve(
PPO_TORCH_CONFIG.hyperparameters, batch_size=128, buffer_size=1280, beta=0.01
)
config = attr.evolve(PPO_TORCH_CONFIG, hyperparameters=new_hyperparams, max_steps=10000)
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=1.0)
def test_3ddhybrid_ppo():
env = HybridEnvironment(
[BRAIN_NAME], continuous_action_size=1, discrete_action_size=2, step_size=0.8
)
new_hyperparams = attr.evolve(
PPO_TORCH_CONFIG.hyperparameters, batch_size=128, buffer_size=1280, beta=0.01
)
config = attr.evolve(PPO_TORCH_CONFIG, hyperparameters=new_hyperparams, max_steps=10000)
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=1.0)

87
ml-agents/mlagents/trainers/torch/action_model.py


import abc
from typing import List, Tuple
from mlagents.torch_utils import torch, nn
import numpy as np
import math
from mlagents.trainers.torch.layers import linear_layer, Initialization
from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance, GaussianDistribution, MultiCategoricalDistribution
from mlagents.trainers.torch.utils import ModelUtils
from mlagents_envs.base_env import ActionSpec
EPSILON = 1e-7 # Small value to avoid divide by zero
class ActionModel(nn.Module):
def __init__(
self,
hidden_size: int,
action_spec: ActionSpec,
conditional_sigma: bool = False,
tanh_squash: bool = False,
):
super().__init__()
self.encoding_size = hidden_size
self.continuous_act_size = action_spec.continuous_action_size
self.discrete_act_branches = action_spec.discrete_action_branches
self.discrete_act_size = action_spec.discrete_action_size
self.action_spec = action_spec
self._split_list : List[int] = []
self._distributions = torch.nn.ModuleList()
if self.continuous_act_size > 0:
self._distributions.append(GaussianDistribution(
self.encoding_size,
self.continuous_act_size,
conditional_sigma=conditional_sigma,
tanh_squash=tanh_squash,
)
)
self._split_list.append(self.continuous_act_size)
if self.discrete_act_size > 0:
self._distributions.append(MultiCategoricalDistribution(self.encoding_size, self.discrete_act_branches))
self._split_list += [1 for _ in range(self.discrete_act_size)]
def _sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]:
"""
Samples actions from list of distribution instances
"""
actions = []
for action_dist in dists:
action = action_dist.sample()
actions.append(action)
return actions
def _get_dists(self, inputs: torch.Tensor, masks: torch.Tensor) -> Tuple[List[DistInstance], List[DiscreteDistInstance]]:
distribution_instances: List[DistInstance] = []
for distribution in self._distributions:
dist_instances = distribution(inputs, masks)
for dist_instance in dist_instances:
distribution_instances.append(dist_instance)
return distribution_instances
def evaluate(self, inputs: torch.Tensor, masks: torch.Tensor, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
dists = self._get_dists(inputs, masks)
split_actions = torch.split(actions, self._split_list, dim=1)
action_lists : List[torch.Tensor] = []
for split_action in split_actions:
action_list = [split_action[..., i] for i in range(split_action.shape[-1])]
action_lists += action_list
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_lists, dists)
return log_probs, entropies
def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
dists = self._get_dists(inputs, masks)
return torch.cat([dist.exported_model_output() for dist in dists], dim=1)
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
dists = self._get_dists(inputs, masks)
action_outs : List[torch.Tensor] = []
action_lists = self._sample_action(dists)
for action_list, dist in zip(action_lists, dists):
action_out = action_list.unsqueeze(-1)
action_outs.append(dist.structure_action(action_out))
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_lists, dists)
action = torch.cat(action_outs, dim=1)
return (action, log_probs, entropies)
正在加载...
取消
保存