浏览代码

fix action_spec refs

/develop/action-spec-gym
Andrew Cohen 4 年前
当前提交
97dfa142
共有 7 个文件被更改,包括 15 次插入12 次删除
  1. 2
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  2. 13
      ml-agents/mlagents/trainers/tests/torch/test_policy.py
  3. 2
      ml-agents/mlagents/trainers/tf/components/bc/model.py
  4. 4
      ml-agents/mlagents/trainers/tf/components/bc/module.py
  5. 2
      ml-agents/mlagents/trainers/tf/components/reward_signals/curiosity/model.py
  6. 2
      ml-agents/mlagents/trainers/tf/components/reward_signals/gail/model.py
  7. 2
      ml-agents/mlagents/trainers/torch/model_serialization.py

2
ml-agents/mlagents/trainers/sac/optimizer_torch.py


self.stream_names,
self.policy.behavior_spec.observation_shapes,
policy_network_settings,
self.policy.action_spec,
self.policy.behavior_spec.action_spec,
)
self.target_network = ValueNetwork(

13
ml-agents/mlagents/trainers/tests/torch/test_policy.py


memories=memories,
seq_len=policy.sequence_length,
)
assert log_probs.shape == (64, policy.action_spec.size)
assert entropy.shape == (64, policy.action_spec.size)
assert log_probs.shape == (64, policy.behavior_spec.action_spec.size)
assert entropy.shape == (64, policy.behavior_spec.action_spec.size)
for val in values.values():
assert val.shape == (64,)

all_log_probs=not policy.use_continuous_act,
)
if discrete:
assert log_probs.shape == (64, sum(policy.action_spec.discrete_branches))
assert log_probs.shape == (
64,
sum(policy.behavior_spec.action_spec.discrete_branches),
)
assert log_probs.shape == (64, policy.action_spec.continuous_size)
assert entropies.shape == (64, policy.action_spec.size)
assert log_probs.shape == (64, policy.behavior_spec.action_spec.continuous_size)
assert entropies.shape == (64, policy.behavior_spec.action_spec.size)
if rnn:
assert memories.shape == (1, 1, policy.m_size)

2
ml-agents/mlagents/trainers/tf/components/bc/model.py


self.done_expert = tf.placeholder(shape=[None, 1], dtype=tf.float32)
self.done_policy = tf.placeholder(shape=[None, 1], dtype=tf.float32)
if self.policy.action_spec.is_continuous():
if self.policy.behavior_spec.action_spec.is_continuous():
action_length = self.policy.act_size[0]
self.action_in_expert = tf.placeholder(
shape=[None, action_length], dtype=tf.float32

4
ml-agents/mlagents/trainers/tf/components/bc/module.py


self.policy.sequence_length_ph: self.policy.sequence_length,
}
feed_dict[self.model.action_in_expert] = mini_batch_demo["actions"]
if self.policy.action_spec.is_discrete():
if self.policy.behavior_spec.action_spec.is_discrete():
sum(self.policy.action_spec.discrete_branches),
sum(self.policy.behavior_spec.action_spec.discrete_branches),
),
dtype=np.float32,
)

2
ml-agents/mlagents/trainers/tf/components/reward_signals/curiosity/model.py


"""
combined_input = tf.concat([encoded_state, encoded_next_state], axis=1)
hidden = tf.layers.dense(combined_input, 256, activation=ModelUtils.swish)
if self.policy.action_spec.is_continuous():
if self.policy.behavior_spec.action_spec.is_continuous():
pred_action = tf.layers.dense(
hidden, self.policy.act_size[0], activation=None
)

2
ml-agents/mlagents/trainers/tf/components/reward_signals/gail/model.py


self.done_expert = tf.expand_dims(self.done_expert_holder, -1)
self.done_policy = tf.expand_dims(self.done_policy_holder, -1)
if self.policy.action_spec.is_continuous():
if self.policy.behavior_spec.action_spec.is_continuous():
action_length = self.policy.act_size[0]
self.action_in_expert = tf.placeholder(
shape=[None, action_length], dtype=tf.float32

2
ml-agents/mlagents/trainers/torch/model_serialization.py


if len(shape) == 3
]
dummy_masks = torch.ones(
batch_dim + [sum(self.policy.action_spec.discrete_branches)]
batch_dim + [sum(self.policy.behavior_spec.action_spec.discrete_branches)]
)
dummy_memories = torch.zeros(
batch_dim + seq_len_dim + [self.policy.export_memory_size]

正在加载...
取消
保存