Skip to content

Commit

Permalink
Merge pull request #64 from carlosgalvez-tiendeo/master
Browse files Browse the repository at this point in the history
Add mget
  • Loading branch information
vrcmarcos authored Mar 2, 2021
2 parents c10cb2d + c8aa31d commit fe382af
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 18 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ bin/
share/
pyvenv.cfg

### Visual Studio ###
.vscode/

### Intellij ###
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
Expand Down
66 changes: 49 additions & 17 deletions elasticmock/fake_elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
import dateutil.parser
from elasticsearch import Elasticsearch
from elasticsearch.client.utils import query_params
from elasticsearch.client import _normalize_hosts
from elasticsearch.transport import Transport
from elasticsearch.exceptions import NotFoundError, RequestError

from elasticmock.behaviour.server_failure import server_failure
from elasticmock.fake_cluster import FakeClusterClient
from elasticmock.fake_indices import FakeIndicesClient
from elasticmock.utilities import extract_ignore_as_iterable, get_random_id, get_random_scroll_id
from elasticmock.utilities import (extract_ignore_as_iterable, get_random_id,
get_random_scroll_id)
from elasticmock.utilities.decorator import for_all_methods

PY3 = sys.version_info[0] == 3
Expand Down Expand Up @@ -228,11 +231,15 @@ def _compare_value_for_field(self, doc_source, field, value, ignore_case):
value = value.lower()

doc_val = doc_source
# Remove boosting
field, *_ = field.split("*")
for k in field.split("."):
if hasattr(doc_val, k):
doc_val = getattr(doc_val, k)
break
elif k in doc_val:
doc_val = doc_val[k]
break
else:
return False

Expand All @@ -247,7 +254,7 @@ def _compare_value_for_field(self, doc_source, field, value, ignore_case):

if value == val:
return True
if isinstance(val, str) and value in val:
if isinstance(val, str) and str(value) in val:
return True

return False
Expand All @@ -260,6 +267,7 @@ class FakeElasticsearch(Elasticsearch):
def __init__(self, hosts=None, transport_class=None, **kwargs):
self.__documents_dict = {}
self.__scrolls = {}
self.transport = Transport(_normalize_hosts(hosts), **kwargs)

