|
|
|
|
|
|
settings.demo_path, 1, specs |
|
|
|
) # This is supposed to be the sequence length but we do not have access here |
|
|
|
params = list(self._discriminator_network.parameters()) |
|
|
|
self.decay_learning_rate = ModelUtils.DecayedValue( |
|
|
|
settings.learning_rate_schedule, |
|
|
|
settings.learning_rate, |
|
|
|
1e-10, |
|
|
|
settings.decay_steps, |
|
|
|
) |
|
|
|
self.optimizer = torch.optim.Adam(params, lr=settings.learning_rate) |
|
|
|
|
|
|
|
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray: |
|
|
|
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]: |
|
|
|
def update( |
|
|
|
self, mini_batch: AgentBuffer, global_step: int |
|
|
|
) -> Dict[str, np.ndarray]: |
|
|
|
decay_lr = self.decay_learning_rate.get_value(global_step) |
|
|
|
expert_batch = self._demo_buffer.sample_mini_batch( |
|
|
|
mini_batch.num_experiences, 1 |
|
|
|
) |
|
|
|
|
|
|
ModelUtils.update_learning_rate(self.optimizer, decay_lr) |
|
|
|
self.optimizer.zero_grad() |
|
|
|
loss.backward() |
|
|
|
self.optimizer.step() |
|
|
|