I'm trying to analyse some Python code to identify where specific functions are being called and which arguments are being passed.
For instance, suppose I have an ML script that contains a model.fit(X_train,y_train)
. I want to find this line in the script, identify what object is being fit (i.e., model
), and to identify X_train
and y_train
as the arguments (as well as any others).
I'm new to AST, so I don't know how to do this in an efficient way.
So far, I've been able to locate the line in question by iterating through a list of child nodes (using ast.iter_child_nodes
) until I arrive at the ast.Call
object, and then calling its func.attr
, which returns "fit"
. I can also get "X_train"
and "y_train"
with args
.
The problem is that I have to know where it is in advance in order to do it this way, so it's not particularly useful. The idea would be for it to obtain the information I'm looking for automatically.
Additionally, I have not been able to find a way to determine that model
is what is calling fit
.
You can traverse the ast
and search for ast.Call
nodes where the name is fit
:
import ast
def fit_calls(tree):
for i in ast.walk(tree):
if isinstance(i, ast.Call) and isinstance(i.func, ast.Attribute) and i.func.attr == 'fit':
yield {'model_obj_str':ast.unparse(i.func.value),
'model_obj_ast':i.func.value,
'args':[ast.unparse(j) for j in i.args],
'kwargs':{j.arg:ast.unparse(j.value) for j in i.keywords}}
Test samples:
#https://www.tensorflow.org/api_docs/python/tf/keras/Model
sample_1 = """
model = tf.keras.models.Model(
inputs=inputs, outputs=[output_1, output_2])
model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"])
model.fit(x, (y, y))
model.metrics_names
"""
sample_2 = """
optimizer = tf.keras.optimizers.SGD()
model.compile(optimizer, loss='mse', steps_per_execution=10)
model.fit(dataset, epochs=2, steps_per_epoch=10)
"""
sample_3 = """
x = np.random.random((2, 3))
y = np.random.randint(0, 2, (2, 2))
_ = model.fit(x, y, verbose=0)
"""
#https://scikit-learn.org/stable/developers/develop.html
sample_4 = """
estimator = estimator.fit(data, targets)
"""
sample_5 = """
y_predicted = SVC(C=100).fit(X_train, y_train).predict(X_test)
"""
print([*fit_calls(ast.parse(sample_1))])
print([*fit_calls(ast.parse(sample_2))])
print([*fit_calls(ast.parse(sample_3))])
print([*fit_calls(ast.parse(sample_4))])
print([*fit_calls(ast.parse(sample_5))])
Output:
[{'model_obj_str': 'model', 'model_obj_ast': <ast.Name object at 0x1007737c0>,
'args': ['x', '(y, y)'], 'kwargs': {}}]
[{'model_obj_str': 'model', 'model_obj_ast': <ast.Name object at 0x1007731f0>,
'args': ['dataset'], 'kwargs': {'epochs': '2', 'steps_per_epoch': '10'}}]
[{'model_obj_str': 'model', 'model_obj_ast': <ast.Name object at 0x100773d00>,
'args': ['x', 'y'], 'kwargs': {'verbose': '0'}}]
[{'model_obj_str': 'estimator', 'model_obj_ast': <ast.Name object at 0x100773ca0>,
'args': ['data', 'targets'], 'kwargs': {}}]
[{'model_obj_str': 'SVC(C=100)', 'model_obj_ast': <ast.Call object at 0x100773130>,
'args': ['X_train', 'y_train'], 'kwargs': {}}]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With