tf.gather_nd
一句话介绍
按照索引将输入tensor的某些维度拼凑成一个新的tenosr
API
1 | tf.gather_nd( |
indices是一个K维的整形tensor。
indices的最后一维至多和params的rank一样大,如果indices.shape==params.rank,那么对应的是elements,如果indices.shape $\lt$ params.rank,那么对应的是slices。输出的tensor shape是:
indices.shape[:-1] + params.shape[indices.shape[-1]:]
原文如下:
The last dimension of indices corresponds to elements (if indices.shape[-1] == params.rank) or slices (if indices.shape[-1] < params.rank) along dimension indices.shape[-1] of params. The output tensor has shape
indices.shape[:-1] + params.shape[indices.shape[-1]:]
如果indices是两维的,那么就相当于用第二维的indices去访问params,然后indices的第一维度相当于把第二维的tensor放入一个列表。
indices是高维(大于两维)的话,反正就是找最后一维的维度,然后到params中找对应的数。
1 | indices = [[[1]], [[0]]] |
代码示例1
1 | import tensorflow as tf |
代码示例2
1 | import tensorflow as tf |