@@ -98,17 +98,20 @@ def agg(
9898 ValueError: when the instruction refers to a non-existing column, or when
9999 more than one columns are referred to.
100100 """
101- self ._validate_model (model )
101+ import bigframes .bigquery as bbq
102+ import bigframes .dataframe
103+ import bigframes .series
102104
105+ self ._validate_model (model )
103106 columns = self ._parse_columns (instruction )
107+
108+ df : bigframes .dataframe .DataFrame = self ._df .copy ()
104109 for column in columns :
105110 if column not in self ._df .columns :
106111 raise ValueError (f"Column { column } not found." )
107- if self ._df [column ].dtype != dtypes .STRING_DTYPE :
108- raise TypeError (
109- "Semantics aggregated column must be a string type, not "
110- f"{ type (self ._df [column ])} "
111- )
112+
113+ if df [column ].dtype != dtypes .STRING_DTYPE :
114+ df [column ] = df [column ].astype (dtypes .STRING_DTYPE )
112115
113116 if len (columns ) > 1 :
114117 raise NotImplementedError (
@@ -122,11 +125,6 @@ def agg(
122125 "It must be greater than 1."
123126 )
124127
125- import bigframes .bigquery as bbq
126- import bigframes .dataframe
127- import bigframes .series
128-
129- df : bigframes .dataframe .DataFrame = self ._df .copy ()
130128 user_instruction = self ._format_instruction (instruction , columns )
131129
132130 num_cluster = 1
@@ -325,26 +323,27 @@ def filter(self, instruction: str, model):
325323 ValueError: when the instruction refers to a non-existing column, or when no
326324 columns are referred to.
327325 """
326+ import bigframes .dataframe
327+ import bigframes .series
328+
328329 self ._validate_model (model )
329330 columns = self ._parse_columns (instruction )
330331 for column in columns :
331332 if column not in self ._df .columns :
332333 raise ValueError (f"Column { column } not found." )
333- if self . _df [ column ]. dtype != dtypes . STRING_DTYPE :
334- raise TypeError (
335- "Semantics aggregated column must be a string type, not "
336- f" { type ( self . _df [column ]) } "
337- )
334+
335+ df : bigframes . dataframe . DataFrame = self . _df [ columns ]. copy ()
336+ for column in columns :
337+ if df [column ]. dtype != dtypes . STRING_DTYPE :
338+ df [ column ] = df [ column ]. astype ( dtypes . STRING_DTYPE )
338339
339340 user_instruction = self ._format_instruction (instruction , columns )
340341 output_instruction = "Based on the provided context, reply to the following claim by only True or False:"
341342
342- from bigframes .dataframe import DataFrame
343-
344343 results = typing .cast (
345- DataFrame ,
344+ bigframes . dataframe . DataFrame ,
346345 model .predict (
347- self ._make_prompt (columns , user_instruction , output_instruction ),
346+ self ._make_prompt (df , columns , user_instruction , output_instruction ),
348347 temperature = 0.0 ,
349348 ),
350349 )
@@ -398,28 +397,29 @@ def map(self, instruction: str, output_column: str, model):
398397 ValueError: when the instruction refers to a non-existing column, or when no
399398 columns are referred to.
400399 """
400+ import bigframes .dataframe
401+ import bigframes .series
402+
401403 self ._validate_model (model )
402404 columns = self ._parse_columns (instruction )
403405 for column in columns :
404406 if column not in self ._df .columns :
405407 raise ValueError (f"Column { column } not found." )
406- if self . _df [ column ]. dtype != dtypes . STRING_DTYPE :
407- raise TypeError (
408- "Semantics aggregated column must be a string type, not "
409- f" { type ( self . _df [column ]) } "
410- )
408+
409+ df : bigframes . dataframe . DataFrame = self . _df [ columns ]. copy ()
410+ for column in columns :
411+ if df [column ]. dtype != dtypes . STRING_DTYPE :
412+ df [ column ] = df [ column ]. astype ( dtypes . STRING_DTYPE )
411413
412414 user_instruction = self ._format_instruction (instruction , columns )
413415 output_instruction = (
414416 "Based on the provided contenxt, answer the following instruction:"
415417 )
416418
417- from bigframes .series import Series
418-
419419 results = typing .cast (
420- Series ,
420+ bigframes . series . Series ,
421421 model .predict (
422- self ._make_prompt (columns , user_instruction , output_instruction ),
422+ self ._make_prompt (df , columns , user_instruction , output_instruction ),
423423 temperature = 0.0 ,
424424 )["ml_generate_text_llm_result" ],
425425 )
@@ -683,6 +683,9 @@ def top_k(self, instruction: str, model, k=10):
683683 ValueError: when the instruction refers to a non-existing column, or when no
684684 columns are referred to.
685685 """
686+ import bigframes .dataframe
687+ import bigframes .series
688+
686689 self ._validate_model (model )
687690 columns = self ._parse_columns (instruction )
688691 for column in columns :
@@ -692,12 +695,12 @@ def top_k(self, instruction: str, model, k=10):
692695 raise NotImplementedError (
693696 "Semantic aggregations are limited to a single column."
694697 )
698+
699+ df : bigframes .dataframe .DataFrame = self ._df [columns ].copy ()
695700 column = columns [0 ]
696- if self ._df [column ].dtype != dtypes .STRING_DTYPE :
697- raise TypeError (
698- "Referred column must be a string type, not "
699- f"{ type (self ._df [column ])} "
700- )
701+ if df [column ].dtype != dtypes .STRING_DTYPE :
702+ df [column ] = df [column ].astype (dtypes .STRING_DTYPE )
703+
701704 # `index` is reserved for the `reset_index` below.
702705 if column == "index" :
703706 raise ValueError (
@@ -709,12 +712,7 @@ def top_k(self, instruction: str, model, k=10):
709712
710713 user_instruction = self ._format_instruction (instruction , columns )
711714
712- import bigframes .dataframe
713- import bigframes .series
714-
715- df : bigframes .dataframe .DataFrame = self ._df [columns ].copy ()
716715 n = df .shape [0 ]
717-
718716 if k >= n :
719717 return df
720718
@@ -762,17 +760,17 @@ def _topk_partition(
762760
763761 # Random pivot selection for improved average quickselect performance.
764762 pending_df = df [df [status_column ].isna ()]
765- pivot_iloc = np .random .randint (0 , pending_df .shape [0 ] - 1 )
763+ pivot_iloc = np .random .randint (0 , pending_df .shape [0 ])
766764 pivot_index = pending_df .iloc [pivot_iloc ]["index" ]
767765 pivot_df = pending_df [pending_df ["index" ] == pivot_index ]
768766
769767 # Build a prompt to compare the pivot item's relevance to other pending items.
770768 prompt_s = pending_df [pending_df ["index" ] != pivot_index ][column ]
771769 prompt_s = (
772770 f"{ output_instruction } \n \n Question: { user_instruction } \n "
773- + "\n Document 1: "
771+ + f "\n Document 1: { column } "
774772 + pivot_df .iloc [0 ][column ]
775- + "\n Document 2: "
773+ + f "\n Document 2: { column } "
776774 + prompt_s # type:ignore
777775 )
778776
@@ -920,9 +918,8 @@ def _attach_embedding(dataframe, source_column: str, embedding_column: str, mode
920918 return result_df
921919
922920 def _make_prompt (
923- self , columns : List [ str ] , user_instruction : str , output_instruction : str
921+ self , prompt_df , columns , user_instruction : str , output_instruction : str
924922 ):
925- prompt_df = self ._df [columns ].copy ()
926923 prompt_df ["prompt" ] = f"{ output_instruction } \n { user_instruction } \n Context: "
927924
928925 # Combine context from multiple columns.
0 commit comments