浏览代码

Hotfixes for Release 0.15.1 (#3698)

* [bug-fix] Increase height of wall in CrawlerStatic (#3650)

* [bug-fix] Improve performance for PPO with continuous actions (#3662)

* Corrected a typo in a name of a function (#3670)

OnEpsiodeBegin was corrected to OnEpisodeBegin in Migrating.md document

* Add Academy.AutomaticSteppingEnabled to migration (#3666)

* Fix editor port in Dockerfile (#3674)

* Hotfix memory leak on Python (#3664)

* Hotfix memory leak on Python

* Fixing

* Fixing a bug in the heuristic policy. A decision should not be requested when the agent is done

* [bug-fix] Make Python able to deal with 0-step episodes (#3671)

* adding some comments

Co-authored-by: Ervin T <ervin@unity3d.com>

* Remove vis_encode_type from list of required (#3677)

* Update changelog (#3678)

* Shorten timeout duration for environment close (#3679)

The timeout duration for closing an environment was set to the
same duration as the timeout when waiting ...
/release-0.15.1
GitHub 5 年前
当前提交
ec278616
共有 45 个文件被更改,包括 320 次插入147 次删除
  1. 2
      .pylintrc
  2. 6
      Dockerfile
  3. 13
      Project/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab
  4. 3
      README.md
  5. 11
      com.unity.ml-agents/CHANGELOG.md
  6. 2
      com.unity.ml-agents/Runtime/Academy.cs
  7. 3
      com.unity.ml-agents/Runtime/Agent.cs
  8. 17
      com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
  9. 5
      com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
  10. 2
      com.unity.ml-agents/package.json
  11. 4
      docs/Migrating.md
  12. 2
      gym-unity/gym_unity/__init__.py
  13. 32
      gym-unity/gym_unity/envs/__init__.py
  14. 48
      gym-unity/gym_unity/tests/test_gym.py
  15. 2
      ml-agents-envs/mlagents_envs/__init__.py
  16. 23
      ml-agents-envs/mlagents_envs/environment.py
  17. 4
      ml-agents-envs/mlagents_envs/side_channel/outgoing_message.py
  18. 4
      ml-agents-envs/mlagents_envs/side_channel/side_channel.py
  19. 5
      ml-agents/mlagents/model_serialization.py
  20. 2
      ml-agents/mlagents/trainers/__init__.py
  21. 31
      ml-agents/mlagents/trainers/agent_processor.py
  22. 5
      ml-agents/mlagents/trainers/components/reward_signals/__init__.py
  23. 4
      ml-agents/mlagents/trainers/curriculum.py
  24. 38
      ml-agents/mlagents/trainers/distributions.py
  25. 5
      ml-agents/mlagents/trainers/env_manager.py
  26. 5
      ml-agents/mlagents/trainers/ghost/trainer.py
  27. 19
      ml-agents/mlagents/trainers/learn.py
  28. 4
      ml-agents/mlagents/trainers/meta_curriculum.py
  29. 1
      ml-agents/mlagents/trainers/policy/nn_policy.py
  30. 23
      ml-agents/mlagents/trainers/policy/tf_policy.py
  31. 4
      ml-agents/mlagents/trainers/ppo/trainer.py
  32. 5
      ml-agents/mlagents/trainers/sac/optimizer.py
  33. 6
      ml-agents/mlagents/trainers/sac/trainer.py
  34. 7
      ml-agents/mlagents/trainers/stats.py
  35. 5
      ml-agents/mlagents/trainers/subprocess_env_manager.py
  36. 2
      ml-agents/mlagents/trainers/tests/simple_test_envs.py
  37. 9
      ml-agents/mlagents/trainers/tests/test_agent_processor.py
  38. 10
      ml-agents/mlagents/trainers/tests/test_distributions.py
  39. 23
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  40. 5
      ml-agents/mlagents/trainers/trainer/trainer.py
  41. 4
      ml-agents/mlagents/trainers/trainer_controller.py
  42. 5
      ml-agents/mlagents/trainers/trainer_util.py
  43. 1
      setup.cfg
  44. 46
      ml-agents-envs/mlagents_envs/logging_util.py
  45. 10
      ml-agents/mlagents/logging_util.py

2
.pylintrc


# Appears to be https://github.com/PyCQA/pylint/issues/2981
W0201,
# Using the global statement
W0603,

6
Dockerfile


WORKDIR /ml-agents
RUN pip install -e .
# port 5005 is the port used in in Editor training.
EXPOSE 5005
# Port 5004 is the port used in in Editor training.
# Environments will start from port 5005,
# so allow enough ports for several environments.
EXPOSE 5004-5050
ENTRYPOINT ["mlagents-learn"]

13
Project/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab


m_InferenceDevice: 0
m_BehaviorType: 0
m_BehaviorName: CrawlerStatic
m_TeamID: 0
m_useChildSensors: 1
TeamId: 0
m_UseChildSensors: 1
--- !u!114 &114230237520033992
MonoBehaviour:
m_ObjectHideFlags: 0

m_Script: {fileID: 11500000, guid: 2f37c30a5e8d04117947188818902ef3, type: 3}
m_Name:
m_EditorClassIdentifier:
agentParameters:
maxStep: 0
hasUpgradedFromAgentParameters: 1
maxStep: 5000
target: {fileID: 4749909135913778}
ground: {fileID: 4856650706546504}

m_Name:
m_EditorClassIdentifier:
DecisionPeriod: 5
RepeatAction: 0
TakeActionsBetweenDecisions: 0
offsetStep: 0
--- !u!1 &1492926997393242
GameObject:

m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 1995322274649904}
m_LocalRotation: {x: 0, y: -0, z: -0, w: 1}
m_LocalPosition: {x: -0, y: 0.5, z: 0}
m_LocalScale: {x: 0.01, y: 0.01, z: 0.01}
m_LocalPosition: {x: -0, y: 1.5, z: 0}
m_LocalScale: {x: 0.01, y: 0.03, z: 0.01}
m_Children: []
m_Father: {fileID: 4924174722017668}
m_RootOrder: 1

3
README.md


* Train using concurrent Unity environment instances
## Releases & Documentation
**Our latest, stable release is 0.15.0. Click
**Our latest, stable release is 0.15.1. Click
[here](docs/Readme.md) to
get started with the latest release of ML-Agents.**

| **Version** | **Release Date** | **Source** | **Documentation** | **Download** |
|:-------:|:------:|:-------------:|:-------:|:------------:|
| **0.15.0** | March 18, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/0.15.0) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/0.15.0/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/0.15.0.zip) |
| **0.14.1** | February 26, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/0.14.1) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/0.14.1/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/0.14.1.zip) |
| **0.14.0** | February 13, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/0.14.0) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/0.14.0/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/0.14.0.zip) |
| **0.13.1** | January 21, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/0.13.1) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/0.13.1/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/0.13.1.zip) |

11
com.unity.ml-agents/CHANGELOG.md


and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [0.15.1-preview] - 2020-03-30
### Bug Fixes
- Raise the wall in CrawlerStatic scene to prevent Agent from falling off. (#3650)
- Fixed an issue where specifying `vis_encode_type` was required only for SAC. (#3677)
- Fixed the reported entropy values for continuous actions (#3684)
- Fixed an issue where switching models using `SetModel()` during training would use an excessive amount of memory. (#3664)
- Environment subprocesses now close immediately on timeout or wrong API version. (#3679)
- Fixed an issue in the gym wrapper that would raise an exception if an Agent called EndEpisode multiple times in the same step. (#3700)
- Fixed an issue where logging output was not visible; logging levels are now set consistently (#3703).
## [0.15.0-preview] - 2020-03-18
### Major Changes
- `Agent.CollectObservations` now takes a VectorSensor argument. (#3352, #3389)

2
com.unity.ml-agents/Runtime/Academy.cs


/// Unity package version of com.unity.ml-agents.
/// This must match the version string in package.json and is checked in a unit test.
/// </summary>
internal const string k_PackageVersion = "0.15.0-preview";
internal const string k_PackageVersion = "0.15.1-preview";
const int k_EditorTrainingPort = 5004;

3
com.unity.ml-agents/Runtime/Agent.cs


void NotifyAgentDone(DoneReason doneReason)
{
m_Info.episodeId = m_EpisodeId;
m_Info.reward = m_Reward;
m_Info.done = true;
m_Info.maxStepReached = doneReason == DoneReason.MaxStepReached;

// If everything is the same, don't make any changes.
return;
}
NotifyAgentDone(DoneReason.Disabled);
m_PolicyFactory.model = model;
m_PolicyFactory.inferenceDevice = inferenceDevice;
m_PolicyFactory.behaviorName = behaviorName;

17
com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs


{
if (m_CurrentUnityRlOutput.AgentInfos.ContainsKey(behaviorName))
{
if (output == null)
if (m_CurrentUnityRlOutput.AgentInfos[behaviorName].CalculateSize() > 0)
output = new UnityRLInitializationOutputProto();
}
// Only send the BrainParameters if there is a non empty list of
// AgentInfos ready to be sent.
// This is to ensure that The Python side will always have a first
// observation when receiving the BrainParameters
if (output == null)
{
output = new UnityRLInitializationOutputProto();
}
var brainParameters = m_UnsentBrainKeys[behaviorName];
output.BrainParameters.Add(brainParameters.ToProto(behaviorName, true));
var brainParameters = m_UnsentBrainKeys[behaviorName];
output.BrainParameters.Add(brainParameters.ToProto(behaviorName, true));
}
}
}

5
com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs


public void RequestDecision(AgentInfo info, List<ISensor> sensors)
{
StepSensors(sensors);
m_LastDecision = m_Heuristic.Invoke();
if (!info.done)
{
m_LastDecision = m_Heuristic.Invoke();
}
}
/// <inheritdoc />

2
com.unity.ml-agents/package.json


{
"name": "com.unity.ml-agents",
"displayName": "ML Agents",
"version": "0.15.0-preview",
"version": "0.15.1-preview",
"unity": "2018.4",
"description": "Add interactivity to your game with Machine Learning Agents trained using Deep Reinforcement Learning.",
"dependencies": {

4
docs/Migrating.md


* The interface for SideChannels was changed:
* In C#, `OnMessageReceived` now takes a `IncomingMessage` argument, and `QueueMessageToSend` takes an `OutgoingMessage` argument.
* In python, `on_message_received` now takes a `IncomingMessage` argument, and `queue_message_to_send` takes an `OutgoingMessage` argument.
* Automatic stepping for Academy is now controlled from the AutomaticSteppingEnabled property.
### Steps to Migrate
* Add the `using MLAgents.Sensors;` in addition to `using MLAgents;` on top of your Agent's script.

* We strongly recommend replacing the following methods with their new equivalent as they will be removed in a later release:
* `InitializeAgent()` to `Initialize()`
* `AgentAction()` to `OnActionReceived()`
* `AgentReset()` to `OnEpsiodeBegin()`
* `AgentReset()` to `OnEpisodeBegin()`
* Replace calls to Academy.EnableAutomaticStepping()/DisableAutomaticStepping() with Academy.AutomaticSteppingEnabled = true/false.
## Migrating from 0.13 to 0.14

2
gym-unity/gym_unity/__init__.py


__version__ = "0.15.0"
__version__ = "0.15.1"

32
gym-unity/gym_unity/envs/__init__.py


import logging
import itertools
import numpy as np
from typing import Any, Dict, List, Optional, Tuple, Union

from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.base_env import BatchedStepResult
from mlagents_envs import logging_util
class UnityGymException(error.Error):

pass
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("gym_unity")
logger = logging_util.get_logger(__name__)
logging_util.set_log_level(logging_util.INFO)
GymSingleStepResult = Tuple[np.ndarray, float, bool, Dict]
GymMultiStepResult = Tuple[List[np.ndarray], List[float], List[bool], Dict]

def _sanitize_info(self, step_result: BatchedStepResult) -> BatchedStepResult:
n_extra_agents = step_result.n_agents() - self._n_agents
if n_extra_agents < 0 or n_extra_agents > self._n_agents:
if n_extra_agents < 0:
# or too many requested a decision
raise UnityGymException(
"The number of agents in the scene does not match the expected number."
)

# only cares about the ordering.
for index, agent_id in enumerate(step_result.agent_id):
if not self._previous_step_result.contains_agent(agent_id):
if step_result.done[index]:
# If the Agent is already done (e.g. it ended its epsiode twice in one step)
# Don't try to register it here.
continue
# Register this agent, and get the reward of the previous agent that
# was in its index, so that we can return it to the gym.
last_reward = self.agent_mapper.register_new_agent_id(agent_id)

"""
Declare the agent done with the corresponding final reward.
"""
gym_index = self._agent_id_to_gym_index.pop(agent_id)
self._done_agents_index_to_last_reward[gym_index] = reward
if agent_id in self._agent_id_to_gym_index:
gym_index = self._agent_id_to_gym_index.pop(agent_id)
self._done_agents_index_to_last_reward[gym_index] = reward
else:
# Agent was never registered in the first place (e.g. EndEpisode called multiple times)
pass
def register_new_agent_id(self, agent_id: int) -> float:
"""

self._gym_id_order = list(agent_ids)
def mark_agent_done(self, agent_id: int, reward: float) -> None:
gym_index = self._gym_id_order.index(agent_id)
self._done_agents_index_to_last_reward[gym_index] = reward
self._gym_id_order[gym_index] = -1
try:
gym_index = self._gym_id_order.index(agent_id)
self._done_agents_index_to_last_reward[gym_index] = reward
self._gym_id_order[gym_index] = -1
except ValueError:
# Agent was never registered in the first place (e.g. EndEpisode called multiple times)
pass
def register_new_agent_id(self, agent_id: int) -> float:
original_index = self._gym_id_order.index(-1)

48
gym-unity/gym_unity/tests/test_gym.py


assert expected_agent_id == agent_id
@mock.patch("gym_unity.envs.UnityEnvironment")
def test_sanitize_action_new_agent_done(mock_env):
mock_spec = create_mock_group_spec(
vector_action_space_type="discrete", vector_action_space_size=[2, 2, 3]
)
mock_step = create_mock_vector_step_result(num_agents=3)
mock_step.agent_id = np.array(range(5))
setup_mock_unityenvironment(mock_env, mock_spec, mock_step)
env = UnityEnv(" ", use_visual=False, multiagent=True)
received_step_result = create_mock_vector_step_result(num_agents=7)
received_step_result.agent_id = np.array(range(7))
# agent #3 (id = 2) is Done
# so is the "new" agent (id = 5)
done = [False] * 7
done[2] = True
done[5] = True
received_step_result.done = np.array(done)
sanitized_result = env._sanitize_info(received_step_result)
for expected_agent_id, agent_id in zip([0, 1, 6, 3, 4], sanitized_result.agent_id):
assert expected_agent_id == agent_id
@mock.patch("gym_unity.envs.UnityEnvironment")
def test_sanitize_action_single_agent_multiple_done(mock_env):
mock_spec = create_mock_group_spec(
vector_action_space_type="discrete", vector_action_space_size=[2, 2, 3]
)
mock_step = create_mock_vector_step_result(num_agents=1)
mock_step.agent_id = np.array(range(1))
setup_mock_unityenvironment(mock_env, mock_spec, mock_step)
env = UnityEnv(" ", use_visual=False, multiagent=False)
received_step_result = create_mock_vector_step_result(num_agents=3)
received_step_result.agent_id = np.array(range(3))
# original agent (id = 0) is Done
# so is the "new" agent (id = 1)
done = [True, True, False]
received_step_result.done = np.array(done)
sanitized_result = env._sanitize_info(received_step_result)
for expected_agent_id, agent_id in zip([2], sanitized_result.agent_id):
assert expected_agent_id == agent_id
# Helper methods

# Mark some agents as done with their last rewards.
mapper.mark_agent_done(1001, 42.0)
mapper.mark_agent_done(1004, 1337.0)
# Make sure we can handle an unknown agent id being marked done.
# This can happen when an agent ends an episode on the same step it starts.
mapper.mark_agent_done(9999, -1.0)
# Now add new agents, and get the rewards of the agent they replaced.
old_reward1 = mapper.register_new_agent_id(2001)

2
ml-agents-envs/mlagents_envs/__init__.py


__version__ = "0.15.0"
__version__ = "0.15.1"

23
ml-agents-envs/mlagents_envs/environment.py


import atexit
import glob
import uuid
import logging
import numpy as np
import os
import subprocess

from mlagents_envs.logging_util import get_logger
from mlagents_envs.side_channel.side_channel import SideChannel, IncomingMessage
from mlagents_envs.base_env import (

import struct
logger = logging.getLogger("mlagents_envs")
logger = get_logger(__name__)
class UnityEnvironment(BaseEnv):

aca_output = self.send_academy_parameters(rl_init_parameters_in)
aca_params = aca_output.rl_initialization_output
except UnityTimeOutException:
self._close()
self._close(0)
self._close()
self._close(0)
raise UnityEnvironmentException(
f"The communication API version is not compatible between Unity and python. "
f"Python API: {UnityEnvironment.API_VERSION}, Unity API: {unity_communicator_version}.\n "

def executable_launcher(self, file_name, docker_training, no_graphics, args):
launch_string = self.validate_environment_path(file_name)
if launch_string is None:
self._close()
self._close(0)
raise UnityEnvironmentException(
f"Couldn't launch the {file_name} environment. Provided filename does not match any environments."
)

else:
raise UnityEnvironmentException("No Unity environment is loaded.")
def _close(self):
def _close(self, timeout: Optional[int] = None) -> None:
"""
Close the communicator and environment subprocess (if necessary).
:int timeout: [Optional] Number of seconds to wait for the environment to shut down before
force-killing it. Defaults to `self.timeout_wait`.
"""
if timeout is None:
timeout = self.timeout_wait
self.proc1.wait(timeout=self.timeout_wait)
self.proc1.wait(timeout=timeout)
signal_name = self.returncode_to_signal_name(self.proc1.returncode)
signal_name = f" ({signal_name})" if signal_name else ""
return_info = f"Environment shut down with return code {self.proc1.returncode}{signal_name}."

4
ml-agents-envs/mlagents_envs/side_channel/outgoing_message.py


from typing import List
import struct
import logging
from mlagents_envs.logging_util import get_logger
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
class OutgoingMessage:

4
ml-agents-envs/mlagents_envs/side_channel/side_channel.py


from abc import ABC, abstractmethod
from typing import List
import uuid
import logging
from mlagents_envs.logging_util import get_logger
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
class SideChannel(ABC):

5
ml-agents/mlagents/model_serialization.py


from distutils.util import strtobool
import os
import logging
from typing import Any, List, Set, NamedTuple
from distutils.version import LooseVersion

from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util
from mlagents_envs.logging_util import get_logger
from mlagents.trainers import tensorflow_to_barracuda as tf2bc
if LooseVersion(tf.__version__) < LooseVersion("1.12.0"):

logger = get_logger(__name__)
logger = logging.getLogger("mlagents.trainers")
POSSIBLE_INPUT_NODES = frozenset(
[

2
ml-agents/mlagents/trainers/__init__.py


__version__ = "0.15.0"
__version__ = "0.15.1"

31
ml-agents/mlagents/trainers/agent_processor.py


import sys
from typing import List, Dict, Deque, TypeVar, Generic, Tuple, Set
from typing import List, Dict, Deque, TypeVar, Generic, Tuple, Any
from collections import defaultdict, Counter, deque
from mlagents_envs.base_env import BatchedStepResult, StepResult

for _entropy in take_action_outputs["entropy"]:
self.stats_reporter.add_stat("Policy/Entropy", _entropy)
terminated_agents: Set[str] = set()
# Make unique agent_ids that are global across workers
action_global_agent_ids = [
get_global_agent_id(worker_id, ag_id) for ag_id in previous_action.agent_ids

stored_take_action_outputs = self.last_take_action_outputs.get(
global_id, None
)
if stored_agent_step is not None and stored_take_action_outputs is not None:
# We know the step is from the same worker, so use the local agent id.
obs = stored_agent_step.obs

traj_queue.put(trajectory)
self.experience_buffers[global_id] = []
if curr_agent_step.done:
# Record episode length for agents which have had at least
# 1 step. Done after reset ignored.
self.stats_reporter.add_stat(
"Environment/Cumulative Reward",
self.episode_rewards.get(global_id, 0),

self.episode_steps.get(global_id, 0),
)
terminated_agents.add(global_id)
elif not curr_agent_step.done:
self.episode_steps[global_id] += 1

batched_step_result.agent_id_to_index[_id],
)
for terminated_id in terminated_agents:
self._clean_agent_data(terminated_id)
# Delete all done agents, regardless of if they had a 0-length episode.
if curr_agent_step.done:
self._clean_agent_data(global_id)
for _gid in action_global_agent_ids:
# If the ID doesn't have a last step result, the agent just reset,

"""
Removes the data for an Agent.
"""
del self.experience_buffers[global_id]
del self.last_take_action_outputs[global_id]
del self.last_step_result[global_id]
del self.episode_steps[global_id]
del self.episode_rewards[global_id]
self._safe_delete(self.experience_buffers, global_id)
self._safe_delete(self.last_take_action_outputs, global_id)
self._safe_delete(self.last_step_result, global_id)
self._safe_delete(self.episode_steps, global_id)
self._safe_delete(self.episode_rewards, global_id)
def _safe_delete(self, my_dictionary: Dict[Any, Any], key: Any) -> None:
"""
Safe removes data from a dictionary. If not found,
don't delete.
"""
if key in my_dictionary:
del my_dictionary[key]
def publish_trajectory_queue(
self, trajectory_queue: "AgentManagerQueue[Trajectory]"

5
ml-agents/mlagents/trainers/components/reward_signals/__init__.py


import logging
from typing import Any, Dict, List
from collections import namedtuple
import numpy as np

from mlagents_envs.logging_util import get_logger
logger = logging.getLogger("mlagents.trainers")
logger = get_logger(__name__)
RewardSignalResult = namedtuple(
"RewardSignalResult", ["scaled_reward", "unscaled_reward"]

4
ml-agents/mlagents/trainers/curriculum.py


from .exception import CurriculumConfigError, CurriculumLoadingError
import logging
from mlagents_envs.logging_util import get_logger
logger = logging.getLogger("mlagents.trainers")
logger = get_logger(__name__)
class Curriculum:

38
ml-agents/mlagents/trainers/distributions.py


act_size: List[int],
reparameterize: bool = False,
tanh_squash: bool = False,
condition_sigma: bool = True,
log_sigma_min: float = -20,
log_sigma_max: float = 2,
):

:param log_sigma_max: Maximum log standard deviation to clip by.
"""
encoded = self._create_mu_log_sigma(
logits, act_size, log_sigma_min, log_sigma_max
logits,
act_size,
log_sigma_min,
log_sigma_max,
condition_sigma=condition_sigma,
)
self._sampled_policy = self._create_sampled_policy(encoded)
if not reparameterize:

act_size: List[int],
log_sigma_min: float,
log_sigma_max: float,
condition_sigma: bool,
) -> "GaussianDistribution.MuSigmaTensors":
mu = tf.layers.dense(

reuse=tf.AUTO_REUSE,
)
# Policy-dependent log_sigma_sq
log_sigma = tf.layers.dense(
logits,
act_size[0],
activation=None,
name="log_std",
kernel_initializer=ModelUtils.scaled_init(0.01),
)
if condition_sigma:
# Policy-dependent log_sigma_sq
log_sigma = tf.layers.dense(
logits,
act_size[0],
activation=None,
name="log_std",
kernel_initializer=ModelUtils.scaled_init(0.01),
)
else:
log_sigma = tf.get_variable(
"log_std",
[act_size[0]],
dtype=tf.float32,
initializer=tf.zeros_initializer(),
)
log_sigma = tf.clip_by_value(log_sigma, log_sigma_min, log_sigma_max)
sigma = tf.exp(log_sigma)
return self.MuSigmaTensors(mu, log_sigma, sigma)

self, encoded: "GaussianDistribution.MuSigmaTensors"
) -> tf.Tensor:
single_dim_entropy = 0.5 * tf.reduce_mean(
tf.log(2 * np.pi * np.e) + tf.square(encoded.log_sigma)
tf.log(2 * np.pi * np.e) + 2 * encoded.log_sigma
)
# Make entropy the right shape
return tf.ones_like(tf.reshape(encoded.mu[:, 0], [-1])) * single_dim_entropy

Adjust probabilities for squashed sample before output
"""
probs -= tf.log(1 - squashed_policy ** 2 + EPSILON)
return probs
adjusted_probs = probs - tf.log(1 - squashed_policy ** 2 + EPSILON)
return adjusted_probs
@property
def total_log_probs(self) -> tf.Tensor:

5
ml-agents/mlagents/trainers/env_manager.py


from abc import ABC, abstractmethod
import logging
from typing import List, Dict, NamedTuple, Iterable
from mlagents_envs.base_env import BatchedStepResult, AgentGroupSpec, AgentGroup
from mlagents.trainers.brain import BrainParameters

from mlagents_envs.logging_util import get_logger
logger = logging.getLogger("mlagents.trainers")
logger = get_logger(__name__)
class EnvironmentStep(NamedTuple):

5
ml-agents/mlagents/trainers/ghost/trainer.py


from typing import Deque, Dict, List, Any, cast
import numpy as np
import logging
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.policy import Policy
from mlagents.trainers.policy.tf_policy import TFPolicy

from mlagents.trainers.agent_processor import AgentManagerQueue
logger = logging.getLogger("mlagents.trainers")
logger = get_logger(__name__)
class GhostTrainer(Trainer):

19
ml-agents/mlagents/trainers/learn.py


# # Unity ML-Agents Toolkit
import logging
import argparse
import os

from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig
from mlagents_envs.exception import UnityEnvironmentException
from mlagents_envs.timers import hierarchical_timer, get_timer_tree
from mlagents.logging_util import create_logger
from mlagents_envs import logging_util
logger = logging_util.get_logger(__name__)
def _create_parser():

with open(timing_path, "w") as f:
json.dump(get_timer_tree(), f, indent=4)
except FileNotFoundError:
logging.warning(
logger.warning(
f"Unable to save to {timing_path}. Make sure the directory exists"
)

shutil.copyfile(src_f, dst_f)
os.chmod(dst_f, 0o775) # Make executable
except Exception as e:
logging.getLogger("mlagents.trainers").info(e)
logger.info(e)
env_path = "/ml-agents/{env_path}".format(env_path=env_path)
return env_path

print(get_version_string())
if options.debug:
log_level = logging.DEBUG
log_level = logging_util.DEBUG
log_level = logging.INFO
log_level = logging_util.INFO
trainer_logger = create_logger("mlagents.trainers", log_level)
logging_util.set_log_level(log_level)
trainer_logger.debug("Configuration for this run:")
trainer_logger.debug(json.dumps(options._asdict(), indent=4))
logger.debug("Configuration for this run:")
logger.debug(json.dumps(options._asdict(), indent=4))
run_seed = options.seed
if options.cpu:

4
ml-agents/mlagents/trainers/meta_curriculum.py


from typing import Dict, Set
from mlagents.trainers.curriculum import Curriculum
import logging
from mlagents_envs.logging_util import get_logger
logger = logging.getLogger("mlagents.trainers")
logger = get_logger(__name__)
class MetaCurriculum:

1
ml-agents/mlagents/trainers/policy/nn_policy.py


self.act_size,
reparameterize=reparameterize,
tanh_squash=tanh_squash,
condition_sigma=condition_sigma_on_obs,
)
if tanh_squash:

23
ml-agents/mlagents/trainers/policy/tf_policy.py


import logging
from typing import Any, Dict, List, Optional
import abc
import numpy as np

from mlagents_envs.logging_util import get_logger
from mlagents.trainers.policy import Policy
from mlagents.trainers.action_info import ActionInfo
from mlagents.trainers.trajectory import SplitObservations

logger = logging.getLogger("mlagents.trainers")
logger = get_logger(__name__)
class UnityPolicyException(UnityException):

"""
if batched_step_result.n_agents() == 0:
return ActionInfo.empty()
agents_done = [
agent
for agent, done in zip(
batched_step_result.agent_id, batched_step_result.done
)
if done
]
self.remove_memories(agents_done)
self.remove_previous_action(agents_done)
global_agent_ids = [
get_global_agent_id(worker_id, int(agent_id))

def create_input_placeholders(self):
with self.graph.as_default():
self.global_step, self.increment_step_op, self.steps_to_increment = (
ModelUtils.create_global_steps()
)
(
self.global_step,
self.increment_step_op,
self.steps_to_increment,
) = ModelUtils.create_global_steps()
self.visual_in = ModelUtils.create_visual_input_placeholders(
self.brain.camera_resolutions
)

4
ml-agents/mlagents/trainers/ppo/trainer.py


# ## ML-Agent Learning (PPO)
# Contains an implementation of PPO as described in: https://arxiv.org/abs/1707.06347
import logging
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.policy.nn_policy import NNPolicy
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.brain import BrainParameters

from mlagents.trainers.exception import UnityTrainerException
logger = logging.getLogger("mlagents.trainers")
logger = get_logger(__name__)
class PPOTrainer(RLTrainer):

5
ml-agents/mlagents/trainers/sac/optimizer.py


import logging
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.sac.network import SACPolicyNetwork, SACTargetNetwork
from mlagents.trainers.models import LearningRateSchedule, EncoderType, ModelUtils
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer

EPSILON = 1e-6 # Small value to avoid divide by zero
logger = logging.getLogger("mlagents.trainers")
logger = get_logger(__name__)
POLICY_SCOPE = ""
TARGET_SCOPE = "target_network"

"q1_loss": self.q1_loss,
"q2_loss": self.q2_loss,
"entropy_coef": self.ent_coef,
"entropy": self.policy.entropy,
"update_batch": self.update_batch_policy,
"update_value": self.update_batch_value,
"update_entropy": self.update_batch_entropy,

6
ml-agents/mlagents/trainers/sac/trainer.py


# Contains an implementation of SAC as described in https://arxiv.org/abs/1801.01290
# and implemented in https://github.com/hill-a/stable-baselines
import logging
from collections import defaultdict
from typing import Dict
import os

from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import timed
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.policy.nn_policy import NNPolicy

from mlagents.trainers.exception import UnityTrainerException
logger = logging.getLogger("mlagents.trainers")
logger = get_logger(__name__)
BUFFER_TRUNCATE_PERCENT = 0.8

"memory_size",
"model_path",
"reward_signals",
"vis_encode_type",
]
self._check_param_keys()

7
ml-agents/mlagents/trainers/stats.py


import abc
import csv
import os
from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import set_gauge
from mlagents.tf_utils import tf
from mlagents.tf_utils import tf
from mlagents_envs.timers import set_gauge
logger = get_logger(__name__)
class StatsSummary(NamedTuple):

5
ml-agents/mlagents/trainers/subprocess_env_manager.py


import logging
from typing import Dict, NamedTuple, List, Any, Optional, Callable, Set
import cloudpickle

from multiprocessing.connection import Connection
from queue import Empty as EmptyQueueException
from mlagents_envs.base_env import BaseEnv, AgentGroup
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.env_manager import EnvManager, EnvironmentStep, AllStepResult
from mlagents_envs.timers import (
TimerNode,

from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents.trainers.brain_conversion_utils import group_spec_to_brain_parameters
logger = logging.getLogger("mlagents.trainers")
logger = get_logger(__name__)
class EnvironmentCommand(NamedTuple):

2
ml-agents/mlagents/trainers/tests/simple_test_envs.py


VIS_OBS_SIZE = (20, 20, 3)
STEP_SIZE = 0.1
TIME_PENALTY = 0.001
TIME_PENALTY = 0.01
MIN_STEPS = int(1.0 / STEP_SIZE) + 1
SUCCESS_REWARD = 1.0 + MIN_STEPS * TIME_PENALTY

9
ml-agents/mlagents/trainers/tests/test_agent_processor.py


assert len(processor.last_take_action_outputs.keys()) == 0
assert len(processor.episode_steps.keys()) == 0
assert len(processor.episode_rewards.keys()) == 0
assert len(processor.last_step_result.keys()) == 0
# check that steps with immediate dones don't add to dicts
processor.add_experiences(mock_done_step, 0, ActionInfo.empty())
assert len(processor.experience_buffers.keys()) == 0
assert len(processor.last_take_action_outputs.keys()) == 0
assert len(processor.episode_steps.keys()) == 0
assert len(processor.episode_rewards.keys()) == 0
assert len(processor.last_step_result.keys()) == 0
def test_end_episode():

10
ml-agents/mlagents/trainers/tests/test_distributions.py


def test_gaussian_distribution():
with tf.Graph().as_default():
logits = tf.Variable(initial_value=[[0, 0]], trainable=True, dtype=tf.float32)
logits = tf.Variable(initial_value=[[1, 1]], trainable=True, dtype=tf.float32)
distribution = GaussianDistribution(
logits,
act_size=VECTOR_ACTION_SPACE,

assert out.shape[1] == VECTOR_ACTION_SPACE[0]
output = sess.run([distribution.total_log_probs])
assert output[0].shape[0] == 1
# Test entropy is correct
log_std_tensor = tf.get_default_graph().get_tensor_by_name(
"log_std/BiasAdd:0"
)
feed_dict = {log_std_tensor: [[1.0, 1.0]]}
entropy = sess.run([distribution.entropy], feed_dict=feed_dict)
# Entropy with log_std of 1.0 should be 2.42
assert pytest.approx(entropy[0], 0.01) == 2.42
def test_tanh_distribution():

23
ml-agents/mlagents/trainers/tests/test_simple_rl.py


lambd: 0.95
learning_rate: 5.0e-3
learning_rate_schedule: constant
max_steps: 2000
max_steps: 3000
memory_size: 16
normalize: false
num_epoch: 3

# Custom reward processors shuld be built within the test function and passed to _check_environment_trains
# Default is average over the last 5 final rewards
def default_reward_processor(rewards, last_n_rewards=5):
rewards_to_use = rewards[-last_n_rewards:]
# For debugging tests
print("Last {} rewards:".format(last_n_rewards), rewards_to_use)
return np.array(rewards[-last_n_rewards:], dtype=np.float32).mean()

trainer_config,
reward_processor=default_reward_processor,
meta_curriculum=None,
success_threshold=0.99,
success_threshold=0.9,
env_manager=None,
):
# Create controller and begin training.

if (
success_threshold is not None
): # For tests where we are just checking setup and not reward
processed_rewards = [
reward_processor(rewards) for rewards in env.final_rewards.values()
]

def test_recurrent_ppo(use_discrete):
env = Memory1DEnvironment([BRAIN_NAME], use_discrete=use_discrete)
override_vals = {
"max_steps": 3000,
"max_steps": 5000,
"learning_rate": 1e-3,
_check_environment_trains(env, config)
_check_environment_trains(env, config, success_threshold=0.9)
@pytest.mark.parametrize("use_discrete", [True, False])

@pytest.mark.parametrize("use_discrete", [True, False])
def test_recurrent_sac(use_discrete):
env = Memory1DEnvironment([BRAIN_NAME], use_discrete=use_discrete)
override_vals = {"batch_size": 32, "use_recurrent": True, "max_steps": 2000}
config = generate_config(SAC_CONFIG, override_vals)
_check_environment_trains(env, config)
@pytest.mark.parametrize("use_discrete", [True, False])
def test_simple_ghost(use_discrete):
env = Simple1DEnvironment(
[BRAIN_NAME + "?team=0", BRAIN_NAME + "?team=1"], use_discrete=use_discrete

processed_rewards = [
default_reward_processor(rewards) for rewards in env.final_rewards.values()
]
success_threshold = 0.99
success_threshold = 0.9
assert any(reward > success_threshold for reward in processed_rewards) and any(
reward < success_threshold for reward in processed_rewards
)

5
ml-agents/mlagents/trainers/trainer/trainer.py


# # Unity ML-Agents Toolkit
import logging
from typing import Dict, List, Deque, Any
import time
import abc

from collections import deque
from mlagents_envs.timers import set_gauge
from mlagents_envs.logging_util import get_logger
from mlagents.model_serialization import export_policy_model, SerializationSettings
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.stats import StatsReporter

from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.timers import hierarchical_timer
logger = logging.getLogger("mlagents.trainers")
logger = get_logger(__name__)
class Trainer(abc.ABC):

4
ml-agents/mlagents/trainers/trainer_controller.py


import os
import sys
import logging
from typing import Dict, Optional, Set
from collections import defaultdict

from mlagents_envs.logging_util import get_logger
from mlagents.trainers.env_manager import EnvManager
from mlagents_envs.exception import (
UnityEnvironmentException,

self.trainer_factory = trainer_factory
self.model_path = model_path
self.summaries_dir = summaries_dir
self.logger = logging.getLogger("mlagents.trainers")
self.logger = get_logger(__name__)
self.run_id = run_id
self.save_freq = save_freq
self.train_model = train

5
ml-agents/mlagents/trainers/trainer_util.py


import os
import yaml
from typing import Any, Dict, TextIO
import logging
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.meta_curriculum import MetaCurriculum
from mlagents.trainers.exception import TrainerConfigError
from mlagents.trainers.trainer import Trainer

from mlagents.trainers.ghost.trainer import GhostTrainer
logger = logging.getLogger("mlagents.trainers")
logger = get_logger(__name__)
class TrainerFactory:

1
setup.cfg


I200,
banned-modules = tensorflow = use mlagents.tf_utils instead (it handles tf2 compat).
logging = use mlagents_envs.logging_util instead

46
ml-agents-envs/mlagents_envs/logging_util.py


import logging # noqa I251
CRITICAL = logging.CRITICAL
FATAL = logging.FATAL
ERROR = logging.ERROR
WARNING = logging.WARNING
INFO = logging.INFO
DEBUG = logging.DEBUG
NOTSET = logging.NOTSET
_loggers = set()
_log_level = NOTSET
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
LOG_FORMAT = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
def get_logger(name: str) -> logging.Logger:
"""
Create a logger with the specified name. The logger will use the log level
specified by set_log_level()
"""
logger = logging.getLogger(name=name)
# If we've already set the log level, make sure new loggers use it
if _log_level != NOTSET:
logger.setLevel(_log_level)
# Keep track of this logger so that we can change the log level later
_loggers.add(logger)
return logger
def set_log_level(log_level: int) -> None:
"""
Set the ML-Agents logging level. This will also configure the logging format (if it hasn't already been set).
"""
global _log_level
_log_level = log_level
# Configure the log format.
# In theory, this would be sufficient, but if another library calls logging.basicConfig
# first, it doesn't have any effect.
logging.basicConfig(level=_log_level, format=LOG_FORMAT, datefmt=DATE_FORMAT)
for logger in _loggers:
logger.setLevel(log_level)

10
ml-agents/mlagents/logging_util.py


import logging
def create_logger(name, log_level):
date_format = "%Y-%m-%d %H:%M:%S"
log_format = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(level=log_level, format=log_format, datefmt=date_format)
logger = logging.getLogger(name=name)
return logger
正在加载...
取消
保存