Function(class torch.autograd.Funtion)
用法
Function一般只定义一个操作,并且它无法保存参数,一般适用于激活函数,pooling等,它需要定义三个方法,init(),forward(),backward()(这个需要自己定义怎么求导)
Model保存了参数,适合定义一层,如线性层(Linear layer),卷积层(conv layer),也适合定义一个网络。
和Model的区别,model只需要定义__init()__,foward()方法,backward()不需要我们定义,它可以由自动求导机制计算。
Function定义只是一个函数,forward和backward都只与这个Function的输入和输出有关
functions
1 | import torch |