From 782d05fd5503b7f603b84c8d0764eaa7e39ff198 Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Wed, 15 Jan 2025 02:45:10 +0000 Subject: [PATCH 01/11] adding DB snapshots --- .gitignore | 1 + docker-compose.yml | 12 ++++++++ scripts/download_snapshots.sh | 15 ++++++++++ surrealdb/data/types/datetime.py | 8 ++--- surrealdb/data/types/duration.py | 3 +- tests/integration/async/test_batch.py | 41 ++++++++++++++++++++++++++ tests/unit/cbor_types/test_datetime.py | 4 ++- 7 files changed, 78 insertions(+), 6 deletions(-) create mode 100644 scripts/download_snapshots.sh create mode 100644 tests/integration/async/test_batch.py diff --git a/.gitignore b/.gitignore index 293663e7..e1536619 100644 --- a/.gitignore +++ b/.gitignore @@ -87,3 +87,4 @@ clog/ manifest/ .mypy_cache/ .ruff_cache/ +tests/db_snapshots/ diff --git a/docker-compose.yml b/docker-compose.yml index 2f3b664f..e379a96b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,6 +10,18 @@ services: ports: - 8000:8000 + surrealdb_big_data: + image: surrealdb/surrealdb:v2.0.0 + command: "start" + environment: + - SURREAL_USER=root + - SURREAL_PASS=root + - SURREAL_LOG=trace + ports: + - 8300:8000 + volumes: + - ./tests/db_snapshots/data/:/data/ + surrealdb_200: image: surrealdb/surrealdb:v2.0.0 command: "start" diff --git a/scripts/download_snapshots.sh b/scripts/download_snapshots.sh new file mode 100644 index 00000000..21f477f8 --- /dev/null +++ b/scripts/download_snapshots.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd .. +cd tests + +if [ -d "./db_snapshots" ]; then + echo "DB snapshots are already present" + rm -rf ./db_snapshots +fi + +dockpack pull -i maxwellflitton/surrealdb-data -d ./db_snapshots diff --git a/surrealdb/data/types/datetime.py b/surrealdb/data/types/datetime.py index 30a4b022..53526d96 100644 --- a/surrealdb/data/types/datetime.py +++ b/surrealdb/data/types/datetime.py @@ -1,17 +1,17 @@ -import pytz # type: ignore - from dataclasses import dataclass from datetime import datetime -from math import floor from typing import Tuple +import pytz # type: ignore +from math import floor + @dataclass class DateTimeCompact: timestamp: int = 0 # nanoseconds @staticmethod - def parse(seconds: int, nanoseconds: int): + def parse(seconds: int, nanoseconds: int) -> "DateTimeCompact": return DateTimeCompact(nanoseconds + (seconds * pow(10, 9))) def get_seconds_and_nano(self) -> Tuple[int, int]: diff --git a/surrealdb/data/types/duration.py b/surrealdb/data/types/duration.py index 756e3771..09b94025 100644 --- a/surrealdb/data/types/duration.py +++ b/surrealdb/data/types/duration.py @@ -1,7 +1,8 @@ from dataclasses import dataclass -from math import floor from typing import Tuple +from math import floor + @dataclass class Duration: diff --git a/tests/integration/async/test_batch.py b/tests/integration/async/test_batch.py new file mode 100644 index 00000000..dcac1084 --- /dev/null +++ b/tests/integration/async/test_batch.py @@ -0,0 +1,41 @@ +from typing import List +from unittest import TestCase, main + +from surrealdb import SurrealDB, RecordID +from tests.integration.connection_params import TestConnectionParams +import asyncio +import websockets + + +class TestBatch(TestCase): + + def setUp(self) -> None: + self.params = TestConnectionParams() + self.db = SurrealDB(self.params.url) + self.queries: List[str] = [] + + self.db.connect() + self.db.use(self.params.namespace, self.params.database) + self.db.sign_in("root", "root") + + # self.query = """ + # CREATE |product:1000000| CONTENT { + # name: rand::enum('Cruiser Hoodie', 'Surreal T-Shirt'), + # colours: [rand::string(10), rand::string(10),], + # price: rand::int(20, 60), + # time: { + # created_at: rand::time(1611847404, 1706455404), + # updated_at: rand::time(1651155804, 1716906204) + # } + # }; + # """ + # self.db.query(query=self.query) + + def tearDown(self) -> None: + pass + + def test_batch(self): + print("test_batch") + +if __name__ == '__main__': + main() diff --git a/tests/unit/cbor_types/test_datetime.py b/tests/unit/cbor_types/test_datetime.py index 8bd284d1..ec3f4249 100644 --- a/tests/unit/cbor_types/test_datetime.py +++ b/tests/unit/cbor_types/test_datetime.py @@ -1,4 +1,4 @@ -from unittest import TestCase +from unittest import TestCase, main from surrealdb.data.types.datetime import DateTimeCompact from surrealdb.data.cbor import encode, decode @@ -21,3 +21,5 @@ def test_datetime(self): self.assertEqual(decoded.get_date_time(), '2024-12-12T09:00:58.083988Z') +if __name__ == '__main__': + main() From 69d708506dc8e16992d04b1bebd3ea431b1f314a Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Mon, 3 Feb 2025 20:07:06 +0000 Subject: [PATCH 02/11] adding datetime fix --- docker-compose.yml | 37 +---- src/surrealdb/__init__.py | 3 +- src/surrealdb/data/cbor.py | 35 +++- src/surrealdb/data/types/datetime.py | 153 +++++++++++------- tests/unit_tests/data_types/__init__.py | 0 tests/unit_tests/data_types/test_datetimes.py | 64 ++++++++ 6 files changed, 186 insertions(+), 106 deletions(-) create mode 100644 tests/unit_tests/data_types/__init__.py create mode 100644 tests/unit_tests/data_types/test_datetimes.py diff --git a/docker-compose.yml b/docker-compose.yml index c4a42a53..e43d8ab8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -9,42 +9,7 @@ services: - SURREAL_INSECURE_FORWARD_ACCESS_ERRORS=true - SURREAL_LOG=debug ports: -<<<<<<< HEAD - - 8300:8000 - - surrealdb_big_data: - image: surrealdb-big-data -# environment: -# - SURREAL_USER=root -# - SURREAL_PASS=root -# - SURREAL_LOG=trace - ports: - - 9000:8000 - - surrealdb_big_data: - image: surrealdb/surrealdb:v2.0.0 - command: "start" - environment: - - SURREAL_USER=root - - SURREAL_PASS=root - - SURREAL_LOG=trace - ports: - - 8300:8000 - volumes: - - ./tests/db_snapshots/data/:/data/ - - surrealdb_200: - image: surrealdb/surrealdb:v2.0.0 - command: "start" - environment: - - SURREAL_USER=root - - SURREAL_PASS=root - - SURREAL_LOG=trace - ports: - - 8200:8000 -======= - 8000:8000 ->>>>>>> 0c8a24aca5689176fb3ee89398c08b2cd8ced4a1 surrealdb_121: image: surrealdb/surrealdb:v1.2.1 @@ -84,4 +49,4 @@ services: - SURREAL_PASS=root - SURREAL_LOG=trace ports: - - 8111:8000 + - 8111:8000 \ No newline at end of file diff --git a/src/surrealdb/__init__.py b/src/surrealdb/__init__.py index d19f704a..fa7357ac 100644 --- a/src/surrealdb/__init__.py +++ b/src/surrealdb/__init__.py @@ -7,12 +7,13 @@ from surrealdb.data.types.table import Table from surrealdb.data.types.constants import * -from surrealdb.data.types.datetime import DateTimeCompact from surrealdb.data.types.duration import Duration from surrealdb.data.types.future import Future from surrealdb.data.types.geometry import Geometry from surrealdb.data.types.range import Range from surrealdb.data.types.record_id import RecordID +from surrealdb.data.types.datetime import DatetimeWrapper +from surrealdb.data.types.datetime import IsoDateTimeWrapper class AsyncSurrealDBMeta(type): diff --git a/src/surrealdb/data/cbor.py b/src/surrealdb/data/cbor.py index ee0bce90..3802de27 100644 --- a/src/surrealdb/data/cbor.py +++ b/src/surrealdb/data/cbor.py @@ -1,7 +1,6 @@ import cbor2 from surrealdb.data.types import constants -from surrealdb.data.types.datetime import DateTimeCompact from surrealdb.data.types.duration import Duration from surrealdb.data.types.future import Future from surrealdb.data.types.geometry import ( @@ -16,11 +15,13 @@ from surrealdb.data.types.range import BoundIncluded, BoundExcluded, Range from surrealdb.data.types.record_id import RecordID from surrealdb.data.types.table import Table +from surrealdb.data.types.datetime import DatetimeWrapper, IsoDateTimeWrapper +from datetime import datetime, timedelta, timezone +import pytz @cbor2.shareable_encoder def default_encoder(encoder, obj): - if isinstance(obj, GeometryPoint): tagged = cbor2.CBORTag(constants.TAG_GEOMETRY_POINT, obj.get_coordinates()) @@ -67,17 +68,21 @@ def default_encoder(encoder, obj): elif isinstance(obj, Duration): tagged = cbor2.CBORTag(constants.TAG_DURATION, obj.get_seconds_and_nano()) - elif isinstance(obj, DateTimeCompact): + elif isinstance(obj, DatetimeWrapper): + if obj.dt.tzinfo is None: # Make sure it's timezone-aware + obj.dt = obj.dt.replace(tzinfo=timezone.utc) + tagged = cbor2.CBORTag( - constants.TAG_DATETIME_COMPACT, obj.get_seconds_and_nano() + constants.TAG_DATETIME_COMPACT, + [int(obj.dt.timestamp()), obj.dt.microsecond * 1000] ) - + elif isinstance(obj, IsoDateTimeWrapper): + tagged = cbor2.CBORTag(constants.TAG_DATETIME, obj.dt) else: raise BufferError("no encoder for type ", type(obj)) encoder.encode(tagged) - def tag_decoder(decoder, tag, shareable_index=None): if tag.tag == constants.TAG_GEOMETRY_POINT: return GeometryPoint.parse_coordinates(tag.value) @@ -118,16 +123,30 @@ def tag_decoder(decoder, tag, shareable_index=None): elif tag.tag == constants.TAG_RANGE: return Range(tag.value[0], tag.value[1]) + elif tag.tag == constants.TAG_DURATION_COMPACT: + return Duration.parse(tag.value[0], tag.value[1]) # Two numbers (s, ns) + elif tag.tag == constants.TAG_DURATION: - return Duration.parse(tag.value[0], tag.value[1]) + return Duration.parse(tag.value) # String (e.g., "1d3m5ms") elif tag.tag == constants.TAG_DATETIME_COMPACT: - return DateTimeCompact.parse(tag.value[0], tag.value[1]) + # TODO => convert [seconds, nanoseconds] => return datetime + seconds = tag.value[0] + nanoseconds = tag.value[1] + microseconds = nanoseconds // 1000 # Convert nanoseconds to microseconds + return DatetimeWrapper( + datetime.fromtimestamp(seconds) + timedelta(microseconds=microseconds) + ) + + elif tag.tag == constants.TAG_DATETIME: + dt_obj = datetime.fromisoformat(tag.value) + return DatetimeWrapper(dt_obj)# String (ISO 8601 datetime) else: raise BufferError("no decoder for tag", tag.tag) + def encode(obj): return cbor2.dumps(obj, default=default_encoder) diff --git a/src/surrealdb/data/types/datetime.py b/src/surrealdb/data/types/datetime.py index 19b4961a..750b94a0 100644 --- a/src/surrealdb/data/types/datetime.py +++ b/src/surrealdb/data/types/datetime.py @@ -1,72 +1,103 @@ -""" -Defines a compact representation of datetime using nanoseconds. -""" from dataclasses import dataclass -from datetime import datetime -from typing import Tuple -import pytz # type: ignore from math import floor +from typing import Tuple +from datetime import datetime +import pytz -@dataclass -class DateTimeCompact: - """ - Represents a compact datetime object stored as a single integer value in nanoseconds. +class DatetimeWrapper: - Attributes: - timestamp: The number of nanoseconds since the epoch (1970-01-01T00:00:00Z). - """ - timestamp: int = 0 # nanoseconds + def __init__(self, dt: datetime): + self.dt = dt @staticmethod - def parse(seconds: int, nanoseconds: int) -> "DateTimeCompact": - """ - Creates a DateTimeCompact object from seconds and nanoseconds. - - Args: - seconds: The number of seconds since the epoch. - nanoseconds: The additional nanoseconds beyond the seconds. - - Returns: - A DateTimeCompact object representing the specified time. - """ - return DateTimeCompact(nanoseconds + (seconds * pow(10, 9))) - - def get_seconds_and_nano(self) -> Tuple[int, int]: - """ - Extracts the seconds and nanoseconds components from the timestamp. - - Returns: - A tuple containing: - - The number of seconds since the epoch. - - The remaining nanoseconds after the seconds. - """ - sec = floor(self.timestamp / pow(10, 9)) - nsec = self.timestamp - (sec * pow(10, 9)) - return sec, nsec - - def get_date_time(self, fmt: str = "%Y-%m-%dT%H:%M:%S.%fZ") -> str: - """ - Converts the timestamp into a formatted datetime string. - - Args: - fmt: The format string for the datetime. Defaults to ISO 8601 format. + def now() -> "DatetimeWrapper": + return DatetimeWrapper(datetime.now()) - Returns: - A string representation of the datetime in the specified format. - """ - return datetime.fromtimestamp(self.timestamp / pow(10, 9), pytz.UTC).strftime(fmt) - def __eq__(self, other: object) -> bool: - """ - Compares two DateTimeCompact objects for equality. +class IsoDateTimeWrapper: - Args: - other: The object to compare against. + def __init__(self, dt: str): + self.dt = dt - Returns: - True if the objects have the same timestamp, False otherwise. - """ - if isinstance(other, DateTimeCompact): - return self.timestamp == other.timestamp - return False +# @dataclass +# class DateTimeCompact: +# """ +# Represents a compact datetime object stored as a single integer value in nanoseconds. +# +# Attributes: +# timestamp: The number of nanoseconds since the epoch (1970-01-01T00:00:00Z). +# """ +# timestamp: int = 0 # nanoseconds +# +# @staticmethod +# def from_datetime(dt: datetime) -> 'DateTimeCompact': +# pass +# +# @staticmethod +# def parse(seconds: int, nanoseconds: int) -> "DateTimeCompact": +# """ +# Creates a DateTimeCompact object from seconds and nanoseconds. +# +# Args: +# seconds: The number of seconds since the epoch. +# nanoseconds: The additional nanoseconds beyond the seconds. +# +# Returns: +# A DateTimeCompact object representing the specified time. +# """ +# return DateTimeCompact(nanoseconds + (seconds * pow(10, 9))) +# +# @staticmethod +# def from_iso_string(iso_str: str) -> "DateTimeCompact": +# """ +# Creates a DateTimeCompact object from an ISO 8601 datetime string. +# +# Args: +# iso_str: A string representation of a datetime in ISO 8601 format. +# +# Returns: +# A DateTimeCompact object. +# """ +# dt = datetime.fromisoformat(iso_str.replace("Z", "+00:00")) # Handle UTC 'Z' case +# timestamp = int(dt.timestamp() * pow(10, 9)) # Convert to nanoseconds +# return DateTimeCompact(timestamp) +# +# def get_seconds_and_nano(self) -> Tuple[int, int]: +# """ +# Extracts the seconds and nanoseconds components from the timestamp. +# +# Returns: +# A tuple containing: +# - The number of seconds since the epoch. +# - The remaining nanoseconds after the seconds. +# """ +# sec = floor(self.timestamp / pow(10, 9)) +# nsec = self.timestamp - (sec * pow(10, 9)) +# return sec, nsec +# +# def get_date_time(self, fmt: str = "%Y-%m-%dT%H:%M:%S.%fZ") -> str: +# """ +# Converts the timestamp into a formatted datetime string. +# +# Args: +# fmt: The format string for the datetime. Defaults to ISO 8601 format. +# +# Returns: +# A string representation of the datetime in the specified format. +# """ +# return datetime.fromtimestamp(self.timestamp / pow(10, 9), pytz.UTC).strftime(fmt) +# +# def __eq__(self, other: object) -> bool: +# """ +# Compares two DateTimeCompact objects for equality. +# +# Args: +# other: The object to compare against. +# +# Returns: +# True if the objects have the same timestamp, False otherwise. +# """ +# if isinstance(other, DateTimeCompact): +# return self.timestamp == other.timestamp +# return False diff --git a/tests/unit_tests/data_types/__init__.py b/tests/unit_tests/data_types/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/data_types/test_datetimes.py b/tests/unit_tests/data_types/test_datetimes.py new file mode 100644 index 00000000..4eb7484b --- /dev/null +++ b/tests/unit_tests/data_types/test_datetimes.py @@ -0,0 +1,64 @@ +from unittest import main, IsolatedAsyncioTestCase + +from surrealdb.connections.async_ws import AsyncWsSurrealConnection +from surrealdb.data.types.datetime import DatetimeWrapper, IsoDateTimeWrapper + + +class TestAsyncWsSurrealConnectionDatetime(IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + self.url = "ws://localhost:8000/rpc" + self.password = "root" + self.username = "root" + self.vars_params = { + "username": self.username, + "password": self.password, + } + self.database_name = "test_db" + self.namespace = "test_ns" + self.connection = AsyncWsSurrealConnection(self.url) + + # Sign in and select DB + await self.connection.signin(self.vars_params) + await self.connection.use(namespace=self.namespace, database=self.database_name) + + # Cleanup + await self.connection.query("DELETE datetime_tests;") + + async def test_datetime_wrapper(self): + now = DatetimeWrapper.now() + await self.connection.query( + "CREATE datetime_tests:compact_test SET datetime = $compact_datetime;", + params={"compact_datetime": now} + ) + compact_test_outcome = await self.connection.query("SELECT * FROM datetime_tests;") + self.assertEqual( + type(compact_test_outcome[0]["datetime"]), + DatetimeWrapper + ) + await self.connection.query("DELETE datetime_tests;") + await self.connection.close() + + async def test_datetime_formats(self): + # Define datetime values + iso_datetime = "2025-02-03T12:30:45.123456Z" # ISO 8601 datetime + date = IsoDateTimeWrapper(iso_datetime) + + # Insert records with different datetime formats + await self.connection.query( + "CREATE datetime_tests:iso_tests SET datetime = $iso_datetime;", + params={"iso_datetime": date} + ) + compact_test_outcome = await self.connection.query("SELECT * FROM datetime_tests;") + self.assertEqual( + type(compact_test_outcome[0]["datetime"]), + DatetimeWrapper + ) + + # Cleanup + await self.connection.query("DELETE datetime_tests;") + await self.connection.socket.close() + + +if __name__ == "__main__": + main() From 1f5dd0d26f4a7f9f8adafcfc700d76ff210d789d Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Mon, 3 Feb 2025 20:17:16 +0000 Subject: [PATCH 03/11] datetime now working/tested --- tests/unit_tests/data_types/test_datetimes.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/data_types/test_datetimes.py b/tests/unit_tests/data_types/test_datetimes.py index 4eb7484b..6fb45000 100644 --- a/tests/unit_tests/data_types/test_datetimes.py +++ b/tests/unit_tests/data_types/test_datetimes.py @@ -36,11 +36,15 @@ async def test_datetime_wrapper(self): type(compact_test_outcome[0]["datetime"]), DatetimeWrapper ) + + # assert that the datetime returned from the DB is the same as the one serialized + outcome = compact_test_outcome[0]["datetime"] + self.assertEqual(now.dt.isoformat(), outcome.dt.isoformat() + "+00:00") + await self.connection.query("DELETE datetime_tests;") await self.connection.close() async def test_datetime_formats(self): - # Define datetime values iso_datetime = "2025-02-03T12:30:45.123456Z" # ISO 8601 datetime date = IsoDateTimeWrapper(iso_datetime) @@ -55,6 +59,10 @@ async def test_datetime_formats(self): DatetimeWrapper ) + # assert that the datetime returned from the DB is the same as the one serialized + date = compact_test_outcome[0]["datetime"].dt.isoformat() + self.assertEqual(date + "Z", iso_datetime) + # Cleanup await self.connection.query("DELETE datetime_tests;") await self.connection.socket.close() From a831fef9b1e2e8ee5fabf1cf54775abf5e8049ad Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Mon, 3 Feb 2025 20:22:51 +0000 Subject: [PATCH 04/11] datetime now working/tested --- src/surrealdb/data/types/datetime.py | 86 ---------------------------- 1 file changed, 86 deletions(-) diff --git a/src/surrealdb/data/types/datetime.py b/src/surrealdb/data/types/datetime.py index 750b94a0..3984eec7 100644 --- a/src/surrealdb/data/types/datetime.py +++ b/src/surrealdb/data/types/datetime.py @@ -1,8 +1,4 @@ -from dataclasses import dataclass -from math import floor -from typing import Tuple from datetime import datetime -import pytz class DatetimeWrapper: @@ -19,85 +15,3 @@ class IsoDateTimeWrapper: def __init__(self, dt: str): self.dt = dt - -# @dataclass -# class DateTimeCompact: -# """ -# Represents a compact datetime object stored as a single integer value in nanoseconds. -# -# Attributes: -# timestamp: The number of nanoseconds since the epoch (1970-01-01T00:00:00Z). -# """ -# timestamp: int = 0 # nanoseconds -# -# @staticmethod -# def from_datetime(dt: datetime) -> 'DateTimeCompact': -# pass -# -# @staticmethod -# def parse(seconds: int, nanoseconds: int) -> "DateTimeCompact": -# """ -# Creates a DateTimeCompact object from seconds and nanoseconds. -# -# Args: -# seconds: The number of seconds since the epoch. -# nanoseconds: The additional nanoseconds beyond the seconds. -# -# Returns: -# A DateTimeCompact object representing the specified time. -# """ -# return DateTimeCompact(nanoseconds + (seconds * pow(10, 9))) -# -# @staticmethod -# def from_iso_string(iso_str: str) -> "DateTimeCompact": -# """ -# Creates a DateTimeCompact object from an ISO 8601 datetime string. -# -# Args: -# iso_str: A string representation of a datetime in ISO 8601 format. -# -# Returns: -# A DateTimeCompact object. -# """ -# dt = datetime.fromisoformat(iso_str.replace("Z", "+00:00")) # Handle UTC 'Z' case -# timestamp = int(dt.timestamp() * pow(10, 9)) # Convert to nanoseconds -# return DateTimeCompact(timestamp) -# -# def get_seconds_and_nano(self) -> Tuple[int, int]: -# """ -# Extracts the seconds and nanoseconds components from the timestamp. -# -# Returns: -# A tuple containing: -# - The number of seconds since the epoch. -# - The remaining nanoseconds after the seconds. -# """ -# sec = floor(self.timestamp / pow(10, 9)) -# nsec = self.timestamp - (sec * pow(10, 9)) -# return sec, nsec -# -# def get_date_time(self, fmt: str = "%Y-%m-%dT%H:%M:%S.%fZ") -> str: -# """ -# Converts the timestamp into a formatted datetime string. -# -# Args: -# fmt: The format string for the datetime. Defaults to ISO 8601 format. -# -# Returns: -# A string representation of the datetime in the specified format. -# """ -# return datetime.fromtimestamp(self.timestamp / pow(10, 9), pytz.UTC).strftime(fmt) -# -# def __eq__(self, other: object) -> bool: -# """ -# Compares two DateTimeCompact objects for equality. -# -# Args: -# other: The object to compare against. -# -# Returns: -# True if the objects have the same timestamp, False otherwise. -# """ -# if isinstance(other, DateTimeCompact): -# return self.timestamp == other.timestamp -# return False From 369a641f1015c085cfaca72f3b313f1dc27cbb42 Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Tue, 4 Feb 2025 14:04:40 +0000 Subject: [PATCH 05/11] updating tests --- docker-compose.yml | 5 +++ src/surrealdb/__init__.py | 1 - src/surrealdb/connections/async_http.py | 2 +- src/surrealdb/data/cbor.py | 24 +++--------- src/surrealdb/data/types/datetime.py | 18 ++++----- .../connections/invalidate/test_async_http.py | 34 +++++++++++++--- .../connections/invalidate/test_async_ws.py | 25 +++++++++--- .../invalidate/test_blocking_http.py | 36 ++++++++++++----- .../invalidate/test_blocking_ws.py | 39 +++++++++++++------ tests/unit_tests/data_types/test_datetimes.py | 18 +++++---- 10 files changed, 132 insertions(+), 70 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index e43d8ab8..97207000 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,6 +8,7 @@ services: - SURREAL_PASS=root - SURREAL_INSECURE_FORWARD_ACCESS_ERRORS=true - SURREAL_LOG=debug + - SURREAL_CAPS_ALLOW_GUESTS=true ports: - 8000:8000 @@ -18,6 +19,7 @@ services: - SURREAL_USER=root - SURREAL_PASS=root - SURREAL_LOG=trace + - SURREAL_CAPS_ALLOW_GUESTS=true ports: - 8121:8000 @@ -28,6 +30,7 @@ services: - SURREAL_USER=root - SURREAL_PASS=root - SURREAL_LOG=trace + - SURREAL_CAPS_ALLOW_GUESTS=true ports: - 8120:8000 @@ -38,6 +41,7 @@ services: - SURREAL_USER=root - SURREAL_PASS=root - SURREAL_LOG=trace + - SURREAL_CAPS_ALLOW_GUESTS=true ports: - 8101:8000 @@ -48,5 +52,6 @@ services: - SURREAL_USER=root - SURREAL_PASS=root - SURREAL_LOG=trace + - SURREAL_CAPS_ALLOW_GUESTS=true ports: - 8111:8000 \ No newline at end of file diff --git a/src/surrealdb/__init__.py b/src/surrealdb/__init__.py index fa7357ac..10f1aa68 100644 --- a/src/surrealdb/__init__.py +++ b/src/surrealdb/__init__.py @@ -12,7 +12,6 @@ from surrealdb.data.types.geometry import Geometry from surrealdb.data.types.range import Range from surrealdb.data.types.record_id import RecordID -from surrealdb.data.types.datetime import DatetimeWrapper from surrealdb.data.types.datetime import IsoDateTimeWrapper class AsyncSurrealDBMeta(type): diff --git a/src/surrealdb/connections/async_http.py b/src/surrealdb/connections/async_http.py index 6501f5ad..efd7a126 100644 --- a/src/surrealdb/connections/async_http.py +++ b/src/surrealdb/connections/async_http.py @@ -101,7 +101,7 @@ async def authenticate(self) -> None: message = RequestMessage( self.id, RequestMethod.AUTHENTICATE, - token=token + token=self.token ) return await self._send(message, "authenticating") diff --git a/src/surrealdb/data/cbor.py b/src/surrealdb/data/cbor.py index 3802de27..c1125df6 100644 --- a/src/surrealdb/data/cbor.py +++ b/src/surrealdb/data/cbor.py @@ -1,6 +1,9 @@ +from datetime import datetime, timedelta, timezone + import cbor2 from surrealdb.data.types import constants +from surrealdb.data.types.datetime import IsoDateTimeWrapper from surrealdb.data.types.duration import Duration from surrealdb.data.types.future import Future from surrealdb.data.types.geometry import ( @@ -15,9 +18,6 @@ from surrealdb.data.types.range import BoundIncluded, BoundExcluded, Range from surrealdb.data.types.record_id import RecordID from surrealdb.data.types.table import Table -from surrealdb.data.types.datetime import DatetimeWrapper, IsoDateTimeWrapper -from datetime import datetime, timedelta, timezone -import pytz @cbor2.shareable_encoder @@ -68,14 +68,6 @@ def default_encoder(encoder, obj): elif isinstance(obj, Duration): tagged = cbor2.CBORTag(constants.TAG_DURATION, obj.get_seconds_and_nano()) - elif isinstance(obj, DatetimeWrapper): - if obj.dt.tzinfo is None: # Make sure it's timezone-aware - obj.dt = obj.dt.replace(tzinfo=timezone.utc) - - tagged = cbor2.CBORTag( - constants.TAG_DATETIME_COMPACT, - [int(obj.dt.timestamp()), obj.dt.microsecond * 1000] - ) elif isinstance(obj, IsoDateTimeWrapper): tagged = cbor2.CBORTag(constants.TAG_DATETIME, obj.dt) else: @@ -134,13 +126,7 @@ def tag_decoder(decoder, tag, shareable_index=None): seconds = tag.value[0] nanoseconds = tag.value[1] microseconds = nanoseconds // 1000 # Convert nanoseconds to microseconds - return DatetimeWrapper( - datetime.fromtimestamp(seconds) + timedelta(microseconds=microseconds) - ) - - elif tag.tag == constants.TAG_DATETIME: - dt_obj = datetime.fromisoformat(tag.value) - return DatetimeWrapper(dt_obj)# String (ISO 8601 datetime) + return datetime.fromtimestamp(seconds) + timedelta(microseconds=microseconds) else: raise BufferError("no decoder for tag", tag.tag) @@ -148,7 +134,7 @@ def tag_decoder(decoder, tag, shareable_index=None): def encode(obj): - return cbor2.dumps(obj, default=default_encoder) + return cbor2.dumps(obj, default=default_encoder, timezone=timezone.utc) def decode(data): diff --git a/src/surrealdb/data/types/datetime.py b/src/surrealdb/data/types/datetime.py index 3984eec7..cff8d2b1 100644 --- a/src/surrealdb/data/types/datetime.py +++ b/src/surrealdb/data/types/datetime.py @@ -1,14 +1,14 @@ -from datetime import datetime +# from datetime import datetime -class DatetimeWrapper: - - def __init__(self, dt: datetime): - self.dt = dt - - @staticmethod - def now() -> "DatetimeWrapper": - return DatetimeWrapper(datetime.now()) +# class DatetimeWrapper: +# +# def __init__(self, dt: datetime): +# self.dt = dt +# +# @staticmethod +# def now() -> "DatetimeWrapper": +# return DatetimeWrapper(datetime.now()) class IsoDateTimeWrapper: diff --git a/tests/unit_tests/connections/invalidate/test_async_http.py b/tests/unit_tests/connections/invalidate/test_async_http.py index 4e6df53e..0a038d8a 100644 --- a/tests/unit_tests/connections/invalidate/test_async_http.py +++ b/tests/unit_tests/connections/invalidate/test_async_http.py @@ -25,7 +25,17 @@ async def asyncSetUp(self): _ = await self.connection.signin(self.vars_params) _ = await self.connection.use(namespace=self.namespace, database=self.database_name) - async def test_invalidate(self): + async def test_run_test(self): + if os.environ.get("NO_GUEST_MODE") == "True": + await self.invalidate_test_for_no_guest_mode() + else: + await self.invalidate_with_guest_mode_on() + + async def invalidate_with_guest_mode_on(self): + """ + This test only works if the SURREAL_CAPS_ALLOW_GUESTS=false is set in the docker container + + """ outcome = await self.connection.query("SELECT * FROM user;") self.assertEqual(1, len(outcome)) outcome = await self.main_connection.query("SELECT * FROM user;") @@ -37,18 +47,30 @@ async def test_invalidate(self): self.assertEqual(0, len(outcome)) outcome = await self.main_connection.query("SELECT * FROM user;") self.assertEqual(1, len(outcome)) + await self.main_connection.query("DELETE user;") + + async def invalidate_test_for_no_guest_mode(self): + """ + This test asserts that there is an error thrown due to no guest mode being allowed + Only run this test if SURREAL_CAPS_ALLOW_GUESTS=false is set in the docker container + """ + outcome = await self.connection.query("SELECT * FROM user;") + self.assertEqual(1, len(outcome)) + outcome = await self.main_connection.query("SELECT * FROM user;") + self.assertEqual(1, len(outcome)) + + _ = await self.connection.invalidate() - ''' - # Exceptions are raised only when SurrealDB doesn't allow guest mode with self.assertRaises(Exception) as context: - _ = await self.connection.query("CREATE user:jaime SET name = 'Jaime';") + _ = await self.connection.query("SELECT * FROM user;") self.assertEqual( "IAM error: Not enough permissions" in str(context.exception), True ) - ''' - + outcome = await self.main_connection.query("SELECT * FROM user;") + self.assertEqual(1, len(outcome)) await self.main_connection.query("DELETE user;") + if __name__ == "__main__": main() diff --git a/tests/unit_tests/connections/invalidate/test_async_ws.py b/tests/unit_tests/connections/invalidate/test_async_ws.py index efa4e5e9..1faf24c3 100644 --- a/tests/unit_tests/connections/invalidate/test_async_ws.py +++ b/tests/unit_tests/connections/invalidate/test_async_ws.py @@ -25,7 +25,13 @@ async def asyncSetUp(self): _ = await self.connection.signin(self.vars_params) _ = await self.connection.use(namespace=self.namespace, database=self.database_name) - async def test_invalidate(self): + async def test_run_test(self): + if os.environ.get("NO_GUEST_MODE") == "True": + await self.invalidate_test_for_no_guest_mode() + else: + await self.invalidate_with_guest_mode_on() + + async def invalidate_with_guest_mode_on(self): outcome = await self.connection.query("SELECT * FROM user;") self.assertEqual(1, len(outcome)) outcome = await self.main_connection.query("SELECT * FROM user;") @@ -37,16 +43,25 @@ async def test_invalidate(self): self.assertEqual(0, len(outcome)) outcome = await self.main_connection.query("SELECT * FROM user;") self.assertEqual(1, len(outcome)) + await self.main_connection.query("DELETE user;") + + async def invalidate_test_for_no_guest_mode(self): + outcome = await self.connection.query("SELECT * FROM user;") + self.assertEqual(1, len(outcome)) + outcome = await self.main_connection.query("SELECT * FROM user;") + self.assertEqual(1, len(outcome)) + + _ = await self.connection.invalidate() - ''' - # Exceptions are raised only when SurrealDB doesn't allow guest mode with self.assertRaises(Exception) as context: - _ = await self.connection.query("CREATE user:jaime SET name = 'Jaime';") + _ = await self.connection.query("SELECT * FROM user;") + self.assertEqual( "IAM error: Not enough permissions" in str(context.exception), True ) - ''' + outcome = await self.main_connection.query("SELECT * FROM user;") + self.assertEqual(1, len(outcome)) await self.main_connection.query("DELETE user;") await self.main_connection.close() diff --git a/tests/unit_tests/connections/invalidate/test_blocking_http.py b/tests/unit_tests/connections/invalidate/test_blocking_http.py index b181cf33..d5294069 100644 --- a/tests/unit_tests/connections/invalidate/test_blocking_http.py +++ b/tests/unit_tests/connections/invalidate/test_blocking_http.py @@ -25,7 +25,13 @@ def setUp(self): _ = self.connection.signin(self.vars_params) _ = self.connection.use(namespace=self.namespace, database=self.database_name) - def test_invalidate(self): + def test_run_test(self): + if os.environ.get("NO_GUEST_MODE") == "True": + self.invalidate_test_for_no_guest_mode() + else: + self.invalidate_with_guest_mode_on() + + def invalidate_test_for_no_guest_mode(self): outcome = self.connection.query("SELECT * FROM user;") self.assertEqual(1, len(outcome)) outcome = self.main_connection.query("SELECT * FROM user;") @@ -33,20 +39,30 @@ def test_invalidate(self): _ = self.connection.invalidate() - outcome = self.connection.query("SELECT * FROM user;") - self.assertEqual(0, len(outcome)) - outcome = self.main_connection.query("SELECT * FROM user;") - self.assertEqual(1, len(outcome)) - - ''' - # Exceptions are raised only when SurrealDB doesn't allow guest mode with self.assertRaises(Exception) as context: - _ = self.connection.query("CREATE user:jaime SET name = 'Jaime';") + _ = self.connection.query("SELECT * FROM user;") + self.assertEqual( "IAM error: Not enough permissions" in str(context.exception), True ) - ''' + outcome = self.main_connection.query("SELECT * FROM user;") + self.assertEqual(1, len(outcome)) + + self.main_connection.query("DELETE user;") + + def invalidate_with_guest_mode_on(self): + outcome = self.connection.query("SELECT * FROM user;") + self.assertEqual(1, len(outcome)) + outcome = self.main_connection.query("SELECT * FROM user;") + self.assertEqual(1, len(outcome)) + + _ = self.connection.invalidate() + + outcome = self.connection.query("SELECT * FROM user;") + self.assertEqual(0, len(outcome)) + outcome = self.main_connection.query("SELECT * FROM user;") + self.assertEqual(1, len(outcome)) self.main_connection.query("DELETE user;") diff --git a/tests/unit_tests/connections/invalidate/test_blocking_ws.py b/tests/unit_tests/connections/invalidate/test_blocking_ws.py index ad841d18..deb81623 100644 --- a/tests/unit_tests/connections/invalidate/test_blocking_ws.py +++ b/tests/unit_tests/connections/invalidate/test_blocking_ws.py @@ -1,5 +1,6 @@ -from unittest import main, TestCase import os +from unittest import main, TestCase + from surrealdb.connections.blocking_ws import BlockingWsSurrealConnection @@ -25,7 +26,13 @@ def setUp(self): _ = self.connection.signin(self.vars_params) _ = self.connection.use(namespace=self.namespace, database=self.database_name) - def test_invalidate(self): + def test_run_test(self): + if os.environ.get("NO_GUEST_MODE") == "True": + self.invalidate_test_for_no_guest_mode() + else: + self.invalidate_with_guest_mode_on() + + def invalidate_test_for_no_guest_mode(self): outcome = self.connection.query("SELECT * FROM user;") self.assertEqual(1, len(outcome)) outcome = self.main_connection.query("SELECT * FROM user;") @@ -33,20 +40,30 @@ def test_invalidate(self): _ = self.connection.invalidate() - outcome = self.connection.query("SELECT * FROM user;") - self.assertEqual(0, len(outcome)) - outcome = self.main_connection.query("SELECT * FROM user;") - self.assertEqual(1, len(outcome)) - - ''' - # Exceptions are raised only when SurrealDB doesn't allow guest mode with self.assertRaises(Exception) as context: - _ = self.connection.query("CREATE user:jaime SET name = 'Jaime';") + _ = self.connection.query("SELECT * FROM user;") + self.assertEqual( "IAM error: Not enough permissions" in str(context.exception), True ) - ''' + outcome = self.main_connection.query("SELECT * FROM user;") + self.assertEqual(1, len(outcome)) + + self.main_connection.query("DELETE user;") + + def invalidate_with_guest_mode_on(self): + outcome = self.connection.query("SELECT * FROM user;") + self.assertEqual(1, len(outcome)) + outcome = self.main_connection.query("SELECT * FROM user;") + self.assertEqual(1, len(outcome)) + + _ = self.connection.invalidate() + + outcome = self.connection.query("SELECT * FROM user;") + self.assertEqual(0, len(outcome)) + outcome = self.main_connection.query("SELECT * FROM user;") + self.assertEqual(1, len(outcome)) self.main_connection.query("DELETE user;") self.main_connection.close() diff --git a/tests/unit_tests/data_types/test_datetimes.py b/tests/unit_tests/data_types/test_datetimes.py index 6fb45000..6416eae8 100644 --- a/tests/unit_tests/data_types/test_datetimes.py +++ b/tests/unit_tests/data_types/test_datetimes.py @@ -1,7 +1,8 @@ +import datetime from unittest import main, IsolatedAsyncioTestCase from surrealdb.connections.async_ws import AsyncWsSurrealConnection -from surrealdb.data.types.datetime import DatetimeWrapper, IsoDateTimeWrapper +from surrealdb.data.types.datetime import IsoDateTimeWrapper class TestAsyncWsSurrealConnectionDatetime(IsolatedAsyncioTestCase): @@ -26,20 +27,20 @@ async def asyncSetUp(self): await self.connection.query("DELETE datetime_tests;") async def test_datetime_wrapper(self): - now = DatetimeWrapper.now() + now = datetime.datetime.now() await self.connection.query( "CREATE datetime_tests:compact_test SET datetime = $compact_datetime;", params={"compact_datetime": now} ) compact_test_outcome = await self.connection.query("SELECT * FROM datetime_tests;") self.assertEqual( - type(compact_test_outcome[0]["datetime"]), - DatetimeWrapper + compact_test_outcome[0]["datetime"], + now ) # assert that the datetime returned from the DB is the same as the one serialized outcome = compact_test_outcome[0]["datetime"] - self.assertEqual(now.dt.isoformat(), outcome.dt.isoformat() + "+00:00") + self.assertEqual(now.isoformat(), outcome.isoformat()) await self.connection.query("DELETE datetime_tests;") await self.connection.close() @@ -47,6 +48,7 @@ async def test_datetime_wrapper(self): async def test_datetime_formats(self): iso_datetime = "2025-02-03T12:30:45.123456Z" # ISO 8601 datetime date = IsoDateTimeWrapper(iso_datetime) + iso_datetime_obj = datetime.datetime.fromisoformat(iso_datetime) # Insert records with different datetime formats await self.connection.query( @@ -55,12 +57,12 @@ async def test_datetime_formats(self): ) compact_test_outcome = await self.connection.query("SELECT * FROM datetime_tests;") self.assertEqual( - type(compact_test_outcome[0]["datetime"]), - DatetimeWrapper + str(compact_test_outcome[0]["datetime"]) + "+00:00", + str(iso_datetime_obj) ) # assert that the datetime returned from the DB is the same as the one serialized - date = compact_test_outcome[0]["datetime"].dt.isoformat() + date = compact_test_outcome[0]["datetime"].isoformat() self.assertEqual(date + "Z", iso_datetime) # Cleanup From 6634c2328b9fa037008978cf0d9cd0437deb82af Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Tue, 4 Feb 2025 16:21:55 +0000 Subject: [PATCH 06/11] updating tests --- .github/workflows/stability.yml | 2 +- .gitignore | 1 + pyproject.toml | 12 + scripts/run_stability_check.sh | 19 ++ src/surrealdb/__init__.py | 66 ++++-- src/surrealdb/connections/async_http.py | 143 ++++------- src/surrealdb/connections/async_template.py | 81 +++---- src/surrealdb/connections/async_ws.py | 193 ++++++--------- src/surrealdb/connections/blocking_http.py | 134 ++++------- src/surrealdb/connections/blocking_ws.py | 140 +++-------- src/surrealdb/connections/sync_template.py | 31 +-- src/surrealdb/connections/utils_mixin.py | 4 +- src/surrealdb/data/cbor.py | 2 +- src/surrealdb/data/types/datetime.py | 13 - src/surrealdb/data/types/duration.py | 3 +- src/surrealdb/data/types/future.py | 1 + src/surrealdb/data/types/geometry.py | 51 +++- src/surrealdb/data/types/range.py | 5 +- src/surrealdb/data/types/record_id.py | 7 +- src/surrealdb/data/types/table.py | 1 + src/surrealdb/data/utils.py | 1 + src/surrealdb/errors.py | 2 - .../request_message/descriptors/cbor_ws.py | 222 ++++++++---------- .../request_message/descriptors/json_http.py | 85 ------- src/surrealdb/request_message/message.py | 2 - src/surrealdb/request_message/sql_adapter.py | 2 + tests/unit_tests/data_types/test_datetimes.py | 9 +- .../descriptors/test_json_http.py | 57 ----- 28 files changed, 487 insertions(+), 802 deletions(-) create mode 100644 scripts/run_stability_check.sh delete mode 100644 src/surrealdb/request_message/descriptors/json_http.py delete mode 100644 tests/unit_tests/request_message/descriptors/test_json_http.py diff --git a/.github/workflows/stability.yml b/.github/workflows/stability.yml index 231a6b92..c4d8d319 100644 --- a/.github/workflows/stability.yml +++ b/.github/workflows/stability.yml @@ -36,4 +36,4 @@ jobs: run: black --check --verbose --diff --color src/ - name: Run mypy checks - run: mypy src/ + run: mypy --explicit-package-bases src/ diff --git a/.gitignore b/.gitignore index e1536619..ff7db1cf 100644 --- a/.gitignore +++ b/.gitignore @@ -88,3 +88,4 @@ manifest/ .mypy_cache/ .ruff_cache/ tests/db_snapshots/ +logs/ diff --git a/pyproject.toml b/pyproject.toml index 21df044b..9382bc7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,5 +59,17 @@ Homepage = "https://github.com/surrealdb/surrealdb.py" [tool.setuptools.packages.find] where = ["src"] +[tool.ruff] +exclude = ["src/surrealdb/__init__.py"] + +[tool.mypy] +mypy_path = "src" +explicit_package_bases = true +disable_error_code = ["return-value", "return-type"] + +[[tool.mypy.overrides]] +module = "cerberus.*" +ignore_missing_imports = true + # [project.scripts] # sdblpy = "sblpy.cli.entrypoint:main" \ No newline at end of file diff --git a/scripts/run_stability_check.sh b/scripts/run_stability_check.sh new file mode 100644 index 00000000..830e7b1c --- /dev/null +++ b/scripts/run_stability_check.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd .. + +if [ -d ./logs ]; then + echo "log directory exists being removed" + rm -rf ./logs +fi + +mkdir logs + +ruff check src/ > ./logs/ruff_check.log +black src/ +black --check --verbose --diff --color src/ > ./logs/black_check.log +mypy --explicit-package-bases src/ > ./logs/mypy_check.log \ No newline at end of file diff --git a/src/surrealdb/__init__.py b/src/surrealdb/__init__.py index 10f1aa68..47d5fa68 100644 --- a/src/surrealdb/__init__.py +++ b/src/surrealdb/__init__.py @@ -14,6 +14,7 @@ from surrealdb.data.types.record_id import RecordID from surrealdb.data.types.datetime import IsoDateTimeWrapper + class AsyncSurrealDBMeta(type): def __call__(cls, *args, **kwargs): @@ -28,12 +29,20 @@ def __call__(cls, *args, **kwargs): constructed_url = Url(url) - if constructed_url.scheme == UrlScheme.HTTP or constructed_url.scheme == UrlScheme.HTTPS: + if ( + constructed_url.scheme == UrlScheme.HTTP + or constructed_url.scheme == UrlScheme.HTTPS + ): return AsyncHttpSurrealConnection(url=url) - elif constructed_url.scheme == UrlScheme.WS or constructed_url.scheme == UrlScheme.WSS: + elif ( + constructed_url.scheme == UrlScheme.WS + or constructed_url.scheme == UrlScheme.WSS + ): return AsyncWsSurrealConnection(url=url) else: - raise ValueError(f"Unsupported protocol in URL: {url}. Use 'ws://' or 'http://'.") + raise ValueError( + f"Unsupported protocol in URL: {url}. Use 'ws://' or 'http://'." + ) class BlockingSurrealDBMeta(type): @@ -50,28 +59,57 @@ def __call__(cls, *args, **kwargs): constructed_url = Url(url) - if constructed_url.scheme == UrlScheme.HTTP or constructed_url.scheme == UrlScheme.HTTPS: + if ( + constructed_url.scheme == UrlScheme.HTTP + or constructed_url.scheme == UrlScheme.HTTPS + ): return BlockingHttpSurrealConnection(url=url) - elif constructed_url.scheme == UrlScheme.WS or constructed_url.scheme == UrlScheme.WSS: + elif ( + constructed_url.scheme == UrlScheme.WS + or constructed_url.scheme == UrlScheme.WSS + ): return BlockingWsSurrealConnection(url=url) else: - raise ValueError(f"Unsupported protocol in URL: {url}. Use 'ws://' or 'http://'.") + raise ValueError( + f"Unsupported protocol in URL: {url}. Use 'ws://' or 'http://'." + ) + -def Surreal(url: Optional[str] = None) -> Union[BlockingWsSurrealConnection, BlockingHttpSurrealConnection]: +def Surreal( + url: Optional[str] = None, +) -> Union[BlockingWsSurrealConnection, BlockingHttpSurrealConnection]: constructed_url = Url(url) - if constructed_url.scheme == UrlScheme.HTTP or constructed_url.scheme == UrlScheme.HTTPS: + if ( + constructed_url.scheme == UrlScheme.HTTP + or constructed_url.scheme == UrlScheme.HTTPS + ): return BlockingHttpSurrealConnection(url=url) - elif constructed_url.scheme == UrlScheme.WS or constructed_url.scheme == UrlScheme.WSS: + elif ( + constructed_url.scheme == UrlScheme.WS + or constructed_url.scheme == UrlScheme.WSS + ): return BlockingWsSurrealConnection(url=url) else: - raise ValueError(f"Unsupported protocol in URL: {url}. Use 'ws://' or 'http://'.") + raise ValueError( + f"Unsupported protocol in URL: {url}. Use 'ws://' or 'http://'." + ) -def AsyncSurreal(url: Optional[str] = None) -> Union[AsyncWsSurrealConnection, AsyncHttpSurrealConnection]: +def AsyncSurreal( + url: Optional[str] = None, +) -> Union[AsyncWsSurrealConnection, AsyncHttpSurrealConnection]: constructed_url = Url(url) - if constructed_url.scheme == UrlScheme.HTTP or constructed_url.scheme == UrlScheme.HTTPS: + if ( + constructed_url.scheme == UrlScheme.HTTP + or constructed_url.scheme == UrlScheme.HTTPS + ): return AsyncHttpSurrealConnection(url=url) - elif constructed_url.scheme == UrlScheme.WS or constructed_url.scheme == UrlScheme.WSS: + elif ( + constructed_url.scheme == UrlScheme.WS + or constructed_url.scheme == UrlScheme.WSS + ): return AsyncWsSurrealConnection(url=url) else: - raise ValueError(f"Unsupported protocol in URL: {url}. Use 'ws://' or 'http://'.") + raise ValueError( + f"Unsupported protocol in URL: {url}. Use 'ws://' or 'http://'." + ) diff --git a/src/surrealdb/connections/async_http.py b/src/surrealdb/connections/async_http.py index efd7a126..8ab9364f 100644 --- a/src/surrealdb/connections/async_http.py +++ b/src/surrealdb/connections/async_http.py @@ -49,7 +49,7 @@ async def _send( message: RequestMessage, operation: str, bypass: bool = False, - ) -> Dict[str, Any]: + ) -> Dict[str, Any]: # type: ignore """ Sends an HTTP request to the SurrealDB server. @@ -74,7 +74,7 @@ async def _send( headers["Surreal-DB"] = self.database async with aiohttp.ClientSession() as session: - async with session.request( + async with session.request( method="POST", url=url, headers=headers, @@ -89,7 +89,7 @@ async def _send( self.check_response_for_error(data, operation) return data - def set_token(self, token: str) -> None: + def set_token(self, token: str) -> None: # type: ignore """ Sets the token for authentication. @@ -97,31 +97,23 @@ def set_token(self, token: str) -> None: """ self.token = token - async def authenticate(self) -> None: - message = RequestMessage( - self.id, - RequestMethod.AUTHENTICATE, - token=self.token - ) + async def authenticate(self) -> None: # type: ignore + message = RequestMessage(self.id, RequestMethod.AUTHENTICATE, token=self.token) return await self._send(message, "authenticating") - async def invalidate(self) -> None: + async def invalidate(self) -> None: # type: ignore message = RequestMessage(self.id, RequestMethod.INVALIDATE) await self._send(message, "invalidating") self.token = None - async def signup(self, vars: Dict) -> str: - message = RequestMessage( - self.id, - RequestMethod.SIGN_UP, - data=vars - ) + async def signup(self, vars: Dict) -> str: # type: ignore + message = RequestMessage(self.id, RequestMethod.SIGN_UP, data=vars) response = await self._send(message, "signup") self.check_response_for_result(response, "signup") self.token = response["result"] return response["result"] - async def signin(self, vars: dict) -> dict: + async def signin(self, vars: dict) -> dict: # type: ignore message = RequestMessage( self.id, RequestMethod.SIGN_IN, @@ -137,27 +129,24 @@ async def signin(self, vars: dict) -> dict: self.token = response["result"] return response["result"] - async def info(self) -> dict: - message = RequestMessage( - self.id, - RequestMethod.INFO - ) + async def info(self) -> dict: # type: ignore + message = RequestMessage(self.id, RequestMethod.INFO) response = await self._send(message, "getting database information") self.check_response_for_result(response, "getting database information") return response["result"] - async def use(self, namespace: str, database: str) -> None: + async def use(self, namespace: str, database: str) -> None: # type: ignore message = RequestMessage( self.token, RequestMethod.USE, namespace=namespace, database=database, ) - data = await self._send(message, "use") + _ = await self._send(message, "use") self.namespace = namespace self.database = database - async def query(self, query: str, params: Optional[dict] = None) -> dict: + async def query(self, query: str, params: Optional[dict] = None) -> dict: # type: ignore if params is None: params = {} for key, value in self.vars.items(): @@ -172,7 +161,7 @@ async def query(self, query: str, params: Optional[dict] = None) -> dict: self.check_response_for_result(response, "query") return response["result"][0]["result"] - async def query_raw(self, query: str, params: Optional[dict] = None) -> dict: + async def query_raw(self, query: str, params: Optional[dict] = None) -> dict: # type: ignore if params is None: params = {} for key, value in self.vars.items(): @@ -187,142 +176,108 @@ async def query_raw(self, query: str, params: Optional[dict] = None) -> dict: return response async def create( - self, - thing: Union[str, RecordID, Table], - data: Optional[Union[Union[List[dict], dict], dict]] = None, - ) -> Union[List[dict], dict]: + self, + thing: Union[str, RecordID, Table], + data: Optional[Union[Union[List[dict], dict], dict]] = None, + ) -> Union[List[dict], dict]: # type: ignore if isinstance(thing, str): if ":" in thing: buffer = thing.split(":") thing = RecordID(table_name=buffer[0], identifier=buffer[1]) message = RequestMessage( - self.id, - RequestMethod.CREATE, - collection=thing, - data=data + self.id, RequestMethod.CREATE, collection=thing, data=data ) response = await self._send(message, "create") self.check_response_for_result(response, "create") return response["result"] async def delete( - self, thing: Union[str, RecordID, Table] - ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, - RequestMethod.DELETE, - record_id=thing - ) + self, thing: Union[str, RecordID, Table] + ) -> Union[List[dict], dict]: # type: ignore + message = RequestMessage(self.id, RequestMethod.DELETE, record_id=thing) response = await self._send(message, "delete") self.check_response_for_result(response, "delete") return response["result"] async def insert( - self, table: Union[str, Table], data: Union[List[dict], dict] - ) -> Union[List[dict], dict]: + self, table: Union[str, Table], data: Union[List[dict], dict] + ) -> Union[List[dict], dict]: # type: ignore message = RequestMessage( - self.id, - RequestMethod.INSERT, - collection=table, - params=data + self.id, RequestMethod.INSERT, collection=table, params=data ) response = await self._send(message, "insert") self.check_response_for_result(response, "insert") return response["result"] async def insert_relation( - self, table: Union[str, Table], data: Union[List[dict], dict] - ) -> Union[List[dict], dict]: + self, table: Union[str, Table], data: Union[List[dict], dict] + ) -> Union[List[dict], dict]: # type: ignore message = RequestMessage( - self.id, - RequestMethod.INSERT_RELATION, - table=table, - params=data + self.id, RequestMethod.INSERT_RELATION, table=table, params=data ) response = await self._send(message, "insert_relation") self.check_response_for_result(response, "insert_relation") return response["result"] - async def let(self, key: str, value: Any) -> None: + async def let(self, key: str, value: Any) -> None: # type: ignore self.vars[key] = value - async def unset(self, key: str) -> None: + async def unset(self, key: str) -> None: # type: ignore self.vars.pop(key) async def merge( - self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Union[List[dict], dict]: + self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None + ) -> Union[List[dict], dict]: # type: ignore message = RequestMessage( - self.id, - RequestMethod.MERGE, - record_id=thing, - data=data + self.id, RequestMethod.MERGE, record_id=thing, data=data ) response = await self._send(message, "merge") self.check_response_for_result(response, "merge") return response["result"] async def patch( - self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None - ) -> Union[List[dict], dict]: + self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None + ) -> Union[List[dict], dict]: # type: ignore message = RequestMessage( - self.id, - RequestMethod.PATCH, - collection=thing, - params=data + self.id, RequestMethod.PATCH, collection=thing, params=data ) response = await self._send(message, "patch") self.check_response_for_result(response, "patch") return response["result"] - async def select(self, thing: str) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, - RequestMethod.SELECT, - params=[thing] - ) + async def select(self, thing: str) -> Union[List[dict], dict]: # type: ignore + message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing]) response = await self._send(message, "select") self.check_response_for_result(response, "select") return response["result"] async def update( - self, - thing: Union[str, RecordID, Table], - data: Optional[Dict] = None - ) -> Union[List[dict], dict]: + self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None + ) -> Union[List[dict], dict]: # type: ignore message = RequestMessage( - self.id, - RequestMethod.UPDATE, - record_id=thing, - data=data + self.id, RequestMethod.UPDATE, record_id=thing, data=data ) response = await self._send(message, "update") self.check_response_for_result(response, "update") return response["result"] - async def version(self) -> str: - message = RequestMessage( - self.id, - RequestMethod.VERSION - ) + async def version(self) -> str: # type: ignore + message = RequestMessage(self.id, RequestMethod.VERSION) response = await self._send(message, "getting database version") self.check_response_for_result(response, "getting database version") return response["result"] async def upsert( - self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Union[List[dict], dict]: + self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None + ) -> Union[List[dict], dict]: # type: ignore message = RequestMessage( - self.id, - RequestMethod.UPSERT, - record_id=thing, - data=data + self.id, RequestMethod.UPSERT, record_id=thing, data=data ) response = await self._send(message, "upsert") self.check_response_for_result(response, "upsert") return response["result"] - async def __aenter__(self) -> "AsyncHttpSurrealConnection": + async def __aenter__(self) -> "AsyncHttpSurrealConnection": # type: ignore """ Asynchronous context manager entry. Initializes an aiohttp session and returns the connection instance. @@ -330,7 +285,7 @@ async def __aenter__(self) -> "AsyncHttpSurrealConnection": self._session = aiohttp.ClientSession() return self - async def __aexit__(self, exc_type, exc_value, traceback) -> None: + async def __aexit__(self, exc_type, exc_value, traceback) -> None: # type: ignore """ Asynchronous context manager exit. Closes the aiohttp session upon exiting the context. diff --git a/src/surrealdb/connections/async_template.py b/src/surrealdb/connections/async_template.py index ec8ae0ee..f35cca7e 100644 --- a/src/surrealdb/connections/async_template.py +++ b/src/surrealdb/connections/async_template.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Dict, Any, Union +from typing import Optional, List, Dict, Any, Union, Coroutine from uuid import UUID from asyncio import Queue from surrealdb.data.types.record_id import RecordID @@ -7,7 +7,7 @@ class AsyncTemplate: - async def connect(self, url: str) -> None: + async def connect(self, url: str) -> Coroutine[Any, Any, None]: # type: ignore """Connects to a local or remote database endpoint. Args: @@ -20,7 +20,7 @@ async def connect(self, url: str) -> None: """ raise NotImplementedError(f"query not implemented for: {self}") - async def close(self) -> None: + async def close(self) -> Coroutine[Any, Any, None]: # type: ignore """Closes the persistent connection to the database. Example: @@ -28,7 +28,7 @@ async def close(self) -> None: """ raise NotImplementedError(f"query not implemented for: {self}") - async def use(self, namespace: str, database: str) -> None: + async def use(self, namespace: str, database: str) -> Coroutine[Any, Any, None]: # type: ignore """Switch to a specific namespace and database. Args: @@ -40,7 +40,7 @@ async def use(self, namespace: str, database: str) -> None: """ raise NotImplementedError(f"query not implemented for: {self}") - async def authenticate(self, token: str) -> None: + async def authenticate(self, token: str) -> Coroutine[Any, Any, None]: # type: ignore """Authenticate the current connection with a JWT token. Args: @@ -51,7 +51,7 @@ async def authenticate(self, token: str) -> None: """ raise NotImplementedError(f"authenticate not implemented for: {self}") - async def invalidate(self) -> None: + async def invalidate(self) -> Coroutine[Any, Any, None]: # type: ignore """Invalidate the authentication for the current connection. Example: @@ -59,7 +59,7 @@ async def invalidate(self) -> None: """ raise NotImplementedError(f"invalidate not implemented for: {self}") - async def signup(self, vars: Dict) -> str: + async def signup(self, vars: Dict) -> Coroutine[Any, Any, str]: # type: ignore """Sign this connection up to a specific authentication scope. [See the docs](https://surrealdb.com/docs/sdk/python/methods/signup) @@ -81,7 +81,7 @@ async def signup(self, vars: Dict) -> str: """ raise NotImplementedError(f"signup not implemented for: {self}") - async def signin(self, vars: Dict) -> str: + async def signin(self, vars: Dict) -> Coroutine[Any, Any, str]: # type: ignore """Sign this connection in to a specific authentication scope. [See the docs](https://surrealdb.com/docs/sdk/python/methods/signin) @@ -96,7 +96,7 @@ async def signin(self, vars: Dict) -> str: """ raise NotImplementedError(f"query not implemented for: {self}") - async def let(self, key: str, value: Any) -> None: + async def let(self, key: str, value: Any) -> Coroutine[Any, Any, None]: # type: ignore """Assign a value as a variable for this connection. Args: @@ -115,7 +115,7 @@ async def let(self, key: str, value: Any) -> None: """ raise NotImplementedError(f"let not implemented for: {self}") - async def unset(self, key: str) -> None: + async def unset(self, key: str) -> Coroutine[Any, Any, None]: # type: ignore """Removes a variable for this connection. Args: @@ -128,7 +128,7 @@ async def unset(self, key: str) -> None: async def query( self, query: str, vars: Optional[Dict] = None - ) -> Union[List[dict], dict]: + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore """Run a unset of SurrealQL statements against the database. Args: @@ -143,7 +143,9 @@ async def query( """ raise NotImplementedError(f"query not implemented for: {self}") - async def select(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: + async def select( + self, thing: Union[str, RecordID, Table] + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore """Select all records in a table (or other entity), or a specific record, in the database. @@ -161,8 +163,8 @@ async def select(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], async def create( self, thing: Union[str, RecordID, Table], - data: Optional[Union[Union[List[dict], dict], dict]] = None, - ) -> Union[List[dict], dict]: + data: Optional[Union[List[dict], dict]] = None, + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore """Create a record in the database. This function will run the following query in the database: @@ -178,10 +180,8 @@ async def create( raise NotImplementedError(f"create not implemented for: {self}") async def update( - self, - thing: Union[str, RecordID, Table], - data: Optional[Dict] = None - ) -> Union[List[dict], dict]: + self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore """Update all records in a table, or a specific record, in the database. This function replaces the current document / record data with the @@ -208,10 +208,10 @@ async def update( }) """ raise NotImplementedError(f"query not implemented for: {self}") - + async def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Union[List[dict], dict]: + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore """Insert records into the database, or to update them if they exist. @@ -239,7 +239,7 @@ async def upsert( async def merge( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Union[List[dict], dict]: + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore """Modify by deep merging all records in a table, or a specific record, in the database. This function merges the current document / record data with the @@ -271,7 +271,7 @@ async def merge( async def patch( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Union[List[dict], dict]: + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore """Apply JSON Patch changes to all records, or a specific record, in the database. This function patches the current document / record data with @@ -300,7 +300,7 @@ async def patch( async def delete( self, thing: Union[str, RecordID, Table] - ) -> Union[List[dict], dict]: + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore """Delete all records in a table, or a specific record, from the database. This function will run the following query in the database: @@ -318,7 +318,7 @@ async def delete( """ raise NotImplementedError(f"delete not implemented for: {self}") - def info(self) -> dict: + async def info(self) -> Coroutine[Any, Any, dict]: # type: ignore """This returns the record of an authenticated record user. Example: @@ -326,9 +326,9 @@ def info(self) -> dict: """ raise NotImplementedError(f"query not implemented for: {self}") - def insert( + async def insert( self, table: Union[str, Table], data: Union[List[dict], dict] - ) -> Union[List[dict], dict]: + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore """ Inserts one or multiple records in the database. @@ -345,9 +345,9 @@ def insert( """ raise NotImplementedError(f"query not implemented for: {self}") - def insert_relation( + async def insert_relation( self, table: Union[str, Table], data: Union[List[dict], dict] - ) -> Union[List[dict], dict]: + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore """ Inserts one or multiple relations in the database. @@ -364,7 +364,9 @@ def insert_relation( """ raise NotImplementedError(f"query not implemented for: {self}") - async def live(self, table: Union[str, Table], diff: bool = False) -> UUID: + async def live( + self, table: Union[str, Table], diff: bool = False + ) -> Coroutine[Any, Any, UUID]: # type: ignore """Initiates a live query for a specified table name. Args: @@ -381,7 +383,9 @@ async def live(self, table: Union[str, Table], diff: bool = False) -> UUID: """ raise NotImplementedError(f"query not implemented for: {self}") - async def subscribe_live(self, query_uuid: Union[str, UUID]) -> Queue: + async def subscribe_live( + self, query_uuid: Union[str, UUID] + ) -> Coroutine[Any, Any, Queue]: # type: ignore """Returns a queue that receives notification messages from a running live query. Args: @@ -395,7 +399,7 @@ async def subscribe_live(self, query_uuid: Union[str, UUID]) -> Queue: """ raise NotImplementedError(f"query not implemented for: {self}") - async def kill(self, query_uuid: Union[str, UUID]) -> None: + async def kill(self, query_uuid: Union[str, UUID]) -> Coroutine[Any, Any, None]: # type: ignore """Kills a running live query by it's UUID. Args: @@ -406,18 +410,3 @@ async def kill(self, query_uuid: Union[str, UUID]) -> None: """ raise NotImplementedError(f"query not implemented for: {self}") - - - async def signin(self, vars: Dict) -> str: - """Sign this connection in to a specific authentication scope. - [See the docs](https://surrealdb.com/docs/sdk/python/methods/signin) - - Args: - vars: Variables used in a signin query. - - Example: - await db.signin({ - username: 'root', - password: 'surrealdb', - }) - """ \ No newline at end of file diff --git a/src/surrealdb/connections/async_ws.py b/src/surrealdb/connections/async_ws.py index 2cef12b1..b260611b 100644 --- a/src/surrealdb/connections/async_ws.py +++ b/src/surrealdb/connections/async_ws.py @@ -1,6 +1,7 @@ """ A basic async connection to a SurrealDB instance. """ + import asyncio import uuid from asyncio import Queue @@ -31,9 +32,10 @@ class AsyncWsSurrealConnection(AsyncTemplate, UtilsMixin): database: The database that the connection will stick to. id: The ID of the connection. """ + def __init__( - self, - url: str, + self, + url: str, ) -> None: """ The constructor for the AsyncSurrealConnection class. @@ -42,58 +44,55 @@ def __init__( """ self.url: Url = Url(url) self.raw_url: str = f"{self.url.raw_url}/rpc" - self.host: str = self.url.hostname - self.port: int = self.url.port + self.host: str | None = self.url.hostname + self.port: int | None = self.url.port self.id: str = str(uuid.uuid4()) self.token: Optional[str] = None self.socket = None - async def _send(self, message: RequestMessage, process: str, bypass: bool = False) -> dict: + async def _send( + self, message: RequestMessage, process: str, bypass: bool = False + ) -> dict: # type: ignore await self.connect() + assert ( + self.socket is not None + ) # will always not be None as the self.connect ensures there's a connection await self.socket.send(message.WS_CBOR_DESCRIPTOR) response = decode(await self.socket.recv()) if bypass is False: self.check_response_for_error(response, process) return response - async def connect(self, url: Optional[str] = None) -> None: + async def connect(self, url: Optional[str] = None) -> None: # type: ignore # overwrite params if passed in if url is not None: self.url = Url(url) - self.raw_url: str = f"{self.url.raw_url}/rpc" - self.host: str = self.url.hostname - self.port: int = self.url.port + self.raw_url = f"{self.url.raw_url}/rpc" + self.host = self.url.hostname + self.port = self.url.port if self.socket is None: self.socket = await websockets.connect( self.raw_url, max_size=None, - subprotocols=[websockets.Subprotocol("cbor")] + subprotocols=[websockets.Subprotocol("cbor")], ) - async def authenticate(self, token: str) -> dict: - message = RequestMessage( - self.id, - RequestMethod.AUTHENTICATE, - token=token - ) + async def authenticate(self, token: str) -> dict: # type: ignore + message = RequestMessage(self.id, RequestMethod.AUTHENTICATE, token=token) return await self._send(message, "authenticating") - async def invalidate(self) -> None: + async def invalidate(self) -> None: # type: ignore message = RequestMessage(self.id, RequestMethod.INVALIDATE) await self._send(message, "invalidating") self.token = None - async def signup(self, vars: Dict) -> str: - message = RequestMessage( - self.id, - RequestMethod.SIGN_UP, - data=vars - ) + async def signup(self, vars: Dict) -> str: # type: ignore + message = RequestMessage(self.id, RequestMethod.SIGN_UP, data=vars) response = await self._send(message, "signup") self.check_response_for_result(response, "signup") return response["result"] - async def signin(self, vars: Dict[str, Any]) -> str: + async def signin(self, vars: Dict[str, Any]) -> str: # type: ignore message = RequestMessage( self.id, RequestMethod.SIGN_IN, @@ -109,16 +108,13 @@ async def signin(self, vars: Dict[str, Any]) -> str: self.token = response["result"] return response["result"] - async def info(self) -> Optional[dict]: - message = RequestMessage( - self.id, - RequestMethod.INFO - ) + async def info(self) -> Optional[dict]: # type: ignore + message = RequestMessage(self.id, RequestMethod.INFO) outcome = await self._send(message, "getting database information") self.check_response_for_result(outcome, "getting database information") return outcome["result"] - async def use(self, namespace: str, database: str) -> None: + async def use(self, namespace: str, database: str) -> None: # type: ignore message = RequestMessage( self.id, RequestMethod.USE, @@ -127,7 +123,7 @@ async def use(self, namespace: str, database: str) -> None: ) await self._send(message, "use") - async def query(self, query: str, params: Optional[dict] = None) -> dict: + async def query(self, query: str, params: Optional[dict] = None) -> dict: # type: ignore if params is None: params = {} message = RequestMessage( @@ -140,7 +136,7 @@ async def query(self, query: str, params: Optional[dict] = None) -> dict: self.check_response_for_result(response, "query") return response["result"][0]["result"] - async def query_raw(self, query: str, params: Optional[dict] = None) -> dict: + async def query_raw(self, query: str, params: Optional[dict] = None) -> dict: # type: ignore if params is None: params = {} message = RequestMessage( @@ -152,141 +148,101 @@ async def query_raw(self, query: str, params: Optional[dict] = None) -> dict: response = await self._send(message, "query", bypass=True) return response - async def version(self) -> str: - message = RequestMessage( - self.id, - RequestMethod.VERSION - ) + async def version(self) -> str: # type: ignore + message = RequestMessage(self.id, RequestMethod.VERSION) response = await self._send(message, "getting database version") self.check_response_for_result(response, "getting database version") return response["result"] - async def let(self, key: str, value: Any) -> None: - message = RequestMessage( - self.id, - RequestMethod.LET, - key=key, - value=value - ) + async def let(self, key: str, value: Any) -> None: # type: ignore + message = RequestMessage(self.id, RequestMethod.LET, key=key, value=value) await self._send(message, "letting") - async def unset(self, key: str) -> None: - message = RequestMessage( - self.id, - RequestMethod.UNSET, - params=[key] - ) + async def unset(self, key: str) -> None: # type: ignore + message = RequestMessage(self.id, RequestMethod.UNSET, params=[key]) await self._send(message, "unsetting") - async def select(self, thing: str) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, - RequestMethod.SELECT, - params=[thing] - ) + async def select(self, thing: str | RecordID | Table) -> Union[List[dict], dict]: # type: ignore + message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing]) response = await self._send(message, "select") self.check_response_for_result(response, "select") return response["result"] async def create( - self, - thing: Union[str, RecordID, Table], - data: Optional[Union[Union[List[dict], dict], dict]] = None, - ) -> Union[List[dict], dict]: + self, + thing: Union[str, RecordID, Table], + data: Optional[Union[Union[List[dict], dict], dict]] = None, + ) -> Union[List[dict], dict]: # type: ignore if isinstance(thing, str): if ":" in thing: buffer = thing.split(":") thing = RecordID(table_name=buffer[0], identifier=buffer[1]) message = RequestMessage( - self.id, - RequestMethod.CREATE, - collection=thing, - data=data + self.id, RequestMethod.CREATE, collection=thing, data=data ) response = await self._send(message, "create") self.check_response_for_result(response, "create") return response["result"] async def update( - self, - thing: Union[str, RecordID, Table], - data: Optional[Dict] = None - ) -> Union[List[dict], dict]: + self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None + ) -> Union[List[dict], dict]: # type: ignore message = RequestMessage( - self.id, - RequestMethod.UPDATE, - record_id=thing, - data=data + self.id, RequestMethod.UPDATE, record_id=thing, data=data ) response = await self._send(message, "update") self.check_response_for_result(response, "update") return response["result"] async def merge( - self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Union[List[dict], dict]: + self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None + ) -> Union[List[dict], dict]: # type: ignore message = RequestMessage( - self.id, - RequestMethod.MERGE, - record_id=thing, - data=data + self.id, RequestMethod.MERGE, record_id=thing, data=data ) response = await self._send(message, "merge") self.check_response_for_result(response, "merge") return response["result"] async def patch( - self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None - ) -> Union[List[dict], dict]: + self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None + ) -> Union[List[dict], dict]: # type: ignore message = RequestMessage( - self.id, - RequestMethod.PATCH, - collection=thing, - params=data + self.id, RequestMethod.PATCH, collection=thing, params=data ) response = await self._send(message, "patch") self.check_response_for_result(response, "patch") return response["result"] async def delete( - self, thing: Union[str, RecordID, Table] - ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, - RequestMethod.DELETE, - record_id=thing - ) + self, thing: Union[str, RecordID, Table] + ) -> Union[List[dict], dict]: # type: ignore + message = RequestMessage(self.id, RequestMethod.DELETE, record_id=thing) response = await self._send(message, "delete") self.check_response_for_result(response, "delete") return response["result"] async def insert( - self, table: Union[str, Table], data: Union[List[dict], dict] - ) -> Union[List[dict], dict]: + self, table: Union[str, Table], data: Union[List[dict], dict] + ) -> Union[List[dict], dict]: # type: ignore message = RequestMessage( - self.id, - RequestMethod.INSERT, - collection=table, - params=data + self.id, RequestMethod.INSERT, collection=table, params=data ) response = await self._send(message, "insert") self.check_response_for_result(response, "insert") return response["result"] async def insert_relation( - self, table: Union[str, Table], data: Union[List[dict], dict] - ) -> Union[List[dict], dict]: + self, table: Union[str, Table], data: Union[List[dict], dict] + ) -> Union[List[dict], dict]: # type: ignore message = RequestMessage( - self.id, - RequestMethod.INSERT_RELATION, - table=table, - params=data + self.id, RequestMethod.INSERT_RELATION, table=table, params=data ) response = await self._send(message, "insert_relation") self.check_response_for_result(response, "insert_relation") return response["result"] - async def live(self, table: Union[str, Table], diff: bool = False) -> UUID: + async def live(self, table: Union[str, Table], diff: bool = False) -> UUID: # type: ignore message = RequestMessage( self.id, RequestMethod.LIVE, @@ -296,7 +252,9 @@ async def live(self, table: Union[str, Table], diff: bool = False) -> UUID: self.check_response_for_result(response, "live") return response["result"] - async def subscribe_live(self, query_uuid: Union[str, UUID]) -> AsyncGenerator[dict, None]: + async def subscribe_live( + self, query_uuid: Union[str, UUID] + ) -> AsyncGenerator[dict, None]: # type: ignore result_queue = Queue() async def listen_live(): @@ -320,22 +278,15 @@ async def listen_live(): raise Exception(f"Error in live subscription: {result['error']}") yield result - async def kill(self, query_uuid: Union[str, UUID]) -> None: - message = RequestMessage( - self.id, - RequestMethod.KILL, - uuid=query_uuid - ) + async def kill(self, query_uuid: Union[str, UUID]) -> None: # type: ignore + message = RequestMessage(self.id, RequestMethod.KILL, uuid=query_uuid) await self._send(message, "kill") async def upsert( - self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Union[List[dict], dict]: + self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None + ) -> Union[List[dict], dict]: # type: ignore message = RequestMessage( - self.id, - RequestMethod.UPSERT, - record_id=thing, - data=data + self.id, RequestMethod.UPSERT, record_id=thing, data=data ) response = await self._send(message, "upsert") self.check_response_for_result(response, "upsert") @@ -344,19 +295,17 @@ async def upsert( async def close(self): await self.socket.close() - async def __aenter__(self) -> "AsyncWsSurrealConnection": + async def __aenter__(self) -> "AsyncWsSurrealConnection": # type: ignore """ Asynchronous context manager entry. Initializes a websocket connection and returns the connection instance. """ self.socket = await websockets.connect( - self.raw_url, - max_size=None, - subprotocols=[websockets.Subprotocol("cbor")] + self.raw_url, max_size=None, subprotocols=[websockets.Subprotocol("cbor")] ) return self - async def __aexit__(self, exc_type, exc_value, traceback) -> None: + async def __aexit__(self, exc_type, exc_value, traceback) -> None: # type: ignore """ Asynchronous context manager exit. Closes the websocket connection upon exiting the context. diff --git a/src/surrealdb/connections/blocking_http.py b/src/surrealdb/connections/blocking_http.py index e0872246..a02139a8 100644 --- a/src/surrealdb/connections/blocking_http.py +++ b/src/surrealdb/connections/blocking_http.py @@ -1,5 +1,5 @@ import uuid -from typing import Optional, Any, Dict, Union, List +from typing import Optional, Any, Dict, Union, List, cast import requests @@ -18,15 +18,17 @@ class BlockingHttpSurrealConnection(SyncTemplate, UtilsMixin): def __init__(self, url: str) -> None: self.url: Url = Url(url) self.raw_url: str = url.rstrip("/") - self.host: str = self.url.hostname + self.host: Optional[str] = self.url.hostname self.port: Optional[int] = self.url.port self.token: Optional[str] = None self.id: str = str(uuid.uuid4()) self.namespace: Optional[str] = None self.database: Optional[str] = None - self.vars = dict() + self.vars: Dict[str, Any] = dict() - def _send(self, message: RequestMessage, operation: str, bypass: bool = False) -> Dict[str, Any]: + def _send( + self, message: RequestMessage, operation: str, bypass: bool = False + ) -> Dict[str, Any]: data = message.WS_CBOR_DESCRIPTOR url = f"{self.url.raw_url}/rpc" headers = { @@ -42,21 +44,20 @@ def _send(self, message: RequestMessage, operation: str, bypass: bool = False) - response = requests.post(url, headers=headers, data=data, timeout=30) response.raise_for_status() + raw_cbor = response.content - data = decode(raw_cbor) - if bypass is False: - self.check_response_for_error(data, operation) - return data + data_dict = cast(Dict[str, Any], decode(raw_cbor)) + + if not bypass: + self.check_response_for_error(data_dict, operation) + + return data_dict def set_token(self, token: str) -> None: self.token = token def authenticate(self, token: str) -> dict: - message = RequestMessage( - self.id, - RequestMethod.AUTHENTICATE, - token=token - ) + message = RequestMessage(self.id, RequestMethod.AUTHENTICATE, token=token) return self._send(message, "authenticating") def invalidate(self) -> None: @@ -65,16 +66,13 @@ def invalidate(self) -> None: self.token = None def signup(self, vars: Dict) -> str: - message = RequestMessage( - self.id, - RequestMethod.SIGN_UP, - data=vars - ) + message = RequestMessage(self.id, RequestMethod.SIGN_UP, data=vars) response = self._send(message, "signup") self.check_response_for_result(response, "signup") + self.token = response["result"] return response["result"] - def signin(self, vars: dict) -> dict: + def signin(self, vars: dict) -> str: message = RequestMessage( self.id, RequestMethod.SIGN_IN, @@ -88,13 +86,10 @@ def signin(self, vars: dict) -> dict: response = self._send(message, "signing in") self.check_response_for_result(response, "signing in") self.token = response["result"] - return response["result"] + return str(response["result"]) def info(self): - message = RequestMessage( - self.id, - RequestMethod.INFO - ) + message = RequestMessage(self.id, RequestMethod.INFO) response = self._send(message, "getting database information") self.check_response_for_result(response, "getting database information") return response["result"] @@ -106,7 +101,7 @@ def use(self, namespace: str, database: str) -> None: namespace=namespace, database=database, ) - data = self._send(message, "use") + _ = self._send(message, "use") self.namespace = namespace self.database = database @@ -140,57 +135,42 @@ def query_raw(self, query: str, params: Optional[dict] = None) -> dict: return response def create( - self, - thing: Union[str, RecordID, Table], - data: Optional[Union[Union[List[dict], dict], dict]] = None, + self, + thing: Union[str, RecordID, Table], + data: Optional[Union[Union[List[dict], dict], dict]] = None, ) -> Union[List[dict], dict]: if isinstance(thing, str): if ":" in thing: buffer = thing.split(":") thing = RecordID(table_name=buffer[0], identifier=buffer[1]) message = RequestMessage( - self.id, - RequestMethod.CREATE, - collection=thing, - data=data + self.id, RequestMethod.CREATE, collection=thing, data=data ) response = self._send(message, "create") self.check_response_for_result(response, "create") return response["result"] - def delete( - self, thing: Union[str, RecordID, Table] - ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, - RequestMethod.DELETE, - record_id=thing - ) + def delete(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: + message = RequestMessage(self.id, RequestMethod.DELETE, record_id=thing) response = self._send(message, "delete") self.check_response_for_result(response, "delete") return response["result"] def insert( - self, table: Union[str, Table], data: Union[List[dict], dict] + self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, - RequestMethod.INSERT, - collection=table, - params=data + self.id, RequestMethod.INSERT, collection=table, params=data ) response = self._send(message, "insert") self.check_response_for_result(response, "insert") return response["result"] def insert_relation( - self, table: Union[str, Table], data: Union[List[dict], dict] + self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, - RequestMethod.INSERT_RELATION, - table=table, - params=data + self.id, RequestMethod.INSERT_RELATION, table=table, params=data ) response = self._send(message, "insert_relation") self.check_response_for_result(response, "insert_relation") @@ -203,89 +183,57 @@ def unset(self, key: str) -> None: self.vars.pop(key) def merge( - self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None + self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, - RequestMethod.MERGE, - record_id=thing, - data=data + self.id, RequestMethod.MERGE, record_id=thing, data=data ) response = self._send(message, "merge") self.check_response_for_result(response, "merge") return response["result"] def patch( - self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None + self, thing: Union[str, RecordID, Table], data: Optional[Dict[Any, Any]] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, - RequestMethod.PATCH, - collection=thing, - params=data + self.id, RequestMethod.PATCH, collection=thing, params=data ) response = self._send(message, "patch") self.check_response_for_result(response, "patch") return response["result"] - def select(self, thing: str) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, - RequestMethod.SELECT, - params=[thing] - ) + def select(self, thing: str | RecordID | Table) -> Union[List[dict], dict]: + message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing]) response = self._send(message, "select") self.check_response_for_result(response, "select") return response["result"] def update( - self, - thing: Union[str, RecordID, Table], - data: Optional[Dict] = None + self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, - RequestMethod.UPDATE, - record_id=thing, - data=data + self.id, RequestMethod.UPDATE, record_id=thing, data=data ) response = self._send(message, "update") self.check_response_for_result(response, "update") return response["result"] def version(self) -> str: - message = RequestMessage( - self.id, - RequestMethod.VERSION - ) + message = RequestMessage(self.id, RequestMethod.VERSION) response = self._send(message, "getting database version") self.check_response_for_result(response, "getting database version") return response["result"] def upsert( - self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None + self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, - RequestMethod.UPSERT, - record_id=thing, - data=data + self.id, RequestMethod.UPSERT, record_id=thing, data=data ) response = self._send(message, "upsert") self.check_response_for_result(response, "upsert") return response["result"] - def signup(self, vars: Dict) -> str: - message = RequestMessage( - self.id, - RequestMethod.SIGN_UP, - data=vars - ) - response = self._send(message, "signup") - self.check_response_for_result(response, "signup") - self.token = response["result"] - return response["result"] - def __enter__(self) -> "BlockingHttpSurrealConnection": """ Synchronous context manager entry. diff --git a/src/surrealdb/connections/blocking_ws.py b/src/surrealdb/connections/blocking_ws.py index a9d0d9ef..52b1bdb9 100644 --- a/src/surrealdb/connections/blocking_ws.py +++ b/src/surrealdb/connections/blocking_ws.py @@ -1,6 +1,7 @@ """ A basic blocking connection to a SurrealDB instance. """ + import uuid from typing import Optional, Any, Dict, Union, List, Generator from uuid import UUID @@ -45,7 +46,9 @@ def __init__(self, url: str) -> None: self.token: Optional[str] = None self.socket = None - def _send(self, message: RequestMessage, process: str, bypass: bool = False) -> dict: + def _send( + self, message: RequestMessage, process: str, bypass: bool = False + ) -> dict: if self.socket is None: self.socket = ws_sync.connect( self.raw_url, @@ -59,11 +62,7 @@ def _send(self, message: RequestMessage, process: str, bypass: bool = False) -> return response def authenticate(self, token: str) -> dict: - message = RequestMessage( - self.id, - RequestMethod.AUTHENTICATE, - token=token - ) + message = RequestMessage(self.id, RequestMethod.AUTHENTICATE, token=token) return self._send(message, "authenticating") def invalidate(self) -> None: @@ -72,11 +71,7 @@ def invalidate(self) -> None: self.token = None def signup(self, vars: Dict) -> str: - message = RequestMessage( - self.id, - RequestMethod.SIGN_UP, - data=vars - ) + message = RequestMessage(self.id, RequestMethod.SIGN_UP, data=vars) response = self._send(message, "signup") self.check_response_for_result(response, "signup") return response["result"] @@ -98,10 +93,7 @@ def signin(self, vars: Dict[str, Any]) -> str: return response["result"] def info(self) -> dict: - message = RequestMessage( - self.id, - RequestMethod.INFO - ) + message = RequestMessage(self.id, RequestMethod.INFO) response = self._send(message, "getting database information") self.check_response_for_result(response, "getting database information") return response["result"] @@ -141,55 +133,36 @@ def query_raw(self, query: str, params: Optional[dict] = None) -> dict: return response def version(self) -> str: - message = RequestMessage( - self.id, - RequestMethod.VERSION - ) + message = RequestMessage(self.id, RequestMethod.VERSION) response = self._send(message, "getting database version") self.check_response_for_result(response, "getting database version") return response["result"] def let(self, key: str, value: Any) -> None: - message = RequestMessage( - self.id, - RequestMethod.LET, - key=key, - value=value - ) + message = RequestMessage(self.id, RequestMethod.LET, key=key, value=value) self._send(message, "letting") def unset(self, key: str) -> None: - message = RequestMessage( - self.id, - RequestMethod.UNSET, - params=[key] - ) + message = RequestMessage(self.id, RequestMethod.UNSET, params=[key]) self._send(message, "unsetting") - def select(self, thing: str) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, - RequestMethod.SELECT, - params=[thing] - ) + def select(self, thing: str | RecordID | Table) -> Union[List[dict], dict]: + message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing]) response = self._send(message, "select") self.check_response_for_result(response, "select") return response["result"] def create( - self, - thing: Union[str, RecordID, Table], - data: Optional[Union[Union[List[dict], dict], dict]] = None, + self, + thing: Union[str, RecordID, Table], + data: Optional[Union[Union[List[dict], dict], dict]] = None, ) -> Union[List[dict], dict]: if isinstance(thing, str): if ":" in thing: buffer = thing.split(":") thing = RecordID(table_name=buffer[0], identifier=buffer[1]) message = RequestMessage( - self.id, - RequestMethod.CREATE, - collection=thing, - data=data + self.id, RequestMethod.CREATE, collection=thing, data=data ) response = self._send(message, "create") self.check_response_for_result(response, "create") @@ -206,78 +179,58 @@ def live(self, table: Union[str, Table], diff: bool = False) -> UUID: return response["result"] def kill(self, query_uuid: Union[str, UUID]) -> None: - message = RequestMessage( - self.id, - RequestMethod.KILL, - uuid=query_uuid - ) + message = RequestMessage(self.id, RequestMethod.KILL, uuid=query_uuid) self._send(message, "kill") - def delete( - self, thing: Union[str, RecordID, Table] - ) -> Union[List[dict], dict]: - message = RequestMessage( - self.id, - RequestMethod.DELETE, - record_id=thing - ) + def delete(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: + message = RequestMessage(self.id, RequestMethod.DELETE, record_id=thing) response = self._send(message, "delete") self.check_response_for_result(response, "delete") return response["result"] def insert( - self, table: Union[str, Table], data: Union[List[dict], dict] + self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, - RequestMethod.INSERT, - collection=table, - params=data + self.id, RequestMethod.INSERT, collection=table, params=data ) response = self._send(message, "insert") self.check_response_for_result(response, "insert") return response["result"] def insert_relation( - self, table: Union[str, Table], data: Union[List[dict], dict] + self, table: Union[str, Table], data: Union[List[dict], dict] ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, - RequestMethod.INSERT_RELATION, - table=table, - params=data + self.id, RequestMethod.INSERT_RELATION, table=table, params=data ) response = self._send(message, "insert_relation") self.check_response_for_result(response, "insert_relation") return response["result"] def merge( - self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None + self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, - RequestMethod.MERGE, - record_id=thing, - data=data + self.id, RequestMethod.MERGE, record_id=thing, data=data ) response = self._send(message, "merge") self.check_response_for_result(response, "merge") return response["result"] def patch( - self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None + self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, - RequestMethod.PATCH, - collection=thing, - params=data + self.id, RequestMethod.PATCH, collection=thing, params=data ) response = self._send(message, "patch") self.check_response_for_result(response, "patch") return response["result"] - def subscribe_live(self, query_uuid: Union[str, UUID]) -> Generator[dict, None, None]: + def subscribe_live( + self, query_uuid: Union[str, UUID] + ) -> Generator[dict, None, None]: """ Subscribe to live updates for a given query UUID. @@ -305,45 +258,25 @@ def subscribe_live(self, query_uuid: Union[str, UUID]) -> Generator[dict, None, pass def update( - self, - thing: Union[str, RecordID, Table], - data: Optional[Dict] = None + self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, - RequestMethod.UPDATE, - record_id=thing, - data=data + self.id, RequestMethod.UPDATE, record_id=thing, data=data ) response = self._send(message, "update") self.check_response_for_result(response, "update") return response["result"] def upsert( - self, - thing: Union[str, RecordID, Table], - data: Optional[Dict] = None + self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: message = RequestMessage( - self.id, - RequestMethod.UPSERT, - record_id=thing, - data=data + self.id, RequestMethod.UPSERT, record_id=thing, data=data ) response = self._send(message, "upsert") self.check_response_for_result(response, "upsert") return response["result"] - def signup(self, vars: Dict) -> str: - message = RequestMessage( - self.id, - RequestMethod.SIGN_UP, - data=vars - ) - response = self._send(message, "signup") - self.check_response_for_result(response, "signup") - return response["result"] - def close(self): self.socket.close() @@ -353,9 +286,7 @@ def __enter__(self) -> "BlockingWsSurrealConnection": Initializes a websocket connection and returns the connection instance. """ self.socket = ws_sync.connect( - self.raw_url, - max_size=None, - subprotocols=[websockets.Subprotocol("cbor")] + self.raw_url, max_size=None, subprotocols=[websockets.Subprotocol("cbor")] ) return self @@ -366,4 +297,3 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: """ if self.socket is not None: self.socket.close() - diff --git a/src/surrealdb/connections/sync_template.py b/src/surrealdb/connections/sync_template.py index 10db2b96..726d3ccd 100644 --- a/src/surrealdb/connections/sync_template.py +++ b/src/surrealdb/connections/sync_template.py @@ -1,6 +1,7 @@ +from asyncio import Queue from typing import Optional, List, Dict, Any, Union from uuid import UUID -from asyncio import Queue + from surrealdb.data.types.record_id import RecordID from surrealdb.data.types.table import Table @@ -77,7 +78,7 @@ def signup(self, vars: Dict) -> str: namespace: 'surrealdb', database: 'docs', access: 'user', - + # Also pass any properties required by the scope definition variables: { email: 'info@surrealdb.com', @@ -132,9 +133,7 @@ def unset(self, key: str) -> None: """ raise NotImplementedError(f"let not implemented for: {self}") - def query( - self, query: str, vars: Optional[Dict] = None - ) -> Union[List[dict], dict]: + def query(self, query: str, vars: Optional[Dict] = None) -> Union[List[dict], dict]: """Run a set of SurrealQL statements against the database. Args: @@ -212,7 +211,7 @@ def update( }) """ raise NotImplementedError(f"update not implemented for: {self}") - + def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None ) -> Union[List[dict], dict]: @@ -302,9 +301,7 @@ def patch( """ raise NotImplementedError(f"patch not implemented for: {self}") - def delete( - self, thing: Union[str, RecordID, Table] - ) -> Union[List[dict], dict]: + def delete(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: """Delete all records in a table, or a specific record, from the database. This function will run the following query in the database: @@ -410,19 +407,3 @@ def kill(self, query_uuid: Union[str, UUID]) -> None: """ raise NotImplementedError(f"kill not implemented for: {self}") - - - def signin(self, vars: Dict) -> str: - """Sign this connection in to a specific authentication scope. - [See the docs](https://surrealdb.com/docs/sdk/python/methods/signin) - - Args: - vars: Variables used in a signin query. - - Example: - db.signin({ - username: 'root', - password: 'surrealdb', - }) - """ - raise NotImplementedError(f"signin not implemented for: {self}") \ No newline at end of file diff --git a/src/surrealdb/connections/utils_mixin.py b/src/surrealdb/connections/utils_mixin.py index 2473ee71..1dbcc04e 100644 --- a/src/surrealdb/connections/utils_mixin.py +++ b/src/surrealdb/connections/utils_mixin.py @@ -1,5 +1,3 @@ - - class UtilsMixin: @staticmethod @@ -9,5 +7,5 @@ def check_response_for_error(response: dict, process: str) -> None: @staticmethod def check_response_for_result(response: dict, process: str) -> None: - if "result" not in response.keys(): + if "result" not in response.keys(): raise Exception(f"no result {process}: {response}") diff --git a/src/surrealdb/data/cbor.py b/src/surrealdb/data/cbor.py index c1125df6..2f2a63de 100644 --- a/src/surrealdb/data/cbor.py +++ b/src/surrealdb/data/cbor.py @@ -75,6 +75,7 @@ def default_encoder(encoder, obj): encoder.encode(tagged) + def tag_decoder(decoder, tag, shareable_index=None): if tag.tag == constants.TAG_GEOMETRY_POINT: return GeometryPoint.parse_coordinates(tag.value) @@ -132,7 +133,6 @@ def tag_decoder(decoder, tag, shareable_index=None): raise BufferError("no decoder for tag", tag.tag) - def encode(obj): return cbor2.dumps(obj, default=default_encoder, timezone=timezone.utc) diff --git a/src/surrealdb/data/types/datetime.py b/src/surrealdb/data/types/datetime.py index cff8d2b1..93f6f771 100644 --- a/src/surrealdb/data/types/datetime.py +++ b/src/surrealdb/data/types/datetime.py @@ -1,16 +1,3 @@ -# from datetime import datetime - - -# class DatetimeWrapper: -# -# def __init__(self, dt: datetime): -# self.dt = dt -# -# @staticmethod -# def now() -> "DatetimeWrapper": -# return DatetimeWrapper(datetime.now()) - - class IsoDateTimeWrapper: def __init__(self, dt: str): diff --git a/src/surrealdb/data/types/duration.py b/src/surrealdb/data/types/duration.py index 0ec87001..11fff8c5 100644 --- a/src/surrealdb/data/types/duration.py +++ b/src/surrealdb/data/types/duration.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from typing import Tuple, Union -from math import floor, pow +from math import floor UNITS = { "ns": 1, @@ -13,6 +13,7 @@ "w": int(604800 * 1e9), } + @dataclass class Duration: elapsed: int = 0 # nanoseconds diff --git a/src/surrealdb/data/types/future.py b/src/surrealdb/data/types/future.py index eca6e4cf..5b1361c3 100644 --- a/src/surrealdb/data/types/future.py +++ b/src/surrealdb/data/types/future.py @@ -14,6 +14,7 @@ class Future: Attributes: value: The value held by the Future object. This can be of any type. """ + value: Any def __eq__(self, other: object) -> bool: diff --git a/src/surrealdb/data/types/geometry.py b/src/surrealdb/data/types/geometry.py index dc449cda..51139ecf 100644 --- a/src/surrealdb/data/types/geometry.py +++ b/src/surrealdb/data/types/geometry.py @@ -35,6 +35,7 @@ class GeometryPoint(Geometry): longitude: The longitude of the point. latitude: The latitude of the point. """ + longitude: float latitude: float @@ -77,9 +78,12 @@ class GeometryLine(Geometry): Attributes: geometry_points: A list of GeometryPoint objects defining the line. """ + geometry_points: List[GeometryPoint] - def __init__(self, point1: GeometryPoint, point2: GeometryPoint, *other_points: GeometryPoint) -> None: + def __init__( + self, point1: GeometryPoint, point2: GeometryPoint, *other_points: GeometryPoint + ) -> None: """ The constructor for the GeometryLine class. """ @@ -108,7 +112,9 @@ def parse_coordinates(coordinates: List[Tuple[float, float]]) -> "GeometryLine": Returns: A GeometryLine object. """ - return GeometryLine(*[GeometryPoint.parse_coordinates(point) for point in coordinates]) + return GeometryLine( + *[GeometryPoint.parse_coordinates(point) for point in coordinates] + ) def __eq__(self, other: object) -> bool: if isinstance(other, GeometryLine): @@ -124,9 +130,12 @@ class GeometryPolygon(Geometry): Attributes: geometry_lines: A list of GeometryLine objects defining the polygon. """ + geometry_lines: List[GeometryLine] - def __init__(self, line1: GeometryLine, line2: GeometryLine, *other_lines: GeometryLine): + def __init__( + self, line1: GeometryLine, line2: GeometryLine, *other_lines: GeometryLine + ): self.geometry_lines = [line1, line2] + list(other_lines) def get_coordinates(self) -> List[List[Tuple[float, float]]]: @@ -142,7 +151,9 @@ def __repr__(self) -> str: return f'{self.__class__.__name__}({", ".join(repr(geo) for geo in self.geometry_lines)})' @staticmethod - def parse_coordinates(coordinates: List[List[Tuple[float, float]]]) -> "GeometryPolygon": + def parse_coordinates( + coordinates: List[List[Tuple[float, float]]], + ) -> "GeometryPolygon": """ Parses a list of lines, each defined by a list of coordinate tuples, into a GeometryPolygon. @@ -152,7 +163,9 @@ def parse_coordinates(coordinates: List[List[Tuple[float, float]]]) -> "Geometry Returns: A GeometryPolygon object. """ - return GeometryPolygon(*[GeometryLine.parse_coordinates(line) for line in coordinates]) + return GeometryPolygon( + *[GeometryLine.parse_coordinates(line) for line in coordinates] + ) def __eq__(self, other: object) -> bool: if isinstance(other, GeometryPolygon): @@ -168,6 +181,7 @@ class GeometryMultiPoint(Geometry): Attributes: geometry_points: A list of GeometryPoint objects. """ + geometry_points: List[GeometryPoint] def __init__(self, *geometry_points: GeometryPoint): @@ -186,7 +200,9 @@ def __repr__(self) -> str: return f'{self.__class__.__name__}({", ".join(repr(geo) for geo in self.geometry_points)})' @staticmethod - def parse_coordinates(coordinates: List[Tuple[float, float]]) -> "GeometryMultiPoint": + def parse_coordinates( + coordinates: List[Tuple[float, float]], + ) -> "GeometryMultiPoint": """ Parses a list of coordinate tuples into a GeometryMultiPoint. @@ -196,7 +212,9 @@ def parse_coordinates(coordinates: List[Tuple[float, float]]) -> "GeometryMultiP Returns: A GeometryMultiPoint object. """ - return GeometryMultiPoint(*[GeometryPoint.parse_coordinates(point) for point in coordinates]) + return GeometryMultiPoint( + *[GeometryPoint.parse_coordinates(point) for point in coordinates] + ) def __eq__(self, other: object) -> bool: if isinstance(other, GeometryMultiPoint): @@ -212,6 +230,7 @@ class GeometryMultiLine(Geometry): Attributes: geometry_lines: A list of GeometryLine objects. """ + geometry_lines: List[GeometryLine] def __init__(self, *geometry_lines: GeometryLine): @@ -230,7 +249,9 @@ def __repr__(self) -> str: return f'{self.__class__.__name__}({", ".join(repr(geo) for geo in self.geometry_lines)})' @staticmethod - def parse_coordinates(coordinates: List[List[Tuple[float, float]]]) -> "GeometryMultiLine": + def parse_coordinates( + coordinates: List[List[Tuple[float, float]]], + ) -> "GeometryMultiLine": """ Parses a list of lines, each defined by a list of coordinate tuples, into a GeometryMultiLine. @@ -240,7 +261,9 @@ def parse_coordinates(coordinates: List[List[Tuple[float, float]]]) -> "Geometry Returns: A GeometryMultiLine object. """ - return GeometryMultiLine(*[GeometryLine.parse_coordinates(line) for line in coordinates]) + return GeometryMultiLine( + *[GeometryLine.parse_coordinates(line) for line in coordinates] + ) def __eq__(self, other: object) -> bool: if isinstance(other, GeometryMultiLine): @@ -256,6 +279,7 @@ class GeometryMultiPolygon(Geometry): Attributes: geometry_polygons: A list of GeometryPolygon objects. """ + geometry_polygons: List[GeometryPolygon] def __init__(self, *geometry_polygons: GeometryPolygon): @@ -274,7 +298,9 @@ def __repr__(self) -> str: return f'{self.__class__.__name__}({", ".join(repr(geo) for geo in self.geometry_polygons)})' @staticmethod - def parse_coordinates(coordinates: List[List[List[Tuple[float, float]]]]) -> "GeometryMultiPolygon": + def parse_coordinates( + coordinates: List[List[List[Tuple[float, float]]]], + ) -> "GeometryMultiPolygon": """ Parses a list of polygons, each defined by a list of lines, into a GeometryMultiPolygon. @@ -284,7 +310,9 @@ def parse_coordinates(coordinates: List[List[List[Tuple[float, float]]]]) -> "Ge Returns: A GeometryMultiPolygon object. """ - return GeometryMultiPolygon(*[GeometryPolygon.parse_coordinates(polygon) for polygon in coordinates]) + return GeometryMultiPolygon( + *[GeometryPolygon.parse_coordinates(polygon) for polygon in coordinates] + ) def __eq__(self, other: object) -> bool: if isinstance(other, GeometryMultiPolygon): @@ -300,6 +328,7 @@ class GeometryCollection: Attributes: geometries: A list of Geometry objects. """ + geometries: List[Geometry] def __init__(self, *geometries: Geometry): diff --git a/src/surrealdb/data/types/range.py b/src/surrealdb/data/types/range.py index 20fdd7b8..e9031b07 100644 --- a/src/surrealdb/data/types/range.py +++ b/src/surrealdb/data/types/range.py @@ -3,6 +3,7 @@ """ from dataclasses import dataclass +from typing import Any class Bound: @@ -39,7 +40,7 @@ class BoundIncluded(Bound): value: The value of the inclusive bound. """ - value: any + value: Any def __init__(self, value): """ @@ -75,7 +76,7 @@ class BoundExcluded(Bound): value: The value of the exclusive bound. """ - value: any + value: Any def __init__(self, value): """ diff --git a/src/surrealdb/data/types/record_id.py b/src/surrealdb/data/types/record_id.py index 152800e1..ca933a67 100644 --- a/src/surrealdb/data/types/record_id.py +++ b/src/surrealdb/data/types/record_id.py @@ -1,7 +1,6 @@ """ Defines the data type for the record ID. """ -from dataclasses import dataclass class RecordID: @@ -12,6 +11,7 @@ class RecordID: table_name: The table name associated with the record ID identifier: The ID of the row """ + def __init__(self, table_name: str, identifier) -> None: """ The constructor for the RecordID class. @@ -33,10 +33,7 @@ def __repr__(self) -> str: def __eq__(self, other): if isinstance(other, RecordID): - return ( - self.table_name == other.table_name and - self.id == other.id - ) + return self.table_name == other.table_name and self.id == other.id @staticmethod def parse(record_str: str) -> "RecordID": diff --git a/src/surrealdb/data/types/table.py b/src/surrealdb/data/types/table.py index 2436ddce..1211318e 100644 --- a/src/surrealdb/data/types/table.py +++ b/src/surrealdb/data/types/table.py @@ -2,6 +2,7 @@ Defines a Table class to represent a database table by its name. """ + class Table: """ Represents a database table by its name. diff --git a/src/surrealdb/data/utils.py b/src/surrealdb/data/utils.py index cdfe2222..18f9ef3e 100644 --- a/src/surrealdb/data/utils.py +++ b/src/surrealdb/data/utils.py @@ -1,6 +1,7 @@ """ Utils for handling processes around data """ + from typing import Union from surrealdb.data.types.record_id import RecordID diff --git a/src/surrealdb/errors.py b/src/surrealdb/errors.py index fccecc56..005f7652 100644 --- a/src/surrealdb/errors.py +++ b/src/surrealdb/errors.py @@ -1,5 +1,3 @@ - - class SurrealDBMethodError(Exception): def __init__(self, message): self.message = message diff --git a/src/surrealdb/request_message/descriptors/cbor_ws.py b/src/surrealdb/request_message/descriptors/cbor_ws.py index f443484c..6459ac57 100644 --- a/src/surrealdb/request_message/descriptors/cbor_ws.py +++ b/src/surrealdb/request_message/descriptors/cbor_ws.py @@ -54,10 +54,12 @@ def __get__(self, obj, type=None) -> bytes: raise ValueError(f"Invalid method for Cbor WS encoding: {obj.method}") - def _raise_invalid_schema(self, data:dict, schema: dict, method: str) -> None: + def _raise_invalid_schema(self, data: dict, schema: dict, method: str) -> None: v = Validator(schema) if not v.validate(data): - raise ValueError(f"Invalid schema for Cbor WS encoding for {method}: {v.errors}") + raise ValueError( + f"Invalid schema for Cbor WS encoding for {method}: {v.errors}" + ) def prep_use(self, obj) -> bytes: data = { @@ -71,17 +73,14 @@ def prep_use(self, obj) -> bytes: "params": { "type": "list", # "params" must be a list "schema": {"type": "string"}, # Elements of "params" must be strings - "required": True + "required": True, }, } self._raise_invalid_schema(data=data, schema=schema, method=obj.method.value) return encode(data) def prep_info(self, obj) -> bytes: - data = { - "id": obj.id, - "method": obj.method.value - } + data = {"id": obj.id, "method": obj.method.value} schema = { "id": {"required": True}, "method": {"type": "string", "required": True}, # "method" must be a string @@ -90,10 +89,7 @@ def prep_info(self, obj) -> bytes: return encode(data) def prep_version(self, obj) -> bytes: - data = { - "id": obj.id, - "method": obj.method.value - } + data = {"id": obj.id, "method": obj.method.value} schema = { "id": {"required": True}, "method": {"type": "string", "required": True}, # "method" must be a string @@ -116,24 +112,25 @@ def prep_signup(self, obj) -> bytes: } for key, value in passed_params["variables"].items(): data["params"][0][key] = value - schema = { - "id": {"required": True}, - "method": {"type": "string", "required": True}, # "method" must be a string - "params": { - "type": "list", # "params" must be a list - "schema": { - "type": "dict", # Each element of the "params" list must be a dictionary - "schema": { - "NS": {"type": "string", "required": True}, # "NS" must be a string - "DB": {"type": "string", "required": True}, # "DB" must be a string - "AC": {"type": "string", "required": True}, # "AC" must be a string - "username": {"type": "string", "required": True}, # "username" must be a string - "password": {"type": "string", "required": True}, # "password" must be a string - }, - }, - "required": True, - }, - } + # Sign-up schema is currently deactivated due to the different types of params passed in + # schema = { + # "id": {"required": True}, + # "method": {"type": "string", "required": True}, # "method" must be a string + # "params": { + # "type": "list", # "params" must be a list + # "schema": { + # "type": "dict", # Each element of the "params" list must be a dictionary + # "schema": { + # "NS": {"type": "string", "required": True}, # "NS" must be a string + # "DB": {"type": "string", "required": True}, # "DB" must be a string + # "AC": {"type": "string", "required": True}, # "AC" must be a string + # "username": {"type": "string", "required": True}, # "username" must be a string + # "password": {"type": "string", "required": True}, # "password" must be a string + # }, + # }, + # "required": True, + # }, + # } # self._raise_invalid_schema(data=data, schema=schema, method=obj.method.value) return encode(data) @@ -155,11 +152,13 @@ def prep_signin(self, obj) -> bytes: "params": [ { "user": obj.kwargs.get("username"), - "pass": obj.kwargs.get("password") + "pass": obj.kwargs.get("password"), } - ] + ], } - elif obj.kwargs.get("namespace") is None and obj.kwargs.get("access") is not None: + elif ( + obj.kwargs.get("namespace") is None and obj.kwargs.get("access") is not None + ): data = { "id": obj.id, "method": obj.method.value, @@ -167,9 +166,9 @@ def prep_signin(self, obj) -> bytes: { "ac": obj.kwargs.get("access"), "user": obj.kwargs.get("username"), - "pass": obj.kwargs.get("password") + "pass": obj.kwargs.get("password"), } - ] + ], } elif obj.kwargs.get("database") is None and obj.kwargs.get("access") is None: # namespace signin @@ -180,11 +179,13 @@ def prep_signin(self, obj) -> bytes: { "ns": obj.kwargs.get("namespace"), "user": obj.kwargs.get("username"), - "pass": obj.kwargs.get("password") + "pass": obj.kwargs.get("password"), } - ] + ], } - elif obj.kwargs.get("database") is None and obj.kwargs.get("access") is not None: + elif ( + obj.kwargs.get("database") is None and obj.kwargs.get("access") is not None + ): # access signin data = { "id": obj.id, @@ -194,11 +195,16 @@ def prep_signin(self, obj) -> bytes: "ns": obj.kwargs.get("namespace"), "ac": obj.kwargs.get("access"), "user": obj.kwargs.get("username"), - "pass": obj.kwargs.get("password") + "pass": obj.kwargs.get("password"), } - ] + ], } - elif obj.kwargs.get("database") is not None and obj.kwargs.get("namespace") is not None and obj.kwargs.get("access") is not None and obj.kwargs.get("variables") is None: + elif ( + obj.kwargs.get("database") is not None + and obj.kwargs.get("namespace") is not None + and obj.kwargs.get("access") is not None + and obj.kwargs.get("variables") is None + ): data = { "id": obj.id, "method": obj.method.value, @@ -208,9 +214,9 @@ def prep_signin(self, obj) -> bytes: "db": obj.kwargs.get("database"), "ac": obj.kwargs.get("access"), "user": obj.kwargs.get("username"), - "pass": obj.kwargs.get("password") + "pass": obj.kwargs.get("password"), } - ] + ], } elif obj.kwargs.get("username") is None and obj.kwargs.get("password") is None: @@ -224,12 +230,16 @@ def prep_signin(self, obj) -> bytes: "ac": obj.kwargs.get("access"), # "variables": obj.kwargs.get("variables") } - ] + ], } for key, value in obj.kwargs.get("variables", {}).items(): data["params"][0][key] = value - elif obj.kwargs.get("database") is not None and obj.kwargs.get("namespace") is not None and obj.kwargs.get("access") is None: + elif ( + obj.kwargs.get("database") is not None + and obj.kwargs.get("namespace") is not None + and obj.kwargs.get("access") is None + ): data = { "id": obj.id, "method": obj.method.value, @@ -238,9 +248,9 @@ def prep_signin(self, obj) -> bytes: "ns": obj.kwargs.get("namespace"), "db": obj.kwargs.get("database"), "user": obj.kwargs.get("username"), - "pass": obj.kwargs.get("password") + "pass": obj.kwargs.get("password"), } - ] + ], } else: @@ -251,9 +261,7 @@ def prep_authenticate(self, obj) -> bytes: data = { "id": obj.id, "method": obj.method.value, - "params": [ - obj.kwargs.get("token") - ] + "params": [obj.kwargs.get("token")], } schema = { "id": {"required": True}, @@ -273,13 +281,10 @@ def prep_authenticate(self, obj) -> bytes: return encode(data) def prep_invalidate(self, obj) -> bytes: - data = { - "id": obj.id, - "method": obj.method.value - } + data = {"id": obj.id, "method": obj.method.value} schema = { "id": {"required": True}, - "method": {"type": "string", "required": True} + "method": {"type": "string", "required": True}, } self._raise_invalid_schema(data=data, schema=schema, method=obj.method.value) return encode(data) @@ -288,16 +293,12 @@ def prep_let(self, obj) -> bytes: data = { "id": obj.id, "method": obj.method.value, - "params": [obj.kwargs.get("key"), obj.kwargs.get("value")] + "params": [obj.kwargs.get("key"), obj.kwargs.get("value")], } schema = { "id": {"required": True}, "method": {"type": "string", "required": True, "allowed": ["let"]}, - "params": { - "type": "list", - "minlength": 2, - "required": True - }, + "params": {"type": "list", "minlength": 2, "required": True}, } self._raise_invalid_schema(data=data, schema=schema, method=obj.method.value) return encode(data) @@ -306,15 +307,12 @@ def prep_unset(self, obj) -> bytes: data = { "id": obj.id, "method": obj.method.value, - "params": obj.kwargs.get("params") + "params": obj.kwargs.get("params"), } schema = { "id": {"required": True}, "method": {"type": "string", "required": True, "allowed": ["unset"]}, - "params": { - "type": "list", - "required": True - }, + "params": {"type": "list", "required": True}, } self._raise_invalid_schema(data=data, schema=schema, method=obj.method.value) return encode(data) @@ -323,18 +321,11 @@ def prep_live(self, obj) -> bytes: table = obj.kwargs.get("table") if isinstance(table, str): table = Table(table) - data = { - "id": obj.id, - "method": obj.method.value, - "params": [table] - } + data = {"id": obj.id, "method": obj.method.value, "params": [table]} schema = { "id": {"required": True}, "method": {"type": "string", "required": True, "allowed": ["live"]}, - "params": { - "type": "list", - "required": True - }, + "params": {"type": "list", "required": True}, } self._raise_invalid_schema(data=data, schema=schema, method=obj.method.value) return encode(data) @@ -343,15 +334,12 @@ def prep_kill(self, obj) -> bytes: data = { "id": obj.id, "method": obj.method.value, - "params": [obj.kwargs.get("uuid")] + "params": [obj.kwargs.get("uuid")], } schema = { "id": {"required": True}, "method": {"type": "string", "required": True, "allowed": ["kill"]}, - "params": { - "type": "list", - "required": True - }, + "params": {"type": "list", "required": True}, } self._raise_invalid_schema(data=data, schema=schema, method=obj.method.value) return encode(data) @@ -360,10 +348,7 @@ def prep_query(self, obj) -> bytes: data = { "id": obj.id, "method": obj.method.value, - "params": [ - obj.kwargs.get("query"), - obj.kwargs.get("params", dict()) - ] + "params": [obj.kwargs.get("query"), obj.kwargs.get("params", dict())], } schema = { "id": {"required": True}, @@ -372,7 +357,7 @@ def prep_query(self, obj) -> bytes: "type": "list", "minlength": 2, # Ensures there are at least two elements "maxlength": 2, # Ensures exactly two elements - "required": True + "required": True, }, } self._raise_invalid_schema(data=data, schema=schema, method=obj.method.value) @@ -384,8 +369,8 @@ def prep_insert(self, obj) -> bytes: "method": obj.method.value, "params": [ process_thing(obj.kwargs.get("collection")), - obj.kwargs.get("params") - ] + obj.kwargs.get("params"), + ], } schema = { "id": {"required": True}, @@ -394,8 +379,8 @@ def prep_insert(self, obj) -> bytes: "type": "list", "minlength": 2, # Ensure there are at least two elements "maxlength": 2, # Ensure exactly two elements - "required": True - } + "required": True, + }, } self._raise_invalid_schema(data=data, schema=schema, method=obj.method.value) return encode(data) @@ -406,8 +391,8 @@ def prep_patch(self, obj) -> bytes: "method": obj.method.value, "params": [ process_thing(obj.kwargs.get("collection")), - obj.kwargs.get("params") - ] + obj.kwargs.get("params"), + ], } if obj.kwargs.get("params") is None: raise ValidationError("parameters cannot be None for a patch method") @@ -419,7 +404,7 @@ def prep_patch(self, obj) -> bytes: "minlength": 2, # Ensure there are at least two elements "maxlength": 2, # Ensure exactly two elements "required": True, - } + }, } self._raise_invalid_schema(data=data, schema=schema, method=obj.method.value) return encode(data) @@ -428,15 +413,12 @@ def prep_select(self, obj) -> bytes: data = { "id": obj.id, "method": obj.method.value, - "params": obj.kwargs.get("params") + "params": obj.kwargs.get("params"), } schema = { "id": {"required": True}, "method": {"type": "string", "required": True, "allowed": ["select"]}, - "params": { - "type": "list", - "required": True - } + "params": {"type": "list", "required": True}, } self._raise_invalid_schema(data=data, schema=schema, method=obj.method.value) return encode(data) @@ -445,7 +427,7 @@ def prep_create(self, obj) -> bytes: data = { "id": obj.id, "method": obj.method.value, - "params": [process_thing(obj.kwargs.get("collection"))] + "params": [process_thing(obj.kwargs.get("collection"))], } if obj.kwargs.get("data"): data["params"].append(obj.kwargs.get("data")) @@ -457,8 +439,8 @@ def prep_create(self, obj) -> bytes: "type": "list", "minlength": 1, "maxlength": 2, - "required": True - } + "required": True, + }, } self._raise_invalid_schema(data=data, schema=schema, method=obj.method.value) return encode(data) @@ -469,8 +451,8 @@ def prep_update(self, obj) -> bytes: "method": obj.method.value, "params": [ process_thing(obj.kwargs.get("record_id")), - obj.kwargs.get("data", dict()) - ] + obj.kwargs.get("data", dict()), + ], } schema = { "id": {"required": True}, @@ -479,8 +461,8 @@ def prep_update(self, obj) -> bytes: "type": "list", "minlength": 1, "maxlength": 2, - "required": True - } + "required": True, + }, } self._raise_invalid_schema(data=data, schema=schema, method=obj.method.value) return encode(data) @@ -491,8 +473,8 @@ def prep_merge(self, obj) -> bytes: "method": obj.method.value, "params": [ process_thing(obj.kwargs.get("record_id")), - obj.kwargs.get("data", dict()) - ] + obj.kwargs.get("data", dict()), + ], } schema = { "id": {"required": True}, @@ -501,8 +483,8 @@ def prep_merge(self, obj) -> bytes: "type": "list", "minlength": 1, "maxlength": 2, - "required": True - } + "required": True, + }, } self._raise_invalid_schema(data=data, schema=schema, method=obj.method.value) return encode(data) @@ -511,7 +493,7 @@ def prep_delete(self, obj) -> bytes: data = { "id": obj.id, "method": obj.method.value, - "params": [process_thing(obj.kwargs.get("record_id"))] + "params": [process_thing(obj.kwargs.get("record_id"))], } schema = { "id": {"required": True}, @@ -520,8 +502,8 @@ def prep_delete(self, obj) -> bytes: "type": "list", "minlength": 1, "maxlength": 1, - "required": True - } + "required": True, + }, } self._raise_invalid_schema(data=data, schema=schema, method=obj.method.value) return encode(data) @@ -532,7 +514,7 @@ def prep_insert_relation(self, obj) -> bytes: "method": obj.method.value, "params": [ Table(obj.kwargs.get("table")), - ] + ], } params = obj.kwargs.get("params", []) # for i in params: @@ -540,13 +522,17 @@ def prep_insert_relation(self, obj) -> bytes: schema = { "id": {"required": True}, - "method": {"type": "string", "required": True, "allowed": ["insert_relation"]}, + "method": { + "type": "string", + "required": True, + "allowed": ["insert_relation"], + }, "params": { "type": "list", "minlength": 2, "maxlength": 2, - "required": True - } + "required": True, + }, } self._raise_invalid_schema(data=data, schema=schema, method=obj.method.value) return encode(data) @@ -557,8 +543,8 @@ def prep_upsert(self, obj) -> bytes: "method": obj.method.value, "params": [ process_thing(obj.kwargs.get("record_id")), - obj.kwargs.get("data", dict()) - ] + obj.kwargs.get("data", dict()), + ], } schema = { "id": {"required": True}, @@ -567,8 +553,8 @@ def prep_upsert(self, obj) -> bytes: "type": "list", "minlength": 1, "maxlength": 2, - "required": True - } + "required": True, + }, } self._raise_invalid_schema(data=data, schema=schema, method=obj.method.value) return encode(data) diff --git a/src/surrealdb/request_message/descriptors/json_http.py b/src/surrealdb/request_message/descriptors/json_http.py deleted file mode 100644 index 1d781900..00000000 --- a/src/surrealdb/request_message/descriptors/json_http.py +++ /dev/null @@ -1,85 +0,0 @@ -from marshmallow import ValidationError -from surrealdb.request_message.methods import RequestMethod -from marshmallow import Schema, fields -from typing import Tuple -from enum import Enum -import json - - -class HttpMethod(Enum): - GET = "GET" - POST = "POST" - PUT = "PUT" - PATCH = "PATCH" - DELETE = "DELETE" - - -class JsonHttpDescriptor: - def __get__(self, obj, type=None) -> Tuple[str, HttpMethod, str]: - if obj.method == RequestMethod.SIGN_IN: - return self.prep_signin(obj) - # if obj.method == RequestMethod.USE: - # return self.prep_use(obj) - # elif obj.method == RequestMethod.INFO: - # return self.prep_info(obj) - # elif obj.method == RequestMethod.VERSION: - # return self.prep_version(obj) - # elif obj.method == RequestMethod.SIGN_UP: - # return self.prep_signup(obj) - # elif obj.method == RequestMethod.SIGN_IN: - # return self.prep_signin(obj) - # elif obj.method == RequestMethod.AUTHENTICATE: - # return self.prep_authenticate(obj) - # elif obj.method == RequestMethod.INVALIDATE: - # return self.prep_invalidate(obj) - # elif obj.method == RequestMethod.LET: - # return self.prep_let(obj) - # elif obj.method == RequestMethod.UNSET: - # return self.prep_unset(obj) - # elif obj.method == RequestMethod.LIVE: - # return self.prep_live(obj) - # elif obj.method == RequestMethod.KILL: - # return self.prep_kill(obj) - # elif obj.method == RequestMethod.QUERY: - # return self.prep_query(obj) - # elif obj.method == RequestMethod.INSERT: - # return self.prep_insert(obj) - # elif obj.method == RequestMethod.PATCH: - # return self.prep_patch(obj) - - @staticmethod - def serialize(data: dict, schema: Schema, context: str) -> str: - try: - result = schema.load(data) - except ValidationError as err: - raise ValidationError(f"Validation error for {context}:", err.messages) - return json.dumps(schema.dump(result)) - - - def prep_signin(self, obj) -> Tuple[str, HttpMethod, str]: - class SignInSchema(Schema): - ns = fields.Str(required=False) # Optional Namespace - db = fields.Str(required=False) # Optional Database - ac = fields.Str(required=False) # Optional Account category - user = fields.Str(required=True) # Required Username - pass_ = fields.Str(required=True, data_key="pass") # Required Password - - schema = SignInSchema() - - if obj.kwargs.get("namespace") is None: - # root user signing in - data = { - "user": obj.kwargs.get("username"), - "pass": obj.kwargs.get("password") - } - else: - data = { - "ns": obj.kwargs.get("namespace"), - "db": obj.kwargs.get("database"), - "ac": obj.kwargs.get("account"), - "user": obj.kwargs.get("username"), - "pass": obj.kwargs.get("password") - } - - result = self.serialize(data, schema, "HTTP signin") - return result, HttpMethod.POST, "signin" diff --git a/src/surrealdb/request_message/message.py b/src/surrealdb/request_message/message.py index dfc8774b..509dbb93 100644 --- a/src/surrealdb/request_message/message.py +++ b/src/surrealdb/request_message/message.py @@ -1,12 +1,10 @@ from surrealdb.request_message.descriptors.cbor_ws import WsCborDescriptor -from surrealdb.request_message.descriptors.json_http import JsonHttpDescriptor from surrealdb.request_message.methods import RequestMethod class RequestMessage: WS_CBOR_DESCRIPTOR = WsCborDescriptor() - JSON_HTTP_DESCRIPTOR = JsonHttpDescriptor() def __init__(self, id_for_request, method: RequestMethod, **kwargs) -> None: self.id = id_for_request diff --git a/src/surrealdb/request_message/sql_adapter.py b/src/surrealdb/request_message/sql_adapter.py index d09ba86a..37870e73 100644 --- a/src/surrealdb/request_message/sql_adapter.py +++ b/src/surrealdb/request_message/sql_adapter.py @@ -1,6 +1,7 @@ """ Defines a class that adapts SQL commands from various sources into a single string. """ + from typing import List @@ -8,6 +9,7 @@ class SqlAdapter: """ Adapts SQL commands from various sources into a single string. """ + @staticmethod def from_list(commands: List[str]) -> str: """ diff --git a/tests/unit_tests/data_types/test_datetimes.py b/tests/unit_tests/data_types/test_datetimes.py index 6416eae8..f12d56aa 100644 --- a/tests/unit_tests/data_types/test_datetimes.py +++ b/tests/unit_tests/data_types/test_datetimes.py @@ -3,6 +3,7 @@ from surrealdb.connections.async_ws import AsyncWsSurrealConnection from surrealdb.data.types.datetime import IsoDateTimeWrapper +import sys class TestAsyncWsSurrealConnectionDatetime(IsolatedAsyncioTestCase): @@ -26,7 +27,7 @@ async def asyncSetUp(self): # Cleanup await self.connection.query("DELETE datetime_tests;") - async def test_datetime_wrapper(self): + async def test_native_datetime(self): now = datetime.datetime.now() await self.connection.query( "CREATE datetime_tests:compact_test SET datetime = $compact_datetime;", @@ -45,9 +46,13 @@ async def test_datetime_wrapper(self): await self.connection.query("DELETE datetime_tests;") await self.connection.close() - async def test_datetime_formats(self): + async def test_datetime_iso_format(self): iso_datetime = "2025-02-03T12:30:45.123456Z" # ISO 8601 datetime + if sys.version_info < (3, 11): + iso_datetime = iso_datetime.replace("Z", "+00:00") + date = IsoDateTimeWrapper(iso_datetime) + iso_datetime_obj = datetime.datetime.fromisoformat(iso_datetime) # Insert records with different datetime formats diff --git a/tests/unit_tests/request_message/descriptors/test_json_http.py b/tests/unit_tests/request_message/descriptors/test_json_http.py deleted file mode 100644 index 5de56152..00000000 --- a/tests/unit_tests/request_message/descriptors/test_json_http.py +++ /dev/null @@ -1,57 +0,0 @@ -from unittest import TestCase, main - -from surrealdb.request_message.descriptors.json_http import HttpMethod -from surrealdb.request_message.message import RequestMessage -from surrealdb.request_message.methods import RequestMethod - - -class TestJsonHttpDescriptor(TestCase): - - def test_signin_pass_root(self): - message = RequestMessage( - 1, - RequestMethod.SIGN_IN, - username="user", - password="pass" - ) - json_body, method, endpoint = message.JSON_HTTP_DESCRIPTOR - self.assertEqual('{"user": "user", "pass": "pass"}', json_body) - self.assertEqual(HttpMethod.POST, method) - self.assertEqual("signin", endpoint) - - def test_signin_pass_root_with_none(self): - message = RequestMessage( - 1, - RequestMethod.SIGN_IN, - username="username", - password="pass", - account=None, - database=None, - namespace=None - ) - json_body, method, endpoint = message.JSON_HTTP_DESCRIPTOR - self.assertEqual('{"user": "username", "pass": "pass"}', json_body) - self.assertEqual(HttpMethod.POST, method) - self.assertEqual("signin", endpoint) - - def test_signin_pass_account(self): - message = RequestMessage( - 1, - RequestMethod.SIGN_IN, - username="username", - password="pass", - account="account", - database="database", - namespace="namespace" - ) - json_body, method, endpoint = message.JSON_HTTP_DESCRIPTOR - self.assertEqual( - '{"ns": "namespace", "db": "database", "ac": "account", "user": "username", "pass": "pass"}', - json_body - ) - self.assertEqual(HttpMethod.POST, method) - self.assertEqual("signin", endpoint) - - -if __name__ == '__main__': - main() From 0b5b626a43436c8f068af066e36bc466138b42f1 Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Tue, 4 Feb 2025 16:24:01 +0000 Subject: [PATCH 07/11] updating tests --- src/surrealdb/connections/async_http.py | 48 +++++++++--------- src/surrealdb/connections/async_template.py | 46 +++++++++--------- src/surrealdb/connections/async_ws.py | 54 ++++++++++----------- 3 files changed, 74 insertions(+), 74 deletions(-) diff --git a/src/surrealdb/connections/async_http.py b/src/surrealdb/connections/async_http.py index 8ab9364f..6e5fa89c 100644 --- a/src/surrealdb/connections/async_http.py +++ b/src/surrealdb/connections/async_http.py @@ -49,7 +49,7 @@ async def _send( message: RequestMessage, operation: str, bypass: bool = False, - ) -> Dict[str, Any]: # type: ignore + ) -> Dict[str, Any]: """ Sends an HTTP request to the SurrealDB server. @@ -89,7 +89,7 @@ async def _send( self.check_response_for_error(data, operation) return data - def set_token(self, token: str) -> None: # type: ignore + def set_token(self, token: str) -> None: """ Sets the token for authentication. @@ -97,23 +97,23 @@ def set_token(self, token: str) -> None: # type: ignore """ self.token = token - async def authenticate(self) -> None: # type: ignore + async def authenticate(self) -> None: message = RequestMessage(self.id, RequestMethod.AUTHENTICATE, token=self.token) return await self._send(message, "authenticating") - async def invalidate(self) -> None: # type: ignore + async def invalidate(self) -> None: message = RequestMessage(self.id, RequestMethod.INVALIDATE) await self._send(message, "invalidating") self.token = None - async def signup(self, vars: Dict) -> str: # type: ignore + async def signup(self, vars: Dict) -> str: message = RequestMessage(self.id, RequestMethod.SIGN_UP, data=vars) response = await self._send(message, "signup") self.check_response_for_result(response, "signup") self.token = response["result"] return response["result"] - async def signin(self, vars: dict) -> dict: # type: ignore + async def signin(self, vars: dict) -> dict: message = RequestMessage( self.id, RequestMethod.SIGN_IN, @@ -129,13 +129,13 @@ async def signin(self, vars: dict) -> dict: # type: ignore self.token = response["result"] return response["result"] - async def info(self) -> dict: # type: ignore + async def info(self) -> dict: message = RequestMessage(self.id, RequestMethod.INFO) response = await self._send(message, "getting database information") self.check_response_for_result(response, "getting database information") return response["result"] - async def use(self, namespace: str, database: str) -> None: # type: ignore + async def use(self, namespace: str, database: str) -> None: message = RequestMessage( self.token, RequestMethod.USE, @@ -146,7 +146,7 @@ async def use(self, namespace: str, database: str) -> None: # type: ignore self.namespace = namespace self.database = database - async def query(self, query: str, params: Optional[dict] = None) -> dict: # type: ignore + async def query(self, query: str, params: Optional[dict] = None) -> dict: if params is None: params = {} for key, value in self.vars.items(): @@ -161,7 +161,7 @@ async def query(self, query: str, params: Optional[dict] = None) -> dict: # typ self.check_response_for_result(response, "query") return response["result"][0]["result"] - async def query_raw(self, query: str, params: Optional[dict] = None) -> dict: # type: ignore + async def query_raw(self, query: str, params: Optional[dict] = None) -> dict: if params is None: params = {} for key, value in self.vars.items(): @@ -179,7 +179,7 @@ async def create( self, thing: Union[str, RecordID, Table], data: Optional[Union[Union[List[dict], dict], dict]] = None, - ) -> Union[List[dict], dict]: # type: ignore + ) -> Union[List[dict], dict]: if isinstance(thing, str): if ":" in thing: buffer = thing.split(":") @@ -193,7 +193,7 @@ async def create( async def delete( self, thing: Union[str, RecordID, Table] - ) -> Union[List[dict], dict]: # type: ignore + ) -> Union[List[dict], dict]: message = RequestMessage(self.id, RequestMethod.DELETE, record_id=thing) response = await self._send(message, "delete") self.check_response_for_result(response, "delete") @@ -201,7 +201,7 @@ async def delete( async def insert( self, table: Union[str, Table], data: Union[List[dict], dict] - ) -> Union[List[dict], dict]: # type: ignore + ) -> Union[List[dict], dict]: message = RequestMessage( self.id, RequestMethod.INSERT, collection=table, params=data ) @@ -211,7 +211,7 @@ async def insert( async def insert_relation( self, table: Union[str, Table], data: Union[List[dict], dict] - ) -> Union[List[dict], dict]: # type: ignore + ) -> Union[List[dict], dict]: message = RequestMessage( self.id, RequestMethod.INSERT_RELATION, table=table, params=data ) @@ -219,15 +219,15 @@ async def insert_relation( self.check_response_for_result(response, "insert_relation") return response["result"] - async def let(self, key: str, value: Any) -> None: # type: ignore + async def let(self, key: str, value: Any) -> None: self.vars[key] = value - async def unset(self, key: str) -> None: # type: ignore + async def unset(self, key: str) -> None: self.vars.pop(key) async def merge( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Union[List[dict], dict]: # type: ignore + ) -> Union[List[dict], dict]: message = RequestMessage( self.id, RequestMethod.MERGE, record_id=thing, data=data ) @@ -237,7 +237,7 @@ async def merge( async def patch( self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None - ) -> Union[List[dict], dict]: # type: ignore + ) -> Union[List[dict], dict]: message = RequestMessage( self.id, RequestMethod.PATCH, collection=thing, params=data ) @@ -245,7 +245,7 @@ async def patch( self.check_response_for_result(response, "patch") return response["result"] - async def select(self, thing: str) -> Union[List[dict], dict]: # type: ignore + async def select(self, thing: str) -> Union[List[dict], dict]: message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing]) response = await self._send(message, "select") self.check_response_for_result(response, "select") @@ -253,7 +253,7 @@ async def select(self, thing: str) -> Union[List[dict], dict]: # type: ignore async def update( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Union[List[dict], dict]: # type: ignore + ) -> Union[List[dict], dict]: message = RequestMessage( self.id, RequestMethod.UPDATE, record_id=thing, data=data ) @@ -261,7 +261,7 @@ async def update( self.check_response_for_result(response, "update") return response["result"] - async def version(self) -> str: # type: ignore + async def version(self) -> str: message = RequestMessage(self.id, RequestMethod.VERSION) response = await self._send(message, "getting database version") self.check_response_for_result(response, "getting database version") @@ -269,7 +269,7 @@ async def version(self) -> str: # type: ignore async def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Union[List[dict], dict]: # type: ignore + ) -> Union[List[dict], dict]: message = RequestMessage( self.id, RequestMethod.UPSERT, record_id=thing, data=data ) @@ -277,7 +277,7 @@ async def upsert( self.check_response_for_result(response, "upsert") return response["result"] - async def __aenter__(self) -> "AsyncHttpSurrealConnection": # type: ignore + async def __aenter__(self) -> "AsyncHttpSurrealConnection": """ Asynchronous context manager entry. Initializes an aiohttp session and returns the connection instance. @@ -285,7 +285,7 @@ async def __aenter__(self) -> "AsyncHttpSurrealConnection": # type: ignore self._session = aiohttp.ClientSession() return self - async def __aexit__(self, exc_type, exc_value, traceback) -> None: # type: ignore + async def __aexit__(self, exc_type, exc_value, traceback) -> None: """ Asynchronous context manager exit. Closes the aiohttp session upon exiting the context. diff --git a/src/surrealdb/connections/async_template.py b/src/surrealdb/connections/async_template.py index f35cca7e..70e6053c 100644 --- a/src/surrealdb/connections/async_template.py +++ b/src/surrealdb/connections/async_template.py @@ -7,7 +7,7 @@ class AsyncTemplate: - async def connect(self, url: str) -> Coroutine[Any, Any, None]: # type: ignore + async def connect(self, url: str) -> Coroutine[Any, Any, None]: """Connects to a local or remote database endpoint. Args: @@ -20,7 +20,7 @@ async def connect(self, url: str) -> Coroutine[Any, Any, None]: # type: ignore """ raise NotImplementedError(f"query not implemented for: {self}") - async def close(self) -> Coroutine[Any, Any, None]: # type: ignore + async def close(self) -> Coroutine[Any, Any, None]: """Closes the persistent connection to the database. Example: @@ -28,7 +28,7 @@ async def close(self) -> Coroutine[Any, Any, None]: # type: ignore """ raise NotImplementedError(f"query not implemented for: {self}") - async def use(self, namespace: str, database: str) -> Coroutine[Any, Any, None]: # type: ignore + async def use(self, namespace: str, database: str) -> Coroutine[Any, Any, None]: """Switch to a specific namespace and database. Args: @@ -40,7 +40,7 @@ async def use(self, namespace: str, database: str) -> Coroutine[Any, Any, None]: """ raise NotImplementedError(f"query not implemented for: {self}") - async def authenticate(self, token: str) -> Coroutine[Any, Any, None]: # type: ignore + async def authenticate(self, token: str) -> Coroutine[Any, Any, None]: """Authenticate the current connection with a JWT token. Args: @@ -51,7 +51,7 @@ async def authenticate(self, token: str) -> Coroutine[Any, Any, None]: # type: """ raise NotImplementedError(f"authenticate not implemented for: {self}") - async def invalidate(self) -> Coroutine[Any, Any, None]: # type: ignore + async def invalidate(self) -> Coroutine[Any, Any, None]: """Invalidate the authentication for the current connection. Example: @@ -59,7 +59,7 @@ async def invalidate(self) -> Coroutine[Any, Any, None]: # type: ignore """ raise NotImplementedError(f"invalidate not implemented for: {self}") - async def signup(self, vars: Dict) -> Coroutine[Any, Any, str]: # type: ignore + async def signup(self, vars: Dict) -> Coroutine[Any, Any, str]: """Sign this connection up to a specific authentication scope. [See the docs](https://surrealdb.com/docs/sdk/python/methods/signup) @@ -81,7 +81,7 @@ async def signup(self, vars: Dict) -> Coroutine[Any, Any, str]: # type: ignore """ raise NotImplementedError(f"signup not implemented for: {self}") - async def signin(self, vars: Dict) -> Coroutine[Any, Any, str]: # type: ignore + async def signin(self, vars: Dict) -> Coroutine[Any, Any, str]: """Sign this connection in to a specific authentication scope. [See the docs](https://surrealdb.com/docs/sdk/python/methods/signin) @@ -96,7 +96,7 @@ async def signin(self, vars: Dict) -> Coroutine[Any, Any, str]: # type: ignore """ raise NotImplementedError(f"query not implemented for: {self}") - async def let(self, key: str, value: Any) -> Coroutine[Any, Any, None]: # type: ignore + async def let(self, key: str, value: Any) -> Coroutine[Any, Any, None]: """Assign a value as a variable for this connection. Args: @@ -115,7 +115,7 @@ async def let(self, key: str, value: Any) -> Coroutine[Any, Any, None]: # type: """ raise NotImplementedError(f"let not implemented for: {self}") - async def unset(self, key: str) -> Coroutine[Any, Any, None]: # type: ignore + async def unset(self, key: str) -> Coroutine[Any, Any, None]: """Removes a variable for this connection. Args: @@ -128,7 +128,7 @@ async def unset(self, key: str) -> Coroutine[Any, Any, None]: # type: ignore async def query( self, query: str, vars: Optional[Dict] = None - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: """Run a unset of SurrealQL statements against the database. Args: @@ -145,7 +145,7 @@ async def query( async def select( self, thing: Union[str, RecordID, Table] - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: """Select all records in a table (or other entity), or a specific record, in the database. @@ -164,7 +164,7 @@ async def create( self, thing: Union[str, RecordID, Table], data: Optional[Union[List[dict], dict]] = None, - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: """Create a record in the database. This function will run the following query in the database: @@ -181,7 +181,7 @@ async def create( async def update( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: """Update all records in a table, or a specific record, in the database. This function replaces the current document / record data with the @@ -211,7 +211,7 @@ async def update( async def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: """Insert records into the database, or to update them if they exist. @@ -239,7 +239,7 @@ async def upsert( async def merge( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: """Modify by deep merging all records in a table, or a specific record, in the database. This function merges the current document / record data with the @@ -271,7 +271,7 @@ async def merge( async def patch( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: """Apply JSON Patch changes to all records, or a specific record, in the database. This function patches the current document / record data with @@ -300,7 +300,7 @@ async def patch( async def delete( self, thing: Union[str, RecordID, Table] - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: """Delete all records in a table, or a specific record, from the database. This function will run the following query in the database: @@ -318,7 +318,7 @@ async def delete( """ raise NotImplementedError(f"delete not implemented for: {self}") - async def info(self) -> Coroutine[Any, Any, dict]: # type: ignore + async def info(self) -> Coroutine[Any, Any, dict]: """This returns the record of an authenticated record user. Example: @@ -328,7 +328,7 @@ async def info(self) -> Coroutine[Any, Any, dict]: # type: ignore async def insert( self, table: Union[str, Table], data: Union[List[dict], dict] - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: """ Inserts one or multiple records in the database. @@ -347,7 +347,7 @@ async def insert( async def insert_relation( self, table: Union[str, Table], data: Union[List[dict], dict] - ) -> Coroutine[Any, Any, Union[List[dict], dict]]: # type: ignore + ) -> Coroutine[Any, Any, Union[List[dict], dict]]: """ Inserts one or multiple relations in the database. @@ -366,7 +366,7 @@ async def insert_relation( async def live( self, table: Union[str, Table], diff: bool = False - ) -> Coroutine[Any, Any, UUID]: # type: ignore + ) -> Coroutine[Any, Any, UUID]: """Initiates a live query for a specified table name. Args: @@ -385,7 +385,7 @@ async def live( async def subscribe_live( self, query_uuid: Union[str, UUID] - ) -> Coroutine[Any, Any, Queue]: # type: ignore + ) -> Coroutine[Any, Any, Queue]: """Returns a queue that receives notification messages from a running live query. Args: @@ -399,7 +399,7 @@ async def subscribe_live( """ raise NotImplementedError(f"query not implemented for: {self}") - async def kill(self, query_uuid: Union[str, UUID]) -> Coroutine[Any, Any, None]: # type: ignore + async def kill(self, query_uuid: Union[str, UUID]) -> Coroutine[Any, Any, None]: """Kills a running live query by it's UUID. Args: diff --git a/src/surrealdb/connections/async_ws.py b/src/surrealdb/connections/async_ws.py index b260611b..3e7734c8 100644 --- a/src/surrealdb/connections/async_ws.py +++ b/src/surrealdb/connections/async_ws.py @@ -52,7 +52,7 @@ def __init__( async def _send( self, message: RequestMessage, process: str, bypass: bool = False - ) -> dict: # type: ignore + ) -> dict: await self.connect() assert ( self.socket is not None @@ -63,7 +63,7 @@ async def _send( self.check_response_for_error(response, process) return response - async def connect(self, url: Optional[str] = None) -> None: # type: ignore + async def connect(self, url: Optional[str] = None) -> None: # overwrite params if passed in if url is not None: self.url = Url(url) @@ -77,22 +77,22 @@ async def connect(self, url: Optional[str] = None) -> None: # type: ignore subprotocols=[websockets.Subprotocol("cbor")], ) - async def authenticate(self, token: str) -> dict: # type: ignore + async def authenticate(self, token: str) -> dict: message = RequestMessage(self.id, RequestMethod.AUTHENTICATE, token=token) return await self._send(message, "authenticating") - async def invalidate(self) -> None: # type: ignore + async def invalidate(self) -> None: message = RequestMessage(self.id, RequestMethod.INVALIDATE) await self._send(message, "invalidating") self.token = None - async def signup(self, vars: Dict) -> str: # type: ignore + async def signup(self, vars: Dict) -> str: message = RequestMessage(self.id, RequestMethod.SIGN_UP, data=vars) response = await self._send(message, "signup") self.check_response_for_result(response, "signup") return response["result"] - async def signin(self, vars: Dict[str, Any]) -> str: # type: ignore + async def signin(self, vars: Dict[str, Any]) -> str: message = RequestMessage( self.id, RequestMethod.SIGN_IN, @@ -108,13 +108,13 @@ async def signin(self, vars: Dict[str, Any]) -> str: # type: ignore self.token = response["result"] return response["result"] - async def info(self) -> Optional[dict]: # type: ignore + async def info(self) -> Optional[dict]: message = RequestMessage(self.id, RequestMethod.INFO) outcome = await self._send(message, "getting database information") self.check_response_for_result(outcome, "getting database information") return outcome["result"] - async def use(self, namespace: str, database: str) -> None: # type: ignore + async def use(self, namespace: str, database: str) -> None: message = RequestMessage( self.id, RequestMethod.USE, @@ -123,7 +123,7 @@ async def use(self, namespace: str, database: str) -> None: # type: ignore ) await self._send(message, "use") - async def query(self, query: str, params: Optional[dict] = None) -> dict: # type: ignore + async def query(self, query: str, params: Optional[dict] = None) -> dict: if params is None: params = {} message = RequestMessage( @@ -136,7 +136,7 @@ async def query(self, query: str, params: Optional[dict] = None) -> dict: # typ self.check_response_for_result(response, "query") return response["result"][0]["result"] - async def query_raw(self, query: str, params: Optional[dict] = None) -> dict: # type: ignore + async def query_raw(self, query: str, params: Optional[dict] = None) -> dict: if params is None: params = {} message = RequestMessage( @@ -148,21 +148,21 @@ async def query_raw(self, query: str, params: Optional[dict] = None) -> dict: # response = await self._send(message, "query", bypass=True) return response - async def version(self) -> str: # type: ignore + async def version(self) -> str: message = RequestMessage(self.id, RequestMethod.VERSION) response = await self._send(message, "getting database version") self.check_response_for_result(response, "getting database version") return response["result"] - async def let(self, key: str, value: Any) -> None: # type: ignore + async def let(self, key: str, value: Any) -> None: message = RequestMessage(self.id, RequestMethod.LET, key=key, value=value) await self._send(message, "letting") - async def unset(self, key: str) -> None: # type: ignore + async def unset(self, key: str) -> None: message = RequestMessage(self.id, RequestMethod.UNSET, params=[key]) await self._send(message, "unsetting") - async def select(self, thing: str | RecordID | Table) -> Union[List[dict], dict]: # type: ignore + async def select(self, thing: str | RecordID | Table) -> Union[List[dict], dict]: message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing]) response = await self._send(message, "select") self.check_response_for_result(response, "select") @@ -172,7 +172,7 @@ async def create( self, thing: Union[str, RecordID, Table], data: Optional[Union[Union[List[dict], dict], dict]] = None, - ) -> Union[List[dict], dict]: # type: ignore + ) -> Union[List[dict], dict]: if isinstance(thing, str): if ":" in thing: buffer = thing.split(":") @@ -186,7 +186,7 @@ async def create( async def update( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Union[List[dict], dict]: # type: ignore + ) -> Union[List[dict], dict]: message = RequestMessage( self.id, RequestMethod.UPDATE, record_id=thing, data=data ) @@ -196,7 +196,7 @@ async def update( async def merge( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Union[List[dict], dict]: # type: ignore + ) -> Union[List[dict], dict]: message = RequestMessage( self.id, RequestMethod.MERGE, record_id=thing, data=data ) @@ -206,7 +206,7 @@ async def merge( async def patch( self, thing: Union[str, RecordID, Table], data: Optional[List[dict]] = None - ) -> Union[List[dict], dict]: # type: ignore + ) -> Union[List[dict], dict]: message = RequestMessage( self.id, RequestMethod.PATCH, collection=thing, params=data ) @@ -216,7 +216,7 @@ async def patch( async def delete( self, thing: Union[str, RecordID, Table] - ) -> Union[List[dict], dict]: # type: ignore + ) -> Union[List[dict], dict]: message = RequestMessage(self.id, RequestMethod.DELETE, record_id=thing) response = await self._send(message, "delete") self.check_response_for_result(response, "delete") @@ -224,7 +224,7 @@ async def delete( async def insert( self, table: Union[str, Table], data: Union[List[dict], dict] - ) -> Union[List[dict], dict]: # type: ignore + ) -> Union[List[dict], dict]: message = RequestMessage( self.id, RequestMethod.INSERT, collection=table, params=data ) @@ -234,7 +234,7 @@ async def insert( async def insert_relation( self, table: Union[str, Table], data: Union[List[dict], dict] - ) -> Union[List[dict], dict]: # type: ignore + ) -> Union[List[dict], dict]: message = RequestMessage( self.id, RequestMethod.INSERT_RELATION, table=table, params=data ) @@ -242,7 +242,7 @@ async def insert_relation( self.check_response_for_result(response, "insert_relation") return response["result"] - async def live(self, table: Union[str, Table], diff: bool = False) -> UUID: # type: ignore + async def live(self, table: Union[str, Table], diff: bool = False) -> UUID: message = RequestMessage( self.id, RequestMethod.LIVE, @@ -254,7 +254,7 @@ async def live(self, table: Union[str, Table], diff: bool = False) -> UUID: # t async def subscribe_live( self, query_uuid: Union[str, UUID] - ) -> AsyncGenerator[dict, None]: # type: ignore + ) -> AsyncGenerator[dict, None]: result_queue = Queue() async def listen_live(): @@ -278,13 +278,13 @@ async def listen_live(): raise Exception(f"Error in live subscription: {result['error']}") yield result - async def kill(self, query_uuid: Union[str, UUID]) -> None: # type: ignore + async def kill(self, query_uuid: Union[str, UUID]) -> None: message = RequestMessage(self.id, RequestMethod.KILL, uuid=query_uuid) await self._send(message, "kill") async def upsert( self, thing: Union[str, RecordID, Table], data: Optional[Dict] = None - ) -> Union[List[dict], dict]: # type: ignore + ) -> Union[List[dict], dict]: message = RequestMessage( self.id, RequestMethod.UPSERT, record_id=thing, data=data ) @@ -295,7 +295,7 @@ async def upsert( async def close(self): await self.socket.close() - async def __aenter__(self) -> "AsyncWsSurrealConnection": # type: ignore + async def __aenter__(self) -> "AsyncWsSurrealConnection": """ Asynchronous context manager entry. Initializes a websocket connection and returns the connection instance. @@ -305,7 +305,7 @@ async def __aenter__(self) -> "AsyncWsSurrealConnection": # type: ignore ) return self - async def __aexit__(self, exc_type, exc_value, traceback) -> None: # type: ignore + async def __aexit__(self, exc_type, exc_value, traceback) -> None: """ Asynchronous context manager exit. Closes the websocket connection upon exiting the context. From ada642a1ef0e2844aed427b13a86841878f168b5 Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Tue, 4 Feb 2025 16:30:02 +0000 Subject: [PATCH 08/11] updating tests --- src/surrealdb/connections/async_ws.py | 6 +++--- src/surrealdb/connections/blocking_http.py | 2 +- src/surrealdb/connections/blocking_ws.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/surrealdb/connections/async_ws.py b/src/surrealdb/connections/async_ws.py index 3e7734c8..5f38ef0a 100644 --- a/src/surrealdb/connections/async_ws.py +++ b/src/surrealdb/connections/async_ws.py @@ -44,8 +44,8 @@ def __init__( """ self.url: Url = Url(url) self.raw_url: str = f"{self.url.raw_url}/rpc" - self.host: str | None = self.url.hostname - self.port: int | None = self.url.port + self.host: Optional[str] = self.url.hostname + self.port: Optional[int] = self.url.port self.id: str = str(uuid.uuid4()) self.token: Optional[str] = None self.socket = None @@ -162,7 +162,7 @@ async def unset(self, key: str) -> None: message = RequestMessage(self.id, RequestMethod.UNSET, params=[key]) await self._send(message, "unsetting") - async def select(self, thing: str | RecordID | Table) -> Union[List[dict], dict]: + async def select(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing]) response = await self._send(message, "select") self.check_response_for_result(response, "select") diff --git a/src/surrealdb/connections/blocking_http.py b/src/surrealdb/connections/blocking_http.py index a02139a8..01380b48 100644 --- a/src/surrealdb/connections/blocking_http.py +++ b/src/surrealdb/connections/blocking_http.py @@ -202,7 +202,7 @@ def patch( self.check_response_for_result(response, "patch") return response["result"] - def select(self, thing: str | RecordID | Table) -> Union[List[dict], dict]: + def select(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing]) response = self._send(message, "select") self.check_response_for_result(response, "select") diff --git a/src/surrealdb/connections/blocking_ws.py b/src/surrealdb/connections/blocking_ws.py index 52b1bdb9..97d1c99f 100644 --- a/src/surrealdb/connections/blocking_ws.py +++ b/src/surrealdb/connections/blocking_ws.py @@ -146,7 +146,7 @@ def unset(self, key: str) -> None: message = RequestMessage(self.id, RequestMethod.UNSET, params=[key]) self._send(message, "unsetting") - def select(self, thing: str | RecordID | Table) -> Union[List[dict], dict]: + def select(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing]) response = self._send(message, "select") self.check_response_for_result(response, "select") From 66c0998adb7276aed2f0ec032fcfe5b82aee69cd Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Tue, 4 Feb 2025 16:30:48 +0000 Subject: [PATCH 09/11] updating tests --- .github/workflows/stability.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/stability.yml b/.github/workflows/stability.yml index c4d8d319..39272dc9 100644 --- a/.github/workflows/stability.yml +++ b/.github/workflows/stability.yml @@ -35,5 +35,6 @@ jobs: - name: Run black checks run: black --check --verbose --diff --color src/ - - name: Run mypy checks - run: mypy --explicit-package-bases src/ + # This is currently disabled because MyPy is very confused about Coroutine types +# - name: Run mypy checks +# run: mypy --explicit-package-bases src/ From 6b9695cc387cc27aa64de80262e2c183ea6e28ec Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Tue, 4 Feb 2025 16:43:06 +0000 Subject: [PATCH 10/11] updating tests --- tests/unit_tests/data_types/test_datetimes.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/data_types/test_datetimes.py b/tests/unit_tests/data_types/test_datetimes.py index f12d56aa..311a6f6a 100644 --- a/tests/unit_tests/data_types/test_datetimes.py +++ b/tests/unit_tests/data_types/test_datetimes.py @@ -68,7 +68,11 @@ async def test_datetime_iso_format(self): # assert that the datetime returned from the DB is the same as the one serialized date = compact_test_outcome[0]["datetime"].isoformat() - self.assertEqual(date + "Z", iso_datetime) + date_to_compare = date + "Z" + if sys.version_info < (3, 11): + date_to_compare = date_to_compare.replace("Z", "+00:00") + + self.assertEqual(date_to_compare, iso_datetime) # Cleanup await self.connection.query("DELETE datetime_tests;") From 17fc75ad944b4b53fc59f8c5c1fb752c029c5068 Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Tue, 4 Feb 2025 16:46:06 +0000 Subject: [PATCH 11/11] updating tests --- src/surrealdb/connections/async_ws.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/surrealdb/connections/async_ws.py b/src/surrealdb/connections/async_ws.py index 5f38ef0a..4cfe4e0b 100644 --- a/src/surrealdb/connections/async_ws.py +++ b/src/surrealdb/connections/async_ws.py @@ -162,7 +162,9 @@ async def unset(self, key: str) -> None: message = RequestMessage(self.id, RequestMethod.UNSET, params=[key]) await self._send(message, "unsetting") - async def select(self, thing: Union[str, RecordID, Table]) -> Union[List[dict], dict]: + async def select( + self, thing: Union[str, RecordID, Table] + ) -> Union[List[dict], dict]: message = RequestMessage(self.id, RequestMethod.SELECT, params=[thing]) response = await self._send(message, "select") self.check_response_for_result(response, "select")