eeb9de3f84
[SVN r47490]
132 lines
3.7 KiB
Python
132 lines
3.7 KiB
Python
# (C) Copyright 2007
|
|
# Andreas Kloeckner <inform -at- tiker.net>
|
|
#
|
|
# Use, modification and distribution is subject to the Boost Software
|
|
# License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
|
|
# http://www.boost.org/LICENSE_1_0.txt)
|
|
#
|
|
# Authors: Andreas Kloeckner
|
|
|
|
|
|
|
|
|
|
import boost.mpi as mpi
|
|
import random
|
|
import sys
|
|
|
|
MAX_GENERATIONS = 20
|
|
TAG_DEBUG = 0
|
|
TAG_DATA = 1
|
|
TAG_TERMINATE = 2
|
|
TAG_PROGRESS_REPORT = 3
|
|
|
|
|
|
|
|
|
|
class TagGroupListener:
|
|
"""Class to help listen for only a given set of tags.
|
|
|
|
This is contrived: Typicallly you could just listen for
|
|
mpi.any_tag and filter."""
|
|
def __init__(self, comm, tags):
|
|
self.tags = tags
|
|
self.comm = comm
|
|
self.active_requests = {}
|
|
|
|
def wait(self):
|
|
for tag in self.tags:
|
|
if tag not in self.active_requests:
|
|
self.active_requests[tag] = self.comm.irecv(tag=tag)
|
|
requests = mpi.RequestList(self.active_requests.values())
|
|
data, status, index = mpi.wait_any(requests)
|
|
del self.active_requests[status.tag]
|
|
return status, data
|
|
|
|
def cancel(self):
|
|
for r in self.active_requests.itervalues():
|
|
r.cancel()
|
|
#r.wait()
|
|
self.active_requests = {}
|
|
|
|
|
|
|
|
def rank0():
|
|
sent_histories = (mpi.size-1)*15
|
|
print "sending %d packets on their way" % sent_histories
|
|
send_reqs = mpi.RequestList()
|
|
for i in range(sent_histories):
|
|
dest = random.randrange(1, mpi.size)
|
|
send_reqs.append(mpi.world.isend(dest, TAG_DATA, []))
|
|
|
|
mpi.wait_all(send_reqs)
|
|
|
|
completed_histories = []
|
|
progress_reports = {}
|
|
dead_kids = []
|
|
|
|
tgl = TagGroupListener(mpi.world,
|
|
[TAG_DATA, TAG_DEBUG, TAG_PROGRESS_REPORT, TAG_TERMINATE])
|
|
|
|
def is_complete():
|
|
for i in progress_reports.values():
|
|
if i != sent_histories:
|
|
return False
|
|
return len(dead_kids) == mpi.size-1
|
|
|
|
while True:
|
|
status, data = tgl.wait()
|
|
|
|
if status.tag == TAG_DATA:
|
|
#print "received completed history %s from %d" % (data, status.source)
|
|
completed_histories.append(data)
|
|
if len(completed_histories) == sent_histories:
|
|
print "all histories received, exiting"
|
|
for rank in range(1, mpi.size):
|
|
mpi.world.send(rank, TAG_TERMINATE, None)
|
|
elif status.tag == TAG_PROGRESS_REPORT:
|
|
progress_reports[len(data)] = progress_reports.get(len(data), 0) + 1
|
|
elif status.tag == TAG_DEBUG:
|
|
print "[DBG %d] %s" % (status.source, data)
|
|
elif status.tag == TAG_TERMINATE:
|
|
dead_kids.append(status.source)
|
|
else:
|
|
print "unexpected tag %d from %d" % (status.tag, status.source)
|
|
|
|
if is_complete():
|
|
break
|
|
|
|
print "OK"
|
|
|
|
def comm_rank():
|
|
while True:
|
|
data, status = mpi.world.recv(return_status=True)
|
|
if status.tag == TAG_DATA:
|
|
mpi.world.send(0, TAG_PROGRESS_REPORT, data)
|
|
data.append(mpi.rank)
|
|
if len(data) >= MAX_GENERATIONS:
|
|
dest = 0
|
|
else:
|
|
dest = random.randrange(1, mpi.size)
|
|
mpi.world.send(dest, TAG_DATA, data)
|
|
elif status.tag == TAG_TERMINATE:
|
|
from time import sleep
|
|
mpi.world.send(0, TAG_TERMINATE, 0)
|
|
break
|
|
else:
|
|
print "[DIRECTDBG %d] unexpected tag %d from %d" % (mpi.rank, status.tag, status.source)
|
|
|
|
|
|
def main():
|
|
# this program sends around messages consisting of lists of visited nodes
|
|
# randomly. After MAX_GENERATIONS, they are returned to rank 0.
|
|
|
|
if mpi.rank == 0:
|
|
rank0()
|
|
else:
|
|
comm_rank()
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|