Skip to content

Commit

Permalink
Add lock around protocol detection (#1816)
Browse files Browse the repository at this point in the history
  • Loading branch information
narrieta authored Mar 17, 2020
1 parent 2eec112 commit f050ec9
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 67 deletions.
115 changes: 63 additions & 52 deletions azurelinuxagent/common/protocol/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class ProtocolUtil(SingletonPerThread):
"""

def __init__(self):
self._lock = threading.RLock() # protects the files on disk created during protocol detection
self._protocol = None
self.endpoint = None
self.osutil = get_osutil()
Expand Down Expand Up @@ -162,29 +163,32 @@ def _get_tag_file_path(self):
TAG_FILE_NAME)

def get_wireserver_endpoint(self):
self._lock.acquire()
try:
if self.endpoint:
return self.endpoint

if self.endpoint:
return self.endpoint

file_path = self._get_wireserver_endpoint_file_path()
if os.path.isfile(file_path):
try:
self.endpoint = fileutil.read_file(file_path)
file_path = self._get_wireserver_endpoint_file_path()
if os.path.isfile(file_path):
try:
self.endpoint = fileutil.read_file(file_path)

if self.endpoint:
logger.info("WireServer endpoint {0} read from file", self.endpoint)
return self.endpoint
if self.endpoint:
logger.info("WireServer endpoint {0} read from file", self.endpoint)
return self.endpoint

logger.error("[GetWireserverEndpoint] Unexpected empty file {0}", file_path)
except (IOError, OSError) as e:
logger.error("[GetWireserverEndpoint] Error reading file {0}: {1}", file_path, str(e))
else:
logger.error("[GetWireserverEndpoint] Missing file {0}", file_path)
logger.error("[GetWireserverEndpoint] Unexpected empty file {0}", file_path)
except (IOError, OSError) as e:
logger.error("[GetWireserverEndpoint] Error reading file {0}: {1}", file_path, str(e))
else:
logger.error("[GetWireserverEndpoint] Missing file {0}", file_path)

self.endpoint = KNOWN_WIRESERVER_IP
logger.info("Using hardcoded Wireserver endpoint {0}", self.endpoint)
self.endpoint = KNOWN_WIRESERVER_IP
logger.info("Using hardcoded Wireserver endpoint {0}", self.endpoint)

return self.endpoint
return self.endpoint
finally:
self._lock.release()

def _set_wireserver_endpoint(self, endpoint):
try:
Expand Down Expand Up @@ -302,49 +306,56 @@ def clear_protocol(self):
"""
Cleanup previous saved protocol endpoint.
"""
logger.info("Clean protocol and wireserver endpoint")
self._clear_wireserver_endpoint()
self._protocol = None
protocol_file_path = self._get_protocol_file_path()
if not os.path.isfile(protocol_file_path):
return

self._lock.acquire()
try:
os.remove(protocol_file_path)
except (IOError, OSError) as e:
# Ignore file-not-found errors (since the file is being removed)
if e.errno == errno.ENOENT:
logger.info("Clean protocol and wireserver endpoint")
self._clear_wireserver_endpoint()
self._protocol = None
protocol_file_path = self._get_protocol_file_path()
if not os.path.isfile(protocol_file_path):
return
logger.error("Failed to clear protocol endpoint: {0}", e)

try:
os.remove(protocol_file_path)
except (IOError, OSError) as e:
# Ignore file-not-found errors (since the file is being removed)
if e.errno == errno.ENOENT:
return
logger.error("Failed to clear protocol endpoint: {0}", e)
finally:
self._lock.release()

def get_protocol(self, by_file=False):
"""
Detect protocol by endpoints, if by_file is True,
detect MetadataProtocol in priority.
:returns: protocol instance
"""

if self._protocol is not None:
return self._protocol

self._lock.acquire()
try:
self._protocol = self._get_protocol()
return self._protocol
except ProtocolNotFoundError:
pass
logger.info("Detect protocol endpoints")
protocols = [prots.WireProtocol]

if by_file:
tag_file_path = self._get_tag_file_path()
if os.path.isfile(tag_file_path):
protocols.insert(0, prots.MetadataProtocol)
else:
protocols.append(prots.MetadataProtocol)
protocol_name, protocol = self._detect_protocol(protocols)
if self._protocol is not None:
return self._protocol

IOErrorCounter.set_protocol_endpoint(endpoint=protocol.get_endpoint())
self._save_protocol(protocol_name)
try:
self._protocol = self._get_protocol()
return self._protocol
except ProtocolNotFoundError:
pass
logger.info("Detect protocol endpoints")
protocols = [prots.WireProtocol]

if by_file:
tag_file_path = self._get_tag_file_path()
if os.path.isfile(tag_file_path):
protocols.insert(0, prots.MetadataProtocol)
else:
protocols.append(prots.MetadataProtocol)
protocol_name, protocol = self._detect_protocol(protocols)

IOErrorCounter.set_protocol_endpoint(endpoint=protocol.get_endpoint())
self._save_protocol(protocol_name)

self._protocol = protocol
return self._protocol
self._protocol = protocol
return self._protocol
finally:
self._lock.release()
25 changes: 10 additions & 15 deletions tests/protocol/test_protocol_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,29 +44,24 @@ def test_get_protocol_util_should_return_same_object_for_same_thread(self, _):
self.assertEqual(protocol_util1, protocol_util2)

def test_get_protocol_util_should_return_different_object_for_different_thread(self, _):
def get_util_obj(q, err):
protocol_util_instances = []
errors = []

def get_protocol_util_instance():
try:
q.put(get_protocol_util())
protocol_util_instances.append(get_protocol_util())
except Exception as e:
err.put(str(e))
errors.append(e)

queue = Queue()
errors = Queue()
t1 = Thread(target=get_util_obj, args=(queue, errors))
t2 = Thread(target=get_util_obj, args=(queue, errors))
t1 = Thread(target=get_protocol_util_instance)
t2 = Thread(target=get_protocol_util_instance)
t1.start()
t2.start()
t1.join()
t2.join()

errs = []
while not errors.empty():
errs.append(errors.get())
if len(errs) > 0:
raise Exception("Unable to fetch protocol_util. Errors: %s" % ' , '.join(errs))

self.assertEqual(2, queue.qsize()) # Assert that there are 2 objects in the queue
self.assertNotEqual(queue.get(), queue.get())
self.assertEqual(len(protocol_util_instances), 2, "Could not create the expected number of protocols. Errors: [{0}]".format(errors))
self.assertNotEqual(protocol_util_instances[0], protocol_util_instances[1], "The instances created by different threads should be different")

@patch("azurelinuxagent.common.protocol.util.MetadataProtocol")
@patch("azurelinuxagent.common.protocol.util.WireProtocol")
Expand Down

0 comments on commit f050ec9

Please sign in to comment.