代码
这个DQN的Replay Buffer实现只用到了numpy库,可以很容易的进行扩展。主要有五个函数。接下来分函数进行解析。
1 | import numpy as np |
init函数
ReplayBuffer的init的输入参数为一个config文件,包含了创建ReplayBuffer的参数,memory_size是Buffer大小,batch_size为训练和测试的batch大小,screens, actions, rewards, terminals分别存放的是每次采样得到的screen, action, reward和terminal(当前episode是否结束)。history_length是原文中提到的连续处理四张图片的四,而不仅仅是一张。state_format指的是’NHWC’还是’NCHW’,即depth通道在第$1$维还是第$3$维,states存放的是一个tensor,shape为$(batch_size, screen_height, screen_width, history_length)$,count记录当前Buffer的大小,current记录当前experience插入的地方。
add方法
该方法实现了向ReplayBuffer中添加experience。
__len__方法
放回Buffer当前的大小
clear方法
清空Buffer
sample方法
从buffer中进行采样,返回一个元组,(states, actions, rewards, next_states, terminals)
getState方法
给定一个index,寻找它的前history_length - 1 个screens。