|
|
|
|
|
|
action_mask = np.split(action_mask, indices, axis=1) |
|
|
|
return BatchedStepResult(obs_list, rewards, done, max_step, agent_id, action_mask) |
|
|
|
|
|
|
|
@timed |
|
|
|
def proto_from_batched_step_result(batched_step_result: BatchedStepResult) -> AgentInfoProto: |
|
|
|
reward = batched_step_result.reward |
|
|
|
done = batched_step_result.done |
|
|
|
max_step_reached = batched_step_result.max_step |
|
|
|
agent_id = batched_step_result.agent_id |
|
|
|
action_mask = batched_step_result.action_mask |
|
|
|
observations = batched_step_result.obs |
|
|
|
return AgentInfoProto(reward=reward, done=done, id=agent_id, max_step_reached=max_step_reached, action_mask=action_mask, observations=observations) |
|
|
|
|
|
|
|
def _generate_split_indices(dims): |
|
|
|
if len(dims) <= 1: |
|
|
|