Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Header fix #1516

Merged
merged 27 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
afe279d
(WIP) Refactor synapse and dendrite to verify body hash in signature
ifrit98 Sep 14, 2023
a968387
EUREKA!
ifrit98 Sep 14, 2023
f5107a5
cleanup print statements
ifrit98 Sep 14, 2023
4cc7140
remove (now) unnecessary metaclass override protection
ifrit98 Sep 14, 2023
4a46d74
remove metaclass from Synapse def
ifrit98 Sep 14, 2023
68df89b
remove metaclass from Synapse def
ifrit98 Sep 14, 2023
c886926
fix commit
ifrit98 Sep 14, 2023
c125bc9
remove print statements
ifrit98 Sep 14, 2023
b2134fb
run black
ifrit98 Sep 14, 2023
3cd4ae3
add typehints for request
ifrit98 Sep 14, 2023
7e74974
(WIP) update synapse tests
ifrit98 Sep 14, 2023
c063e91
move hash to utils
ifrit98 Sep 14, 2023
4235b47
remove extra logging
ifrit98 Sep 14, 2023
c88d560
fix repackaging request, inspect body and hash to compare against hea…
ifrit98 Sep 15, 2023
e85b8b2
remove streaming example, moved to subnet-1 template
ifrit98 Sep 18, 2023
199c2e4
put blacklist first so we don't load the body if we're gonna blacklis…
ifrit98 Sep 18, 2023
09a3823
fix axon tests
ifrit98 Sep 19, 2023
5d59788
merge revolution in
ifrit98 Sep 19, 2023
8d1fc5f
fix bt. import
ifrit98 Sep 20, 2023
850052e
fix typehint bt.
ifrit98 Sep 20, 2023
99e2179
fix test relative to header changes
ifrit98 Sep 20, 2023
827da9a
run black
ifrit98 Sep 20, 2023
5e25f2e
add reason to error logging
ifrit98 Sep 20, 2023
fa88ea1
create dependency for all forward_fn routes to verify body before pas…
ifrit98 Sep 21, 2023
059203c
fix test (again)
ifrit98 Sep 21, 2023
cd4af02
add md5 hash option
ifrit98 Sep 21, 2023
bce6ad6
remove extra trace
ifrit98 Sep 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 75 additions & 9 deletions bittensor/axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
import uuid
import copy
import json
import time
import asyncio
import inspect
Expand All @@ -35,10 +36,11 @@
from inspect import signature, Signature, Parameter
from fastapi.responses import JSONResponse
from substrateinterface import Keypair
from fastapi import FastAPI, APIRouter, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from fastapi import FastAPI, APIRouter, Request, Response, Depends
from starlette.types import Scope, Message
from starlette.responses import Response
from starlette.requests import Request
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from typing import Dict, Optional, Tuple, Union, List, Callable, Any


