浏览代码

layer norm and weight decay with fixed architecture

/layernorm
Andrew Cohen 4 年前
当前提交
bc77c990
共有 7 个文件被更改,包括 235 次插入259 次删除
  1. 42
      Project/Assets/Bullet/Prefabs/BulletArea.prefab
  2. 228
      Project/Assets/Bullet/Scenes/Bullet.unity
  3. 4
      Project/Assets/Bullet/Scripts/DodgeAgent.cs
  4. 10
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  5. 196
      ml-agents/mlagents/trainers/torch/attention.py
  6. 2
      ml-agents/mlagents/trainers/torch/layers.py
  7. 12
      ml-agents/mlagents/trainers/torch/networks.py

42
Project/Assets/Bullet/Prefabs/BulletArea.prefab


m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_ReceiveGI: 1
m_PreserveUVs: 0
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0

m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_ReceiveGI: 1
m_PreserveUVs: 0
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0

m_Modification:
m_TransformParent: {fileID: 7562565324271374322}
m_Modifications:
- target: {fileID: 968348991679639438, guid: 16266ae2c040142b1a996878d94ec3fb,
type: 3}
propertyPath: ground
value:
objectReference: {fileID: 7564015931936126656}
- target: {fileID: 968348991679639438, guid: 16266ae2c040142b1a996878d94ec3fb,
type: 3}
propertyPath: area
value:
objectReference: {fileID: 7565593112241766594}
- target: {fileID: 2272550728573618168, guid: 16266ae2c040142b1a996878d94ec3fb,
type: 3}
propertyPath: m_LocalPosition.x

propertyPath: m_LocalEulerAnglesHint.z
value: 0
objectReference: {fileID: 0}
- target: {fileID: 968348991679639438, guid: 16266ae2c040142b1a996878d94ec3fb,
type: 3}
propertyPath: ground
value:
objectReference: {fileID: 7564015931936126656}
- target: {fileID: 968348991679639438, guid: 16266ae2c040142b1a996878d94ec3fb,
type: 3}
propertyPath: area
value:
objectReference: {fileID: 7565593112241766594}
- target: {fileID: 5229809530250264056, guid: 16266ae2c040142b1a996878d94ec3fb,
type: 3}
propertyPath: m_Name

type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 2602957484348800521, guid: 16266ae2c040142b1a996878d94ec3fb,
type: 3}
propertyPath: m_BrainParameters.VectorActionSize.Array.size
value: 1
objectReference: {fileID: 0}
- target: {fileID: 2602957484348800521, guid: 16266ae2c040142b1a996878d94ec3fb,
type: 3}
propertyPath: m_BrainParameters.m_ActionSpec.m_NumContinuousActions
value: 2
objectReference: {fileID: 0}
- target: {fileID: 2602957484348800521, guid: 16266ae2c040142b1a996878d94ec3fb,
type: 3}
propertyPath: m_BrainParameters.VectorActionSpaceType
value: 1
objectReference: {fileID: 0}
- target: {fileID: 2602957484348800521, guid: 16266ae2c040142b1a996878d94ec3fb,
type: 3}
propertyPath: m_BrainParameters.VectorActionSize.Array.data[0]
value: 2
objectReference: {fileID: 0}
m_RemovedComponents:
- {fileID: 9204493211719777360, guid: 16266ae2c040142b1a996878d94ec3fb, type: 3}

228
Project/Assets/Bullet/Scenes/Bullet.unity


m_EnableBakedLightmaps: 1
m_EnableRealtimeLightmaps: 1
m_LightmapEditorSettings:
serializedVersion: 10
serializedVersion: 12
m_Resolution: 2
m_BakeResolution: 40
m_AtlasSize: 1024

m_CompAOExponentDirect: 0
m_ExtractAmbientOcclusion: 0
m_Padding: 2
m_LightmapParameters: {fileID: 0}
m_LightmapsBakeMode: 1

m_PVRDirectSampleCount: 32
m_PVRSampleCount: 500
m_PVRBounces: 2
m_PVREnvironmentSampleCount: 500
m_PVREnvironmentReferencePointCount: 2048
m_PVRFilteringMode: 2
m_PVRDenoiserTypeDirect: 0
m_PVRDenoiserTypeIndirect: 0
m_PVRDenoiserTypeAO: 0
m_PVRFilteringMode: 1
m_PVREnvironmentMIS: 0
m_PVRCulling: 1
m_PVRFilteringGaussRadiusDirect: 1
m_PVRFilteringGaussRadiusIndirect: 5

