Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use raw Tensor data as BLOB in SQLiteDatabase #8054

Merged
merged 4 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added a `Database` interface and `SQLiteDatabase`/`RocksDatabase` implementations ([#8028](https://github.com/pyg-team/pytorch_geometric/pull/8028), [#8044](https://github.com/pyg-team/pytorch_geometric/pull/8044), [#8046](https://github.com/pyg-team/pytorch_geometric/pull/8046), [#8051](https://github.com/pyg-team/pytorch_geometric/pull/8051), [#8052](https://github.com/pyg-team/pytorch_geometric/pull/8052))
- Added a `Database` interface and `SQLiteDatabase`/`RocksDatabase` implementations ([#8028](https://github.com/pyg-team/pytorch_geometric/pull/8028), [#8044](https://github.com/pyg-team/pytorch_geometric/pull/8044), [#8046](https://github.com/pyg-team/pytorch_geometric/pull/8046), [#8051](https://github.com/pyg-team/pytorch_geometric/pull/8051), [#8052](https://github.com/pyg-team/pytorch_geometric/pull/8052), [#8054](https://github.com/pyg-team/pytorch_geometric/pull/8054))
- Added support for weighted/biased sampling in `NeighborLoader`/`LinkNeighborLoader` ([#8038](https://github.com/pyg-team/pytorch_geometric/pull/8038))
- Added the `MixHopConv` layer and an corresponding example ([#8025](https://github.com/pyg-team/pytorch_geometric/pull/8025))
- Added the option to pass keyword arguments to the underlying normalization layers within `BasicGNN` and `MLP` ([#8024](https://github.com/pyg-team/pytorch_geometric/pull/8024), [#8033](https://github.com/pyg-team/pytorch_geometric/pull/8033))
Expand Down
4 changes: 2 additions & 2 deletions test/data/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ def test_database_syntactic_sugar(tmp_path):
print(f'Initialized RocksDB in {time.perf_counter() - t:.2f} seconds')

def in_memory_get(data):
index = torch.randint(0, args.numel, (128, ))
index = torch.randint(0, args.numel, (args.batch_size, ))
return data[index]

def db_get(db):
index = torch.randint(0, args.numel, (128, ))
index = torch.randint(0, args.numel, (args.batch_size, ))
return db[index]

benchmark(
Expand Down
178 changes: 144 additions & 34 deletions torch_geometric/data/database.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,51 @@
import pickle
import warnings
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Optional, Union
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from uuid import uuid4

import torch
from torch import Tensor
from tqdm import tqdm

from torch_geometric.utils.mixin import CastMixin


@dataclass
class TensorInfo(CastMixin):
dtype: torch.dtype
size: Tuple[int, ...] = field(default_factory=lambda: (-1, ))


def maybe_cast_to_tensor_info(value: Any) -> Union[Any, TensorInfo]:
if not isinstance(value, dict):
return value
if len(value) < 1 or len(value) > 2:
return value
if len(value) == 1 and 'dtype' not in value:
return value
if len(value) == 2 and 'dtype' not in value and 'size' not in value:
return value
return TensorInfo.cast(value)


Schema = Union[Any, Dict[str, Any], Tuple[Any], List[Any]]


class Database(ABC):
r"""Base class for database."""
r"""Base class for inserting and retrieving data from a database."""
def __init__(self, schema: Schema = object):
schema = maybe_cast_to_tensor_info(schema)
schema = self._to_dict(schema)
schema = {
key: maybe_cast_to_tensor_info(value)
for key, value in schema.items()
}

self.schema: Dict[Union[str, int], Any] = schema

def connect(self):
pass

Expand Down Expand Up @@ -83,18 +120,13 @@ def _multi_get(self, indices: Union[Iterable[int], Tensor]) -> List[Any]:
# Helper functions ########################################################

@staticmethod
def serialize(data: Any) -> bytes:
r"""Serializes :obj:`data` into bytes."""
# Ensure that data is not a view of a larger tensor:
if isinstance(data, Tensor):
data = data.clone()

return pickle.dumps(data)

@staticmethod
def deserialize(data: bytes) -> Any:
r"""Deserializes bytes into the original data."""
return pickle.loads(data)
def _to_dict(value) -> Dict[Union[str, int], Any]:
if isinstance(value, dict):
return value
if isinstance(value, (tuple, list)):
return {i: v for i, v in enumerate(value)}
else:
return {0: value}

def slice_to_range(self, indices: slice) -> range:
start = 0 if indices.start is None else indices.start
Expand Down Expand Up @@ -136,8 +168,10 @@ def __repr__(self) -> str:


class SQLiteDatabase(Database):
def __init__(self, path: str, name: str):
super().__init__()
def __init__(self, path: str, name: str, schema: Schema = object):
super().__init__(schema)

warnings.filterwarnings('ignore', '.*given buffer is not writable.*')

import sqlite3

Expand All @@ -149,9 +183,13 @@ def __init__(self, path: str, name: str):

self.connect()

sql_schema = ',\n'.join([
f' {col_name} {self._to_sql_type(type_info)} NOT NULL' for
col_name, type_info in zip(self._col_names, self.schema.values())
])
query = (f'CREATE TABLE IF NOT EXISTS {self.name} (\n'
f' id INTEGER PRIMARY KEY,\n'
f' data BLOB NOT NULL\n'
f'{sql_schema}\n'
f')')
self.cursor.execute(query)

Expand All @@ -174,8 +212,10 @@ def cursor(self) -> Any:
return self._cursor

def insert(self, index: int, data: Any):
query = f'INSERT INTO {self.name} (id, data) VALUES (?, ?)'
self.cursor.execute(query, (index, self.serialize(data)))
query = (f'INSERT INTO {self.name} '
f'(id, {self._joined_col_names}) '
f'VALUES (?, {self._dummies})')
self.cursor.execute(query, (index, self._serialize(data)))

def _multi_insert(
self,
Expand All @@ -185,15 +225,18 @@ def _multi_insert(
if isinstance(indices, Tensor):
indices = indices.tolist()

data_list = [self.serialize(data) for data in data_list]
data_list = [self._serialize(data) for data in data_list]

query = f'INSERT INTO {self.name} (id, data) VALUES (?, ?)'
query = (f'INSERT INTO {self.name} '
f'(id, {self._joined_col_names}) '
f'VALUES (?, {self._dummies})')
self.cursor.executemany(query, zip(indices, data_list))

def get(self, index: int) -> Any:
query = f'SELECT data FROM {self.name} WHERE id = ?'
query = (f'SELECT {self._joined_col_names} FROM {self.name} '
f'WHERE id = ?')
self.cursor.execute(query, (index, ))
return self.deserialize(self.cursor.fetchone()[0])
return self._deserialize(self.cursor.fetchone())

def multi_get(
self,
Expand Down Expand Up @@ -221,7 +264,7 @@ def multi_get(
query = f'SELECT * FROM {join_table_name}'
self.cursor.execute(query)

query = (f'SELECT {self.name}.data '
query = (f'SELECT {self._joined_col_names} '
f'FROM {self.name} INNER JOIN {join_table_name} '
f'ON {self.name}.id = {join_table_name}.id '
f'ORDER BY {join_table_name}.row_id')
Expand All @@ -240,17 +283,77 @@ def multi_get(
query = f'DROP TABLE {join_table_name}'
self.cursor.execute(query)

return [self.deserialize(data[0]) for data in data_list]
return [self._deserialize(data) for data in data_list]

def __len__(self) -> int:
query = f'SELECT COUNT(*) FROM {self.name}'
self.cursor.execute(query)
return self.cursor.fetchone()[0]

# Helper functions ########################################################

@cached_property
def _col_names(self) -> List[str]:
return [f'COL_{key}' for key in self.schema.keys()]

@cached_property
def _joined_col_names(self) -> str:
return ', '.join(self._col_names)

@cached_property
def _dummies(self) -> str:
return ', '.join(['?'] * len(self.schema.keys()))

def _to_sql_type(self, type_info: Any) -> str:
if type_info == int:
return 'INTEGER'
if type_info == int:
return 'FLOAT'
if type_info == str:
return 'TEXT'
else:
return 'BLOB'

def _serialize(self, row: Any) -> Union[Any, List[Any]]:
out_list: List[Any] = []
for key, col in self._to_dict(row).items():
if isinstance(self.schema[key], TensorInfo):
out = row.numpy().tobytes()
elif isinstance(col, Tensor):
self.schema[key] = TensorInfo(dtype=col.dtype)
out = row.numpy().tobytes()
elif self.schema[key] in {int, float, str}:
out = col
else:
out = pickle.dumps(col)

out_list.append(out)

return out_list if len(out_list) > 1 else out_list[0]

def _deserialize(self, row: Tuple[Any]) -> Any:
out_dict = {}
for i, (key, col_schema) in enumerate(self.schema.items()):
if isinstance(col_schema, TensorInfo):
out_dict[key] = torch.frombuffer(
row[i], dtype=col_schema.dtype).view(*col_schema.size)
elif col_schema in {int, float, str}:
out_dict[key] = row[i]
else:
out_dict[key] = pickle.loads(row[i])

if 0 in self.schema:
if len(self.schema) == 1:
return out_dict[0]
else:
return tuple(out_dict.values())
else:
return out_dict


class RocksDatabase(Database):
def __init__(self, path: str):
super().__init__()
def __init__(self, path: str, schema: Schema = object):
super().__init__(schema)

import rocksdict

Expand Down Expand Up @@ -283,18 +386,25 @@ def to_key(index: int) -> bytes:
return index.to_bytes(8, byteorder='big', signed=True)

def insert(self, index: int, data: Any):
# Ensure that data is not a view of a larger tensor:
if isinstance(data, Tensor):
data = data.clone()

self.db[self.to_key(index)] = self.serialize(data)
self.db[self.to_key(index)] = self._serialize(data)

def get(self, index: int) -> Any:
return self.deserialize(self.db[self.to_key(index)])
return self._deserialize(self.db[self.to_key(index)])

def _multi_get(self, indices: Union[Iterable[int], Tensor]) -> List[Any]:
if isinstance(indices, Tensor):
indices = indices.tolist()
indices = [self.to_key(index) for index in indices]
data_list = self.db[indices]
return [self.deserialize(data) for data in data_list]
return [self._deserialize(data) for data in data_list]

# Helper functions ########################################################

def _serialize(self, row: Any) -> bytes:
# Ensure that data is not a view of a larger tensor:
if isinstance(row, Tensor):
row = row.clone()
return pickle.dumps(row)

def _deserialize(self, row: bytes) -> Any:
return pickle.loads(row)