浏览代码

fix test_tf_policy

/develop/action-spec-gym
Andrew Cohen 4 年前
当前提交
97d94a83
共有 2 个文件被更改,包括 8 次插入6 次删除
  1. 12
      ml-agents/mlagents/trainers/policy/tf_policy.py
  2. 2
      ml-agents/mlagents/trainers/tests/tensorflow/test_tf_policy.py

12
ml-agents/mlagents/trainers/policy/tf_policy.py


self.save_memories(global_agent_ids, run_out.get("memory_out"))
# For Compatibility with buffer changes for hybrid action support
run_out["log_probs"] = {"action_probs": run_out["log_probs"]}
if self.behavior_spec.action_spec.is_continuous():
run_out["action"] = {"continuous_action": run_out["action"]}
else:
run_out["action"] = {"discrete_action": run_out["action"]}
if "log_probs" in run_out:
run_out["log_probs"] = {"action_probs": run_out["log_probs"]}
if "action" in run_out:
if self.behavior_spec.action_spec.is_continuous():
run_out["action"] = {"continuous_action": run_out["action"]}
else:
run_out["action"] = {"discrete_action": run_out["action"]}
return ActionInfo(
action=run_out.get("action"),
value=run_out.get("value"),

2
ml-agents/mlagents/trainers/tests/tensorflow/test_tf_policy.py


behavior_spec = basic_behavior_spec()
policy = FakePolicy(test_seed, behavior_spec, TrainerSettings(), "output")
policy_eval_out = {
"action": np.array([1.0], dtype=np.float32),
"action": {"continuous_action": np.array([1.0], dtype=np.float32)},
"memory_out": np.array([[2.5]], dtype=np.float32),
"value": np.array([1.1], dtype=np.float32),
}

正在加载...
取消
保存