浏览代码

fixed recurrent prev_action issue

/develop/action-spec-gym
Andrew Cohen 4 年前
当前提交
3c65b964
共有 4 个文件被更改,包括 5 次插入3 次删除
  1. 2
      ml-agents/mlagents/trainers/agent_processor.py
  2. 2
      ml-agents/mlagents/trainers/optimizer/tf_optimizer.py
  3. 2
      ml-agents/mlagents/trainers/sac/optimizer_tf.py
  4. 2
      ml-agents/mlagents/trainers/tests/tensorflow/test_simple_rl.py

2
ml-agents/mlagents/trainers/agent_processor.py


action_mask = stored_decision_step.action_mask
prev_action = self.policy.retrieve_previous_action([global_id])
for _prev_act_type, _prev_act in prev_action.items():
prev_action[_prev_act_type] = _prev_act[0, :]
experience = AgentExperience(
obs=obs,
reward=step.reward,

2
ml-agents/mlagents/trainers/optimizer/tf_optimizer.py


]
feed_dict[self.memory_in] = [np.zeros((self.m_size), dtype=np.float32)]
if self.policy.prev_action is not None:
feed_dict[self.policy.prev_action] = batch["prev_action"]
feed_dict[self.policy.prev_action] = batch["prev_discrete_action"]
if self.policy.use_recurrent:
value_estimates, policy_mem, value_mem = self.sess.run(

2
ml-agents/mlagents/trainers/sac/optimizer_tf.py


else:
feed_dict[policy.output] = batch["discrete_action"]
if self.policy.use_recurrent:
feed_dict[policy.prev_action] = batch["prev_action"]
feed_dict[policy.prev_action] = batch["prev_discrete_action"]
feed_dict[policy.action_masks] = batch["action_mask"]
if self.policy.use_vec_obs:
feed_dict[policy.vector_in] = batch["vector_obs"]

2
ml-agents/mlagents/trainers/tests/tensorflow/test_simple_rl.py


_check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.5)
@pytest.mark.parametrize("use_discrete", [True])
@pytest.mark.parametrize("use_discrete", [True, False])
def test_recurrent_ppo(use_discrete):
env = MemoryEnvironment([BRAIN_NAME], use_discrete=use_discrete)
new_network_settings = attr.evolve(

正在加载...
取消
保存