Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subpartitioning Python Cosmos DB SDK #31121

Merged
merged 30 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
80cb361
sub partitioning
bambriz Jul 4, 2023
59d3d9d
Additional Sub Partitioning Updates
bambriz Jul 13, 2023
398bf88
Merge branch 'main' into subpartitioning
bambriz Jul 13, 2023
0bea4a7
remove uneeded line
bambriz Jul 13, 2023
4aa2da4
Merge branch 'subpartitioning' of https://github.com/bambriz/azure-sd…
bambriz Jul 13, 2023
fa7225c
update changelog
bambriz Jul 13, 2023
c61942e
pylint fixes
bambriz Jul 13, 2023
8529239
remove debug code on subpartition test
bambriz Jul 13, 2023
6b6fc2a
Merge remote-tracking branch 'upstream/main' into subpartitioning
bambriz Aug 22, 2023
542793e
Adding support for prefix partition queries
bambriz Sep 12, 2023
6ddc172
pylint and cspell fixes
bambriz Sep 13, 2023
1e7e8d4
Merge branch 'Azure:main' into subpartitioning
bambriz Sep 14, 2023
bf0b519
Additional Updates and fixes
bambriz Sep 14, 2023
9f8a930
removing uneeded lines from test config
bambriz Sep 14, 2023
f8c1346
Test fix
bambriz Sep 14, 2023
1349513
update test crud subpartition
bambriz Sep 14, 2023
5c0ce6d
Update test_config.py
bambriz Sep 15, 2023
e224bae
additional feedback fixes
bambriz Sep 15, 2023
373c171
Merge branch 'subpartitioning' of https://github.com/bambriz/azure-sd…
bambriz Sep 15, 2023
572f154
Fixed Python Version Compatibility
bambriz Sep 15, 2023
78d23e4
Fixed small issue causing tests to fail
bambriz Sep 15, 2023
50c9bbc
Testing fix for subpartitioning
bambriz Sep 15, 2023
209ab93
Update test_crud_subpartition_async.py
simorenoh Sep 15, 2023
ebd097f
Update test_crud_subpartition_async.py
simorenoh Sep 15, 2023
b9cc291
Update dev_requirements.txt
simorenoh Sep 18, 2023
710d48e
Update async test and samples
bambriz Oct 3, 2023
1990bed
Change public method to be private
bambriz Oct 3, 2023
763b021
Added support for prefix query involving multiple over lapping ranges
bambriz Oct 6, 2023
1de4181
Better over lapping support and new over lapping range tests
bambriz Oct 9, 2023
fef66b7
Clarified information in some comments
bambriz Oct 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.5.2 (Unreleased)

