浏览代码

deep copy transition list

/ai-hw-2021
Ruo-Ping Dong 4 年前
当前提交
3b267315
共有 2 个文件被更改,包括 37 次插入3 次删除
  1. 11
      com.unity.ml-agents/Runtime/Inference/TensorProxy.cs
  2. 29
      com.unity.ml-agents/Runtime/Training/ReplayBuffer.cs

11
com.unity.ml-agents/Runtime/Inference/TensorProxy.cs


}
}
}
public static TensorProxy DeepCopy(TensorProxy tensor)
{
return new TensorProxy
{
name = tensor.name,
valueType = tensor.valueType,
data = tensor.data.DeepCopy(),
shape = (long[]) tensor.shape.Clone()
};
}
}
}

29
com.unity.ml-agents/Runtime/Training/ReplayBuffer.cs


List<Transition> m_Buffer;
int m_CurrentIndex;
int m_MaxSize;
ITensorAllocator m_Allocator;
public ReplayBuffer(int maxSize)
public ReplayBuffer(int maxSize, ITensorAllocator allocator)
{
m_Buffer = new List<Transition>();
m_Buffer.Capacity = maxSize;

{
if (m_Buffer.Count < m_MaxSize)
{
m_Buffer.Add(new Transition() { state = state, action = info.storedActions, reward = info.reward, done = info.done, nextState = nextState });
m_Buffer.Add(new Transition() {
state = CopyTensorList(state),
action = info.storedActions,
reward = info.reward,
done = info.done,
nextState = CopyTensorList(nextState)
});
m_Buffer[m_CurrentIndex] = new Transition() { state = state, action = info.storedActions, reward = info.reward, done = info.done, nextState = nextState };
m_Buffer[m_CurrentIndex] = new Transition() {
state = CopyTensorList(state),
action = info.storedActions,
reward = info.reward,
done = info.done,
nextState = CopyTensorList(nextState)
};
}
m_CurrentIndex += 1;
m_CurrentIndex = m_CurrentIndex % m_MaxSize;

index.Add(random.Next(m_Buffer.Count));
}
return index.ToList();
}
IReadOnlyList<TensorProxy> CopyTensorList(IReadOnlyList<TensorProxy> inputList)
{
var newList = new List<TensorProxy>();
for (var i = 0; i < inputList.Count; i++)
{
newList.Add(TensorUtils.DeepCopy(inputList[i]));
}
return (IReadOnlyList<TensorProxy>) newList;
}
}
}
正在加载...
取消
保存