From 4061d81d4ea15a43f70a68a1d95b679617179582 Mon Sep 17 00:00:00 2001 From: zhongzhiwei Date: Thu, 4 Sep 2025 21:19:56 +0800 Subject: [PATCH 1/2] feat: [Coda] implement Execute and ExecuteStreaming interfaces with streaming support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change-Id: I54cf6b1895a94ee1c82d6384be711eeaaed4875f Co-Authored-By: Coda refactor: [Coda] 重构execute_prompt和aexecute_prompt方法参数,统一底层传参方式 Change-Id: Ib480cc24e18ac2d68a0abf964aa29599b683d88f Co-Authored-By: Coda feat: [Coda] refactor execute interfaces and enhance prompt entities - Remove ExecuteParam exposure from entities directory - Refactor _build_execute_request method parameters - Restructure object definitions in internal/prompt/openapi.py - Align object fields with cozeloop-go implementation - Add base64_data field to ContentPart and improve converter.py - Enhance execute and streaming functionality Change-Id: I34d8f03f6b9de8b170f789562e71f92089ed6cfc Co-Authored-By: Coda refactor: [Coda] merge duplicate _convert_to_result methods into unified converter (LogID: 20250904221721010091104016861EFDC) Co-Authored-By: Coda gitignore Change-Id: I9042daa98841f3cd99036ccddeac04aa34695fde feat: [Coda] 修复cozeloop-python流处理逻辑与Go版本保持一致 - 重构StreamProcessor类,统一数据前缀处理格式(data:) - 实现完整的错误处理机制,包括错误前缀检测和累积器 - 添加对event:行的忽略处理,与Go版本保持一致 - 重构process_lines和aprocess_lines方法,消除重复代码 - 移除[DONE]标记处理,统一流结束检测逻辑 - 增强错误响应解析和RemoteServiceError异常处理 Change-Id: I910ea7182168eb55326c9e13aedf953f8718b180 Co-Authored-By: Coda feat: [Coda] 创建ptaas示例目录,提供execute_prompt和aexecute_prompt的完整使用示例 (LogID: 2025090500044019216800101072170CF) Co-Authored-By: Coda fix: [Coda] 修复流式调用中httpx上下文管理器的headers访问错误 (LogID: 202509050038421921680010102185489) Co-Authored-By: Coda feat: [Coda] remove StreamResponseWrapper and simplify post_stream to return httpx.Response directly (LogID: 2025090500503319216800101012637F9) Co-Authored-By: Coda feat: [Coda] implement context manager for StreamReader to fix post_stream context issue (LogID: 20250905013341192168001010942E193) Co-Authored-By: Coda refactor: [Coda] simplify ExecuteStreamReader to only accept stream_context (LogID: 20250905013341192168001010942E193) Co-Authored-By: Coda feat: [Coda] implement StreamReader based on Fornax Stream design (LogID: 20250905020453192168001010730B720) Co-Authored-By: Coda refactor: [Coda] remove Context classes, implement context manager directly in StreamReader (LogID: 20250905020453192168001010730B720) Co-Authored-By: Coda refactor: [Coda] replace execute_prompt with SSE-based implementation (LogID: 20250905020453192168001010730B720) Co-Authored-By: Coda ptaas Change-Id: I3be6b974680ea03ecb387263f851d2b510715625 fix: [Coda] handle non-200 HTTP status in execute_stream post_stream (LogID: 20250905140202010091104016024A92C) Co-Authored-By: Coda fix: [Coda] restore stream architecture with delayed status check (LogID: 20250905140202010091104016024A92C) Co-Authored-By: Coda fix: [Coda] Correctly parse streaming errors to raise RemoteServiceError (LogID: 20250905140202010091104016024A92C) Co-Authored-By: Coda fix: [Coda] read response content before json() in stream error handling (LogID: 20250905140202010091104016024A92C) Co-Authored-By: Coda fix: [Coda] 修复ExecuteStreamReader close()方法AttributeError (LogID: 202509051616010100911040169551803) Co-Authored-By: Coda ptaas example Change-Id: I46e0392a3f4e47b0f435b6ee34c6d15a46425fad feat: [Coda] 重写PTaaS Python示例,参考Go版本保持场景和注释一致性 (LogID: 20250908180534010091104016071B631) Co-Authored-By: Coda fix: [Coda] 添加同步流式示例并将所有PTaaS示例注释改为英文 (LogID: 20250908180534010091104016071B631) Co-Authored-By: Coda fix: [Coda] 修复PTaaS异步流处理中的上下文管理器协议错误 (LogID: 2025090819310001009110401618550CC) Co-Authored-By: Coda refactor: [Coda] 清理未使用的异步流处理方法 (LogID: 2025090819310001009110401618550CC) Co-Authored-By: Coda feat: [Coda] add ConfigDict(use_enum_values=True) to models with enum fields (LogID: 20250908201801010091104016415D94E) Co-Authored-By: Coda ptaas example Change-Id: If358626bdfa18dfc0e09bf5758e2950bef7fd800 feat: [Coda] add timeout parameter support for execute_prompt and aexecute_prompt methods (LogID: 202509082041490100911040169244AFC) Co-Authored-By: Coda feat: [Coda] set default timeout to 10min for execute_prompt and aexecute_prompt (LogID: 202509082041490100911040169244AFC) Co-Authored-By: Coda ptaas example Change-Id: Ie8552a762bc1804e840132449869dee68a5e4b7f delete useless example Change-Id: I3f0a22ced554f372537a68ce8cb411239311f71f delete useless example Change-Id: Ibfd4f38295ab6f0635d3c9ab96579587f6d35a15 --- .gitignore | 3 +- cozeloop/__init__.py | 4 +- cozeloop/_client.py | 115 +++++++- cozeloop/_noop.py | 33 ++- cozeloop/entities/prompt.py | 45 ++- cozeloop/entities/stream.py | 36 +++ cozeloop/internal/consts/__init__.py | 3 +- cozeloop/internal/httpclient/client.py | 61 +++- cozeloop/internal/prompt/converter.py | 204 ++++++++++++- .../internal/prompt/execute_stream_reader.py | 225 ++++++++++++++ cozeloop/internal/prompt/openapi.py | 106 ++++++- cozeloop/internal/prompt/prompt.py | 191 +++++++++++- cozeloop/internal/stream/__init__.py | 11 + .../internal/stream/base_stream_reader.py | 274 ++++++++++++++++++ cozeloop/internal/stream/sse.py | 185 ++++++++++++ cozeloop/prompt.py | 55 +++- examples/prompt/ptaas/__init__.py | 19 ++ examples/prompt/ptaas/ptaas.py | 165 +++++++++++ examples/prompt/ptaas/ptaas_jinja.py | 83 ++++++ examples/prompt/ptaas/ptaas_multimodal.py | 131 +++++++++ .../ptaas/ptaas_placeholder_variable.py | 86 ++++++ examples/prompt/ptaas/ptaas_timeout.py | 66 +++++ examples/prompt/ptaas/ptaas_with_label.py | 79 +++++ 23 files changed, 2139 insertions(+), 41 deletions(-) create mode 100755 cozeloop/entities/stream.py create mode 100755 cozeloop/internal/prompt/execute_stream_reader.py create mode 100755 cozeloop/internal/stream/__init__.py create mode 100755 cozeloop/internal/stream/base_stream_reader.py create mode 100755 cozeloop/internal/stream/sse.py create mode 100755 examples/prompt/ptaas/__init__.py create mode 100755 examples/prompt/ptaas/ptaas.py create mode 100755 examples/prompt/ptaas/ptaas_jinja.py create mode 100755 examples/prompt/ptaas/ptaas_multimodal.py create mode 100755 examples/prompt/ptaas/ptaas_placeholder_variable.py create mode 100755 examples/prompt/ptaas/ptaas_timeout.py create mode 100755 examples/prompt/ptaas/ptaas_with_label.py diff --git a/.gitignore b/.gitignore index 625b923..74c712b 100644 --- a/.gitignore +++ b/.gitignore @@ -34,4 +34,5 @@ venv.bak/ /.idea /.vscode /output -dist/ \ No newline at end of file +dist/ +.coda/ \ No newline at end of file diff --git a/cozeloop/__init__.py b/cozeloop/__init__.py index c66aa9d..3be38e2 100644 --- a/cozeloop/__init__.py +++ b/cozeloop/__init__.py @@ -18,6 +18,8 @@ close, get_prompt, prompt_format, + execute_prompt, + aexecute_prompt, start_span, get_span_from_context, get_span_from_header, @@ -30,4 +32,4 @@ ENV_JWT_OAUTH_PUBLIC_KEY_ID ) -from .span import SpanContext, Span +from .span import SpanContext, Span \ No newline at end of file diff --git a/cozeloop/_client.py b/cozeloop/_client.py index a511492..657cb9d 100644 --- a/cozeloop/_client.py +++ b/cozeloop/_client.py @@ -7,13 +7,14 @@ import os import threading from datetime import datetime -from typing import Dict, Any, List, Optional, Callable +from typing import Dict, Any, List, Optional, Callable, Union import httpx from cozeloop.client import Client from cozeloop._noop import NOOP_SPAN, _NoopClient -from cozeloop.entities.prompt import Prompt, Message, PromptVariable +from cozeloop.entities.prompt import Prompt, Message, PromptVariable, ExecuteResult +from cozeloop.entities.stream import StreamReader from cozeloop.internal import consts, httpclient from cozeloop.internal.consts import ClientClosedError from cozeloop.internal.httpclient import Auth @@ -269,6 +270,62 @@ def prompt_format(self, prompt: Prompt, variables: Dict[str, PromptVariable]) -> raise ClientClosedError() return self._prompt_provider.prompt_format(prompt, variables) + def execute_prompt( + self, + prompt_key: str, + *, + version: Optional[str] = None, + label: Optional[str] = None, + variable_vals: Optional[Dict[str, Any]] = None, + messages: Optional[List[Message]] = None, + stream: bool = False, + timeout: Optional[int] = None + ) -> Union[ExecuteResult, StreamReader[ExecuteResult]]: + """ + 执行Prompt请求 + + :param timeout: 请求超时时间(秒),可选,默认为600秒(10分钟) + """ + if self._closed: + raise ClientClosedError() + return self._prompt_provider.execute_prompt( + prompt_key, + version=version, + label=label, + variable_vals=variable_vals, + messages=messages, + stream=stream, + timeout=timeout + ) + + async def aexecute_prompt( + self, + prompt_key: str, + *, + version: Optional[str] = None, + label: Optional[str] = None, + variable_vals: Optional[Dict[str, Any]] = None, + messages: Optional[List[Message]] = None, + stream: bool = False, + timeout: Optional[int] = None + ) -> Union[ExecuteResult, StreamReader[ExecuteResult]]: + """ + 异步执行Prompt请求 + + :param timeout: 请求超时时间(秒),可选,默认为600秒(10分钟) + """ + if self._closed: + raise ClientClosedError() + return await self._prompt_provider.aexecute_prompt( + prompt_key, + version=version, + label=label, + variable_vals=variable_vals, + messages=messages, + stream=stream, + timeout=timeout + ) + def start_span( self, name: str, @@ -368,6 +425,58 @@ def prompt_format(prompt: Prompt, variables: Dict[str, Any]) -> List[Message]: return get_default_client().prompt_format(prompt, variables) +def execute_prompt( + prompt_key: str, + *, + version: Optional[str] = None, + label: Optional[str] = None, + variable_vals: Optional[Dict[str, Any]] = None, + messages: Optional[List[Message]] = None, + stream: bool = False, + timeout: Optional[int] = None +) -> Union[ExecuteResult, StreamReader[ExecuteResult]]: + """ + 执行Prompt请求 + + :param timeout: 请求超时时间(秒),可选,默认为600秒(10分钟) + """ + return get_default_client().execute_prompt( + prompt_key, + version=version, + label=label, + variable_vals=variable_vals, + messages=messages, + stream=stream, + timeout=timeout + ) + + +async def aexecute_prompt( + prompt_key: str, + *, + version: Optional[str] = None, + label: Optional[str] = None, + variable_vals: Optional[Dict[str, Any]] = None, + messages: Optional[List[Message]] = None, + stream: bool = False, + timeout: Optional[int] = None +) -> Union[ExecuteResult, StreamReader[ExecuteResult]]: + """ + 异步执行Prompt请求 + + :param timeout: 请求超时时间(秒),可选,默认为600秒(10分钟) + """ + return await get_default_client().aexecute_prompt( + prompt_key, + version=version, + label=label, + variable_vals=variable_vals, + messages=messages, + stream=stream, + timeout=timeout + ) + + def start_span(name: str, span_type: str, *, start_time: Optional[int] = None, child_of: Optional[SpanContext] = None) -> Span: return get_default_client().start_span(name, span_type, start_time=start_time, child_of=child_of) @@ -382,4 +491,4 @@ def get_span_from_header(header: Dict[str, str]) -> SpanContext: def flush() -> None: - return get_default_client().flush() + return get_default_client().flush() \ No newline at end of file diff --git a/cozeloop/_noop.py b/cozeloop/_noop.py index 56d508d..b36e7d3 100644 --- a/cozeloop/_noop.py +++ b/cozeloop/_noop.py @@ -2,10 +2,11 @@ # SPDX-License-Identifier: MIT import logging -from typing import Dict, Optional, List +from typing import Dict, Optional, List, Union, Any from cozeloop.client import Client -from cozeloop.entities.prompt import Prompt, Message, PromptVariable +from cozeloop.entities.prompt import Prompt, Message, PromptVariable, ExecuteResult +from cozeloop.entities.stream import StreamReader from cozeloop.internal.trace.noop_span import NoopSpan from cozeloop.span import SpanContext, Span @@ -35,6 +36,32 @@ def prompt_format(self, prompt: Prompt, variables: Dict[str, PromptVariable]) -> logger.warning(f"Noop client not supported. {self.new_exception}") raise self.new_exception + def execute_prompt( + self, + prompt_key: str, + *, + version: Optional[str] = None, + label: Optional[str] = None, + variable_vals: Optional[Dict[str, Any]] = None, + messages: Optional[List[Message]] = None, + stream: bool = False + ) -> Union[ExecuteResult, StreamReader[ExecuteResult]]: + logger.warning(f"Noop client not supported. {self.new_exception}") + raise self.new_exception + + async def aexecute_prompt( + self, + prompt_key: str, + *, + version: Optional[str] = None, + label: Optional[str] = None, + variable_vals: Optional[Dict[str, Any]] = None, + messages: Optional[List[Message]] = None, + stream: bool = False + ) -> Union[ExecuteResult, StreamReader[ExecuteResult]]: + logger.warning(f"Noop client not supported. {self.new_exception}") + raise self.new_exception + def start_span(self, name: str, span_type: str, *, start_time: Optional[int] = None, child_of: Optional[SpanContext] = None, start_new_trace: bool = False) -> Span: logger.warning(f"Noop client not supported. {self.new_exception}") @@ -49,4 +76,4 @@ def get_span_from_header(self, header: Dict[str, str]) -> SpanContext: return NOOP_SPAN def flush(self) -> None: - logger.warning(f"Noop client not supported. {self.new_exception}") + logger.warning(f"Noop client not supported. {self.new_exception}") \ No newline at end of file diff --git a/cozeloop/entities/prompt.py b/cozeloop/entities/prompt.py index 157f82a..b076ba7 100644 --- a/cozeloop/entities/prompt.py +++ b/cozeloop/entities/prompt.py @@ -3,8 +3,8 @@ from enum import Enum from typing import List, Optional, Union - -from pydantic import BaseModel, Field, ConfigDict +from typing import List, Optional, Union, Dict, Any +from pydantic import BaseModel class TemplateType(str, Enum): @@ -47,6 +47,7 @@ class ToolChoiceType(str, Enum): class ContentType(str, Enum): TEXT = "text" IMAGE_URL = "image_url" + BASE64_DATA = "base64_data" MULTI_PART_VARIABLE = "multi_part_variable" @@ -54,12 +55,28 @@ class ContentPart(BaseModel): type: ContentType text: Optional[str] = None image_url: Optional[str] = None + base64_data: Optional[str] = None + + +class FunctionCall(BaseModel): + name: str + arguments: Optional[str] = None + + +class ToolCall(BaseModel): + index: int + id: str + type: ToolType + function_call: Optional[FunctionCall] = None class Message(BaseModel): role: Role + reasoning_content: Optional[str] = None content: Optional[str] = None parts: Optional[List[ContentPart]] = None + tool_call_id: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = None class VariableDef(BaseModel): @@ -109,5 +126,27 @@ class Prompt(BaseModel): llm_config: Optional[LLMConfig] = None +class ExecuteParam(BaseModel): + """Execute参数""" + prompt_key: str + version: str = "" + label: str = "" + variable_vals: Optional[Dict[str, Any]] = None + messages: Optional[List[Message]] = None + + +class TokenUsage(BaseModel): + """Token使用统计""" + input_tokens: int = 0 + output_tokens: int = 0 + + +class ExecuteResult(BaseModel): + """Execute结果""" + message: Optional[Message] = None + finish_reason: Optional[str] = None + usage: Optional[TokenUsage] = None + + MessageLikeObject = Union[Message, List[Message]] -PromptVariable = Union[str, MessageLikeObject] +PromptVariable = Union[str, MessageLikeObject] \ No newline at end of file diff --git a/cozeloop/entities/stream.py b/cozeloop/entities/stream.py new file mode 100755 index 0000000..4f7d6a1 --- /dev/null +++ b/cozeloop/entities/stream.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +from abc import ABC, abstractmethod +from typing import TypeVar, Generic, AsyncIterator, Iterator + +T = TypeVar('T') + + +class StreamReader(ABC, Generic[T]): + """流式读取器接口""" + + @abstractmethod + def __iter__(self) -> Iterator[T]: + """支持同步迭代 - for循环直接读取""" + pass + + @abstractmethod + def __next__(self) -> T: + """支持next()函数调用""" + pass + + @abstractmethod + def __aiter__(self) -> AsyncIterator[T]: + """支持异步迭代 - async for循环直接读取""" + pass + + @abstractmethod + async def __anext__(self) -> T: + """支持async next()调用""" + pass + + @abstractmethod + def close(self): + """关闭流""" + pass \ No newline at end of file diff --git a/cozeloop/internal/consts/__init__.py b/cozeloop/internal/consts/__init__.py index 2fa0c50..0736e1d 100644 --- a/cozeloop/internal/consts/__init__.py +++ b/cozeloop/internal/consts/__init__.py @@ -15,6 +15,7 @@ DEFAULT_PROMPT_CACHE_REFRESH_INTERVAL = 60 DEFAULT_TIMEOUT = 3 DEFAULT_UPLOAD_TIMEOUT = 30 +DEFAULT_PROMPT_EXECUTE_TIMEOUT = 600 # 10分钟,专用于execute_prompt和aexecute_prompt方法 LOG_ID_HEADER = "x-tt-logid" AUTHORIZE_HEADER = "Authorization" @@ -118,4 +119,4 @@ OUTPUT: MAX_BYTES_OF_ONE_TAG_VALUE_OF_INPUT_OUTPUT, } -BAGGAGE_SPECIAL_CHARS = {"=", ","} +BAGGAGE_SPECIAL_CHARS = {"=", ","} \ No newline at end of file diff --git a/cozeloop/internal/httpclient/client.py b/cozeloop/internal/httpclient/client.py index 8ee77af..4e36d64 100644 --- a/cozeloop/internal/httpclient/client.py +++ b/cozeloop/internal/httpclient/client.py @@ -3,7 +3,7 @@ import logging import os -from typing import Optional, Dict, Union, IO, Type, Tuple +from typing import Optional, Dict, Union, IO, Type, Tuple, Any import httpx import pydantic @@ -20,6 +20,7 @@ FileType = Tuple[str, FileContent] + class Client: def __init__( self, @@ -122,3 +123,61 @@ def upload_file( ) -> T: _file = {"file": (file_name, file)} return self.request(path, "POST", response_model, form=form, files=_file, timeout=self.upload_timeout) + + def post_stream( + self, + path: str, + json: Union[BaseModel, Dict] = None, + timeout: Optional[int] = None, + ): + """发起流式POST请求,返回stream_context""" + url = self._build_url(path) + headers = self._set_headers({"Content-Type": "application/json"}) + + if isinstance(json, BaseModel): + json = json.model_dump(by_alias=True) + + _timeout = timeout if timeout is not None else self.timeout + + try: + # 返回stream_context,让StreamReader管理上下文 + stream_context = self.http_client.stream( + "POST", + url, + json=json, + headers=headers, + timeout=_timeout + ) + return stream_context + except httpx.HTTPError as e: + logger.error(f"Http client stream request failed, path: {path}, err: {e}.") + raise consts.NetworkError from e + + async def apost_stream( + self, + path: str, + json: Union[BaseModel, Dict] = None, + timeout: Optional[int] = None, + ): + """发起异步流式POST请求,返回stream_context""" + url = self._build_url(path) + headers = self._set_headers({"Content-Type": "application/json"}) + + if isinstance(json, BaseModel): + json = json.model_dump(by_alias=True) + + _timeout = timeout if timeout is not None else self.timeout + + try: + # 返回stream_context,让StreamReader管理上下文 + stream_context = self.http_client.stream( + "POST", + url, + json=json, + headers=headers, + timeout=_timeout + ) + return stream_context + except httpx.HTTPError as e: + logger.error(f"Http client async stream request failed, path: {path}, err: {e}.") + raise consts.NetworkError from e \ No newline at end of file diff --git a/cozeloop/internal/prompt/converter.py b/cozeloop/internal/prompt/converter.py index 29bb38a..45d95b2 100644 --- a/cozeloop/internal/prompt/converter.py +++ b/cozeloop/internal/prompt/converter.py @@ -1,7 +1,7 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT -from typing import List, Dict +from typing import List, Dict, Optional from cozeloop.spec.tracespec import PromptInput, PromptOutput, ModelMessage, PromptArgument, ModelMessagePart, \ ModelMessagePartType, ModelImageURL, PromptArgumentValueType @@ -22,6 +22,9 @@ PromptVariable, ContentType as EntityContentType, ContentPart as EntityContentPart, + ToolCall as EntityToolCall, + FunctionCall as EntityFunctionCall, + TokenUsage as EntityTokenUsage, ) from cozeloop.internal.prompt.openapi import ( @@ -40,10 +43,14 @@ TemplateType as OpenAPITemplateType, ContentType as OpenAPIContentType, ContentPart as OpenAPIContentPart, + ToolCall as OpenAPIToolCall, + FunctionCall as OpenAPIFunctionCall, + TokenUsage as OpenAPITokenUsage, ) def _convert_role(openapi_role: OpenAPIRole) -> EntityRole: + """转换角色类型""" role_mapping = { OpenAPIRole.SYSTEM: EntityRole.SYSTEM, OpenAPIRole.USER: EntityRole.USER, @@ -51,33 +58,64 @@ def _convert_role(openapi_role: OpenAPIRole) -> EntityRole: OpenAPIRole.TOOL: EntityRole.TOOL, OpenAPIRole.PLACEHOLDER: EntityRole.PLACEHOLDER } - return role_mapping.get(openapi_role, EntityRole.USER) # Default to USER type + return role_mapping.get(openapi_role, EntityRole.USER) def _convert_content_type(openapi_type: OpenAPIContentType) -> EntityContentType: + """转换内容类型""" content_type_mapping = { OpenAPIContentType.TEXT: EntityContentType.TEXT, + OpenAPIContentType.IMAGE_URL: EntityContentType.IMAGE_URL, + OpenAPIContentType.BASE64_DATA: EntityContentType.BASE64_DATA, OpenAPIContentType.MULTI_PART_VARIABLE: EntityContentType.MULTI_PART_VARIABLE, } return content_type_mapping.get(openapi_type, EntityContentType.TEXT) -def to_content_part(openapi_part: OpenAPIContentPart) -> EntityContentPart: +def _convert_content_part(openapi_part: OpenAPIContentPart) -> EntityContentPart: + """转换内容部分,确保text、image_url、base64_data字段都被转换""" return EntityContentPart( type=_convert_content_type(openapi_part.type), - text=openapi_part.text + text=openapi_part.text, + image_url=openapi_part.image_url, + base64_data=openapi_part.base64_data + ) + + +def _convert_function_call(func_call: Optional[OpenAPIFunctionCall]) -> Optional[EntityFunctionCall]: + """转换函数调用,确保name、arguments字段都被转换""" + if func_call is None: + return None + return EntityFunctionCall( + name=func_call.name, + arguments=func_call.arguments + ) + + +def _convert_tool_call(tool_call: OpenAPIToolCall) -> EntityToolCall: + """转换工具调用,确保index、id、type、function_call字段都被转换""" + return EntityToolCall( + index=tool_call.index, + id=tool_call.id, + type=_convert_tool_type(tool_call.type), + function_call=_convert_function_call(tool_call.function_call) ) def _convert_message(msg: OpenAPIMessage) -> EntityMessage: + """转换消息,确保role、content、reasoning_content、tool_call_id、tool_calls字段都被转换""" return EntityMessage( role=_convert_role(msg.role), + reasoning_content=msg.reasoning_content, content=msg.content, - parts=[to_content_part(part) for part in msg.parts] if msg.parts else None + parts=[_convert_content_part(part) for part in msg.parts] if msg.parts else None, + tool_call_id=msg.tool_call_id, + tool_calls=[_convert_tool_call(tool_call) for tool_call in msg.tool_calls] if msg.tool_calls else None ) def _convert_variable_type(openapi_type: OpenAPIVariableType) -> EntityVariableType: + """转换变量类型""" type_mapping = { OpenAPIVariableType.STRING: EntityVariableType.STRING, OpenAPIVariableType.PLACEHOLDER: EntityVariableType.PLACEHOLDER, @@ -92,10 +130,11 @@ def _convert_variable_type(openapi_type: OpenAPIVariableType) -> EntityVariableT OpenAPIVariableType.ARRAY_OBJECT: EntityVariableType.ARRAY_OBJECT, OpenAPIVariableType.MULTI_PART: EntityVariableType.MULTI_PART, } - return type_mapping.get(openapi_type, EntityVariableType.STRING) # Default to STRING type + return type_mapping.get(openapi_type, EntityVariableType.STRING) def _convert_variable_def(var_def: OpenAPIVariableDef) -> EntityVariableDef: + """转换变量定义""" return EntityVariableDef( key=var_def.key, desc=var_def.desc, @@ -104,6 +143,7 @@ def _convert_variable_def(var_def: OpenAPIVariableDef) -> EntityVariableDef: def _convert_function(func: OpenAPIFunction) -> EntityFunction: + """转换函数定义""" return EntityFunction( name=func.name, description=func.description, @@ -112,13 +152,15 @@ def _convert_function(func: OpenAPIFunction) -> EntityFunction: def _convert_tool_type(openapi_tool_type: OpenAPIToolType) -> EntityToolType: + """转换工具类型""" type_mapping = { OpenAPIToolType.FUNCTION: EntityToolType.FUNCTION, } - return type_mapping.get(openapi_tool_type, EntityToolType.FUNCTION) # Default to FUNCTION type + return type_mapping.get(openapi_tool_type, EntityToolType.FUNCTION) def _convert_tool(tool: OpenAPITool) -> EntityTool: + """转换工具定义""" return EntityTool( type=_convert_tool_type(tool.type), function=_convert_function(tool.function) if tool.function else None @@ -126,20 +168,23 @@ def _convert_tool(tool: OpenAPITool) -> EntityTool: def _convert_tool_choice_type(openapi_tool_choice_type: OpenAPIChoiceType) -> EntityToolChoiceType: + """转换工具选择类型""" choice_mapping = { OpenAPIChoiceType.AUTO: EntityToolChoiceType.AUTO, OpenAPIChoiceType.NONE: EntityToolChoiceType.NONE } - return choice_mapping.get(openapi_tool_choice_type, EntityToolChoiceType.AUTO) # Default to AUTO type + return choice_mapping.get(openapi_tool_choice_type, EntityToolChoiceType.AUTO) def _convert_tool_call_config(config: OpenAPIToolCallConfig) -> EntityToolCallConfig: + """转换工具调用配置""" return EntityToolCallConfig( tool_choice=_convert_tool_choice_type(config.tool_choice) ) def _convert_llm_config(config: OpenAPIModelConfig) -> EntityModelConfig: + """转换LLM配置""" return EntityModelConfig( temperature=config.temperature, max_tokens=config.max_tokens, @@ -152,14 +197,16 @@ def _convert_llm_config(config: OpenAPIModelConfig) -> EntityModelConfig: def _convert_template_type(openapi_template_type: OpenAPITemplateType) -> EntityTemplateType: + """转换模板类型""" template_mapping = { OpenAPITemplateType.NORMAL: EntityTemplateType.NORMAL, OpenAPITemplateType.JINJA2: EntityTemplateType.JINJA2 } - return template_mapping.get(openapi_template_type, EntityTemplateType.NORMAL) # Default to NORMAL type + return template_mapping.get(openapi_template_type, EntityTemplateType.NORMAL) def _convert_prompt_template(template: OpenAPIPromptTemplate) -> EntityPromptTemplate: + """转换提示模板""" return EntityPromptTemplate( template_type=_convert_template_type(template.template_type), messages=[_convert_message(msg) for msg in template.messages] if template.messages else None, @@ -168,8 +215,18 @@ def _convert_prompt_template(template: OpenAPIPromptTemplate) -> EntityPromptTem ) +def _convert_token_usage(usage: Optional[OpenAPITokenUsage]) -> Optional[EntityTokenUsage]: + """转换Token使用统计,确保input_tokens、output_tokens字段都被转换""" + if usage is None: + return None + return EntityTokenUsage( + input_tokens=usage.input_tokens, + output_tokens=usage.output_tokens + ) + + def _convert_prompt(prompt: OpenAPIPrompt) -> EntityPrompt: - """Convert OpenAPI Prompt object to entity Prompt object""" + """转换OpenAPI Prompt对象到entity Prompt对象""" return EntityPrompt( workspace_id=prompt.workspace_id, prompt_key=prompt.prompt_key, @@ -181,7 +238,125 @@ def _convert_prompt(prompt: OpenAPIPrompt) -> EntityPrompt: ) +# 公开的转换函数 +def to_content_part(openapi_part: OpenAPIContentPart) -> EntityContentPart: + """公开的内容部分转换函数""" + return _convert_content_part(openapi_part) + + +def to_prompt(openapi_prompt: OpenAPIPrompt) -> EntityPrompt: + """公开的提示转换函数""" + return _convert_prompt(openapi_prompt) + + +def to_message(openapi_message: OpenAPIMessage) -> EntityMessage: + """公开的消息转换函数""" + return _convert_message(openapi_message) + + +def to_token_usage(openapi_usage: Optional[OpenAPITokenUsage]) -> Optional[EntityTokenUsage]: + """公开的Token使用统计转换函数""" + return _convert_token_usage(openapi_usage) + + +def convert_execute_data_to_result(data) -> 'ExecuteResult': + """将ExecuteData转换为ExecuteResult + + 统一的转换入口,复用现有转换逻辑 + 用于替代 prompt.py 和 reader.py 中的重复实现 + + Args: + data: ExecuteData对象,包含执行结果数据 + + Returns: + ExecuteResult: 转换后的执行结果对象 + """ + from cozeloop.entities.prompt import ExecuteResult + + return ExecuteResult( + message=to_message(data.message) if data.message else None, + finish_reason=data.finish_reason, + usage=to_token_usage(data.usage) + ) + + +def to_openapi_message(message: EntityMessage) -> OpenAPIMessage: + """将EntityMessage转换为OpenAPIMessage""" + return OpenAPIMessage( + role=_to_openapi_role(message.role), + reasoning_content=message.reasoning_content, + content=message.content, + parts=[_to_openapi_content_part(part) for part in message.parts] if message.parts else None, + tool_call_id=message.tool_call_id, + tool_calls=[_to_openapi_tool_call(tool_call) for tool_call in + message.tool_calls] if message.tool_calls else None + ) + + +def _to_openapi_role(role: EntityRole) -> OpenAPIRole: + """将EntityRole转换为OpenAPIRole""" + role_mapping = { + EntityRole.SYSTEM: OpenAPIRole.SYSTEM, + EntityRole.USER: OpenAPIRole.USER, + EntityRole.ASSISTANT: OpenAPIRole.ASSISTANT, + EntityRole.TOOL: OpenAPIRole.TOOL, + EntityRole.PLACEHOLDER: OpenAPIRole.PLACEHOLDER + } + return role_mapping.get(role, OpenAPIRole.USER) + + +def _to_openapi_content_part(part: EntityContentPart) -> OpenAPIContentPart: + """将EntityContentPart转换为OpenAPIContentPart""" + return OpenAPIContentPart( + type=_to_openapi_content_type(part.type), + text=part.text, + image_url=part.image_url, + base64_data=part.base64_data + ) + + +def _to_openapi_content_type(content_type: EntityContentType) -> OpenAPIContentType: + """将EntityContentType转换为OpenAPIContentType""" + type_mapping = { + EntityContentType.TEXT: OpenAPIContentType.TEXT, + EntityContentType.IMAGE_URL: OpenAPIContentType.IMAGE_URL, + EntityContentType.BASE64_DATA: OpenAPIContentType.BASE64_DATA, + EntityContentType.MULTI_PART_VARIABLE: OpenAPIContentType.MULTI_PART_VARIABLE + } + return type_mapping.get(content_type, OpenAPIContentType.TEXT) + + +def _to_openapi_tool_call(tool_call: EntityToolCall) -> OpenAPIToolCall: + """将EntityToolCall转换为OpenAPIToolCall""" + return OpenAPIToolCall( + index=tool_call.index, + id=tool_call.id, + type=_to_openapi_tool_type(tool_call.type), + function_call=_to_openapi_function_call(tool_call.function_call) + ) + + +def _to_openapi_function_call(func_call: Optional[EntityFunctionCall]) -> Optional[OpenAPIFunctionCall]: + """将EntityFunctionCall转换为OpenAPIFunctionCall""" + if func_call is None: + return None + return OpenAPIFunctionCall( + name=func_call.name, + arguments=func_call.arguments + ) + + +def _to_openapi_tool_type(tool_type: EntityToolType) -> OpenAPIToolType: + """将EntityToolType转换为OpenAPIToolType""" + type_mapping = { + EntityToolType.FUNCTION: OpenAPIToolType.FUNCTION, + } + return type_mapping.get(tool_type, OpenAPIToolType.FUNCTION) + + +# Span相关转换函数 def _to_span_prompt_input(messages: List[EntityMessage], variables: Dict[str, PromptVariable]) -> PromptInput: + """转换到Span的提示输入""" return PromptInput( templates=_to_span_messages(messages), arguments=_to_span_arguments(variables), @@ -189,12 +364,14 @@ def _to_span_prompt_input(messages: List[EntityMessage], variables: Dict[str, Pr def _to_span_prompt_output(messages: List[EntityMessage]) -> PromptOutput: + """转换到Span的提示输出""" return PromptOutput( prompts=_to_span_messages(messages) ) def _to_span_messages(messages: List[EntityMessage]) -> List[ModelMessage]: + """转换消息列表到Span格式""" return [ ModelMessage( role=msg.role, @@ -205,14 +382,17 @@ def _to_span_messages(messages: List[EntityMessage]) -> List[ModelMessage]: def _to_span_arguments(arguments: Dict[str, PromptVariable]) -> List[PromptArgument]: + """转换参数字典到Span格式""" return [ to_span_argument(key, value) for key, value in arguments.items() ] def to_span_argument(key: str, value: any) -> PromptArgument: + """转换单个参数到Span格式""" converted_value = str(value) value_type = PromptArgumentValueType.TEXT + # 判断是否是多模态变量 if isinstance(value, list) and all(isinstance(part, EntityContentPart) for part in value): value_type = PromptArgumentValueType.MODEL_MESSAGE_PART @@ -232,20 +412,24 @@ def to_span_argument(key: str, value: any) -> PromptArgument: def _to_span_content_type(entity_type: EntityContentType) -> ModelMessagePartType: + """转换内容类型到Span格式""" span_content_type_mapping = { EntityContentType.TEXT: ModelMessagePartType.TEXT, EntityContentType.IMAGE_URL: ModelMessagePartType.IMAGE, + EntityContentType.BASE64_DATA: ModelMessagePartType.IMAGE, EntityContentType.MULTI_PART_VARIABLE: ModelMessagePartType.MULTI_PART_VARIABLE, } return span_content_type_mapping.get(entity_type, ModelMessagePartType.TEXT) def _to_span_content_part(entity_part: EntityContentPart) -> ModelMessagePart: + """转换内容部分到Span格式""" image_url = None if entity_part.image_url is not None: image_url = ModelImageURL( url=entity_part.image_url ) + return ModelMessagePart( type=_to_span_content_type(entity_part.type), text=entity_part.text, diff --git a/cozeloop/internal/prompt/execute_stream_reader.py b/cozeloop/internal/prompt/execute_stream_reader.py new file mode 100755 index 0000000..4227d90 --- /dev/null +++ b/cozeloop/internal/prompt/execute_stream_reader.py @@ -0,0 +1,225 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import json +import logging +from typing import Optional + +from cozeloop.entities.prompt import ExecuteResult +from cozeloop.internal.consts.error import RemoteServiceError +from cozeloop.internal.prompt.converter import convert_execute_data_to_result +from cozeloop.internal.prompt.openapi import ExecuteData +from cozeloop.internal.stream.base_stream_reader import BaseStreamReader +from cozeloop.internal.stream.sse import ServerSentEvent + +logger = logging.getLogger(__name__) + + +class ExecuteStreamReader(BaseStreamReader[ExecuteResult]): + """ + Prompt执行结果的StreamReader实现 + + 继承自BaseStreamReader,实现具体的SSE数据解析逻辑 + 将SSE事件中的数据解析为ExecuteResult对象 + 支持同步和异步迭代器模式,提供完整的流式处理能力 + 直接实现上下文管理器,无需单独的Context类 + """ + + def __init__(self, stream_context, log_id: str = ""): + """ + 初始化ExecuteStreamReader + + Args: + stream_context: 流上下文管理器 + log_id: 日志ID,用于错误追踪 + """ + self._stream_context = stream_context + self._response = None + self._context_entered = False + self.log_id = log_id + self._closed = False + # 不调用super().__init__,因为还没有response对象 + + def _parse_sse_data(self, sse: ServerSentEvent) -> Optional[ExecuteResult]: + """ + 解析SSE数据为ExecuteResult对象 + + Args: + sse: ServerSentEvent对象 + + Returns: + Optional[ExecuteResult]: 解析后的ExecuteResult对象,如果不需要返回则为None + """ + # 跳过空数据 + if not sse.data or sse.data.strip() == "": + return None + + # 跳过非data事件 + if sse.event and sse.event != "data": + logger.debug(f"Skipping non-data event: {sse.event}") + return None + + try: + # 解析JSON数据 + data_dict = sse.json() + + # 验证数据结构 + if not isinstance(data_dict, dict): + logger.warning(f"Invalid SSE data format, expected dict, got {type(data_dict)}") + return None + + # 将字典转换为ExecuteData对象 + execute_data = ExecuteData.model_validate(data_dict) + + # 转换为ExecuteResult + result = convert_execute_data_to_result(execute_data) + + logger.debug(f"Successfully parsed SSE data to ExecuteResult: {result}") + return result + + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse SSE data as JSON: {e}, data: {sse.data}") + return None + except ValueError as e: + logger.warning(f"Failed to validate ExecuteData: {e}, data: {sse.data}") + return None + except Exception as e: + logger.error(f"Unexpected error parsing SSE data: {e}, data: {sse.data}") + return None + + def __enter__(self): + """同步上下文管理器入口""" + if not self._context_entered: + self._response = self._stream_context.__enter__() # 检查HTTP状态码 + if self._response.status_code != 200: + try: + # 先读取完整响应内容 + self._response.read() + + # 现在可以安全调用json() + error_data = self._response.json() + log_id = self._response.headers.get("x-tt-logid", "") + error_code = error_data.get('code', 0) + error_msg = error_data.get('msg', 'Unknown error') + # 确保关闭stream_context + self._stream_context.__exit__(None, None, None) + raise RemoteServiceError(self._response.status_code, error_code, error_msg, log_id) + except Exception as e: + self._stream_context.__exit__(None, None, None) + if isinstance(e, RemoteServiceError): + raise + from cozeloop.internal.consts.error import InternalError + raise InternalError(f"Failed to parse error response: {e}") + + # 初始化BaseStreamReader的属性 + super().__init__(self._response, self.log_id) + self._context_entered = True + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """同步上下文管理器出口""" + self.close() + if self._context_entered: + return self._stream_context.__exit__(exc_type, exc_val, exc_tb) + + async def __aenter__(self): + """异步上下文管理器入口""" + if not self._context_entered: + self._response = self._stream_context.__enter__() # 检查HTTP状态码(同步版本逻辑) + if self._response.status_code != 200: + try: + # 先读取完整响应内容 + await self._response.aread() + + # 现在可以安全调用json() + error_data = self._response.json() + log_id = self._response.headers.get("x-tt-logid", "") + error_code = error_data.get('code', 0) + error_msg = error_data.get('msg', 'Unknown error') + self._stream_context.__exit__(None, None, None) + raise RemoteServiceError(self._response.status_code, error_code, error_msg, log_id) + except Exception as e: + self._stream_context.__exit__(None, None, None) + if isinstance(e, RemoteServiceError): + raise + from cozeloop.internal.consts.error import InternalError + raise InternalError(f"Failed to parse error response: {e}") + + # 初始化BaseStreamReader的属性 + super().__init__(self._response, self.log_id) + self._context_entered = True + + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """异步上下文管理器出口""" + await self.aclose() + if self._context_entered: + return self._stream_context.__exit__(exc_type, exc_val, exc_tb) + + def __iter__(self): + """支持for循环直接读取""" + if not self._context_entered: + self.__enter__() + return super().__iter__() + + def __aiter__(self): + """支持async for循环直接读取""" + # 注意:异步版本需要特殊处理 + return self._aiter_impl() + + async def _aiter_impl(self): + """异步迭代器实现""" + if not self._context_entered: + await self.__aenter__() + async for item in super().__aiter__(): + yield item + + def close(self) -> None: + """关闭流""" + self._closed = True + # 如果还没有进入上下文,直接关闭stream_context + if not self._context_entered: + if hasattr(self._stream_context, '__exit__'): + try: + self._stream_context.__exit__(None, None, None) + except Exception: + pass + return + + # 如果已经进入上下文,调用父类的close方法 + if hasattr(self, 'response'): + super().close() + else: + # 如果response属性不存在,只关闭stream_context + if hasattr(self._stream_context, '__exit__'): + try: + self._stream_context.__exit__(None, None, None) + except Exception: + pass + + async def aclose(self) -> None: + """异步关闭流""" + self._closed = True + # 如果还没有进入上下文,直接关闭stream_context + if not self._context_entered: + if hasattr(self._stream_context, '__exit__'): + try: + self._stream_context.__exit__(None, None, None) + except Exception: + pass + return + + # 如果已经进入上下文,调用父类的aclose方法 + if hasattr(self, 'response'): + await super().aclose() + else: + # 如果response属性不存在,只关闭stream_context + if hasattr(self._stream_context, '__exit__'): + try: + self._stream_context.__exit__(None, None, None) + except Exception: + pass \ No newline at end of file diff --git a/cozeloop/internal/prompt/openapi.py b/cozeloop/internal/prompt/openapi.py index 12b64ea..17da6ab 100644 --- a/cozeloop/internal/prompt/openapi.py +++ b/cozeloop/internal/prompt/openapi.py @@ -5,11 +5,13 @@ from typing import List, Optional import pydantic -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from cozeloop.internal.httpclient import Client, BaseResponse MPULL_PROMPT_PATH = "/v1/loop/prompts/mget" +EXECUTE_PROMPT_PATH = "/v1/loop/prompts/execute" +EXECUTE_STREAMING_PROMPT_PATH = "/v1/loop/prompts/execute_streaming" MAX_PROMPT_QUERY_BATCH_SIZE = 25 @@ -52,21 +54,48 @@ class ToolChoiceType(str, Enum): class ContentType(str, Enum): TEXT = "text" + IMAGE_URL = "image_url" + BASE64_DATA = "base64_data" MULTI_PART_VARIABLE = "multi_part_variable" class ContentPart(BaseModel): + model_config = ConfigDict(use_enum_values=True) + type: ContentType text: Optional[str] = None + image_url: Optional[str] = None + base64_data: Optional[str] = None + + +class FunctionCall(BaseModel): + name: str + arguments: Optional[str] = None + + +class ToolCall(BaseModel): + model_config = ConfigDict(use_enum_values=True) + + index: int + id: str + type: ToolType + function_call: Optional[FunctionCall] = None class Message(BaseModel): + model_config = ConfigDict(use_enum_values=True) + role: Role + reasoning_content: Optional[str] = None content: Optional[str] = None parts: Optional[List[ContentPart]] = None + tool_call_id: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = None class VariableDef(BaseModel): + model_config = ConfigDict(use_enum_values=True) + key: str desc: str type: VariableType @@ -79,10 +108,14 @@ class Function(BaseModel): class Tool(BaseModel): + model_config = ConfigDict(use_enum_values=True) + type: ToolType function: Optional[Function] = None - - +class ToolCallConfig(BaseModel): + model_config = ConfigDict(use_enum_values=True) + + tool_choice: ToolChoiceType class ToolCallConfig(BaseModel): tool_choice: ToolChoiceType @@ -98,6 +131,8 @@ class LLMConfig(BaseModel): class PromptTemplate(BaseModel): + model_config = ConfigDict(use_enum_values=True) + template_type: TemplateType messages: Optional[List[Message]] = None variable_defs: Optional[List[VariableDef]] = None @@ -137,6 +172,37 @@ class MPullPromptResponse(BaseResponse): data: Optional[PromptResultData] = None +# Execute相关数据结构 +class VariableVal(BaseModel): + key: str + value: Optional[str] = None + placeholder_messages: Optional[List[Message]] = None + multi_part_values: Optional[List[ContentPart]] = None + + +class ExecuteRequest(BaseModel): + workspace_id: str + prompt_identifier: Optional[PromptQuery] = None + variable_vals: Optional[List[VariableVal]] = None + messages: Optional[List[Message]] = None + + +# HTTP接口专用的Token使用统计(字段名对齐HTTP接口) +class TokenUsage(BaseModel): + input_tokens: int = 0 + output_tokens: int = 0 + + +class ExecuteData(BaseModel): + message: Optional[Message] = None + finish_reason: Optional[str] = None + usage: Optional[TokenUsage] = None + + +class ExecuteResponse(BaseResponse): + data: Optional[ExecuteData] = None + + class OpenAPIClient: def __init__(self, http_client: Client): self.http_client = http_client @@ -173,3 +239,37 @@ def _do_mpull_prompt(self, workspace_id: str, queries: List[PromptQuery]) -> Opt real_resp = MPullPromptResponse.model_validate(response) if real_resp.data is not None: return real_resp.data.items + + def execute(self, request: ExecuteRequest, timeout: Optional[int] = None) -> ExecuteData: + """执行Prompt请求""" + response = self.http_client.request( + EXECUTE_PROMPT_PATH, + "POST", + ExecuteResponse, + json=request, + timeout=timeout + ) + if response.data is None: + raise ValueError("Execute response data is None") + return response.data + + def execute_streaming(self, request: ExecuteRequest, timeout: Optional[int] = None): + """流式执行Prompt请求""" + return self.http_client.post_stream(EXECUTE_STREAMING_PROMPT_PATH, request, timeout=timeout) + + async def aexecute(self, request: ExecuteRequest, timeout: Optional[int] = None) -> ExecuteData: + """异步执行Prompt请求""" + response = self.http_client.request( + EXECUTE_PROMPT_PATH, + "POST", + ExecuteResponse, + json=request, + timeout=timeout + ) + if response.data is None: + raise ValueError("Execute response data is None") + return response.data + + async def aexecute_streaming(self, request: ExecuteRequest, timeout: Optional[int] = None): + """异步流式执行Prompt请求""" + return await self.http_client.apost_stream(EXECUTE_STREAMING_PROMPT_PATH, request, timeout=timeout) \ No newline at end of file diff --git a/cozeloop/internal/prompt/prompt.py b/cozeloop/internal/prompt/prompt.py index b88f697..2699ed4 100644 --- a/cozeloop/internal/prompt/prompt.py +++ b/cozeloop/internal/prompt/prompt.py @@ -2,23 +2,27 @@ # SPDX-License-Identifier: MIT import json -from typing import Dict, Any, List, Optional -import pydantic +from typing import Dict, Any, List, Optional, Union +import pydantic from jinja2 import BaseLoader, Undefined from jinja2.sandbox import SandboxedEnvironment from jinja2.utils import missing, object_type_repr from cozeloop.entities.prompt import (Prompt, Message, VariableDef, VariableType, TemplateType, Role, - PromptVariable, ContentPart, ContentType) + PromptVariable, ContentPart, ContentType, ExecuteResult) +from cozeloop.entities.stream import StreamReader from cozeloop.internal import consts from cozeloop.internal.consts.error import RemoteServiceError from cozeloop.internal.httpclient.client import Client from cozeloop.internal.prompt.cache import PromptCache -from cozeloop.internal.prompt.converter import _convert_prompt, _to_span_prompt_input, _to_span_prompt_output -from cozeloop.internal.prompt.openapi import OpenAPIClient, PromptQuery +from cozeloop.internal.prompt.converter import _convert_prompt, _to_span_prompt_input, _to_span_prompt_output, \ + convert_execute_data_to_result, to_openapi_message +from cozeloop.internal.prompt.execute_stream_reader import ExecuteStreamReader +from cozeloop.internal.prompt.openapi import OpenAPIClient, PromptQuery, ExecuteRequest, VariableVal from cozeloop.internal.trace.trace import TraceProvider -from cozeloop.spec.tracespec import PROMPT_KEY, INPUT, PROMPT_VERSION, V_SCENE_PROMPT_TEMPLATE, V_SCENE_PROMPT_HUB, PROMPT_LABEL +from cozeloop.spec.tracespec import PROMPT_KEY, INPUT, PROMPT_VERSION, V_SCENE_PROMPT_TEMPLATE, V_SCENE_PROMPT_HUB, \ + PROMPT_LABEL class PromptProvider: @@ -82,7 +86,8 @@ def _get_prompt(self, prompt_key: str, version: str, label: str = '') -> Optiona prompt = self.cache.get(prompt_key, version, label) # If not in cache, fetch from server and cache it if prompt is None: - result = self.openapi_client.mpull_prompt(self.workspace_id, [PromptQuery(prompt_key=prompt_key, version=version, label=label)]) + result = self.openapi_client.mpull_prompt(self.workspace_id, [ + PromptQuery(prompt_key=prompt_key, version=version, label=label)]) if result: prompt = _convert_prompt(result[0].prompt) self.cache.set(prompt_key, version, label, prompt) @@ -169,7 +174,8 @@ def _validate_variable_values_type(self, variable_defs: List[VariableDef], varia if not isinstance(val, str): raise ValueError(f"type of variable '{var_def.key}' should be string") elif var_def.type == VariableType.PLACEHOLDER: - if not (isinstance(val, Message) or (isinstance(val, List) and all(isinstance(item, Message) for item in val))): + if not (isinstance(val, Message) or ( + isinstance(val, List) and all(isinstance(item, Message) for item in val))): raise ValueError(f"type of variable '{var_def.key}' should be Message like object") elif var_def.type == VariableType.BOOLEAN: if not isinstance(val, bool): @@ -265,7 +271,7 @@ def format_multi_part( vardef = def_map[multi_part_key] value = val_map[multi_part_key] if vardef is not None and value is not None and vardef.type == VariableType.MULTI_PART: - formatted_parts.extend(value) + formatted_parts.extend(value) else: formatted_parts.append(part) @@ -298,7 +304,6 @@ def _format_placeholder_messages( return expanded_messages - def _render_text_content( self, template_type: TemplateType, @@ -326,7 +331,6 @@ def _render_text_content( else: raise ValueError(f"text render unsupported template type: {template_type}") - def _render_jinja2_template(self, template_str: str, variable_def_map: Dict[str, VariableDef], variables: Dict[str, Any]) -> str: """渲染 Jinja2 模板""" @@ -335,6 +339,168 @@ def _render_jinja2_template(self, template_str: str, variable_def_map: Dict[str, render_vars = {k: variables[k] for k in variable_def_map.keys() if variables is not None and k in variables} return template.render(**render_vars) + def execute_prompt( + self, + prompt_key: str, + *, + version: Optional[str] = None, + label: Optional[str] = None, + variable_vals: Optional[Dict[str, Any]] = None, + messages: Optional[List[Message]] = None, + stream: bool = False, + timeout: Optional[int] = None + ) -> Union[ExecuteResult, StreamReader[ExecuteResult]]: + """ + 执行Prompt请求 + + 使用基于SSE解码器的PromptStreamReader提供更好的流式处理性能和错误处理能力 + + Args: + prompt_key: Prompt标识符 + version: Prompt版本,可选 + label: Prompt标签,可选 + variable_vals: 变量值字典,可选 + messages: 消息列表,可选 + stream: 是否使用流式处理 + timeout: 请求超时时间(秒),可选,默认为600秒(10分钟) + + Returns: + Union[ExecuteResult, StreamReader[ExecuteResult]]: + 如果stream=False,返回ExecuteResult + 如果stream=True,返回PromptStreamReader实例(支持上下文管理器) + """ + # 设置默认超时时间为600秒(10分钟) + actual_timeout = timeout if timeout is not None else consts.DEFAULT_PROMPT_EXECUTE_TIMEOUT + + # 验证timeout参数 + self._validate_timeout(actual_timeout) + + request = self._build_execute_request( + prompt_key=prompt_key, + version=version or "", + label=label or "", + variable_vals=variable_vals, + messages=messages + ) + + if stream: + stream_context = self.openapi_client.execute_streaming(request, timeout=actual_timeout) + reader = ExecuteStreamReader(stream_context) + return reader + else: + data = self.openapi_client.execute(request, timeout=actual_timeout) + return convert_execute_data_to_result(data) + + async def aexecute_prompt( + self, + prompt_key: str, + *, + version: Optional[str] = None, + label: Optional[str] = None, + variable_vals: Optional[Dict[str, Any]] = None, + messages: Optional[List[Message]] = None, + stream: bool = False, + timeout: Optional[int] = None + ) -> Union[ExecuteResult, StreamReader[ExecuteResult]]: + """ + 异步执行Prompt请求 + + 使用基于SSE解码器的PromptStreamReader提供更好的流式处理性能和错误处理能力 + + Args: + prompt_key: Prompt标识符 + version: Prompt版本,可选 + label: Prompt标签,可选 + variable_vals: 变量值字典,可选 + messages: 消息列表,可选 + stream: 是否使用流式处理 + timeout: 请求超时时间(秒),可选,默认为600秒(10分钟) + + Returns: + Union[ExecuteResult, StreamReader[ExecuteResult]]: + 如果stream=False,返回ExecuteResult + 如果stream=True,返回PromptStreamReader实例(支持异步上下文管理器) + """ + # 设置默认超时时间为600秒(10分钟) + actual_timeout = timeout if timeout is not None else consts.DEFAULT_PROMPT_EXECUTE_TIMEOUT + + # 验证timeout参数 + self._validate_timeout(actual_timeout) + + request = self._build_execute_request( + prompt_key=prompt_key, + version=version or "", + label=label or "", + variable_vals=variable_vals, + messages=messages + ) + + if stream: + stream_context = await self.openapi_client.aexecute_streaming(request, timeout=actual_timeout) + reader = ExecuteStreamReader(stream_context) + return reader + else: + data = await self.openapi_client.aexecute(request, timeout=actual_timeout) + return convert_execute_data_to_result(data) + + def _build_execute_request( + self, + prompt_key: str, + version: Optional[str] = None, + label: Optional[str] = None, + variable_vals: Optional[Dict[str, Any]] = None, + messages: Optional[List[Message]] = None + ) -> ExecuteRequest: + """构建执行请求""" + # 构建prompt_identifier + prompt_identifier = PromptQuery( + prompt_key=prompt_key, + version=version if version else None, + label=label if label else None + ) + + # 构建variable_vals + variable_vals_list = None + if variable_vals: + variable_vals_list = [] + for key, value in variable_vals.items(): + var_val = VariableVal(key=key) + + if isinstance(value, str): + var_val.value = value + elif isinstance(value, Message): + var_val.placeholder_messages = [value] + elif isinstance(value, ContentPart): + var_val.multi_part_values = [value] + elif isinstance(value, list): + if all(isinstance(item, Message) for item in value): + var_val.placeholder_messages = value + elif all(isinstance(item, ContentPart) for item in value): + var_val.multi_part_values = value + else: + # 对于其他类型的list,转换为JSON字符串 + var_val.value = json.dumps(value) + else: + # 对于其他类型,转换为JSON字符串 + var_val.value = json.dumps(value) + + variable_vals_list.append(var_val) + + return ExecuteRequest( + workspace_id=self.workspace_id, + prompt_identifier=prompt_identifier, + variable_vals=variable_vals_list, + messages=[to_openapi_message(msg) for msg in messages] if messages else None, + ) + + def _validate_timeout(self, timeout: int) -> None: + """验证超时参数""" + if not isinstance(timeout, int): + raise ValueError("timeout must be an integer") + if timeout <= 0: + raise ValueError("timeout must be greater than 0") + + class CustomUndefined(Undefined): __slots__ = () @@ -351,5 +517,4 @@ def __str__(self) -> str: f"[{self._undefined_name!r}]" ) - return f"{{{{{message}}}}}" - + return f"{{{{{message}}}}}" \ No newline at end of file diff --git a/cozeloop/internal/stream/__init__.py b/cozeloop/internal/stream/__init__.py new file mode 100755 index 0000000..b8d42a0 --- /dev/null +++ b/cozeloop/internal/stream/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +from .base_stream_reader import BaseStreamReader +from .sse import SSEDecoder, ServerSentEvent + +__all__ = [ + "BaseStreamReader", + "SSEDecoder", + "ServerSentEvent", +] \ No newline at end of file diff --git a/cozeloop/internal/stream/base_stream_reader.py b/cozeloop/internal/stream/base_stream_reader.py new file mode 100755 index 0000000..c71dbf4 --- /dev/null +++ b/cozeloop/internal/stream/base_stream_reader.py @@ -0,0 +1,274 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import TypeVar, Generic, Iterator, AsyncIterator, Optional, Any +import json + +import httpx + +from cozeloop.entities.stream import StreamReader +from cozeloop.internal.stream.sse import SSEDecoder, ServerSentEvent +from cozeloop.internal.consts.error import RemoteServiceError, InternalError + +T = TypeVar('T') + +logger = logging.getLogger(__name__) + + +class BaseStreamReader(StreamReader[T], ABC, Generic[T]): + """ + 通用StreamReader基类 + + 基于Fornax的Stream设计模式,集成SSEDecoder进行SSE数据解码 + 支持同步和异步迭代器模式,实现上下文管理器 + 提供统一的错误处理机制和资源管理 + """ + + def __init__(self, response: httpx.Response, log_id: str = ""): + """ + 初始化BaseStreamReader + + Args: + response: httpx响应对象 + log_id: 日志ID,用于错误追踪 + """ + self.response = response + self.log_id = log_id + self._decoder = SSEDecoder() + self._closed = False + self._sync_iterator: Optional[Iterator[T]] = None + self._async_iterator: Optional[AsyncIterator[T]] = None + + @abstractmethod + def _parse_sse_data(self, sse: ServerSentEvent) -> Optional[T]: + """ + 解析SSE数据为业务对象,子类必须实现 + + Args: + sse: ServerSentEvent对象 + + Returns: + Optional[T]: 解析后的业务对象,如果不需要返回则为None + """ + pass + + def _iter_events(self) -> Iterator[ServerSentEvent]: + """ + 迭代SSE事件 + + Yields: + ServerSentEvent: 解码后的SSE事件 + """ + try: + for sse in self._decoder.iter_bytes(self.response.iter_bytes()): + yield sse + except Exception as e: + logger.error(f"Error iterating SSE events: {e}") + raise InternalError(f"Failed to decode SSE stream: {e}") + + async def _aiter_events(self) -> AsyncIterator[ServerSentEvent]: + """ + 异步迭代SSE事件 + + Yields: + ServerSentEvent: 解码后的SSE事件 + """ + try: + # 由于httpx.stream()返回的是同步流,即使在异步上下文中也需要使用同步迭代 + # 将同步迭代包装成异步生成器 + for sse in self._decoder.iter_bytes(self.response.iter_bytes()): + yield sse + except Exception as e: + logger.error(f"Error async iterating SSE events: {e}") + raise InternalError(f"Failed to decode SSE stream: {e}") + + def _handle_sse_error(self, sse: ServerSentEvent) -> None: + """ + 处理SSE事件中的错误 + + Args: + sse: ServerSentEvent对象 + + Raises: + RemoteServiceError: 当检测到错误事件时 + """ + if not sse.data: + return + + try: + data = sse.json() + + # 检查是否包含错误信息 + if isinstance(data, dict): + # 检查错误码字段 + if 'code' in data and data['code'] != 0: + error_code = data.get('code', 0) + error_msg = data.get('msg', 'Unknown error') + raise RemoteServiceError(200, error_code, error_msg, self.log_id) + + # 检查error字段 + if 'error' in data: + error_info = data['error'] + if isinstance(error_info, dict): + error_code = error_info.get('code', 0) + error_msg = error_info.get('message', 'Unknown error') + else: + error_code = 0 + error_msg = str(error_info) + raise RemoteServiceError(200, error_code, error_msg, self.log_id) + + except json.JSONDecodeError: + # 如果不是JSON格式,忽略错误检查 + pass + except RemoteServiceError: + # 重新抛出RemoteServiceError + raise + except Exception as e: + logger.warning(f"Error checking SSE error: {e}") + + def __stream__(self) -> Iterator[T]: + """ + 核心流处理逻辑 + + Yields: + T: 解析后的业务对象 + """ + if self._closed: + return + + try: + for sse in self._iter_events(): + if self._closed: + break + + # 检查错误 + self._handle_sse_error(sse) + + # 解析数据 + result = self._parse_sse_data(sse) + if result is not None: + yield result + + except RemoteServiceError: + raise + except Exception as e: + logger.error(f"Error in stream processing: {e}") + raise InternalError(f"Stream processing failed: {e}") + finally: + self._closed = True + + async def __astream__(self) -> AsyncIterator[T]: + """ + 异步核心流处理逻辑 + + Yields: + T: 解析后的业务对象 + """ + if self._closed: + return + + try: + async for sse in self._aiter_events(): + if self._closed: + break + + # 检查错误 + self._handle_sse_error(sse) + + # 解析数据 + result = self._parse_sse_data(sse) + if result is not None: + yield result + + except RemoteServiceError: + raise + except Exception as e: + logger.error(f"Error in async stream processing: {e}") + raise InternalError(f"Async stream processing failed: {e}") + finally: + self._closed = True + + # 同步迭代器接口 + def __iter__(self) -> Iterator[T]: + """支持同步迭代 - for循环直接读取""" + if self._sync_iterator is None: + self._sync_iterator = self.__stream__() + return self._sync_iterator + + def __next__(self) -> T: + """支持next()函数调用""" + if self._closed: + raise StopIteration("Stream is closed") + + try: + if self._sync_iterator is None: + self._sync_iterator = self.__stream__() + return next(self._sync_iterator) + except StopIteration: + self._closed = True + raise + except Exception as e: + self._closed = True + raise StopIteration from e + + # 异步迭代器接口 + def __aiter__(self) -> AsyncIterator[T]: + """支持异步迭代 - async for循环直接读取""" + if self._async_iterator is None: + self._async_iterator = self.__astream__() + return self._async_iterator + + async def __anext__(self) -> T: + """支持async next()调用""" + if self._closed: + raise StopAsyncIteration("Stream is closed") + + try: + if self._async_iterator is None: + self._async_iterator = self.__astream__() + return await self._async_iterator.__anext__() + except StopAsyncIteration: + self._closed = True + raise + except Exception as e: + self._closed = True + raise StopAsyncIteration from e + + # 上下文管理器接口 + def __enter__(self) -> BaseStreamReader[T]: + """同步上下文管理器入口""" + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """同步上下文管理器出口""" + self.close() + + async def __aenter__(self) -> BaseStreamReader[T]: + """异步上下文管理器入口""" + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """异步上下文管理器出口""" + await self.aclose() + + # 资源管理 + def close(self) -> None: + """关闭流""" + self._closed = True + if hasattr(self.response, 'close'): + self.response.close() + + async def aclose(self) -> None: + """异步关闭流""" + self._closed = True + if hasattr(self.response, 'aclose'): + await self.response.aclose() + + @property + def closed(self) -> bool: + """检查流是否已关闭""" + return self._closed \ No newline at end of file diff --git a/cozeloop/internal/stream/sse.py b/cozeloop/internal/stream/sse.py new file mode 100755 index 0000000..babb74f --- /dev/null +++ b/cozeloop/internal/stream/sse.py @@ -0,0 +1,185 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import json +from typing import Any, Iterator, Optional + + +class ServerSentEvent: + """ + Server-Sent Event (SSE) 数据结构 + + 封装SSE事件的各个字段:event, data, id, retry + 提供JSON解析功能 + """ + + def __init__( + self, + *, + event: str | None = None, + data: str | None = None, + id: str | None = None, + retry: int | None = None, + ) -> None: + """ + 初始化ServerSentEvent + + Args: + event: 事件类型 + data: 事件数据 + id: 事件ID + retry: 重试间隔(毫秒) + """ + if data is None: + data = "" + + self._id = id + self._data = data + self._event = event or None + self._retry = retry + + @property + def event(self) -> str | None: + """获取事件类型""" + return self._event + + @property + def id(self) -> str | None: + """获取事件ID""" + return self._id + + @property + def retry(self) -> int | None: + """获取重试间隔""" + return self._retry + + @property + def data(self) -> str: + """获取事件数据""" + return self._data + + def json(self) -> Any: + """ + 将data字段解析为JSON对象 + + Returns: + 解析后的JSON对象 + + Raises: + json.JSONDecodeError: 当data不是有效的JSON时 + """ + return json.loads(self.data) + + def __repr__(self) -> str: + return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})" + + +class SSEDecoder: + """ + Server-Sent Event (SSE) 解码器 + + 负责将字节流解码为ServerSentEvent对象 + 支持SSE协议的完整规范,包括多行数据累积和各种字段处理 + """ + + def __init__(self) -> None: + """初始化SSE解码器""" + self._event: Optional[str] = None + self._data: list[str] = [] + self._last_event_id: Optional[str] = None + self._retry: Optional[int] = None + + def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: + """ + 同步解码字节流为SSE事件 + + Args: + iterator: 字节流迭代器 + + Yields: + ServerSentEvent: 解码后的SSE事件 + """ + for chunk in self._iter_chunks(iterator): + # 先分割再解码,确保splitlines()只使用\r和\n + for raw_line in chunk.splitlines(): + line = raw_line.decode("utf-8") + sse = self.decode(line) + if sse: + yield sse + + def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]: + """ + 同步处理字节块,确保完整的SSE消息 + + Args: + iterator: 字节流迭代器 + + Yields: + bytes: 完整的SSE消息块 + """ + data = b"" + for chunk in iterator: + for line in chunk.splitlines(keepends=True): + data += line + if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): + yield data + data = b"" + if data: + yield data + + def decode(self, line: str) -> Optional[ServerSentEvent]: + """ + 解码单行SSE数据 + + Args: + line: SSE数据行 + + Returns: + Optional[ServerSentEvent]: 解码后的SSE事件,如果未完成则返回None + """ + if not line: + # 空行表示事件结束,构造SSE事件 + if not self._event and not self._data and not self._last_event_id and self._retry is None: + return None + + sse = ServerSentEvent( + event=self._event, + data="\n".join(self._data), + id=self._last_event_id, + retry=self._retry, + ) + + # 重置状态,准备下一个事件 + self._event = None + self._data = [] + self._retry = None + + return sse + + # 解析字段 + fieldname, _, value = line.partition(":") + + # 去掉值前面的空格 + if value.startswith(" "): + value = value[1:] + + # 处理各种字段 + if fieldname == "event": + self._event = value + elif fieldname == "data": + self._data.append(value) + elif fieldname == "id": + # 根据SSE规范,id字段不能包含null字符 + if "\0" not in value: + self._last_event_id = value + elif fieldname == "retry": + try: + self._retry = int(value) + except (TypeError, ValueError): + # 忽略无效的retry值 + pass + # 其他字段被忽略 + + return None diff --git a/cozeloop/prompt.py b/cozeloop/prompt.py index 085b6f2..74b3ba7 100644 --- a/cozeloop/prompt.py +++ b/cozeloop/prompt.py @@ -2,9 +2,10 @@ # SPDX-License-Identifier: MIT from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union, Any -from cozeloop.entities.prompt import Prompt, Message, PromptVariable +from cozeloop.entities.prompt import Prompt, Message, PromptVariable, ExecuteResult +from cozeloop.entities.stream import StreamReader class PromptClient(ABC): @@ -36,3 +37,53 @@ def prompt_format( :param variables: A dictionary of variables to use when formatting the prompt. :return: A list of formatted messages (`entity.Message`) if successful, or None. """ + + @abstractmethod + def execute_prompt( + self, + prompt_key: str, + *, + version: Optional[str] = None, + label: Optional[str] = None, + variable_vals: Optional[Dict[str, Any]] = None, + messages: Optional[List[Message]] = None, + stream: bool = False, + timeout: Optional[int] = None + ) -> Union[ExecuteResult, StreamReader[ExecuteResult]]: + """ + 执行Prompt请求 + + :param prompt_key: prompt的唯一标识 + :param version: prompt版本,可选 + :param label: prompt标签,可选 + :param variable_vals: 变量值字典,可选 + :param messages: 消息列表,可选 + :param stream: 是否流式返回,默认False + :param timeout: 请求超时时间(秒),可选,默认为600秒(10分钟) + :return: stream=False时返回ExecuteResult,stream=True时返回StreamReader[ExecuteResult] + """ + + @abstractmethod + async def aexecute_prompt( + self, + prompt_key: str, + *, + version: Optional[str] = None, + label: Optional[str] = None, + variable_vals: Optional[Dict[str, Any]] = None, + messages: Optional[List[Message]] = None, + stream: bool = False, + timeout: Optional[int] = None + ) -> Union[ExecuteResult, StreamReader[ExecuteResult]]: + """ + 异步执行Prompt请求 + + :param prompt_key: prompt的唯一标识 + :param version: prompt版本,可选 + :param label: prompt标签,可选 + :param variable_vals: 变量值字典,可选 + :param messages: 消息列表,可选 + :param stream: 是否流式返回,默认False + :param timeout: 请求超时时间(秒),可选,默认为600秒(10分钟) + :return: stream=False时返回ExecuteResult,stream=True时返回StreamReader[ExecuteResult] + """ \ No newline at end of file diff --git a/examples/prompt/ptaas/__init__.py b/examples/prompt/ptaas/__init__.py new file mode 100755 index 0000000..22759db --- /dev/null +++ b/examples/prompt/ptaas/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +""" +PTaaS (Prompt Template as a Service) 示例 + +本包包含了PTaaS功能的各种使用示例,包括: +- 基础示例:同步非流式、异步非流式、异步流式调用 +- 高级示例:占位符变量、标签使用、Jinja2模板、超时控制、多模态处理 +""" + +__all__ = [ + "ptaas", + "ptaas_placeholder_variable", + "ptaas_with_label", + "ptaas_jinja", + "ptaas_timeout", + "ptaas_multimodal" +] \ No newline at end of file diff --git a/examples/prompt/ptaas/ptaas.py b/examples/prompt/ptaas/ptaas.py new file mode 100755 index 0000000..8e8b166 --- /dev/null +++ b/examples/prompt/ptaas/ptaas.py @@ -0,0 +1,165 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +""" +PTaaS Basic Example - Sync non-stream, sync stream, async non-stream, async stream calls + +Demonstrates: +- Sync non-stream call +- Sync stream call +- Async non-stream call +- Async stream call +""" + +import asyncio +import os + +from cozeloop import new_client, Client +from cozeloop.entities.prompt import Message, Role, ExecuteResult + + +def setup_client() -> Client: + """ + Unified client setup function + + Environment variables: + - COZELOOP_WORKSPACE_ID: workspace ID + - COZELOOP_API_TOKEN: API token + """ + # Set the following environment variables first. + # COZELOOP_WORKSPACE_ID=your workspace id + # COZELOOP_API_TOKEN=your token + client = new_client( + api_base_url=os.getenv("COZELOOP_API_BASE_URL"), + workspace_id=os.getenv("COZELOOP_WORKSPACE_ID"), + api_token=os.getenv("COZELOOP_API_TOKEN"), + ) + return client + + +def print_execute_result(result: ExecuteResult) -> None: + """Unified result printing function, consistent with Go version format""" + if result.message: + print(f"Message: {result.message}") + if result.finish_reason: + print(f"FinishReason: {result.finish_reason}") + if result.usage: + print(f"Usage: {result.usage}") + + +def sync_non_stream_example(client: Client) -> None: + """Sync non-stream call example""" + print("=== Sync Non-Stream Example ===") + + # 1. Create a prompt on the platform + # Create a Prompt on the platform's Prompt development page (set Prompt Key to 'ptaas_demo'), + # add the following messages to the template, submit a version. + # System: You are a helpful assistant for {{topic}}. + # User: Please help me with {{user_request}} + + result = client.execute_prompt( + prompt_key="ptaas_demo", + version="0.0.1", + variable_vals={ + "topic": "artificial intelligence", + "user_request": "explain what is machine learning" + }, + # You can also append messages to the prompt. + messages=[ + Message(role=Role.USER, content="Keep the answer brief.") + ], + stream=False + ) + print_execute_result(result) + + +def sync_stream_example(client: Client) -> None: + """Sync stream call example""" + print("=== Sync Stream Example ===") + + stream_reader = client.execute_prompt( + prompt_key="ptaas_demo", + version="0.0.1", + variable_vals={ + "topic": "artificial intelligence", + "user_request": "explain what is machine learning" + }, + messages=[ + Message(role=Role.USER, content="Keep the answer brief.") + ], + stream=True + ) + + for result in stream_reader: + print_execute_result(result) + + print("\nStream finished.") + + +async def async_non_stream_example(client: Client) -> None: + """Async non-stream call example""" + print("=== Async Non-Stream Example ===") + + result = await client.aexecute_prompt( + prompt_key="ptaas_demo", + version="0.0.1", + variable_vals={ + "topic": "artificial intelligence", + "user_request": "explain what is machine learning" + }, + messages=[ + Message(role=Role.USER, content="Keep the answer brief.") + ], + stream=False + ) + print_execute_result(result) + + +async def async_stream_example(client: Client) -> None: + """Async stream call example""" + print("=== Async Stream Example ===") + + stream_reader = await client.aexecute_prompt( + prompt_key="ptaas_demo", + version="0.0.1", + variable_vals={ + "topic": "artificial intelligence", + "user_request": "explain what is machine learning" + }, + messages=[ + Message(role=Role.USER, content="Keep the answer brief.") + ], + stream=True + ) + + async for result in stream_reader: + print_execute_result(result) + + print("\nStream finished.") + + +async def main(): + """Main function""" + client = setup_client() + + try: + # Sync non-stream call + sync_non_stream_example(client) + + # Sync stream call + sync_stream_example(client) + + # Async non-stream call + await async_non_stream_example(client) + + # Async stream call + await async_stream_example(client) + + finally: + # Close client + if hasattr(client, 'close'): + client.close() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/examples/prompt/ptaas/ptaas_jinja.py b/examples/prompt/ptaas/ptaas_jinja.py new file mode 100755 index 0000000..7755dca --- /dev/null +++ b/examples/prompt/ptaas/ptaas_jinja.py @@ -0,0 +1,83 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +""" +PTaaS Jinja2 Template Example - Demonstrates the use of complex variable structures + +Demonstrates: +- Jinja2 template syntax +- Usage of complex object variables +""" + +import os +from typing import Dict, Any + +from cozeloop import new_client, Client +from cozeloop.entities.prompt import ExecuteResult + + +def setup_client() -> Client: + """ + Unified client setup function + """ + # Set the following environment variables first. + # COZELOOP_WORKSPACE_ID=your workspace id + # COZELOOP_API_TOKEN=your token + client = new_client( + api_base_url=os.getenv("COZELOOP_API_BASE_URL"), + workspace_id=os.getenv("COZELOOP_WORKSPACE_ID"), + api_token=os.getenv("COZELOOP_API_TOKEN"), + ) + return client + + +def print_execute_result(result: ExecuteResult) -> None: + """Unified result printing function, consistent with Go version format""" + if result.message: + print(f"Message: {result.message}") + if result.finish_reason: + print(f"FinishReason: {result.finish_reason}") + if result.usage: + print(f"Usage: {result.usage}") + + +def jinja_template_example(client: Client) -> None: + """Jinja2 template example""" + print("=== Jinja2 Template Example ===") + + # 1. Create a prompt using jinja2 template on the platform + # Create a Prompt on the platform's Prompt development page (set Prompt Key to 'ptaas_demo'), + # add the following messages to the template, submit a version, and set a label (e.g., 'production') for that version. + # System: You are a helpful assistant for {{param.topic}}. Your audience is {{param.age}} years old. + # User: Please help me with {{param.user_request}} + + result = client.execute_prompt( + prompt_key="ptaas_demo", + version="0.0.2", + variable_vals={ + "param": { + "topic": "artificial intelligence", + "age": 10, + "user_request": "explain what is machine learning" + } + }, + stream=False + ) + print_execute_result(result) + + +def main(): + """Main function""" + # The explanation of jinja2 template is based on non-streaming execution, and it also applies to streaming execution. + client = setup_client() + + try: + jinja_template_example(client) + finally: + # Close client + if hasattr(client, 'close'): + client.close() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/prompt/ptaas/ptaas_multimodal.py b/examples/prompt/ptaas/ptaas_multimodal.py new file mode 100755 index 0000000..998882b --- /dev/null +++ b/examples/prompt/ptaas/ptaas_multimodal.py @@ -0,0 +1,131 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +""" +PTaaS Multimodal Example - Demonstrates image processing and multimodal input + +Demonstrates: +- Image URL processing +- Base64 image data processing +- Multimodal message construction +""" + +import os +import base64 + +from cozeloop import new_client, Client +from cozeloop.entities.prompt import Message, Role, ExecuteResult, ContentPart, ContentType + + +def setup_client() -> Client: + """ + Unified client setup function + """ + # Set the following environment variables first. + # COZELOOP_WORKSPACE_ID=your workspace id + # COZELOOP_API_TOKEN=your token + client = new_client( + api_base_url=os.getenv("COZELOOP_API_BASE_URL"), + workspace_id=os.getenv("COZELOOP_WORKSPACE_ID"), + api_token=os.getenv("COZELOOP_API_TOKEN"), + ) + return client + + +def print_execute_result(result: ExecuteResult) -> None: + """Unified result printing function, consistent with Go version format""" + if result.message: + print(f"Message: {result.message}") + if result.finish_reason: + print(f"FinishReason: {result.finish_reason}") + if result.usage: + print(f"Usage: {result.usage}") + + +def multimodal_example(client: Client) -> None: + """Multimodal example""" + print("=== Multimodal Example ===") + + # 1. Create a prompt on the platform + # Create a Prompt on the platform's Prompt development page (set Prompt Key to 'ptaas_demo'), + # add the following messages to the template, submit a version. example1 and example2 are the multi modal variables. + # System: You can quickly identify the location where a photo was taken. + # User: 例如:{{example1}} + # Assistant: {{city1}} + # User: 例如:{{example2}} + # Assistant: {{city2}} + + # Prepare Base64 image data (example) + # In actual use, you need to provide a real image path + image_path = "/Users/bytedance/Downloads/shanghai.jpeg" + base64_data = "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAYEBQYFBAYGBQYHBwYIChAKCgkJChQODwwQFxQYGBcUFhYaHSUfGhsjHBYWICwgIyYnKSopGR8tMC0oMCUoKSj/2wBDAQcHBwoIChMKChMoGhYaKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCj/wAARCAABAAEDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAv/xAAUEAEAAAAAAAAAAAAAAAAAAAAA/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAX/xAAUEQEAAAAAAAAAAAAAAAAAAAAA/9oADAMBAAIRAxEAPwCdABmX/9k=" + # If image file exists, read and encode + # 如果图片文件存在,读取并编码 + if os.path.exists(image_path): + try: + with open(image_path, "rb") as f: + image_bytes = f.read() + base64_image = base64.b64encode(image_bytes).decode() + base64_data = f"data:image/jpeg;base64,{base64_image}" + except Exception as e: + print(f"Warning: Could not read image file {image_path}: {e}") + print("Using placeholder base64 data instead.") + + result = client.execute_prompt( + prompt_key="ptaas_demo", + version="0.0.8", + # multi modal variable can be List[ContentPart]/ContentPart + # Images can be provided via URL or in base64 encoded format. + # Image URL needs to be publicly accessible. + # Base64-formatted data should follow the standard data URI format, like "data:[][;base64],". + variable_vals={ + "example1": [ + ContentPart( + type=ContentType.IMAGE_URL, + image_url="https://p8.itc.cn/q_70/images03/20221219/61785c89cd17421ca0d007c7a87d09fb.jpeg" + ) + ], + "city1": "Beijing", + "example2": [ + ContentPart( + type=ContentType.BASE64_DATA, + base64_data=base64_data + ) + ], + "city2": "Shanghai" + }, + messages=[ + Message( + role=Role.USER, + parts=[ + ContentPart( + type=ContentType.IMAGE_URL, + image_url="https://img0.baidu.com/it/u=1402951118,1660594928&fm=253&app=138&f=JPEG?w=800&h=1200" + ), + ContentPart( + type=ContentType.TEXT, + text="Where is this photo taken?" + ) + ] + ) + ], + stream=False + ) + print_execute_result(result) + + +def main(): + """Main function""" + # The explanation of multi modal is based on non-streaming execution, and it also applies to streaming execution. + client = setup_client() + + try: + multimodal_example(client) + finally: + # Close client + if hasattr(client, 'close'): + client.close() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/prompt/ptaas/ptaas_placeholder_variable.py b/examples/prompt/ptaas/ptaas_placeholder_variable.py new file mode 100755 index 0000000..aac1da0 --- /dev/null +++ b/examples/prompt/ptaas/ptaas_placeholder_variable.py @@ -0,0 +1,86 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +""" +PTaaS Placeholder Variable Example - Demonstrates the use of chat_history placeholder variables + +Demonstrates: +- How to use placeholder variables +- Processing of chat_history variables +""" + +import os +from typing import List + +from cozeloop import new_client, Client +from cozeloop.entities.prompt import Message, Role, ExecuteResult + + +def setup_client() -> Client: + """ + Unified client setup function + """ + # Set the following environment variables first. + # COZELOOP_WORKSPACE_ID=your workspace id + # COZELOOP_API_TOKEN=your token + client = new_client( + api_base_url=os.getenv("COZELOOP_API_BASE_URL"), + workspace_id=os.getenv("COZELOOP_WORKSPACE_ID"), + api_token=os.getenv("COZELOOP_API_TOKEN"), + ) + return client + + +def print_execute_result(result: ExecuteResult) -> None: + """Unified result printing function, consistent with Go version format""" + if result.message: + print(f"Message: {result.message}") + if result.finish_reason: + print(f"FinishReason: {result.finish_reason}") + if result.usage: + print(f"Usage: {result.usage}") + + +def placeholder_variable_example(client: Client) -> None: + """Placeholder variable example""" + print("=== Placeholder Variable Example ===") + + # 1. Create a prompt on the platform + # Create a Prompt on the platform's Prompt development page (set Prompt Key to 'ptaas_demo'), + # add the following messages to the template, submit a version. + # System: You are a helpful assistant for {{topic}}. + # Placeholder: {{chat_history}} + # User: Please help me with {{user_request}} + + result = client.execute_prompt( + prompt_key="ptaas_demo", + version="0.0.5", + variable_vals={ + "topic": "artificial intelligence", + # chat_history is a placeholder variable, and it can be List[Message]/Message. + "chat_history": [ + Message(role=Role.USER, content="hello"), + Message(role=Role.ASSISTANT, content="hello") + ], + "user_request": "explain what is machine learning" + }, + stream=False + ) + print_execute_result(result) + + +def main(): + """Main function""" + # The explanation of placeholder variable is based on non-streaming execution, and it also applies to streaming execution. + client = setup_client() + + try: + placeholder_variable_example(client) + finally: + # Close client + if hasattr(client, 'close'): + client.close() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/prompt/ptaas/ptaas_timeout.py b/examples/prompt/ptaas/ptaas_timeout.py new file mode 100755 index 0000000..2fe4dd3 --- /dev/null +++ b/examples/prompt/ptaas/ptaas_timeout.py @@ -0,0 +1,66 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +""" +PTaaS Timeout Control Example - Demonstrates client timeout and context timeout control + +Demonstrates: +- Client-level timeout settings +- Request-level timeout control +""" + +import os + +from cozeloop import new_client +from cozeloop.entities.prompt import ExecuteResult + + +def print_execute_result(result: ExecuteResult) -> None: + """Unified result printing function, consistent with Go version format""" + if result.message: + print(f"Message: {result.message}") + if result.finish_reason: + print(f"FinishReason: {result.finish_reason}") + if result.usage: + print(f"Usage: {result.usage}") + + +def set_request_timeout(): + print("=== Request Timeout Example ===") + + # 1. Create a prompt on the platform + # Create a Prompt on the platform's Prompt development page (set Prompt Key to 'ptaas_demo'), + # add the following messages to the template, submit a version, and set a label (e.g., 'production') for that version. + # System: You are a helpful assistant for {{topic}}. + # User: Please help me with {{user_request}} + + # Set the following environment variables first. + # COZELOOP_WORKSPACE_ID=your workspace id + # COZELOOP_API_TOKEN=your token + client = new_client( + api_base_url=os.getenv("COZELOOP_API_BASE_URL"), + workspace_id=os.getenv("COZELOOP_WORKSPACE_ID"), + api_token=os.getenv("COZELOOP_API_TOKEN"), + ) + + result = client.execute_prompt( + prompt_key="ptaas_demo", + version="0.0.1", + variable_vals={ + "topic": "artificial intelligence", + "user_request": "explain what is machine learning" + }, + stream=False, + timeout=1 # Set request timeout, default is 600s, max is 600s. + ) + print_execute_result(result) + + +def main(): + """Main function""" + # The explanation of timeout settings is based on non-streaming execution, and it also applies to streaming execution. + set_request_timeout() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/prompt/ptaas/ptaas_with_label.py b/examples/prompt/ptaas/ptaas_with_label.py new file mode 100755 index 0000000..80a92fa --- /dev/null +++ b/examples/prompt/ptaas/ptaas_with_label.py @@ -0,0 +1,79 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +""" +PTaaS Label Usage Example - Demonstrates the use of label parameter + +Demonstrates: +- How to use label parameter +- Version label management +""" + +import os + +from cozeloop import new_client, Client +from cozeloop.entities.prompt import ExecuteResult + + +def setup_client() -> Client: + """ + Unified client setup function + """ + # Set the following environment variables first. + # COZELOOP_WORKSPACE_ID=your workspace id + # COZELOOP_API_TOKEN=your token + client = new_client( + api_base_url=os.getenv("COZELOOP_API_BASE_URL"), + workspace_id=os.getenv("COZELOOP_WORKSPACE_ID"), + api_token=os.getenv("COZELOOP_API_TOKEN"), + ) + return client + + +def print_execute_result(result: ExecuteResult) -> None: + """Unified result printing function, consistent with Go version format""" + if result.message: + print(f"Message: {result.message}") + if result.finish_reason: + print(f"FinishReason: {result.finish_reason}") + if result.usage: + print(f"Usage: {result.usage}") + + +def label_example(client: Client) -> None: + """Label usage example""" + print("=== Label Example ===") + + # 1. Create a prompt on the platform + # Create a Prompt on the platform's Prompt development page (set Prompt Key to 'ptaas_demo'), + # add the following messages to the template, submit a version, and set a label (e.g., 'production') for that version. + # System: You are a helpful assistant for {{topic}}. + # User: Please help me with {{user_request}} + + result = client.execute_prompt( + prompt_key="ptaas_demo", + label="production", # Note: When version is specified, label field will be ignored + variable_vals={ + "topic": "artificial intelligence", + "user_request": "explain what is machine learning" + }, + stream=False + ) + print_execute_result(result) + + +def main(): + """Main function""" + # The explanation of label is based on non-streaming execution, and it also applies to streaming execution. + client = setup_client() + + try: + label_example(client) + finally: + # Close client + if hasattr(client, 'close'): + client.close() + + +if __name__ == "__main__": + main() \ No newline at end of file From 1e5f519ac52fb25f702de17c3093543d5a66f629 Mon Sep 17 00:00:00 2001 From: zhongzhiwei Date: Tue, 9 Sep 2025 19:14:11 +0800 Subject: [PATCH 2/2] delete useless example Change-Id: I885e86a70f3ef146163e234ffd1d8593cfe4e9f3 --- examples/prompt/ptaas/ptaas_multimodal.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/prompt/ptaas/ptaas_multimodal.py b/examples/prompt/ptaas/ptaas_multimodal.py index 998882b..a4703f5 100755 --- a/examples/prompt/ptaas/ptaas_multimodal.py +++ b/examples/prompt/ptaas/ptaas_multimodal.py @@ -55,10 +55,7 @@ def multimodal_example(client: Client) -> None: # User: 例如:{{example2}} # Assistant: {{city2}} - # Prepare Base64 image data (example) - # In actual use, you need to provide a real image path - image_path = "/Users/bytedance/Downloads/shanghai.jpeg" - base64_data = "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAYEBQYFBAYGBQYHBwYIChAKCgkJChQODwwQFxQYGBcUFhYaHSUfGhsjHBYWICwgIyYnKSopGR8tMC0oMCUoKSj/2wBDAQcHBwoIChMKChMoGhYaKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCj/wAARCAABAAEDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAv/xAAUEAEAAAAAAAAAAAAAAAAAAAAA/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAX/xAAUEQEAAAAAAAAAAAAAAAAAAAAA/9oADAMBAAIRAxEAPwCdABmX/9k=" + image_path = "your_image_path" # If image file exists, read and encode # 如果图片文件存在,读取并编码 if os.path.exists(image_path):