2323from langchain_core .vectorstores import VectorStore
2424
2525from .engine import MySQLEngine
26- from .indexes import QueryOptions
26+ from .indexes import DEFAULT_QUERY_OPTIONS , QueryOptions , SearchType , VectorIndex
27+
28+ DEFAULT_INDEX_NAME_SUFFIX = "langchainvectorindex"
2729
2830
2931class MySQLVectorStore (VectorStore ):
@@ -38,7 +40,7 @@ def __init__(
3840 ignore_metadata_columns : Optional [List [str ]] = None ,
3941 id_column : str = "langchain_id" ,
4042 metadata_json_column : Optional [str ] = "langchain_metadata" ,
41- query_options : Optional [ QueryOptions ] = None ,
43+ query_options : QueryOptions = DEFAULT_QUERY_OPTIONS ,
4244 ):
4345 """Constructor for MySQLVectorStore.
4446 Args:
@@ -118,11 +120,16 @@ def __init__(
118120 self .id_column = id_column
119121 self .metadata_json_column = metadata_json_column
120122 self .query_options = query_options
123+ self .db_name = self .__get_db_name ()
121124
122125 @property
123126 def embeddings (self ) -> Embeddings :
124127 return self .embedding_service
125128
129+ def __get_db_name (self ) -> str :
130+ result = self .engine ._fetch ("SELECT DATABASE();" )
131+ return result [0 ]["DATABASE()" ]
132+
126133 def _add_embeddings (
127134 self ,
128135 texts : Iterable [str ],
@@ -210,6 +217,64 @@ def delete(
210217 self .engine ._execute (query )
211218 return True
212219
220+ def apply_vector_index (self , vector_index : VectorIndex ):
221+ # Construct the default index name
222+ if not vector_index .name :
223+ vector_index .name = f"{ self .table_name } _{ DEFAULT_INDEX_NAME_SUFFIX } "
224+ query_template = f"CALL mysql.create_vector_index('{ vector_index .name } ', '{ self .db_name } .{ self .table_name } ', '{ self .embedding_column } ', '{{}}');"
225+ self .__exec_apply_vector_index (query_template , vector_index )
226+ # After applying an index to the table, set the query option search type to be ANN
227+ self .query_options .search_type = SearchType .ANN
228+
229+ def alter_vector_index (self , vector_index : VectorIndex ):
230+ existing_index_name = self ._get_vector_index_name ()
231+ if not existing_index_name :
232+ raise ValueError ("No existing vector index found." )
233+ if not vector_index .name :
234+ vector_index .name = existing_index_name .split ("." )[1 ]
235+ if existing_index_name .split ("." )[1 ] != vector_index .name :
236+ raise ValueError (
237+ f"Existing index name { existing_index_name } does not match the new index name { vector_index .name } ."
238+ )
239+ query_template = (
240+ f"CALL mysql.alter_vector_index('{ existing_index_name } ', '{{}}');"
241+ )
242+ self .__exec_apply_vector_index (query_template , vector_index )
243+
244+ def __exec_apply_vector_index (self , query_template : str , vector_index : VectorIndex ):
245+ index_options = []
246+ if vector_index .index_type :
247+ index_options .append (f"index_type={ vector_index .index_type .value } " )
248+ if vector_index .distance_measure :
249+ index_options .append (
250+ f"distance_measure={ vector_index .distance_measure .value } "
251+ )
252+ if vector_index .num_partitions :
253+ index_options .append (f"num_partitions={ vector_index .num_partitions } " )
254+ if vector_index .num_neighbors :
255+ index_options .append (f"num_neighbors={ vector_index .num_neighbors } " )
256+ index_options_query = "," .join (index_options )
257+
258+ stmt = query_template .format (index_options_query )
259+ self .engine ._execute_outside_tx (stmt )
260+
261+ def _get_vector_index_name (self ):
262+ query = f"SELECT index_name FROM mysql.vector_indexes WHERE table_name='{ self .db_name } .{ self .table_name } ';"
263+ result = self .engine ._fetch (query )
264+ if result :
265+ return result [0 ]["index_name" ]
266+ else :
267+ return None
268+
269+ def drop_vector_index (self ):
270+ existing_index_name = self ._get_vector_index_name ()
271+ if existing_index_name :
272+ self .engine ._execute_outside_tx (
273+ f"CALL mysql.drop_vector_index('{ existing_index_name } ');"
274+ )
275+ self .query_options .search_type = SearchType .KNN
276+ return existing_index_name
277+
213278 @classmethod
214279 def from_texts ( # type: ignore[override]
215280 cls : Type [MySQLVectorStore ],
@@ -225,6 +290,7 @@ def from_texts( # type: ignore[override]
225290 ignore_metadata_columns : Optional [List [str ]] = None ,
226291 id_column : str = "langchain_id" ,
227292 metadata_json_column : str = "langchain_metadata" ,
293+ query_options : QueryOptions = DEFAULT_QUERY_OPTIONS ,
228294 ** kwargs : Any ,
229295 ):
230296 vs = cls (
@@ -237,6 +303,7 @@ def from_texts( # type: ignore[override]
237303 ignore_metadata_columns = ignore_metadata_columns ,
238304 id_column = id_column ,
239305 metadata_json_column = metadata_json_column ,
306+ query_options = query_options ,
240307 )
241308 vs .add_texts (texts , metadatas = metadatas , ids = ids , ** kwargs )
242309 return vs
@@ -255,6 +322,7 @@ def from_documents( # type: ignore[override]
255322 ignore_metadata_columns : Optional [List [str ]] = None ,
256323 id_column : str = "langchain_id" ,
257324 metadata_json_column : str = "langchain_metadata" ,
325+ query_options : QueryOptions = DEFAULT_QUERY_OPTIONS ,
258326 ** kwargs : Any ,
259327 ) -> MySQLVectorStore :
260328 vs = cls (
@@ -267,6 +335,7 @@ def from_documents( # type: ignore[override]
267335 ignore_metadata_columns = ignore_metadata_columns ,
268336 id_column = id_column ,
269337 metadata_json_column = metadata_json_column ,
338+ query_options = query_options ,
270339 )
271340 texts = [doc .page_content for doc in documents ]
272341 metadatas = [doc .metadata for doc in documents ]
0 commit comments