diff --git a/bittensor/axon.py b/bittensor/axon.py index ec1c0d976f..fd2a1b4b9d 100644 --- a/bittensor/axon.py +++ b/bittensor/axon.py @@ -22,6 +22,7 @@ import os import uuid import copy +import json import time import asyncio import inspect @@ -35,10 +36,11 @@ from inspect import signature, Signature, Parameter from fastapi.responses import JSONResponse from substrateinterface import Keypair -from fastapi import FastAPI, APIRouter, Request, Response -from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from fastapi import FastAPI, APIRouter, Request, Response, Depends +from starlette.types import Scope, Message from starlette.responses import Response from starlette.requests import Request +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from typing import Dict, Optional, Tuple, Union, List, Callable, Any @@ -318,7 +320,10 @@ def attach( # Add the endpoint to the router, making it available on both GET and POST methods self.router.add_api_route( - f"/{request_name}", forward_fn, methods=["GET", "POST"] + f"/{request_name}", + forward_fn, + methods=["GET", "POST"], + dependencies=[Depends(self.verify_body_integrity)], ) self.app.include_router(self.router) @@ -475,6 +480,61 @@ def add_args(cls, parser: argparse.ArgumentParser, prefix: str = None): # Exception handling for re-parsing arguments pass + @staticmethod + async def verify_body_integrity(request: Request): + """ + Asynchronously verifies the integrity of the body of a request by comparing the hash of required fields + with the corresponding hashes provided in the request headers. This method is critical for ensuring + that the incoming request payload has not been altered or tampered with during transmission, establishing + a level of trust and security between the sender and receiver in the network. + + Args: + request (Request): The incoming FastAPI request object containing both headers and the request body. + + Returns: + dict: Returns the parsed body of the request as a dictionary if all the hash comparisons match, + indicating that the body is intact and has not been tampered with. + + Raises: + JSONResponse: Raises a JSONResponse with a 400 status code if any of the hash comparisons fail, + indicating a potential integrity issue with the incoming request payload. + The response includes the detailed error message specifying which field has a hash mismatch. + + Example: + Assuming this method is set as a dependency in a route: + + @app.post("/some_endpoint") + async def some_endpoint(body_dict: dict = Depends(verify_body_integrity)): + # body_dict is the parsed body of the request and is available for use in the route function. + # The function only executes if the body integrity verification is successful. + ... + """ + # Extract keys and values that end with '_hash' + hash_headers = {k: v for k, v in request.headers.items() if "_hash_" in k} + hash_headers = {k.split("_")[-1]: v for k, v in hash_headers.items()} + + # Await and load the request body so we can inspect it + body = await request.body() + request_body = body.decode() if isinstance(body, bytes) else body + + # Load the body dict and check if all required field hashes match + body_dict = json.loads(request_body) + for required_field in list(hash_headers): + # Hash the field in the body to compare against the header hashes + field_hash = bittensor.utils.hash(str(body_dict[required_field])) + + # If any hashes fail to match up, return a 400 error as the body is invalid + if field_hash != hash_headers[required_field]: + return JSONResponse( + content={ + "error": f"Hash mismatch with {field_hash} and {getattr(synapse, required_field + '_hash')}" + }, + status_code=400, + ) + + # If body is good, return the parsed body so that it can be injected into the route function + return body_dict + @classmethod def check_config(cls, config: "bittensor.config"): """ @@ -561,7 +621,7 @@ def serve( subtensor.serve_axon(netuid=netuid, axon=self) return self - def default_verify(self, synapse: bittensor.Synapse) -> Request: + def default_verify(self, synapse: bittensor.Synapse): """ This method is used to verify the authenticity of a received message using a digital signature. It ensures that the message was not tampered with and was sent by the expected sender. @@ -584,8 +644,13 @@ def default_verify(self, synapse: bittensor.Synapse) -> Request: # Build the keypair from the dendrite_hotkey keypair = Keypair(ss58_address=synapse.dendrite.hotkey) + # Pull body hashes from synapse recieved with request. + body_hashes = [ + getattr(synapse, field + "_hash") for field in synapse.required_hash_fields + ] + # Build the signature messages. - message = f"{synapse.dendrite.nonce}.{synapse.dendrite.hotkey}.{self.wallet.hotkey.ss58_address}.{synapse.dendrite.uuid}.{synapse.body_hash}" + message = f"{synapse.dendrite.nonce}.{synapse.dendrite.hotkey}.{self.wallet.hotkey.ss58_address}.{synapse.dendrite.uuid}.{body_hashes}" # Build the unique endpoint key. endpoint_key = f"{synapse.dendrite.hotkey}:{synapse.dendrite.uuid}" @@ -649,12 +714,12 @@ async def dispatch( f"axon | <-- | {request.headers.get('content-length', -1)} B | {synapse.name} | {synapse.dendrite.hotkey} | {synapse.dendrite.ip}:{synapse.dendrite.port} | 200 | Success " ) - # Call the verify function - await self.verify(synapse) - # Call the blacklist function await self.blacklist(synapse) + # Call verify and return the verified request + await self.verify(synapse) + # Call the priority function await self.priority(synapse) @@ -802,7 +867,7 @@ async def blacklist(self, synapse: bittensor.Synapse): synapse.axon.status_code = "403" # We raise an exception to halt the process and return the error message to the requester. - raise Exception("Forbidden. Key is blacklisted.") + raise Exception(f"Forbidden. Key is blacklisted: {reason}.") async def priority(self, synapse: bittensor.Synapse): """ diff --git a/bittensor/dendrite.py b/bittensor/dendrite.py index c7692f02c0..eb9c020824 100644 --- a/bittensor/dendrite.py +++ b/bittensor/dendrite.py @@ -347,12 +347,41 @@ def preprocess_synapse_for_request( } ) - # Sign the request using the dendrite and axon information - message = f"{synapse.dendrite.nonce}.{synapse.dendrite.hotkey}.{synapse.axon.hotkey}.{synapse.dendrite.uuid}.{synapse.body_hash}" + # Sign the request using the dendrite, axon info, and the synapse body hashes + body_hashes = self.hash_synapse_body(synapse) + + message = f"{synapse.dendrite.nonce}.{synapse.dendrite.hotkey}.{synapse.axon.hotkey}.{synapse.dendrite.uuid}.{body_hashes}" synapse.dendrite.signature = f"0x{self.keypair.sign(message).hex()}" return synapse + @staticmethod + def hash_synapse_body(synapse: bittensor.Synapse) -> str: + """ + Compute a SHA-256 hash of the serialized body of the Synapse instance. + + The body of the Synapse instance comprises its serialized and encoded + non-optional fields. This property retrieves these fields using the + `get_body` method, then concatenates their string representations, and + finally computes a SHA-256 hash of the resulting string. + + Returns: + str: The hexadecimal representation of the SHA-256 hash of the instance's body. + """ + # Hash the body for verification + hashes = [] + + # Getting the fields of the instance + instance_fields = synapse.__dict__ + + # Iterating over the fields of the instance + for field, value in instance_fields.items(): + # If the field is required in the subclass schema, add it. + if field in synapse.required_hash_fields: + hashes.append(bittensor.utils.hash(str(value))) + + return hashes + def process_server_response( self, server_response: Response, diff --git a/bittensor/synapse.py b/bittensor/synapse.py index 0256af38d8..20a9095e2e 100644 --- a/bittensor/synapse.py +++ b/bittensor/synapse.py @@ -202,37 +202,7 @@ class Config: ) -class ProtectOverride(type): - """ - Metaclass to prevent subclasses from overriding specified methods or attributes. - - When a subclass attempts to override a protected attribute or method, a `TypeError` is raised. - The current implementation specifically checks for overriding the 'body_hash' attribute. - - Overriding `protected_method` in a subclass of `MyClass` will raise a TypeError. - """ - - def __new__(cls, name, bases, class_dict): - # Check if the derived class tries to override the 'body_hash' method or attribute. - if ( - any(base for base in bases if hasattr(base, "body_hash")) - and "body_hash" in class_dict - ): - raise TypeError("You can't override the body_hash attribute!") - return super(ProtectOverride, cls).__new__(cls, name, bases, class_dict) - - -class CombinedMeta(ProtectOverride, type(pydantic.BaseModel)): - """ - Metaclass combining functionality of ProtectOverride and BaseModel's metaclass. - - Inherits the attributes and methods from both parent metaclasses to provide combined behavior. - """ - - pass - - -class Synapse(pydantic.BaseModel, metaclass=CombinedMeta): +class Synapse(pydantic.BaseModel): class Config: validate_assignment = True @@ -311,7 +281,7 @@ def set_name_type(cls, values) -> dict: dendrite: Optional[TerminalInfo] = pydantic.Field( title="dendrite", description="Dendrite Terminal Information", - examples="bt.TerminalInfo", + examples="bittensor.TerminalInfo", default=TerminalInfo(), allow_mutation=True, repr=False, @@ -321,7 +291,7 @@ def set_name_type(cls, values) -> dict: axon: Optional[TerminalInfo] = pydantic.Field( title="axon", description="Axon Terminal Information", - examples="bt.TerminalInfo", + examples="bittensor.TerminalInfo", default=TerminalInfo(), allow_mutation=True, repr=False, @@ -329,12 +299,8 @@ def set_name_type(cls, values) -> dict: def __setattr__(self, name: str, value: Any): """ - Override the __setattr__ method to make the body_hash property read-only. + Override the __setattr__ method to make the `required_hash_fields` property read-only. """ - if name == "body_hash": - raise AttributeError( - "body_hash property is read-only and cannot be overridden." - ) if name == "required_hash_fields": raise AttributeError( "required_hash_fields property is read-only and cannot be overridden." @@ -370,6 +336,7 @@ def required_hash_fields(self) -> List[str]: ) if ( required + and field in required and value != None and field not in [ @@ -380,6 +347,7 @@ def required_hash_fields(self) -> List[str]: "dendrite", "axon", ] + and "_hash" not in field ): fields.append(field) return fields @@ -397,68 +365,6 @@ def get_total_size(self) -> int: self.total_size = get_size(self) return self.total_size - def get_body(self) -> List[Any]: - """ - Retrieve the serialized and encoded non-optional fields of the Synapse instance. - - This method filters through the fields of the Synapse instance and identifies - non-optional attributes that have non-null values, excluding specific attributes - such as `name`, `timeout`, `total_size`, `header_size`, `dendrite`, and `axon`. - It returns a list containing these selected field values. - - Returns: - List[Any]: A list of values from the non-optional fields of the Synapse instance. - - Note: - The determination of whether a field is optional or not is based on the - schema definition for the Synapse class. - """ - fields = [] - - # Getting the fields of the instance - instance_fields = self.__dict__ - - # Iterating over the fields of the instance - for field, value in instance_fields.items(): - # If the field is required in the subclass schema, add it. - if field in self.required_hash_fields: - fields.append(value) - - return fields - - @property - def body_hash(self) -> str: - """ - Compute a SHA-256 hash of the serialized body of the Synapse instance. - - The body of the Synapse instance comprises its serialized and encoded - non-optional fields. This property retrieves these fields using the - `get_body` method, then concatenates their string representations, and - finally computes a SHA-256 hash of the resulting string. - - Note: - This property is intended to be read-only. Any attempts to override - or set its value will raise an AttributeError due to the protections - set in the __setattr__ method. - - Returns: - str: The hexadecimal representation of the SHA-256 hash of the instance's body. - """ - # Hash the body for verification - body = self.get_body() - - # Convert elements to string and concatenate - concat = "".join(map(str, body)) - - # Create a SHA-256 hash object - sha256 = hashlib.sha256() - - # Update the hash object with the concatenated string - sha256.update(concat.encode("utf-8")) - - # Produce the hash - return sha256.hexdigest() - @property def is_success(self) -> bool: """ @@ -599,11 +505,15 @@ def to_headers(self) -> dict: elif required and field in required: try: - serialized_value = json.dumps(value) + # Create an empty (dummy) instance of type(value) to pass pydantic validation on the axon side + serialized_value = json.dumps(value.__class__.__call__()) + # Create a hash of the original data so we can verify on the axon side + hash_value = bittensor.utils.hash(str(value)) encoded_value = base64.b64encode(serialized_value.encode()).decode( "utf-8" ) headers[f"bt_header_input_obj_{field}"] = encoded_value + headers[f"bt_header_input_hash_{field}"] = hash_value except TypeError as e: raise ValueError( f"Error serializing {field} with value {value}. Objects must be json serializable." @@ -720,6 +630,19 @@ def parse_headers_to_inputs(cls, headers: dict) -> dict: f"Error while parsing 'input_obj' header {key}: {e}" ) continue + elif "bt_header_input_hash" in key: + try: + new_key = key.split("bt_header_input_hash_")[1] + "_hash" + # Skip if the key already exists in the dictionary + if new_key in inputs_dict: + continue + # Decode and load the serialized object + inputs_dict[new_key] = value + except Exception as e: + bittensor.logging.error( + f"Error while parsing 'input_hash' header {key}: {e}" + ) + continue else: pass # log unexpected keys diff --git a/bittensor/utils/__init__.py b/bittensor/utils/__init__.py index 65f45ad32f..7b44150ac5 100644 --- a/bittensor/utils/__init__.py +++ b/bittensor/utils/__init__.py @@ -19,6 +19,7 @@ from typing import Callable, Union, List, Optional, Dict, Literal, Type, Any import bittensor +import hashlib import requests import torch import scalecodec @@ -197,3 +198,13 @@ def u8_key_to_ss58(u8_key: List[int]) -> str: """ # First byte is length, then 32 bytes of key. return scalecodec.ss58_encode(bytes(u8_key).hex(), bittensor.__ss58_format__) + + +def hash(content, hash_type="md5", encoding="utf-8"): + algo = hashlib.md5() if hash_type == "md5" else hashlib.sha256() + + # Update the hash object with the concatenated string + algo.update(content.encode(encoding)) + + # Produce the hash + return algo.hexdigest() diff --git a/examples/streaming/prompting.ipynb b/examples/streaming/prompting.ipynb deleted file mode 100644 index 61fa69f6d9..0000000000 --- a/examples/streaming/prompting.ipynb +++ /dev/null @@ -1,203 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import bittensor as bt\n", - "import pydantic\n", - "from starlette.types import Send\n", - "from functools import partial\n", - "from typing import Callable, Awaitable, List\n", - "import asyncio\n", - "\n", - "bt.trace()\n", - "\n", - "\n", - "# This is a subclass of StreamingSynapse for prompting network functionality\n", - "class StreamPrompting(bt.StreamingSynapse):\n", - " \"\"\"\n", - " StreamPrompting is a subclass of StreamingSynapse that is specifically designed for prompting network functionality.\n", - " It overrides abstract methods from the parent class to provide concrete implementations for processing streaming responses,\n", - " deserializing the response, and extracting JSON data.\n", - "\n", - " Attributes:\n", - " roles: List of roles associated with the prompt.\n", - " messages: List of messages to be processed.\n", - " completion: A string to store the completion result.\n", - " \"\"\"\n", - "\n", - " roles: List[str] = pydantic.Field(..., allow_mutation=False)\n", - " messages: List[str] = pydantic.Field(..., allow_mutation=False)\n", - " completion: str = None\n", - "\n", - " async def process_streaming_response(self, response):\n", - " \"\"\"\n", - " Processes the streaming response by iterating through the content and decoding tokens.\n", - " Concatenates the decoded tokens into the completion attribute.\n", - "\n", - " Args:\n", - " response: The response object containing the content to be processed.\n", - " \"\"\"\n", - " if self.completion is None:\n", - " self.completion = \"\"\n", - " async for chunk in response.content.iter_any():\n", - " tokens = chunk.decode('utf-8').split('\\n')\n", - " for token in tokens:\n", - " if token:\n", - " self.completion += token\n", - "\n", - " def deserialize(self):\n", - " \"\"\"\n", - " Deserializes the response by returning the completion attribute.\n", - "\n", - " Returns:\n", - " str: The completion result.\n", - " \"\"\"\n", - " return self.completion\n", - "\n", - " def extract_response_json(self, response):\n", - " \"\"\"\n", - " Extracts JSON data from the response, including headers and specific information related to dendrite and axon.\n", - "\n", - " Args:\n", - " response: The response object from which to extract JSON data.\n", - "\n", - " Returns:\n", - " dict: A dictionary containing extracted JSON data.\n", - " \"\"\"\n", - " headers = {k.decode('utf-8'): v.decode('utf-8') for k, v in response.__dict__[\"_raw_headers\"]}\n", - "\n", - " def extract_info(prefix):\n", - " return {key.split('_')[-1]: value for key, value in headers.items() if key.startswith(prefix)}\n", - "\n", - " return {\n", - " \"name\": headers.get('name', ''),\n", - " \"timeout\": float(headers.get('timeout', 0)),\n", - " \"total_size\": int(headers.get('total_size', 0)),\n", - " \"header_size\": int(headers.get('header_size', 0)),\n", - " \"dendrite\": extract_info('bt_header_dendrite'),\n", - " \"axon\": extract_info('bt_header_axon'),\n", - " \"roles\": self.roles,\n", - " \"messages\": self.messages,\n", - " \"completion\": self.completion,\n", - " }\n", - "\n", - "# This should encapsulate all the logic for generating a streaming response\n", - "def sforward(synapse: StreamPrompting) -> StreamPrompting:\n", - " \"\"\"\n", - " Encapsulates the logic for generating a streaming response. It defines the tokenizer, model, and prompt functions,\n", - " and creates a streaming response using the provided synapse.\n", - "\n", - " Args:\n", - " synapse: A StreamPrompting instance containing the messages to be processed.\n", - "\n", - " Returns:\n", - " StreamPrompting: The streaming response object.\n", - " \"\"\"\n", - " def tokenizer(text):\n", - " return (ord(char) for char in text)\n", - "\n", - " def model(ids):\n", - " return (chr(id) for id in ids)\n", - "\n", - " async def prompt(text: str, send: Send):\n", - " # Simulate model inference\n", - " input_ids = tokenizer(text)\n", - " for token in model(input_ids):\n", - " await send({\"type\": \"http.response.body\", \"body\": (token + '\\n').encode('utf-8'), \"more_body\": True})\n", - " bt.logging.trace(f\"Streamed token: {token}\")\n", - " # Sleep to show the streaming effect\n", - " await asyncio.sleep(0.5)\n", - "\n", - " message = synapse.messages[0]\n", - " token_streamer = partial(prompt, message)\n", - " return synapse.create_streaming_response(token_streamer)\n", - "\n", - "def blacklist(synapse: StreamPrompting) -> bool:\n", - " \"\"\"\n", - " Determines whether the synapse should be blacklisted.\n", - "\n", - " Args:\n", - " synapse: A StreamPrompting instance.\n", - "\n", - " Returns:\n", - " bool: Always returns False, indicating that the synapse should not be blacklisted.\n", - " \"\"\"\n", - " return False\n", - "\n", - "def priority(synapse: StreamPrompting) -> float:\n", - " \"\"\"\n", - " Determines the priority of the synapse.\n", - "\n", - " Args:\n", - " synapse: A StreamPrompting instance.\n", - "\n", - " Returns:\n", - " float: Always returns 0.0, indicating the default priority.\n", - " \"\"\"\n", - " return 0.0\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Create an Axon instance on port 8099.\n", - "axon = bt.axon(port=8099)\n", - "\n", - "# Attach the forward, blacklist, and priority functions to the Axon.\n", - "# forward_fn: The function to handle forwarding logic.\n", - "# blacklist_fn: The function to determine if a request should be blacklisted.\n", - "# priority_fn: The function to determine the priority of the request.\n", - "axon.attach(\n", - " forward_fn=sforward,\n", - " blacklist_fn=blacklist,\n", - " priority_fn=priority\n", - ")\n", - "\n", - "# Start the Axon to begin listening for requests.\n", - "axon.start()\n", - "\n", - "# Create a Dendrite instance to handle client-side communication.\n", - "d = bt.dendrite()\n", - "\n", - "# Send a request to the Axon using the Dendrite, passing in a StreamPrompting instance with roles and messages.\n", - "# The response is awaited, as the Dendrite communicates asynchronously with the Axon.\n", - "resp = await d(\n", - " [axon],\n", - " StreamPrompting(roles=[\"user\"], messages=[\"hello this is a test of streaming.\"])\n", - ")\n", - "\n", - "# The response object contains the result of the streaming operation.\n", - "resp\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "rev", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tests/unit_tests/test_synapse.py b/tests/unit_tests/test_synapse.py index 68f2d9541b..27a81412fe 100644 --- a/tests/unit_tests/test_synapse.py +++ b/tests/unit_tests/test_synapse.py @@ -165,11 +165,11 @@ class Test(bittensor.Synapse): # Create a new Test from the headers and check its properties next_synapse = synapse.from_headers(synapse.to_headers()) - assert next_synapse.a == 1 + assert next_synapse.a == 0 # Default value is 0 assert next_synapse.b == None assert next_synapse.c == None assert next_synapse.d == None - assert next_synapse.e == [1, 2, 3, 4] + assert next_synapse.e == [] # Empty list is default for list types assert next_synapse.f.shape == [10] # Shape is passed through assert next_synapse.f.dtype == "torch.float32" # Type is passed through assert next_synapse.f.buffer == None # Buffer is not passed through @@ -237,15 +237,6 @@ class Test(bittensor.Synapse): assert next_synapse.a["dog"].shape == [11] -def test_override_protection(): - with pytest.raises(TypeError, match="You can't override the body_hash attribute!"): - - class DerivedModel(bittensor.Synapse): - @property - def body_hash(self): - return "new_value" - - def test_body_hash_override(): # Create a Synapse instance synapse_instance = bittensor.Synapse() @@ -253,6 +244,6 @@ def test_body_hash_override(): # Try to set the body_hash property and expect an AttributeError with pytest.raises( AttributeError, - match="body_hash property is read-only and cannot be overridden.", + match="required_hash_fields property is read-only and cannot be overridden.", ): - synapse_instance.body_hash = "some_value" + synapse_instance.required_hash_fields = []