Lokasi ngalangkungan proxy:   [ UP ]  
[Ngawartoskeun bug]   [Panyetelan cookie]                
Skip to content

Commit 2e30b48

Browse files
authored
feat: add index types for vector search (#55)
* feat: adding index operations and tests
1 parent 9cc52c1 commit 2e30b48

6 files changed

Lines changed: 349 additions & 8 deletions

File tree

integration.cloudbuild.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ availableSecrets:
4343
env: "DB_PASSWORD"
4444

4545
substitutions:
46-
_INSTANCE_ID: mysql-vector
46+
_INSTANCE_ID: test-instance
4747
_REGION: us-central1
4848
_DB_NAME: test
4949
_VERSION: "3.8"

src/langchain_google_cloud_sql_mysql/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,22 @@
1414

1515
from .chat_message_history import MySQLChatMessageHistory
1616
from .engine import Column, MySQLEngine
17+
from .indexes import DistanceMeasure, IndexType, QueryOptions, SearchType, VectorIndex
1718
from .loader import MySQLDocumentSaver, MySQLLoader
1819
from .vectorstore import MySQLVectorStore
1920
from .version import __version__
2021

2122
__all__ = [
2223
"Column",
24+
"DistanceMeasure",
25+
"IndexType",
2326
"MySQLChatMessageHistory",
2427
"MySQLDocumentSaver",
2528
"MySQLEngine",
2629
"MySQLLoader",
2730
"MySQLVectorStore",
31+
"QueryOptions",
32+
"SearchType",
33+
"VectorIndex",
2834
"__version__",
2935
]

src/langchain_google_cloud_sql_mysql/engine.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,17 @@ def connect(self) -> sqlalchemy.engine.Connection:
222222
return self.engine.connect()
223223

224224
def _execute(self, query: str, params: Optional[dict] = None) -> None:
225-
"""Execute a SQL query."""
225+
"""Executes a SQL query within a transaction."""
226226
with self.engine.connect() as conn:
227227
conn.execute(sqlalchemy.text(query), params)
228228
conn.commit()
229229

230+
def _execute_outside_tx(self, query: str, params: Optional[dict] = None) -> None:
231+
"""Executes a SQL query with autocommit (outside of transaction)."""
232+
with self.engine.connect() as conn:
233+
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
234+
conn.execute(sqlalchemy.text(query), params)
235+
230236
def _fetch(self, query: str, params: Optional[dict] = None):
231237
"""Fetch results from a SQL query."""
232238
with self.engine.connect() as conn:

src/langchain_google_cloud_sql_mysql/indexes.py

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,91 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from abc import ABC
1615
from dataclasses import dataclass
16+
from enum import Enum
17+
from typing import Optional
18+
19+
20+
class SearchType(Enum):
21+
"""Defines the types of search algorithms that can be used.
22+
23+
Attributes:
24+
KNN: K-Nearest Neighbors search.
25+
ANN: Approximate Nearest Neighbors search.
26+
"""
27+
28+
KNN = "KNN"
29+
ANN = "ANN"
1730

1831

1932
@dataclass
20-
class QueryOptions(ABC):
21-
def to_string(self) -> str:
22-
raise NotImplementedError("to_string method must be implemented by subclass")
33+
class QueryOptions:
34+
"""Holds configuration options for executing a search query.
35+
36+
Attributes:
37+
num_partitions (Optional[int]): The number of partitions to divide the search space into. None means default partitioning.
38+
num_neighbors (Optional[int]): The number of nearest neighbors to retrieve. None means use the default.
39+
search_type (SearchType): The type of search algorithm to use. Defaults to KNN.
40+
"""
41+
42+
num_partitions: Optional[int] = None
43+
num_neighbors: Optional[int] = None
44+
search_type: SearchType = SearchType.KNN
45+
46+
47+
DEFAULT_QUERY_OPTIONS = QueryOptions()
48+
49+
50+
class IndexType(Enum):
51+
"""Defines the types of indexes that can be used for vector storage.
52+
53+
Attributes:
54+
BRUTE_FORCE_SCAN: A simple brute force scan approach.
55+
TREE_AH: A tree-based index, specifically Annoy (Approximate Nearest Neighbors Oh Yeah).
56+
TREE_SQ: A tree-based index, specifically ScaNN (Scalable Nearest Neighbors).
57+
"""
58+
59+
BRUTE_FORCE_SCAN = "BRUTE_FORCE"
60+
TREE_AH = "TREE_AH"
61+
TREE_SQ = "TREE_SQ"
62+
63+
64+
class DistanceMeasure(Enum):
65+
"""Enumerates the types of distance measures that can be used in searches.
66+
67+
Attributes:
68+
COSINE: Cosine similarity measure.
69+
SQUARED_L2: Squared L2 norm (Euclidean) distance.
70+
DOT_PRODUCT: Dot product similarity.
71+
"""
72+
73+
COSINE = "cosine"
74+
SQUARED_L2 = "squared_l2"
75+
DOT_PRODUCT = "dot_product"
76+
77+
78+
class VectorIndex:
79+
"""Represents a vector index for storing and querying vectors.
80+
81+
Attributes:
82+
name (Optional[str]): The name of the index.
83+
index_type (Optional[IndexType]): The type of index.
84+
distance_measure (Optional[DistanceMeasure]): The distance measure to use for the index.
85+
num_partitions (Optional[int]): The number of partitions for the index. None for default.
86+
num_neighbors (Optional[int]): The default number of neighbors to return for queries.
87+
"""
88+
89+
def __init__(
90+
self,
91+
name: Optional[str] = None,
92+
index_type: Optional[IndexType] = None,
93+
distance_measure: Optional[DistanceMeasure] = None,
94+
num_partitions: Optional[int] = None,
95+
num_neighbors: Optional[int] = None,
96+
):
97+
"""Initializes a new instance of the VectorIndex class."""
98+
self.name = name
99+
self.index_type = index_type
100+
self.distance_measure = distance_measure
101+
self.num_partitions = num_partitions
102+
self.num_neighbors = num_neighbors

src/langchain_google_cloud_sql_mysql/vectorstore.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
from langchain_core.vectorstores import VectorStore
2424

2525
from .engine import MySQLEngine
26-
from .indexes import QueryOptions
26+
from .indexes import DEFAULT_QUERY_OPTIONS, QueryOptions, SearchType, VectorIndex
27+
28+
DEFAULT_INDEX_NAME_SUFFIX = "langchainvectorindex"
2729

2830

2931
class MySQLVectorStore(VectorStore):
@@ -38,7 +40,7 @@ def __init__(
3840
ignore_metadata_columns: Optional[List[str]] = None,
3941
id_column: str = "langchain_id",
4042
metadata_json_column: Optional[str] = "langchain_metadata",
41-
query_options: Optional[QueryOptions] = None,
43+
query_options: QueryOptions = DEFAULT_QUERY_OPTIONS,
4244
):
4345
"""Constructor for MySQLVectorStore.
4446
Args:
@@ -118,11 +120,16 @@ def __init__(
118120
self.id_column = id_column
119121
self.metadata_json_column = metadata_json_column
120122
self.query_options = query_options
123+
self.db_name = self.__get_db_name()
121124

122125
@property
123126
def embeddings(self) -> Embeddings:
124127
return self.embedding_service
125128

129+
def __get_db_name(self) -> str:
130+
result = self.engine._fetch("SELECT DATABASE();")
131+
return result[0]["DATABASE()"]
132+
126133
def _add_embeddings(
127134
self,
128135
texts: Iterable[str],
@@ -210,6 +217,64 @@ def delete(
210217
self.engine._execute(query)
211218
return True
212219

220+
def apply_vector_index(self, vector_index: VectorIndex):
221+
# Construct the default index name
222+
if not vector_index.name:
223+
vector_index.name = f"{self.table_name}_{DEFAULT_INDEX_NAME_SUFFIX}"
224+
query_template = f"CALL mysql.create_vector_index('{vector_index.name}', '{self.db_name}.{self.table_name}', '{self.embedding_column}', '{{}}');"
225+
self.__exec_apply_vector_index(query_template, vector_index)
226+
# After applying an index to the table, set the query option search type to be ANN
227+
self.query_options.search_type = SearchType.ANN
228+
229+
def alter_vector_index(self, vector_index: VectorIndex):
230+
existing_index_name = self._get_vector_index_name()
231+
if not existing_index_name:
232+
raise ValueError("No existing vector index found.")
233+
if not vector_index.name:
234+
vector_index.name = existing_index_name.split(".")[1]
235+
if existing_index_name.split(".")[1] != vector_index.name:
236+
raise ValueError(
237+
f"Existing index name {existing_index_name} does not match the new index name {vector_index.name}."
238+
)
239+
query_template = (
240+
f"CALL mysql.alter_vector_index('{existing_index_name}', '{{}}');"
241+
)
242+
self.__exec_apply_vector_index(query_template, vector_index)
243+
244+
def __exec_apply_vector_index(self, query_template: str, vector_index: VectorIndex):
245+
index_options = []
246+
if vector_index.index_type:
247+
index_options.append(f"index_type={vector_index.index_type.value}")
248+
if vector_index.distance_measure:
249+
index_options.append(
250+
f"distance_measure={vector_index.distance_measure.value}"
251+
)
252+
if vector_index.num_partitions:
253+
index_options.append(f"num_partitions={vector_index.num_partitions}")
254+
if vector_index.num_neighbors:
255+
index_options.append(f"num_neighbors={vector_index.num_neighbors}")
256+
index_options_query = ",".join(index_options)
257+
258+
stmt = query_template.format(index_options_query)
259+
self.engine._execute_outside_tx(stmt)
260+
261+
def _get_vector_index_name(self):
262+
query = f"SELECT index_name FROM mysql.vector_indexes WHERE table_name='{self.db_name}.{self.table_name}';"
263+
result = self.engine._fetch(query)
264+
if result:
265+
return result[0]["index_name"]
266+
else:
267+
return None
268+
269+
def drop_vector_index(self):
270+
existing_index_name = self._get_vector_index_name()
271+
if existing_index_name:
272+
self.engine._execute_outside_tx(
273+
f"CALL mysql.drop_vector_index('{existing_index_name}');"
274+
)
275+
self.query_options.search_type = SearchType.KNN
276+
return existing_index_name
277+
213278
@classmethod
214279
def from_texts( # type: ignore[override]
215280
cls: Type[MySQLVectorStore],
@@ -225,6 +290,7 @@ def from_texts( # type: ignore[override]
225290
ignore_metadata_columns: Optional[List[str]] = None,
226291
id_column: str = "langchain_id",
227292
metadata_json_column: str = "langchain_metadata",
293+
query_options: QueryOptions = DEFAULT_QUERY_OPTIONS,
228294
**kwargs: Any,
229295
):
230296
vs = cls(
@@ -237,6 +303,7 @@ def from_texts( # type: ignore[override]
237303
ignore_metadata_columns=ignore_metadata_columns,
238304
id_column=id_column,
239305
metadata_json_column=metadata_json_column,
306+
query_options=query_options,
240307
)
241308
vs.add_texts(texts, metadatas=metadatas, ids=ids, **kwargs)
242309
return vs
@@ -255,6 +322,7 @@ def from_documents( # type: ignore[override]
255322
ignore_metadata_columns: Optional[List[str]] = None,
256323
id_column: str = "langchain_id",
257324
metadata_json_column: str = "langchain_metadata",
325+
query_options: QueryOptions = DEFAULT_QUERY_OPTIONS,
258326
**kwargs: Any,
259327
) -> MySQLVectorStore:
260328
vs = cls(
@@ -267,6 +335,7 @@ def from_documents( # type: ignore[override]
267335
ignore_metadata_columns=ignore_metadata_columns,
268336
id_column=id_column,
269337
metadata_json_column=metadata_json_column,
338+
query_options=query_options,
270339
)
271340
texts = [doc.page_content for doc in documents]
272341
metadatas = [doc.metadata for doc in documents]

0 commit comments

Comments
 (0)