浏览代码

Plurals

/MLA-1734-demo-provider
Arthur Juliani 4 年前
当前提交
e3de0406
共有 7 个文件被更改,包括 13 次插入13 次删除
  1. 2
      ml-agents/mlagents/trainers/policy/torch_policy.py
  2. 6
      ml-agents/mlagents/trainers/tests/mock_brain.py
  3. 2
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py
  4. 2
      ml-agents/mlagents/trainers/tests/torch/test_ppo.py
  5. 2
      ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
  6. 2
      ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py
  7. 10
      ml-agents/mlagents/trainers/torch/networks.py

2
ml-agents/mlagents/trainers/policy/torch_policy.py


else:
ac_class = SharedActorCritic
self.actor_critic = ac_class(
observation_spec=self.behavior_spec.observation_specs,
observation_specs=self.behavior_spec.observation_specs,
network_settings=trainer_settings.network_settings,
action_spec=behavior_spec.action_spec,
stream_names=reward_signal_names,

6
ml-agents/mlagents/trainers/tests/mock_brain.py


def make_fake_trajectory(
length: int,
observation_spec: List[ObservationSpec],
observation_specs: List[ObservationSpec],
action_spec: ActionSpec,
max_step_complete: bool = False,
memory_size: int = 10,

action_size = action_spec.discrete_size + action_spec.continuous_size
for _i in range(length - 1):
obs = []
for obs_spec in observation_spec:
for obs_spec in observation_specs:
obs.append(np.ones(obs_spec.shape, dtype=np.float32))
reward = 1.0
done = False

)
steps_list.append(experience)
obs = []
for obs_spec in observation_spec:
for obs_spec in observation_specs:
obs.append(np.ones(obs_spec.shape, dtype=np.float32))
last_experience = AgentExperience(
obs=obs,

2
ml-agents/mlagents/trainers/tests/test_rl_trainer.py


checkpoint_interval = trainer.trainer_settings.checkpoint_interval
trajectory = mb.make_fake_trajectory(
length=time_horizon,
observation_spec=create_observation_specs_with_shapes([(1,)]),
observation_specs=create_observation_specs_with_shapes([(1,)]),
max_step_complete=True,
action_spec=ActionSpec.create_discrete((2,)),
)

2
ml-agents/mlagents/trainers/tests/torch/test_ppo.py


time_horizon = 15
trajectory = make_fake_trajectory(
length=time_horizon,
observation_spec=optimizer.policy.behavior_spec.observation_specs,
observation_specs=optimizer.policy.behavior_spec.observation_specs,
action_spec=DISCRETE_ACTION_SPEC if discrete else CONTINUOUS_ACTION_SPEC,
max_step_complete=True,
)

2
ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py


memory=None,
)
self._state_encoder = NetworkBody(
specs.observation_spec, state_encoder_settings
specs.observation_specs, state_encoder_settings
)
self._action_flattener = ActionFlattener(self._action_spec)

2
ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py


vis_encode_type=EncoderType.SIMPLE,
memory=None,
)
self._encoder = NetworkBody(specs.observation_spec, state_encoder_settings)
self._encoder = NetworkBody(specs.observation_specs, state_encoder_settings)
def forward(self, mini_batch: AgentBuffer) -> torch.Tensor:
n_obs = len(self._encoder.processors)

10
ml-agents/mlagents/trainers/torch/networks.py


class SharedActorCritic(SimpleActor, ActorCritic):
def __init__(
self,
observation_spec: List[ObservationSpec],
observation_specs: List[ObservationSpec],
network_settings: NetworkSettings,
action_spec: ActionSpec,
stream_names: List[str],

self.use_lstm = network_settings.memory is not None
super().__init__(
observation_spec,
observation_specs,
network_settings,
action_spec,
conditional_sigma,

class SeparateActorCritic(SimpleActor, ActorCritic):
def __init__(
self,
observation_spec: List[ObservationSpec],
observation_specs: List[ObservationSpec],
network_settings: NetworkSettings,
action_spec: ActionSpec,
stream_names: List[str],

self.use_lstm = network_settings.memory is not None
super().__init__(
observation_spec,
observation_specs,
network_settings,
action_spec,
conditional_sigma,

self.critic = ValueNetwork(stream_names, observation_spec, network_settings)
self.critic = ValueNetwork(stream_names, observation_specs, network_settings)
@property
def memory_size(self) -> int:

正在加载...
取消
保存