m_PVRFilteringAtrousPositionSigmaAO: 1
m_ShowResolutionOverlay: 1
m_ExportTrainingData: 0
m_TrainingDataDestination: TrainingData
m_LightingDataAsset: {fileID: 112000002, guid: 03723c7f910c3423aa1974f1b9ce8392,
type: 2}
m_UseShadowmask: 1

m_Modification:
m_TransformParent: {fileID: 0}
m_Modifications:
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_Name
value: BulletArea (7)
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 7562565324271374322, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_LocalPosition.x

propertyPath: m_LocalEulerAnglesHint.z
value: 0
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_Name
value: BulletArea (7)
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
m_RemovedComponents:
- {fileID: 7960826675695843403, guid: e8cf070fe398a4cf49e159e022b4c70f, type: 3}
m_SourcePrefab: {fileID: 100100000, guid: e8cf070fe398a4cf49e159e022b4c70f, type: 3}

m_GameObject: {fileID: 255077123}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 1077351063, guid: f70555f144d8491a825f0804e09c671c, type: 3}
m_Script: {fileID: 11500000, guid: 4f231c4fb786f3946a6b90b886c48677, type: 3}
m_Name:
m_EditorClassIdentifier:
m_HorizontalAxis: Horizontal

m_GameObject: {fileID: 255077123}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: -619905303, guid: f70555f144d8491a825f0804e09c671c, type: 3}
m_Script: {fileID: 11500000, guid: 76c392e42b5098c458856cdf6ecaaaa1, type: 3}
m_Name:
m_EditorClassIdentifier:
m_FirstSelected: {fileID: 0}

m_Modification:
m_TransformParent: {fileID: 0}
m_Modifications:
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_Name
value: BulletArea (1)
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 7562565324271374322, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_LocalPosition.x

propertyPath: m_LocalEulerAnglesHint.z
value: 0
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_Name
value: BulletArea (1)
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
m_RemovedComponents:
- {fileID: 7960826675695843403, guid: e8cf070fe398a4cf49e159e022b4c70f, type: 3}
m_SourcePrefab: {fileID: 100100000, guid: e8cf070fe398a4cf49e159e022b4c70f, type: 3}

m_Modification:
m_TransformParent: {fileID: 0}
m_Modifications:
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_Name
value: BulletArea (5)
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 7562565324271374322, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_LocalPosition.x

propertyPath: m_LocalEulerAnglesHint.z
value: 0
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_Name
value: BulletArea (5)
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
m_RemovedComponents:
- {fileID: 7960826675695843403, guid: e8cf070fe398a4cf49e159e022b4c70f, type: 3}
m_SourcePrefab: {fileID: 100100000, guid: e8cf070fe398a4cf49e159e022b4c70f, type: 3}

m_Modification:
m_TransformParent: {fileID: 0}
m_Modifications:
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_Name
value: BulletArea (2)
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 7562565324271374322, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_LocalPosition.x

propertyPath: m_LocalEulerAnglesHint.z
value: 0
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_Name
value: BulletArea (2)
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
m_RemovedComponents:
- {fileID: 7960826675695843403, guid: e8cf070fe398a4cf49e159e022b4c70f, type: 3}
m_SourcePrefab: {fileID: 100100000, guid: e8cf070fe398a4cf49e159e022b4c70f, type: 3}

m_Modification:
m_TransformParent: {fileID: 0}
m_Modifications:
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_Name
value: BulletArea (4)
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 7562565324271374322, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_LocalPosition.x

propertyPath: m_LocalEulerAnglesHint.z
value: 0
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_Name
value: BulletArea (4)
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
m_RemovedComponents:
- {fileID: 7960826675695843403, guid: e8cf070fe398a4cf49e159e022b4c70f, type: 3}
m_SourcePrefab: {fileID: 100100000, guid: e8cf070fe398a4cf49e159e022b4c70f, type: 3}

