Ich versuche, Scheiben eines Tensors in Bezug auf die letzte Dimension für die teilweise Verbindung zwischen Schichten zu erfassen. Da die Form des Ausgabe-Tensors [batch_size, h, w, depth]
ist, möchte ich die Slices basierend auf der letzten Dimension auswählen, z
# L is intermediate tensor
partL = L[:, :, :, [0,2,3,8]]
Allerdings scheint tf.gather(L, [0, 2,3,8])
nur für die erste Dimension zu funktionieren (richtig?) Kann mir jemand sagen, wie es geht?
Es gibt einen Tracking-Fehler zur Unterstützung dieses Anwendungsfalls hier: https://github.com/tensorflow/tensorflow/issues/206
Für jetzt können Sie:
transponiere deine Matrix so, dass die zu erfassende Dimension zuerst ist (transponieren ist teuer)
umformen Sie Ihren Tensor in 1d (Umformen ist billig), und verwandeln Sie Ihre Spaltenindizes in eine Liste der einzelnen Elementindizes bei linearer Indizierung
gather_nd
. Müssen Sie Ihre Spaltenindizes dennoch in eine Liste einzelner Elementindizes umwandeln.Ab TensorFlow 1.3 verfügt tf.gather
über einen axis
-Parameter, so dass die verschiedenen Problemumgehungen hier nicht mehr erforderlich sind.
https://www.tensorflow.org/versions/r1.3/api_docs/python/tf/gatherhttps://github.com/tensorflow/tensorflow/issues/11223
Mit gather_nd können Sie dies nun folgendermaßen tun:
cat_idx = tf.concat([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=0)
result = tf.gather_nd(matrix, cat_idx)
Wie vom Benutzer Nova in einem von @Yaroslav Bulatov referenzierten Thread berichtet:
x = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
idx = tf.constant([1, 0, 2])
idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx
y = tf.gather(tf.reshape(x, [-1]), # flatten input
idx_flattened) # use flattened indices
with tf.Session(''):
print y.eval() # [2 4 9]
Der Gist glättet den Tensor und verwendet die gestufte 1D-Adressierung mit tf.gather (...).
Noch eine andere Lösung mit tf.unstack (...), tf.gather (...) und tf.stack (..)
Code:
import tensorflow as tf
import numpy as np
shape = [2, 2, 2, 10]
L = np.arange(np.prod(shape))
L = np.reshape(L, shape)
indices = [0, 2, 3, 8]
axis = -1 # last dimension
def gather_axis(params, indices, axis=0):
return tf.stack(tf.unstack(tf.gather(tf.unstack(params, axis=axis), indices)), axis=axis)
print(L)
with tf.Session() as sess:
partL = sess.run(gather_axis(L, indices, axis))
print(partL)
Ergebnis:
L =
[[[[ 0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]]
[[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]]]
[[[40 41 42 43 44 45 46 47 48 49]
[50 51 52 53 54 55 56 57 58 59]]
[[60 61 62 63 64 65 66 67 68 69]
[70 71 72 73 74 75 76 77 78 79]]]]
partL =
[[[[ 0 2 3 8]
[10 12 13 18]]
[[20 22 23 28]
[30 32 33 38]]]
[[[40 42 43 48]
[50 52 53 58]]
[[60 62 63 68]
[70 72 73 78]]]]
Eine korrekte Version von @ Andrei's Antwort würde lesen
cat_idx = tf.stack([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=1)
result = tf.gather_nd(matrix, cat_idx)
Sie können auf diese Weise beispielsweise versuchen (in den meisten Fällen zumindest in NLP),
Der Parameter hat die Form [batch_size, depth]
und die Indizes sind [i, j, k, n, m], deren Länge batch_size ist. Dann kann gather_nd
hilfreich sein.
parameters = tf.constant([
[11, 12, 13],
[21, 22, 23],
[31, 32, 33],
[41, 42, 43]])
targets = tf.constant([2, 1, 0, 1])
batch_nums = tf.range(0, limit=parameters.get_shape().as_list()[0])
indices = tf.stack((batch_nums, targets), axis=1) # the axis is the dimension number
items = tf.gather_nd(parameters, indices)
# which is what we want: [13, 22, 31, 42]
Dieses Snippet sucht zuerst die erste Dimension durch batch_num und holt dann das Element entlang dieser Dimension anhand der Zielnummer.
Tensor hat keine Attributform, sondern die Methode get_shape (). Nachfolgend können Sie Python 2.7 ausführen
import tensorflow as tf
import numpy as np
x = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
idx = tf.constant([1, 0, 2])
idx_flattened = tf.range(0, x.get_shape()[0]) * x.get_shape()[1] + idx
y = tf.gather(tf.reshape(x, [-1]), # flatten input
idx_flattened) # use flattened indices
with tf.Session(''):
print y.eval() # [2 4 9]
Implementierung von @Yaroslav Bulatov:
#Your indices
indices = [0, 2, 3, 8]
#Remember for final reshaping
n_indices = tf.shape(indices)[0]
flattened_L = tf.reshape(L, [-1])
#Walk strided over the flattened array
offset = tf.expand_dims(tf.range(0, tf.reduce_prod(tf.shape(L)), tf.shape(L)[-1]), 1)
flattened_indices = tf.reshape(tf.reshape(indices, [-1])+offset, [-1])
selected_rows = tf.gather(flattened_L, flattened_indices)
#Final reshape
partL = tf.reshape(selected_rows, tf.concat(0, [tf.shape(L)[:-1], [n_indices]]))
Gutschrift an Wie wähle ich Zeilen aus einem 3-D-Tensor in TensorFlow aus?