Expand Down Expand Up @@ -318,7 +320,10 @@ def attach(

# Add the endpoint to the router, making it available on both GET and POST methods
self.router.add_api_route(
f"/{request_name}", forward_fn, methods=["GET", "POST"]
f"/{request_name}",
forward_fn,
methods=["GET", "POST"],
dependencies=[Depends(self.verify_body_integrity)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really clean design! The api route now depends on the verify_body_integrity?

Copy link
Contributor Author

@ifrit98 ifrit98 Sep 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yessir! Seems to be pretty fast in my initial testing as well. We await there but don't need to repackage, just simply passes the body dict on to the function.

)
self.app.include_router(self.router)

Expand Down Expand Up @@ -475,6 +480,61 @@ def add_args(cls, parser: argparse.ArgumentParser, prefix: str = None):
# Exception handling for re-parsing arguments
pass

@staticmethod
async def verify_body_integrity(request: Request):
"""
Asynchronously verifies the integrity of the body of a request by comparing the hash of required fields
with the corresponding hashes provided in the request headers. This method is critical for ensuring
that the incoming request payload has not been altered or tampered with during transmission, establishing
a level of trust and security between the sender and receiver in the network.

Args:
request (Request): The incoming FastAPI request object containing both headers and the request body.

Returns:
dict: Returns the parsed body of the request as a dictionary if all the hash comparisons match,
indicating that the body is intact and has not been tampered with.

Raises:
JSONResponse: Raises a JSONResponse with a 400 status code if any of the hash comparisons fail,
indicating a potential integrity issue with the incoming request payload.
The response includes the detailed error message specifying which field has a hash mismatch.

Example:
Assuming this method is set as a dependency in a route:

@app.post("/some_endpoint")
async def some_endpoint(body_dict: dict = Depends(verify_body_integrity)):
# body_dict is the parsed body of the request and is available for use in the route function.
# The function only executes if the body integrity verification is successful.
...
"""
# Extract keys and values that end with '_hash'
hash_headers = {k: v for k, v in request.headers.items() if "_hash_" in k}
hash_headers = {k.split("_")[-1]: v for k, v in hash_headers.items()}

# Await and load the request body so we can inspect it
body = await request.body()
request_body = body.decode() if isinstance(body, bytes) else body

# Load the body dict and check if all required field hashes match
body_dict = json.loads(request_body)
for required_field in list(hash_headers):
# Hash the field in the body to compare against the header hashes
field_hash = bittensor.utils.hash(str(body_dict[required_field]))

# If any hashes fail to match up, return a 400 error as the body is invalid
if field_hash != hash_headers[required_field]:
return JSONResponse(
content={
"error": f"Hash mismatch with {field_hash} and {getattr(synapse, required_field + '_hash')}"
},
status_code=400,
)

# If body is good, return the parsed body so that it can be injected into the route function
return body_dict

@classmethod
def check_config(cls, config: "bittensor.config"):
"""
Expand Down Expand Up @@ -561,7 +621,7 @@ def serve(
subtensor.serve_axon(netuid=netuid, axon=self)
return self

def default_verify(self, synapse: bittensor.Synapse) -> Request:
def default_verify(self, synapse: bittensor.Synapse):
"""
This method is used to verify the authenticity of a received message using a digital signature.
It ensures that the message was not tampered with and was sent by the expected sender.
Expand All @@ -584,8 +644,13 @@ def default_verify(self, synapse: bittensor.Synapse) -> Request:
# Build the keypair from the dendrite_hotkey
keypair = Keypair(ss58_address=synapse.dendrite.hotkey)

# Pull body hashes from synapse recieved with request.
body_hashes = [
getattr(synapse, field + "_hash") for field in synapse.required_hash_fields
]

# Build the signature messages.
message = f"{synapse.dendrite.nonce}.{synapse.dendrite.hotkey}.{self.wallet.hotkey.ss58_address}.{synapse.dendrite.uuid}.{synapse.body_hash}"
message = f"{synapse.dendrite.nonce}.{synapse.dendrite.hotkey}.{self.wallet.hotkey.ss58_address}.{synapse.dendrite.uuid}.{body_hashes}"

# Build the unique endpoint key.
endpoint_key = f"{synapse.dendrite.hotkey}:{synapse.dendrite.uuid}"
Expand Down Expand Up @@ -649,12 +714,12 @@ async def dispatch(
f"axon | <-- | {request.headers.get('content-length', -1)} B | {synapse.name} | {synapse.dendrite.hotkey} | {synapse.dendrite.ip}:{synapse.dendrite.port} | 200 | Success "
)

# Call the verify function
await self.verify(synapse)

# Call the blacklist function
await self.blacklist(synapse)

# Call verify and return the verified request
await self.verify(synapse)

# Call the priority function
await self.priority(synapse)

Expand Down Expand Up @@ -710,6 +775,7 @@ async def preprocess(self, request: Request) -> bittensor.Synapse:
synapse = self.axon.forward_class_types[request_name].from_headers(
request.headers
)
bittensor.logging.trace(f"request.headers {request.headers}")
synapse.name = request_name

# Fills the local axon information into the synapse.
Expand Down Expand Up @@ -809,7 +875,7 @@ async def blacklist(self, synapse: bittensor.Synapse):
synapse.axon.status_code = "403"

# We raise an exception to halt the process and return the error message to the requester.
raise Exception("Forbidden. Key is blacklisted.")
raise Exception(f"Forbidden. Key is blacklisted: {reason}.")

async def priority(self, synapse: bittensor.Synapse):
"""
Expand Down
33 changes: 31 additions & 2 deletions bittensor/dendrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,41 @@ def preprocess_synapse_for_request(
}
)

# Sign the request using the dendrite and axon information
message = f"{synapse.dendrite.nonce}.{synapse.dendrite.hotkey}.{synapse.axon.hotkey}.{synapse.dendrite.uuid}.{synapse.body_hash}"
# Sign the request using the dendrite, axon info, and the synapse body hashes
body_hashes = self.hash_synapse_body(synapse)

message = f"{synapse.dendrite.nonce}.{synapse.dendrite.hotkey}.{synapse.axon.hotkey}.{synapse.dendrite.uuid}.{body_hashes}"
synapse.dendrite.signature = f"0x{self.keypair.sign(message).hex()}"

return synapse

@staticmethod
def hash_synapse_body(synapse: bittensor.Synapse) -> str:
"""
Compute a SHA-256 hash of the serialized body of the Synapse instance.

The body of the Synapse instance comprises its serialized and encoded
non-optional fields. This property retrieves these fields using the
`get_body` method, then concatenates their string representations, and
finally computes a SHA-256 hash of the resulting string.

Returns:
str: The hexadecimal representation of the SHA-256 hash of the instance's body.
"""
# Hash the body for verification
hashes = []

# Getting the fields of the instance
instance_fields = synapse.__dict__

# Iterating over the fields of the instance
for field, value in instance_fields.items():
# If the field is required in the subclass schema, add it.
if field in synapse.required_hash_fields:
hashes.append(bittensor.utils.hash(str(value)))

return hashes

def process_server_response(
self,
server_response: Response,
Expand Down
125 changes: 24 additions & 101 deletions bittensor/synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,37 +202,7 @@ class Config:
)


class ProtectOverride(type):
"""
Metaclass to prevent subclasses from overriding specified methods or attributes.
When a subclass attempts to override a protected attribute or method, a `TypeError` is raised.
The current implementation specifically checks for overriding the 'body_hash' attribute.
Overriding `protected_method` in a subclass of `MyClass` will raise a TypeError.
"""

def __new__(cls, name, bases, class_dict):
# Check if the derived class tries to override the 'body_hash' method or attribute.
if (
any(base for base in bases if hasattr(base, "body_hash"))
and "body_hash" in class_dict
):
raise TypeError("You can't override the body_hash attribute!")
return super(ProtectOverride, cls).__new__(cls, name, bases, class_dict)


class CombinedMeta(ProtectOverride, type(pydantic.BaseModel)):
"""
Metaclass combining functionality of ProtectOverride and BaseModel's metaclass.
Inherits the attributes and methods from both parent metaclasses to provide combined behavior.
"""

pass


class Synapse(pydantic.BaseModel, metaclass=CombinedMeta):
class Synapse(pydantic.BaseModel):
class Config:
validate_assignment = True

Expand Down Expand Up @@ -311,7 +281,7 @@ def set_name_type(cls, values) -> dict:
dendrite: Optional[TerminalInfo] = pydantic.Field(
title="dendrite",
description="Dendrite Terminal Information",
examples="bt.TerminalInfo",
examples="bittensor.TerminalInfo",
default=TerminalInfo(),
allow_mutation=True,
repr=False,
Expand All @@ -321,20 +291,16 @@ def set_name_type(cls, values) -> dict:
axon: Optional[TerminalInfo] = pydantic.Field(
title="axon",
description="Axon Terminal Information",
examples="bt.TerminalInfo",
examples="bittensor.TerminalInfo",
default=TerminalInfo(),
allow_mutation=True,
repr=False,
)

def __setattr__(self, name: str, value: Any):
"""
Override the __setattr__ method to make the body_hash property read-only.
Override the __setattr__ method to make the `required_hash_fields` property read-only.
"""
if name == "body_hash":
raise AttributeError(
"body_hash property is read-only and cannot be overridden."
)
if name == "required_hash_fields":
raise AttributeError(
"required_hash_fields property is read-only and cannot be overridden."
Expand Down Expand Up @@ -370,6 +336,7 @@ def required_hash_fields(self) -> List[str]:
)
if (
required
and field in required
and value != None
and field
not in [
Expand All @@ -380,6 +347,7 @@ def required_hash_fields(self) -> List[str]:
"dendrite",
"axon",
]
and "_hash" not in field
):
fields.append(field)
return fields
Expand All @@ -397,68 +365,6 @@ def get_total_size(self) -> int:
self.total_size = get_size(self)
return self.total_size

def get_body(self) -> List[Any]:
"""
Retrieve the serialized and encoded non-optional fields of the Synapse instance.
This method filters through the fields of the Synapse instance and identifies
non-optional attributes that have non-null values, excluding specific attributes
such as `name`, `timeout`, `total_size`, `header_size`, `dendrite`, and `axon`.
It returns a list containing these selected field values.
Returns:
List[Any]: A list of values from the non-optional fields of the Synapse instance.
Note:
The determination of whether a field is optional or not is based on the
schema definition for the Synapse class.
"""
fields = []

# Getting the fields of the instance
instance_fields = self.__dict__

# Iterating over the fields of the instance
for field, value in instance_fields.items():
# If the field is required in the subclass schema, add it.
if field in self.required_hash_fields:
fields.append(value)

return fields

@property
def body_hash(self) -> str:
"""
Compute a SHA-256 hash of the serialized body of the Synapse instance.
The body of the Synapse instance comprises its serialized and encoded
non-optional fields. This property retrieves these fields using the
`get_body` method, then concatenates their string representations, and
finally computes a SHA-256 hash of the resulting string.
Note:
This property is intended to be read-only. Any attempts to override
or set its value will raise an AttributeError due to the protections
set in the __setattr__ method.
Returns:
str: The hexadecimal representation of the SHA-256 hash of the instance's body.
"""
# Hash the body for verification
body = self.get_body()

# Convert elements to string and concatenate
concat = "".join(map(str, body))

# Create a SHA-256 hash object
sha256 = hashlib.sha256()

# Update the hash object with the concatenated string
sha256.update(concat.encode("utf-8"))

# Produce the hash
return sha256.hexdigest()

@property
def is_success(self) -> bool:
"""
Expand Down Expand Up @@ -599,11 +505,15 @@ def to_headers(self) -> dict:

elif required and field in required:
try:
serialized_value = json.dumps(value)
# Create an empty (dummy) instance of type(value) to pass pydantic validation on the axon side
serialized_value = json.dumps(value.__class__.__call__())
# Create a hash of the original data so we can verify on the axon side
hash_value = bittensor.utils.hash(str(value))
encoded_value = base64.b64encode(serialized_value.encode()).decode(
"utf-8"
)
headers[f"bt_header_input_obj_{field}"] = encoded_value
headers[f"bt_header_input_hash_{field}"] = hash_value
except TypeError as e:
raise ValueError(
f"Error serializing {field} with value {value}. Objects must be json serializable."
Expand Down Expand Up @@ -720,6 +630,19 @@ def parse_headers_to_inputs(cls, headers: dict) -> dict:
f"Error while parsing 'input_obj' header {key}: {e}"
)
continue
elif "bt_header_input_hash" in key:
try:
new_key = key.split("bt_header_input_hash_")[1] + "_hash"
# Skip if the key already exists in the dictionary
if new_key in inputs_dict:
continue
# Decode and load the serialized object
inputs_dict[new_key] = value
except Exception as e:
bittensor.logging.error(
f"Error while parsing 'input_hash' header {key}: {e}"
)
continue
else:
pass # log unexpected keys

Expand Down
Loading