Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check the key type when the dictionary is indexed #46

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
123 changes: 68 additions & 55 deletions src/pysorteddict/pysorteddict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ struct SortedDictType
PyObject* key_type;

void deinit(void);
bool is_type_key_type(PyObject*, bool);
bool validate_key_type(PyObject*, bool);
int contains(PyObject*);
PyObject* getitem(PyObject*);
int setitem(PyObject*, PyObject*);
Expand All @@ -104,41 +104,56 @@ void SortedDictType::deinit(void)
}

/**
* Check whether a Python object has the correct type for use as a key.
* Validate that the key type member of this instance is non-null, and that it
* matches the type of the key provided. (I know, smartly named, right?)
*
* @param ob Python object.
* @param raise Whether to set a Python exception if the type is wrong.
* @param key Key.
* @param raise Whether to set a Python exception on failure.
*
* @return `true` if its type is the same as the key type, else `false`.
* @return `true` if the key is okay for use with this instance, else `false`.
*/
bool SortedDictType::is_type_key_type(PyObject* ob, bool raise = true)
bool SortedDictType::validate_key_type(PyObject* key, bool raise = true)
{
if (Py_IS_TYPE(ob, reinterpret_cast<PyTypeObject*>(this->key_type)) != 0)
// There is scope for optimisation here. This check will always pass when
// something is inserted for the first time, so it can be skipped.
if (this->key_type == nullptr)
{
return true;
if (raise)
{
PyErr_SetString(PyExc_RuntimeError, "key type cannot be validated; no items have been inserted");
}
return false;
}
if (raise)
if (Py_IS_TYPE(key, reinterpret_cast<PyTypeObject*>(this->key_type)) == 0)
{
PyObject* key_type_repr = PyObject_Repr(this->key_type); // New reference.
if (key_type_repr == nullptr)
if (raise)
{
return false;
PyObject* key_type_repr = PyObject_Repr(this->key_type); // New reference.
if (key_type_repr == nullptr)
{
return false;
}
PyErr_Format(PyExc_TypeError, "key type must be %s", PyUnicode_AsUTF8(key_type_repr));
Py_DECREF(key_type_repr);
}
PyErr_Format(PyExc_TypeError, "key must be of type %s", PyUnicode_AsUTF8(key_type_repr));
Py_DECREF(key_type_repr);
return false;
}
return false;
return true;
}

