Skip to content

Commit 1a4822c

Browse files
complete CRUD example
1 parent aa63aed commit 1a4822c

File tree

11 files changed

+248
-55
lines changed

11 files changed

+248
-55
lines changed

elasticsearch/dsl/pydantic.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,28 @@ class ESMeta(BaseModel):
2929
primary_term: int = 0
3030
seq_no: int = 0
3131
version: int = 0
32+
score: float = 0
3233

3334

3435
class _BaseModel(BaseModel):
3536
meta: Annotated[ESMeta, dsl.mapped_field(exclude=True)] = Field(
36-
default=ESMeta(), init=False
37+
default=ESMeta(),
38+
init=False,
3739
)
3840

3941

4042
class _BaseESModelMetaclass(type(BaseModel)): # type: ignore[misc]
4143
@staticmethod
42-
def process_annotations(metacls: Type["_BaseESModelMetaclass"], annotations: Dict[str, Any]) -> Dict[str, Any]:
44+
def process_annotations(
45+
metacls: Type["_BaseESModelMetaclass"], annotations: Dict[str, Any]
46+
) -> Dict[str, Any]:
4347
updated_annotations = {}
4448
for var, ann in annotations.items():
4549
if isinstance(ann, type(BaseModel)):
4650
# an inner Pydantic model is transformed into an Object field
47-
updated_annotations[var] = metacls.make_dsl_class(metacls, dsl.InnerDoc, ann)
51+
updated_annotations[var] = metacls.make_dsl_class(
52+
metacls, dsl.InnerDoc, ann
53+
)
4854
elif (
4955
hasattr(ann, "__origin__")
5056
and ann.__origin__ in [list, List]
@@ -59,7 +65,12 @@ def process_annotations(metacls: Type["_BaseESModelMetaclass"], annotations: Dic
5965
return updated_annotations
6066

6167
@staticmethod
62-
def make_dsl_class(metacls: Type["_BaseESModelMetaclass"], dsl_class: type, pydantic_model: type, pydantic_attrs: Optional[Dict[str, Any]] = None) -> type:
68+
def make_dsl_class(
69+
metacls: Type["_BaseESModelMetaclass"],
70+
dsl_class: type,
71+
pydantic_model: type,
72+
pydantic_attrs: Optional[Dict[str, Any]] = None,
73+
) -> type:
6374
dsl_attrs = {
6475
attr: value
6576
for attr, value in dsl_class.__dict__.items()
@@ -95,7 +106,7 @@ class BaseESModel(_BaseModel, metaclass=BaseESModelMetaclass):
95106

96107
def to_doc(self) -> dsl.Document:
97108
data = self.model_dump()
98-
meta = {f"_{k}": v for k, v in data.pop("meta", {}).items()}
109+
meta = {f"_{k}": v for k, v in data.pop("meta", {}).items() if v}
99110
return self._doc(**meta, **data)
100111

101112
@classmethod
@@ -116,7 +127,7 @@ class AsyncBaseESModel(_BaseModel, metaclass=AsyncBaseESModelMetaclass):
116127

117128
def to_doc(self) -> dsl.AsyncDocument:
118129
data = self.model_dump()
119-
meta = {f"_{k}": v for k, v in data.pop("meta", {}).items()}
130+
meta = {f"_{k}": v for k, v in data.pop("meta", {}).items() if v}
120131
return self._doc(**meta, **data)
121132

122133
@classmethod

examples/quotes/backend/quotes.py

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Annotated
66

77
from fastapi import FastAPI, HTTPException
8-
from pydantic import BaseModel, Field
8+
from pydantic import BaseModel, Field, ValidationError
99
from sentence_transformers import SentenceTransformer
1010

1111
from elasticsearch import NotFoundError
@@ -50,6 +50,7 @@ class SearchResponse(BaseModel):
5050
version="1.0.0",
5151
)
5252

53+
5354
@app.get("/api/quotes/{id}")
5455
async def get_quote(id: str) -> Quote:
5556
doc = None
@@ -64,31 +65,64 @@ async def get_quote(id: str) -> Quote:
6465

6566
@app.post("/api/quotes", status_code=201)
6667
async def create_quote(req: Quote) -> Quote:
68+
embed_quotes([req])
6769
doc = req.to_doc()
6870
doc.meta.id = ""
6971
await doc.save(refresh=True)
7072
return Quote.from_doc(doc)
7173

7274

7375
@app.put("/api/quotes/{id}")
74-
async def update_quote(id: str, req: Quote) -> Quote:
75-
doc = req.to_doc()
76-
doc.meta.id = id
76+
async def update_quote(id: str, quote: Quote) -> Quote:
77+
doc = None
78+
try:
79+
doc = await Quote._doc.get(id)
80+
except NotFoundError:
81+
pass
82+
if not doc:
83+
raise HTTPException(status_code=404, detail="Item not found")
84+
if quote.quote:
85+
embed_quotes([quote])
86+
doc.quote = quote.quote
87+
doc.embedding = quote.embedding
88+
if quote.author:
89+
doc.author = quote.author
90+
if quote.tags:
91+
doc.tags = quote.tags
7792
await doc.save(refresh=True)
7893
return Quote.from_doc(doc)
7994

8095

8196
@app.delete("/api/quotes/{id}", status_code=204)
82-
async def delete_quote(id: str, req: Quote) -> None:
83-
doc = await Quote._doc.get(id)
97+
async def delete_quote(id: str) -> None:
98+
doc = None
99+
try:
100+
doc = await Quote._doc.get(id)
101+
except NotFoundError:
102+
pass
84103
if not doc:
85104
raise HTTPException(status_code=404, detail="Item not found")
86105
await doc.delete(refresh=True)
87106

88107

89108
@app.post('/api/search')
90109
async def search_quotes(req: SearchRequest) -> SearchResponse:
91-
quotes, tags, total = await search_quotes(req.query, req.filters, use_knn=req.knn, start=req.start)
110+
s = Quote._doc.search()
111+
if req.query == '':
112+
s = s.query(dsl.query.MatchAll())
113+
elif req.knn:
114+
s = s.query(dsl.query.Knn(field=Quote._doc.embedding, query_vector=model.encode(req.query).tolist()))
115+
else:
116+
s = s.query(dsl.query.Match(quote=req.query))
117+
for tag in req.filters:
118+
s = s.filter(dsl.query.Terms(tags=[tag]))
119+
s.aggs.bucket('tags', dsl.aggs.Terms(field=Quote._doc.tags, size=100))
120+
121+
r = await s[req.start:req.start + 25].execute()
122+
tags = [(tag.key, tag.doc_count) for tag in r.aggs.tags.buckets]
123+
quotes = [Quote.from_doc(hit) for hit in r.hits]
124+
total = r['hits'].total.value
125+
92126
return SearchResponse(
93127
quotes=quotes,
94128
tags=[Tag(tag=t[0], count=t[1]) for t in tags],
@@ -97,6 +131,12 @@ async def search_quotes(req: SearchRequest) -> SearchResponse:
97131
)
98132

99133

134+
def embed_quotes(quotes):
135+
embeddings = model.encode([q.quote for q in quotes])
136+
for q, e in zip(quotes, embeddings):
137+
q.embedding = e.tolist()
138+
139+
100140
async def ingest_quotes():
101141
if await Quote._doc._index.exists():
102142
await Quote._doc._index.delete()
@@ -106,11 +146,6 @@ def ingest_progress(count, start):
106146
elapsed = time() - start
107147
print(f'\rIngested {count} quotes. ({count / elapsed:.0f}/sec)', end='')
108148

109-
def embed_quotes(quotes):
110-
embeddings = model.encode([q.quote for q in quotes])
111-
for q, e in zip(quotes, embeddings):
112-
q.embedding = e.tolist()
113-
114149
async def get_next_quote():
115150
quotes: list[Quote] = []
116151
with open('quotes.csv') as f:
@@ -137,21 +172,5 @@ async def get_next_quote():
137172
await Quote._doc.bulk(get_next_quote())
138173

139174

140-
async def search_quotes(q, tags, use_knn=True, start=0, size=25):
141-
s = Quote._doc.search()
142-
if q == '':
143-
s = s.query(dsl.query.MatchAll())
144-
elif use_knn:
145-
s = s.query(dsl.query.Knn(field=Quote._doc.embedding, query_vector=model.encode(q).tolist()))
146-
else:
147-
s = s.query(dsl.query.Match(quote=q))
148-
for tag in tags:
149-
s = s.filter(dsl.query.Terms(tags=[tag]))
150-
s.aggs.bucket('tags', dsl.aggs.Terms(field=Quote._doc.tags, size=100))
151-
r = await s[start:start + size].execute()
152-
tags = [(tag.key, tag.doc_count) for tag in r.aggs.tags.buckets]
153-
return [Quote.from_doc(hit) for hit in r.hits], tags, r['hits'].total.value
154-
155-
156175
if __name__ == "__main__":
157176
asyncio.run(ingest_quotes())

examples/quotes/package-lock.json

Lines changed: 39 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/quotes/package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
"bootstrap": "^5.3.8",
1818
"react": "^19.1.1",
1919
"react-bootstrap": "^2.10.10",
20-
"react-dom": "^19.1.1"
20+
"react-dom": "^19.1.1",
21+
"react-router": "^7.9.2"
2122
},
2223
"devDependencies": {
2324
"@eslint/js": "^9.36.0",

examples/quotes/src/App.tsx

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import React, { useRef, useState, useEffect } from 'react';
22
import Container from 'react-bootstrap/Container';
3+
import { NavLink } from 'react-router';
34
import Row from 'react-bootstrap/Row';
45
import Col from 'react-bootstrap/Col';
56
import Form from 'react-bootstrap/Form';
@@ -9,23 +10,7 @@ import Stack from 'react-bootstrap/Stack';
910
import CloseButton from 'react-bootstrap/CloseButton';
1011
import ToggleButton from 'react-bootstrap/ToggleButton';
1112
import Pagination from 'react-bootstrap/Pagination';
12-
13-
interface Meta {
14-
id: string;
15-
score: number;
16-
}
17-
18-
interface Quote {
19-
meta: Meta;
20-
quote: string;
21-
author: string;
22-
tags: string[];
23-
};
24-
25-
interface Tag {
26-
tag: string;
27-
count: number;
28-
};
13+
import type { Quote, Tag } from './models';
2914

3015
export default function App() {
3116
const inputRef = useRef<HTMLInputElement>(null);
@@ -96,9 +81,10 @@ export default function App() {
9681
<Stack direction="horizontal" gap={2}>
9782
<Form.Control type="text" placeholder="Search for... ?" ref={inputRef} autoFocus={true} />
9883
<CloseButton onClick={onResetQuery} disabled={query === ''} />
99-
<ToggleButton id="knn" type="checkbox" variant="outline-primary" checked={knn} value="1" title="Use hybrid search" onChange={e => setKnn(e.currentTarget.checked)}>
84+
<ToggleButton id="knn" type="checkbox" variant="outline-secondary" checked={knn} value="1" title={knn ? "Using hybrid search" : "Using full-text search"} onChange={e => setKnn(e.currentTarget.checked)}>
10085
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="currentColor" viewBox="0 0 256 256"><path d="M192.5,171.47A88.34,88.34,0,0,0,224,101.93c-1-45.71-37.61-83.4-83.24-85.8A88,88,0,0,0,48,102L25.55,145.18c-.09.18-.18.36-.26.54a16,16,0,0,0,7.55,20.62l.25.11L56,176.94V208a16,16,0,0,0,16,16h48a8,8,0,0,0,0-16H72V171.81a8,8,0,0,0-4.67-7.28L40,152l23.07-44.34A7.9,7.9,0,0,0,64,104a72,72,0,0,1,56-70.21V49.38a24,24,0,1,0,16,0V32c1.3,0,2.6,0,3.9.1A72.26,72.26,0,0,1,203.84,80H184a8,8,0,0,0-6.15,2.88L152.34,113.5a24.06,24.06,0,1,0,12.28,10.25L187.75,96h19.79q.36,3.12.44,6.3a72.26,72.26,0,0,1-28.78,59.3,8,8,0,0,0-3.14,7.39l8,64a8,8,0,0,0,7.93,7,8.39,8.39,0,0,0,1-.06,8,8,0,0,0,6.95-8.93ZM128,80a8,8,0,1,1,8-8A8,8,0,0,1,128,80Zm16,64a8,8,0,1,1,8-8A8,8,0,0,1,144,144Z"></path></svg>
10186
</ToggleButton>
87+
<NavLink to="/quotes/new" className="btn btn-primary">New&nbsp;Quote</NavLink>
10288
</Stack>
10389
</Form>
10490
</Row>
@@ -159,6 +145,8 @@ export default function App() {
159145
<span className="ResultQuote">{quote}</span><span className="ResultAuthor">{author}</span>
160146
<br />
161147
<span className="ResultScore">[Score: {meta.score}]</span> <span className="ResultTags">{tags.map(tag => `#${tag}`).join(', ')}</span>
148+
<br />
149+
<small><NavLink to={`/quotes/${meta.id}`}>Edit</NavLink></small>
162150
</p>
163151
</div>
164152
))}

0 commit comments

Comments
 (0)