浏览代码

Merge pull request #934 from Unity-Technologies/develop-value-estimates-ppo

Develop value estimates ppo
/develop-generalizationTraining-TrainerController
GitHub 7 年前
当前提交
1e21c143
共有 10 个文件被更改,包括 120 次插入17 次删除
  1. 11
      python/communicator_objects/agent_action_proto_pb2.py
  2. 22
      python/unityagents/environment.py
  3. 4
      python/unitytrainers/bc/trainer.py
  4. 4
      python/unitytrainers/ppo/trainer.py
  5. 12
      python/unitytrainers/trainer_controller.py
  6. 1
      unity-environment/Assets/ML-Agents/Scripts/Academy.cs
  7. 15
      unity-environment/Assets/ML-Agents/Scripts/Agent.cs
  8. 2
      unity-environment/Assets/ML-Agents/Scripts/Batcher.cs
  9. 36
      unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/AgentActionProto.cs
  10. 30
      unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs

11
python/communicator_objects/agent_action_proto_pb2.py


package='communicator_objects',
syntax='proto3',
serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'),
serialized_pb=_b('\n-communicator_objects/agent_action_proto.proto\x12\x14\x63ommunicator_objects\"R\n\x10\x41gentActionProto\x12\x16\n\x0evector_actions\x18\x01 \x03(\x02\x12\x14\n\x0ctext_actions\x18\x02 \x01(\t\x12\x10\n\x08memories\x18\x03 \x03(\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n-communicator_objects/agent_action_proto.proto\x12\x14\x63ommunicator_objects\"a\n\x10\x41gentActionProto\x12\x16\n\x0evector_actions\x18\x01 \x03(\x02\x12\x14\n\x0ctext_actions\x18\x02 \x01(\t\x12\x10\n\x08memories\x18\x03 \x03(\x02\x12\r\n\x05value\x18\x04 \x01(\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3')
)

message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='value', full_name='communicator_objects.AgentActionProto.value', index=3,
number=4, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],

oneofs=[
],
serialized_start=71,
serialized_end=153,
serialized_end=168,
)
DESCRIPTOR.message_types_by_name['AgentActionProto'] = _AGENTACTIONPROTO

22
python/unityagents/environment.py


else:
raise UnityEnvironmentException("No Unity environment is loaded.")
def step(self, vector_action=None, memory=None, text_action=None) -> AllBrainInfo:
def step(self, vector_action=None, memory=None, text_action=None, value=None) -> AllBrainInfo:
"""
Provides the environment with an action, moves the environment dynamics forward accordingly, and returns
observation, state, and reward information to the agent.

vector_action = {} if vector_action is None else vector_action
memory = {} if memory is None else memory
text_action = {} if text_action is None else text_action
value = {} if value is None else value
if self._loaded and not self._global_done and self._global_done is not None:
if isinstance(vector_action, (int, np.int_, float, np.float_, list, np.ndarray)):
if self._num_external_brains == 1:

raise UnityActionException(
"There are no external brains in the environment, "
"step cannot take a value input")
if isinstance(value, (int, np.int_, float, np.float_, list, np.ndarray)):
if self._num_external_brains == 1:
value = {self._external_brain_names[0]: value}
elif self._num_external_brains > 1:
raise UnityActionException(
"You have {0} brains, you need to feed a dictionary of brain names as keys "
"and state/action value estimates as values".format(self._num_brains))
else:
raise UnityActionException(
"There are no external brains in the environment, "
"step cannot take a value input")
for brain_name in list(vector_action.keys()) + list(memory.keys()) + list(text_action.keys()):
if brain_name not in self._external_brain_names:

str(vector_action[b])))
outputs = self.communicator.exchange(
self._generate_step_input(vector_action, memory, text_action)
self._generate_step_input(vector_action, memory, text_action, value)
)
if outputs is None:
raise KeyboardInterrupt

)
return _data, global_done
def _generate_step_input(self, vector_action, memory, text_action) -> UnityRLInput:
def _generate_step_input(self, vector_action, memory, text_action, value) -> UnityRLInput:
rl_in = UnityRLInput()
for b in vector_action:
n_agents = self._n_agents[b]

action = AgentActionProto(
vector_actions=vector_action[b][i*_a_s: (i+1)*_a_s],
memories=memory[b][i*_m_s: (i+1)*_m_s],
text_actions=text_action[b][i]
text_actions=text_action[b][i],
if b in value:
action.value = value[b][i]
rl_in.agent_actions[b].value.extend([action])
rl_in.command = 0
return self.wrap_unity_input(rl_in)

4
python/unitytrainers/bc/trainer.py


agent_brain.memories = np.zeros((len(agent_brain.agents), self.m_size))
feed_dict[self.model.memory_in] = agent_brain.memories
agent_action, memories = self.sess.run(self.inference_run_list, feed_dict)
return agent_action, memories, None, None
return agent_action, memories, None, None, None
return agent_action, None, None, None
return agent_action, None, None, None, None
def add_experiences(self, curr_info: AllBrainInfo, next_info: AllBrainInfo, take_action_outputs):
"""