#### Features Added
* Added Support for Subpartitioning in Python SDK. See [PR 31121](https://github.com/Azure/azure-sdk-for-python/pull/31121)

#### Breaking Changes

Expand Down
100 changes: 51 additions & 49 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,12 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches
headers[http_constants.HttpHeaders.PartitionKey] = []
# else serialize using json dumps method which apart from regular values will serialize None into null
else:
headers[http_constants.HttpHeaders.PartitionKey] = json.dumps([options["partitionKey"]])
# single partitioning uses a string and needs to be turned into a list
if isinstance(options["partitionKey"], list) and options["partitionKey"]:
bambriz marked this conversation as resolved.
Show resolved Hide resolved
pk_val = json.dumps(options["partitionKey"], separators=(',', ':'))
else:
pk_val = json.dumps([options["partitionKey"]])
headers[http_constants.HttpHeaders.PartitionKey] = pk_val

if options.get("enableCrossPartitionQuery"):
headers[http_constants.HttpHeaders.EnableCrossPartitionQuery] = options["enableCrossPartitionQuery"]
Expand All @@ -224,7 +229,7 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches
headers[http_constants.HttpHeaders.PopulateQueryMetrics] = options["populateQueryMetrics"]

if options.get("responseContinuationTokenLimitInKb"):
headers[http_constants.HttpHeaders.ResponseContinuationTokenLimitInKb] = options["responseContinuationTokenLimitInKb"] # pylint: disable=line-too-long
headers[http_constants.HttpHeaders.ResponseContinuationTokenLimitInKb] = options["responseContinuationTokenLimitInKb"] # pylint: disable=line-too-long

if cosmos_client_connection.master_key:
#formatedate guarantees RFC 1123 date format regardless of current locale
Expand Down Expand Up @@ -606,56 +611,53 @@ def TrimBeginningAndEndingSlashes(path):

# Parses the paths into a list of token each representing a property
def ParsePaths(paths):
if len(paths) != 1:
raise ValueError("Unsupported paths count.")

segmentSeparator = "/"
path = paths[0]
tokens = []
currentIndex = 0

while currentIndex < len(path):
if path[currentIndex] != segmentSeparator:
raise ValueError("Invalid path character at index " + currentIndex)

currentIndex += 1
if currentIndex == len(path):
break

# " and ' are treated specially in the sense that they can have the / (segment separator)
# between them which is considered part of the token
if path[currentIndex] == '"' or path[currentIndex] == "'":
quote = path[currentIndex]
newIndex = currentIndex + 1

while True:
newIndex = path.find(quote, newIndex)
if newIndex == -1:
raise ValueError("Invalid path character at index " + currentIndex)

# check if the quote itself is escaped by a preceding \ in which case it's part of the token
if path[newIndex - 1] != "\\":
break
newIndex += 1

# This will extract the token excluding the quote chars
token = path[currentIndex + 1: newIndex]
tokens.append(token)
currentIndex = newIndex + 1
else:
newIndex = path.find(segmentSeparator, currentIndex)
token = None
if newIndex == -1:
# This will extract the token from currentIndex to end of the string
token = path[currentIndex:]
currentIndex = len(path)
for path in paths:
currentIndex = 0

while currentIndex < len(path):
if path[currentIndex] != segmentSeparator:
raise ValueError("Invalid path character at index " + currentIndex)

currentIndex += 1
if currentIndex == len(path):
break

# " and ' are treated specially in the sense that they can have the / (segment separator)
# between them which is considered part of the token
if path[currentIndex] == '"' or path[currentIndex] == "'":
quote = path[currentIndex]
newIndex = currentIndex + 1

while True:
newIndex = path.find(quote, newIndex)
if newIndex == -1:
raise ValueError("Invalid path character at index " + currentIndex)

# check if the quote itself is escaped by a preceding \ in which case it's part of the token
if path[newIndex - 1] != "\\":
break
newIndex += 1

# This will extract the token excluding the quote chars
token = path[currentIndex + 1: newIndex]
tokens.append(token)
currentIndex = newIndex + 1
else:
# This will extract the token from currentIndex to the char before the segmentSeparator
token = path[currentIndex:newIndex]
currentIndex = newIndex

token = token.strip()
tokens.append(token)
newIndex = path.find(segmentSeparator, currentIndex)
token = None
if newIndex == -1:
# This will extract the token from currentIndex to end of the string
token = path[currentIndex:]
currentIndex = len(path)
else:
# This will extract the token from currentIndex to the char before the segmentSeparator
token = path[currentIndex:newIndex]
currentIndex = newIndex

token = token.strip()
tokens.append(token)

return tokens

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@
from . import _request_object
from . import _synchronized_request as synchronized_request
from . import _global_endpoint_manager as global_endpoint_manager
from ._routing import routing_map_provider
from ._routing import routing_map_provider, routing_range
from ._retry_utility import ConnectionRetryPolicy
from . import _session
from . import _utils
from .partition_key import _Undefined, _Empty
from .partition_key import _Undefined, _Empty, PartitionKey
from ._auth_policy import CosmosBearerTokenCredentialPolicy
from ._cosmos_http_logging_policy import CosmosHttpLoggingPolicy

Expand Down Expand Up @@ -2539,6 +2539,46 @@ def __GetBodiesFromQueryResult(result):
# Query operations will use ReadEndpoint even though it uses POST(for regular query operations)
request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery)
req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, options, partition_key_range_id)

#check if query has prefix partition key
isPrefixPartitionQuery = kwargs.pop("isPrefixPartitionQuery", None)
if isPrefixPartitionQuery:
# here get the over lapping ranges
partition_key_definition = kwargs.pop("partitionKeyDefinition", None)
pk_properties = partition_key_definition
partition_key_definition = PartitionKey(path=pk_properties["paths"], kind=pk_properties["kind"])
partition_key_value = pk_properties["partition_key"]
feedrangeEPK = partition_key_definition._get_epk_range_for_prefix_partition_key(partition_key_value) # cspell:disable-line # pylint: disable=line-too-long
over_lapping_ranges = self._routing_map_provider.get_overlapping_ranges(id_, [feedrangeEPK])
# It is possible to get more than one over lapping range. We need to get the query results for each one
results = None
for over_lapping_range in over_lapping_ranges:
single_range = routing_range.Range.PartitionKeyRangeToRange(over_lapping_range)
if single_range.min == feedrangeEPK.min and single_range.max == feedrangeEPK.max:
# The EpkRange spans exactly one physical partition
# In this case we can route to the physical pk range id
req_headers[http_constants.HttpHeaders.PartitionKeyRangeID] = over_lapping_range["id"]
else:
# The EpkRange spans less than single physical partition
# In this case we route to the physical partition and
# pass the epk range headers to filter within partition
req_headers[http_constants.HttpHeaders.PartitionKeyRangeID] = over_lapping_range["id"]
req_headers[http_constants.HttpHeaders.StartEpkString] = feedrangeEPK.min
req_headers[http_constants.HttpHeaders.EndEpkString] = feedrangeEPK.max
req_headers[http_constants.HttpHeaders.ReadFeedKeyType] = "EffectivePartitionKeyRange"
r, self.last_response_headers = self.__Post(path, request_params, query, req_headers, **kwargs)
if results:
# add up all the query results from all over lapping ranges
results["Documents"].extend(r["Documents"])
results["_count"] += r["_count"]
else:
results = r
if response_hook:
response_hook(self.last_response_headers, results)
# if the prefix partition query has results lets return it
if results:
return __GetBodiesFromQueryResult(results)