/**
* Check whether a key is present without checking the type of the key.
* Check whether a key is present.
*
* @param ob Python object.
*
* @return 1 if it is present, else 0.
*/
int SortedDictType::contains(PyObject* key)
{
if (!this->validate_key_type(key, false))
{
return 0;
}
if (this->map->find(key) == this->map->end())
{
return 0;
Expand All @@ -147,15 +162,18 @@ int SortedDictType::contains(PyObject* key)
}

/**
* Find the value mapped to a key without checking the type of the key. If not
* found, set a Python exception.
* Find the value mapped to a key. If not found, set a Python exception.
*
* @param key Key.
*
* @return Value if found, else `nullptr`.
*/
PyObject* SortedDictType::getitem(PyObject* key)
{
if (!this->validate_key_type(key))
{
return nullptr;
}
auto it = this->map->find(key);
if (it == this->map->end())
{
Expand All @@ -166,8 +184,8 @@ PyObject* SortedDictType::getitem(PyObject* key)
}

/**
* Map a value to a key or remove a key-value pair without checking the type of
* the key. If not removed when removal was requested, set a Python exception.
* Map a value to a key or remove a key-value pair. If not removed when removal
* was requested, set a Python exception.
*
* @param key Key.
* @param value Value.
Expand All @@ -176,14 +194,42 @@ PyObject* SortedDictType::getitem(PyObject* key)
*/
int SortedDictType::setitem(PyObject* key, PyObject* value)
{
if (this->key_type == nullptr)
{
// It is the first time something is being inserted. Set the key type
// member.
static PyTypeObject* allowed_key_types[] = {
&PyLong_Type,
};
PyObject* key_type = reinterpret_cast<PyObject*>(Py_TYPE(key));
for (auto allowed_key_type : allowed_key_types)
{
if (PyObject_RichCompareBool(reinterpret_cast<PyObject*>(allowed_key_type), key_type, Py_EQ) == 1)
{
this->key_type = Py_NewRef(key_type);
break;
}
}
if (this->key_type == nullptr)
{
PyErr_SetString(PyExc_TypeError, "key type is unsupported");
return -1;
}
}

if (!this->validate_key_type(key))
{
return -1;
}

// Insertion will be faster if the approximate location is known. Hence,
// look for the nearest match.
auto it = this->map->lower_bound(key);
bool found = it != this->map->end() && !this->map->key_comp()(key, it->first);

// Remove the key-value pair.
if (value == nullptr)
{
// Remove the key-value pair.
if (!found)
{
PyErr_SetObject(PyExc_KeyError, key);
Expand Down Expand Up @@ -337,28 +383,7 @@ int SortedDictType::init(PyObject* args, PyObject* kwargs)
// explicitly initialise them.
this->map = new std::map<PyObject*, PyObject*, PyObject_CustomCompare>;
this->key_type = nullptr;

// Up to Python 3.12, the argument parser below took an array of pointers
// (with each pointer pointing to a C string) as its fourth argument.
// However, C++ does not allow converting a string constant to a pointer.
// Hence, I use a character array to construct the C string, and then place
// it in an array of pointers.
char arg_name[] = "key_type";
char* args_names[] = { arg_name, nullptr };
PyObject* key_type;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|", args_names, &key_type))
{
return -1;
}

// Check the type to use for keys.
if (PyObject_RichCompareBool(key_type, reinterpret_cast<PyObject*>(&PyLong_Type), Py_EQ) != 1)
{
PyErr_SetString(PyExc_TypeError, "constructor argument must be a supported type");
return -1;
}

this->key_type = Py_NewRef(key_type);
// This is a stub. Will implement it later.
return 0;
}

Expand Down Expand Up @@ -393,10 +418,6 @@ static void sorted_dict_type_dealloc(PyObject* self)
static int sorted_dict_type_contains(PyObject* self, PyObject* key)
{
SortedDictType* sd = reinterpret_cast<SortedDictType*>(self);
if (!sd->is_type_key_type(key, false))
{
return 0;
}
return sd->contains(key);
}

Expand Down Expand Up @@ -430,10 +451,6 @@ static Py_ssize_t sorted_dict_type_len(PyObject* self)
static PyObject* sorted_dict_type_getitem(PyObject* self, PyObject* key)
{
SortedDictType* sd = reinterpret_cast<SortedDictType*>(self);
if (!sd->is_type_key_type(key))
{
return nullptr;
}
return sd->getitem(key);
}

Expand All @@ -443,10 +460,6 @@ static PyObject* sorted_dict_type_getitem(PyObject* self, PyObject* key)
static int sorted_dict_type_setitem(PyObject* self, PyObject* key, PyObject* value)
{
SortedDictType* sd = reinterpret_cast<SortedDictType*>(self);
if (!sd->is_type_key_type(key))
{
return -1;
}
return sd->setitem(key, value);
}

Expand Down
6 changes: 6 additions & 0 deletions tests/test_invalid_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,25 @@

from pysorteddict import SortedDict

# Update and enable these tests once the constructor is updated to function the
# same way the constructor of Python's native dictionary does.


@pytest.mark.skip
def test_no_arguments():
with pytest.raises(TypeError) as ctx:
SortedDict()
assert ctx.value.args[0] == "function missing required argument 'key_type' (pos 1)"


@pytest.mark.skip
def test_superfluous_arguments():
with pytest.raises(TypeError) as ctx:
SortedDict(object, object)
assert ctx.value.args[0] == "function takes at most 1 argument (2 given)"


@pytest.mark.skip
@pytest.mark.parametrize("key_type", [object, object(), 63, 5.31, "effort", b"salt", ["hear", 0x5EE], (1.61, "taste")])
def test_wrong_type(key_type):
with pytest.raises(TypeError) as ctx:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, key_type: type):
self.values = [self.gen() for _ in self.keys]
self.normal_dict = dict(zip(self.keys, self.values, strict=True))

sorted_dict = SortedDict(self.key_type)
sorted_dict = SortedDict()
for key, value in zip(self.keys, self.values, strict=True):
sorted_dict[key] = value
self.sorted_dicts = [sorted_dict, sorted_dict.copy()]
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_contains_yes(resources, sorted_dict):
def test_getitem_wrong_type(resources, sorted_dict):
with pytest.raises(TypeError) as ctx:
sorted_dict[resources.key_subtype()]
assert ctx.value.args[0] == f"key must be of type {resources.key_type!r}"
assert ctx.value.args[0] == f"key type must be {resources.key_type!r}"


def test_getitem_missing(resources, sorted_dict):
Expand All @@ -158,7 +158,7 @@ def test_getitem_found(resources, sorted_dict):
def test_delitem_wrong_type(resources, sorted_dict):
with pytest.raises(TypeError) as ctx:
del sorted_dict[resources.key_subtype()]
assert ctx.value.args[0] == f"key must be of type {resources.key_type!r}"
assert ctx.value.args[0] == f"key type must be {resources.key_type!r}"


def test_delitem_missing(resources, sorted_dict):
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_setitem_wrong_type(resources, sorted_dict):
value = resources.gen()
with pytest.raises(TypeError) as ctx:
sorted_dict[resources.key_subtype()] = value
assert ctx.value.args[0] == f"key must be of type {resources.key_type!r}"
assert ctx.value.args[0] == f"key type must be {resources.key_type!r}"

if cpython:
assert sys.getrefcount(value) == 2
Expand Down
Loading