浏览代码

separate tensors for disc/cont

/develop/actionmodel-csharp
Andrew Cohen 4 年前
当前提交
72cd0d39
共有 2 个文件被更改,包括 16 次插入17 次删除
  1. 15
      ml-agents/mlagents/trainers/torch/distributions.py
  2. 18
      ml-agents/mlagents/trainers/torch/networks.py

15
ml-agents/mlagents/trainers/torch/distributions.py


import abc
from typing import List
from typing import List, Tuple
from mlagents.torch_utils import torch, nn
import numpy as np
import math

self.discrete_distributions.append(
MultiCategoricalDistribution(self.encoding_size, discrete_act_size)
)
else:
self.discrete_distribution = None
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> List[DistInstance]:
distributions: List[DistInstance] = []
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> Tuple[List[DistInstance], List[DiscreteDistInstance]]:
continuous_distributions: List[DistInstance] = []
discrete_distributions: List[DiscreteDistInstance] = []
distributions += continuous_dist(inputs)
continuous_distributions += continuous_dist(inputs)
distributions += discrete_dist(inputs, masks)
return distributions
discrete_distributions += discrete_dist(inputs, masks)
return continuous_distributions, discrete_distributions

18
ml-agents/mlagents/trainers/torch/networks.py


encoding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
dists = self.distribution(encoding, masks)
return dists, memories
continuous_dists, discrete_dists = self.distribution(encoding, masks)
return continuous_dists, discrete_dists, memories
def forward(
self,

"""
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
"""
# TODO: This is bad right now
dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1)
action_out = torch.cat([dist.exported_model_output() for dist in dists], dim=1)
# TODO: How this is written depends on how the inference model is structured
continuous_dists, discrete_dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1)
action_out = torch.cat([dist.exported_model_output() for dist in continuous_dists + discrete_dists], dim=1)
return (
action_out,
self.version_number,

encoding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
dists = self.distribution(encoding, masks)
continuous_dists, discrete_dists = self.distribution(encoding, masks)
return dists, value_outputs, memories
return continuous_dists, discrete_dists, value_outputs, memories
class SeparateActorCritic(HybridSimpleActor, ActorCritic):

else:
critic_mem = None
actor_mem = None
dists, actor_mem_outs = self.get_dists(
continuous_dists, discrete_dists, actor_mem_outs = self.get_dists(
vec_inputs,
vis_inputs,
memories=actor_mem,

mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1)
else:
mem_out = None
return dists, value_outputs, mem_out
return continuous_dists, discrete_dists, value_outputs, mem_out
################################################################################

正在加载...
取消
保存