我不确定这是不是答案,我也不确定您是否还在寻找答案,但是...
所以你有 100,000 个 python 对象。如果这些对象是常规数据(数据集),而不是某个类的实例,则将数据作为 json 字符串传递。像这样的:
#!/usr/bin/env python
import json
import numpy as np
from mpi4py import MPI
comm = MPI.COMM_WORLD
if comm.rank == 0:
tasks = [
json.dumps( { 'a':1,'x':2,'b':3 } ),
json.dumps( { 'a':3,'x':1,'b':2 } ),
json.dumps( { 'a':2,'x':3,'b':1 } )
]
else:
tasks = None
# Scatter paramters arrays
unit = comm.scatter(tasks, root=0)
p = json.loads(unit)
print "-"*18
print("-- I'm rank %d in %d size task" % (comm.rank,comm.size) )
print("-- My paramters are: {}".format(p))
print "-"*18
comm.Barrier()
calc = p['a']*p['x']**2+p['b']
# gather results
result = comm.gather(calc, root=0)
# do something with result
if comm.rank == 0:
print "the result is ", result
else:
result = None
注意,如果您只有 8 个节点/核心,则必须在 tasks 列表中创建 8 条记录,并依次分散和收集所有 100,000 个数据集。如果您的所有数据集都在ALLDATA 列表中,代码可能如下所示:
def calc(a=0,x=0,b=0):
return a*x**2+b
if comm.rank == 0: collector = []
for xset in zip(*(iter(ALLDATA),) * comm.size):
task = [ json.dumps(s) for s in xset ]
comm.Barrier()
unit = comm.scatter(task if comm.rank == 0 else None, root=0)
p = json.loads(unit)
res = json.dumps( calc(**p) )
totres = comm.gather(res, root=0)
if comm.rank == 0:
collector += [ json.loads(x) for x in totres ]
if comm.rank == 0:
print "the result is ", collector