I am using tensorflow 2.3.0
I have a python data generator-
import tensorflow as tf
import numpy as np
vocab = [1,2,3,4,5]
def create_generator():
'generates a random number from 0 to len(vocab)-1'
count = 0
while count < 4:
x = np.random.randint(0, len(vocab))
yield x
count +=1
I make it a tf.data.Dataset object
gen = tf.data.Dataset.from_generator(create_generator,
args=[],
output_types=tf.int32,
output_shapes = (), )
Now I want to sub-sample items using the map method, such that the tf generator would never output any even number.
def subsample(x):
'remove item if it is present in an even number [2,4]'
'''
#TODO
'''
return x
gen = gen.map(subsample)
How can I achieve this using map method?
Shortly no, you cannot filter data using map. Map functions apply some transformation to every element of the dataset. What you want is to check every element for some predicate and get only those elements that satisfy the predicate.
And that function is filter().
So you can do:
gen = gen.filter(lambda x: x % 2 != 0)
Update:
If you want to use a custom function instead of lambda, you can do something like:
def filter_func(x):
if x**2 < 500:
return True
return False
gen = gen.filter(filter_func)
If this function is passed to filter all numbers whose square is less than 500 will be returned.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With