Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TypeError: Tensor is unhashable. Instead, use tensor.ref()

I'm getting "TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key"

I did a slight change to a public kaggle kernel

I defined a function which checks whether certain value is in a set:

li = pd.read_csv('../input/ch-stock-market-companies/censored_list')
def map_id (id):
  li_set = set(li['investment_id'])
  if id in li_set: return id
  return -5

This function is called during the preprocessing of a tensorflow dataset:

def preprocess(item):
  return (map_id(item["investment_id"]), item["features"]), item["target"] #this is the offending line

def make_dataset(file_paths, batch_size=4096, mode="train"):
  ds = tf.data.TFRecordDataset(file_paths)
  ds = ds.map(decode_function)
  ds = ds.map(preprocess)
  if mode == "train":
      ds = ds.shuffle(batch_size * 4)
  ds = ds.batch(batch_size).cache().prefetch(tf.data.AUTOTUNE)
  return ds

If the above offending line is not changed it would look like this:

    def preprocess(item):
      return (item["investment_id"], item["features"]), item["target"] #this was the line before I changed it

The error message tells me that I cannot use the function map_id as defined.

But how to properly do what I am trying to achieve? Namely, I want to "censor" some of the values in a pandas dataframe by replacing them with a default value of -5. And I want to do this, ideally, as part of creating a tensforflow dataset

like image 944
Nick Avatar asked Jan 20 '26 17:01

Nick


1 Answers

As the error message says, you cannot use a tensor inside a Set directly, since it is not hashable. Try using a tf.lookup.StaticHashTable:

keys_tensor = tf.constant([1, 2, 3])
vals_tensor = tf.constant([1, 2, 3])
table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
    default_value=-5)

print(table.lookup(tf.constant(1)))
print(table.lookup(tf.constant(5)))
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(-5, shape=(), dtype=int32)

Alternatively, you could also use tf.where:

def check_value(value):
  frozen_set = tf.constant([1, 2, 3])
  return tf.where(tf.reduce_any(tf.equal(value, frozen_set), axis=0, keepdims=True), value, tf.constant(-5))

print(check_value(tf.constant(1)))
print(check_value(tf.constant(2)))
print(check_value(tf.constant(4)))
tf.Tensor([1], shape=(1,), dtype=int32)
tf.Tensor([2], shape=(1,), dtype=int32)
tf.Tensor([-5], shape=(1,), dtype=int32)
like image 95
AloneTogether Avatar answered Jan 22 '26 05:01

AloneTogether