浏览代码

Return list instead of np array for make_mini_batch() (#2371)

Return list instead of np array for make_mini_batch() to reduce time copying data
/develop-generalizationTraining-TrainerController
GitHub 5 年前
当前提交
d7ebaae1
共有 11 个文件被更改,包括 151 次插入238 次删除
  1. 110
      ml-agents/mlagents/trainers/buffer.py
  2. 38
      ml-agents/mlagents/trainers/components/bc/module.py
  3. 48
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py
  4. 6
      ml-agents/mlagents/trainers/components/reward_signals/gail/model.py
  5. 56
      ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py
  6. 7
      ml-agents/mlagents/trainers/ppo/models.py
  7. 63
      ml-agents/mlagents/trainers/ppo/policy.py
  8. 15
      ml-agents/mlagents/trainers/ppo/trainer.py
  9. 3
      ml-agents/mlagents/trainers/tests/mock_brain.py
  10. 28
      ml-agents/mlagents/trainers/tests/test_buffer.py
  11. 15
      ml-agents/mlagents/trainers/tests/test_reward_signals.py

110
ml-agents/mlagents/trainers/buffer.py


sequential=True gives [[0,a],[b,c],[d,e]]. If sequential=False gives
[[a,b],[b,c],[c,d],[d,e]]
"""
if training_length == 1:
# When the training length is 1, the method returns a list of elements,
# not a list of sequences of elements.
if sequential:
# The sequences will not have overlapping elements (this involves padding)
leftover = len(self) % training_length
# leftover is the number of elements in the first sequence (this sequence might need 0 padding)
# If batch_size is None : All the elements of the AgentBufferField are returned.
return np.array(self)
# retrieve the maximum number of elements
batch_size = len(self) // training_length + 1 * (leftover != 0)
# The maximum number of sequences taken from a list of length len(self) without overlapping
# with padding is equal to batch_size
if batch_size > (
len(self) // training_length + 1 * (leftover != 0)
):
raise BufferException(
"The batch size and training length requested for get_batch where"
" too large given the current number of data points."
)
if batch_size * training_length > len(self):
padding = np.array(self[-1]) * self.padding_value
return np.array(
[padding] * (training_length - leftover) + self[:]
)
# return the batch_size last elements
if batch_size > len(self):
raise BufferException("Batch size requested is too large")
return np.array(self[-batch_size:])
return np.array(
self[len(self) - batch_size * training_length :]
)
# The training_length is not None, the method returns a list of SEQUENCES of elements
if not sequential:
# The sequences will have overlapping elements
if batch_size is None:
# retrieve the maximum number of elements
batch_size = len(self) - training_length + 1
# The number of sequences of length training_length taken from a list of len(self) elements
# with overlapping is equal to batch_size
if (len(self) - training_length + 1) < batch_size:
raise BufferException(
"The batch size and training length requested for get_batch where"
" too large given the current number of data points."
)
tmp_list = []
for end in range(len(self) - batch_size + 1, len(self) + 1):
tmp_list += [np.array(self[end - training_length : end])]
return np.array(tmp_list)
if sequential:
# The sequences will not have overlapping elements (this involves padding)
leftover = len(self) % training_length
# leftover is the number of elements in the first sequence (this sequence might need 0 padding)
if batch_size is None:
# retrieve the maximum number of elements
batch_size = len(self) // training_length + 1 * (
leftover != 0
)
# The maximum number of sequences taken from a list of length len(self) without overlapping
# with padding is equal to batch_size
if batch_size > (
len(self) // training_length + 1 * (leftover != 0)
):
raise BufferException(
"The batch size and training length requested for get_batch where"
" too large given the current number of data points."
)
tmp_list = []
padding = np.array(self[-1]) * self.padding_value
# The padding is made with zeros and its shape is given by the shape of the last element
for end in range(
len(self), len(self) % training_length, -training_length
)[:batch_size]:
tmp_list += [np.array(self[end - training_length : end])]
if (leftover != 0) and (len(tmp_list) < batch_size):
tmp_list += [
np.array(
[padding] * (training_length - leftover)
+ self[:leftover]
)
]
tmp_list.reverse()
return np.array(tmp_list)
# The sequences will have overlapping elements
if batch_size is None:
# retrieve the maximum number of elements
batch_size = len(self) - training_length + 1
# The number of sequences of length training_length taken from a list of len(self) elements
# with overlapping is equal to batch_size
if (len(self) - training_length + 1) < batch_size:
raise BufferException(
"The batch size and training length requested for get_batch where"
" too large given the current number of data points."
)
tmp_list = []
for end in range(len(self) - batch_size + 1, len(self) + 1):
tmp_list += self[end - training_length : end]
return np.array(tmp_list)
def reset_field(self):
"""

length = len(self[key])
return True
def shuffle(self, key_list=None):
def shuffle(self, sequence_length, key_list=None):
Shuffles the fields in key_list in a consistent way: The reordering will
Shuffles the fields in key_list in a consistent way: The reordering will
be the same across fields.
:param key_list: The fields that must be shuffled.

raise BufferException(
"Unable to shuffle if the fields are not of same length"
)
s = np.arange(len(self[key_list[0]]))
s = np.arange(len(self[key_list[0]]) // sequence_length)
self[key][:] = [self[key][i] for i in s]
tmp = []
for i in s:
tmp += self[key][i * sequence_length : (i + 1) * sequence_length]
self[key][:] = tmp
def make_mini_batch(self, start, end):
"""

"""
mini_batch = {}
for key in self:
mini_batch[key] = np.array(self[key][start:end])
mini_batch[key] = self[key][start:end]
return mini_batch
def __init__(self):

38
ml-agents/mlagents/trainers/components/bc/module.py


n_epoch = self.num_epoch
for _ in range(n_epoch):
self.demonstration_buffer.update_buffer.shuffle()
self.demonstration_buffer.update_buffer.shuffle(
sequence_length=self.policy.sequence_length
)
for i in range(num_batches):
for i in range(num_batches // self.policy.sequence_length):
start = i * self.n_sequences
end = (i + 1) * self.n_sequences
start = i * self.n_sequences * self.policy.sequence_length
end = (i + 1) * self.n_sequences * self.policy.sequence_length
mini_batch_demo = demo_update_buffer.make_mini_batch(start, end)
run_out = self._update_batch(mini_batch_demo, self.n_sequences)
loss = run_out["loss"]

self.policy.model.batch_size: n_sequences,
self.policy.model.sequence_length: self.policy.sequence_length,
}
feed_dict[self.model.action_in_expert] = mini_batch_demo["actions"]
feed_dict[self.model.action_in_expert] = mini_batch_demo["actions"].reshape(
[-1, self.policy.model.brain.vector_action_space_size[0]]
)
feed_dict[self.model.action_in_expert] = mini_batch_demo["actions"].reshape(
[-1, len(self.policy.model.brain.vector_action_space_size)]
)
feed_dict[self.policy.model.action_masks] = np.ones(
(
self.n_sequences,

if self.policy.model.brain.vector_observation_space_size > 0:
apparent_obs_size = (
self.policy.model.brain.vector_observation_space_size
* self.policy.model.brain.num_stacked_vector_observations
)
feed_dict[self.policy.model.vector_in] = mini_batch_demo[
"vector_obs"
].reshape([-1, apparent_obs_size])
feed_dict[self.policy.model.vector_in] = mini_batch_demo["vector_obs"]
visual_obs = mini_batch_demo["visual_obs%d" % i]
if self.policy.sequence_length > 1 and self.policy.use_recurrent:
(_batch, _seq, _w, _h, _c) = visual_obs.shape
feed_dict[self.policy.model.visual_in[i]] = visual_obs.reshape(
[-1, _w, _h, _c]
)
else:
feed_dict[self.policy.model.visual_in[i]] = visual_obs
feed_dict[self.policy.model.visual_in[i]] = mini_batch_demo[
"visual_obs%d" % i
]
if self.use_recurrent:
feed_dict[self.policy.model.memory_in] = np.zeros(
[self.n_sequences, self.policy.m_size]

"prev_action"
].reshape([-1, len(self.policy.model.act_size)])
]
network_out = self.policy.sess.run(
list(self.out_dict.values()), feed_dict=feed_dict

48
ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py


forward_total: List[float] = []
inverse_total: List[float] = []
for _ in range(self.num_epoch):
update_buffer.shuffle()
update_buffer.shuffle(sequence_length=self.policy.sequence_length)
buffer = update_buffer
for l in range(len(update_buffer["actions"]) // num_sequences):
start = l * num_sequences

feed_dict = {
self.policy.model.batch_size: num_sequences,
self.policy.model.sequence_length: self.policy.sequence_length,
self.policy.model.mask_input: mini_batch["masks"].flatten(),
self.policy.model.advantage: mini_batch["advantages"].reshape([-1, 1]),
self.policy.model.all_old_log_probs: mini_batch["action_probs"].reshape(
[-1, sum(self.policy.model.act_size)]
),
self.policy.model.mask_input: mini_batch["masks"],
self.policy.model.advantage: mini_batch["advantages"],
self.policy.model.all_old_log_probs: mini_batch["action_probs"],
feed_dict[self.policy.model.output_pre] = mini_batch["actions_pre"].reshape(
[-1, self.policy.model.act_size[0]]
)
feed_dict[self.policy.model.output_pre] = mini_batch["actions_pre"]
feed_dict[self.policy.model.action_holder] = mini_batch["actions"].reshape(
[-1, len(self.policy.model.act_size)]
)
feed_dict[self.policy.model.action_holder] = mini_batch["actions"]
feed_dict[self.policy.model.vector_in] = mini_batch["vector_obs"].reshape(
[-1, self.policy.vec_obs_size]
)
feed_dict[self.model.next_vector_in] = mini_batch["next_vector_in"].reshape(
[-1, self.policy.vec_obs_size]
)
feed_dict[self.policy.model.vector_in] = mini_batch["vector_obs"]
feed_dict[self.model.next_vector_in] = mini_batch["next_vector_in"]
_obs = mini_batch["visual_obs%d" % i]
if self.policy.sequence_length > 1 and self.policy.use_recurrent:
(_batch, _seq, _w, _h, _c) = _obs.shape
feed_dict[self.policy.model.visual_in[i]] = _obs.reshape(
[-1, _w, _h, _c]
)
else:
feed_dict[self.policy.model.visual_in[i]] = _obs
feed_dict[self.policy.model.visual_in[i]] = mini_batch[
"visual_obs%d" % i
]
_obs = mini_batch["next_visual_obs%d" % i]
if self.policy.sequence_length > 1 and self.policy.use_recurrent:
(_batch, _seq, _w, _h, _c) = _obs.shape
feed_dict[self.model.next_visual_in[i]] = _obs.reshape(
[-1, _w, _h, _c]
)
else:
feed_dict[self.model.next_visual_in[i]] = _obs
feed_dict[self.model.next_visual_in[i]] = mini_batch[
"next_visual_obs%d" % i
]
self.has_updated = True
run_out = self.policy._execute_model(feed_dict, self.update_dict)

6
ml-agents/mlagents/trainers/components/reward_signals/gail/model.py


"""
Creates the input layers for the discriminator
"""
self.done_expert = tf.placeholder(shape=[None, 1], dtype=tf.float32)
self.done_policy = tf.placeholder(shape=[None, 1], dtype=tf.float32)
self.done_expert_holder = tf.placeholder(shape=[None], dtype=tf.float32)
self.done_policy_holder = tf.placeholder(shape=[None], dtype=tf.float32)
self.done_expert = tf.expand_dims(self.done_expert_holder, -1)
self.done_policy = tf.expand_dims(self.done_policy_holder, -1)
if self.policy_model.brain.vector_action_space_type == "continuous":
action_length = self.policy_model.act_size[0]

56
ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py


n_epoch = self.num_epoch
for _epoch in range(n_epoch):
self.demonstration_buffer.update_buffer.shuffle()
update_buffer.shuffle()
self.demonstration_buffer.update_buffer.shuffle(
sequence_length=self.policy.sequence_length
)
update_buffer.shuffle(sequence_length=self.policy.sequence_length)
if max_batches == 0:
num_batches = possible_batches
else:

:return: Output from update process.
"""
feed_dict: Dict[tf.Tensor, Any] = {
self.model.done_expert: mini_batch_demo["done"].reshape([-1, 1]),
self.model.done_policy: mini_batch_policy["done"].reshape([-1, 1]),
self.model.done_expert_holder: mini_batch_demo["done"],
self.model.done_policy_holder: mini_batch_policy["done"],
feed_dict[self.model.action_in_expert] = np.array(mini_batch_demo["actions"])
feed_dict[self.policy.model.selected_actions] = mini_batch_policy[
"actions"
].reshape([-1, self.policy.model.act_size[0]])
feed_dict[self.model.action_in_expert] = mini_batch_demo["actions"].reshape(
[-1, self.policy.model.act_size[0]]
)
feed_dict[self.policy.model.selected_actions] = mini_batch_policy["actions"]
feed_dict[self.policy.model.action_holder] = mini_batch_policy[
"actions"
].reshape([-1, len(self.policy.model.act_size)])
feed_dict[self.model.action_in_expert] = mini_batch_demo["actions"].reshape(
[-1, len(self.policy.model.act_size)]
)
feed_dict[self.policy.model.action_holder] = mini_batch_policy["actions"]
policy_obs = mini_batch_policy["visual_obs%d" % i]
if self.policy.sequence_length > 1 and self.policy.use_recurrent:
(_batch, _seq, _w, _h, _c) = policy_obs.shape
feed_dict[self.policy.model.visual_in[i]] = policy_obs.reshape(
[-1, _w, _h, _c]
)
else:
feed_dict[self.policy.model.visual_in[i]] = policy_obs
demo_obs = mini_batch_demo["visual_obs%d" % i]
if self.policy.sequence_length > 1 and self.policy.use_recurrent:
(_batch, _seq, _w, _h, _c) = demo_obs.shape
feed_dict[self.model.expert_visual_in[i]] = demo_obs.reshape(
[-1, _w, _h, _c]
)
else:
feed_dict[self.model.expert_visual_in[i]] = demo_obs
feed_dict[self.policy.model.visual_in[i]] = mini_batch_policy[
"visual_obs%d" % i
]
feed_dict[self.model.expert_visual_in[i]] = mini_batch_demo[
"visual_obs%d" % i
]
feed_dict[self.policy.model.vector_in] = mini_batch_policy[
"vector_obs"
].reshape([-1, self.policy.vec_obs_size])
feed_dict[self.model.obs_in_expert] = mini_batch_demo["vector_obs"].reshape(
[-1, self.policy.vec_obs_size]
)
feed_dict[self.policy.model.vector_in] = mini_batch_policy["vector_obs"]
feed_dict[self.model.obs_in_expert] = mini_batch_demo["vector_obs"]
out_dict = {
"gail_loss": self.model.loss,

7
ml-agents/mlagents/trainers/ppo/models.py


self.returns_holders[name] = returns_holder
self.old_values[name] = old_value
self.advantage = tf.placeholder(
shape=[None, 1], dtype=tf.float32, name="advantages"
shape=[None], dtype=tf.float32, name="advantages"
advantage = tf.expand_dims(self.advantage, -1)
self.learning_rate = tf.train.polynomial_decay(
lr, self.global_step, max_step, 1e-10, power=1.0
)

self.value_loss = tf.reduce_mean(value_losses)
r_theta = tf.exp(probs - old_probs)
p_opt_a = r_theta * self.advantage
p_opt_a = r_theta * advantage
* self.advantage
* advantage
)
self.policy_loss = -tf.reduce_mean(
tf.dynamic_partition(tf.minimum(p_opt_a, p_opt_b), self.mask, 2)[1]

63
ml-agents/mlagents/trainers/ppo/policy.py


def construct_feed_dict(self, model, mini_batch, num_sequences):
feed_dict = {
model.batch_size: num_sequences,
model.sequence_length: self.sequence_length,
model.mask_input: mini_batch["masks"].flatten(),
model.advantage: mini_batch["advantages"].reshape([-1, 1]),
model.all_old_log_probs: mini_batch["action_probs"].reshape(
[-1, sum(model.act_size)]
),
self.model.batch_size: num_sequences,
self.model.sequence_length: self.sequence_length,
self.model.mask_input: mini_batch["masks"],
self.model.advantage: mini_batch["advantages"],
self.model.all_old_log_probs: mini_batch["action_probs"],
].flatten()
]
].flatten()
]
feed_dict[model.output_pre] = mini_batch["actions_pre"].reshape(
[-1, model.act_size[0]]
)
feed_dict[model.epsilon] = mini_batch["random_normal_epsilon"].reshape(
[-1, model.act_size[0]]
)
feed_dict[model.output_pre] = mini_batch["actions_pre"]
feed_dict[model.epsilon] = mini_batch["random_normal_epsilon"]
feed_dict[model.action_holder] = mini_batch["actions"].reshape(
[-1, len(model.act_size)]
)
feed_dict[model.action_holder] = mini_batch["actions"]
feed_dict[model.prev_action] = mini_batch["prev_action"].reshape(
[-1, len(model.act_size)]
)
feed_dict[model.action_masks] = mini_batch["action_mask"].reshape(
[-1, sum(self.brain.vector_action_space_size)]
)
feed_dict[model.prev_action] = mini_batch["prev_action"]
feed_dict[model.action_masks] = mini_batch["action_mask"]
feed_dict[model.vector_in] = mini_batch["vector_obs"].reshape(
[-1, self.vec_obs_size]
)
if model.vis_obs_size > 0:
for i, _ in enumerate(model.visual_in):
_obs = mini_batch["visual_obs%d" % i]
if self.sequence_length > 1 and self.use_recurrent:
(_batch, _seq, _w, _h, _c) = _obs.shape
feed_dict[model.visual_in[i]] = _obs.reshape([-1, _w, _h, _c])
else:
feed_dict[model.visual_in[i]] = _obs
feed_dict[model.vector_in] = mini_batch["vector_obs"]
if self.model.vis_obs_size > 0:
for i, _ in enumerate(self.model.visual_in):
feed_dict[model.visual_in[i]] = mini_batch["visual_obs%d" % i]
mem_in = mini_batch["memory"][:, 0, :]
mem_in = [
mini_batch["memory"][i]
for i in range(0, len(mini_batch["memory"]), self.sequence_length)
]
feed_dict[model.memory_in] = mem_in
return feed_dict

brain_info.memories = self.make_empty_memory(len(brain_info.agents))
feed_dict[self.model.memory_in] = [brain_info.memories[idx]]
if not self.use_continuous_act and self.use_recurrent:
feed_dict[self.model.prev_action] = brain_info.previous_vector_actions[
idx
].reshape([-1, len(self.model.act_size)])
feed_dict[self.model.prev_action] = [
brain_info.previous_vector_actions[idx]
]
value_estimates = self.sess.run(self.model.value_heads, feed_dict)
value_estimates = {k: float(v) for k, v in value_estimates.items()}

15
ml-agents/mlagents/trainers/ppo/trainer.py


:return: A boolean corresponding to whether or not update_model() can be run
"""
size_of_buffer = len(self.training_buffer.update_buffer["actions"])
return size_of_buffer > max(
int(self.trainer_parameters["buffer_size"] / self.policy.sequence_length), 1
)
return size_of_buffer > self.trainer_parameters["buffer_size"]
def update_policy(self):
"""

mean_return=float(np.mean(self.cumulative_returns_since_policy_update)),
)
self.cumulative_returns_since_policy_update = []
batch_size = self.trainer_parameters["batch_size"]
n_sequences = max(
int(self.trainer_parameters["batch_size"] / self.policy.sequence_length), 1
)

)
num_epoch = self.trainer_parameters["num_epoch"]
for _ in range(num_epoch):
self.training_buffer.update_buffer.shuffle()
self.training_buffer.update_buffer.shuffle(
sequence_length=self.policy.sequence_length
)
len(self.training_buffer.update_buffer["actions"]) // n_sequences
0, len(self.training_buffer.update_buffer["actions"]), batch_size
start = l * n_sequences
end = (l + 1) * n_sequences
buffer.make_mini_batch(start, end), n_sequences
buffer.make_mini_batch(l, l + batch_size), n_sequences
)
value_total.append(run_out["value_loss"])
policy_total.append(np.abs(run_out["policy_loss"]))

3
ml-agents/mlagents/trainers/tests/mock_brain.py


num_vis_observations=0,
num_vector_acts=2,
discrete=False,
num_discrete_branches=1,
):
"""
Creates a mock BrainInfo with observations. Imitates constant

)
if discrete:
mock_braininfo.return_value.previous_vector_actions = np.array(
num_agents * [1 * [0.5]]
num_agents * [num_discrete_branches * [0.5]]
)
mock_braininfo.return_value.action_masks = np.array(
num_agents * [num_vector_acts * [1.0]]

28
ml-agents/mlagents/trainers/tests/test_buffer.py


a = b[1]["vector_observation"].get_batch(
batch_size=2, training_length=1, sequential=True
)
assert_array(a, np.array([[171, 172, 173], [181, 182, 183]]))
assert_array(np.array(a), np.array([[171, 172, 173], [181, 182, 183]]))
a,
np.array(a),
[[231, 232, 233], [241, 242, 243], [251, 252, 253]],
[[261, 262, 263], [271, 272, 273], [281, 282, 283]],
[231, 232, 233],
[241, 242, 243],
[251, 252, 253],
[261, 262, 263],
[271, 272, 273],
[281, 282, 283],
]
),
)

assert_array(
a,
np.array(a),
[[251, 252, 253], [261, 262, 263], [271, 272, 273]],
[[261, 262, 263], [271, 272, 273], [281, 282, 283]],
[251, 252, 253],
[261, 262, 263],
[271, 272, 273],
[261, 262, 263],
[271, 272, 273],
[281, 282, 283],
]
),
)

b.append_update_buffer(2, batch_size=None, training_length=2)
assert len(b.update_buffer["action"]) == 10
assert np.array(b.update_buffer["action"]).shape == (10, 2, 2)
assert len(b.update_buffer["action"]) == 20
assert np.array(b.update_buffer["action"]).shape == (20, 2)
assert c["action"].shape == (1, 2, 2)
assert np.array(c["action"]).shape == (1, 2)

15
ml-agents/mlagents/trainers/tests/test_reward_signals.py


VECTOR_ACTION_SPACE = [2]
VECTOR_OBS_SPACE = 8
DISCRETE_ACTION_SPACE = [2]
DISCRETE_ACTION_SPACE = [3, 3, 3, 2]
BUFFER_INIT_SAMPLES = 20
NUM_AGENTS = 12

DISCRETE_ACTION_SPACE if use_discrete else VECTOR_ACTION_SPACE
),
discrete=use_discrete,
num_discrete_branches=len(DISCRETE_ACTION_SPACE),
)
else:
mock_brain = mb.create_mock_brainparams(

DISCRETE_ACTION_SPACE if use_discrete else VECTOR_ACTION_SPACE
),
discrete=use_discrete,
num_discrete_branches=len(DISCRETE_ACTION_SPACE),
)
mb.setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo)
env = mock_env()

@mock.patch("mlagents.envs.UnityEnvironment")
def test_gail_dc(mock_env, dummy_config, gail_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, gail_dummy_config, False, True, False
)
reward_signal_eval(env, policy, "gail")
reward_signal_update(env, policy, "gail")
@mock.patch("mlagents.envs.UnityEnvironment")
def test_gail_visual(mock_env, dummy_config, gail_dummy_config):
def test_gail_dc_visual(mock_env, dummy_config, gail_dummy_config):
gail_dummy_config["gail"]["demo_path"] = (
os.path.dirname(os.path.abspath(__file__)) + "/testdcvis.demo"
)

正在加载...
取消
保存