Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TF-agents - Replay buffer add trajectory to batch shape mismatch

I'm posting a question that was posted by another user and then deleted. I had the same question, and I found an answer. The original question:

I am currently trying to implement a categorical DQN following this tutorial: https://www.tensorflow.org/agents/tutorials/9_c51_tutorial

The following part is giving me a bit of a headache though:

random_policy = random_tf_policy.RandomTFPolicy(env.time_step_spec(),
                                                env.action_spec())

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=agent.collect_data_spec,
batch_size=1,
max_length=replay_buffer_capacity) # this is 100

# ...

def collect_step(environment, policy):
  time_step = environment.current_time_step()
  action_step = policy.action(time_step)
  next_time_step = environment.step(action_step.action)
  traj = trajectory.from_transition(time_step, action_step, next_time_step)
  print(traj)

  # Add trajectory to the replay buffer
  replay_buffer.add_batch(traj)

for _ in range(initial_collect_steps):
  collect_step(env, random_policy)

For context: agent.collect_data_spec is of the following shape:

Trajectory(step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), observation=BoundedTensorSpec(shape=(4, 84, 84), dtype=tf.float32, name='screen', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)), action=BoundedTensorSpec(shape=(), dtype=tf.int32, name='play', minimum=array(0), maximum=array(6)), policy_info=(), next_step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), reward=TensorSpec(shape=(), dtype=tf.float32, name='reward'), discount=BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)))

And here is what a sample traj looks like:

Trajectory(step_type=<tf.Tensor: shape=(), dtype=int32, numpy=0>, observation=<tf.Tensor: shape=(4, 84, 84), dtype=float32, numpy=array([tensor contents omitted], dtype=float32)>, action=<tf.Tensor: shape=(), dtype=int32, numpy=1>, policy_info=(), next_step_type=<tf.Tensor: shape=(), dtype=int32, numpy=1>, reward=<tf.Tensor: shape=(), dtype=float32, numpy=0.0>, discount=<tf.Tensor: shape=(), dtype=float32, numpy=1.0>)

So, everything should check out, right? The environment outputs a tensor of shape [4, 84, 84], same as the replay buffer expects. Except I'm getting the following error:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Must have updates.shape = indices.shape + params.shape[1:] or updates.shape = [], got updates.shape [4,84,84], indices.shape [1], params.shape [100,4,84,84] [Op:ResourceScatterUpdate]

Which suggests that it is actually expecting a tensor of shape [1, 4, 84, 84]. The thing is though, if I have my environment output a tensor of that shape, I then get another error message telling me that the output shape doesn't match the spec shape (duh). And if I then adjust the spec shape to be [1, 4, 84, 84], suddenly the replay buffer expects a shape of [1, 1, 4, 84, 84], and so on...

Finally, for completion, here you have the time_step_spec and action_spec of my environment respectively:

TimeStep(step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), reward=TensorSpec(shape=(), dtype=tf.float32, name='reward'), discount=BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)), observation=BoundedTensorSpec(shape=(4, 84, 84), dtype=tf.float32, name='screen', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)))
---
BoundedTensorSpec(shape=(), dtype=tf.int32, name='play', minimum=array(0), maximum=array(6))

I've tried pretty much the better half of today trying to get the tensor to fit properly, but you cannot reshape it since it's an attribute so in a last ditch effort I'm hoping maybe some kind stranger out there can tell me what the heck is going on here.

Thank you in advance!

like image 726
David Braun Avatar asked Oct 26 '25 03:10

David Braun


1 Answers

It seems that in the collect_step function, traj is a a single trajectory, not a batch. Therefore you need to expand the dimensions into a batch and then use it. Note that you can't just do tf.expand_dims(traj, 0). There's a helper function for doing it for nested structures.

def collect_step(environment, policy):
    time_step = environment.current_time_step()
    action_step = policy.action(time_step)
    next_time_step = environment.step(action_step.action)
    traj = trajectory.from_transition(time_step, action_step, next_time_step)
    batch = tf.nest.map_structure(lambda t: tf.expand_dims(t, 0), traj)
    # Add trajectory to the replay buffer
    replay_buffer.add_batch(batch)
like image 196
David Braun Avatar answered Oct 27 '25 16:10

David Braun