m_ClearFlags: 2
m_BackGroundColor: {r: 0.46666667, g: 0.5647059, b: 0.60784316, a: 1}
m_projectionMatrixMode: 1
m_GateFitMode: 2
m_FOVAxisMode: 0
m_GateFitMode: 2
m_FocalLength: 50
m_NormalizedViewPortRect:
serializedVersion: 2

m_Modification:
m_TransformParent: {fileID: 0}
m_Modifications:
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_Name
value: BulletArea (8)
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 7562565324271374322, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_LocalPosition.x

propertyPath: m_LocalEulerAnglesHint.z
value: 0
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_Name
value: BulletArea (8)
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
m_RemovedComponents:
- {fileID: 7960826675695843403, guid: e8cf070fe398a4cf49e159e022b4c70f, type: 3}
m_SourcePrefab: {fileID: 100100000, guid: e8cf070fe398a4cf49e159e022b4c70f, type: 3}

m_Modification:
m_TransformParent: {fileID: 0}
m_Modifications:
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_Name
value: BulletArea (3)
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 7562565324271374322, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_LocalPosition.x

type: 3}
propertyPath: m_LocalEulerAnglesHint.z
value: 0
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_Name
value: BulletArea (3)
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
m_RemovedComponents:
- {fileID: 7960826675695843403, guid: e8cf070fe398a4cf49e159e022b4c70f, type: 3}

m_Modification:
m_TransformParent: {fileID: 0}
m_Modifications:
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_Name
value: BulletArea (6)
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 7562565324271374322, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_LocalPosition.x

propertyPath: m_LocalEulerAnglesHint.z
value: 0
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_Name
value: BulletArea (6)
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
m_RemovedComponents:
- {fileID: 7960826675695843403, guid: e8cf070fe398a4cf49e159e022b4c70f, type: 3}
m_SourcePrefab: {fileID: 100100000, guid: e8cf070fe398a4cf49e159e022b4c70f, type: 3}

