Skip to content

Commit

Permalink
fix inproper auto conversion by ChatPGT
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-wangxu authored and yuzhi.wx committed Feb 8, 2024
1 parent 29c28a0 commit 3e9e896
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 40 deletions.
2 changes: 1 addition & 1 deletion persistqueue/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
class Empty(Exception):
"""Exception raised when an operation is attempted on an empty container."""
pass


class Full(Exception):
"""Exception raised when an attempt is made to add an item to a full container."""
Expand Down
25 changes: 17 additions & 8 deletions persistqueue/mysqlqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .sqlbase import SQLBase
from typing import Any, Optional


class MySQLQueue(SQLBase):
"""Mysql(or future standard dbms) based FIFO queue."""
_TABLE_NAME = 'queue'
Expand All @@ -31,13 +32,13 @@ class MySQLQueue(SQLBase):
)
_SQL_UPDATE = 'UPDATE {table_name} SET data = %s WHERE {key_column} = %s'
_SQL_DELETE = 'DELETE FROM {table_name} WHERE {key_column} {op} %s'

def __init__(
self,
host: str,
user: str,
passwd: str,
db_name: str,
self,
host: str,
user: str,
passwd: str,
db_name: str,
name: Optional[str] = None,
port: int = 3306,
charset: str = 'utf8mb4',
Expand Down Expand Up @@ -84,6 +85,7 @@ def _new_db_connection(self) -> None:
self._getter = self._putter

def put(self, item: Any, block: bool = True) -> int:
# block kwarg is noop and only here to align with python's queue
obj = self._serializer.dumps(item)
_id = self._insert_into(obj, _time.time())
self.total += 1
Expand All @@ -106,14 +108,19 @@ def _init(self) -> None:
def get_pooled_conn(self) -> Any:
return self._connection_pool.connection()


class MySQLConn:
"""MySqlConn defines a common structure for
both mysql and sqlite3 connections.
used to mitigate the interface differences between drivers/db
"""

def __init__(self, queue: Optional[MySQLQueue] = None, conn: Optional[Any] = None) -> None:
self._queue = queue
self._conn = conn if conn else (queue.get_pooled_conn() if queue else None)
if queue is not None:
self._conn = queue.get_pooled_conn()
else:
self._conn = conn
self._cursor = None
self.closed = False

Expand All @@ -122,6 +129,8 @@ def __enter__(self) -> Any:
return self._conn

def __exit__(self, exc_type: Optional[type], exc_val: Optional[BaseException], exc_tb: Optional[Any]) -> None:
# do not commit() but to close() , keep same behavior
# with dbutils
self._cursor.close()

def execute(self, *args: Any, **kwargs: Any) -> Any:
Expand All @@ -133,7 +142,7 @@ def execute(self, *args: Any, **kwargs: Any) -> Any:
def close(self) -> None:
if not self.closed:
self._conn.close()
self.closed = True
self.closed = True

def commit(self) -> None:
if not self.closed:
Expand Down
4 changes: 2 additions & 2 deletions persistqueue/pdict.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# coding=utf-8
import logging
import sqlite3
from persistqueue import sqlbase
from typing import Any, Iterator

log = logging.getLogger(__name__)


class PDict(sqlbase.SQLiteBase, dict):
_TABLE_NAME = 'dict'
_KEY_COLUMN = 'key'
Expand All @@ -16,7 +16,7 @@ class PDict(sqlbase.SQLiteBase, dict):
'WHERE {key_column} = ?')
_SQL_UPDATE = 'UPDATE {table_name} SET data = ? WHERE {key_column} = ?'
_SQL_DELETE = 'DELETE FROM {table_name} WHERE {key_column} {op} ?'

def __init__(self, path: str, name: str, multithreading: bool = False) -> None:
# PDict is always auto_commit=True
super().__init__(path, name=name, multithreading=multithreading, auto_commit=True)
Expand Down
21 changes: 15 additions & 6 deletions persistqueue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,26 @@

log = logging.getLogger(__name__)


def _truncate(fn: str, length: int) -> None:
"""Truncate the file to a specified length."""
with open(fn, 'ab+') as f:
f.truncate(length)


