Hack for functools.partial and multiprocessing

Today I wrote some python code that mapped a function onto a sequence of financial time series data.  For the first version, I used itertools.imap to apply a partial function created with functools.partial to the sequence.  I needed a partial function with curried arguments because a couple of the inputs were created and fixed before the call to itertools.imap.

The program took 20 minutes to run because the dataset is rather large.  Since it’s an embarassingly parallel problem, I decided to give the multiprocessing package a whirl.  I’ve used multiprocessing’s parallel map functions before, so I naively created a worker Pool with 4 proceses and changed itertools.imap to pool.imap.  UnfortunatelyI was smacked with this error: “TypeError: type ‘partial’ takes at least one argument.”

It seems that the functools.partial object isn’t pickeled across the multiprocessing workers.  I worked around this by using itertools.izip, combined with itertools.repeat.  The itertools.repeat function mimics currying by repeating the fixed argument.  I’ve pasted a simple example of this below, and I’ve also put it here.  If anybody has a better way of doing this, please let me know.

import functools
import itertools
import multiprocessing

def uncurryDummyMultiply(t):
    return dummyMultiply(t[0], t[1])

def dummyMultiply(a, b):
    return a * b

def serialMap():
    mathFun = functools.partial(dummyMultiply, 10)
    total = 0
    for x in itertools.imap(mathFun, xrange(1000)):
        total += x
    return total

def parallelMap_Error():
    """Generates an Error!
    TypeError: type 'partial' takes at least one argument
    """
    mathFun = functools.partial(dummyMultiply, 10)
    pool    = multiprocessing.Pool(processes=4)
    total = 0
    for x in pool.imap(mathFun, xrange(1000)):
        total += x
    return total

def parallelMap_NoError():
    """Parallel version of serialMap that doesn't generate a TypeError
    """
    pool    = multiprocessing.Pool(processes=4)
    total = 0
    for x in pool.imap(uncurryDummyMultiply,
                       itertools.izip(itertools.repeat(10),
                                      xrange(1000)),
                       chunksize=50):
        total += x
    return total

if __name__ == "__main__":
    print "serialMap result:   %d" % serialMap()
    print "parallelMap result: %d" % parallelMap_NoError()
Advertisement

Leave a Reply

Fill in your details below or click an icon to log in:

Gravatar
WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Connecting to %s

Follow

Get every new post delivered to your Inbox.