Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to decode a tfrecord based on the available meta data

I have a three tfrecords (train,test valid) from the a deepmind github repo (https://github.com/google-deepmind/deepmind-research/tree/master/meshgraphnets) that I am trying to decode (at the minute just to interrogate the data but long term to train a GNN) however I can't work out how to decode the tfrecord.

The json below is the meta data that comes with the data that explains the features' shape, dtype etc. but can't work out how to use that to retrieve the data i need.

{
  "simulator": "comsol",
  "dt": 0.01,
  "collision_radius": null,
  "features": {
    "cells": {
      "type": "static",
      "shape": [
        1,
        -1,
        3
      ],
      "dtype": "int32"
    },
    "mesh_pos": {
      "type": "static",
      "shape": [
        1,
        -1,
        2
      ],
      "dtype": "float32"
    },
    "node_type": {
      "type": "static",
      "shape": [
        1,
        -1,
        1
      ],
      "dtype": "int32"
    },
    "velocity": {
      "type": "dynamic",
      "shape": [
        600,
        -1,
        2
      ],
      "dtype": "float32"
    },
    "pressure": {
      "type": "dynamic",
      "shape": [
        600,
        -1,
        1
      ],
      "dtype": "float32"
    }
  },
  "field_names": [
    "cells",
    "mesh_pos",
    "node_type",
    "velocity",
    "pressure"
  ],
  "trajectory_length": 600
}

I can get as far as retrieving binary strings for each of the features using the code below however I can't figure out how to decode the binary into something usable.

from tensorflow.io import FixedLenFeature, VarLenFeature

def extract_fn(data_record):
    features = {
                "cells": FixedLenFeature([], tf.string),
                "mesh_pos": FixedLenFeature([], tf.string),
                "node_type":FixedLenFeature([], tf.string),
                "velocity":FixedLenFeature([], tf.string),
                "pressure":FixedLenFeature([], tf.string),

              }
    sample = tf.io.parse_single_example(data_record, features)
    #sample = tf.cast(sample["image/encoded"], tf.float32)
    return sample

filename = "/content/CylinderData/cylinder_flow/train.tfrecord"
dataset = tf.data.TFRecordDataset(filename)
#print(f'Number of records: {sum(1 for _ in dataset)}')
#dataset = dataset.map(extract_fn)
#iterator = dataset.make_one_shot_iterator()
#next_element = iterator.get_next()
for i in dataset:
  example = extract_fn(i)
  print(example)
  break
like image 376
Andrew Russell Avatar asked Dec 14 '25 03:12

Andrew Russell


1 Answers

To convert raw binary strings from the TFRecord into a usable numerical format use tf.io.decode_raw with the correct dtype from your metadata, followed by tf.reshape to restore the original tensor dimensions. Please refer the code below which successfully performs this decoding and provides structured numerical tensors ready for your GNN.

import tensorflow as tf

def extract_fn(data_record):
    features = {
        "cells": tf.io.FixedLenFeature([], tf.string),
        "mesh_pos": tf.io.FixedLenFeature([], tf.string),
        "node_type": tf.io.FixedLenFeature([], tf.string),
        "velocity": tf.io.FixedLenFeature([], tf.string),
        "pressure": tf.io.FixedLenFeature([], tf.string),
    }
    sample = tf.io.parse_single_example(data_record, features)
    sample["cells"] = tf.reshape(tf.io.decode_raw(sample["cells"], tf.int32), [-1, 3])
    sample["mesh_pos"] = tf.reshape(tf.io.decode_raw(sample["mesh_pos"], tf.float32), [-1, 2])
    sample["node_type"] = tf.reshape(tf.io.decode_raw(sample["node_type"], tf.int32), [-1, 1])
    sample["velocity"] = tf.reshape(tf.io.decode_raw(sample["velocity"], tf.float32), [600, -1, 2])
    sample["pressure"] = tf.reshape(tf.io.decode_raw(sample["pressure"], tf.float32), [600, -1, 1])
    return sample

filename = "/content/drive/MyDrive/CylinderData/cylinder_flow/train.tfrecord"
dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.map(extract_fn)

for example in dataset.take(1):
  print("Decoded Cells:", example["cells"])
  print("Decoded Mesh Position:", example["mesh_pos"])
  print("Decoded Node Type:", example["node_type"])
  print("Decoded Velocity:", example["velocity"])
  print("Decoded Pressure:", example["pressure"])

Output

Decoded Cells: tf.Tensor(
[[   0    1    2]
 [   3    0    4]
 [   5    3    6]
 ...
 [1474 1868 1870]
 [1475 1474 1870]
 [1475 1870 1872]], shape=(3518, 3), dtype=int32)
Decoded Mesh Position: tf.Tensor(
[[0.         0.01577128]
 [0.         0.00782139]
 [0.01531635 0.01575096]
 ...
 [1.5816092  0.        ]
 [1.6        0.00360123]
 [1.6        0.        ]], shape=(1876, 2), dtype=float32)
Decoded Node Type: tf.Tensor(
[[4]
 [4]
 [0]
 ...
 [6]
 [5]
 [6]], shape=(1876, 1), dtype=int32)
Decoded Velocity: tf.Tensor(
[[[ 7.7918753e-02  0.0000000e+00]
  [ 3.9421193e-02  0.0000000e+00]
  [ 1.0014776e-01 -8.0902927e-02]
  ...
  [ 0.0000000e+00  0.0000000e+00]
  [ 2.7801219e-01 -3.8294352e-03]
  [ 0.0000000e+00  0.0000000e+00]]

 [[ 7.7918753e-02  0.0000000e+00]
  [ 3.9421193e-02  0.0000000e+00]
  [ 1.0051872e-01 -5.5038825e-02]
  ...
  [ 0.0000000e+00  0.0000000e+00]
  [ 1.8797542e-01 -1.1275966e-02]
  [ 0.0000000e+00  0.0000000e+00]]

 [[ 7.7918753e-02  0.0000000e+00]
  [ 3.9421193e-02  0.0000000e+00]
  [ 9.9102437e-02 -3.9522175e-02]
  ...
  [ 0.0000000e+00  0.0000000e+00]
  [ 1.5538059e-01 -1.2076479e-02]
  [ 0.0000000e+00  0.0000000e+00]]

 ...
like image 153
Sagar Avatar answered Dec 15 '25 16:12

Sagar