4
python/unitytrainers/ppo/trainer.py


self.stats['entropy'].append(run_out[self.model.entropy].mean())
self.stats['learning_rate'].append(run_out[self.model.learning_rate])
if self.use_recurrent:
return run_out[self.model.output], run_out[self.model.memory_out], None, run_out
return run_out[self.model.output], run_out[self.model.memory_out], None, run_out[self.model.value], run_out
return run_out[self.model.output], None, None, run_out
return run_out[self.model.output], None, None, run_out[self.model.value], run_out
def construct_curr_info(self, next_info: BrainInfo) -> BrainInfo:
"""

12
python/unitytrainers/trainer_controller.py


if self.trainers[brain_name].parameters["trainer"] == "imitation":
nodes += [scope + x for x in ["action"]]
else:
nodes += [scope + x for x in ["action", "value_estimate", "action_probs"]]
nodes += [scope + x for x in ["action", "value_estimate", "action_probs", "value_estimate"]]
if self.trainers[brain_name].parameters["use_recurrent"]:
nodes += [scope + x for x in ["recurrent_out", "memory_size"]]
if len(scopes) > 1:

for brain_name, trainer in self.trainers.items():
trainer.end_episode()
# Decide and take an action
take_action_vector, take_action_memories, take_action_text, take_action_outputs = {}, {}, {}, {}
take_action_vector, \
take_action_memories, \
take_action_text, \
take_action_value, \
take_action_outputs \
= {}, {}, {}, {}, {}
take_action_value,
text_action=take_action_text)
text_action=take_action_text, value=take_action_value)
for brain_name, trainer in self.trainers.items():
trainer.add_experiences(curr_info, new_info, take_action_outputs[brain_name])
trainer.process_experiences(curr_info, new_info)

1
unity-environment/Assets/ML-Agents/Scripts/Academy.cs


(MLAgents.CommunicatorObjects.BrainTypeProto)
brain.brainType));
}
academyParameters.EnvironmentParameters =
new MLAgents.CommunicatorObjects.EnvironmentParametersProto();

15
unity-environment/Assets/ML-Agents/Scripts/Agent.cs


public float[] vectorActions;
public string textActions;
public List<float> memories;
public float value;
}
/// <summary>

public void UpdateTextAction(string textActions)
{
action.textActions = textActions;
}
/// <summary>
/// Updates the value of the agent.
/// </summary>
/// <param name="textActions">Text actions.</param>
public void UpdateValueAction(float value)
{
action.value = value;
}
protected float GetValueEstimate()
{
return action.value;
}
/// <summary>

2
unity-environment/Assets/ML-Agents/Scripts/Batcher.cs


action.Memories.ToList());
agent.UpdateTextAction(
action.TextActions);
agent.UpdateValueAction(
action.Value);
}
}

36
unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/AgentActionProto.cs


byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"Ci1jb21tdW5pY2F0b3Jfb2JqZWN0cy9hZ2VudF9hY3Rpb25fcHJvdG8ucHJv",
"dG8SFGNvbW11bmljYXRvcl9vYmplY3RzIlIKEEFnZW50QWN0aW9uUHJvdG8S",
"dG8SFGNvbW11bmljYXRvcl9vYmplY3RzImEKEEFnZW50QWN0aW9uUHJvdG8S",
"EhAKCG1lbW9yaWVzGAMgAygCQh+qAhxNTEFnZW50cy5Db21tdW5pY2F0b3JP",
"YmplY3RzYgZwcm90bzM="));
"EhAKCG1lbW9yaWVzGAMgAygCEg0KBXZhbHVlGAQgASgCQh+qAhxNTEFnZW50",
"cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentActionProto), global::MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActions", "TextActions", "Memories" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentActionProto), global::MLAgents.CommunicatorObjects.AgentActionProto.Parser, new[]{ "VectorActions", "TextActions", "Memories", "Value" }, null, null, null)
}));
}
#endregion

vectorActions_ = other.vectorActions_.Clone();
textActions_ = other.textActions_;
memories_ = other.memories_.Clone();
value_ = other.value_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

get { return memories_; }
}
/// <summary>Field number for the "value" field.</summary>
public const int ValueFieldNumber = 4;
private float value_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public float Value {
get { return value_; }
set {
value_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as AgentActionProto);

if(!vectorActions_.Equals(other.vectorActions_)) return false;
if (TextActions != other.TextActions) return false;
if(!memories_.Equals(other.memories_)) return false;
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Value, other.Value)) return false;
return Equals(_unknownFields, other._unknownFields);
}

hash ^= vectorActions_.GetHashCode();
if (TextActions.Length != 0) hash ^= TextActions.GetHashCode();
hash ^= memories_.GetHashCode();
if (Value != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Value);
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}

output.WriteString(TextActions);
}
memories_.WriteTo(output, _repeated_memories_codec);
if (Value != 0F) {
output.WriteRawTag(37);
output.WriteFloat(Value);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}

size += 1 + pb::CodedOutputStream.ComputeStringSize(TextActions);
}
size += memories_.CalculateSize(_repeated_memories_codec);
if (Value != 0F) {
size += 1 + 4;
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}

TextActions = other.TextActions;
}
memories_.Add(other.memories_);
if (other.Value != 0F) {
Value = other.Value;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

case 26:
case 29: {
memories_.AddEntriesFrom(input, _repeated_memories_codec);
break;
}
case 37: {
Value = input.ReadFloat();
break;
}
}

30
unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs


bool hasState;
bool hasBatchSize;
bool hasPrevAction;
bool hasValueEstimate;
float[,] inputState;
int[] inputPrevAction;
List<float[,,,]> observationMatrixList;

if (graph[graphScope + PreviousActionPlaceholderName] != null)
{
hasPrevAction = true;
}
if (graph[graphScope + "value_estimate"] != null)
{
hasValueEstimate = true;
}
}

runner.Fetch(graph[graphScope + RecurrentOutPlaceholderName][0]);
}
if (hasValueEstimate)
{
runner.Fetch(graph[graphScope + "value_estimate"][0]);
}
TFTensor[] networkOutput;
try
{

agent.UpdateMemoriesAction(m.ToList());
i++;
}
}
if (hasValueEstimate)
{
float[,] value_estimates = new float[currentBatchSize,1];
if (hasRecurrent)
{
value_estimates = networkOutput[2].GetValue() as float[,];
}
else
{
value_estimates = networkOutput[1].GetValue() as float[,];
}
var i = 0;
foreach (Agent agent in agentList)
{
agent.UpdateValueAction(value_estimates[i,0]);
}
}

正在加载...
取消
保存