Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

146 行
5.1 KiB

import os
from typing import List
import numpy as np
from mlagents_envs.base_env import ActionTuple, BehaviorSpec, ActionSpec
from mlagents_envs.communicator_objects.agent_info_action_pair_pb2 import (
AgentInfoActionPairProto,
)
from mlagents_envs.rpc_utils import steps_from_proto
from mlagents.trainers.demonstrations.demonstration_provider import (
DemonstrationProvider,
DemonstrationExperience,
DemonstrationTrajectory,
)
from mlagents.trainers.demonstrations.demonstration_proto_utils import (
load_demonstration,
)
class LocalDemonstrationProvider(DemonstrationProvider):
def __init__(self, file_path: str):
super().__init__()
demo_paths = self._get_demo_files(file_path)
behavior_spec, info_action_pairs, = load_demonstration(demo_paths)
self._behavior_spec = behavior_spec
self._info_action_pairs = info_action_pairs
def get_behavior_spec(self) -> BehaviorSpec:
return self._behavior_spec
def pop_trajectories(self) -> List[DemonstrationTrajectory]:
trajectories = LocalDemonstrationProvider._info_action_pairs_to_trajectories(
self._behavior_spec, self._info_action_pairs
)
self._info_action_pairs = []
return trajectories
@staticmethod
def _get_demo_files(path: str) -> List[str]:
"""
Retrieves the demonstration file(s) from a path.
:param path: Path of demonstration file or directory.
:return: List of demonstration files
Raises errors if |path| is invalid.
"""
if os.path.isfile(path):
if not path.endswith(".demo"):
raise ValueError("The path provided is not a '.demo' file.")
return [path]
elif os.path.isdir(path):
paths = [
os.path.join(path, name)
for name in os.listdir(path)
if name.endswith(".demo")
]
if not paths:
raise ValueError(
"There are no '.demo' files in the provided directory."
)
return paths
else:
raise FileNotFoundError(
f"The demonstration file or directory {path} does not exist."
)
@staticmethod
def _info_action_pairs_to_trajectories(
behavior_spec: BehaviorSpec, info_action_pairs: List[AgentInfoActionPairProto]
) -> List[DemonstrationTrajectory]:
trajectories_out: List[DemonstrationTrajectory] = []
current_experiences = []
previous_action = np.zeros(
behavior_spec.action_spec.continuous_size, dtype=np.float32
) # TODO or discrete?
for pair_index, pair in enumerate(info_action_pairs):
# Extract the observations from the decision/terminal steps
current_decision_step, current_terminal_step = steps_from_proto(
[pair.agent_info], behavior_spec
)
if len(current_terminal_step) == 1:
obs = list(current_terminal_step.values())[0].obs
else:
obs = list(current_decision_step.values())[0].obs
action_tuple = LocalDemonstrationProvider._get_action_tuple(
pair, behavior_spec.action_spec
)
exp = DemonstrationExperience(
obs=obs,
reward=pair.agent_info.reward, # TODO next step's reward?
done=pair.agent_info.done,
action=action_tuple,
prev_action=previous_action,
interrupted=pair.agent_info.max_step_reached,
)
current_experiences.append(exp)
previous_action = np.array(
pair.action_info.vector_actions_deprecated, dtype=np.float32
)
if pair.agent_info.done or pair_index == len(info_action_pairs) - 1:
trajectories_out.append(
DemonstrationTrajectory(experiences=current_experiences)
)
current_experiences = []
return trajectories_out
@staticmethod
def _get_action_tuple(
pair: AgentInfoActionPairProto, action_spec: ActionSpec
) -> ActionTuple:
continuous_actions = None
discrete_actions = None
if (
len(pair.action_info.continuous_actions) == 0
and len(pair.action_info.discrete_actions) == 0
):
if action_spec.continuous_size > 0:
continuous_actions = pair.action_info.vector_actions_deprecated
else:
discrete_actions = pair.action_info.vector_actions_deprecated
else:
if action_spec.continuous_size > 0:
continuous_actions = pair.action_info.continuous_actions
if action_spec.discrete_size > 0:
discrete_actions = pair.action_info.discrete_actions
# TODO 2D?
continuous_np = (
np.array(continuous_actions, dtype=np.float32)
if continuous_actions
else None
)
discrete_np = (
np.array(discrete_actions, dtype=np.float32) if discrete_actions else None
)
return ActionTuple(continuous_np, discrete_np)