Overriding module functions in Python

Overview

Today I will describe a way of overriding functions from imported modules in Python which can be useful for profiling function calls or changing the functionality of a function you’re using without modifying the original source code.

We will be using a real world example for this tutorial. Let’s say you want to keep a count of the number of queries run by SQLObject. This is simple if you’re always calling the query function directly. But if the function is being called deeply within the module (as is the case with SQLObject) you have to get more crafty.

The problem

Let’s say you have the following bit of code:

import sqlobject

class SomeTable(sqlobject.SQLObject):
  some_column = sqlobject.StringCol()

for i in xrange(0, 5):
    s = SomeTable(some_column="Test %d" % i)
    print s.some_column

We know that somewhere behind the scenes, SQLObject is running real SQL queries, but we’re not using that code directly, nor is it presented to us. It would be ideal if SQLObject kept an internal query counter, but it doesn’t so we have to be enterprising and find a way to do it ourselves.

A little bit of digging through the SQLObject code leads us to an interesting bit of code in sqlobject.dbconnection:

class DBAPI:
    def _runWithConnection(self, meth, *args):
        conn = self.getConnection()
        try:
            val = meth(conn, *args)
        finally:
            self.releaseConnection(conn)
        return val

This code gets called each time SQLObject executes a query.

The solution

So we’ve found the code, but now we need a way of keeping track of the number of calls to that function. One way to do this is to simply override that attribute with one that is more to our liking. First let’s see what such a function would look like:

QUERY_COUNT = 0

original_function = sqlobject.dbconnection.DBAPI._runWithConnection

def wrapper(self, meth, *args):
    global QUERY_COUNT
    QUERY_COUNT += 1
    return original_function(self, meth, *args)

sqlobject.dbconnection.DBAPI._runWithConnection = wrapper

This works for most cases, but what if there were subclasses of DBAPI that had their own _runWithConnection methods? We need a way to replace those as well.

You can probably already see where this is going. We need to get all subclasses of the DBAPI class and recursively all of their subclasses, then replace the _runWithConnection method in each of them with our wrapper function.

The roadblock

This is relatively trivial with classes that inherit from ‘object’ (also known as ‘2.4 style’ classes), unfortunately DBAPI is not one of them! To make a long story short, things quickly get more complicated than they need to be. Wouldn’t it be nice if someone had written a module that could do this for us?

Introducing wraptools

You may be pleased to know that someone has in fact created such a module. That person happens to be me, which is why it’s being plugged here 🙂

Let’s go back to the original solution and modify it slightly to use wraptools.

from wraptools import wraps

QUERY_COUNT = 0

@wraps(sqlobject.dbconnection.DBAPI._runWithConnection)
def wrapper(original_function, self, meth, *args):
    global QUERY_COUNT
    QUERY_COUNT += 1
    return original_function(self, meth, *args)

So in this example @wraps does all of the dirty work of replacing every instance of the function (including subclasses) via the sweet syntactic sugar of a decorator. The original function and any arguments it was called with gets passed along to your wrapper, allowing you to munge the input or output to your liking.

Putting it all together

import sqlobject
from wraptools import wraps

class SomeTable(sqlobject.SQLObject):
  some_column = sqlobject.StringCol()

QUERY_COUNT = 0

@wraps(sqlobject.dbconnection.DBAPI._runWithConnection)
def wrapper(original_function, self, meth, *args):
    global QUERY_COUNT
    QUERY_COUNT += 1
    return original_function(self, meth, *args)

for i in xrange(0, 5):
    s = SomeTable(some_column="Test %d" % i)

print "Total queries:", QUERY_COUNT

And there you have it: An elegant, simple way of profiling SQLObject queries! This concept can be easily extended to profile and override SQLAlchemy as well as most other Python modules.

You can access the module documentation here.

The damn code

If you’re psyched, you can grab the source code from my code repository.

One Reply to “Overriding module functions in Python”

  1. Hi Phillip,

    I am totally Psyched by your module.

    I am testing a Django Application, and its quite complex.

    Your module has saved me an immense amount of time as I can replace functions inside of my views with different return values to test the code that handles the return values.

    Thank you so much!

    Regards

    Mark
    http://twitter.com/mark_ellul

Leave a Reply