浏览代码

Fix some issues with pdf

/develop/add-fire
Arthur Juliani 4 年前
当前提交
e14eb72b
共有 2 个文件被更改,包括 15 次插入5 次删除
  1. 7
      ml-agents/mlagents/trainers/distributions_torch.py
  2. 13
      ml-agents/mlagents/trainers/models_torch.py

7
ml-agents/mlagents/trainers/distributions_torch.py


def sample(self):
return self.mean + torch.randn_like(self.mean) * self.std
def pdf(self, value):
def log_prob(self, value):
var = self.std ** 2
log_scale = self.std.log()
return (

)
def log_prob(self, value):
return torch.log(self.pdf(value))
def pdf(self, value):
log_prob = self.log_prob(value)
return torch.exp(log_prob)
def entropy(self):
return torch.log(2 * math.pi * math.e * self.std)

13
ml-agents/mlagents/trainers/models_torch.py


import torch
from torch import nn
from mlagents.trainers.distributions_torch import GaussianDistribution, CategoricalDistInstance
from mlagents.trainers.distributions_torch import (
GaussianDistribution,
MultiCategoricalDistribution,
)
from mlagents.trainers.exception import UnityTrainerException
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
EncoderFunction = Callable[

hidden = encoder(vis_input)
vis_embeds.append(hidden)
#embedding = vec_embeds[0]
if len(vec_embeds) > 0:
vec_embeds = torch.stack(vec_embeds, dim=-1).sum(dim=-1)
if len(vis_embeds) > 0:

super(ActorCritic, self).__init__()
self.act_type = ActionType.from_str(act_type)
self.act_size = act_size
self.version_number = torch.nn.Parameter(torch.Tensor([2.0]))
self.memory_size = torch.nn.Parameter(torch.Tensor([0]))
self.is_continuous_int = torch.nn.Parameter(torch.Tensor([1]))
self.act_size_vector = torch.nn.Parameter(torch.Tensor(act_size))
self.separate_critic = separate_critic
self.network_body = NetworkBody(
vector_sizes,

vec_inputs, vis_inputs, masks, memories, sequence_length
)
sampled_actions = self.sample_action(dists)
return sampled_actions, dists[0].pdf(sampled_actions)
return sampled_actions, dists[0].pdf(sampled_actions), self.version_number, self.memory_size, self.is_continuous_int, self.act_size_vector
class Critic(nn.Module):

正在加载...
取消
保存