Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions datajoint/autopopulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import signal
import multiprocessing as mp
import contextlib
import deepdiff

# noinspection PyExceptionInherit,PyCallingNonCallable

Expand Down Expand Up @@ -309,17 +310,46 @@ def _populate1(
):
return False

self.connection.start_transaction()
# if make is a generator, it transaction can be delayed until the final stage
is_generator = inspect.isgeneratorfunction(make)
if not is_generator:
self.connection.start_transaction()

if key in self.target: # already populated
self.connection.cancel_transaction()
if not is_generator:
self.connection.cancel_transaction()
if jobs is not None:
jobs.complete(self.target.table_name, self._job_key(key))
return False

logger.debug(f"Making {key} -> {self.target.full_table_name}")
self.__class__._allow_insert = True

try:
make(dict(key), **(make_kwargs or {}))
if not is_generator:
make(dict(key), **(make_kwargs or {}))
else:
# tripartite make - transaction is delayed until the final stage
gen = make(dict(key), **(make_kwargs or {}))
fetched_data = next(gen)
fetch_hash = deepdiff.DeepHash(
fetched_data, ignore_iterable_order=False
)[fetched_data]
computed_result = next(gen) # perform the computation
# fetch and insert inside a transaction
self.connection.start_transaction()
gen = make(dict(key), **(make_kwargs or {})) # restart make
fetched_data = next(gen)
if (
fetch_hash
!= deepdiff.DeepHash(fetched_data, ignore_iterable_order=False)[
fetched_data
]
): # rollback due to referential integrity fail
self.connection.cancel_transaction()
return False
gen.send(computed_result) # insert

except (KeyboardInterrupt, SystemExit, Exception) as error:
try:
self.connection.cancel_transaction()
Expand Down
2 changes: 1 addition & 1 deletion datajoint/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def pack_blob(self, obj):
return self.pack_dict(obj)
if isinstance(obj, str):
return self.pack_string(obj)
if isinstance(obj, collections.abc.ByteString):
if isinstance(obj, (bytes, bytearray)):
return self.pack_bytes(obj)
if isinstance(obj, collections.abc.MutableSequence):
return self.pack_list(obj)
Expand Down
15 changes: 10 additions & 5 deletions datajoint/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pathlib