m_Modification:
m_TransformParent: {fileID: 0}
m_Modifications:
- target: {fileID: 1261259436004593620, guid: e8cf070fe398a4cf49e159e022b4c70f,
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
propertyPath: m_Constraints
value: 0
objectReference: {fileID: 0}
- target: {fileID: 1790823965615916200, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_BehaviorType
value: 0
objectReference: {fileID: 0}
- target: {fileID: 6206438415594346407, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_LocalPosition.y
value: 0.5
objectReference: {fileID: 0}
- target: {fileID: 6206438415594346407, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_LocalPosition.z
value: 0
propertyPath: m_Name
value: BulletArea
objectReference: {fileID: 0}
- target: {fileID: 7562565324271374322, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}

propertyPath: m_LocalEulerAnglesHint.z
value: 0
objectReference: {fileID: 0}
- target: {fileID: 7565593112241766594, guid: e8cf070fe398a4cf49e159e022b4c70f,
- target: {fileID: 1261259436004593620, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_Constraints
value: 0
objectReference: {fileID: 0}
- target: {fileID: 1790823965615916200, guid: e8cf070fe398a4cf49e159e022b4c70f,
propertyPath: m_Name
value: BulletArea
propertyPath: m_BehaviorType
value: 0
objectReference: {fileID: 0}
- target: {fileID: 6206438415594346407, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_LocalPosition.y
value: 0.5
objectReference: {fileID: 0}
- target: {fileID: 6206438415594346407, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_LocalPosition.z
value: 0
objectReference: {fileID: 0}
- target: {fileID: 7807211597913259905, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}

- target: {fileID: 8381723296076531545, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
- target: {fileID: 3581941315059504425, guid: e8cf070fe398a4cf49e159e022b4c70f,
type: 3}
propertyPath: DecisionPeriod
value: 1
objectReference: {fileID: 0}
m_RemovedComponents:

4
Project/Assets/Bullet/Scripts/DodgeAgent.cs


Vector3 dirToGo = new Vector3(1,0,0) * forwardForce + new Vector3(0,0,1)*lateralForce;
m_AgentRb.AddForce(dirToGo * m_BulletSettings.agentRunSpeed,
ForceMode.VelocityChange);
//Vector3 dirToCenter = new Vector3((transform.position.x - area.transform.position.x) / 10f, 0f, (transform.position.z - area.transform.position.z) / 10f);
//AddReward(.001f / (dirToCenter.magnitude + .0000001f));
}
public override void Heuristic(in ActionBuffers actionsOut)

10
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


)
self.optimizer = torch.optim.Adam(
params, lr=self.trainer_settings.hyperparameters.learning_rate
params, lr=self.trainer_settings.hyperparameters.learning_rate, weight_decay=1E-6
)
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",

update_stats["Policy/" + name + '_mean'] = torch.mean(params).item()
update_stats["Policy/" + name + '_std'] = torch.std(params).item()
update_stats["Policy/" + name + '_grad_mag'] = torch.norm(params.grad).item()
#update_stats["Policy/" + name + '_grad_mean'] = torch.mean(params.grad).item()
#update_stats["Policy/" + name + '_grad_std'] = torch.std(params.grad).item()
update_stats["Policy/" + name + '_grad_mean'] = torch.mean(params.grad).item()
update_stats["Policy/" + name + '_grad_std'] = torch.std(params.grad).item()
#update_stats["Policy/" + name + '_grad_mean'] = torch.mean(params.grad).item()
#update_stats["Policy/" + name + '_grad_std'] = torch.std(params.grad).item()
update_stats["Policy/" + name + '_grad_mean'] = torch.mean(params.grad).item()
update_stats["Policy/" + name + '_grad_std'] = torch.std(params.grad).item()
for reward_provider in self.reward_signals.values():

196
ml-agents/mlagents/trainers/torch/attention.py


def __init__(
self,
query_size: int,
key_size: int,
value_size: int,
output_size: int,
self.output_size = output_size
#self.fc_q = linear_layer(
# query_size,
# self.n_heads * self.embedding_size,
# kernel_init=Initialization.KaimingHeNormal,
# kernel_gain=1.0,
# )
#self.fc_k = linear_layer(
# key_size,
# self.n_heads * self.embedding_size,
# kernel_init=Initialization.KaimingHeNormal,
# kernel_gain=1.0,
# )
#self.fc_v = linear_layer(
# value_size,
# self.n_heads * self.embedding_size,
# kernel_init=Initialization.KaimingHeNormal,
# kernel_gain=1.0,
# )
self.fc_q = torch.nn.Linear(query_size, self.n_heads * self.embedding_size)
self.fc_k = torch.nn.Linear(key_size, self.n_heads * self.embedding_size)
self.fc_v = torch.nn.Linear(value_size, self.n_heads * self.embedding_size)
# self.fc_q = torch.nn.linear(query_size, self.n_heads * self.embedding_size)
# self.fc_k = torch.nn.linear(key_size, self.n_heads * self.embedding_size)
# self.fc_v = torch.nn.linear(value_size, self.n_heads * self.embedding_size)
#self.fc_q = torch.nn.Linear(self.embedding_size // self.n_heads, self.embedding_size // self.n_heads)
#self.fc_k = torch.nn.Linear(self.embedding_size // self.n_heads, self.embedding_size // self.n_heads)
#self.fc_v = torch.nn.Linear(self.embedding_size // self.n_heads, self.embedding_size // self.n_heads)
#self.fc_qk = torch.nn.Linear(self.embedding_size, 2 * self.embedding_size)
self.fc_q = torch.nn.Linear(self.embedding_size, self.embedding_size)
self.fc_k = torch.nn.Linear(self.embedding_size, self.embedding_size)
self.fc_v = torch.nn.Linear(self.embedding_size, self.embedding_size)
# self.fc_q = LinearEncoder(query_size, 2, self.n_heads * self.embedding_size)
# self.fc_k = LinearEncoder(key_size,2, self.n_heads * self.embedding_size)
# self.fc_v = LinearEncoder(value_size,2, self.n_heads * self.embedding_size)
self.n_heads * self.embedding_size, self.output_size
self.embedding_size, self.embedding_size
for layer in [self.fc_q, self.fc_k, self.fc_v, self.fc_out]:
#for layer in [self.fc_qk, self.fc_v, self.fc_out]:
torch.torch.nn.init.normal_(layer.weight, std = (0.125 / embedding_size ) ** 0.5)
self.embedding_norm = torch.nn.LayerNorm(embedding_size)
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
inp: torch.Tensor,
key_mask: Optional[torch.Tensor] = None,
number_of_keys: int = -1,
number_of_queries: int = -1,

n_q = number_of_queries if number_of_queries != -1 else query.size(1)
n_k = number_of_keys if number_of_keys != -1 else key.size(1)
query = self.fc_q(query) # (b, n_q, h*d)
key = self.fc_k(key) # (b, n_k, h*d)
value = self.fc_v(value) # (b, n_k, h*d)
inp = self.embedding_norm(inp)
query = self.fc_q(inp)
query = query.reshape(b, n_q, self.n_heads, self.embedding_size // self.n_heads)
key = self.fc_k(inp)
key = key.reshape(b, n_q, self.n_heads, self.embedding_size // self.n_heads)
#query = self.fc_q(inp)
#qk = self.fc_qk(inp)
#qk = qk.reshape(b, n_q, self.n_heads, self.embedding_size // self.n_heads, 2)
#query, key = torch.split(qk, 1, dim=-1)
#query = torch.squeeze(query, dim=-1)
#key = torch.squeeze(key, dim=-1)
query = query.reshape(b, n_q, self.n_heads, self.embedding_size)
key = key.reshape(b, n_k, self.n_heads, self.embedding_size)
value = value.reshape(b, n_k, self.n_heads, self.embedding_size)
value = self.fc_v(inp) # (b, n_k, h*d)
value = value.reshape(b, n_k, self.n_heads, self.embedding_size // self.n_heads)
#query = query.reshape(b, n_q, self.n_heads, self.embedding_size // self.n_heads)
#key = key.reshape(b, n_k, self.n_heads, self.embedding_size // self.n_heads)
#value = value.reshape(b, n_k, self.n_heads, self.embedding_size // self.n_heads)
#query = self.fc_q(query) # (b, n_q, h*d)
#key = self.fc_k(key) # (b, n_k, h*d)
#value = self.fc_v(value) # (b, n_k, h*d)
query = query.permute([0, 2, 1, 3]) # (b, h, n_q, emb)
# The next few lines are equivalent to : key.permute([0, 2, 3, 1])

value_attention = value_attention.permute([0, 2, 1, 3]) # (b, n_q, h, emb)
value_attention = value_attention.reshape(
b, n_q, self.n_heads * self.embedding_size
b, n_q, self.embedding_size
out = self.fc_out(value_attention) # (b, n_q, emb)
#if out.requires_grad:
# out.register_hook(lambda x: print(x))

x_self_size: int,
entities_sizes: List[int],
embedding_size: int,
output_size: Optional[int] = None,
):
super().__init__()
self.self_size = x_self_size

[
# LinearEncoder(self.self_size + ent_size, 2, embedding_size)
# from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
# linear_layer(self.self_size + ent_size, embedding_size, Initialization.Normal, kernel_gain=1 / (self.self_size + ent_size) ** 0.5)
LinearEncoder(self.self_size + ent_size, 1, embedding_size, layer_norm=True)
# linear_layer(self.self_size + ent_size, embedding_size, Initialization.Normal, kernel_gain=(.125 / (self.self_size + ent_size)) ** 0.5)
linear_layer(ent_size, embedding_size, Initialization.Normal, kernel_gain=(.125 / (self.self_size + ent_size)) ** 0.5)
# LinearEncoder(self.self_size + ent_size, 1, embedding_size, layer_norm=False)
query_size=embedding_size,
key_size=embedding_size,
value_size=embedding_size,
output_size=embedding_size,
#self.residual_layer = LinearEncoder(embedding_size, 1, embedding_size)
self.res_norm = torch.nn.LayerNorm(embedding_size, elementwise_affine=True)
if output_size is None:
output_size = embedding_size
#self.residual_layer = torch.nn.Linear(
# embedding_size, embedding_size
#)
#torch.torch.nn.init.normal_(self.residual_layer.weight, std = 0.125 * embedding_size ** -0.5)
self.res_norm = torch.nn.LayerNorm(embedding_size)
def forward(
self,

for ent in entities:
self.entities_num_max_elements.append(ent.shape[1])
# Concatenate all observations with self
self_and_ent: List[torch.Tensor] = []
for num_entities, ent in zip(self.entities_num_max_elements, entities):
expanded_self = x_self.reshape(-1, 1, self.self_size)
# .repeat(
# 1, num_entities, 1
# )
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
#self_and_ent: List[torch.Tensor] = []
#for num_entities, ent in zip(self.entities_num_max_elements, entities):
# expanded_self = x_self.reshape(-1, 1, self.self_size)
# # .repeat(
# # 1, num_entities, 1
# # )
# expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
# self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)],
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, entities)],
output, _ = self.attention(qkv, qkv, qkv, mask, max_num_ent, max_num_ent)
# Residual
#output = self.residual_layer(output) + qkv
#output += qkv
output = self.res_norm(output + qkv)
#output = self.res_norm(output)
# Average Pooling
output, _ = self.attention(qkv, mask, max_num_ent, max_num_ent)
#residual 1
output += qkv
output = self.res_norm(output)
#residual 2
#output = self.residual_layer(output) + output #qkv
# average pooling
#output = torch.cat([output, x_self], dim=1)
output = torch.cat([output, x_self], dim=1)
return output
@staticmethod

for ent in observations
]
return key_masks
class SmallestAttention(torch.nn.Module):
def __init__(
self,
x_self_size: int,
entities_sizes: List[int],
embedding_size: int,
output_size: Optional[int] = None,
):
super().__init__()
self.self_size = x_self_size
self.entities_sizes = entities_sizes
self.entities_num_max_elements: Optional[List[int]] = None
self.ent_encoders = torch.nn.ModuleList(
[
LinearEncoder(self.self_size + ent_size, 2, embedding_size)
# LinearEncoder(self.self_size + ent_size, 3, embedding_size)
# LinearEncoder(self.self_size + ent_size, 1, embedding_size)
for ent_size in self.entities_sizes
]
)
self.importance_layer = LinearEncoder(embedding_size, 1, 1)
def forward(
self,
x_self: torch.Tensor,
entities: List[torch.Tensor],
key_masks: List[torch.Tensor],
) -> torch.Tensor:
# Gather the maximum number of entities information
if self.entities_num_max_elements is None:
self.entities_num_max_elements = []
for ent in entities:
self.entities_num_max_elements.append(ent.shape[1])
# Concatenate all observations with self
self_and_ent: List[torch.Tensor] = []
for num_entities, ent in zip(self.entities_num_max_elements, entities):
expanded_self = x_self.reshape(-1, 1, self.self_size)
# .repeat(
# 1, num_entities, 1
# )
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
# Generate the tensor that will serve as query, key and value to self attention
qkv = torch.cat(
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)],
dim=1,
)
mask = torch.cat(key_masks, dim=1)
# Feed to self attention
max_num_ent = sum(self.entities_num_max_elements)
importance = self.importance_layer(qkv) + mask.unsqueeze(2) * -1e6
importance = torch.softmax(importance, dim=1)
weighted_qkv = qkv * importance
output = torch.sum(weighted_qkv, dim=1)
output = torch.cat([output, x_self], dim=1)
return output

2
ml-agents/mlagents/trainers/torch/layers.py


kernel_gain=1.0,
)
)
self.layers.append(Swish())
self.layers.append(Swish())
self.seq_layers = torch.nn.Sequential(*self.layers)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:

12
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.torch.attention import SimpleTransformer, SmallestAttention
from mlagents.trainers.torch.attention import SimpleTransformer
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]

self.normalize = network_settings.normalize
self.use_lstm = network_settings.memory is not None
self.h_size = network_settings.hidden_units
self.n_embd = 128
self.m_size = (
network_settings.memory.memory_size
if network_settings.memory is not None

self.transformer = SimpleTransformer(
x_self_len,
entities_sizes,
self.h_size,
self.h_size,
self.n_embd,
total_enc_size = self.h_size #+ sum(self.embedding_sizes)
total_enc_size = self.n_embd + sum(self.embedding_sizes)
n_layers = 1
n_layers = 2
if self.use_fc:
self.transformer = None
total_enc_size = 80 + sum(self.embedding_sizes)

tens.retain_grad()
total_enc_size += encoded_act_size
self.linear_encoder = LinearEncoder(
total_enc_size, n_layers, self.h_size
total_enc_size, n_layers, self.h_size, layer_norm=False
)
for _,tens in list(self.linear_encoder.named_parameters()):
tens.retain_grad()

正在加载...
取消
保存