Today I wrote some python code that mapped a function onto a sequence of financial time series data. For the first version, I used itertools.imap to apply a partial function created with functools.partial to the sequence. I needed a partial function with curried arguments because a couple of the inputs were created and fixed before the call to itertools.imap.
The program took 20 minutes to run because the dataset is rather large. Since it’s an embarassingly parallel problem, I decided to give the multiprocessing package a whirl. I’ve used multiprocessing’s parallel map functions before, so I naively created a worker Pool with 4 proceses and changed itertools.imap to pool.imap. UnfortunatelyI was smacked with this error: “TypeError: type ‘partial’ takes at least one argument.”
It seems that the functools.partial object isn’t pickeled across the multiprocessing workers. I worked around this by using itertools.izip, combined with itertools.repeat. The itertools.repeat function mimics currying by repeating the fixed argument. I’ve pasted a simple example of this below, and I’ve also put it here. If anybody has a better way of doing this, please let me know.
import functools
import itertools
import multiprocessing
def uncurryDummyMultiply(t):
return dummyMultiply(t[0], t[1])
def dummyMultiply(a, b):
return a * b
def serialMap():
mathFun = functools.partial(dummyMultiply, 10)
total = 0
for x in itertools.imap(mathFun, xrange(1000)):
total += x
return total
def parallelMap_Error():
"""Generates an Error!
TypeError: type 'partial' takes at least one argument
"""
mathFun = functools.partial(dummyMultiply, 10)
pool = multiprocessing.Pool(processes=4)
total = 0
for x in pool.imap(mathFun, xrange(1000)):
total += x
return total
def parallelMap_NoError():
"""Parallel version of serialMap that doesn't generate a TypeError
"""
pool = multiprocessing.Pool(processes=4)
total = 0
for x in pool.imap(uncurryDummyMultiply,
itertools.izip(itertools.repeat(10),
xrange(1000)),
chunksize=50):
total += x
return total
if __name__ == "__main__":
print "serialMap result: %d" % serialMap()
print "parallelMap result: %d" % parallelMap_NoError()