Skip to content

Commit

Permalink
Header fix (#1516)
Browse files Browse the repository at this point in the history
* (WIP) Refactor synapse and dendrite to verify body hash in signature

* EUREKA!

* cleanup print statements

* remove (now) unnecessary metaclass override protection

* remove metaclass from Synapse def

* fix commit

* remove print statements

* run black

* add typehints for request

* (WIP) update synapse tests

* move hash to utils

* remove extra logging

* fix repackaging request, inspect body and hash to compare against header hashes

* remove streaming example, moved to subnet-1 template

* put blacklist first so we don't load the body if we're gonna blacklist anyway

* fix axon tests

* fix bt. import

* fix typehint bt.

* fix test relative to header changes

* run black

* add reason to error logging

* create dependency for all forward_fn routes to verify body before passing to synapse function

* fix test (again)

* add md5 hash option

* remove extra trace
  • Loading branch information
ifrit98 authored Sep 21, 2023
1 parent bd3f253 commit 226c27d
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 328 deletions.
83 changes: 74 additions & 9 deletions bittensor/axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
import uuid
import copy
import json
import time
import asyncio
import inspect
Expand All @@ -35,10 +36,11 @@
from inspect import signature, Signature, Parameter
from fastapi.responses import JSONResponse
from substrateinterface import Keypair
from fastapi import FastAPI, APIRouter, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from fastapi import FastAPI, APIRouter, Request, Response, Depends
from starlette.types import Scope, Message
from starlette.responses import Response
from starlette.requests import Request
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from typing import Dict, Optional, Tuple, Union, List, Callable, Any


Expand Down Expand Up @@ -318,7 +320,10 @@ def attach(

# Add the endpoint to the router, making it available on both GET and POST methods
self.router.add_api_route(
f"/{request_name}", forward_fn, methods=["GET", "POST"]
f"/{request_name}",
forward_fn,
methods=["GET", "POST"],
dependencies=[Depends(self.verify_body_integrity)],
)
self.app.include_router(self.router)

Expand Down Expand Up @@ -475,6 +480,61 @@ def add_args(cls, parser: argparse.ArgumentParser, prefix: str = None):
# Exception handling for re-parsing arguments
pass

@staticmethod
async def verify_body_integrity(request: Request):
"""
Asynchronously verifies the integrity of the body of a request by comparing the hash of required fields
with the corresponding hashes provided in the request headers. This method is critical for ensuring
that the incoming request payload has not been altered or tampered with during transmission, establishing
a level of trust and security between the sender and receiver in the network.
Args:
request (Request): The incoming FastAPI request object containing both headers and the request body.
Returns:
dict: Returns the parsed body of the request as a dictionary if all the hash comparisons match,
indicating that the body is intact and has not been tampered with.
Raises:
JSONResponse: Raises a JSONResponse with a 400 status code if any of the hash comparisons fail,
indicating a potential integrity issue with the incoming request payload.
The response includes the detailed error message specifying which field has a hash mismatch.
Example:
Assuming this method is set as a dependency in a route:
@app.post("/some_endpoint")
async def some_endpoint(body_dict: dict = Depends(verify_body_integrity)):
# body_dict is the parsed body of the request and is available for use in the route function.
# The function only executes if the body integrity verification is successful.
...
"""
# Extract keys and values that end with '_hash'
hash_headers = {k: v for k, v in request.headers.items() if "_hash_" in k}
hash_headers = {k.split("_")[-1]: v for k, v in hash_headers.items()}

# Await and load the request body so we can inspect it
body = await request.body()
request_body = body.decode() if isinstance(body, bytes) else body

# Load the body dict and check if all required field hashes match
body_dict = json.loads(request_body)
for required_field in list(hash_headers):
# Hash the field in the body to compare against the header hashes
field_hash = bittensor.utils.hash(str(body_dict[required_field]))

# If any hashes fail to match up, return a 400 error as the body is invalid
if field_hash != hash_headers[required_field]:
return JSONResponse(
content={
"error": f"Hash mismatch with {field_hash} and {getattr(synapse, required_field + '_hash')}"
},
status_code=400,
)

# If body is good, return the parsed body so that it can be injected into the route function
return body_dict

@classmethod
def check_config(cls, config: "bittensor.config"):
"""
Expand Down Expand Up @@ -561,7 +621,7 @@ def serve(
subtensor.serve_axon(netuid=netuid, axon=self)
return self

def default_verify(self, synapse: bittensor.Synapse) -> Request:
def default_verify(self, synapse: bittensor.Synapse):
"""
This method is used to verify the authenticity of a received message using a digital signature.
It ensures that the message was not tampered with and was sent by the expected sender.
Expand All @@ -584,8 +644,13 @@ def default_verify(self, synapse: bittensor.Synapse) -> Request:
# Build the keypair from the dendrite_hotkey
keypair = Keypair(ss58_address=synapse.dendrite.hotkey)

# Pull body hashes from synapse recieved with request.
body_hashes = [
getattr(synapse, field + "_hash") for field in synapse.required_hash_fields
]

# Build the signature messages.
message = f"{synapse.dendrite.nonce}.{synapse.dendrite.hotkey}.{self.wallet.hotkey.ss58_address}.{synapse.dendrite.uuid}.{synapse.body_hash}"
message = f"{synapse.dendrite.nonce}.{synapse.dendrite.hotkey}.{self.wallet.hotkey.ss58_address}.{synapse.dendrite.uuid}.{body_hashes}"

# Build the unique endpoint key.
endpoint_key = f"{synapse.dendrite.hotkey}:{synapse.dendrite.uuid}"
Expand Down Expand Up @@ -649,12 +714,12 @@ async def dispatch(
f"axon | <-- | {request.headers.get('content-length', -1)} B | {synapse.name} | {synapse.dendrite.hotkey} | {synapse.dendrite.ip}:{synapse.dendrite.port} | 200 | Success "
)

# Call the verify function
await self.verify(synapse)

# Call the blacklist function
await self.blacklist(synapse)

# Call verify and return the verified request
await self.verify(synapse)

# Call the priority function
await self.priority(synapse)

Expand Down Expand Up @@ -802,7 +867,7 @@ async def blacklist(self, synapse: bittensor.Synapse):
synapse.axon.status_code = "403"

# We raise an exception to halt the process and return the error message to the requester.
raise Exception("Forbidden. Key is blacklisted.")
raise Exception(f"Forbidden. Key is blacklisted: {reason}.")

async def priority(self, synapse: bittensor.Synapse):
"""
Expand Down
33 changes: 31 additions & 2 deletions bittensor/dendrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,12 +347,41 @@ def preprocess_synapse_for_request(
}
)

# Sign the request using the dendrite and axon information
message = f"{synapse.dendrite.nonce}.{synapse.dendrite.hotkey}.{synapse.axon.hotkey}.{synapse.dendrite.uuid}.{synapse.body_hash}"
# Sign the request using the dendrite, axon info, and the synapse body hashes
body_hashes = self.hash_synapse_body(synapse)

message = f"{synapse.dendrite.nonce}.{synapse.dendrite.hotkey}.{synapse.axon.hotkey}.{synapse.dendrite.uuid}.{body_hashes}"
synapse.dendrite.signature = f"0x{self.keypair.sign(message).hex()}"

return synapse

@staticmethod
def hash_synapse_body(synapse: bittensor.Synapse) -> str:
"""
Compute a SHA-256 hash of the serialized body of the Synapse instance.
The body of the Synapse instance comprises its serialized and encoded
non-optional fields. This property retrieves these fields using the
`get_body` method, then concatenates their string representations, and
finally computes a SHA-256 hash of the resulting string.
Returns:
str: The hexadecimal representation of the SHA-256 hash of the instance's body.
"""
# Hash the body for verification
hashes = []

# Getting the fields of the instance
instance_fields = synapse.__dict__

# Iterating over the fields of the instance
for field, value in instance_fields.items():
# If the field is required in the subclass schema, add it.
if field in synapse.required_hash_fields:
hashes.append(bittensor.utils.hash(str(value)))

return hashes

def process_server_response(
self,
server_response: Response,
Expand Down
125 changes: 24 additions & 101 deletions bittensor/synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,37 +202,7 @@ class Config:
)


class ProtectOverride(type):
"""
Metaclass to prevent subclasses from overriding specified methods or attributes.
When a subclass attempts to override a protected attribute or method, a `TypeError` is raised.
The current implementation specifically checks for overriding the 'body_hash' attribute.
Overriding `protected_method` in a subclass of `MyClass` will raise a TypeError.
"""

def __new__(cls, name, bases, class_dict):
# Check if the derived class tries to override the 'body_hash' method or attribute.
if (
any(base for base in bases if hasattr(base, "body_hash"))
and "body_hash" in class_dict
):
raise TypeError("You can't override the body_hash attribute!")
return super(ProtectOverride, cls).__new__(cls, name, bases, class_dict)


class CombinedMeta(ProtectOverride, type(pydantic.BaseModel)):
"""
Metaclass combining functionality of ProtectOverride and BaseModel's metaclass.
Inherits the attributes and methods from both parent metaclasses to provide combined behavior.
"""

pass


class Synapse(pydantic.BaseModel, metaclass=CombinedMeta):
class Synapse(pydantic.BaseModel):
class Config:
validate_assignment = True

Expand Down Expand Up @@ -311,7 +281,7 @@ def set_name_type(cls, values) -> dict:
dendrite: Optional[TerminalInfo] = pydantic.Field(
title="dendrite",
description="Dendrite Terminal Information",
examples="bt.TerminalInfo",
examples="bittensor.TerminalInfo",
default=TerminalInfo(),
allow_mutation=True,
repr=False,
Expand All @@ -321,20 +291,16 @@ def set_name_type(cls, values) -> dict:
axon: Optional[TerminalInfo] = pydantic.Field(
title="axon",
description="Axon Terminal Information",
examples="bt.TerminalInfo",
examples="bittensor.TerminalInfo",
default=TerminalInfo(),
allow_mutation=True,
repr=False,
)

def __setattr__(self, name: str, value: Any):
"""
Override the __setattr__ method to make the body_hash property read-only.
Override the __setattr__ method to make the `required_hash_fields` property read-only.
"""
if name == "body_hash":
raise AttributeError(
"body_hash property is read-only and cannot be overridden."
)
if name == "required_hash_fields":
raise AttributeError(
"required_hash_fields property is read-only and cannot be overridden."
Expand Down Expand Up @@ -370,6 +336,7 @@ def required_hash_fields(self) -> List[str]:
)
if (
required
and field in required
and value != None
and field
not in [
Expand All @@ -380,6 +347,7 @@ def required_hash_fields(self) -> List[str]:
"dendrite",
"axon",
]
and "_hash" not in field
):
fields.append(field)
return fields
Expand All @@ -397,68 +365,6 @@ def get_total_size(self) -> int:
self.total_size = get_size(self)
return self.total_size

def get_body(self) -> List[Any]:
"""
Retrieve the serialized and encoded non-optional fields of the Synapse instance.
This method filters through the fields of the Synapse instance and identifies
non-optional attributes that have non-null values, excluding specific attributes
such as `name`, `timeout`, `total_size`, `header_size`, `dendrite`, and `axon`.
It returns a list containing these selected field values.
Returns:
List[Any]: A list of values from the non-optional fields of the Synapse instance.
Note:
The determination of whether a field is optional or not is based on the
schema definition for the Synapse class.
"""
fields = []

# Getting the fields of the instance
instance_fields = self.__dict__

# Iterating over the fields of the instance
for field, value in instance_fields.items():
# If the field is required in the subclass schema, add it.
if field in self.required_hash_fields:
fields.append(value)

return fields

@property
def body_hash(self) -> str:
"""
Compute a SHA-256 hash of the serialized body of the Synapse instance.
The body of the Synapse instance comprises its serialized and encoded
non-optional fields. This property retrieves these fields using the
`get_body` method, then concatenates their string representations, and
finally computes a SHA-256 hash of the resulting string.
Note:
This property is intended to be read-only. Any attempts to override
or set its value will raise an AttributeError due to the protections
set in the __setattr__ method.
Returns:
str: The hexadecimal representation of the SHA-256 hash of the instance's body.
"""
# Hash the body for verification
body = self.get_body()

# Convert elements to string and concatenate
concat = "".join(map(str, body))

# Create a SHA-256 hash object
sha256 = hashlib.sha256()

# Update the hash object with the concatenated string
sha256.update(concat.encode("utf-8"))

# Produce the hash
return sha256.hexdigest()

@property
def is_success(self) -> bool:
"""
Expand Down Expand Up @@ -599,11 +505,15 @@ def to_headers(self) -> dict:

elif required and field in required:
try:
serialized_value = json.dumps(value)
# Create an empty (dummy) instance of type(value) to pass pydantic validation on the axon side
serialized_value = json.dumps(value.__class__.__call__())
# Create a hash of the original data so we can verify on the axon side
hash_value = bittensor.utils.hash(str(value))
encoded_value = base64.b64encode(serialized_value.encode()).decode(
"utf-8"
)
headers[f"bt_header_input_obj_{field}"] = encoded_value
headers[f"bt_header_input_hash_{field}"] = hash_value
except TypeError as e:
raise ValueError(
f"Error serializing {field} with value {value}. Objects must be json serializable."
Expand Down Expand Up @@ -720,6 +630,19 @@ def parse_headers_to_inputs(cls, headers: dict) -> dict:
f"Error while parsing 'input_obj' header {key}: {e}"
)
continue
elif "bt_header_input_hash" in key:
try:
new_key = key.split("bt_header_input_hash_")[1] + "_hash"
# Skip if the key already exists in the dictionary
if new_key in inputs_dict:
continue
# Decode and load the serialized object
inputs_dict[new_key] = value
except Exception as e:
bittensor.logging.error(
f"Error while parsing 'input_hash' header {key}: {e}"
)
continue
else:
pass # log unexpected keys

Expand Down
Loading

0 comments on commit 226c27d

Please sign in to comment.