浏览代码

Develop add fire exp framework (#4213)

* Experiment branch for comparing torch

* Updates and merging ervin changes

* improvements on experiment_torch.py

* Better printing of results

* preliminary gpu experiment

* Testing gpu

* Prepare to see a lot of commits, because I like my IDE and I am testing on a server and I am using git to sync the two

* Prepare to see a lot of commits, because I like my IDE and I am testing on a server and I am using git to sync the two

* _

* _

* _

* _

* _

* _

* _

* _

* Attempt at gpu on tf. Does not work

* _

* _

* _

* _

* _

* _

* _

* _

* _

* _

* _

* Fixing learn.py
/develop/add-fire
GitHub 5 年前
当前提交
05a11c96
共有 8 个文件被更改,包括 195 次插入41 次删除
  1. 34
      ml-agents/mlagents/trainers/learn.py
  2. 5
      ml-agents/mlagents/trainers/models_torch.py
  3. 4
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  4. 2
      ml-agents/mlagents/trainers/policy/nn_policy.py
  5. 62
      ml-agents/mlagents/trainers/policy/torch_policy.py
  6. 4
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  7. 14
      ml-agents/mlagents/trainers/ppo/trainer.py
  8. 111
      experiment_torch.py

34
ml-agents/mlagents/trainers/learn.py


)
from mlagents_envs import logging_util
from mlagents.trainers.ppo.trainer import TestingConfiguration
from mlagents_envs.registry import default_registry
logger = logging_util.get_logger(__name__)
TRAINING_STATUS_FILE_NAME = "training_status.json"

) -> UnityEnvironment:
# Make sure that each environment gets a different seed
env_seed = seed + worker_id
return UnityEnvironment(
file_name=env_path,
worker_id=worker_id,
seed=env_seed,
no_graphics=no_graphics,
base_port=start_port,
additional_args=env_args,
side_channels=side_channels,
log_folder=log_folder,
)
if TestingConfiguration.env_name == "":
return UnityEnvironment(
file_name=env_path,
worker_id=worker_id,
seed=env_seed,
no_graphics=no_graphics,
base_port=start_port,
additional_args=env_args,
side_channels=side_channels,
log_folder=log_folder,
)
else:
return default_registry[TestingConfiguration.env_name].make(
seed=env_seed,
no_graphics=no_graphics,
base_port=start_port,
worker_id=worker_id,
additional_args=env_args,
side_channels=side_channels,
log_folder=log_folder,
)
return create_unity_environment

5
ml-agents/mlagents/trainers/models_torch.py


if self.use_lstm:
embedding = embedding.view([sequence_length, -1, self.h_size])
memories = torch.split(memories, self.m_size // 2, dim=-1)
embedding, memories = self.lstm(embedding, memories)
embedding, memories = self.lstm(embedding.contiguous(), (memories[0].contiguous(), memories[1].contiguous()))
embedding = embedding.view([-1, self.m_size // 2])
memories = torch.cat(memories, dim=-1)
return embedding, memories

def forward(self, visual_obs):
conv_1 = torch.relu(self.conv1(visual_obs))
conv_2 = torch.relu(self.conv2(conv_1))
hidden = torch.relu(self.dense(conv_2.view([-1, self.final_flat])))
# hidden = torch.relu(self.dense(conv_2.view([-1, self.final_flat])))
hidden = torch.relu(self.dense(torch.reshape(conv_2,(-1, self.final_flat))))
return hidden

4
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


)
for name, estimate in value_estimates.items():
value_estimates[name] = estimate.detach().numpy()
next_value_estimate[name] = next_value_estimate[name].detach().numpy()
value_estimates[name] = estimate.detach().cpu().numpy()
next_value_estimate[name] = next_value_estimate[name].detach().cpu().numpy()
if done:
for k in next_value_estimate:

2
ml-agents/mlagents/trainers/policy/nn_policy.py


MultiCategoricalDistribution,
)
from mlagents.trainers.ppo.trainer import TestingConfiguration
EPSILON = 1e-6 # Small value to avoid divide by zero

62
ml-agents/mlagents/trainers/policy/torch_policy.py


from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.models_torch import ActorCritic
from mlagents.trainers.ppo.trainer import TestingConfiguration
EPSILON = 1e-7 # Small value to avoid divide by zero

# good explanation and usually shouldn't be touched.
self.log_std_min = -20
self.log_std_max = 2
if TestingConfiguration.device != "cpu":
torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
torch.set_default_tensor_type(torch.FloatTensor)
self.inference_dict: Dict[str, tf.Tensor] = {}
self.update_dict: Dict[str, tf.Tensor] = {}

separate_critic=self.use_continuous_act,
)
self.actor_critic.to(TestingConfiguration.device)
def split_decision_step(self, decision_requests):
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs)
mask = None

