浏览代码

Clean up memory_size logic

/develop/add-fire/memoryclass
Ervin Teng 4 年前
当前提交
1656d290
共有 1 个文件被更改,包括 10 次插入8 次删除
  1. 18
      ml-agents/mlagents/trainers/torch/networks.py

18
ml-agents/mlagents/trainers/torch/networks.py


for n1, n2 in zip(self.vector_encoders, other_network.vector_encoders):
n1.copy_normalization(n2)
@property
def memory_size(self) -> int:
return self.lstm.memory_size if self.use_lstm else 0
def forward(
self,
vec_inputs: List[torch.Tensor],

else:
encoding_size = network_settings.hidden_units
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream)
@property
def memory_size(self) -> int:
return self.network_body.memory_size
def forward(
self,

@property
def memory_size(self) -> int:
if self.network_body.lstm is not None:
return self.network_body.lstm.memory_size
else:
return 0
return self.network_body.memory_size
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None:
self.network_body.update_normalization(vector_obs)

@property
def memory_size(self) -> int:
if self.network_body.lstm is not None:
return 2 * self.network_body.lstm.memory_size
else:
return 0
return self.network_body.memory_size + self.critic.memory_size
def critic_pass(
self,

正在加载...
取消
保存