浏览代码

Fix np float32 errors

/develop-newnormalization
Ervin Teng 5 年前
当前提交
77aea4cd
共有 2 个文件被更改,包括 10 次插入10 次删除
  1. 2
      ml-agents/mlagents/trainers/agent_processor.py
  2. 18
      ml-agents/mlagents/trainers/tests/test_trajectory.py

2
ml-agents/mlagents/trainers/agent_processor.py


self.last_take_action_outputs[agent_id] = take_action_outputs
# Store the environment reward
tmp_environment = np.array(next_info.rewards)
tmp_environment = np.array(next_info.rewards, dtype=np.float32)
for agent_id in next_info.agents:
stored_info = self.last_brain_info.get(agent_id, None)

18
ml-agents/mlagents/trainers/tests/test_trajectory.py


for i in range(length - 1):
obs = []
for i in range(num_vis_obs):
obs.append(np.ones((84, 84, 3)))
obs.append(np.ones(vec_obs_size))
obs.append(np.ones((84, 84, 3), dtype=np.float32))
obs.append(np.ones(vec_obs_size, dtype=np.float32))
action = np.zeros(action_space)
action_probs = np.ones(action_space)
action_pre = np.zeros(action_space)
action_mask = np.ones(action_space)
prev_action = np.ones(action_space)
action = np.zeros(action_space, dtype=np.float32)
action_probs = np.ones(action_space, dtype=np.float32)
action_pre = np.zeros(action_space, dtype=np.float32)
action_mask = np.ones(action_space, dtype=np.float32)
prev_action = np.ones(action_space, dtype=np.float32)
memory = np.ones(10)
memory = np.ones(10, dtype=np.float32)
agent_id = "test_agent"
experience = AgentExperience(
obs=obs,

for i in range(num_visual_obs):
obs.append(np.ones((84, 84, 3), dtype=np.float32))
for i in range(num_vec_obs):
obs.append(np.ones(VEC_OBS_SIZE))
obs.append(np.ones(VEC_OBS_SIZE, dtype=np.float32))
split_observations = split_obs(obs)
if num_vec_obs == 1:

正在加载...
取消
保存