diff --git a/Dockerfile b/Dockerfile index 78855f3a8a..8043c09f78 100644 --- a/Dockerfile +++ b/Dockerfile @@ -34,9 +34,10 @@ RUN bash -c "source $HOME/.nvm/nvm.sh && \ # install pm2 npm install --location=global pm2" -RUN mkdir -p /root/.bittensor/bittensor -RUN cd ~/.bittensor/bittensor && \ - python3 -m pip install bittensor +RUN mkdir -p /root/.bittensor/ +RUN cd /root/.bittensor/ && \ + git clone https://github.com/opentensor/bittensor.git bittensor && \ + python3 -m pip install -e bittensor # Increase ulimit to 1,000,000 RUN prlimit --pid=$PPID --nofile=1000000 diff --git a/bittensor/_axon/__init__.py b/bittensor/_axon/__init__.py index ea422d0a0a..b2f2d30a22 100644 --- a/bittensor/_axon/__init__.py +++ b/bittensor/_axon/__init__.py @@ -24,7 +24,7 @@ import inspect import time from concurrent import futures -from typing import List, Callable, Optional +from typing import Dict, List, Callable, Optional, Tuple, Union from bittensor._threadpool import prioritythreadpool import torch @@ -339,101 +339,114 @@ def check_forward_callback( forward_callback:Callable, synapses:list = []): forward_callback([sample_input], synapses, hotkey='') class AuthInterceptor(grpc.ServerInterceptor): - """ Creates a new server interceptor that authenticates incoming messages from passed arguments. - """ - def __init__(self, key:str = 'Bittensor',blacklist:List = []): - r""" Creates a new server interceptor that authenticates incoming messages from passed arguments. + """Creates a new server interceptor that authenticates incoming messages from passed arguments.""" + + def __init__(self, key: str = "Bittensor", blacklist: List = []): + r"""Creates a new server interceptor that authenticates incoming messages from passed arguments. Args: key (str, `optional`): - key for authentication header in the metadata (default= Bittensor) - black_list (Fucntion, `optional`): + key for authentication header in the metadata (default = Bittensor) + black_list (Function, `optional`): black list function that prevents certain pubkeys from sending messages """ super().__init__() - self._valid_metadata = ('rpc-auth-header', key) - self.nounce_dic = {} - self.message = 'Invalid key' + self.auth_header_value = key + self.nonces = {} self.blacklist = blacklist - def deny(_, context): - context.abort(grpc.StatusCode.UNAUTHENTICATED, self.message) - self._deny = grpc.unary_unary_rpc_method_handler(deny) - - def intercept_service(self, continuation, handler_call_details): - r""" Authentication between bittensor nodes. Intercepts messages and checks them - """ - meta = handler_call_details.invocation_metadata + def parse_legacy_signature( + self, signature: str + ) -> Union[Tuple[int, str, str, str], None]: + r"""Attempts to parse a signature using the legacy format, using `bitxx` as a separator""" + parts = signature.split("bitxx") + if len(parts) < 4: + return None + try: + nonce = int(parts[0]) + parts = parts[1:] + except ValueError: + return None + receptor_uuid, parts = parts[-1], parts[:-1] + message, parts = parts[-1], parts[:-1] + pubkey = "".join(parts) + return (nonce, pubkey, message, receptor_uuid) + + def parse_signature(self, metadata: Dict[str, str]) -> Tuple[int, str, str, str]: + r"""Attempts to parse a signature from the metadata""" + signature = metadata.get("bittensor-signature") + if signature is None: + raise Exception("Request signature missing") + parts = self.parse_legacy_signature(signature) + if parts is not None: + return parts + raise Exception("Unknown signature format") + + def check_signature( + self, nonce: int, pubkey: str, signature: str, receptor_uuid: str + ): + r"""verification of signature in metadata. Uses the pubkey and nonce""" + keypair = Keypair(ss58_address=pubkey) + # Build the expected message which was used to build the signature. + message = f"{nonce}{pubkey}{receptor_uuid}" + # Build the key which uniquely identifies the endpoint that has signed + # the message. + endpoint_key = f"{pubkey}:{receptor_uuid}" + + if endpoint_key in self.nonces.keys(): + previous_nonce = self.nonces[endpoint_key] + # Nonces must be strictly monotonic over time. + if nonce - previous_nonce <= -10: + raise Exception("Nonce is too small") + if not keypair.verify(message, signature): + raise Exception("Signature mismatch") + self.nonces[endpoint_key] = nonce + return + + if not keypair.verify(message, signature): + raise Exception("Signature mismatch") + self.nonces[endpoint_key] = nonce + + def version_checking(self, metadata: Dict[str, str]): + r"""Checks the header and version in the metadata""" + provided_value = metadata.get("rpc-auth-header") + if provided_value is None or provided_value != self.auth_header_value: + raise Exception("Unexpected caller metadata") + + def black_list_checking(self, pubkey: str, method: str): + r"""Tries to call to blacklist function in the miner and checks if it should blacklist the pubkey""" + if self.blacklist == None: + return - try: - #version checking - self.version_checking(meta) + request_type = { + "/Bittensor/Forward": bittensor.proto.RequestType.FORWARD, + "/Bittensor/Backward": bittensor.proto.RequestType.BACKWARD, + }.get(method) + if request_type is None: + raise Exception("Unknown request type") - #signature checking - self.signature_checking(meta) + if self.blacklist(pubkey, request_type): + raise Exception("Request type is blacklisted") - #blacklist checking - self.black_list_checking(meta) + def intercept_service(self, continuation, handler_call_details): + r"""Authentication between bittensor nodes. Intercepts messages and checks them""" + method = handler_call_details.method + metadata = dict(handler_call_details.invocation_metadata) - return continuation(handler_call_details) + try: + # version checking + self.version_checking(metadata) - except Exception as e: - self.message = str(e) - return self._deny + (nonce, pubkey, signature, receptor_uuid) = self.parse_signature(metadata) - def vertification(self,meta): - r"""vertification of signature in metadata. Uses the pubkey and nounce - """ - variable_length_messages = meta[1].value.split('bitxx') - nounce = int(variable_length_messages[0]) - pubkey = variable_length_messages[1] - message = variable_length_messages[2] - unique_receptor_uid = variable_length_messages[3] - _keypair = Keypair(ss58_address=pubkey) - - # Unique key that specifies the endpoint. - endpoint_key = str(pubkey) + str(unique_receptor_uid) - - #checking the time of creation, compared to previous messages - if endpoint_key in self.nounce_dic.keys(): - prev_data_time = self.nounce_dic[ endpoint_key ] - if nounce - prev_data_time > -10: - self.nounce_dic[ endpoint_key ] = nounce - - #decrypting the message and verify that message is correct - verification = _keypair.verify( str(nounce) + str(pubkey) + str(unique_receptor_uid), message) - else: - verification = False - else: - self.nounce_dic[ endpoint_key ] = nounce - verification = _keypair.verify( str( nounce ) + str(pubkey) + str(unique_receptor_uid), message) + # signature checking + self.check_signature(nonce, pubkey, signature, receptor_uuid) - return verification + # blacklist checking + self.black_list_checking(pubkey, method) - def signature_checking(self,meta): - r""" Calls the vertification of the signature and raises an error if failed - """ - if self.vertification(meta): - pass - else: - raise Exception('Incorrect Signature') - - def version_checking(self,meta): - r""" Checks the header and version in the metadata - """ - if meta[0] == self._valid_metadata: - pass - else: - raise Exception('Incorrect Metadata format') + return continuation(handler_call_details) - def black_list_checking(self,meta): - r"""Tries to call to blacklist function in the miner and checks if it should blacklist the pubkey - """ - variable_length_messages = meta[1].value.split('bitxx') - pubkey = variable_length_messages[1] - - if self.blacklist == None: - pass - elif self.blacklist(pubkey,int(meta[3].value)): - raise Exception('Black listed') - else: - pass + except Exception as e: + message = str(e) + abort = lambda _, ctx: ctx.abort(grpc.StatusCode.UNAUTHENTICATED, message) + return grpc.unary_unary_rpc_method_handler(abort) diff --git a/bittensor/_receptor/receptor_impl.py b/bittensor/_receptor/receptor_impl.py index d064fb8ef7..821a3f2bdf 100644 --- a/bittensor/_receptor/receptor_impl.py +++ b/bittensor/_receptor/receptor_impl.py @@ -664,7 +664,7 @@ def finalize_stats_and_logs(): ('rpc-auth-header','Bittensor'), ('bittensor-signature',self.sign()), ('bittensor-version',str(bittensor.__version_as_int__)), - ('request_type', str(bittensor.proto.RequestType.FORWARD)), + ('request_type', str(bittensor.proto.RequestType.BACKWARD)), )) asyncio_future.cancel()