Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python AST - finding particular named function calls

I'm trying to analyse some Python code to identify where specific functions are being called and which arguments are being passed.

For instance, suppose I have an ML script that contains a model.fit(X_train,y_train). I want to find this line in the script, identify what object is being fit (i.e., model), and to identify X_train and y_train as the arguments (as well as any others).

I'm new to AST, so I don't know how to do this in an efficient way.

So far, I've been able to locate the line in question by iterating through a list of child nodes (using ast.iter_child_nodes) until I arrive at the ast.Call object, and then calling its func.attr, which returns "fit". I can also get "X_train" and "y_train" with args.

The problem is that I have to know where it is in advance in order to do it this way, so it's not particularly useful. The idea would be for it to obtain the information I'm looking for automatically.

Additionally, I have not been able to find a way to determine that model is what is calling fit.

like image 763
radishapollo Avatar asked Sep 07 '25 18:09

radishapollo


1 Answers

You can traverse the ast and search for ast.Call nodes where the name is fit:

import ast
def fit_calls(tree):
  for i in ast.walk(tree):
    if isinstance(i, ast.Call) and isinstance(i.func, ast.Attribute) and i.func.attr == 'fit':
       yield {'model_obj_str':ast.unparse(i.func.value),
              'model_obj_ast':i.func.value,
              'args':[ast.unparse(j) for j in i.args],
              'kwargs':{j.arg:ast.unparse(j.value) for j in i.keywords}}

Test samples:

#https://www.tensorflow.org/api_docs/python/tf/keras/Model
sample_1 = """
model = tf.keras.models.Model(
   inputs=inputs, outputs=[output_1, output_2])
model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"])
model.fit(x, (y, y))
model.metrics_names
"""
sample_2 = """
optimizer = tf.keras.optimizers.SGD()
model.compile(optimizer, loss='mse', steps_per_execution=10)
model.fit(dataset, epochs=2, steps_per_epoch=10)
"""
sample_3 = """
x = np.random.random((2, 3))
y = np.random.randint(0, 2, (2, 2))
_ = model.fit(x, y, verbose=0)
"""
#https://scikit-learn.org/stable/developers/develop.html
sample_4 = """
estimator = estimator.fit(data, targets)
"""
sample_5 = """
y_predicted = SVC(C=100).fit(X_train, y_train).predict(X_test)
"""

print([*fit_calls(ast.parse(sample_1))])
print([*fit_calls(ast.parse(sample_2))])
print([*fit_calls(ast.parse(sample_3))])
print([*fit_calls(ast.parse(sample_4))])
print([*fit_calls(ast.parse(sample_5))])

Output:

[{'model_obj_str': 'model', 'model_obj_ast': <ast.Name object at 0x1007737c0>, 
  'args': ['x', '(y, y)'], 'kwargs': {}}]
[{'model_obj_str': 'model', 'model_obj_ast': <ast.Name object at 0x1007731f0>, 
  'args': ['dataset'], 'kwargs': {'epochs': '2', 'steps_per_epoch': '10'}}]
[{'model_obj_str': 'model', 'model_obj_ast': <ast.Name object at 0x100773d00>, 
  'args': ['x', 'y'], 'kwargs': {'verbose': '0'}}]
[{'model_obj_str': 'estimator', 'model_obj_ast': <ast.Name object at 0x100773ca0>, 
  'args': ['data', 'targets'], 'kwargs': {}}]
[{'model_obj_str': 'SVC(C=100)', 'model_obj_ast': <ast.Call object at 0x100773130>, 
  'args': ['X_train', 'y_train'], 'kwargs': {}}]
 
like image 89
Ajax1234 Avatar answered Sep 10 '25 04:09

Ajax1234