Is there a way to find all variables that a given operation (usually a loss) depends upon?
I would like to use this to then pass this collection into optimizer.minimize() or tf.gradients() using various set().intersection() combinations.
So far I have found op.op.inputs and tried a simple BFS on that, but I never chance upon Variable objects as returned by tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) or slim.get_variables()
There does seem to be a correspondence between corresponding 'Tensor.op._idandVariables.op._id` fields, but I'm not sure that's a something I should rely upon.
Or maybe I should't want to do this in the first place? I could of course construct my disjoint sets of variables meticulously while building my graph, but then it would be easy to miss something if I change the model.
The documentation for tf.Variable.op is not particularly clear, but it does refer to the crucial tf.Operation used in the implementation of a tf.Variable: any op that depends on a tf.Variable will be on a path from that operation. Since the tf.Operation object is hashable, you can use it as the key of a dict that maps tf.Operation objects to the corresponding tf.Variable object, and then perform the BFS as before:
op_to_var = {var.op: var for var in tf.trainable_variables()}
starting_op = ...
dependent_vars = []
queue = collections.deque()
queue.append(starting_op)
visited = set([starting_op])
while queue:
op = queue.popleft()
try:
dependent_vars.append(op_to_var[op])
except KeyError:
# `op` is not a variable, so search its inputs (if any).
for op_input in op.inputs:
if op_input.op not in visited:
queue.append(op_input.op)
visited.add(op_input.op)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With