From eb072dbb29dba602049c9134aba4e4c7ee5d0086 Mon Sep 17 00:00:00 2001 From: Teo Date: Thu, 9 Jan 2025 15:03:36 +0100 Subject: [PATCH 1/4] refactor: reorganize session management into dedicated components Split session logic into dedicated components for better separation of concerns and maintainability: - Move session code to dedicated session/ module - Split Session class into: - Session: Data container with minimal public API - SessionManager: Handles lifecycle and state management - SessionApi: Handles API communication - SessionTelemetry: Manages event recording and OTEL integration Key fixes: - Proper UUID and timestamp serialization in events - Consistent API key header handling - Correct token cost formatting in analytics - Proper session ID inheritance - Tags conversion and validation - Event counts type handling This refactor improves code organization while maintaining backward compatibility through the session/__init__.py module. Signed-off-by: Teo --- agentops/client.py | 48 +- agentops/event.py | 22 +- agentops/session.py | 660 ------------------------ agentops/session/__init__.py | 8 + agentops/session/api.py | 104 ++++ agentops/session/manager.py | 196 +++++++ agentops/session/registry.py | 24 + agentops/session/session.py | 105 ++++ agentops/session/telemetry.py | 88 ++++ agentops/telemetry/exporters/session.py | 106 ++++ 10 files changed, 664 insertions(+), 697 deletions(-) delete mode 100644 agentops/session.py create mode 100644 agentops/session/__init__.py create mode 100644 agentops/session/api.py create mode 100644 agentops/session/manager.py create mode 100644 agentops/session/registry.py create mode 100644 agentops/session/session.py create mode 100644 agentops/session/telemetry.py create mode 100644 agentops/telemetry/exporters/session.py diff --git a/agentops/client.py b/agentops/client.py index fb3e17937..cbb75a606 100644 --- a/agentops/client.py +++ b/agentops/client.py @@ -198,47 +198,31 @@ def start_session( self, tags: Optional[List[str]] = None, inherited_session_id: Optional[str] = None, - ) -> Union[Session, None]: - """ - Start a new session for recording events. - - Args: - tags (List[str], optional): Tags that can be used for grouping or sorting later. - e.g. ["test_run"]. - config: (Configuration, optional): Client configuration object - inherited_session_id (optional, str): assign session id to match existing Session - """ + ) -> Optional[Session]: + """Start a new session""" if not self.is_initialized: - return - - if inherited_session_id is not None: - try: - session_id = UUID(inherited_session_id) - except ValueError: - return logger.warning(f"Invalid session id: {inherited_session_id}") - else: - session_id = uuid4() + return None - session_tags = self._config.default_tags.copy() - if tags is not None: - session_tags.update(tags) + try: + session_id = UUID(inherited_session_id) if inherited_session_id else uuid4() + except ValueError: + return logger.warning(f"Invalid session id: {inherited_session_id}") + default_tags = list(self._config.default_tags) if self._config.default_tags else [] session = Session( session_id=session_id, - tags=list(session_tags), - host_env=self.host_env, config=self._config, + tags=tags or default_tags, + host_env=self.host_env, ) - if not session.is_running: - return logger.error("Failed to start session") - - if self._pre_init_queue["agents"] and len(self._pre_init_queue["agents"]) > 0: - for agent_args in self._pre_init_queue["agents"]: - session.create_agent(name=agent_args["name"], agent_id=agent_args["agent_id"]) - self._pre_init_queue["agents"] = [] + if session.is_running: + # Process any queued agents + if self._pre_init_queue["agents"]: + for agent_args in self._pre_init_queue["agents"]: + session.create_agent(name=agent_args["name"], agent_id=agent_args["agent_id"]) + self._pre_init_queue["agents"].clear() - self._sessions.append(session) return session def end_session( diff --git a/agentops/event.py b/agentops/event.py index c6200aca1..50e669078 100644 --- a/agentops/event.py +++ b/agentops/event.py @@ -25,6 +25,7 @@ class Event: end_timestamp(str): A timestamp indicating when the event ended. Defaults to the time when this Event was instantiated. agent_id(UUID, optional): The unique identifier of the agent that triggered the event. id(UUID): A unique identifier for the event. Defaults to a new UUID. + session_id(UUID, optional): The unique identifier of the session that the event belongs to. foo(x=1) { ... @@ -43,6 +44,7 @@ class Event: end_timestamp: Optional[str] = None agent_id: Optional[UUID] = field(default_factory=check_call_stack_for_agent_id) id: UUID = field(default_factory=uuid4) + session_id: Optional[UUID] = None @dataclass @@ -105,7 +107,7 @@ class ToolEvent(Event): @dataclass -class ErrorEvent: +class ErrorEvent(Event): """ For recording any errors e.g. ones related to agent execution @@ -115,21 +117,31 @@ class ErrorEvent: code(str, optional): A code that can be used to identify the error e.g. 501. details(str, optional): Detailed information about the error. logs(str, optional): For detailed information/logging related to the error. - timestamp(str): A timestamp indicating when the error occurred. Defaults to the time when this ErrorEvent was instantiated. - """ + # Inherit common Event fields + event_type: str = field(default=EventType.ERROR.value) + + # Error-specific fields trigger_event: Optional[Event] = None exception: Optional[BaseException] = None error_type: Optional[str] = None code: Optional[str] = None details: Optional[Union[str, Dict[str, str]]] = None logs: Optional[str] = field(default_factory=traceback.format_exc) - timestamp: str = field(default_factory=get_ISO_time) def __post_init__(self): - self.event_type = EventType.ERROR.value + """Process exception if provided""" if self.exception: self.error_type = self.error_type or type(self.exception).__name__ self.details = self.details or str(self.exception) self.exception = None # removes exception from serialization + + # Ensure end timestamp is set + if not self.end_timestamp: + self.end_timestamp = get_ISO_time() + + @property + def timestamp(self) -> str: + """Maintain backward compatibility with old code expecting timestamp""" + return self.init_timestamp diff --git a/agentops/session.py b/agentops/session.py deleted file mode 100644 index b9f07d20b..000000000 --- a/agentops/session.py +++ /dev/null @@ -1,660 +0,0 @@ -from __future__ import annotations - -import asyncio -import functools -import json -import threading -from datetime import datetime, timezone -from decimal import ROUND_HALF_UP, Decimal -from typing import Any, Dict, List, Optional, Sequence, Union -from uuid import UUID, uuid4 - -from opentelemetry import trace -from opentelemetry.context import attach, detach, set_value -from opentelemetry.sdk.resources import SERVICE_NAME, Resource -from opentelemetry.sdk.trace import ReadableSpan, TracerProvider -from opentelemetry.sdk.trace.export import ( - BatchSpanProcessor, - ConsoleSpanExporter, - SpanExporter, - SpanExportResult, -) -from termcolor import colored - -from .config import Configuration -from .enums import EndState -from .event import ErrorEvent, Event -from .exceptions import ApiServerException -from .helpers import filter_unjsonable, get_ISO_time, safe_serialize -from .http_client import HttpClient, Response -from .log_config import logger - -""" -OTEL Guidelines: - - - -- Maintain a single TracerProvider for the application runtime - - Have one global TracerProvider in the Client class - -- According to the OpenTelemetry Python documentation, Resource should be initialized once per application and shared across all telemetry (traces, metrics, logs). -- Each Session gets its own Tracer (with session-specific context) -- Allow multiple sessions to share the provider while maintaining their own context - - - -:: Resource - - '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' - Captures information about the entity producing telemetry as Attributes. - For example, a process producing telemetry that is running in a container - on Kubernetes has a process name, a pod name, a namespace, and possibly - a deployment name. All these attributes can be included in the Resource. - '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' - - The key insight from the documentation is: - - - Resource represents the entity producing telemetry - in our case, that's the AgentOps SDK application itself - - Session-specific information should be attributes on the spans themselves - - A Resource is meant to identify the service/process/application1 - - Sessions are units of work within that application - - The documentation example about "process name, pod name, namespace" refers to where the code is running, not the work it's doing - -""" - - -class SessionExporter(SpanExporter): - """ - Manages publishing events for Session - """ - - def __init__(self, session: Session, **kwargs): - self.session = session - self._shutdown = threading.Event() - self._export_lock = threading.Lock() - super().__init__(**kwargs) - - @property - def endpoint(self): - return f"{self.session.config.endpoint}/v2/create_events" - - def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: - if self._shutdown.is_set(): - return SpanExportResult.SUCCESS - - with self._export_lock: - try: - # Skip if no spans to export - if not spans: - return SpanExportResult.SUCCESS - - events = [] - for span in spans: - event_data = json.loads(span.attributes.get("event.data", "{}")) - - # Format event data based on event type - if span.name == "actions": - formatted_data = { - "action_type": event_data.get("action_type", event_data.get("name", "unknown_action")), - "params": event_data.get("params", {}), - "returns": event_data.get("returns"), - } - elif span.name == "tools": - formatted_data = { - "name": event_data.get("name", event_data.get("tool_name", "unknown_tool")), - "params": event_data.get("params", {}), - "returns": event_data.get("returns"), - } - else: - formatted_data = event_data - - formatted_data = {**event_data, **formatted_data} - # Get timestamps, providing defaults if missing - current_time = datetime.now(timezone.utc).isoformat() - init_timestamp = span.attributes.get("event.timestamp") - end_timestamp = span.attributes.get("event.end_timestamp") - - # Handle missing timestamps - if init_timestamp is None: - init_timestamp = current_time - if end_timestamp is None: - end_timestamp = current_time - - # Get event ID, generate new one if missing - event_id = span.attributes.get("event.id") - if event_id is None: - event_id = str(uuid4()) - - events.append( - { - "id": event_id, - "event_type": span.name, - "init_timestamp": init_timestamp, - "end_timestamp": end_timestamp, - **formatted_data, - "session_id": str(self.session.session_id), - } - ) - - # Only make HTTP request if we have events and not shutdown - if events: - try: - res = HttpClient.post( - self.endpoint, - json.dumps({"events": events}).encode("utf-8"), - api_key=self.session.config.api_key, - jwt=self.session.jwt, - ) - return SpanExportResult.SUCCESS if res.code == 200 else SpanExportResult.FAILURE - except Exception as e: - logger.error(f"Failed to send events: {e}") - return SpanExportResult.FAILURE - - return SpanExportResult.SUCCESS - - except Exception as e: - logger.error(f"Failed to export spans: {e}") - return SpanExportResult.FAILURE - - def force_flush(self, timeout_millis: Optional[int] = None) -> bool: - return True - - def shutdown(self) -> None: - """Handle shutdown gracefully""" - self._shutdown.set() - # Don't call session.end_session() here to avoid circular dependencies - - -class Session: - """ - Represents a session of events, with a start and end state. - - Args: - session_id (UUID): The session id is used to record particular runs. - config (Configuration): The configuration object for the session. - tags (List[str], optional): Tags that can be used for grouping or sorting later. Examples could be ["GPT-4"]. - host_env (dict, optional): A dictionary containing host and environment data. - - Attributes: - init_timestamp (str): The ISO timestamp for when the session started. - end_timestamp (str, optional): The ISO timestamp for when the session ended. Only set after end_session is called. - end_state (str, optional): The final state of the session. Options: "Success", "Fail", "Indeterminate". Defaults to "Indeterminate". - end_state_reason (str, optional): The reason for ending the session. - session_id (UUID): Unique identifier for the session. - tags (List[str]): List of tags associated with the session for grouping and filtering. - video (str, optional): URL to a video recording of the session. - host_env (dict, optional): Dictionary containing host and environment data. - config (Configuration): Configuration object containing settings for the session. - jwt (str, optional): JSON Web Token for authentication with the AgentOps API. - token_cost (Decimal): Running total of token costs for the session. - event_counts (dict): Counter for different types of events: - - llms: Number of LLM calls - - tools: Number of tool calls - - actions: Number of actions - - errors: Number of errors - - apis: Number of API calls - session_url (str, optional): URL to view the session in the AgentOps dashboard. - is_running (bool): Flag indicating if the session is currently active. - """ - - def __init__( - self, - session_id: UUID, - config: Configuration, - tags: Optional[List[str]] = None, - host_env: Optional[dict] = None, - ): - self.end_timestamp = None - self.end_state: Optional[str] = "Indeterminate" - self.session_id = session_id - self.init_timestamp = get_ISO_time() - self.tags: List[str] = tags or [] - self.video: Optional[str] = None - self.end_state_reason: Optional[str] = None - self.host_env = host_env - self.config = config - self.jwt = None - self._lock = threading.Lock() - self._end_session_lock = threading.Lock() - self.token_cost: Decimal = Decimal(0) - self._session_url: str = "" - self.event_counts = { - "llms": 0, - "tools": 0, - "actions": 0, - "errors": 0, - "apis": 0, - } - # self.session_url: Optional[str] = None - - # Start session first to get JWT - self.is_running = self._start_session() - if not self.is_running: - return - - # Initialize OTEL components with a more controlled processor - self._tracer_provider = TracerProvider() - self._otel_tracer = self._tracer_provider.get_tracer( - f"agentops.session.{str(session_id)}", - ) - self._otel_exporter = SessionExporter(session=self) - - # Use smaller batch size and shorter delay to reduce buffering - self._span_processor = BatchSpanProcessor( - self._otel_exporter, - max_queue_size=self.config.max_queue_size, - schedule_delay_millis=self.config.max_wait_time, - max_export_batch_size=min( - max(self.config.max_queue_size // 20, 1), - min(self.config.max_queue_size, 32), - ), - export_timeout_millis=20000, - ) - - self._tracer_provider.add_span_processor(self._span_processor) - - def set_video(self, video: str) -> None: - """ - Sets a url to the video recording of the session. - - Args: - video (str): The url of the video recording - """ - self.video = video - - def _flush_spans(self) -> bool: - """ - Flush pending spans for this specific session with timeout. - Returns True if flush was successful, False otherwise. - """ - if not hasattr(self, "_span_processor"): - return True - - try: - success = self._span_processor.force_flush(timeout_millis=self.config.max_wait_time) - if not success: - logger.warning("Failed to flush all spans before session end") - return success - except Exception as e: - logger.warning(f"Error flushing spans: {e}") - return False - - def end_session( - self, - end_state: str = "Indeterminate", - end_state_reason: Optional[str] = None, - video: Optional[str] = None, - ) -> Union[Decimal, None]: - with self._end_session_lock: - if not self.is_running: - return None - - if not any(end_state == state.value for state in EndState): - logger.warning("Invalid end_state. Please use one of the EndState enums") - return None - - try: - # Force flush any pending spans before ending session - if hasattr(self, "_span_processor"): - self._span_processor.force_flush(timeout_millis=5000) - - # 1. Set shutdown flag on exporter first - if hasattr(self, "_otel_exporter"): - self._otel_exporter.shutdown() - - # 2. Set session end state - self.end_timestamp = get_ISO_time() - self.end_state = end_state - self.end_state_reason = end_state_reason - if video is not None: - self.video = video - - # 3. Mark session as not running before cleanup - self.is_running = False - - # 4. Clean up OTEL components - if hasattr(self, "_span_processor"): - try: - # Force flush any pending spans - self._span_processor.force_flush(timeout_millis=5000) - # Shutdown the processor - self._span_processor.shutdown() - except Exception as e: - logger.warning(f"Error during span processor cleanup: {e}") - finally: - del self._span_processor - - # 5. Final session update - if not (analytics_stats := self.get_analytics()): - return None - - analytics = ( - f"Session Stats - " - f"{colored('Duration:', attrs=['bold'])} {analytics_stats['Duration']} | " - f"{colored('Cost:', attrs=['bold'])} ${analytics_stats['Cost']} | " - f"{colored('LLMs:', attrs=['bold'])} {analytics_stats['LLM calls']} | " - f"{colored('Tools:', attrs=['bold'])} {analytics_stats['Tool calls']} | " - f"{colored('Actions:', attrs=['bold'])} {analytics_stats['Actions']} | " - f"{colored('Errors:', attrs=['bold'])} {analytics_stats['Errors']}" - ) - logger.info(analytics) - - except Exception as e: - logger.exception(f"Error during session end: {e}") - finally: - active_sessions.remove(self) # First thing, get rid of the session - - logger.info( - colored( - f"\x1b[34mSession Replay: {self.session_url}\x1b[0m", - "blue", - ) - ) - return self.token_cost - - def add_tags(self, tags: List[str]) -> None: - """ - Append to session tags at runtime. - """ - if not self.is_running: - return - - if not (isinstance(tags, list) and all(isinstance(item, str) for item in tags)): - if isinstance(tags, str): - tags = [tags] - - # Initialize tags if None - if self.tags is None: - self.tags = [] - - # Add new tags that don't exist - for tag in tags: - if tag not in self.tags: - self.tags.append(tag) - - # Update session state immediately - self._update_session() - - def set_tags(self, tags): - """Set session tags, replacing any existing tags""" - if not self.is_running: - return - - if not (isinstance(tags, list) and all(isinstance(item, str) for item in tags)): - if isinstance(tags, str): - tags = [tags] - - # Set tags directly - self.tags = tags.copy() # Make a copy to avoid reference issues - - # Update session state immediately - self._update_session() - - def record(self, event: Union[Event, ErrorEvent], flush_now=False): - """Record an event using OpenTelemetry spans""" - if not self.is_running: - return - - # Ensure event has all required base attributes - if not hasattr(event, "id"): - event.id = uuid4() - if not hasattr(event, "init_timestamp"): - event.init_timestamp = get_ISO_time() - if not hasattr(event, "end_timestamp") or event.end_timestamp is None: - event.end_timestamp = get_ISO_time() - - # Create session context - token = set_value("session.id", str(self.session_id)) - - try: - token = attach(token) - - # Create a copy of event data to modify - event_data = dict(filter_unjsonable(event.__dict__)) - - # Add required fields based on event type - if isinstance(event, ErrorEvent): - event_data["error_type"] = getattr(event, "error_type", event.event_type) - elif event.event_type == "actions": - # Ensure action events have action_type - if "action_type" not in event_data: - event_data["action_type"] = event_data.get("name", "unknown_action") - if "name" not in event_data: - event_data["name"] = event_data.get("action_type", "unknown_action") - elif event.event_type == "tools": - # Ensure tool events have name - if "name" not in event_data: - event_data["name"] = event_data.get("tool_name", "unknown_tool") - if "tool_name" not in event_data: - event_data["tool_name"] = event_data.get("name", "unknown_tool") - - with self._otel_tracer.start_as_current_span( - name=event.event_type, - attributes={ - "event.id": str(event.id), - "event.type": event.event_type, - "event.timestamp": event.init_timestamp or get_ISO_time(), - "event.end_timestamp": event.end_timestamp or get_ISO_time(), - "session.id": str(self.session_id), - "session.tags": ",".join(self.tags) if self.tags else "", - "event.data": json.dumps(event_data), - }, - ) as span: - if event.event_type in self.event_counts: - self.event_counts[event.event_type] += 1 - - if isinstance(event, ErrorEvent): - span.set_attribute("error", True) - if hasattr(event, "trigger_event") and event.trigger_event: - span.set_attribute("trigger_event.id", str(event.trigger_event.id)) - span.set_attribute("trigger_event.type", event.trigger_event.event_type) - - if flush_now and hasattr(self, "_span_processor"): - self._span_processor.force_flush() - finally: - detach(token) - - def _send_event(self, event): - """Direct event sending for testing""" - try: - payload = { - "events": [ - { - "id": str(event.id), - "event_type": event.event_type, - "init_timestamp": event.init_timestamp, - "end_timestamp": event.end_timestamp, - "data": filter_unjsonable(event.__dict__), - } - ] - } - - HttpClient.post( - f"{self.config.endpoint}/v2/create_events", - json.dumps(payload).encode("utf-8"), - jwt=self.jwt, - ) - except Exception as e: - logger.error(f"Failed to send event: {e}") - - def _reauthorize_jwt(self) -> Union[str, None]: - with self._lock: - payload = {"session_id": self.session_id} - serialized_payload = json.dumps(filter_unjsonable(payload)).encode("utf-8") - res = HttpClient.post( - f"{self.config.endpoint}/v2/reauthorize_jwt", - serialized_payload, - self.config.api_key, - ) - - logger.debug(res.body) - - if res.code != 200: - return None - - jwt = res.body.get("jwt", None) - self.jwt = jwt - return jwt - - def _start_session(self): - with self._lock: - payload = {"session": self.__dict__} - serialized_payload = json.dumps(filter_unjsonable(payload)).encode("utf-8") - - try: - res = HttpClient.post( - f"{self.config.endpoint}/v2/create_session", - serialized_payload, - api_key=self.config.api_key, - parent_key=self.config.parent_key, - ) - except ApiServerException as e: - return logger.error(f"Could not start session - {e}") - - logger.debug(res.body) - - if res.code != 200: - return False - - jwt = res.body.get("jwt", None) - self.jwt = jwt - if jwt is None: - return False - - logger.info( - colored( - f"\x1b[34mSession Replay: {self.session_url}\x1b[0m", - "blue", - ) - ) - - return True - - def _update_session(self) -> None: - """Update session state on the server""" - if not self.is_running: - return - - # TODO: Determine whether we really need to lock here: are incoming calls coming from other threads? - with self._lock: - payload = {"session": self.__dict__} - - try: - res = HttpClient.post( - f"{self.config.endpoint}/v2/update_session", - json.dumps(filter_unjsonable(payload)).encode("utf-8"), - # self.config.api_key, - jwt=self.jwt, - ) - except ApiServerException as e: - return logger.error(f"Could not update session - {e}") - - def create_agent(self, name, agent_id): - if not self.is_running: - return - if agent_id is None: - agent_id = str(uuid4()) - - payload = { - "id": agent_id, - "name": name, - } - - serialized_payload = safe_serialize(payload).encode("utf-8") - try: - HttpClient.post( - f"{self.config.endpoint}/v2/create_agent", - serialized_payload, - api_key=self.config.api_key, - jwt=self.jwt, - ) - except ApiServerException as e: - return logger.error(f"Could not create agent - {e}") - - return agent_id - - def patch(self, func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - kwargs["session"] = self - return func(*args, **kwargs) - - return wrapper - - def _get_response(self) -> Optional[Response]: - payload = {"session": self.__dict__} - try: - response = HttpClient.post( - f"{self.config.endpoint}/v2/update_session", - json.dumps(filter_unjsonable(payload)).encode("utf-8"), - api_key=self.config.api_key, - jwt=self.jwt, - ) - except ApiServerException as e: - return logger.error(f"Could not end session - {e}") - - logger.debug(response.body) - return response - - def _format_duration(self, start_time, end_time) -> str: - start = datetime.fromisoformat(start_time.replace("Z", "+00:00")) - end = datetime.fromisoformat(end_time.replace("Z", "+00:00")) - duration = end - start - - hours, remainder = divmod(duration.total_seconds(), 3600) - minutes, seconds = divmod(remainder, 60) - - parts = [] - if hours > 0: - parts.append(f"{int(hours)}h") - if minutes > 0: - parts.append(f"{int(minutes)}m") - parts.append(f"{seconds:.1f}s") - - return " ".join(parts) - - def _get_token_cost(self, response: Response) -> Decimal: - token_cost = response.body.get("token_cost", "unknown") - if token_cost == "unknown" or token_cost is None: - return Decimal(0) - return Decimal(token_cost) - - def _format_token_cost(self, token_cost: Decimal) -> str: - return ( - "{:.2f}".format(token_cost) - if token_cost == 0 - else "{:.6f}".format(token_cost.quantize(Decimal("0.000001"), rounding=ROUND_HALF_UP)) - ) - - def get_analytics(self) -> Optional[Dict[str, Any]]: - if not self.end_timestamp: - self.end_timestamp = get_ISO_time() - - formatted_duration = self._format_duration(self.init_timestamp, self.end_timestamp) - - if (response := self._get_response()) is None: - return None - - self.token_cost = self._get_token_cost(response) - - return { - "LLM calls": self.event_counts["llms"], - "Tool calls": self.event_counts["tools"], - "Actions": self.event_counts["actions"], - "Errors": self.event_counts["errors"], - "Duration": formatted_duration, - "Cost": self._format_token_cost(self.token_cost), - } - - @property - def session_url(self) -> str: - """Returns the URL for this session in the AgentOps dashboard.""" - assert self.session_id, "Session ID is required to generate a session URL" - return f"https://app.agentops.ai/drilldown?session_id={self.session_id}" - - # @session_url.setter - # def session_url(self, url: str): - # pass - - -active_sessions: List[Session] = [] diff --git a/agentops/session/__init__.py b/agentops/session/__init__.py new file mode 100644 index 000000000..18f3ff3fe --- /dev/null +++ b/agentops/session/__init__.py @@ -0,0 +1,8 @@ +"""Session management module""" +from .session import Session +from .registry import get_active_sessions, add_session, remove_session + +# For backward compatibility +active_sessions = get_active_sessions() + +__all__ = ["Session", "active_sessions", "add_session", "remove_session"] diff --git a/agentops/session/api.py b/agentops/session/api.py new file mode 100644 index 000000000..ba8cf1aed --- /dev/null +++ b/agentops/session/api.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any, Tuple +from uuid import UUID + +from termcolor import colored + +from agentops.event import Event +from agentops.exceptions import ApiServerException +from agentops.helpers import filter_unjsonable, safe_serialize +from agentops.http_client import HttpClient, HttpStatus, Response +from agentops.log_config import logger + +if TYPE_CHECKING: + from agentops.session import Session + + +class SessionApi: + """Handles all API communication for sessions""" + + def __init__(self, session: "Session"): + self.session = session + + @property + def config(self): + return self.session.config + + def create_session(self) -> Tuple[bool, Optional[str]]: + """Create a new session, returns (success, jwt)""" + payload = {"session": dict(self.session)} + try: + res = self._post("/v2/create_session", payload, needs_api_key=True, needs_parent_key=True) + + jwt = res.body.get("jwt") + if not jwt: + return False, None + + return True, jwt + + except ApiServerException as e: + logger.error(f"Could not create session - {e}") + return False, None + + def update_session(self) -> Optional[Dict[str, Any]]: + """Update session state, returns response data if successful""" + payload = {"session": dict(self.session)} + try: + res = self._post("/v2/update_session", payload, needs_api_key=True) + return res.body + except ApiServerException as e: + logger.error(f"Could not update session - {e}") + return None + + def create_agent(self, name: str, agent_id: str) -> bool: + """Create a new agent, returns success""" + payload = { + "id": agent_id, + "name": name, + } + try: + self._post("/v2/create_agent", payload, needs_api_key=True) + return True + except ApiServerException as e: + logger.error(f"Could not create agent - {e}") + return False + + def create_events(self, events: List[Union[Event, dict]]) -> bool: + """Sends events to API""" + try: + res = self._post("/v2/create_events", {"events": events}, needs_api_key=True) + return res.status == HttpStatus.SUCCESS + except ApiServerException as e: + logger.error(f"Could not create events - {e}") + return False + + def _post( + self, endpoint: str, payload: Dict[str, Any], needs_api_key: bool = False, needs_parent_key: bool = False + ) -> Response: + """Helper for making POST requests""" + url = f"{self.config.endpoint}{endpoint}" + serialized = safe_serialize(payload).encode("utf-8") + + kwargs = {} + header = {} + + if needs_api_key: + # Add API key to both kwargs and header + kwargs["api_key"] = self.config.api_key + header["X-Agentops-Api-Key"] = self.config.api_key + + if needs_parent_key: + kwargs["parent_key"] = self.config.parent_key + + if self.session.jwt: + kwargs["jwt"] = self.session.jwt + + if hasattr(self.session, "session_id"): + header["X-Session-ID"] = str(self.session.session_id) + + if header: + kwargs["header"] = header + + return HttpClient.post(url, serialized, **kwargs) diff --git a/agentops/session/manager.py b/agentops/session/manager.py new file mode 100644 index 000000000..0354e5f41 --- /dev/null +++ b/agentops/session/manager.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import threading +from datetime import datetime +from decimal import Decimal +from typing import TYPE_CHECKING, Optional, Union, Dict, List + +from termcolor import colored +from agentops.enums import EndState +from agentops.helpers import get_ISO_time +from agentops.log_config import logger + +if TYPE_CHECKING: + from agentops.event import Event, ErrorEvent + from .session import Session + from .registry import add_session, remove_session + from .api import SessionApi + from .telemetry import SessionTelemetry + + +class SessionManager: + """Handles session lifecycle and state management""" + + def __init__(self, session: "Session"): + self._state = session + self._lock = threading.Lock() + self._end_session_lock = threading.Lock() + + # Import at runtime to avoid circular imports + from .registry import add_session, remove_session + + self._add_session = add_session + self._remove_session = remove_session + + # Initialize components + from .api import SessionApi + from .telemetry import SessionTelemetry + + self._api = SessionApi(self._state) + self._telemetry = SessionTelemetry(self._state) + + # Store reference on session for backward compatibility + self._state._api = self._api + self._state._telemetry = self._telemetry + self._state._otel_exporter = self._telemetry._exporter + + def start_session(self) -> bool: + """Start and initialize session""" + with self._lock: + if not self._state._api: + return False + + success, jwt = self._state._api.create_session() + if success: + self._state.jwt = jwt + self._add_session(self._state) + return success + + def create_agent(self, name: str, agent_id: Optional[str] = None) -> Optional[str]: + """Create a new agent""" + with self._lock: + if agent_id is None: + from uuid import uuid4 + + agent_id = str(uuid4()) + + if not self._state._api: + return None + + success = self._state._api.create_agent(name=name, agent_id=agent_id) + return agent_id if success else None + + def add_tags(self, tags: Union[str, List[str]]) -> None: + """Add tags to session""" + with self._lock: + if isinstance(tags, str): + if tags not in self._state.tags: + self._state.tags.append(tags) + elif isinstance(tags, list): + self._state.tags.extend(t for t in tags if t not in self._state.tags) + + if self._state._api: + self._state._api.update_session() + + def set_tags(self, tags: Union[str, List[str]]) -> None: + """Set session tags""" + with self._lock: + if isinstance(tags, str): + self._state.tags = [tags] + elif isinstance(tags, list): + self._state.tags = list(tags) + + if self._state._api: + self._state._api.update_session() + + def record_event(self, event: Union["Event", "ErrorEvent"], flush_now: bool = False) -> None: + """Update event counts and record event""" + with self._lock: + # Update counts + if event.event_type in self._state.event_counts: + self._state.event_counts[event.event_type] += 1 + + # Record via telemetry + if self._telemetry: + self._telemetry.record_event(event, flush_now) + + def end_session( + self, end_state: str, end_state_reason: Optional[str], video: Optional[str] + ) -> Union[Decimal, None]: + """End session and cleanup""" + with self._end_session_lock: + if not self._state.is_running: + return None + + try: + # Flush any pending telemetry + if self._telemetry: + self._telemetry.flush(timeout_millis=5000) + + self._state.end_timestamp = get_ISO_time() + self._state.end_state = end_state + self._state.end_state_reason = end_state_reason + self._state.video = video if video else self._state.video + self._state.is_running = False + + if analytics := self._get_analytics(): + self._log_analytics(analytics) + self._remove_session(self._state) + return self._state.token_cost + return None + except Exception as e: + logger.exception(f"Error ending session: {e}") + return None + + def _get_analytics(self) -> Optional[Dict[str, Union[int, str]]]: + """Get session analytics""" + if not self._state.end_timestamp: + self._state.end_timestamp = get_ISO_time() + + formatted_duration = self._format_duration(self._state.init_timestamp, self._state.end_timestamp) + + if not self._state._api: + return None + + response = self._state._api.update_session() + if not response: + return None + + # Update token cost from API response + if "token_cost" in response: + self._state.token_cost = Decimal(str(response["token_cost"])) + + return { + "LLM calls": self._state.event_counts["llms"], + "Tool calls": self._state.event_counts["tools"], + "Actions": self._state.event_counts["actions"], + "Errors": self._state.event_counts["errors"], + "Duration": formatted_duration, + "Cost": self._format_token_cost(self._state.token_cost), + } + + def _format_duration(self, start_time: str, end_time: str) -> str: + """Format duration between two timestamps""" + start = datetime.fromisoformat(start_time.replace("Z", "+00:00")) + end = datetime.fromisoformat(end_time.replace("Z", "+00:00")) + duration = end - start + + hours, remainder = divmod(duration.total_seconds(), 3600) + minutes, seconds = divmod(remainder, 60) + + parts = [] + if hours > 0: + parts.append(f"{int(hours)}h") + if minutes > 0: + parts.append(f"{int(minutes)}m") + parts.append(f"{seconds:.1f}s") + + return " ".join(parts) + + def _format_token_cost(self, token_cost: Decimal) -> str: + """Format token cost for display""" + # Always format with 6 decimal places for consistency with tests + return "{:.6f}".format(token_cost) + + def _log_analytics(self, stats: Dict[str, Union[int, str]]) -> None: + """Log analytics in a consistent format""" + analytics = ( + f"Session Stats - " + f"{colored('Duration:', attrs=['bold'])} {stats['Duration']} | " + f"{colored('Cost:', attrs=['bold'])} ${stats['Cost']} | " + f"{colored('LLMs:', attrs=['bold'])} {str(stats['LLM calls'])} | " + f"{colored('Tools:', attrs=['bold'])} {str(stats['Tool calls'])} | " + f"{colored('Actions:', attrs=['bold'])} {str(stats['Actions'])} | " + f"{colored('Errors:', attrs=['bold'])} {str(stats['Errors'])}" + ) + logger.info(analytics) diff --git a/agentops/session/registry.py b/agentops/session/registry.py new file mode 100644 index 000000000..5b62a7453 --- /dev/null +++ b/agentops/session/registry.py @@ -0,0 +1,24 @@ +"""Registry for tracking active sessions""" +from typing import List, TYPE_CHECKING + +if TYPE_CHECKING: + from .session import Session + +_active_sessions = [] # type: List["Session"] + + +def add_session(session: "Session") -> None: + """Add session to active sessions list""" + if session not in _active_sessions: + _active_sessions.append(session) + + +def remove_session(session: "Session") -> None: + """Remove session from active sessions list""" + if session in _active_sessions: + _active_sessions.remove(session) + + +def get_active_sessions() -> List["Session"]: + """Get list of active sessions""" + return _active_sessions diff --git a/agentops/session/session.py b/agentops/session/session.py new file mode 100644 index 000000000..81fd9110a --- /dev/null +++ b/agentops/session/session.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from decimal import Decimal +from typing import TYPE_CHECKING, Dict, List, Optional, Union +from uuid import UUID + +from agentops.config import Configuration +from agentops.enums import EndState +from agentops.helpers import get_ISO_time + +if TYPE_CHECKING: + from agentops.event import Event, ErrorEvent + + +@dataclass +class Session: + """Data container for session state with minimal public API""" + + session_id: UUID + config: Configuration + tags: List[str] = field(default_factory=list) + host_env: Optional[dict] = None + token_cost: Decimal = field(default_factory=lambda: Decimal(0)) + end_state: str = field(default_factory=lambda: EndState.INDETERMINATE.value) + end_state_reason: Optional[str] = None + end_timestamp: Optional[str] = None + jwt: Optional[str] = None + video: Optional[str] = None + event_counts: Dict[str, int] = field( + default_factory=lambda: {"llms": 0, "tools": 0, "actions": 0, "errors": 0, "apis": 0} + ) + init_timestamp: str = field(default_factory=get_ISO_time) + is_running: bool = field(default=True) + + def __post_init__(self): + """Initialize session manager""" + # Convert tags to list first + if isinstance(self.tags, (str, set)): + self.tags = list(self.tags) + elif self.tags is None: + self.tags = [] + + # Then initialize manager + from .manager import SessionManager + + self._manager = SessionManager(self) + self.is_running = self._manager.start_session() + + # Public API - All delegate to manager + def add_tags(self, tags: Union[str, List[str]]) -> None: + """Add tags to session""" + if self.is_running and self._manager: + self._manager.add_tags(tags) + + def set_tags(self, tags: Union[str, List[str]]) -> None: + """Set session tags""" + if self.is_running and self._manager: + self._manager.set_tags(tags) + + def record(self, event: Union["Event", "ErrorEvent"], flush_now: bool = False) -> None: + """Record an event""" + if self._manager: + self._manager.record_event(event, flush_now) + + def end_session( + self, + end_state: str = EndState.INDETERMINATE.value, + end_state_reason: Optional[str] = None, + video: Optional[str] = None, + ) -> Union[Decimal, None]: + """End the session""" + if self._manager: + return self._manager.end_session(end_state, end_state_reason, video) + return None + + def create_agent(self, name: str, agent_id: Optional[str] = None) -> Optional[str]: + """Create a new agent for this session""" + if self.is_running and self._manager: + return self._manager.create_agent(name, agent_id) + return None + + def get_analytics(self) -> Optional[Dict[str, str]]: + """Get session analytics""" + if self._manager: + return self._manager._get_analytics() + return None + + # Serialization support + def __iter__(self): + return iter(self.__dict__().items()) + + def __dict__(self): + filtered_dict = {k: v for k, v in asdict(self).items() if not k.startswith("_") and not callable(v)} + filtered_dict["session_id"] = str(self.session_id) + return filtered_dict + + @property + def session_url(self) -> str: + return f"https://app.agentops.ai/drilldown?session_id={self.session_id}" + + @property + def _tracer_provider(self): + """For testing compatibility""" + return self._telemetry._tracer_provider if self._telemetry else None diff --git a/agentops/session/telemetry.py b/agentops/session/telemetry.py new file mode 100644 index 000000000..ae98ffe50 --- /dev/null +++ b/agentops/session/telemetry.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Optional, Union +from uuid import UUID +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.context import attach, detach, set_value + +from agentops.helpers import get_ISO_time, filter_unjsonable + +if TYPE_CHECKING: + from agentops.session import Session + from agentops.event import Event, ErrorEvent + + +class SessionTelemetry: + """Handles telemetry setup and event recording""" + + def __init__(self, session: "Session"): + self.session = session + self._setup_telemetry() + + def _setup_telemetry(self): + """Initialize OpenTelemetry components""" + self._tracer_provider = TracerProvider() + self._otel_tracer = self._tracer_provider.get_tracer( + f"agentops.session.{str(self.session.session_id)}", + ) + + from agentops.telemetry.exporters.session import SessionExporter + + self._exporter = SessionExporter(session=self.session) + + # Configure batch processor + self._span_processor = BatchSpanProcessor( + self._exporter, + max_queue_size=self.session.config.max_queue_size, + schedule_delay_millis=self.session.config.max_wait_time, + max_export_batch_size=min( + max(self.session.config.max_queue_size // 20, 1), + min(self.session.config.max_queue_size, 32), + ), + export_timeout_millis=20000, + ) + + self._tracer_provider.add_span_processor(self._span_processor) + + def record_event(self, event: Union[Event, ErrorEvent], flush_now: bool = False) -> None: + """Record telemetry for an event""" + if not hasattr(self, "_otel_tracer"): + return + + # Create session context + token = set_value("session.id", str(self.session.session_id)) + try: + token = attach(token) + + # Filter out non-serializable data + event_data = filter_unjsonable(event.__dict__) + + with self._otel_tracer.start_as_current_span( + name=event.event_type, + attributes={ + "event.id": str(event.id), + "event.type": event.event_type, + "event.timestamp": event.init_timestamp or get_ISO_time(), + "event.end_timestamp": event.end_timestamp or get_ISO_time(), + "session.id": str(self.session.session_id), + "session.tags": ",".join(self.session.tags) if self.session.tags else "", + "event.data": json.dumps(event_data), + }, + ) as span: + if hasattr(event, "error_type"): + span.set_attribute("error", True) + if hasattr(event, "trigger_event") and event.trigger_event: + span.set_attribute("trigger_event.id", str(event.trigger_event.id)) + span.set_attribute("trigger_event.type", event.trigger_event.event_type) + + if flush_now and hasattr(self, "_span_processor"): + self._span_processor.force_flush() + finally: + detach(token) + + def flush(self, timeout_millis: Optional[int] = None) -> None: + """Force flush pending spans""" + if hasattr(self, "_span_processor"): + self._span_processor.force_flush(timeout_millis=timeout_millis) diff --git a/agentops/telemetry/exporters/session.py b/agentops/telemetry/exporters/session.py new file mode 100644 index 000000000..c363e01a8 --- /dev/null +++ b/agentops/telemetry/exporters/session.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import json +import threading +from typing import TYPE_CHECKING, Optional, Sequence +from uuid import UUID, uuid4 + +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult + +from agentops.helpers import filter_unjsonable, get_ISO_time +from agentops.http_client import HttpClient +from agentops.log_config import logger + +if TYPE_CHECKING: + from agentops.session import Session + + +class SessionExporter(SpanExporter): + """Manages publishing events for Session""" + + def __init__(self, session: Session, **kwargs): + self.session = session + self._shutdown = threading.Event() + self._export_lock = threading.Lock() + super().__init__(**kwargs) + + @property + def endpoint(self): + return f"{self.session.config.endpoint}/v2/create_events" + + def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: + if self._shutdown.is_set(): + return SpanExportResult.SUCCESS + + with self._export_lock: + try: + if not spans: + return SpanExportResult.SUCCESS + + events = [] + for span in spans: + event_data = json.loads(span.attributes.get("event.data", "{}")) + + # Format event data based on event type + if span.name == "actions": + formatted_data = { + "action_type": event_data.get("action_type", event_data.get("name", "unknown_action")), + "params": event_data.get("params", {}), + "returns": event_data.get("returns"), + } + elif span.name == "tools": + formatted_data = { + "name": event_data.get("name", event_data.get("tool_name", "unknown_tool")), + "params": event_data.get("params", {}), + "returns": event_data.get("returns"), + } + else: + formatted_data = event_data + + formatted_data = {**event_data, **formatted_data} + + # Get timestamps and ID, providing defaults + init_timestamp = span.attributes.get("event.timestamp") or get_ISO_time() + end_timestamp = span.attributes.get("event.end_timestamp") or get_ISO_time() + event_id = span.attributes.get("event.id") or str(uuid4()) + + events.append( + filter_unjsonable( + { + "id": event_id, + "event_type": span.name, + "init_timestamp": init_timestamp, + "end_timestamp": end_timestamp, + **formatted_data, + "session_id": str(self.session.session_id), + } + ) + ) + + # Only make HTTP request if we have events and not shutdown + if events: + try: + res = HttpClient.post( + self.endpoint, + json.dumps({"events": events}).encode("utf-8"), + api_key=self.session.config.api_key, + jwt=self.session.jwt, + ) + return SpanExportResult.SUCCESS if res.code == 200 else SpanExportResult.FAILURE + except Exception as e: + logger.error(f"Failed to send events: {e}") + return SpanExportResult.FAILURE + + return SpanExportResult.SUCCESS + + except Exception as e: + logger.error(f"Failed to export spans: {e}") + return SpanExportResult.FAILURE + + def force_flush(self, timeout_millis: Optional[int] = None) -> bool: + return True + + def shutdown(self) -> None: + """Handle shutdown gracefully""" + self._shutdown.set() From f9b141f6e77ccae304ac34bf44007474d6c75f7b Mon Sep 17 00:00:00 2001 From: Teo Date: Thu, 9 Jan 2025 16:21:33 +0100 Subject: [PATCH 2/4] feat(session): add README for session package documentation --- agentops/session/README.md | 93 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 agentops/session/README.md diff --git a/agentops/session/README.md b/agentops/session/README.md new file mode 100644 index 000000000..05e5a7cc1 --- /dev/null +++ b/agentops/session/README.md @@ -0,0 +1,93 @@ +# Session Package + +This package contains the core session management functionality for AgentOps. + +## Architecture + +```mermaid +graph TD + S[Session] --> |delegates to| M[SessionManager] + M --> |uses| A[SessionApi] + M --> |uses| T[SessionTelemetry] + T --> |uses| E[SessionExporter] + M --> |manages| R[Registry] + R --> |tracks| S +``` + +## Component Responsibilities + +### Session (`session.py`) +- Data container for session state +- Provides public API for session operations +- Delegates all operations to SessionManager + +### SessionManager (`manager.py`) +- Handles session lifecycle and state management +- Coordinates between API, telemetry, and registry +- Manages session analytics and event counts + +### SessionApi (`api.py`) +- Handles all HTTP communication with AgentOps API +- Manages authentication headers and JWT +- Serializes session state for API calls + +### SessionTelemetry (`telemetry.py`) +- Sets up OpenTelemetry infrastructure +- Records events with proper context +- Manages event batching and flushing + +### SessionExporter (`../telemetry/exporters/session.py`) +- Exports OpenTelemetry spans as AgentOps events +- Handles event formatting and delivery +- Manages export batching and retries + +### Registry (`registry.py`) +- Tracks active sessions +- Provides global session access +- Maintains backward compatibility with old code + +## Data Flow + +```mermaid +sequenceDiagram + participant C as Client + participant S as Session + participant M as SessionManager + participant A as SessionApi + participant T as SessionTelemetry + participant E as SessionExporter + + C->>S: start_session() + S->>M: create() + M->>A: create_session() + A-->>M: jwt + M->>T: setup() + T->>E: init() + + C->>S: record(event) + S->>M: record_event() + M->>T: record_event() + T->>E: export() + E->>A: create_events() +``` + +## Usage Example + +```python +from agentops import Client + +# Create client +client = Client(api_key="your-key") + +# Start session +session = client.start_session(tags=["test"]) + +# Record events +session.record(some_event) + +# Add tags +session.add_tags(["new_tag"]) + +# End session +session.end_session(end_state="Success") +``` From 1a0740f2a4111f1a1b0fd5545483528c04158fb9 Mon Sep 17 00:00:00 2001 From: Pratyush Shukla Date: Thu, 9 Jan 2025 20:53:51 +0530 Subject: [PATCH 3/4] remove redundant code in ai21 --- agentops/llms/providers/ai21.py | 75 +++-------------- examples/ai21_examples/ai21_examples.ipynb | 94 ++-------------------- 2 files changed, 15 insertions(+), 154 deletions(-) diff --git a/agentops/llms/providers/ai21.py b/agentops/llms/providers/ai21.py index 8c907d525..91795040a 100644 --- a/agentops/llms/providers/ai21.py +++ b/agentops/llms/providers/ai21.py @@ -29,7 +29,6 @@ def handle_response(self, response, kwargs, init_timestamp, session: Optional[Se from ai21.stream.async_stream import AsyncStream from ai21.models.chat.chat_completion_chunk import ChatCompletionChunk from ai21.models.chat.chat_completion_response import ChatCompletionResponse - from ai21.models.responses.answer_response import AnswerResponse llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) action_event = ActionEvent(init_timestamp=init_timestamp, params=kwargs) @@ -108,27 +107,15 @@ async def async_generator(): # Handle object responses try: - if isinstance(response, ChatCompletionResponse): - llm_event.returns = response - llm_event.agent_id = check_call_stack_for_agent_id() - llm_event.model = kwargs["model"] - llm_event.prompt = [message.model_dump() for message in kwargs["messages"]] - llm_event.prompt_tokens = response.usage.prompt_tokens - llm_event.completion = response.choices[0].message.model_dump() - llm_event.completion_tokens = response.usage.completion_tokens - llm_event.end_timestamp = get_ISO_time() - self._safe_record(session, llm_event) - - elif isinstance(response, AnswerResponse): - action_event.returns = response - action_event.agent_id = check_call_stack_for_agent_id() - action_event.action_type = "Contextual Answers" - action_event.logs = [ - {"context": kwargs["context"], "question": kwargs["question"]}, - response.model_dump() if response.model_dump() else None, - ] - action_event.end_timestamp = get_ISO_time() - self._safe_record(session, action_event) + llm_event.returns = response + llm_event.agent_id = check_call_stack_for_agent_id() + llm_event.model = kwargs["model"] + llm_event.prompt = [message.model_dump() for message in kwargs["messages"]] + llm_event.prompt_tokens = response.usage.prompt_tokens + llm_event.completion = response.choices[0].message.model_dump() + llm_event.completion_tokens = response.usage.completion_tokens + llm_event.end_timestamp = get_ISO_time() + self._safe_record(session, llm_event) except Exception as e: self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e)) @@ -145,8 +132,6 @@ async def async_generator(): def override(self): self._override_completion() self._override_completion_async() - self._override_answer() - self._override_answer_async() def _override_completion(self): from ai21.clients.studio.resources.chat import ChatCompletions @@ -184,42 +169,6 @@ async def patched_function(*args, **kwargs): # Override the original method with the patched one AsyncChatCompletions.create = patched_function - def _override_answer(self): - from ai21.clients.studio.resources.studio_answer import StudioAnswer - - global original_answer - original_answer = StudioAnswer.create - - def patched_function(*args, **kwargs): - # Call the original function with its original arguments - init_timestamp = get_ISO_time() - - session = kwargs.get("session", None) - if "session" in kwargs.keys(): - del kwargs["session"] - result = original_answer(*args, **kwargs) - return self.handle_response(result, kwargs, init_timestamp, session=session) - - StudioAnswer.create = patched_function - - def _override_answer_async(self): - from ai21.clients.studio.resources.studio_answer import AsyncStudioAnswer - - global original_answer_async - original_answer_async = AsyncStudioAnswer.create - - async def patched_function(*args, **kwargs): - # Call the original function with its original arguments - init_timestamp = get_ISO_time() - - session = kwargs.get("session", None) - if "session" in kwargs.keys(): - del kwargs["session"] - result = await original_answer_async(*args, **kwargs) - return self.handle_response(result, kwargs, init_timestamp, session=session) - - AsyncStudioAnswer.create = patched_function - def undo_override(self): if ( self.original_create is not None @@ -231,12 +180,6 @@ def undo_override(self): ChatCompletions, AsyncChatCompletions, ) - from ai21.clients.studio.resources.studio_answer import ( - StudioAnswer, - AsyncStudioAnswer, - ) ChatCompletions.create = self.original_create AsyncChatCompletions.create = self.original_create_async - StudioAnswer.create = self.original_answer - AsyncStudioAnswer.create = self.original_answer_async diff --git a/examples/ai21_examples/ai21_examples.ipynb b/examples/ai21_examples/ai21_examples.ipynb index 16afb4ad7..3f889f9ad 100644 --- a/examples/ai21_examples/ai21_examples.ipynb +++ b/examples/ai21_examples/ai21_examples.ipynb @@ -33,7 +33,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -54,7 +54,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -82,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -114,7 +114,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -171,7 +171,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -226,88 +226,6 @@ "await main()" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Task-Specific Models Examples" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Contextual Answers" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The following example demonstrates the answering capability of AI21 without streaming." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "CONTEXT = \"\"\"\n", - "In 2020 and 2021, enormous QE — approximately $4.4 trillion, or 18%, of 2021 gross\n", - "domestic product (GDP) — and enormous fiscal stimulus (which has been and\n", - "always will be inflationary) — approximately $5 trillion, or 21%, of 2021 GDP\n", - "— stabilized markets and allowed companies to raise enormous amounts of\n", - "capital. In addition, this infusion of capital saved many small businesses and\n", - "put more than $2.5 trillion in the hands of consumers and almost $1 trillion into\n", - "state and local coffers. These actions led to a rapid decline in unemployment, \n", - "dropping from 15% to under 4% in 20 months — the magnitude and speed of which were both\n", - "unprecedented. Additionally, the economy grew 7% in 2021 despite the arrival of\n", - "the Delta and Omicron variants and the global supply chain shortages, which were\n", - "largely fueled by the dramatic upswing in consumer spending and the shift in\n", - "that spend from services to goods.\n", - "\"\"\"\n", - "response = client.answer.create(\n", - " context=CONTEXT,\n", - " question=\"Did the economy shrink after the Omicron variant arrived?\",\n", - ")\n", - "print(response.answer)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Similarly, we can use streaming to get the answer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "CONTEXT = \"\"\"\n", - "In the rapidly evolving field of Artificial Intelligence (AI), mathematical \n", - "foundations such as calculus, linear algebra, and statistics play a crucial role. \n", - "For instance, linear algebra is essential for understanding and developing machine \n", - "learning algorithms. It involves the study of vectors, matrices, and tensor operations \n", - "which are critical for performing transformations and optimizations. Additionally, \n", - "concepts from calculus like derivatives and integrals are used to optimize the \n", - "performance of AI models through gradient descent and other optimization techniques. \n", - "Statistics and probability form the backbone for making inferences and predictions, \n", - "enabling AI systems to learn from data and make decisions under uncertainty. \n", - "Understanding these mathematical principles allows for the development of more robust \n", - "and effective AI systems.\n", - "\"\"\"\n", - "response = client.answer.create(\n", - " context=CONTEXT,\n", - " question=\"Why is linear algebra important for machine learning algorithms?\",\n", - " stream=True,\n", - ")\n", - "print(response.answer)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -334,7 +252,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.19" + "version": "3.10.16" } }, "nbformat": 4, From e7a80350219f5923a8ec6029d88457a52b4e2612 Mon Sep 17 00:00:00 2001 From: Pratyush Shukla Date: Thu, 9 Jan 2025 22:05:35 +0530 Subject: [PATCH 4/4] Revert "remove redundant code in ai21" This reverts commit 1a0740f2a4111f1a1b0fd5545483528c04158fb9. --- agentops/llms/providers/ai21.py | 75 ++++++++++++++--- examples/ai21_examples/ai21_examples.ipynb | 94 ++++++++++++++++++++-- 2 files changed, 154 insertions(+), 15 deletions(-) diff --git a/agentops/llms/providers/ai21.py b/agentops/llms/providers/ai21.py index 91795040a..8c907d525 100644 --- a/agentops/llms/providers/ai21.py +++ b/agentops/llms/providers/ai21.py @@ -29,6 +29,7 @@ def handle_response(self, response, kwargs, init_timestamp, session: Optional[Se from ai21.stream.async_stream import AsyncStream from ai21.models.chat.chat_completion_chunk import ChatCompletionChunk from ai21.models.chat.chat_completion_response import ChatCompletionResponse + from ai21.models.responses.answer_response import AnswerResponse llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) action_event = ActionEvent(init_timestamp=init_timestamp, params=kwargs) @@ -107,15 +108,27 @@ async def async_generator(): # Handle object responses try: - llm_event.returns = response - llm_event.agent_id = check_call_stack_for_agent_id() - llm_event.model = kwargs["model"] - llm_event.prompt = [message.model_dump() for message in kwargs["messages"]] - llm_event.prompt_tokens = response.usage.prompt_tokens - llm_event.completion = response.choices[0].message.model_dump() - llm_event.completion_tokens = response.usage.completion_tokens - llm_event.end_timestamp = get_ISO_time() - self._safe_record(session, llm_event) + if isinstance(response, ChatCompletionResponse): + llm_event.returns = response + llm_event.agent_id = check_call_stack_for_agent_id() + llm_event.model = kwargs["model"] + llm_event.prompt = [message.model_dump() for message in kwargs["messages"]] + llm_event.prompt_tokens = response.usage.prompt_tokens + llm_event.completion = response.choices[0].message.model_dump() + llm_event.completion_tokens = response.usage.completion_tokens + llm_event.end_timestamp = get_ISO_time() + self._safe_record(session, llm_event) + + elif isinstance(response, AnswerResponse): + action_event.returns = response + action_event.agent_id = check_call_stack_for_agent_id() + action_event.action_type = "Contextual Answers" + action_event.logs = [ + {"context": kwargs["context"], "question": kwargs["question"]}, + response.model_dump() if response.model_dump() else None, + ] + action_event.end_timestamp = get_ISO_time() + self._safe_record(session, action_event) except Exception as e: self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e)) @@ -132,6 +145,8 @@ async def async_generator(): def override(self): self._override_completion() self._override_completion_async() + self._override_answer() + self._override_answer_async() def _override_completion(self): from ai21.clients.studio.resources.chat import ChatCompletions @@ -169,6 +184,42 @@ async def patched_function(*args, **kwargs): # Override the original method with the patched one AsyncChatCompletions.create = patched_function + def _override_answer(self): + from ai21.clients.studio.resources.studio_answer import StudioAnswer + + global original_answer + original_answer = StudioAnswer.create + + def patched_function(*args, **kwargs): + # Call the original function with its original arguments + init_timestamp = get_ISO_time() + + session = kwargs.get("session", None) + if "session" in kwargs.keys(): + del kwargs["session"] + result = original_answer(*args, **kwargs) + return self.handle_response(result, kwargs, init_timestamp, session=session) + + StudioAnswer.create = patched_function + + def _override_answer_async(self): + from ai21.clients.studio.resources.studio_answer import AsyncStudioAnswer + + global original_answer_async + original_answer_async = AsyncStudioAnswer.create + + async def patched_function(*args, **kwargs): + # Call the original function with its original arguments + init_timestamp = get_ISO_time() + + session = kwargs.get("session", None) + if "session" in kwargs.keys(): + del kwargs["session"] + result = await original_answer_async(*args, **kwargs) + return self.handle_response(result, kwargs, init_timestamp, session=session) + + AsyncStudioAnswer.create = patched_function + def undo_override(self): if ( self.original_create is not None @@ -180,6 +231,12 @@ def undo_override(self): ChatCompletions, AsyncChatCompletions, ) + from ai21.clients.studio.resources.studio_answer import ( + StudioAnswer, + AsyncStudioAnswer, + ) ChatCompletions.create = self.original_create AsyncChatCompletions.create = self.original_create_async + StudioAnswer.create = self.original_answer + AsyncStudioAnswer.create = self.original_answer_async diff --git a/examples/ai21_examples/ai21_examples.ipynb b/examples/ai21_examples/ai21_examples.ipynb index 3f889f9ad..16afb4ad7 100644 --- a/examples/ai21_examples/ai21_examples.ipynb +++ b/examples/ai21_examples/ai21_examples.ipynb @@ -33,7 +33,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -54,7 +54,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -82,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -114,7 +114,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -171,7 +171,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -226,6 +226,88 @@ "await main()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Task-Specific Models Examples" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Contextual Answers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following example demonstrates the answering capability of AI21 without streaming." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CONTEXT = \"\"\"\n", + "In 2020 and 2021, enormous QE — approximately $4.4 trillion, or 18%, of 2021 gross\n", + "domestic product (GDP) — and enormous fiscal stimulus (which has been and\n", + "always will be inflationary) — approximately $5 trillion, or 21%, of 2021 GDP\n", + "— stabilized markets and allowed companies to raise enormous amounts of\n", + "capital. In addition, this infusion of capital saved many small businesses and\n", + "put more than $2.5 trillion in the hands of consumers and almost $1 trillion into\n", + "state and local coffers. These actions led to a rapid decline in unemployment, \n", + "dropping from 15% to under 4% in 20 months — the magnitude and speed of which were both\n", + "unprecedented. Additionally, the economy grew 7% in 2021 despite the arrival of\n", + "the Delta and Omicron variants and the global supply chain shortages, which were\n", + "largely fueled by the dramatic upswing in consumer spending and the shift in\n", + "that spend from services to goods.\n", + "\"\"\"\n", + "response = client.answer.create(\n", + " context=CONTEXT,\n", + " question=\"Did the economy shrink after the Omicron variant arrived?\",\n", + ")\n", + "print(response.answer)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Similarly, we can use streaming to get the answer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CONTEXT = \"\"\"\n", + "In the rapidly evolving field of Artificial Intelligence (AI), mathematical \n", + "foundations such as calculus, linear algebra, and statistics play a crucial role. \n", + "For instance, linear algebra is essential for understanding and developing machine \n", + "learning algorithms. It involves the study of vectors, matrices, and tensor operations \n", + "which are critical for performing transformations and optimizations. Additionally, \n", + "concepts from calculus like derivatives and integrals are used to optimize the \n", + "performance of AI models through gradient descent and other optimization techniques. \n", + "Statistics and probability form the backbone for making inferences and predictions, \n", + "enabling AI systems to learn from data and make decisions under uncertainty. \n", + "Understanding these mathematical principles allows for the development of more robust \n", + "and effective AI systems.\n", + "\"\"\"\n", + "response = client.answer.create(\n", + " context=CONTEXT,\n", + " question=\"Why is linear algebra important for machine learning algorithms?\",\n", + " stream=True,\n", + ")\n", + "print(response.answer)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -252,7 +334,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.16" + "version": "3.9.19" } }, "nbformat": 4,