Andrew Cohen
4 年前
当前提交
06e4356c
共有 64 个文件被更改,包括 977 次插入 和 561 次删除
-
2com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
-
12com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
-
53com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
-
30com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
-
18com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
-
2com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
-
13com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
-
34com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
-
22com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
-
5com.unity.ml-agents/CHANGELOG.md
-
2docs/Getting-Started.md
-
2docs/Training-Configuration-File.md
-
2docs/Training-on-Microsoft-Azure.md
-
4docs/Using-Docker.md
-
2docs/Using-Tensorboard.md
-
4docs/localized/zh-CN/docs/Getting-Started-with-Balance-Ball.md
-
2ml-agents-envs/setup.py
-
2ml-agents/mlagents/model_serialization.py
-
5ml-agents/mlagents/trainers/agent_processor.py
-
2ml-agents/mlagents/trainers/components/reward_signals/__init__.py
-
2ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py
-
2ml-agents/mlagents/trainers/components/reward_signals/gail/model.py
-
29ml-agents/mlagents/trainers/env_manager.py
-
11ml-agents/mlagents/trainers/ghost/trainer.py
-
3ml-agents/mlagents/trainers/optimizer/optimizer.py
-
2ml-agents/mlagents/trainers/optimizer/tf_optimizer.py
-
166ml-agents/mlagents/trainers/policy/policy.py
-
334ml-agents/mlagents/trainers/policy/tf_policy.py
-
2ml-agents/mlagents/trainers/ppo/optimizer.py
-
21ml-agents/mlagents/trainers/ppo/trainer.py
-
4ml-agents/mlagents/trainers/sac/network.py
-
6ml-agents/mlagents/trainers/sac/optimizer.py
-
17ml-agents/mlagents/trainers/sac/trainer.py
-
12ml-agents/mlagents/trainers/settings.py
-
17ml-agents/mlagents/trainers/stats.py
-
2ml-agents/mlagents/trainers/tests/test_barracuda_converter.py
-
7ml-agents/mlagents/trainers/tests/test_bcmodule.py
-
2ml-agents/mlagents/trainers/tests/test_distributions.py
-
2ml-agents/mlagents/trainers/tests/test_models.py
-
17ml-agents/mlagents/trainers/tests/test_nn_policy.py
-
15ml-agents/mlagents/trainers/tests/test_ppo.py
-
6ml-agents/mlagents/trainers/tests/test_reward_signals.py
-
13ml-agents/mlagents/trainers/tests/test_sac.py
-
3ml-agents/mlagents/trainers/tests/test_simple_rl.py
-
2ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py
-
3ml-agents/mlagents/trainers/tests/test_trainer_controller.py
-
4ml-agents/mlagents/trainers/trainer/rl_trainer.py
-
9ml-agents/mlagents/trainers/trainer/trainer.py
-
36ml-agents/mlagents/trainers/trainer_controller.py
-
2ml-agents/mlagents/trainers/tf/tensorflow_to_barracuda.py
-
13ml-agents/mlagents/trainers/tf/models.py
-
2ml-agents/mlagents/trainers/tf/distributions.py
-
147com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs
-
11com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyJointExtractor.cs.meta
-
27com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs
-
11com.unity.ml-agents.extensions/Runtime/Sensors/IJointExtractor.cs.meta
-
62com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs
-
11com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyJointExtractor.cs.meta
-
0ml-agents/mlagents/trainers/tf/__init__.py
-
285ml-agents/mlagents/trainers/policy/nn_policy.py
-
0/ml-agents/mlagents/trainers/tf/tensorflow_to_barracuda.py
-
0/ml-agents/mlagents/trainers/tf/models.py
-
0/ml-agents/mlagents/trainers/tf/distributions.py
|
|||
from abc import ABC, abstractmethod |
|||
from abc import abstractmethod |
|||
from typing import Dict, List, Optional |
|||
import numpy as np |
|||
from mlagents_envs.exception import UnityException |
|||
|
|||
from mlagents.model_serialization import SerializationSettings |
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
from mlagents.trainers.settings import TrainerSettings, NetworkSettings |
|||
class Policy(ABC): |
|||
@abstractmethod |
|||
class UnityPolicyException(UnityException): |
|||
""" |
|||
Related to errors with the Trainer. |
|||
""" |
|||
|
|||
pass |
|||
|
|||
|
|||
class Policy: |
|||
def __init__( |
|||
self, |
|||
seed: int, |
|||
behavior_spec: BehaviorSpec, |
|||
trainer_settings: TrainerSettings, |
|||
model_path: str, |
|||
load: bool = False, |
|||
tanh_squash: bool = False, |
|||
reparameterize: bool = False, |
|||
condition_sigma_on_obs: bool = True, |
|||
): |
|||
self.behavior_spec = behavior_spec |
|||
self.trainer_settings = trainer_settings |
|||
self.network_settings: NetworkSettings = trainer_settings.network_settings |
|||
self.seed = seed |
|||
self.act_size = ( |
|||
list(behavior_spec.discrete_action_branches) |
|||
if behavior_spec.is_action_discrete() |
|||
else [behavior_spec.action_size] |
|||
) |
|||
self.vec_obs_size = sum( |
|||
shape[0] for shape in behavior_spec.observation_shapes if len(shape) == 1 |
|||
) |
|||
self.vis_obs_size = sum( |
|||
1 for shape in behavior_spec.observation_shapes if len(shape) == 3 |
|||
) |
|||
self.model_path = model_path |
|||
self.initialize_path = self.trainer_settings.init_path |
|||
self._keep_checkpoints = self.trainer_settings.keep_checkpoints |
|||
self.use_continuous_act = behavior_spec.is_action_continuous() |
|||
self.num_branches = self.behavior_spec.action_size |
|||
self.previous_action_dict: Dict[str, np.array] = {} |
|||
self.memory_dict: Dict[str, np.ndarray] = {} |
|||
self.normalize = trainer_settings.network_settings.normalize |
|||
self.use_recurrent = self.network_settings.memory is not None |
|||
self.load = load |
|||
self.h_size = self.network_settings.hidden_units |
|||
num_layers = self.network_settings.num_layers |
|||
if num_layers < 1: |
|||
num_layers = 1 |
|||
self.num_layers = num_layers |
|||
|
|||
self.vis_encode_type = self.network_settings.vis_encode_type |
|||
self.tanh_squash = tanh_squash |
|||
self.reparameterize = reparameterize |
|||
self.condition_sigma_on_obs = condition_sigma_on_obs |
|||
|
|||
self.m_size = 0 |
|||
self.sequence_length = 1 |
|||
if self.network_settings.memory is not None: |
|||
self.m_size = self.network_settings.memory.memory_size |
|||
self.sequence_length = self.network_settings.memory.sequence_length |
|||
|
|||
# Non-exposed parameters; these aren't exposed because they don't have a |
|||
# good explanation and usually shouldn't be touched. |
|||
self.log_std_min = -20 |
|||
self.log_std_max = 2 |
|||
|
|||
def make_empty_memory(self, num_agents): |
|||
""" |
|||
Creates empty memory for use with RNNs |
|||
:param num_agents: Number of agents. |
|||
:return: Numpy array of zeros. |
|||
""" |
|||
return np.zeros((num_agents, self.m_size), dtype=np.float32) |
|||
|
|||
def save_memories( |
|||
self, agent_ids: List[str], memory_matrix: Optional[np.ndarray] |
|||
) -> None: |
|||
if memory_matrix is None: |
|||
return |
|||
for index, agent_id in enumerate(agent_ids): |
|||
self.memory_dict[agent_id] = memory_matrix[index, :] |
|||
|
|||
def retrieve_memories(self, agent_ids: List[str]) -> np.ndarray: |
|||
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32) |
|||
for index, agent_id in enumerate(agent_ids): |
|||
if agent_id in self.memory_dict: |
|||
memory_matrix[index, :] = self.memory_dict[agent_id] |
|||
return memory_matrix |
|||
|
|||
def remove_memories(self, agent_ids): |
|||
for agent_id in agent_ids: |
|||
if agent_id in self.memory_dict: |
|||
self.memory_dict.pop(agent_id) |
|||
|
|||
def make_empty_previous_action(self, num_agents): |
|||
""" |
|||
Creates empty previous action for use with RNNs and discrete control |
|||
:param num_agents: Number of agents. |
|||
:return: Numpy array of zeros. |
|||
""" |
|||
return np.zeros((num_agents, self.num_branches), dtype=np.int) |
|||
|
|||
def save_previous_action( |
|||
self, agent_ids: List[str], action_matrix: Optional[np.ndarray] |
|||
) -> None: |
|||
if action_matrix is None: |
|||
return |
|||
for index, agent_id in enumerate(agent_ids): |
|||
self.previous_action_dict[agent_id] = action_matrix[index, :] |
|||
|
|||
def retrieve_previous_action(self, agent_ids: List[str]) -> np.ndarray: |
|||
action_matrix = np.zeros((len(agent_ids), self.num_branches), dtype=np.int) |
|||
for index, agent_id in enumerate(agent_ids): |
|||
if agent_id in self.previous_action_dict: |
|||
action_matrix[index, :] = self.previous_action_dict[agent_id] |
|||
return action_matrix |
|||
|
|||
def remove_previous_action(self, agent_ids): |
|||
for agent_id in agent_ids: |
|||
if agent_id in self.previous_action_dict: |
|||
self.previous_action_dict.pop(agent_id) |
|||
|
|||
raise NotImplementedError |
|||
|
|||
@abstractmethod |
|||
def update_normalization(self, vector_obs: np.ndarray) -> None: |
|||
pass |
|||
|
|||
@abstractmethod |
|||
def increment_step(self, n_steps): |
|||
pass |
|||
|
|||
@abstractmethod |
|||
def get_current_step(self): |
|||
pass |
|||
|
|||
@abstractmethod |
|||
def checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None: |
|||
pass |
|||
|
|||
@abstractmethod |
|||
def save(self, output_filepath: str, settings: SerializationSettings) -> None: |
|||
pass |
|||
|
|||
@abstractmethod |
|||
def load_weights(self, values: List[np.ndarray]) -> None: |
|||
pass |
|||
|
|||
@abstractmethod |
|||
def get_weights(self) -> List[np.ndarray]: |
|||
return [] |
|||
|
|||
@abstractmethod |
|||
def init_load_weights(self) -> None: |
|||
pass |
|
|||
#if UNITY_2020_1_OR_NEWER
|
|||
|
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
public class ArticulationBodyJointExtractor : IJointExtractor |
|||
{ |
|||
ArticulationBody m_Body; |
|||
|
|||
public ArticulationBodyJointExtractor(ArticulationBody body) |
|||
{ |
|||
m_Body = body; |
|||
} |
|||
|
|||
public int NumObservations(PhysicsSensorSettings settings) |
|||
{ |
|||
return NumObservations(m_Body, settings); |
|||
} |
|||
|
|||
public static int NumObservations(ArticulationBody body, PhysicsSensorSettings settings) |
|||
{ |
|||
if (body == null || body.isRoot) |
|||
{ |
|||
return 0; |
|||
} |
|||
|
|||
var totalCount = 0; |
|||
if (settings.UseJointPositionsAndAngles) |
|||
{ |
|||
switch (body.jointType) |
|||
{ |
|||
case ArticulationJointType.RevoluteJoint: |
|||
case ArticulationJointType.SphericalJoint: |
|||
// Both RevoluteJoint and SphericalJoint have all angular components.
|
|||
// We use sine and cosine of the angles for the observations.
|
|||
totalCount += 2 * body.dofCount; |
|||
break; |
|||
case ArticulationJointType.FixedJoint: |
|||
// Since FixedJoint can't moved, there aren't any interesting observations for it.
|
|||
break; |
|||
case ArticulationJointType.PrismaticJoint: |
|||
// One linear component
|
|||
totalCount += body.dofCount; |
|||
break; |
|||
} |
|||
} |
|||
|
|||
if (settings.UseJointForces) |
|||
{ |
|||
totalCount += body.dofCount; |
|||
} |
|||
|
|||
return totalCount; |
|||
} |
|||
|
|||
public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset) |
|||
{ |
|||
if (m_Body == null || m_Body.isRoot) |
|||
{ |
|||
return 0; |
|||
} |
|||
|
|||
var currentOffset = offset; |
|||
|
|||
// Write joint positions
|
|||
if (settings.UseJointPositionsAndAngles) |
|||
{ |
|||
switch (m_Body.jointType) |
|||
{ |
|||
case ArticulationJointType.RevoluteJoint: |
|||
case ArticulationJointType.SphericalJoint: |
|||
// All joint positions are angular
|
|||
for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++) |
|||
{ |
|||
var jointRotationRads = m_Body.jointPosition[dofIndex]; |
|||
writer[currentOffset++] = Mathf.Sin(jointRotationRads); |
|||
writer[currentOffset++] = Mathf.Cos(jointRotationRads); |
|||
} |
|||
break; |
|||
case ArticulationJointType.FixedJoint: |
|||
// No observations
|
|||
break; |
|||
case ArticulationJointType.PrismaticJoint: |
|||
writer[currentOffset++] = GetPrismaticValue(); |
|||
break; |
|||
} |
|||
} |
|||
|
|||
if (settings.UseJointForces) |
|||
{ |
|||
for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++) |
|||
{ |
|||
// take tanh to keep in [-1, 1]
|
|||
writer[currentOffset++] = (float) System.Math.Tanh(m_Body.jointForce[dofIndex]); |
|||
} |
|||
} |
|||
|
|||
return currentOffset - offset; |
|||
} |
|||
|
|||
float GetPrismaticValue() |
|||
{ |
|||
// Prismatic joints should have at most one free axis.
|
|||
bool limited = false; |
|||
var drive = m_Body.xDrive; |
|||
if (m_Body.linearLockX == ArticulationDofLock.LimitedMotion) |
|||
{ |
|||
drive = m_Body.xDrive; |
|||
limited = true; |
|||
} |
|||
else if (m_Body.linearLockY == ArticulationDofLock.LimitedMotion) |
|||
{ |
|||
drive = m_Body.yDrive; |
|||
limited = true; |
|||
} |
|||
else if (m_Body.linearLockZ == ArticulationDofLock.LimitedMotion) |
|||
{ |
|||
drive = m_Body.zDrive; |
|||
limited = true; |
|||
} |
|||
|
|||
var jointPos = m_Body.jointPosition[0]; |
|||
if (limited) |
|||
{ |
|||
// If locked, interpolate between the limits.
|
|||
var upperLimit = drive.upperLimit; |
|||
var lowerLimit = drive.lowerLimit; |
|||
if (upperLimit <= lowerLimit) |
|||
{ |
|||
// Invalid limits (probably equal), so don't try to lerp
|
|||
return 0; |
|||
} |
|||
var invLerped = Mathf.InverseLerp(lowerLimit, upperLimit, jointPos); |
|||
|
|||
// Convert [0, 1] -> [-1, 1]
|
|||
var normalized = 2.0f * invLerped - 1.0f; |
|||
return normalized; |
|||
} |
|||
// take tanh() to keep in [-1, 1]
|
|||
return (float) System.Math.Tanh(jointPos); |
|||
} |
|||
} |
|||
} |
|||
#endif
|
|
|||
fileFormatVersion: 2 |
|||
guid: 238d15f867b9c4ced9cef331b7420b27 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
/// <summary>
|
|||
/// Interface for generating observations from a physical joint or constraint.
|
|||
/// </summary>
|
|||
public interface IJointExtractor |
|||
{ |
|||
/// <summary>
|
|||
/// Determine the number of observations that would be generated for the particular joint
|
|||
/// using the provided PhysicsSensorSettings.
|
|||
/// </summary>
|
|||
/// <param name="settings"></param>
|
|||
/// <returns>Number of floats that will be written.</returns>
|
|||
int NumObservations(PhysicsSensorSettings settings); |
|||
|
|||
/// <summary>
|
|||
/// Write the observations to the ObservationWriter, starting at the specified offset.
|
|||
/// </summary>
|
|||
/// <param name="settings"></param>
|
|||
/// <param name="writer"></param>
|
|||
/// <param name="offset"></param>
|
|||
/// <returns>Number of floats that were written.</returns>
|
|||
int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset); |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 2d2a01ea194334a4682d5c8cad4a956b |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
public class RigidBodyJointExtractor : IJointExtractor |
|||
{ |
|||
Rigidbody m_Body; |
|||
Joint m_Joint; |
|||
|
|||
public RigidBodyJointExtractor(Rigidbody body) |
|||
{ |
|||
m_Body = body; |
|||
m_Joint = m_Body?.GetComponent<Joint>(); |
|||
} |
|||
|
|||
public int NumObservations(PhysicsSensorSettings settings) |
|||
{ |
|||
return NumObservations(m_Body, m_Joint, settings); |
|||
} |
|||
|
|||
public static int NumObservations(Rigidbody body, Joint joint, PhysicsSensorSettings settings) |
|||
{ |
|||
if(body == null || joint == null) |
|||
{ |
|||
return 0; |
|||
} |
|||
|
|||
var numObservations = 0; |
|||
if (settings.UseJointForces) |
|||
{ |
|||
// 3 force and 3 torque values
|
|||
numObservations += 6; |
|||
} |
|||
|
|||
return numObservations; |
|||
} |
|||
|
|||
public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset) |
|||
{ |
|||
if (m_Body == null || m_Joint == null) |
|||
{ |
|||
return 0; |
|||
} |
|||
|
|||
var currentOffset = offset; |
|||
if (settings.UseJointForces) |
|||
{ |
|||
// Take tanh of the forces and torques to ensure they're in [-1, 1]
|
|||
writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.x); |
|||
writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.y); |
|||
writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentForce.z); |
|||
|
|||
writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.x); |
|||
writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.y); |
|||
writer[currentOffset++] = (float)System.Math.Tanh(m_Joint.currentTorque.z); |
|||
} |
|||
return currentOffset - offset; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 5014d7ab95c6a44469f447b8a7019746 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
from typing import Any, Dict, Optional, List |
|||
from mlagents.tf_utils import tf |
|||
from mlagents_envs.timers import timed |
|||
from mlagents_envs.base_env import DecisionSteps, BehaviorSpec |
|||
from mlagents.trainers.models import EncoderType |
|||
from mlagents.trainers.models import ModelUtils |
|||
from mlagents.trainers.policy.tf_policy import TFPolicy |
|||
from mlagents.trainers.settings import TrainerSettings |
|||
from mlagents.trainers.distributions import ( |
|||
GaussianDistribution, |
|||
MultiCategoricalDistribution, |
|||
) |
|||
|
|||
EPSILON = 1e-6 # Small value to avoid divide by zero |
|||
|
|||
|
|||
class NNPolicy(TFPolicy): |
|||
def __init__( |
|||
self, |
|||
seed: int, |
|||
behavior_spec: BehaviorSpec, |
|||
trainer_params: TrainerSettings, |
|||
is_training: bool, |
|||
model_path: str, |
|||
load: bool, |
|||
tanh_squash: bool = False, |
|||
reparameterize: bool = False, |
|||
condition_sigma_on_obs: bool = True, |
|||
create_tf_graph: bool = True, |
|||
): |
|||
""" |
|||
Policy that uses a multilayer perceptron to map the observations to actions. Could |
|||
also use a CNN to encode visual input prior to the MLP. Supports discrete and |
|||
continuous action spaces, as well as recurrent networks. |
|||
:param seed: Random seed. |
|||
:param brain: Assigned BrainParameters object. |
|||
:param trainer_params: Defined training parameters. |
|||
:param is_training: Whether the model should be trained. |
|||
:param load: Whether a pre-trained model will be loaded or a new one created. |
|||
:param model_path: Path where the model should be saved and loaded. |
|||
:param tanh_squash: Whether to use a tanh function on the continuous output, or a clipped output. |
|||
:param reparameterize: Whether we are using the resampling trick to update the policy in continuous output. |
|||
""" |
|||
super().__init__(seed, behavior_spec, trainer_params, model_path, load) |
|||
self.grads = None |
|||
self.update_batch: Optional[tf.Operation] = None |
|||
num_layers = self.network_settings.num_layers |
|||
self.h_size = self.network_settings.hidden_units |
|||
if num_layers < 1: |
|||
num_layers = 1 |
|||
self.num_layers = num_layers |
|||
self.vis_encode_type = self.network_settings.vis_encode_type |
|||
self.tanh_squash = tanh_squash |
|||
self.reparameterize = reparameterize |
|||
self.condition_sigma_on_obs = condition_sigma_on_obs |
|||
self.trainable_variables: List[tf.Variable] = [] |
|||
|
|||
# Non-exposed parameters; these aren't exposed because they don't have a |
|||
# good explanation and usually shouldn't be touched. |
|||
self.log_std_min = -20 |
|||
self.log_std_max = 2 |
|||
if create_tf_graph: |
|||
self.create_tf_graph() |
|||
|
|||
def get_trainable_variables(self) -> List[tf.Variable]: |
|||
""" |
|||
Returns a List of the trainable variables in this policy. if create_tf_graph hasn't been called, |
|||
returns empty list. |
|||
""" |
|||
return self.trainable_variables |
|||
|
|||
def create_tf_graph(self) -> None: |
|||
""" |
|||
Builds the tensorflow graph needed for this policy. |
|||
""" |
|||
with self.graph.as_default(): |
|||
tf.set_random_seed(self.seed) |
|||
_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) |
|||
if len(_vars) > 0: |
|||
# We assume the first thing created in the graph is the Policy. If |
|||
# already populated, don't create more tensors. |
|||
return |
|||
|
|||
self.create_input_placeholders() |
|||
encoded = self._create_encoder( |
|||
self.visual_in, |
|||
self.processed_vector_in, |
|||
self.h_size, |
|||
self.num_layers, |
|||
self.vis_encode_type, |
|||
) |
|||
if self.use_continuous_act: |
|||
self._create_cc_actor( |
|||
encoded, |
|||
self.tanh_squash, |
|||
self.reparameterize, |
|||
self.condition_sigma_on_obs, |
|||
) |
|||
self.saliency = tf.reduce_mean( |
|||
tf.square(tf.gradients(self.output, self.vector_in)), axis=1 |
|||
) |
|||
|
|||
else: |
|||
self._create_dc_actor(encoded) |
|||
self.saliency = tf.reduce_mean( |
|||
tf.square(tf.gradients(self.output_pre, self.vector_in)), axis=1 |
|||
) |
|||
|
|||
self.trainable_variables = tf.get_collection( |
|||
tf.GraphKeys.TRAINABLE_VARIABLES, scope="policy" |
|||
) |
|||
self.trainable_variables += tf.get_collection( |
|||
tf.GraphKeys.TRAINABLE_VARIABLES, scope="lstm" |
|||
) # LSTMs need to be root scope for Barracuda export |
|||
|
|||
|
|||
self.inference_dict: Dict[str, tf.Tensor] = { |
|||
"action": self.output, |
|||
"log_probs": self.all_log_probs, |
|||
"entropy": self.entropy, |
|||
} |
|||
if self.use_continuous_act: |
|||
self.inference_dict["pre_action"] = self.output_pre |
|||
if self.use_recurrent: |
|||
self.inference_dict["memory_out"] = self.memory_out |
|||
|
|||
# We do an initialize to make the Policy usable out of the box. If an optimizer is needed, |
|||
# it will re-load the full graph |
|||
self._initialize_graph() |
|||
|
|||
|
|||
@timed |
|||
def evaluate( |
|||
self, decision_requests: DecisionSteps, global_agent_ids: List[str] |
|||
) -> Dict[str, Any]: |
|||
""" |
|||
Evaluates policy for the agent experiences provided. |
|||
:param decision_requests: DecisionSteps object containing inputs. |
|||
:param global_agent_ids: The global (with worker ID) agent ids of the data in the batched_step_result. |
|||
:return: Outputs from network as defined by self.inference_dict. |
|||
""" |
|||
feed_dict = { |
|||
self.batch_size_ph: len(decision_requests), |
|||
self.sequence_length_ph: 1, |
|||
} |
|||
if self.use_recurrent: |
|||
if not self.use_continuous_act: |
|||
feed_dict[self.prev_action] = self.retrieve_previous_action( |
|||
global_agent_ids |
|||
) |
|||
feed_dict[self.memory_in] = self.retrieve_memories(global_agent_ids) |
|||
feed_dict = self.fill_eval_dict(feed_dict, decision_requests) |
|||
run_out = self._execute_model(feed_dict, self.inference_dict) |
|||
return run_out |
|||
|
|||
def _create_encoder( |
|||
self, |
|||
visual_in: List[tf.Tensor], |
|||
vector_in: tf.Tensor, |
|||
h_size: int, |
|||
num_layers: int, |
|||
vis_encode_type: EncoderType, |
|||
) -> tf.Tensor: |
|||
""" |
|||
Creates an encoder for visual and vector observations. |
|||
:param h_size: Size of hidden linear layers. |
|||
:param num_layers: Number of hidden linear layers. |
|||
:param vis_encode_type: Type of visual encoder to use if visual input. |
|||
:return: The hidden layer (tf.Tensor) after the encoder. |
|||
""" |
|||
with tf.variable_scope("policy"): |
|||
encoded = ModelUtils.create_observation_streams( |
|||
self.visual_in, |
|||
self.processed_vector_in, |
|||
1, |
|||
h_size, |
|||
num_layers, |
|||
vis_encode_type, |
|||
)[0] |
|||
return encoded |
|||
|
|||
def _create_cc_actor( |
|||
self, |
|||
encoded: tf.Tensor, |
|||
tanh_squash: bool = False, |
|||
reparameterize: bool = False, |
|||
condition_sigma_on_obs: bool = True, |
|||
) -> None: |
|||
""" |
|||
Creates Continuous control actor-critic model. |
|||
:param h_size: Size of hidden linear layers. |
|||
:param num_layers: Number of hidden linear layers. |
|||
:param vis_encode_type: Type of visual encoder to use if visual input. |
|||
:param tanh_squash: Whether to use a tanh function, or a clipped output. |
|||
:param reparameterize: Whether we are using the resampling trick to update the policy. |
|||
""" |
|||
if self.use_recurrent: |
|||
self.memory_in = tf.placeholder( |
|||
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in" |
|||
) |
|||
hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder( |
|||
encoded, self.memory_in, self.sequence_length_ph, name="lstm_policy" |
|||
) |
|||
|
|||
self.memory_out = tf.identity(memory_policy_out, name="recurrent_out") |
|||
else: |
|||
hidden_policy = encoded |
|||
|
|||
with tf.variable_scope("policy"): |
|||
distribution = GaussianDistribution( |
|||
hidden_policy, |
|||
self.act_size, |
|||
reparameterize=reparameterize, |
|||
tanh_squash=tanh_squash, |
|||
condition_sigma=condition_sigma_on_obs, |
|||
) |
|||
|
|||
if tanh_squash: |
|||
self.output_pre = distribution.sample |
|||
self.output = tf.identity(self.output_pre, name="action") |
|||
else: |
|||
self.output_pre = distribution.sample |
|||
# Clip and scale output to ensure actions are always within [-1, 1] range. |
|||
output_post = tf.clip_by_value(self.output_pre, -3, 3) / 3 |
|||
self.output = tf.identity(output_post, name="action") |
|||
|
|||
self.selected_actions = tf.stop_gradient(self.output) |
|||
|
|||
self.all_log_probs = tf.identity(distribution.log_probs, name="action_probs") |
|||
self.entropy = distribution.entropy |
|||
|
|||
# We keep these tensors the same name, but use new nodes to keep code parallelism with discrete control. |
|||
self.total_log_probs = distribution.total_log_probs |
|||
|
|||
def _create_dc_actor(self, encoded: tf.Tensor) -> None: |
|||
""" |
|||
Creates Discrete control actor-critic model. |
|||
:param h_size: Size of hidden linear layers. |
|||
:param num_layers: Number of hidden linear layers. |
|||
:param vis_encode_type: Type of visual encoder to use if visual input. |
|||
""" |
|||
if self.use_recurrent: |
|||
self.prev_action = tf.placeholder( |
|||
shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action" |
|||
) |
|||
prev_action_oh = tf.concat( |
|||
[ |
|||
tf.one_hot(self.prev_action[:, i], self.act_size[i]) |
|||
for i in range(len(self.act_size)) |
|||
], |
|||
axis=1, |
|||
) |
|||
hidden_policy = tf.concat([encoded, prev_action_oh], axis=1) |
|||
|
|||
self.memory_in = tf.placeholder( |
|||
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in" |
|||
) |
|||
hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder( |
|||
hidden_policy, |
|||
self.memory_in, |
|||
self.sequence_length_ph, |
|||
name="lstm_policy", |
|||
) |
|||
|
|||
self.memory_out = tf.identity(memory_policy_out, "recurrent_out") |
|||
else: |
|||
hidden_policy = encoded |
|||
|
|||
self.action_masks = tf.placeholder( |
|||
shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks" |
|||
) |
|||
|
|||
with tf.variable_scope("policy"): |
|||
distribution = MultiCategoricalDistribution( |
|||
hidden_policy, self.act_size, self.action_masks |
|||
) |
|||
# It's important that we are able to feed_dict a value into this tensor to get the |
|||
# right one-hot encoding, so we can't do identity on it. |
|||
self.output = distribution.sample |
|||
self.all_log_probs = tf.identity(distribution.log_probs, name="action") |
|||
self.selected_actions = tf.stop_gradient( |
|||
distribution.sample_onehot |
|||
) # In discrete, these are onehot |
|||
self.entropy = distribution.entropy |
|||
self.total_log_probs = distribution.total_log_probs |
撰写
预览
正在加载...
取消
保存
Reference in new issue