浏览代码

Feature/previous text action (#375)

* [Previous Text Actions] Renamed previous_action to previous_vector_action
added previous_text_action to the BrainInfo

* [Semantics] Carried the modifications to the semantics of previous_vector_action to the trainers
/develop-generalizationTraining-TrainerController
GitHub 7 年前
当前提交
dcf58f75
共有 5 个文件被更改,包括 26 次插入18 次删除
  1. 5
      python/unityagents/brain.py
  2. 19
      python/unityagents/environment.py
  3. 2
      python/unitytrainers/bc/trainer.py
  4. 6
      python/unitytrainers/ppo/trainer.py
  5. 12
      unity-environment/Assets/ML-Agents/Scripts/ExternalCommunicator.cs

5
python/unityagents/brain.py


class BrainInfo:
def __init__(self, visual_observation, vector_observation, text_observations, memory=None,
reward=None, agents=None, local_done=None,
action=None, max_reached=None):
vector_action=None, text_action=None, max_reached=None):
"""
Describes experience at current step of all agents linked to a brain.
"""

self.local_done = local_done
self.max_reached = max_reached
self.agents = agents
self.previous_actions = action
self.previous_vector_actions = vector_action
self.previous_text_actions = text_action
AllBrainInfo = Dict[str, BrainInfo]

19
python/unityagents/environment.py


for _b in self._brain_names:
if _b not in self._data:
self._data[_b] = BrainInfo([], np.array([]), [], np.array([]),
[], [], [], np.array([]), max_reached=[])
[], [], [], np.array([]), [], max_reached=[])
return self._data
b = state_dict["brain_name"]
n_agent = len(state_dict["agents"])

vector_obs = np.array(state_dict["vectorObservations"]).reshape(
(n_agent, self._brains[b].num_stacked_vector_observations))
except UnityActionException:
raise UnityActionException("Brain {0} has an invalid state. "
"Expecting {1} {2} state but received {3}."
raise UnityActionException("Brain {0} has an invalid vector observation. "
"Expecting {1} {2} vector observations but received {3}."
len(state_dict["states"])))
len(state_dict["vectorObservations"])))
memories = np.array(state_dict["memories"]).reshape((n_agent, -1))
text_obs = state_dict["textObservations"]

maxes = state_dict["maxes"]
vector_actions = np.array(state_dict["vectorActions"]).reshape((n_agent, -1))
vector_actions = np.array(state_dict["previousVectorActions"]).reshape((n_agent, -1))
text_actions = state_dict["previousTextActions"]
text_actions = []
observations = []
for o in range(self._brains[b].number_visual_observations):
obs_n = []

observations.append(np.array(obs_n))
self._data[b] = BrainInfo(observations, vector_obs, text_obs, memories, rewards,
agents, dones, vector_actions, max_reached=maxes)
agents, dones, vector_actions, text_actions, max_reached=maxes)
def _send_action(self, vector_action ,memory, text_action):
"""

else:
if text_action[b] is None:
text_action[b] = []
else:
text_action[b] = [""] * n_agent
if isinstance(text_action[b], str):
text_action[b] = [text_action[b]] * n_agent
if not ((len(text_action[b]) == n_agent) or len(text_action[b]) == 0):
raise UnityActionException(
"There was a mismatch between the provided text_action and environment's expectation: "

2
python/unitytrainers/bc/trainer.py


if stored_info_expert.memories.shape[1] == 0:
stored_info_expert.memories = np.zeros((len(stored_info_expert.agents), self.m_size))
self.training_buffer[agent_id]['memory'].append(stored_info_expert.memories[idx])
self.training_buffer[agent_id]['actions'].append(next_info_expert.previous_actions[next_idx])
self.training_buffer[agent_id]['actions'].append(next_info_expert.previous_vector_actions[next_idx])
info_student = curr_info[self.brain_name]
next_info_student = next_info[self.brain_name]
for agent_id in info_student.agents:

6
python/unitytrainers/ppo/trainer.py


if self.is_continuous:
run_list.append(self.model.epsilon)
elif self.use_recurrent:
feed_dict[self.model.prev_action] = np.reshape(curr_brain_info.previous_actions, [-1])
feed_dict[self.model.prev_action] = np.reshape(curr_brain_info.previous_vector_actions, [-1])
if self.use_observations:
for i, _ in enumerate(curr_brain_info.visual_observations):
feed_dict[self.model.observation_in[i]] = curr_brain_info.visual_observations[i]

a_dist = stored_take_action_outputs[self.model.all_probs]
value = stored_take_action_outputs[self.model.value]
self.training_buffer[agent_id]['actions'].append(actions[idx])
self.training_buffer[agent_id]['prev_action'].append(stored_info.previous_actions[idx])
self.training_buffer[agent_id]['prev_action'].append(stored_info.previous_vector_actions[idx])
self.training_buffer[agent_id]['masks'].append(1.0)
self.training_buffer[agent_id]['rewards'].append(next_info.rewards[next_idx])
self.training_buffer[agent_id]['action_probs'].append(a_dist[idx])

info.memories = np.zeros((len(info.vector_observations), self.m_size))
feed_dict[self.model.memory_in] = info.memories
if not self.is_continuous and self.use_recurrent:
feed_dict[self.model.prev_action] = np.reshape(info.previous_actions, [-1])
feed_dict[self.model.prev_action] = np.reshape(info.previous_vector_actions, [-1])
value_next = self.sess.run(self.model.value, feed_dict)[l]
agent_id = info.agents[l]

12
unity-environment/Assets/ML-Agents/Scripts/ExternalCommunicator.cs


public List<int> agents;
public List<float> vectorObservations;
public List<float> rewards;
public List<float> vectorActions;
public List<float> previousVectorActions;
public List<string> previousTextActions;
public List<float> memories;
public List<string> textObservations;
public List<bool> dones;

sMessage.rewards = new List<float>(defaultNumAgents);
sMessage.memories= new List<float>(defaultNumAgents * defaultNumObservations);
sMessage.dones = new List<bool>(defaultNumAgents);
sMessage.vectorActions = new List<float>(defaultNumAgents * defaultNumObservations);
sMessage.previousVectorActions = new List<float>(defaultNumAgents * defaultNumObservations);
sMessage.previousTextActions = new List<string>(defaultNumAgents);
sMessage.maxes= new List<bool>(defaultNumAgents);
sMessage.textObservations = new List<string>(defaultNumAgents);

sMessage.rewards.Clear();
sMessage.memories.Clear();
sMessage.dones.Clear();
sMessage.vectorActions.Clear();
sMessage.previousVectorActions.Clear();
sMessage.previousTextActions.Clear();
sMessage.maxes.Clear();
sMessage.textObservations.Clear();

for (int j = 0; j < memorySize - agentInfo[agent].memories.Count; j++ )
sMessage.memories.Add(0f);
sMessage.dones.Add(agentInfo[agent].done);
sMessage.vectorActions.AddRange(agentInfo[agent].StoredVectorActions.ToList());
sMessage.previousVectorActions.AddRange(agentInfo[agent].StoredVectorActions.ToList());
sMessage.previousTextActions.Add(agentInfo[agent].StoredTextActions);
sMessage.maxes.Add(agentInfo[agent].maxStepReached);
sMessage.textObservations.Add(agentInfo[agent].textObservation);

正在加载...
取消
保存