浏览代码

Move one-hot out of policy and remove selected_actions

/develop/removeactionholder-onehot
Ervin Teng 5 年前
当前提交
53c25fb1
共有 10 个文件被更改,包括 108 次插入70 次删除
  1. 19
      ml-agents/mlagents/trainers/common/nn_policy.py
  2. 35
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py
  3. 10
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py
  4. 19
      ml-agents/mlagents/trainers/components/reward_signals/gail/model.py
  5. 10
      ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py
  6. 39
      ml-agents/mlagents/trainers/models.py
  7. 17
      ml-agents/mlagents/trainers/ppo/optimizer.py
  8. 10
      ml-agents/mlagents/trainers/sac/network.py
  9. 17
      ml-agents/mlagents/trainers/sac/optimizer.py
  10. 2
      ml-agents/mlagents/trainers/tf_policy.py

19
ml-agents/mlagents/trainers/common/nn_policy.py


self.prev_action = tf.placeholder(
shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action"
)
prev_action_oh = tf.concat(
[
tf.one_hot(self.prev_action[:, i], self.act_size[i])
for i in range(len(self.act_size))
],
axis=1,
prev_action_oh = ModelUtils.to_onehot_tensor(
self.prev_action, self.act_size
)
hidden_policy = tf.concat([encoded, prev_action_oh], axis=1)

self.output = tf.identity(output)
self.all_log_probs = tf.identity(normalized_logits, name="action")
self.action_oh = tf.concat(
[
tf.one_hot(self.output[:, i], self.act_size[i])
for i in range(len(self.act_size))
],
axis=1,
)
self.selected_actions = tf.stop_gradient(self.action_oh)
action_oh = ModelUtils.to_onehot_tensor(self.output, self.act_size)
action_idx = [0] + list(np.cumsum(self.act_size))

tf.stack(
[
-tf.nn.softmax_cross_entropy_with_logits_v2(
labels=self.action_oh[:, action_idx[i] : action_idx[i + 1]],
labels=action_oh[:, action_idx[i] : action_idx[i + 1]],
logits=normalized_logits[
:, action_idx[i] : action_idx[i + 1]
],

35
ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py


"""
self.encoding_size = encoding_size
self.policy = policy
is_discrete = self.policy.brain.vector_action_space_type != "continuous"
self.selected_actions_ph = ModelUtils.create_action_input_placeholder(
self.policy.act_size, is_discrete
)
if is_discrete:
action_input = ModelUtils.to_onehot_tensor(
self.selected_actions_ph, self.policy.act_size
)
else:
action_input = self.selected_actions_ph
self.create_inverse_model(encoded_state, encoded_next_state)
self.create_forward_model(encoded_state, encoded_next_state)
self.create_inverse_model(action_input, encoded_state, encoded_next_state)
self.create_forward_model(action_input, encoded_state, encoded_next_state)
self.create_loss(learning_rate)
def create_curiosity_encoders(self) -> Tuple[tf.Tensor, tf.Tensor]:

return encoded_state, encoded_next_state
def create_inverse_model(
self, encoded_state: tf.Tensor, encoded_next_state: tf.Tensor
self,
action_input: tf.Tensor,
encoded_state: tf.Tensor,
encoded_next_state: tf.Tensor,
:param action_input: Tensor representing the current selected action.
:param encoded_state: Tensor corresponding to encoded current state.
:param encoded_next_state: Tensor corresponding to encoded next state.
"""

hidden, self.policy.act_size[0], activation=None
)
squared_difference = tf.reduce_sum(
tf.squared_difference(pred_action, self.policy.selected_actions), axis=1
tf.squared_difference(pred_action, action_input), axis=1
)
self.inverse_loss = tf.reduce_mean(
tf.dynamic_partition(squared_difference, self.policy.mask, 2)[1]

axis=1,
)
cross_entropy = tf.reduce_sum(
-tf.log(pred_action + 1e-10) * self.policy.selected_actions, axis=1
-tf.log(pred_action + 1e-10) * action_input, axis=1
)
self.inverse_loss = tf.reduce_mean(
tf.dynamic_partition(cross_entropy, self.policy.mask, 2)[1]

self, encoded_state: tf.Tensor, encoded_next_state: tf.Tensor
self,
action_input: tf.Tensor,
encoded_state: tf.Tensor,
encoded_next_state: tf.Tensor,
:param action_input: Tensor representing the current selected action.
combined_input = tf.concat(
[encoded_state, self.policy.selected_actions], axis=1
)
combined_input = tf.concat([encoded_state, action_input], axis=1)
hidden = tf.layers.dense(combined_input, 256, activation=ModelUtils.swish)
pred_next_state = tf.layers.dense(
hidden,

10
ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py


feed_dict[self.policy.visual_in[i]] = _obs
feed_dict[self.model.next_visual_in[i]] = _next_obs
if self.policy.use_continuous_act:
feed_dict[self.policy.selected_actions] = mini_batch["actions"]
else:
feed_dict[self.policy.output] = mini_batch["actions"]
feed_dict[self.model.selected_actions_ph] = mini_batch["actions"]
unscaled_reward = self.policy.sess.run(
self.model.intrinsic_reward, feed_dict=feed_dict
)

policy.sequence_length_ph: self.policy.sequence_length,
policy.mask_input: mini_batch["masks"],
}
if self.policy.use_continuous_act:
feed_dict[policy.selected_actions] = mini_batch["actions"]
else:
feed_dict[policy.output] = mini_batch["actions"]
feed_dict[self.model.selected_actions_ph] = mini_batch["actions"]
if self.policy.use_vec_obs:
feed_dict[policy.vector_in] = mini_batch["vector_obs"]
feed_dict[self.model.next_vector_in] = mini_batch["next_vector_in"]

19
ml-agents/mlagents/trainers/components/reward_signals/gail/model.py


self.encoding_size = encoding_size
self.gradient_penalty_weight = gradient_penalty_weight
self.use_vail = use_vail
self.use_actions = use_actions # True # Not using actions
self.use_actions = use_actions # Not using actions
is_discrete = self.policy.brain.vector_action_space_type != "continuous"
self.selected_actions_ph = ModelUtils.create_action_input_placeholder(
self.policy.act_size, is_discrete
)
if is_discrete:
self.action_input = ModelUtils.to_onehot_tensor(
self.selected_actions_ph, self.policy.act_size
)
else:
self.action_input = self.selected_actions_ph
self.noise: Optional[tf.Tensor] = None
self.z: Optional[tf.Tensor] = None

self.encoded_expert, self.expert_action, self.done_expert, reuse=False
)
self.policy_estimate, self.z_mean_policy, _ = self.create_encoder(
self.encoded_policy,
self.policy.selected_actions,
self.done_policy,
reuse=True,
self.encoded_policy, self.action_input, self.done_policy, reuse=True
)
self.mean_policy_estimate = tf.reduce_mean(self.policy_estimate)
self.mean_expert_estimate = tf.reduce_mean(self.expert_estimate)

for off-policy. Compute gradients w.r.t randomly interpolated input.
"""
expert = [self.encoded_expert, self.expert_action, self.done_expert]
policy = [self.encoded_policy, self.policy.selected_actions, self.done_policy]
policy = [self.encoded_policy, self.action_input, self.done_policy]
interp = []
for _expert_in, _policy_in in zip(expert, policy):
alpha = tf.random_uniform(tf.shape(_expert_in))

10
ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py


_obs = mini_batch["visual_obs%d" % i]
feed_dict[self.policy.visual_in[i]] = _obs
if self.policy.use_continuous_act:
feed_dict[self.policy.selected_actions] = mini_batch["actions"]
else:
feed_dict[self.policy.output] = mini_batch["actions"]
feed_dict[self.model.selected_actions_ph] = mini_batch["actions"]
feed_dict[self.model.done_policy_holder] = np.array(
mini_batch["done"]
).flatten()

feed_dict[self.model.use_noise] = [1]
feed_dict[self.model.action_in_expert] = np.array(mini_batch_demo["actions"])
if self.policy.use_continuous_act:
feed_dict[policy.selected_actions] = mini_batch["actions"]
else:
feed_dict[policy.output] = mini_batch["actions"]
feed_dict[self.model.selected_actions_ph] = mini_batch["actions"]
if self.policy.use_vis_obs > 0:
for i in range(len(policy.visual_in)):

39
ml-agents/mlagents/trainers/models.py


:param stream_names: The list of reward signal names
:param hidden_input: The last layer of the Critic. The heads will consist of one dense hidden layer on top
of the hidden input.
:return: A tuple of (value heads, value) where the value heads are a dict of reward signal name to
value output, and the value is an average of all of them.
"""
value_heads = {}
for name in stream_names:

return value_heads, value
@staticmethod
def to_onehot_tensor(action: tf.Tensor, act_size: List[int]) -> tf.Tensor:
"""
For discrete actions, converts the action tensor (array of ints, one for each action)
to a one-hot representation. This could be useful e.g. to feed in as input to a
neural network.
:param action: Tensor that represents a branched discrete action. Length should be the
number of action branches, with the values as the action type.
:param act_size: List of ints that represent the number of actions for each branch.
:return: One-hot tensor of the action.
"""
action_oh = tf.concat(
[tf.one_hot(action[:, i], act_size[i]) for i in range(len(act_size))],
axis=1,
)
return action_oh
@staticmethod
def create_action_input_placeholder(
act_size: List[int], is_discrete: bool = False
) -> tf.Tensor:
"""
Create a placeholder input for actions.
:param is_discrete: whether or nto the act_size describes discrete actions (otherwise, it is continuous.)
:param act_size: List of ints that represent the number of actions for each branch.
:return: Placeholder for the specified action.
"""
if is_discrete:
action_holder = tf.placeholder(
shape=[None, len(act_size)], dtype=tf.int32, name="action_holder"
)
else:
action_holder = tf.placeholder(
shape=[None, act_size[0]], dtype=tf.float32, name="action_holder"
)
return action_holder

17
ml-agents/mlagents/trainers/ppo/optimizer.py


vis_encode_type,
)[0]
self._discrete_action_ph = ModelUtils.create_action_input_placeholder(
self.policy.act_size, is_discrete=True
)
action_oh = ModelUtils.to_onehot_tensor(
self._discrete_action_ph, self.policy.act_size
)
if self.policy.use_recurrent:
hidden_value, memory_value_out = ModelUtils.create_recurrent_encoder(
hidden_stream,

tf.stack(
[
-tf.nn.softmax_cross_entropy_with_logits_v2(
labels=self.policy.action_oh[
:, action_idx[i] : action_idx[i + 1]
],
labels=action_oh[:, action_idx[i] : action_idx[i + 1]],
logits=old_normalized_logits[
:, action_idx[i] : action_idx[i + 1]
],

"{}_value_estimates".format(name)
]
if self.policy.output_pre is not None and "actions_pre" in mini_batch:
feed_dict[self.policy.output_pre] = mini_batch["actions_pre"]
else:
feed_dict[self.policy.output] = mini_batch["actions"]
if not self.policy.use_continuous_act:
feed_dict[self._discrete_action_ph] = mini_batch["actions"]
if self.policy.use_recurrent:
feed_dict[self.policy.prev_action] = mini_batch["prev_action"]
feed_dict[self.policy.action_masks] = mini_batch["action_mask"]

10
ml-agents/mlagents/trainers/sac/network.py


self.h_size,
self.join_scopes(scope, "value"),
)
self.external_action_in = tf.placeholder(
shape=[None, self.policy.act_size[0]],
dtype=tf.float32,
name="external_action_in",
self.external_action_in = ModelUtils.create_action_input_placeholder(
self.policy.act_size, is_discrete=False
)
self.value_vars = self.get_vars(self.join_scopes(scope, "value"))
if create_qs:

self.h_size,
self.join_scopes(scope, "value"),
)
self.external_action_in = ModelUtils.create_action_input_placeholder(
self.policy.act_size, is_discrete=True
)
self.value_vars = self.get_vars("/".join([scope, "value"]))
if create_qs:

17
ml-agents/mlagents/trainers/sac/optimizer.py


)
self._create_sac_optimizer_ops()
self.selected_actions = (
self.policy.selected_actions
) # For GAIL and other reward signals
if self.policy.normalize:
target_update_norm = self.target_network.copy_normalization(
self.policy.running_mean,

)
if discrete:
onehot_action = ModelUtils.to_onehot_tensor(
self.policy_network.external_action_in, self.policy.act_size
)
self.policy.action_oh * q1_streams[name]
onehot_action * q1_streams[name]
self.policy.action_oh * q2_streams[name]
onehot_action * q2_streams[name]
)
# Reduce each branch into scalar

}
for name in self.reward_signals:
feed_dict[self.rewards_holders[name]] = batch["{}_rewards".format(name)]
if self.policy.use_continuous_act:
feed_dict[self.policy_network.external_action_in] = batch["actions"]
else:
feed_dict[policy.output] = batch["actions"]
feed_dict[self.policy_network.external_action_in] = batch["actions"]
if not self.policy.use_continuous_act:
if self.policy.use_recurrent:
feed_dict[policy.prev_action] = batch["prev_action"]
feed_dict[policy.action_masks] = batch["action_mask"]

2
ml-agents/mlagents/trainers/tf_policy.py


self.all_log_probs: tf.Tensor = None
self.log_probs: Optional[tf.Tensor] = None
self.entropy: Optional[tf.Tensor] = None
self.action_oh: tf.Tensor = None
self.selected_actions: Optional[tf.Tensor] = None
self.action_masks: Optional[tf.Tensor] = None
self.prev_action: Optional[tf.Tensor] = None
self.memory_in: Optional[tf.Tensor] = None

正在加载...
取消
保存