浏览代码

Enable GAIL to decay

/develop/decaygail
Ervin Teng 3 年前
当前提交
fd3f05b9
共有 9 个文件被更改,包括 39 次插入9 次删除
  1. 4
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  2. 11
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  3. 2
      ml-agents/mlagents/trainers/sac/trainer.py
  4. 2
      ml-agents/mlagents/trainers/settings.py
  5. 5
      ml-agents/mlagents/trainers/torch/components/reward_providers/base_reward_provider.py
  6. 4
      ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
  7. 4
      ml-agents/mlagents/trainers/torch/components/reward_providers/extrinsic_reward_provider.py
  8. 12
      ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
  9. 4
      ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py

4
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


}
for reward_provider in self.reward_signals.values():
update_stats.update(reward_provider.update(batch))
update_stats.update(
reward_provider.update(batch, self.policy.get_current_step())
)
return update_stats

11
ml-agents/mlagents/trainers/sac/optimizer_torch.py


return update_stats
def update_reward_signals(
self, reward_signal_minibatches: Mapping[str, AgentBuffer], num_sequences: int
self,
reward_signal_minibatches: Mapping[str, AgentBuffer],
num_sequences: int,
global_step: int,
update_stats.update(self.reward_signals[name].update(update_buffer))
update_stats.update(
self.reward_signals[name].update(
update_buffer, self.policy.get_current_step()
)
)
return update_stats
def get_modules(self):

2
ml-agents/mlagents/trainers/sac/trainer.py


sequence_length=self.policy.sequence_length,
)
update_stats = self.optimizer.update_reward_signals(
reward_signal_minibatches, n_sequences
reward_signal_minibatches, n_sequences, self.step
)
for stat_name, value in update_stats.items():
batch_update_stats[stat_name].append(value)

2
ml-agents/mlagents/trainers/settings.py


class GAILSettings(RewardSignalSettings):
encoding_size: int = 64
learning_rate: float = 3e-4
learning_rate_schedule: ScheduleType = ScheduleType.CONSTANT
decay_steps: int = 0
use_actions: bool = False
use_vail: bool = False
demo_path: str = attr.ib(kw_only=True)

5
ml-agents/mlagents/trainers/torch/components/reward_providers/base_reward_provider.py


)
@abstractmethod
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
def update(
self, mini_batch: AgentBuffer, global_step: int
) -> Dict[str, np.ndarray]:
:param global_step: The trainer's global step. Used to decay reward signals over time.
:return: A dictionary from string to stats values
"""
raise NotImplementedError(

4
ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py


rewards = np.minimum(rewards, 1.0 / self.strength)
return rewards * self._has_updated_once
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
def update(
self, mini_batch: AgentBuffer, global_step: int
) -> Dict[str, np.ndarray]:
self._has_updated_once = True
forward_loss = self._network.compute_forward_loss(mini_batch)
inverse_loss = self._network.compute_inverse_loss(mini_batch)

4
ml-agents/mlagents/trainers/torch/components/reward_providers/extrinsic_reward_provider.py


def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray:
return np.array(mini_batch[BufferKey.ENVIRONMENT_REWARDS], dtype=np.float32)
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
def update(
self, mini_batch: AgentBuffer, global_step: int
) -> Dict[str, np.ndarray]:
return {}

12
ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py


settings.demo_path, 1, specs
) # This is supposed to be the sequence length but we do not have access here
params = list(self._discriminator_network.parameters())
self.decay_learning_rate = ModelUtils.DecayedValue(
settings.learning_rate_schedule,
settings.learning_rate,
1e-10,
settings.decay_steps,
)
self.optimizer = torch.optim.Adam(params, lr=settings.learning_rate)
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray:

)
)
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
def update(
self, mini_batch: AgentBuffer, global_step: int
) -> Dict[str, np.ndarray]:
decay_lr = self.decay_learning_rate.get_value(global_step)
expert_batch = self._demo_buffer.sample_mini_batch(
mini_batch.num_experiences, 1
)

ModelUtils.update_learning_rate(self.optimizer, decay_lr)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

4
ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py


rewards = torch.sum((prediction - target) ** 2, dim=1)
return rewards.detach().cpu().numpy()
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
def update(
self, mini_batch: AgentBuffer, global_step: int
) -> Dict[str, np.ndarray]:
with torch.no_grad():
target = self._random_network(mini_batch)
prediction = self._training_network(mini_batch)

正在加载...
取消
保存