Say I've added a few event listeners to a model and I wanted to get a list of all of these added events for the model to verify their existence during testing with assertions. Is there a way to do so?
I'm aware of SQLAlchemy's inspect, which I currently use to assert the presence of columns and relationships. But, is there a way to obtain the list of custom event listeners through inspect as well? If not, is there any other way of doing so? I'd like to obtain only the events that have been added to the models explicitly, not those that are present by default (if possible).
Example of how I expect to retrieve event listeners:
def test_schema(self):
# sanity checks
# this will raise any flags in the event schema is modified, so we know to update the appropriate tests
assert tuple(inspect(MyModel).columns.keys()) == (
"id", "module", "slug", "display_name"
)
assert tuple(inspect(MyModel).relationships.keys()) == ("accounts", "reports", "jobs")
assert tuple(inspect(MyModel).events) == (
"{event_function_name}_{trigger_action}",
"{notify_manager_of_billing_changes}_{after_update}"
)
def notify_manager_of_billing_changes(mapper, connection, model_instance):
print(model_instance.billing_address)
from sqlalchemy import event
event.listen(MyModel, "after_update", notify_manager_of_billing_changes, retval=False)
The public API for such a test is:
assert event.contains(MyModel, "after_update", notify_manager_of_billing_changes)
SQLAlchemy doesn't track the function name, only its id1 and a wrap function2.
1 as in id(notify_manager_of_billing_changes).
2 without using functools.wraps!
With the help of call_function_get_frame from the answer How can I get the values of the locals of a function, adding an except IndexError:, we can get the reference to fn from the wrap function.
import sys
from sqlalchemy.orm import Mapper
def call_function_get_frame(func, *args, **kwargs):
"""
Calls the function *func* with the specified arguments and keyword
arguments and snatches its local frame before it actually executes.
"""
frame = None
trace = sys.gettrace()
def snatch_locals(_frame, name, arg):
nonlocal frame
if frame is None and name == 'call':
frame = _frame
sys.settrace(trace)
return trace
sys.settrace(snatch_locals)
try:
result = func(*args, **kwargs)
except IndexError: # Added
result = None # Added
finally:
sys.settrace(trace)
return frame, result
def get_events(mapper):
events = []
dispatch = mapper.dispatch
for event_name in dispatch._event_names:
listeners = getattr(dispatch, event_name).listeners
for wrap in listeners:
frame, result = call_function_get_frame(wrap)
events.append(f"{{{frame.f_locals['fn'].__name__}}}_{{{event_name}}}")
return events
Mapper.events = property(get_events)
Usage, as desired in the question:
assert tuple(inspect(MyModel).events) == (
# "{event_function_name}_{trigger_action}",
"{notify_manager_of_billing_changes}_{after_update}",
)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With