result, self.last_response_headers = self.__Post(path, request_params, query, req_headers, **kwargs)

if response_hook:
Expand Down Expand Up @@ -2576,6 +2616,8 @@ def _GetQueryPlanThroughGateway(self, query, resource_link, **kwargs):
is_query_plan=True,
**kwargs)



def __CheckAndUnifyQueryFormat(self, query_body):
"""Checks and unifies the format of the query body.

Expand Down Expand Up @@ -2650,21 +2692,36 @@ def _AddPartitionKey(self, collection_link, document, options):

# Extracts the partition key from the document using the partitionKey definition
def _ExtractPartitionKey(self, partitionKeyDefinition, document):
if partitionKeyDefinition["kind"] == "MultiHash":
ret = []
for partition_key_level in partitionKeyDefinition.get("paths"):
# Parses the paths into a list of token each representing a property
partition_key_parts = base.ParsePaths([partition_key_level])
# Check if the partitionKey is system generated or not
is_system_key = partitionKeyDefinition["systemKey"] if "systemKey" in partitionKeyDefinition else False

# Navigates the document to retrieve the partitionKey specified in the paths
val = self._retrieve_partition_key(partition_key_parts, document, is_system_key)
if val is _Undefined:
break
ret.append(val)
return ret


# Parses the paths into a list of token each representing a property
partition_key_parts = base.ParsePaths(partitionKeyDefinition.get("paths"))
# Check if the partitionKey is system generated or not
is_system_key = partitionKeyDefinition["systemKey"] if "systemKey" in partitionKeyDefinition else False

# Navigates the document to retrieve the partitionKey specified in the paths

return self._retrieve_partition_key(partition_key_parts, document, is_system_key)

# Navigates the document to retrieve the partitionKey specified in the partition key parts
def _retrieve_partition_key(self, partition_key_parts, document, is_system_key):
expected_matchCount = len(partition_key_parts)
matchCount = 0
partitionKey = document

for part in partition_key_parts:
# At any point if we don't find the value of a sub-property in the document, we return as Undefined
if part not in partitionKey:
Expand Down
3 changes: 2 additions & 1 deletion sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from ..http_constants import StatusCodes
from ..offer import ThroughputProperties
from ._scripts import ScriptsProxy
from ..partition_key import NonePartitionKeyValue
from ..partition_key import NonePartitionKeyValue, PartitionKey

__all__ = ("ContainerProxy",)

Expand Down Expand Up @@ -361,6 +361,7 @@ def query_items(
partition_key = kwargs.pop('partition_key', None)
if partition_key is not None:
feed_options["partitionKey"] = self._set_partition_key(partition_key)
kwargs["containerProperties"] = self._get_properties
annatisch marked this conversation as resolved.
Show resolved Hide resolved
else:
feed_options["enableCrossPartitionQuery"] = True
max_integrated_cache_staleness_in_ms = kwargs.pop('max_integrated_cache_staleness_in_ms', None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

"""Document client class for the Azure Cosmos database service.
"""
import json
# https://github.com/PyCQA/pylint/issues/3112
# Currently pylint is locked to 2.3.3 and this is fixed in 2.4.4
from typing import Dict, Any, Optional, TypeVar # pylint: disable=unused-import
Expand All @@ -44,6 +45,7 @@

from .. import _base as base
from .. import documents
from .._routing import routing_range
from ..documents import ConnectionPolicy
from .. import _constants as constants
from .. import http_constants
Expand All @@ -56,7 +58,7 @@
from ._retry_utility_async import _ConnectionRetryPolicy
from .. import _session
from .. import _utils
from ..partition_key import _Undefined, _Empty
from ..partition_key import _Undefined, _Empty, PartitionKey
from ._auth_policy_async import AsyncCosmosBearerTokenCredentialPolicy
from .._cosmos_http_logging_policy import CosmosHttpLoggingPolicy

Expand Down Expand Up @@ -2349,6 +2351,55 @@ def __GetBodiesFromQueryResult(result):
# Query operations will use ReadEndpoint even though it uses POST(for regular query operations)
request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery)
req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, options, partition_key_range_id)

# check if query has prefix partition key
cont_prop = kwargs.pop("containerProperties", None)
partition_key = options.get("partitionKey", None)
isPrefixPartitionQuery = False
partition_key_definition = None
if cont_prop:
cont_prop = await cont_prop()
pk_properties = cont_prop["partitionKey"]
partition_key_definition = PartitionKey(path=pk_properties["paths"], kind=pk_properties["kind"])
if partition_key_definition.kind == "MultiHash" and\
(type(partition_key) == list and len(partition_key_definition['paths']) != len(partition_key)):
isPrefixPartitionQuery = True

if isPrefixPartitionQuery:
# here get the overlapping ranges
req_headers.pop(http_constants.HttpHeaders.PartitionKey, None)
feedrangeEPK = partition_key_definition._get_epk_range_for_prefix_partition_key(partition_key) # cspell:disable-line # pylint: disable=line-too-long
over_lapping_ranges = await self._routing_map_provider.get_overlapping_ranges(id_, [feedrangeEPK])
results = None
for over_lapping_range in over_lapping_ranges:
# It is possible for the over lapping range to include multiple physical partitions
# we should return query results for all the partitions that are overlapped.
single_range = routing_range.Range.PartitionKeyRangeToRange(over_lapping_range)
if single_range.min == feedrangeEPK.min and single_range.max == feedrangeEPK.max:
# The EpkRange spans exactly one physical partition
# In this case we can route to the physical pk range id
req_headers[http_constants.HttpHeaders.PartitionKeyRangeID] = over_lapping_range["id"]
else:
# The EpkRange spans less than single physical partition
# In this case we route to the physical partition and
# pass the epk range headers to filter within partition
req_headers[http_constants.HttpHeaders.PartitionKeyRangeID] = over_lapping_range["id"]
req_headers[http_constants.HttpHeaders.StartEpkString] = feedrangeEPK.min
ealsur marked this conversation as resolved.
Show resolved Hide resolved
req_headers[http_constants.HttpHeaders.EndEpkString] = feedrangeEPK.max
req_headers[http_constants.HttpHeaders.ReadFeedKeyType] = "EffectivePartitionKeyRange"
r, self.last_response_headers = await self.__Post(path, request_params, query, req_headers, **kwargs)
if results:
# add up all the query results from all over lapping ranges
results["Documents"].extend(r["Documents"])
results["_count"] += r["_count"]
else:
results = r
if response_hook:
response_hook(self.last_response_headers, results)
# if the prefix partition query has results lets return it
if results:
return __GetBodiesFromQueryResult(results)

result, self.last_response_headers = await self.__Post(path, request_params, query, req_headers, **kwargs)

if response_hook:
Expand Down Expand Up @@ -2516,13 +2567,29 @@ async def _AddPartitionKey(self, collection_link, document, options):

# Extracts the partition key from the document using the partitionKey definition
def _ExtractPartitionKey(self, partitionKeyDefinition, document):
if partitionKeyDefinition["kind"] == "MultiHash":
ret = []
for partition_key_level in partitionKeyDefinition.get("paths"):
# Parses the paths into a list of token each representing a property
partition_key_parts = base.ParsePaths([partition_key_level])
# Check if the partitionKey is system generated or not
is_system_key = partitionKeyDefinition["systemKey"] if "systemKey" in partitionKeyDefinition else False

# Navigates the document to retrieve the partitionKey specified in the paths
val = self._retrieve_partition_key(partition_key_parts, document, is_system_key)
if val is _Undefined:
break
ret.append(val)
return ret


# Parses the paths into a list of token each representing a property
partition_key_parts = base.ParsePaths(partitionKeyDefinition.get("paths"))
# Check if the partitionKey is system generated or not
is_system_key = partitionKeyDefinition["systemKey"] if "systemKey" in partitionKeyDefinition else False

# Navigates the document to retrieve the partitionKey specified in the paths

return self._retrieve_partition_key(partition_key_parts, document, is_system_key)

# Navigates the document to retrieve the partitionKey specified in the partition key parts
Expand Down
Loading