from .settings import config
from . import errors
from . import errors, __version__
from .dependencies import Dependencies
from .blob import pack, unpack
from .hash import uuid_from_buffer
Expand Down Expand Up @@ -190,15 +190,20 @@ def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None)
self.conn_info["ssl_input"] = use_tls
self.conn_info["host_input"] = host_input
self.init_fun = init_fun
logger.info("Connecting {user}@{host}:{port}".format(**self.conn_info))
self._conn = None
self._query_cache = None
connect_host_hook(self)
if self.is_connected:
logger.info("Connected {user}@{host}:{port}".format(**self.conn_info))
logger.info(
"DataJoint {version} connected to {user}@{host}:{port}".format(
version=__version__, **self.conn_info
)
)
self.connection_id = self.query("SELECT connection_id()").fetchone()[0]
else:
raise errors.LostConnectionError("Connection failed.")
raise errors.LostConnectionError(
"Connection failed {user}@{host}:{port}".format(**self.conn_info)
)
self._in_transaction = False
self.schemas = dict()
self.dependencies = Dependencies(self)
Expand Down Expand Up @@ -344,7 +349,7 @@ def query(
except errors.LostConnectionError:
if not reconnect:
raise
logger.warning("MySQL server has gone away. Reconnecting to the server.")
logger.warning("Reconnecting to MySQL server.")
connect_host_hook(self)
if self._in_transaction:
self.cancel_transaction()
Expand Down
4 changes: 2 additions & 2 deletions datajoint/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

def subfold(name, folds):
"""
subfolding for external storage: e.g. subfold('aBCdefg', (2, 3)) --> ['ab','cde']
subfolding for external storage: e.g. subfold('aBCdefg', (2, 3)) --> ['ab','cde']
"""
return (
(name[: folds[0]].lower(),) + subfold(name[folds[0] :], folds[1:])
Expand Down Expand Up @@ -278,7 +278,7 @@ def upload_filepath(self, local_filepath):

# check if the remote file already exists and verify that it matches
check_hash = (self & {"hash": uuid}).fetch("contents_hash")
if check_hash:
if check_hash.size:
# the tracking entry exists, check that it's the same file as before
if contents_hash != check_hash[0]:
raise DataJointError(
Expand Down
4 changes: 2 additions & 2 deletions datajoint/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,8 @@ def list_tables(self):
return [
t
for d, t in (
full_t.replace("`", "").split(".")
for full_t in self.connection.dependencies.topo_sort()
table_name.replace("`", "").split(".")
for table_name in self.connection.dependencies.topo_sort()
)
if d == self.database
]
Expand Down
23 changes: 14 additions & 9 deletions datajoint/settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Settings for DataJoint.
Settings for DataJoint
"""

from contextlib import contextmanager
Expand Down Expand Up @@ -48,7 +48,8 @@
"database.use_tls": None,
"enable_python_native_blobs": True, # python-native/dj0 encoding support
"add_hidden_timestamp": False,
"filepath_checksum_size_limit": None, # file size limit for when to disable checksums
# file size limit for when to disable checksums
"filepath_checksum_size_limit": None,
}
)

Expand Down Expand Up @@ -117,6 +118,7 @@ def load(self, filename):
if filename is None:
filename = LOCALCONFIG
with open(filename, "r") as fid:
logger.info(f"DataJoint is configured from {os.path.abspath(filename)}")
self._conf.update(json.load(fid))

def save_local(self, verbose=False):
Expand Down Expand Up @@ -236,7 +238,8 @@ class __Config:

def __init__(self, *args, **kwargs):
self._conf = dict(default)
self._conf.update(dict(*args, **kwargs)) # use the free update to set keys
# use the free update to set keys
self._conf.update(dict(*args, **kwargs))

def __getitem__(self, key):
return self._conf[key]
Expand All @@ -250,7 +253,9 @@ def __setitem__(self, key, value):
valid_logging_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
if key == "loglevel":
if value not in valid_logging_levels:
raise ValueError(f"{'value'} is not a valid logging value")
raise ValueError(
f"'{value}' is not a valid logging value {tuple(valid_logging_levels)}"
)
logger.setLevel(value)


Expand All @@ -260,11 +265,9 @@ def __setitem__(self, key, value):
os.path.expanduser(n) for n in (LOCALCONFIG, os.path.join("~", GLOBALCONFIG))
)
try:
config_file = next(n for n in config_files if os.path.exists(n))
config.load(next(n for n in config_files if os.path.exists(n)))
except StopIteration:
pass
else:
config.load(config_file)
logger.info("No config file was found.")

# override login credentials with environment variables
mapping = {
Expand Down Expand Up @@ -292,6 +295,8 @@ def __setitem__(self, key, value):
)
if v is not None
}
config.update(mapping)
if mapping:
logger.info(f"Overloaded settings {tuple(mapping)} from environment variables.")
config.update(mapping)

logger.setLevel(log_levels[config["loglevel"]])
2 changes: 1 addition & 1 deletion datajoint/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.14.3"
__version__ = "0.14.4"

assert len(__version__) <= 10 # The log table limits version to the 10 characters
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.14.3"
dependencies = [
"numpy",
"pymysql>=0.7.2",
"deepdiff",
"pyparsing",
"ipython",
"pandas",
Expand Down
1 change: 0 additions & 1 deletion tests/test_declare.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ class WithSuchALongPartNameThatItCrashesMySQL(dj.Part):


def test_regex_mismatch(schema_any):

class IndexAttribute(dj.Manual):
definition = """
index: int
Expand Down
1 change: 0 additions & 1 deletion tests/test_relational_operand.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,6 @@ def test_union_multiple(schema_simp_pop):


class TestDjTop:

def test_restrictions_by_top(self, schema_simp_pop):
a = L() & dj.Top()
b = L() & dj.Top(order_by=["cond_in_l", "KEY"])
Expand Down