2727
2828
2929def _parse_doc_from_row (
30- content_columns : Iterable [str ], metadata_columns : Iterable [str ], row : Dict
30+ content_columns : Iterable [str ],
31+ metadata_columns : Iterable [str ],
32+ row : Dict ,
33+ metadata_json_column : str = DEFAULT_METADATA_COL ,
3134) -> Document :
3235 page_content = " " .join (
3336 str (row [column ]) for column in content_columns if column in row
3437 )
3538 metadata : Dict [str , Any ] = {}
3639 # unnest metadata from langchain_metadata column
37- if DEFAULT_METADATA_COL in metadata_columns and row .get (DEFAULT_METADATA_COL ):
38- for k , v in row [DEFAULT_METADATA_COL ].items ():
40+ if row .get (metadata_json_column ):
41+ for k , v in row [metadata_json_column ].items ():
3942 metadata [k ] = v
4043 # load metadata from other columns
4144 for column in metadata_columns :
42- if column in row and column != DEFAULT_METADATA_COL :
45+ if column in row and column != metadata_json_column :
4346 metadata [column ] = row [column ]
4447 return Document (page_content = page_content , metadata = metadata )
4548
4649
47- def _parse_row_from_doc (column_names : Iterable [str ], doc : Document ) -> Dict :
50+ def _parse_row_from_doc (
51+ column_names : Iterable [str ],
52+ doc : Document ,
53+ content_column : str = DEFAULT_CONTENT_COL ,
54+ metadata_json_column : str = DEFAULT_METADATA_COL ,
55+ ) -> Dict :
4856 doc_metadata = doc .metadata .copy ()
49- row : Dict [str , Any ] = {DEFAULT_CONTENT_COL : doc .page_content }
57+ row : Dict [str , Any ] = {content_column : doc .page_content }
5058 for entry in doc .metadata :
5159 if entry in column_names :
5260 row [entry ] = doc_metadata [entry ]
5361 del doc_metadata [entry ]
5462 # store extra metadata in langchain_metadata column in json format
55- if DEFAULT_METADATA_COL in column_names and len (doc_metadata ) > 0 :
56- row [DEFAULT_METADATA_COL ] = doc_metadata
63+ if metadata_json_column in column_names and len (doc_metadata ) > 0 :
64+ row [metadata_json_column ] = doc_metadata
5765 return row
5866
5967
@@ -67,6 +75,7 @@ def __init__(
6775 query : str = "" ,
6876 content_columns : Optional [List [str ]] = None ,
6977 metadata_columns : Optional [List [str ]] = None ,
78+ metadata_json_column : Optional [str ] = None ,
7079 ):
7180 """
7281 Document page content defaults to the first column present in the query or table and
@@ -85,12 +94,15 @@ def __init__(
8594 of the document. Optional.
8695 metadata_columns (List[str]): The columns to write into the `metadata` of the document.
8796 Optional.
97+ metadata_json_column (str): The name of the JSON column to use as the metadata’s base
98+ dictionary. Default: `langchain_metadata`. Optional.
8899 """
89100 self .engine = engine
90101 self .table_name = table_name
91102 self .query = query
92103 self .content_columns = content_columns
93104 self .metadata_columns = metadata_columns
105+ self .metadata_json_column = metadata_json_column
94106 if not self .table_name and not self .query :
95107 raise ValueError ("One of 'table_name' or 'query' must be specified." )
96108 if self .table_name and self .query :
@@ -139,6 +151,25 @@ def lazy_load(self) -> Iterator[Document]:
139151 metadata_columns = self .metadata_columns or [
140152 col for col in column_names if col not in content_columns
141153 ]
154+ # check validity of metadata json column
155+ if (
156+ self .metadata_json_column
157+ and self .metadata_json_column not in column_names
158+ ):
159+ raise ValueError (
160+ f"Column { self .metadata_json_column } not found in query result { column_names } ."
161+ )
162+ # check validity of other column
163+ all_names = content_columns + metadata_columns
164+ for name in all_names :
165+ if name not in column_names :
166+ raise ValueError (
167+ f"Column { name } not found in query result { column_names } ."
168+ )
169+ # use default metadata json column if not specified
170+ metadata_json_column = self .metadata_json_column or DEFAULT_METADATA_COL
171+
172+ # load document one by one
142173 while True :
143174 row = result_proxy .fetchone ()
144175 if not row :
@@ -151,7 +182,12 @@ def lazy_load(self) -> Iterator[Document]:
151182 row_data [column ] = json .loads (value )
152183 else :
153184 row_data [column ] = value
154- yield _parse_doc_from_row (content_columns , metadata_columns , row_data )
185+ yield _parse_doc_from_row (
186+ content_columns ,
187+ metadata_columns ,
188+ row_data ,
189+ metadata_json_column ,
190+ )
155191
156192
157193class MySQLDocumentSaver :
@@ -161,6 +197,8 @@ def __init__(
161197 self ,
162198 engine : MySQLEngine ,
163199 table_name : str ,
200+ content_column : Optional [str ] = None ,
201+ metadata_json_column : Optional [str ] = None ,
164202 ):
165203 """
166204 MySQLDocumentSaver allows for saving of langchain documents in a database. If the table
@@ -169,17 +207,33 @@ def __init__(
169207 - langchain_metadata (type: JSON)
170208
171209 Args:
172- engine: MySQLEngine object to connect to the MySQL database.
173- table_name: The name of table for saving documents.
210+ engine (MySQLEngine): MySQLEngine object to connect to the MySQL database.
211+ table_name (str): The name of table for saving documents.
212+ content_column (str): The column to store document content.
213+ Deafult: `page_content`. Optional.
214+ metadata_json_column (str): The name of the JSON column to use as the metadata’s base
215+ dictionary. Default: `langchain_metadata`. Optional.
174216 """
175217 self .engine = engine
176218 self .table_name = table_name
177219 self ._table = self .engine ._load_document_table (table_name )
178- if DEFAULT_CONTENT_COL not in self ._table .columns .keys ():
220+
221+ self .content_column = content_column or DEFAULT_CONTENT_COL
222+ if self .content_column not in self ._table .columns .keys ():
179223 raise ValueError (
180- f"Missing '{ DEFAULT_CONTENT_COL } ' field in table { table_name } ."
224+ f"Missing '{ self . content_column } ' field in table { table_name } ."
181225 )
182226
227+ # check metadata_json_column existence if it's provided.
228+ if (
229+ metadata_json_column
230+ and metadata_json_column not in self ._table .columns .keys ()
231+ ):
232+ raise ValueError (
233+ f"Cannot find '{ metadata_json_column } ' column in table { table_name } ."
234+ )
235+ self .metadata_json_column = metadata_json_column or DEFAULT_METADATA_COL
236+
183237 def add_documents (self , docs : List [Document ]) -> None :
184238 """
185239 Save documents in the DocumentSaver table. Document’s metadata is added to columns if found or
@@ -190,7 +244,12 @@ def add_documents(self, docs: List[Document]) -> None:
190244 """
191245 with self .engine .connect () as conn :
192246 for doc in docs :
193- row = _parse_row_from_doc (self ._table .columns .keys (), doc )
247+ row = _parse_row_from_doc (
248+ self ._table .columns .keys (),
249+ doc ,
250+ self .content_column ,
251+ self .metadata_json_column ,
252+ )
194253 conn .execute (sqlalchemy .insert (self ._table ).values (row ))
195254 conn .commit ()
196255
@@ -204,7 +263,12 @@ def delete(self, docs: List[Document]) -> None:
204263 """
205264 with self .engine .connect () as conn :
206265 for doc in docs :
207- row = _parse_row_from_doc (self ._table .columns .keys (), doc )
266+ row = _parse_row_from_doc (
267+ self ._table .columns .keys (),
268+ doc ,
269+ self .content_column ,
270+ self .metadata_json_column ,
271+ )
208272 # delete by matching all fields of document
209273 where_conditions = []
210274 for col in self ._table .columns :
0 commit comments