Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Find all variables that a tensorflow op depends upon

Tags:

tensorflow

Is there a way to find all variables that a given operation (usually a loss) depends upon? I would like to use this to then pass this collection into optimizer.minimize() or tf.gradients() using various set().intersection() combinations.

So far I have found op.op.inputs and tried a simple BFS on that, but I never chance upon Variable objects as returned by tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) or slim.get_variables()

There does seem to be a correspondence between corresponding 'Tensor.op._idandVariables.op._id` fields, but I'm not sure that's a something I should rely upon.

Or maybe I should't want to do this in the first place? I could of course construct my disjoint sets of variables meticulously while building my graph, but then it would be easy to miss something if I change the model.

like image 302
black_puppydog Avatar asked Dec 05 '25 15:12

black_puppydog


1 Answers

The documentation for tf.Variable.op is not particularly clear, but it does refer to the crucial tf.Operation used in the implementation of a tf.Variable: any op that depends on a tf.Variable will be on a path from that operation. Since the tf.Operation object is hashable, you can use it as the key of a dict that maps tf.Operation objects to the corresponding tf.Variable object, and then perform the BFS as before:

op_to_var = {var.op: var for var in tf.trainable_variables()}

starting_op = ...
dependent_vars = []

queue = collections.deque()
queue.append(starting_op)

visited = set([starting_op])

while queue:
  op = queue.popleft()
  try:
    dependent_vars.append(op_to_var[op])
  except KeyError:
    # `op` is not a variable, so search its inputs (if any). 
    for op_input in op.inputs:
      if op_input.op not in visited:
        queue.append(op_input.op)
        visited.add(op_input.op)
like image 63
mrry Avatar answered Dec 07 '25 16:12

mrry