浏览代码

Move some common logic to buffer class

/develop-newnormalization
Ervin Teng 5 年前
当前提交
c9116ed2
共有 3 个文件被更改,包括 35 次插入14 次删除
  1. 29
      ml-agents/mlagents/trainers/buffer.py
  2. 10
      ml-agents/mlagents/trainers/ppo/trainer.py
  3. 10
      ml-agents/mlagents/trainers/sac/trainer.py

29
ml-agents/mlagents/trainers/buffer.py


for _key in self.keys():
self[_key] = self[_key][current_length - max_length :]
def resequence_and_append(
self,
target_buffer: "AgentBuffer",
key_list: List[str] = None,
batch_size: int = None,
training_length: int = None,
) -> None:
"""
Takes in a batch size and training length (sequence length), and appends this AgentBuffer to target_buffer
properly padded for LSTM use. Optionally, use key_list to restrict which fields are inserted into the new
buffer.
:param target_buffer: The buffer which to append the samples to.
:param key_list: The fields that must be added. If None: all fields will be appended.
:param batch_size: The number of elements that must be appended. If None: All of them will be.
:param training_length: The length of the samples that must be appended. If None: only takes one element.
"""
if key_list is None:
key_list = list(self.keys())
if not self.check_length(key_list):
raise BufferException(
"The length of the fields {0} were not of same length".format(key_list)
)
for field_key in key_list:
target_buffer[field_key].extend(
self[field_key].get_batch(
batch_size=batch_size, training_length=training_length
)
)
@property
def num_experiences(self) -> int:
"""

10
ml-agents/mlagents/trainers/ppo/trainer.py


agent_buffer_trajectory["advantages"].set(global_advantages)
agent_buffer_trajectory["discounted_returns"].set(global_returns)
# Append to update buffer
key_list = agent_buffer_trajectory.keys()
for field_key in key_list:
self.update_buffer[field_key].extend(
agent_buffer_trajectory[field_key].get_batch(
batch_size=None, training_length=self.policy.sequence_length
)
)
agent_buffer_trajectory.resequence_and_append(
self.update_buffer, training_length=self.policy.sequence_length
)
if trajectory.steps[-1].done:
self.stats["Environment/Episode Length"].append(

10
ml-agents/mlagents/trainers/sac/trainer.py


agent_buffer_trajectory["done"][-1] = False
# Append to update buffer
key_list = agent_buffer_trajectory.keys()
for field_key in key_list:
self.update_buffer[field_key].extend(
agent_buffer_trajectory[field_key].get_batch(
batch_size=None, training_length=self.policy.sequence_length
)
)
agent_buffer_trajectory.resequence_and_append(
self.update_buffer, training_length=self.policy.sequence_length
)
if trajectory.steps[-1].done:
self.stats["Environment/Episode Length"].append(

正在加载...
取消
保存