def atomic_rename(src: str, dst: str) -> None:
"""Atomically rename a file from src to dst."""
os.replace(src, dst)


class Queue:
"""Thread-safe, persistent queue."""

def __init__(
self,
path: str,
maxsize: int = 0,
chunksize: int = 100,
self,
path: str,
maxsize: int = 0,
chunksize: int = 100,
tempdir: Optional[str] = None,
serializer: Any = persistqueue.serializers.pickle,
autosave: bool = False
Expand Down Expand Up @@ -66,7 +69,8 @@ def __init__(
self._init(maxsize)
if self.tempdir:
if os.stat(self.path).st_dev != os.stat(self.tempdir).st_dev:
raise ValueError("tempdir has to be located on same path filesystem")
raise ValueError(
"tempdir has to be located on same path filesystem")
else:
fd, tempdir = tempfile.mkstemp()
if os.stat(self.path).st_dev != os.stat(tempdir).st_dev:
Expand All @@ -83,18 +87,21 @@ def __init__(
if os.path.exists(headfn):
if hoffset < os.path.getsize(headfn):
_truncate(headfn, hoffset)
# let the head file open
self.headf: BinaryIO = self._openchunk(hnum, 'ab+')
tnum, _, toffset = self.info['tail']
self.tailf: BinaryIO = self._openchunk(tnum)
self.tailf.seek(toffset)
# update unfinished tasks with the current number of enqueued tasks
self.unfinished_tasks: int = self.info['size']
self.update_info: bool = True

def _init(self, maxsize: int) -> None:
self.mutex: threading.Lock = threading.Lock()
self.not_empty: threading.Condition = threading.Condition(self.mutex)
self.not_full: threading.Condition = threading.Condition(self.mutex)
self.all_tasks_done: threading.Condition = threading.Condition(self.mutex)
self.all_tasks_done: threading.Condition = threading.Condition(
self.mutex)
if not os.path.exists(self.path):
os.makedirs(self.path)

Expand Down Expand Up @@ -256,6 +263,7 @@ def _saveinfo(self) -> None:
self._clear_tail_file()

def _clear_tail_file(self) -> None:
"""Remove the tail files whose items were already get."""
tnum, _, _ = self.info['tail']
while tnum >= 1:
tnum -= 1
Expand All @@ -272,6 +280,7 @@ def _infopath(self) -> str:
return os.path.join(self.path, 'info')

def __del__(self) -> None:
"""Handles the removal of queue."""
for to_close in self.headf, self.tailf:
if to_close and not to_close.closed:
to_close.close()
4 changes: 4 additions & 0 deletions persistqueue/serializers/cbor2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Define the Struct for prefixing serialized objects with their byte length
length_struct = Struct("<L")


def dump(value: Any, fp: BinaryIO, sort_keys: bool = False) -> None:
"""
Serialize value as cbor2 to a byte-mode file object with a length prefix.
Expand All @@ -29,6 +30,7 @@ def dump(value: Any, fp: BinaryIO, sort_keys: bool = False) -> None:
fp.write(length)
fp.write(packed)


def dumps(value: Any, sort_keys: bool = False) -> bytes:
"""
Serialize value as cbor2 to bytes without length prefix.
Expand All @@ -45,6 +47,7 @@ def dumps(value: Any, sort_keys: bool = False) -> bytes:
value = {key: value[key] for key in sorted(value)}
return cbor2.dumps(value)


def load(fp: BinaryIO) -> Any:
"""
Deserialize one cbor2 value from a byte-mode file object using length prefix.
Expand All @@ -60,6 +63,7 @@ def load(fp: BinaryIO) -> Any:
# Read the serialized object using the determined length and deserialize it
return cbor2.loads(fp.read(length))


def loads(bytes_value: bytes) -> Any:
"""
Deserialize one cbor2 value from bytes.
Expand Down
5 changes: 4 additions & 1 deletion persistqueue/serializers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
from typing import Any, BinaryIO


def dump(value: Any, fp: BinaryIO, sort_keys: bool = False) -> None:
"""Serialize value as json line to a byte-mode file object.
Expand All @@ -19,6 +20,7 @@ def dump(value: Any, fp: BinaryIO, sort_keys: bool = False) -> None:
fp.write(json.dumps(value, sort_keys=sort_keys).encode('utf-8'))
fp.write(b"\n")


def dumps(value: Any, sort_keys: bool = False) -> bytes:
"""Serialize value as json to bytes.
Expand All @@ -31,6 +33,7 @@ def dumps(value: Any, sort_keys: bool = False) -> bytes:
"""
return json.dumps(value, sort_keys=sort_keys).encode('utf-8')


def load(fp: BinaryIO) -> Any:
"""Deserialize one json line from a byte-mode file object.
Expand All @@ -42,6 +45,7 @@ def load(fp: BinaryIO) -> Any:
"""
return json.loads(fp.readline().decode('utf-8'))


def loads(bytes_value: bytes) -> Any:
"""Deserialize one json value from bytes.
Expand All @@ -52,4 +56,3 @@ def loads(bytes_value: bytes) -> Any:
The deserialized Python object.
"""
return json.loads(bytes_value.decode('utf-8'))

20 changes: 12 additions & 8 deletions persistqueue/serializers/msgpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
import struct
from typing import Any, BinaryIO, Dict


def dump(value: Any, fp: BinaryIO, sort_keys: bool = False) -> None:
"""
Serialize value as msgpack to a byte-mode file object with a length prefix.
Args:
value: The Python object to serialize.
fp: A file-like object supporting binary write operations.
sort_keys: If True, the output of dictionaries will be sorted by key.
Returns:
None
"""
Expand All @@ -25,41 +26,44 @@ def dump(value: Any, fp: BinaryIO, sort_keys: bool = False) -> None:
fp.write(length)
fp.write(packed)


def dumps(value: Any, sort_keys: bool = False) -> bytes:
"""
Serialize value as msgpack to bytes.
Args:
value: The Python object to serialize.
sort_keys: If True, the output of dictionaries will be sorted by key.
Returns:
A bytes object containing the serialized representation of value.
"""
if sort_keys and isinstance(value, Dict):
value = {key: value[key] for key in sorted(value)}
return msgpack.packb(value, use_bin_type=True)


def load(fp: BinaryIO) -> Any:
"""
Deserialize one msgpack value from a byte-mode file object using length prefix.
Args:
fp: A file-like object supporting binary read operations.
Returns:
The deserialized Python object.
"""
length = struct.unpack("<L", fp.read(4))[0]
return msgpack.unpackb(fp.read(length), use_list=False, raw=False)


def loads(bytes_value: bytes) -> Any:
"""
Deserialize one msgpack value from bytes.
Args:
bytes_value: A bytes object containing the serialized msgpack data.
Returns:
The deserialized Python object.
"""
Expand Down
6 changes: 5 additions & 1 deletion persistqueue/serializers/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
log = logging.getLogger(__name__)

# Retrieve the selected pickle protocol from a common utility module
protocol: int = 4 # Python 3 uses protocol version 4 or higher
protocol: int = 4 # Python 3 uses protocol version 4 or higher
log.info(f"Selected pickle protocol: '{protocol}'")


def dump(value: Any, fp: BinaryIO, sort_keys: bool = False) -> None:
"""
Serialize value as pickle to a byte-mode file object.
Expand All @@ -29,6 +30,7 @@ def dump(value: Any, fp: BinaryIO, sort_keys: bool = False) -> None:
value = {key: value[key] for key in sorted(value)}
pickle.dump(value, fp, protocol=protocol)


def dumps(value: Any, sort_keys: bool = False) -> bytes:
"""
Serialize value as pickle to bytes.
Expand All @@ -46,6 +48,7 @@ def dumps(value: Any, sort_keys: bool = False) -> bytes:
value = {key: value[key] for key in sorted(value)}
return pickle.dumps(value, protocol=protocol)


def load(fp: BinaryIO) -> Any:
"""
Deserialize one pickle value from a byte-mode file object.
Expand All @@ -58,6 +61,7 @@ def load(fp: BinaryIO) -> Any:
"""
return pickle.load(fp)


def loads(bytes_value: bytes) -> Any:
"""
Deserialize one pickle value from bytes.
Expand Down
Loading

0 comments on commit 3e9e896

Please sign in to comment.