@property
def indices(self):
Expand Down Expand Up @@ -309,10 +317,10 @@ def index(self, index, body, doc_type='_doc', id=None, params=None, headers=None

if id is None:
id = get_random_id()
elif self.exists(index, doc_type, id, params=params):
doc = self.get(index, id, doc_type, params=params)
elif self.exists(index, id, doc_type=doc_type, params=params):
doc = self.get(index, id, doc_type=doc_type, params=params)
version = doc['_version'] + 1
self.delete(index, doc_type, id)
self.delete(index, id, doc_type=doc_type)

self.__documents_dict[index].append({
'_type': doc_type,
Expand Down Expand Up @@ -344,10 +352,10 @@ def bulk(self, body, index=None, doc_type=None, params=None, headers=None):
action = next(iter(line.keys()))

version = 1
index = line[action]['_index']
index = line[action].get('_index') or index
doc_type = line[action].get('_type', "_doc") # _type is deprecated in 7.x

if action in ['delete', 'updated'] and not line[action].get("_id"):
if action in ['delete', 'update'] and not line[action].get("_id"):
raise RequestError(400, 'action_request_validation_exception', 'missing id')

document_id = line[action].get('_id', get_random_id())
Expand All @@ -367,7 +375,7 @@ def bulk(self, body, index=None, doc_type=None, params=None, headers=None):
errors = True
item[action]["error"] = result
else:
self.delete(index, doc_type, document_id, params=params)
self.delete(index, document_id, doc_type=doc_type, params=params)
item[action]["result"] = result
items.append(item)

Expand All @@ -392,10 +400,10 @@ def bulk(self, body, index=None, doc_type=None, params=None, headers=None):
}
if not error:
item[action]["result"] = result
if self.exists(index, doc_type, document_id, params=params):
doc = self.get(index, document_id, doc_type, params=params)
if self.exists(index, document_id, doc_type=doc_type, params=params):
doc = self.get(index, document_id, doc_type=doc_type, params=params)
version = doc['_version'] + 1
self.delete(index, doc_type, document_id, params=params)
self.delete(index, document_id, doc_type=doc_type, params=params)

self.__documents_dict[index].append({
'_type': doc_type,
Expand Down Expand Up @@ -430,7 +438,7 @@ def _validate_action(self, action, index, document_id, doc_type, params=None):
raise NotImplementedError(f"{action} behaviour hasn't been implemented")

@query_params('parent', 'preference', 'realtime', 'refresh', 'routing')
def exists(self, index, doc_type, id, params=None, headers=None):
def exists(self, index, id, doc_type=None, params=None, headers=None):
result = False
if index in self.__documents_dict:
for document in self.__documents_dict[index]:
Expand Down Expand Up @@ -471,6 +479,26 @@ def get(self, index, id, doc_type='_all', params=None, headers=None):
}
raise NotFoundError(404, json.dumps(error_data))

@query_params('_source', '_source_exclude', '_source_include',
'preference', 'realtime', 'refresh', 'routing',
'stored_fields')
def mget(self, body, index, doc_type='_all', params=None, headers=None):
ids = body.get('ids')
results = []
for id in ids:
try:
results.append(self.get(index, id, doc_type=doc_type,
params=params, headers=headers))
except:
pass
if not results:
raise RequestError(
400,
'action_request_validation_exception',
'Validation Failed: 1: no documents to get;'
)
return {'docs': results}

@query_params('_source', '_source_exclude', '_source_include', 'parent',
'preference', 'realtime', 'refresh', 'routing', 'version',
'version_type')
Expand Down Expand Up @@ -646,17 +674,20 @@ def scroll(self, scroll_id, params=None, headers=None):

@query_params('consistency', 'parent', 'refresh', 'replication', 'routing',
'timeout', 'version', 'version_type')
def delete(self, index, doc_type, id, params=None, headers=None):
def delete(self, index, id, doc_type=None, params=None, headers=None):

found = False
ignore = extract_ignore_as_iterable(params)

if index in self.__documents_dict:
for document in self.__documents_dict[index]:
if document.get('_type') == doc_type and document.get('_id') == id:
if document.get('_id') == id:
found = True
self.__documents_dict[index].remove(document)
break
if doc_type and document.get('_type') != doc_type:
found = False
if found:
self.__documents_dict[index].remove(document)
break

result_dict = {
'found': found,
Expand All @@ -665,12 +696,13 @@ def delete(self, index, doc_type, id, params=None, headers=None):
'_id': id,
'_version': 1,
}

if found:
return result_dict
elif params and 404 in ignore:
return {'found': False}
else:
raise NotFoundError(404, json.dumps(result_dict, default=str))
raise NotFoundError(404, json.dumps(result_dict))

@query_params('allow_no_indices', 'expand_wildcards', 'ignore_unavailable',
'preference', 'routing')
Expand Down
8 changes: 8 additions & 0 deletions tests/fake_elasticsearch/test_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,11 @@ def test_should_get_only_document_source_with_id(self):
target_doc_source = self.es.get_source(index=INDEX_NAME, doc_type=DOC_TYPE, id=document_id)

self.assertEqual(target_doc_source, BODY)

def test_mget_get_several_documents_by_id(self):
ids = []
for _ in range(0, 10):
data = self.es.index(index=INDEX_NAME, doc_type=DOC_TYPE, body=BODY)
ids.append(data.get('_id'))
results = self.es.mget(index=INDEX_NAME, body={'ids': ids})
self.assertEqual(len(results['docs']), 10)
2 changes: 1 addition & 1 deletion tests/fake_elasticsearch/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def test_search_bool_should_match_query(self):
}
}
})
self.assertEqual(response['hits']['total'], 3)
self.assertEqual(response['hits']['total']['value'], 3)
hits = response['hits']['hits']
self.assertEqual(len(hits), 3)
self.assertEqual(hits[0]['_source'], {'data': 'test_0'})
Expand Down

0 comments on commit fe382af

Please sign in to comment.