-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathoci_utils.py
111 lines (86 loc) · 2.43 KB
/
oci_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
File name: oci_utils.py
Author: Luigi Saetta
Date created: 2023-12-17
Date last modified: 2024-03-16
Python Version: 3.11
Description:
This module provides some utilities
Usage:
Import this module into other scripts to use its functions.
Example:
...
License:
This code is released under the MIT License.
Notes:
This is a part of a set of demo showing how to use Oracle Vector DB,
OCI GenAI service, Oracle GenAI Embeddings, to buil a RAG solution,
where all he data (text + embeddings) are stored in Oracle DB 23c
Warnings:
This module is in development, may change in future versions.
"""
import logging
import oci
from config import (
EMBED_MODEL_TYPE,
EMBED_MODEL,
GEN_MODEL,
OCI_GEN_MODEL,
TOP_K,
ADD_RERANKER,
RERANKER_MODEL,
TOP_N,
ADD_PHX_TRACING,
)
logger = logging.getLogger("ConsoleLogger")
def load_oci_config():
"""
todo
"""
# read OCI config to connect to OCI with API key
# are you using default profile?
oci_config = oci.config.from_file("~/.oci/config", "DEFAULT")
return oci_config
def print_configuration():
"""
todo
"""
logger.info("------------------------")
logger.info("Configuration used:")
model_str = f" {EMBED_MODEL_TYPE} {EMBED_MODEL} for embeddings..."
logger.info(model_str)
logger.info(" Using Oracle AI Vector Search...")
logger.info(" Using %s as LLM...", GEN_MODEL)
if GEN_MODEL == "OCI":
logger.info(" Using %s as OCI model..", OCI_GEN_MODEL)
logger.info(" Retrieval parameters:")
logger.info(" TOP_K: %s", TOP_K)
if ADD_RERANKER:
logger.info(" TOP_N: %s", TOP_N)
logger.info(" Using %s as reranker...", RERANKER_MODEL)
if ADD_PHX_TRACING:
logger.info(" Enabled observability with Phoenix tracing...")
logger.info("------------------------")
logger.info("")
def pretty_print_docs(docs):
"""
todo
"""
print(
f"\n{'-' * 100}\n".join(
[f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]
)
)
def format_docs(docs):
"""
todo
"""
return "\n\n".join(doc.page_content for doc in docs)
def check_value_in_list(value, values_list):
"""
to check that we don't enter a not supported value
"""
if value not in values_list:
raise ValueError(
f"Value {value} is not valid: value must be in list {values_list}"
)