action, log_probs, entropy, value_heads, memories = self.sample_actions(
vec_obs, vis_obs, masks=masks, memories=memories
)
run_out["action"] = action.detach().numpy()
run_out["pre_action"] = action.detach().numpy()
run_out["action"] = action.detach().cpu().numpy()
run_out["pre_action"] = action.detach().cpu().numpy()
run_out["log_probs"] = log_probs.detach().numpy()
run_out["entropy"] = entropy.detach().numpy()
run_out["log_probs"] = log_probs.detach().cpu().numpy()
run_out["entropy"] = entropy.detach().cpu().numpy()
name: t.detach().numpy() for name, t in value_heads.items()
name: t.detach().cpu().numpy() for name, t in value_heads.items()
run_out["memories"] = memories.detach().numpy()
run_out["memories"] = memories.detach().cpu().numpy()
self.actor_critic.update_normalization(vec_obs)
return run_out

self.actor_critic.load_state_dict(torch.load(load_path))
def export_model(self, step=0):
fake_vec_obs = [torch.zeros([1] + [self.brain.vector_observation_space_size])]
fake_vis_obs = [torch.zeros([1] + [84, 84, 3])]
fake_masks = torch.ones([1] + self.actor_critic.act_size)
# fake_memories = torch.zeros([1] + [self.m_size])
export_path = "./model-" + str(step) + ".onnx"
output_names = ["action", "action_probs"]
input_names = ["vector_observation", "action_mask"]
dynamic_axes = {"vector_observation": [0], "action": [0], "action_probs": [0]}
onnx.export(
self.actor_critic,
(fake_vec_obs, fake_vis_obs, fake_masks),
export_path,
verbose=True,
opset_version=12,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
)
try:
fake_vec_obs = [torch.zeros([1] + [self.brain.vector_observation_space_size])]
fake_vis_obs = [torch.zeros([1] + [84, 84, 3])]
fake_masks = torch.ones([1] + self.actor_critic.act_size)
# fake_memories = torch.zeros([1] + [self.m_size])
export_path = "./model-" + str(step) + ".onnx"
output_names = ["action", "action_probs"]
input_names = ["vector_observation", "action_mask"]
dynamic_axes = {"vector_observation": [0], "action": [0], "action_probs": [0]}
onnx.export(
self.actor_critic,
(fake_vec_obs, fake_vis_obs, fake_masks),
export_path,
verbose=True,
opset_version=12,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
)
except:
print("Could not export torch model")
return
@property
def vis_obs_size(self):

4
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


self.optimizer.step()
update_stats = {
"Losses/Policy Loss": abs(policy_loss.detach().numpy()),
"Losses/Value Loss": value_loss.detach().numpy(),
"Losses/Policy Loss": abs(policy_loss.detach().cpu().numpy()),
"Losses/Value Loss": value_loss.detach().cpu().numpy(),
}
return update_stats

14
ml-agents/mlagents/trainers/ppo/trainer.py


# ## ML-Agent Learning (PPO)
# Contains an implementation of PPO as described in: https://arxiv.org/abs/1707.06347
class TestingConfiguration:
use_torch = False
max_steps = 0
env_name = ""
device = "cpu"
from collections import defaultdict
from typing import cast

from mlagents.trainers.settings import TrainerSettings, PPOSettings
logger = get_logger(__name__)
class PPOTrainer(RLTrainer):

)
self.load = load
self.seed = seed
self.framework = "torch"
self.framework = "torch" if TestingConfiguration.use_torch else "tf"
if TestingConfiguration.max_steps > 0:
self.trainer_settings.max_steps = TestingConfiguration.max_steps
self.policy: Policy = None # type: ignore
def _process_trajectory(self, trajectory: Trajectory) -> None:

111
experiment_torch.py


