Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Combine arbitrary shaped tensors

I'd like to combine two variable length tensors.

Since they don't match in shape I can't use tf.concat or tf.stack.

So I thought I'd flatten one and then append it to each element of the other - but I don't see how to do that.

For example,

a = [ [1,2], [3,4] ]
flat_b = [5, 6]

combine(a, flat_b) would be [ [ [1,5,6], [2,5,6] ],
                              [ [3,5,6], [4,5,6] ] ]

Is there a method like this?

like image 438
Jamie Hlusko Avatar asked Nov 19 '25 04:11

Jamie Hlusko


2 Answers

Using tf.map_fn with tf.concat, Example code:

import tensorflow as tf

a = tf.constant([ [1,2], [3,4] ])
flat_b = [5, 6]
flat_a = tf.reshape(a, (tf.reduce_prod(a.shape).numpy(), ))[:, tf.newaxis]
print(flat_a)
c = tf.map_fn(fn=lambda t: tf.concat([t, flat_b], axis=0), elems=flat_a)
c = tf.reshape(c, (-1, a.shape[1], c.shape[1]))
print(c)

Outputs:

tf.Tensor(
[[1]
 [2]
 [3]
 [4]], shape=(4, 1), dtype=int32)
tf.Tensor(
[[[1 5 6]
  [2 5 6]]

 [[3 5 6]
  [4 5 6]]], shape=(2, 2, 3), dtype=int32)
like image 156
Mr. For Example Avatar answered Nov 21 '25 16:11

Mr. For Example


Here's a somewhat simpler version of the previous answer. Rather than reshaping several times, I prefer to use tf.expand_dims and tf.stack. The latter adds a dimension so that's one less call to tf.reshape, which can be confusing.

import tensorflow as tf

a = tf.constant([[1,2], [3,4]])
b = [5, 6]

flat_a = tf.reshape(a, [-1]) 

c = tf.map_fn(lambda x: tf.concat([[x], b], axis=0), flat_a)
c = tf.stack(tf.split(c, num_or_size_splits=len(a)), axis=0)
<tf.Tensor: shape=(2, 2, 3), dtype=int32, numpy=
array([[[1, 5, 6],
        [2, 5, 6]],
       [[3, 5, 6],
        [4, 5, 6]]])>
like image 29
Nicolas Gervais Avatar answered Nov 21 '25 17:11

Nicolas Gervais



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!