diff --git a/python/dgl/runtime/ir/program.py b/python/dgl/runtime/ir/program.py index fa17937fb050..0682f4d4ba06 100644 --- a/python/dgl/runtime/ir/program.py +++ b/python/dgl/runtime/ir/program.py @@ -2,6 +2,7 @@ from __future__ import absolute_import from contextlib import contextmanager +import threading from .registry import IR_REGISTRY @@ -44,18 +45,30 @@ def pprint(self): for exe in self.execs: self.pprint_exe(exe) +class CurrentProgram(threading.local): + """Thread local storage to keep the reference of current thread's program""" + def __init__(self): + super(CurrentProgram, self).__init__() + self.prog = None + + def get_prog(self): + """Get program""" + return self.prog + + def set_prog(self, program): + """Set program""" + self.prog = program + # current program -CURRENT_PROG = None +CURRENT_PROG = CurrentProgram() def get_current_prog(): """Get the current program.""" - global CURRENT_PROG - return CURRENT_PROG + return CURRENT_PROG.get_prog() def set_current_prog(program): """Set the current program.""" - global CURRENT_PROG - CURRENT_PROG = program + CURRENT_PROG.set_prog(program) @contextmanager def prog():