import json
import os
import torch
import tensorflow as tf
import argparse
from mlagents.trainers.learn import run_cli, parse_command_line
from mlagents.trainers.settings import RunOptions
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.ppo.trainer import TestingConfiguration
from mlagents_envs.timers import _thread_timer_stacks
def run_experiment(name:str, steps:int, use_torch:bool, num_torch_threads:int, use_gpu:bool, num_envs :int= 1, config_name=None):
TestingConfiguration.env_name = name
TestingConfiguration.max_steps = steps
TestingConfiguration.use_torch = use_torch
TestingConfiguration.device = "cuda:0" if use_gpu else "cpu"
if use_gpu:
tf.device("/GPU:0")
else:
tf.device("/device:CPU:0")
if (not torch.cuda.is_available() and use_gpu):
return name, str(steps), str(use_torch), str(num_torch_threads), str(num_envs), str(use_gpu), "na","na","na","na","na","na","na"
if config_name is None:
config_name = name
run_options = parse_command_line([f"config/ppo/{config_name}.yaml", "--num-envs", f"{num_envs}"])
run_options.checkpoint_settings.run_id = f"{name}_test_" +str(steps) +"_"+("torch" if use_torch else "tf")
run_options.checkpoint_settings.force = True
# run_options.env_settings.num_envs = num_envs
for trainer_settings in run_options.behaviors.values():
trainer_settings.threaded = False
timers_path = os.path.join("results", run_options.checkpoint_settings.run_id, "run_logs", "timers.json")
if use_torch:
torch.set_num_threads(num_torch_threads)
run_cli(run_options)
StatsReporter.writers.clear()
StatsReporter.stats_dict.clear()
_thread_timer_stacks.clear()
with open(timers_path) as timers_json_file:
timers_json = json.load(timers_json_file)
total = timers_json["total"]
tc_advance = timers_json["children"]["TrainerController.start_learning"]["children"]["TrainerController.advance"]
evaluate = timers_json["children"]["TrainerController.start_learning"]["children"]["TrainerController.advance"]["children"]["env_step"]["children"]["SubprocessEnvManager._take_step"]["children"]
update = timers_json["children"]["TrainerController.start_learning"]["children"]["TrainerController.advance"]["children"]["trainer_advance"]["children"]["_update_policy"]["children"]
tc_advance_total = tc_advance["total"]
tc_advance_count = tc_advance["count"]
if use_torch:
update_total = update["TorchPPOOptimizer.update"]["total"]
evaluate_total = evaluate["TorchPolicy.evaluate"]["total"]
update_count = update["TorchPPOOptimizer.update"]["count"]
evaluate_count = evaluate["TorchPolicy.evaluate"]["count"]
else:
update_total = update["TFPPOOptimizer.update"]["total"]
evaluate_total = evaluate["NNPolicy.evaluate"]["total"]
update_count = update["TFPPOOptimizer.update"]["count"]
evaluate_count= evaluate["NNPolicy.evaluate"]["count"]
# todo: do total / count
return name, str(steps), str(use_torch), str(num_torch_threads), str(num_envs), str(use_gpu), str(total), str(tc_advance_total), str(tc_advance_count), str(update_total), str(update_count), str(evaluate_total), str(evaluate_count)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--steps", default=25000, type=int, help="The number of steps")
parser.add_argument("--num-envs", default=1, type=int, help="The number of envs")
parser.add_argument("--gpu", default = False, action="store_true", help="If true, will use the GPU")
parser.add_argument("--threads", default=False, action="store_true", help="If true, will try both 1 and 8 threads for torch")
parser.add_argument("--ball", default=False, action="store_true", help="If true, will only do 3dball")
args = parser.parse_args()
if args.gpu:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
else:
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
envs_config_tuples = [("3DBall", "3DBall"), ("GridWorld", "GridWorld"), ("PushBlock", "PushBlock"), ("Hallway", "Hallway"), ("CrawlerStaticTarget", "CrawlerStatic"), ("VisualHallway", "VisualHallway")]
if args.ball:
envs_config_tuples=[("3DBall", "3DBall")]
labels = ("name", "steps", "use_torch", "num_torch_threads", "num_envs", "use_gpu" , "total", "tc_advance_total", "tc_advance_count", "update_total", "update_count", "evaluate_total", "evaluate_count")
results = []
results.append(labels)
f = open(f"result_data_steps_{args.steps}_envs_{args.num_envs}_gpu_{args.gpu}_thread_{args.threads}.txt", "w")
f.write(" ".join(labels)+ "\n")
for env_config in envs_config_tuples:
data = run_experiment(name = env_config[0], steps=args.steps, use_torch=True, num_torch_threads=1, use_gpu=args.gpu, num_envs = args.num_envs, config_name=env_config[1])
results.append(data)
f.write(" ".join(data) + "\n")
if args.threads:
data = run_experiment(name = env_config[0], steps=args.steps, use_torch=True, num_torch_threads=8, use_gpu=args.gpu, num_envs = args.num_envs, config_name=env_config[1])
results.append(data)
f.write(" ".join(data)+ "\n")
data = run_experiment(name = env_config[0], steps=args.steps, use_torch=False, num_torch_threads=1, use_gpu=args.gpu, num_envs = args.num_envs, config_name=env_config[1])
results.append(data)
f.write(" ".join(data)+ "\n")
for r in results:
print(*r)
f.close()
if __name__ == "__main__":
main()
正在加载...
取消
保存