|
|
|
|
|
|
action_mask = np.split(action_mask, indices, axis=1) |
|
|
|
return BatchedStepResult(obs_list, rewards, done, max_step, agent_id, action_mask) |
|
|
|
|
|
|
|
|
|
|
|
def proto_from_batched_step_result(batched_step_result: BatchedStepResult) -> AgentInfoProto: |
|
|
|
def proto_from_batched_step_result( |
|
|
|
batched_step_result: BatchedStepResult |
|
|
|
) -> AgentInfoProto: |
|
|
|
reward = batched_step_result.reward |
|
|
|
done = batched_step_result.done |
|
|
|
max_step_reached = batched_step_result.max_step |
|
|
|
|
|
|
return AgentInfoProto(reward=reward, done=done, id=agent_id, max_step_reached=max_step_reached, action_mask=action_mask, observations=observations) |
|
|
|
return AgentInfoProto( |
|
|
|
reward=reward, |
|
|
|
done=done, |
|
|
|
id=agent_id, |
|
|
|
max_step_reached=max_step_reached, |
|
|
|
action_mask=action_mask, |
|
|
|
observations=observations, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _generate_split_indices(dims): |
|
|
|
if len(dims) <= 1: |
|
|
|