Skip to content
Advertisement

How to build TF tensor with ones in specified locations – batch compatible

I apologize for the poor question title but I’m not sure quite how to phrase it. Here’s the problem I’m trying to solve: I have two NNs working off of the same input dataset in my code. One of them is a traditional network while the other is used to limit the acceptable range of the first. This works by using a tf.where() statement which works fine in most cases, such as this toy example:

pcts= [0.04,0.06,0.06,0.06,0.06,0.06,0.06,0.04,0.04,0.04]
legal_actions = tf.where(pcts>=0.05, tf.ones_like(pcts), tf.zeros_like(pcts))

Which gives the correct result: legal_actions = [0,1,1,1,1,1,1,0,0,0]

I can then multiply this by the output of my first network to limit its Q values to only those of the legal actions. In a case like the above this works great.

However, it is also possible that my original vector looks something like this, with low values in the middle of the high values: pcts= [0.04,0.06,0.06,0.04,0.04,0.06,0.06,0.04,0.04,0.04]

Using the same code as above my legal_actions comes out as this: legal_actions = [0,1,1,0,0,1,1,0,0,0]

Based on the code I have this is correct, however, I’d like to include any zeros in the middle as part of my legal_actions. In other words, I’d like this second example to be the same as the first. Working in basic TF this is easy to do in several different ways, such as in this reproducible example (it’s also easy to do with sparse tensors):

import tensorflow as tf

pcts= tf.placeholder(tf.float32, shape=(10,))
legal_actions = tf.where(pcts>=0.05, tf.ones_like(pcts), tf.zeros_like(pcts))
mask = tf.where(tf.greater(legal_actions,0))
legals = tf.cast(tf.range(tf.reduce_min(mask),tf.reduce_max(mask)+1),tf.int64)

oh = tf.one_hot(legals,10)
oh = tf.reduce_sum(oh,0)

with tf.Session() as sess:
    print(sess.run(oh,feed_dict={pcts:[0.04,0.06,0.06,0.04,0.04,0.06,0.06,0.04,0.04,0.04]}))

The problem that I’m running into is when I try to apply this to my actual code which is reading in batches from a file. I can’t figure out a way to fill in the “gaps” in my tensor without the range function and/or I can’t figure out how to make the range function work with batches (it will only make one range at a time, not one per batch, as near as I can tell). Any suggestions on how to either make what I’m working on work or how to solve the problem a completely different way would be appreciated.

Advertisement

Answer

Try this code:

import tensorflow as tf
pcts = tf.random.uniform((2,3,4))
a = pcts>=0.5
shape = tf.shape(pcts)[-1]
a = tf.reshape(a, (-1, shape))
a = tf.cast(a, dtype=tf.float32)
def rng(t):
  left = tf.scan(lambda a, x: max(a, x), t)
  right = tf.scan(lambda a, x: max(a, x), t, reverse=True)
  return tf.minimum(left, right)

a = tf.map_fn(lambda x: rng(x), a)
a = tf.reshape(a, (tf.shape(pcts)))
User contributions licensed under: CC BY-SA
7 People found this is helpful
Advertisement