-
Notifications
You must be signed in to change notification settings - Fork 35
/
plugin_manager.py
163 lines (148 loc) · 6.09 KB
/
plugin_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import asyncio
import importlib.machinery
import inspect
import logging
import pathlib
from types import ModuleType
from base_plugin import BasePlugin
from configuration_manager import ConfigurationManager
from pparser import PacketParser
from utilities import detect_overrides
class PluginManager:
def __init__(self, config: ConfigurationManager, *, base=BasePlugin,
factory=None):
self.base = base
self.config = config
self.failed = {}
self._seen_classes = set()
self._plugins = {}
self._activated_plugins = set()
self._deactivated_plugins = set()
self._resolved = False
self._overrides = set()
self._override_cache = set()
self._packet_parser = PacketParser(self.config)
self._factory = factory
self.logger = logging.getLogger("starrypy.plugin_manager")
def list_plugins(self):
return self._plugins
async def do(self, connection, action: str, packet: dict):
"""
Calls an action on all loaded plugins.
"""
try:
if ("on_%s" % action) in self._overrides:
packet = await self._packet_parser.parse(packet)
send_flag = True
for plugin in self._plugins.values():
p = getattr(plugin, "on_%s" % action)
if not (await p(packet, connection)):
send_flag = False
return send_flag
else:
return True
except Exception:
self.logger.exception("Exception encountered in plugin on action: "
"%s", action, exc_info=True)
return True
def load_from_path(self, plugin_path: pathlib.Path):
blacklist = ["__init__", "__pycache__"]
loaded = set()
for file in plugin_path.iterdir():
if file.stem in blacklist:
continue
if (file.suffix == ".py" or file.is_dir()) and str(
file) not in loaded:
try:
loaded.add(str(file))
self.load_plugin(file)
except (SyntaxError, ImportError) as e:
self.failed[file.stem] = str(e)
print(e)
except FileNotFoundError:
self.logger.warning("File not found in plugin loader.")
@staticmethod
def _load_module(file_path: pathlib.Path):
"""
Attempts to load a module, either from a straight python file or from
a python package, by appending __init__.py to the end of the path if it
is a directory.
"""
if file_path.is_dir():
file_path /= '__init__.py'
if not file_path.exists():
raise FileNotFoundError("{0} doesn't exist.".format(file_path))
name = "plugins.%s" % file_path.stem
loader = importlib.machinery.SourceFileLoader(name, str(file_path))
module = loader.load_module(name)
return module
def load_plugin(self, plugin_path: pathlib.Path):
module = self._load_module(plugin_path)
classes = self.get_classes(module)
for candidate in classes:
candidate.factory = self._factory
self._seen_classes.add(candidate)
self.config.save_config()
def get_classes(self, module: ModuleType):
"""
Uses the inspect module to find all classes in a given module that
are subclassed from `self.base`, but are not actually `self.base`.
"""
class_list = []
for _, obj in inspect.getmembers(module):
if inspect.isclass(obj):
if issubclass(obj, self.base) and obj is not self.base:
obj.config = self.config
obj.logger = logging.getLogger("starrypy.plugin.%s" %
obj.name)
class_list.append(obj)
return class_list
def load_plugins(self, plugins: list):
for plugin in plugins:
self.load_plugin(plugin)
def resolve_dependencies(self):
"""
Resolves dependencies from self._seen_classes through a very simple
topological sort. Raises ImportError if there is an unresolvable
dependency, otherwise it instantiates the class and puts it in
self._plugins.
"""
deps = {x.name: set(x.depends) for x in self._seen_classes}
classes = {x.name: x for x in self._seen_classes}
while len(deps) > 0:
ready = [x for x, d in deps.items() if len(d) == 0]
for name in ready:
p = classes[name]()
self._plugins[name] = p
del deps[name]
for name, depends in deps.items():
to_load = depends & set(self._plugins.keys())
deps[name] = deps[name].difference(set(self._plugins.keys()))
for plugin in to_load:
classes[name].plugins[plugin] = self._plugins[plugin]
if len(ready) == 0:
raise ImportError("Unresolved dependencies found in: "
"{}".format(deps))
self._resolved = True
async def get_overrides(self):
if self._override_cache is self._activated_plugins:
return self._overrides
else:
overrides = set()
for plugin in self._activated_plugins:
override = await detect_overrides(BasePlugin, plugin)
overrides.update({x for x in override})
self._overrides = overrides
self._override_cache = self._activated_plugins
return overrides
async def activate_all(self):
self.logger.info("Activating plugins:")
for plugin in self._plugins.values():
self.logger.info(plugin.name)
await plugin.activate()
self._activated_plugins.add(plugin)
await self.get_overrides()
async def deactivate_all(self):
for plugin in self._plugins.values():
self.logger.info("Deactivating %s", plugin.name)
await plugin.deactivate()