浏览代码

Some work for 2d obs

/exp-bullet-hell-trainer
vincentpierre 4 年前
当前提交
c2b50936
共有 2 个文件被更改,包括 6 次插入2 次删除
  1. 4
      com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs
  2. 4
      ml-agents-envs/mlagents_envs/rpc_utils.py

4
com.unity.ml-agents/Runtime/Sensors/ObservationWriter.cs


{
m_TensorShape = new TensorShape(m_Batch, shape[0]);
}
else if (shape.Length == 2)
{
m_TensorShape = new TensorShape(new int[]{m_Batch, 1, shape[0], shape[1]});
}
else
{
m_TensorShape = new TensorShape(m_Batch, shape[0], shape[1], shape[2]);

4
ml-agents-envs/mlagents_envs/rpc_utils.py


], # pylint: disable=unsubscriptable-object
) -> np.ndarray:
if len(agent_info_list) == 0:
return np.zeros((0, shape[0]), dtype=np.float32)
return np.zeros((0,) + shape, dtype=np.float32)
np_obs = np.array(
[
agent_obs.observations[obs_index].float_data.data

)
).reshape((len(agent_info_list), ) + shape)
_raise_on_nan_and_inf(np_obs, "observations")
return np_obs

正在加载...
取消
保存