PyTorch Agent Net library
简介
Ptan是一个简化RL的库,它主要目标是实现两个问题的平衡:
- 导入库函数,只需要一行命令,就像OpenAI的baselines一样
- 从头开始实现
我们既不想一行命令直接调包,也不想从头开始实现一切。
模块
- Agent:
- ActionSelector
- ExperienceSource
- ExperienceSourceBuffer
- others
Action Selector
简介
将network的输出转换成具体的action。常用的有
- Argmax:用于使用Q网络的方法,生成离散action。
- Policy-based:网络输出logits或者normalizaed distribution,从这个distribution中采样。
Action Selector被Agent使用,常用的有:
- ArgmaxActionSelector
- EpsilonGreedyActionSector
- ProbabilityActionSelector
基类
ActionSelector
1 | class ActionSelector: |
子类
ArgmaxActionSelector
1 | class ArgmaxActionSelector(ActionSelector): |
EpsilonGreedyActionSector
1 | class EpsilonGreedyActionSector(ActionSelector): |
ProbabilityActionSelector
1 | class ProbabilityActionSelector(ActionSelector): |
对比三个ActionSelector
两个GreedySelector:Argmax和EpsilonGreedy,输入都需要是q值,输出是action。
而Probability需要的输入是概率,输出是动作。
其他
EpsilonTraker
用来记录epsilon的变化。
Agent
将observation转换为actions,常见的三种方法如下:
- Q-function:预测当前observation下所有可能采取的action的$Q$值,选择$\arg \max Q(s)$作为action。
- Policy-based:预测$\pi(s)$的概率分布,从分布中采样。
- Continuous Contrl:预测连续控制参数$\mu(s)$,直接输出action。
基类
BaseAgent
1 | class BaseAgent: |
子类
DQNAgent
1 | class DQNAgent(BaseAgent): |
PolicyAgent
输入的model产生离散动作的policy distribution,Policy distribution可以是logtis或者normalized distribution。
PolicyAgent调用probability action selector对这个distribution进行采样 。PolicyAgent其实就是将model和action selector组装在了一起。
1 | class PolicyAgent(BaseAgent): |
ActorCriticAgent
1 | class ActorCriticAgent |
其他
default_states_preprocessor
1 | def default_states_preprocessor(states): |
TargetNet
1 | class TargetNet |
Experience Source
Agent不断的和env进行交互产生一系列的trajectories,Experience可以将这些交互存储起来,重复利用。Experience的主要作用有:
- 支持batch,利用GPU的并行计算提高训练效率
- 可以对transitions或者trajectory进行预处理。比如n-step DQN。
- ???
常见的ExperienceSource有:
- ExperienceSource
- ExperienceSourceFirstLast
- ExperienceSourceRollouts
- ExperienceReplayBuffer: :DQN中几乎不会使用刚刚获得的experience samples,因为他们是高度相关的,让训练很不稳定。Buffer用来存放experience pieces,从buffer中采样进行训练,因为buffer容量有限,老样本会被从replay buffer中删掉
- PrioReplayBufferNaive: Complexity of sampling is O(n)
- PrioritizedReplayBuffer: O(log(n)) sampling complexity.
基类
ExperienceSource
1 | class ExperienceSource |
ExperienceReplayBuffer
1 | class ExperienceReplayBuffer |
BatchPreprocessor
1 | class BatchPreprocessor |
子类
ExperienceSourceFirstLast
1 | # Q(st, at) = rt+1 + \gamma r_t+2 + ... \gamma^t+n-1 r_t+n + Q(s t+n, s t+n) |
PrioritizedReplayBuffer
1 | class PrioritizedReplayBuffer(ExperienceReplayBuffer) |
QLearningPreprocessor
1 | class QLearningPreprocessor(BatchPreprocessor) |
其他
ExperienceSourceRollouts
1 | class ExperienceSourceRollouts: |
ExperienceSourceBuffer
1 | class ExperienceSourceBuffer |
ExperienceReplayNaive
1 | class ExperienceReplayNaive |
代码解析
ExperienceSource
1 | class ExperienceSource |
ExperienceSourceFirstLast
1 | class ExperienceSourceFirstLast(ExperienceSource): |
ExperienceReplayBuffer
1 | class ExperienceReplayBuffer |
参考文献
1.https://github.com/Shmuma/ptan/blob/master/docs/intro.ipynb