@@ -370,20 +370,20 @@ def if_(
370370 provides optimization such that not all rows are evaluated with the LLM.
371371
372372 **Examples:**
373- >>> import bigframes.pandas as bpd
374- >>> import bigframes.bigquery as bbq
375- >>> bpd.options.display.progress_bar = None
376- >>> us_state = bpd.Series(["Massachusetts", "Illinois", "Hawaii"])
377- >>> bbq.ai.if_((us_state, " has a city called Springfield"))
378- 0 True
379- 1 True
380- 2 False
381- dtype: boolean
382-
383- >>> us_state[bbq.ai.if_((us_state, " has a city called Springfield"))]
384- 0 Massachusetts
385- 1 Illinois
386- dtype: string
373+ >>> import bigframes.pandas as bpd
374+ >>> import bigframes.bigquery as bbq
375+ >>> bpd.options.display.progress_bar = None
376+ >>> us_state = bpd.Series(["Massachusetts", "Illinois", "Hawaii"])
377+ >>> bbq.ai.if_((us_state, " has a city called Springfield"))
378+ 0 True
379+ 1 True
380+ 2 False
381+ dtype: boolean
382+
383+ >>> us_state[bbq.ai.if_((us_state, " has a city called Springfield"))]
384+ 0 Massachusetts
385+ 1 Illinois
386+ dtype: string
387387
388388 Args:
389389 prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
@@ -408,6 +408,56 @@ def if_(
408408 return series_list [0 ]._apply_nary_op (operator , series_list [1 :])
409409
410410
411+ @log_adapter .method_logger (custom_base_name = "bigquery_ai" )
412+ def classify (
413+ input : PROMPT_TYPE ,
414+ categories : tuple [str , ...] | list [str ],
415+ * ,
416+ connection_id : str | None = None ,
417+ ) -> series .Series :
418+ """
419+ Classifies a given input into one of the specified categories. It will always return one of the provided categories best fit the prompt input.
420+
421+ **Examples:**
422+
423+ >>> import bigframes.pandas as bpd
424+ >>> import bigframes.bigquery as bbq
425+ >>> bpd.options.display.progress_bar = None
426+ >>> df = bpd.DataFrame({'creature': ['Cat', 'Salmon']})
427+ >>> df['type'] = bbq.ai.classify(df['creature'], ['Mammal', 'Fish'])
428+ >>> df
429+ creature type
430+ 0 Cat Mammal
431+ 1 Salmon Fish
432+ <BLANKLINE>
433+ [2 rows x 2 columns]
434+
435+ Args:
436+ input (Series | List[str|Series] | Tuple[str|Series, ...]):
437+ A mixture of Series and string literals that specifies the input to send to the model. The Series can be BigFrames Series
438+ or pandas Series.
439+ categories (tuple[str, ...] | list[str]):
440+ Categories to classify the input into.
441+ connection_id (str, optional):
442+ Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
443+ If not provided, the connection from the current session will be used.
444+
445+ Returns:
446+ bigframes.series.Series: A new series of strings.
447+ """
448+
449+ prompt_context , series_list = _separate_context_and_series (input )
450+ assert len (series_list ) > 0
451+
452+ operator = ai_ops .AIClassify (
453+ prompt_context = tuple (prompt_context ),
454+ categories = tuple (categories ),
455+ connection_id = _resolve_connection_id (series_list [0 ], connection_id ),
456+ )
457+
458+ return series_list [0 ]._apply_nary_op (operator , series_list [1 :])
459+
460+
411461@log_adapter .method_logger (custom_base_name = "bigquery_ai" )
412462def score (
413463 prompt : PROMPT_TYPE ,
@@ -420,15 +470,16 @@ def score(
420470 rubric with examples in the prompt.
421471
422472 **Examples:**
423- >>> import bigframes.pandas as bpd
424- >>> import bigframes.bigquery as bbq
425- >>> bpd.options.display.progress_bar = None
426- >>> animal = bpd.Series(["Tiger", "Rabbit", "Blue Whale"])
427- >>> bbq.ai.score(("Rank the relative weights of ", animal, " on the scale from 1 to 3")) # doctest: +SKIP
428- 0 2.0
429- 1 1.0
430- 2 3.0
431- dtype: Float64
473+
474+ >>> import bigframes.pandas as bpd
475+ >>> import bigframes.bigquery as bbq
476+ >>> bpd.options.display.progress_bar = None
477+ >>> animal = bpd.Series(["Tiger", "Rabbit", "Blue Whale"])
478+ >>> bbq.ai.score(("Rank the relative weights of ", animal, " on the scale from 1 to 3")) # doctest: +SKIP
479+ 0 2.0
480+ 1 1.0
481+ 2 3.0
482+ dtype: Float64
432483
433484 Args:
434485 prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
0 commit comments