浏览代码

remove self.action_spec from policy/bc

/develop/action-spec-gym
Andrew Cohen 4 年前
当前提交
6cf54bf2
共有 3 个文件被更改,包括 16 次插入12 次删除
  1. 16
      ml-agents/mlagents/trainers/policy/policy.py
  2. 9
      ml-agents/mlagents/trainers/policy/tf_policy.py
  3. 3
      ml-agents/mlagents/trainers/torch/components/bc/module.py

16
ml-agents/mlagents/trainers/policy/policy.py


condition_sigma_on_obs: bool = True,
):
self.behavior_spec = behavior_spec
self.action_spec = behavior_spec.action_spec
if self.action_spec.continuous_size > 0 and self.action_spec.discrete_size > 0:
if (
self.behavior_spec.action_spec.continuous_size > 0
and self.behavior_spec.action_spec.discrete_size > 0
):
list(self.action_spec.discrete_branches)
if self.action_spec.is_discrete()
else [self.action_spec.size]
list(self.behavior_spec.action_spec.discrete_branches)
if self.behavior_spec.action_spec.is_discrete()
else [self.behavior_spec.action_spec.size]
)
self.vec_obs_size = sum(
shape[0] for shape in behavior_spec.observation_shapes if len(shape) == 1

)
self.use_continuous_act = self.action_spec.is_continuous()
self.num_branches = self.action_spec.size
self.use_continuous_act = self.behavior_spec.action_spec.is_continuous()
self.num_branches = self.behavior_spec.action_spec.size
self.previous_action_dict: Dict[str, np.array] = {}
self.memory_dict: Dict[str, np.ndarray] = {}
self.normalize = trainer_settings.network_settings.normalize

9
ml-agents/mlagents/trainers/policy/tf_policy.py


feed_dict[self.vector_in] = vec_vis_obs.vector_observations
if not self.use_continuous_act:
mask = np.ones(
(len(batched_step_result), sum(self.action_spec.discrete_branches)),
(
len(batched_step_result),
sum(self.behavior_spec.action_spec.discrete_branches),
),
dtype=np.float32,
)
if batched_step_result.action_mask is not None:

self.mask = tf.cast(self.mask_input, tf.int32)
tf.Variable(
int(self.action_spec.is_continuous()),
int(self.behavior_spec.action_spec.is_continuous()),
name="is_continuous_control",
trainable=False,
dtype=tf.int32,

tf.Variable(
self.m_size, name="memory_size", trainable=False, dtype=tf.int32
)
if self.action_spec.is_continuous():
if self.behavior_spec.action_spec.is_continuous():
tf.Variable(
self.act_size[0],
name="action_output_shape",

3
ml-agents/mlagents/trainers/torch/components/bc/module.py


for the pretrainer.
"""
self.policy = policy
self.action_spec = policy.action_spec
self._anneal_steps = settings.steps
self.current_lr = policy_learning_rate * settings.strength

np.ones(
(
self.n_sequences * self.policy.sequence_length,
sum(self.action_spec.discrete_branches),
sum(self.policy.behavior_spec.action_spec.discrete_branches),
),
dtype=np.float32,
)

正在加载...
取消
保存