Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Root Metagraph fixes #1524

Merged
merged 1 commit into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 51 additions & 9 deletions bittensor/metagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def sync(
block: Optional[int] = None,
lite: bool = True,
subtensor: Optional["bittensor.subtensor"] = None,
root: bool = False,
) -> "metagraph":
"""
Initiates the synchronization process of the metagraph.
Expand All @@ -353,7 +354,7 @@ def sync(
metagraph: Updated metagraph object.
"""
# Initialize subtensor
subtensor = self._initialize_subtensor(subtensor)
self.subtensor = self._initialize_subtensor(subtensor)

# Assign neurons based on 'lite' flag
self._assign_neurons(block, lite, subtensor)
Expand All @@ -363,7 +364,7 @@ def sync(

# If not a 'lite' version, compute and set weights and bonds for each neuron
if not lite:
self._set_weights_and_bonds()
self._set_weights_and_bonds(root=root)

def _initialize_subtensor(self, subtensor):
"""
Expand Down Expand Up @@ -473,20 +474,25 @@ def _create_tensor(self, data, dtype) -> torch.nn.Parameter:
# TODO: Check and test the creation of tensor
return torch.nn.Parameter(torch.tensor(data, dtype=dtype), requires_grad=False)

def _set_weights_and_bonds(self):
def _set_weights_and_bonds(self, root:bool =False):
"""
Computes and sets weights and bonds for each neuron.

Returns:
None.
"""
# TODO: Check and test the computation of weights and bonds
self.weights = self._process_weights_or_bonds(
[neuron.weights for neuron in self.neurons], "weights"
)
self.bonds = self._process_weights_or_bonds(
[neuron.bonds for neuron in self.neurons], "bonds"
)
if root:
self.weights = self._process_root_weights(
[neuron.weights for neuron in self.neurons], "weights"
)
else:
self.weights = self._process_weights_or_bonds(
[neuron.weights for neuron in self.neurons], "weights"
)
self.bonds = self._process_weights_or_bonds(
[neuron.bonds for neuron in self.neurons], "bonds"
)

def _process_weights_or_bonds(self, data, attribute: str) -> torch.nn.Parameter:
"""
Expand Down Expand Up @@ -528,7 +534,43 @@ def _process_weights_or_bonds(self, data, attribute: str) -> torch.nn.Parameter:
f"Empty {attribute}_array on metagraph.sync(). The '{attribute}' tensor is empty."
)
return tensor_param

def _process_root_weights(self, data, attribute: str) -> torch.nn.Parameter:
"""
Processes root weights based on the given attribute.

Args:
data: The weights or bonds data to be processed.
attribute: The attribute to decide the type of processing ('weights' or 'bonds').

Returns:
The processed tensor parameter.
"""
data_array = []
n_subnets = self.subtensor.get_total_subnets()
for item in data:
if len(item) == 0:
data_array.append(torch.zeros(n_subnets))
else:
uids, values = zip(*item)
# TODO: Validate and test the conversion of uids and values to tensor
data_array.append(
bittensor.utils.weight_utils.convert_weight_uids_and_vals_to_tensor(
n_subnets, uids, values
)
)

tensor_param = (
torch.nn.Parameter(torch.stack(data_array), requires_grad=False)
if len(data_array)
else torch.nn.Parameter()
)
if len(data_array) == 0:
bittensor.logging.warning(
f"Empty {attribute}_array on metagraph.sync(). The '{attribute}' tensor is empty."
)
return tensor_param

def save(self) -> "metagraph":
"""
Save the state of the metagraph object.
Expand Down
4 changes: 2 additions & 2 deletions bittensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2306,7 +2306,7 @@ def neurons_lite(
return NeuronInfoLite.list_from_vec_u8(bytes_result)

def metagraph(
self, netuid: int, lite: bool = True, block: Optional[int] = None
self, netuid: int, lite: bool = True, block: Optional[int] = None, root: bool = False
) -> "bittensor.Metagraph":
r"""Returns a synced metagraph for the subnet.
Args:
Expand All @@ -2323,7 +2323,7 @@ def metagraph(
metagraph_ = bittensor.metagraph(
network=self.network, netuid=netuid, lite=lite, sync=False
)
metagraph_.sync(block=block, lite=lite, subtensor=self)
metagraph_.sync(block=block, lite=lite, subtensor=self, root=root)

return metagraph_

Expand Down