|
|
|
|
|
|
from typing import Any, Dict, List |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
import copy |
|
|
|
|
|
|
|
from mlagents.trainers.action_info import ActionInfo |
|
|
|
from mlagents.trainers.behavior_id_utils import get_global_agent_id |
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
def get_weights(self) -> List[np.ndarray]: |
|
|
|
return self.actor_critic.state_dict() |
|
|
|
return copy.deepcopy(self.actor_critic.state_dict()) |
|
|
|
|
|
|
|
def get_modules(self): |
|
|
|
return {"Policy": self.actor_critic, "global_step": self.global_step} |