torch.distributions
这个库和gym.space库很相似,都是提供一些分布,然后从中采样。
常见的有ExponentialFamily,Bernoulli,Binomial,Categorical,Exponential,Gamma,Independent,Laplace,Multinomial,MultivariateNormal。这里不做过程陈述,可以看gym中。
Categorical
对应tensorflow中的tf.multinomial。
类原型:
1 | CLASS torch.distributions.categorical.Categorical(probs=None, logits=None, validate_args=None) |
参数probs只能是$1$维或者$2$维,而且必须是非负,有限非零和的,然后将其归一化到和为$1$。
这个类和torch.multinormal是一样的,从${0,\cdots, K-1}$中按照probs的概率进行采样,$K$是probs.size(-1),即是size()矩阵的最后一列,$2$维时把第$1$维当成了batch。
举一个简单的例子,代码。
1 | import torch.distributions as diss |
输出结果如下:
tensor(2)
tensor(1)
tensor(1)
tensor(1)
tensor(1)
tensor([2, 2])
tensor([1, 2])
tensor([0, 1])
tensor([0, 2])
tensor([0, 0])
作为对比,gym.spaces.Discrete示例如下:
1 | from gym import spaces |
输出结果是:
3
0
1
0
4