using System;
using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents.Policies;
namespace Unity.MLAgents.Actuators
{
///
/// Defines the structure of an Action Space to be used by the Actuator system.
///
public readonly struct ActionSpec
{
///
/// An array of branch sizes for our action space.
///
/// For an IActuator that uses a Discrete , the number of
/// branches is the Length of the Array and each index contains the branch size.
/// The cumulative sum of the total number of discrete actions can be retrieved
/// by the property.
///
/// For an IActuator with a Continuous it will be null.
///
public readonly int[] BranchSizes;
///
/// The number of actions for a Continuous .
///
public int NumContinuousActions { get; }
///
/// The number of branches for a Discrete .
///
public int NumDiscreteActions { get; }
///
/// Get the total number of Discrete Actions that can be taken by calculating the Sum
/// of all of the Discrete Action branch sizes.
///
public int SumOfDiscreteBranchSizes { get; }
///
/// Creates a Continuous with the number of actions available.
///
/// The number of actions available.
/// An Continuous ActionSpec initialized with the number of actions available.
public static ActionSpec MakeContinuous(int numActions)
{
var actuatorSpace = new ActionSpec(numActions, 0);
return actuatorSpace;
}
///
/// Creates a Discrete with the array of branch sizes that
/// represents the action space.
///
/// The array of branch sizes for the discrete action space. Each index
/// contains the number of actions available for that branch.
/// An Discrete ActionSpec initialized with the array of branch sizes.
public static ActionSpec MakeDiscrete(int[] branchSizes)
{
var numActions = branchSizes.Length;
var actuatorSpace = new ActionSpec(0, numActions, branchSizes);
return actuatorSpace;
}
ActionSpec(int numContinuousActions, int numDiscreteActions, int[] branchSizes = null)
{
NumContinuousActions = numContinuousActions;
NumDiscreteActions = numDiscreteActions;
BranchSizes = branchSizes;
SumOfDiscreteBranchSizes = branchSizes?.Sum() ?? 0;
}
}
}