@@ -609,6 +609,120 @@ def search(
609609
610610 return typing .cast (bigframes .dataframe .DataFrame , search_result )
611611
612+ def sim_join (
613+ self ,
614+ other ,
615+ left_on : str ,
616+ right_on : str ,
617+ model ,
618+ top_k : int = 3 ,
619+ score_column : Optional [str ] = None ,
620+ max_rows : int = 1000 ,
621+ ):
622+ """
623+ Joins two dataframes based on the similarity of the specified columns.
624+
625+ This method uses BigQuery's VECTOR_SEARCH function to match rows on the left side with the rows that have
626+ nearest embedding vectors on the right. In the worst case scenario, the complexity is around O(M * N * log K).
627+ Therefore, this is a potentially expensive operation.
628+
629+ ** Examples: **
630+
631+ >>> import bigframes.pandas as bpd
632+ >>> bpd.options.display.progress_bar = None
633+
634+ >>> import bigframes
635+ >>> bigframes.options.experiments.semantic_operators = True
636+
637+ >>> import bigframes.ml.llm as llm
638+ >>> model = llm.TextEmbeddingGenerator(model_name="text-embedding-004")
639+
640+ >>> df1 = bpd.DataFrame({'animal': ['monkey', 'spider']})
641+ >>> df2 = bpd.DataFrame({'animal': ['scorpion', 'baboon']})
642+
643+ >>> df1.semantics.sim_join(df2, left_on='animal', right_on='animal', model=model, top_k=1)
644+ animal animal_1
645+ 0 monkey baboon
646+ 1 spider scorpion
647+ <BLANKLINE>
648+ [2 rows x 2 columns]
649+
650+ Args:
651+ other (DataFrame):
652+ The other data frame to join with.
653+ left_on (str):
654+ The name of the column on left side for the join.
655+ right_on (str):
656+ The name of the column on the right side for the join.
657+ top_k (int, default 3):
658+ The number of nearest neighbors to return.
659+ model (TextEmbeddingGenerator):
660+ A TextEmbeddingGenerator provided by Bigframes ML package.
661+ score_column (Optional[str], default None):
662+ The name of the the additional column containning the similarity scores. If None,
663+ this column won't be attached to the result.
664+ max_rows:
665+ The maximum number of rows allowed to be processed per call. If the result is too large, the method
666+ call will end early with an error.
667+
668+ Returns:
669+ DataFrame: the data frame with the join result.
670+
671+ Raises:
672+ ValueError: when the amount of data to be processed exceeds the specified max_rows.
673+ """
674+
675+ if left_on not in self ._df .columns :
676+ raise ValueError (f"Left column { left_on } not found" )
677+ if right_on not in self ._df .columns :
678+ raise ValueError (f"Right column { right_on } not found" )
679+
680+ import bigframes .ml .llm as llm
681+
682+ if not isinstance (model , llm .TextEmbeddingGenerator ):
683+ raise TypeError (f"Expect a text embedding model, but got: { type (model )} " )
684+
685+ joined_table_rows = len (self ._df ) * len (other )
686+ if joined_table_rows > max_rows :
687+ raise ValueError (
688+ f"Number of rows that need processing is { joined_table_rows } , which exceeds row limit { max_rows } ."
689+ )
690+
691+ base_table_embedding_column = bigframes .core .guid .generate_guid ()
692+ base_table = self ._attach_embedding (
693+ other , right_on , base_table_embedding_column , model
694+ ).to_gbq ()
695+ query_table = self ._attach_embedding (self ._df , left_on , "embedding" , model )
696+
697+ import bigframes .bigquery as bbq
698+
699+ join_result = bbq .vector_search (
700+ base_table = base_table ,
701+ column_to_search = base_table_embedding_column ,
702+ query = query_table ,
703+ top_k = top_k ,
704+ )
705+
706+ join_result = join_result .drop (
707+ ["embedding" , base_table_embedding_column ], axis = 1
708+ )
709+
710+ if score_column is not None :
711+ join_result = join_result .rename (columns = {"distance" : score_column })
712+ else :
713+ del join_result ["distance" ]
714+
715+ return join_result
716+
717+ @staticmethod
718+ def _attach_embedding (dataframe , source_column : str , embedding_column : str , model ):
719+ result_df = dataframe .copy ()
720+ embeddings = model .predict (dataframe [source_column ])[
721+ "ml_generate_embedding_result"
722+ ]
723+ result_df [embedding_column ] = embeddings
724+ return result_df
725+
612726 def _make_prompt (
613727 self , columns : List [str ], user_instruction : str , output_instruction : str
614728 ):
0 commit comments