From 415d6368590e25e31ab6f756a94788cb59e11c32 Mon Sep 17 00:00:00 2001 From: Kenny Stryker Date: Mon, 1 Dec 2025 22:51:42 +0530 Subject: [PATCH 1/2] feat: Added support for Vertex AI for LLM Extraction Strategy --- crawl4ai/async_configs.py | 36 ++++++++++++++++++++++++--------- crawl4ai/config.py | 6 ++++++ crawl4ai/extraction_strategy.py | 5 ++++- 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/crawl4ai/async_configs.py b/crawl4ai/async_configs.py index bfa0d398b..2ab59ae5d 100644 --- a/crawl4ai/async_configs.py +++ b/crawl4ai/async_configs.py @@ -1,3 +1,4 @@ +import json import os from typing import Union import warnings @@ -1792,7 +1793,8 @@ def __init__( frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, stop: Optional[List[str]] = None, - n: Optional[int] = None, + n: Optional[int] = None, + **kwargs, ): """Configuaration class for LLM provider and API token.""" self.provider = provider @@ -1803,13 +1805,25 @@ def __init__( else: # Check if given provider starts with any of key in PROVIDER_MODELS_PREFIXES # If not, check if it is in PROVIDER_MODELS + prefixes = PROVIDER_MODELS_PREFIXES.keys() if any(provider.startswith(prefix) for prefix in prefixes): - selected_prefix = next( - (prefix for prefix in prefixes if provider.startswith(prefix)), - None, - ) - self.api_token = PROVIDER_MODELS_PREFIXES.get(selected_prefix) + + if provider.startswith("vertex_ai"): + credential_path = PROVIDER_MODELS_PREFIXES["vertex_ai"] + + with open(credential_path, "r") as file: + vertex_credentials = json.load(file) + # Convert to JSON string + + self.vertex_credentials = json.dumps(vertex_credentials) + self.api_token = None + else: + selected_prefix = next( + (prefix for prefix in prefixes if provider.startswith(prefix)), + None, + ) + self.api_token = PROVIDER_MODELS_PREFIXES.get(selected_prefix) else: self.provider = DEFAULT_PROVIDER self.api_token = os.getenv(DEFAULT_PROVIDER_API_KEY) @@ -1834,11 +1848,11 @@ def from_kwargs(kwargs: dict) -> "LLMConfig": frequency_penalty=kwargs.get("frequency_penalty"), presence_penalty=kwargs.get("presence_penalty"), stop=kwargs.get("stop"), - n=kwargs.get("n") + n=kwargs.get("n"), ) def to_dict(self): - return { + result = { "provider": self.provider, "api_token": self.api_token, "base_url": self.base_url, @@ -1848,8 +1862,11 @@ def to_dict(self): "frequency_penalty": self.frequency_penalty, "presence_penalty": self.presence_penalty, "stop": self.stop, - "n": self.n + "n": self.n, } + if self.provider.startswith("vertex_ai"): + result["extra_args"] = {"vertex_credentials": self.vertex_credentials} + return result def clone(self, **kwargs): """Create a copy of this configuration with updated values. @@ -1864,6 +1881,7 @@ def clone(self, **kwargs): config_dict.update(kwargs) return LLMConfig.from_kwargs(config_dict) + class SeedingConfig: """ Configuration class for URL discovery and pre-validation via AsyncUrlSeeder. diff --git a/crawl4ai/config.py b/crawl4ai/config.py index 08f56b832..2e61ffee8 100644 --- a/crawl4ai/config.py +++ b/crawl4ai/config.py @@ -22,6 +22,11 @@ "anthropic/claude-3-opus-20240229": os.getenv("ANTHROPIC_API_KEY"), "anthropic/claude-3-sonnet-20240229": os.getenv("ANTHROPIC_API_KEY"), "anthropic/claude-3-5-sonnet-20240620": os.getenv("ANTHROPIC_API_KEY"), + "vertex_ai/gemini-2.0-flash-lite": os.getenv("GOOGLE_APPLICATION_CREDENTIALS"), + 'vertex_ai/gemini-2.0-flash': os.getenv("GOOGLE_APPLICATION_CREDENTIALS"), + 'vertex_ai/gemini-2.5-flash': os.getenv("GOOGLE_APPLICATION_CREDENTIALS"), + 'vertex_ai/gemini-2.5-pro': os.getenv("GOOGLE_APPLICATION_CREDENTIALS"), + 'vertex_ai/gemini-3-pro-preview': os.getenv("GOOGLE_APPLICATION_CREDENTIALS"), "gemini/gemini-pro": os.getenv("GEMINI_API_KEY"), 'gemini/gemini-1.5-pro': os.getenv("GEMINI_API_KEY"), 'gemini/gemini-2.0-flash': os.getenv("GEMINI_API_KEY"), @@ -35,6 +40,7 @@ "openai": os.getenv("OPENAI_API_KEY"), "anthropic": os.getenv("ANTHROPIC_API_KEY"), "gemini": os.getenv("GEMINI_API_KEY"), + "vertex_ai": os.getenv("GOOGLE_APPLICATION_CREDENTIALS"), "deepseek": os.getenv("DEEPSEEK_API_KEY"), } diff --git a/crawl4ai/extraction_strategy.py b/crawl4ai/extraction_strategy.py index 4a64e5d46..6e6c93907 100644 --- a/crawl4ai/extraction_strategy.py +++ b/crawl4ai/extraction_strategy.py @@ -574,7 +574,10 @@ def __init__( self.overlap_rate = overlap_rate self.word_token_rate = word_token_rate self.apply_chunking = apply_chunking - self.extra_args = kwargs.get("extra_args", {}) + # Merge both extra kwargs + self.extra_args = kwargs.get("extra_args", {}) | self.llm_config.to_dict().get( + "extra_args", {} + ) if not self.apply_chunking: self.chunk_token_threshold = 1e9 self.verbose = verbose From e9f596ff085e0cb6d1222b73308b8d594c5e128e Mon Sep 17 00:00:00 2001 From: Kenny Stryker Date: Mon, 1 Dec 2025 23:02:11 +0530 Subject: [PATCH 2/2] chore: doc and formatting --- crawl4ai/async_configs.py | 1 - deploy/docker/c4ai-code-context.md | 38 ++++++++++++++++++++++-------- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/crawl4ai/async_configs.py b/crawl4ai/async_configs.py index 2ab59ae5d..b93745846 100644 --- a/crawl4ai/async_configs.py +++ b/crawl4ai/async_configs.py @@ -1815,7 +1815,6 @@ def __init__( with open(credential_path, "r") as file: vertex_credentials = json.load(file) # Convert to JSON string - self.vertex_credentials = json.dumps(vertex_credentials) self.api_token = None else: diff --git a/deploy/docker/c4ai-code-context.md b/deploy/docker/c4ai-code-context.md index c18fbc784..6acc318c4 100644 --- a/deploy/docker/c4ai-code-context.md +++ b/deploy/docker/c4ai-code-context.md @@ -1269,7 +1269,8 @@ class LLMConfig: frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, stop: Optional[List[str]] = None, - n: Optional[int] = None, + n: Optional[int] = None, + **kwargs, ): """Configuaration class for LLM provider and API token.""" self.provider = provider @@ -1280,13 +1281,25 @@ class LLMConfig: else: # Check if given provider starts with any of key in PROVIDER_MODELS_PREFIXES # If not, check if it is in PROVIDER_MODELS + prefixes = PROVIDER_MODELS_PREFIXES.keys() if any(provider.startswith(prefix) for prefix in prefixes): - selected_prefix = next( - (prefix for prefix in prefixes if provider.startswith(prefix)), - None, - ) - self.api_token = PROVIDER_MODELS_PREFIXES.get(selected_prefix) + + if provider.startswith("vertex_ai"): + credential_path = PROVIDER_MODELS_PREFIXES["vertex_ai"] + + with open(credential_path, "r") as file: + vertex_credentials = json.load(file) + # Convert to JSON string + self.vertex_credentials = json.dumps(vertex_credentials) + + self.api_token = None + else: + selected_prefix = next( + (prefix for prefix in prefixes if provider.startswith(prefix)), + None, + ) + self.api_token = PROVIDER_MODELS_PREFIXES.get(selected_prefix) else: self.provider = DEFAULT_PROVIDER self.api_token = os.getenv(DEFAULT_PROVIDER_API_KEY) @@ -1311,11 +1324,11 @@ class LLMConfig: frequency_penalty=kwargs.get("frequency_penalty"), presence_penalty=kwargs.get("presence_penalty"), stop=kwargs.get("stop"), - n=kwargs.get("n") + n=kwargs.get("n"), ) def to_dict(self): - return { + result = { "provider": self.provider, "api_token": self.api_token, "base_url": self.base_url, @@ -1325,8 +1338,11 @@ class LLMConfig: "frequency_penalty": self.frequency_penalty, "presence_penalty": self.presence_penalty, "stop": self.stop, - "n": self.n + "n": self.n, } + if self.provider.startswith("vertex_ai"): + result["extra_args"] = {"vertex_credentials": self.vertex_credentials} + return result def clone(self, **kwargs): """Create a copy of this configuration with updated values. @@ -4094,7 +4110,9 @@ class LLMExtractionStrategy(ExtractionStrategy): self.overlap_rate = overlap_rate self.word_token_rate = word_token_rate self.apply_chunking = apply_chunking - self.extra_args = kwargs.get("extra_args", {}) + self.extra_args = kwargs.get("extra_args", {}) | self.llm_config.to_dict().get( + "extra_args", {} + ) if not self.apply_chunking: self.chunk_token_threshold = 1e9 self.verbose = verbose