From 4b8c3c3e71bff29252997ddc306f78222b195667 Mon Sep 17 00:00:00 2001 From: Nidhi Hiremath Date: Sun, 19 Oct 2025 22:26:18 -0700 Subject: [PATCH 1/5] initial commit --- recipes/tool_calls/tc_grpo.py | 232 +++++++++++++++++++++++++++ recipes/tool_calls/train_sft_lora.py | 152 ++++++++++++++++++ 2 files changed, 384 insertions(+) create mode 100644 recipes/tool_calls/tc_grpo.py create mode 100644 recipes/tool_calls/train_sft_lora.py diff --git a/recipes/tool_calls/tc_grpo.py b/recipes/tool_calls/tc_grpo.py new file mode 100644 index 0000000..95e2cc8 --- /dev/null +++ b/recipes/tool_calls/tc_grpo.py @@ -0,0 +1,232 @@ +## pip install -U bitsandbytes trl transformers datasets peft accelerate + +import ast +import re +import json +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from trl import GRPOConfig, GRPOTrainer +from datasets import load_dataset +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training + + +# Load dataset +ds = load_dataset("NousResearch/hermes-function-calling-v1", "func_calling_singleturn") +pattern = r'(.*?)' + +# Load model and tokenizer +model_name = "Qwen/Qwen2.5-7B-Instruct" +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token + +def format_example_grpo(x, tokenizer=tokenizer): + """Format examples for GRPO - we need prompts and completions separately""" + try: + tools = [s['function'] for s in ast.literal_eval(re.findall(pattern, x['conversations'][0]['value'], re.DOTALL)[-1].strip())] + # Create the prompt (without assistant response) + prompt = tokenizer.apply_chat_template( + [ + {"role": "system", "content": "You are an assistant capable of tool calls"}, + {"role": "user", "content": x['conversations'][-2]['value']}, + ], + tools=tools, + add_generation_prompt=True, + tokenize=False + ) + # Get the ground truth completion (reference) + completion = x['conversations'][-1]['value'] + return { + "prompt": prompt, + # "completion": completion, + "references": completion, + "tools": json.dumps(tools) + } + except Exception as e: + return None + +formatted_dataset = ds.map(format_example_grpo).filter(lambda x: x is not None) +formatted_dataset = formatted_dataset.remove_columns(["conversations", "category", "subcategory", "task"]) + +train_test_split = formatted_dataset['train'].train_test_split(test_size=0.1, seed=42) +train_dataset = train_test_split['train'] +eval_dataset = train_test_split['test'] + +print(f"Training samples: {len(train_dataset)}") +print(f"Validation samples: {len(eval_dataset)}") + +# Load model with quantization +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + load_in_8bit=True +) + +# Prepare model for LoRA fine-tuning +model = prepare_model_for_kbit_training(model) + +lora_config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM" +) + +model = get_peft_model(model, lora_config) +model.print_trainable_parameters() + +# Define reward function for tool calling +# def reward_function(prompts, completions, references, tools_list): +def reward_function(prompts, **kwargs): + """ + Reward function that evaluates the quality of tool calls. + Returns a list of reward scores (one per completion). + """ + # print("All kwargs:") + # for key in kwargs: + # print(f" {key}: {type(kwargs[key])}") + rewards = [] + + completions = kwargs.get("completions", []) + references = kwargs.get("references", []) + tools = kwargs.get("tools", []) + + for completion, reference in zip(completions, references): + reward = 0.0 + + # 1. Check if completion contains valid JSON tool call + try: + # Look for tool call patterns + if '' in completion and '' in completion: + reward += 2.0 + + # Extract and validate JSON + tool_call_match = re.search(r'(.*?)', completion, re.DOTALL) + if tool_call_match: + tool_call_json = json.loads(tool_call_match.group(1)) + reward += 2.0 + + # Check if tool name matches reference + if 'name' in tool_call_json: + ref_tool_match = re.search(r'(.*?)', reference, re.DOTALL) + if ref_tool_match: + ref_tool = json.loads(ref_tool_match.group(1)) + if tool_call_json.get('name') == ref_tool.get('name'): + reward += 3.0 + + # Check if arguments overlap + if 'arguments' in tool_call_json and 'arguments' in ref_tool: + arg_overlap = len(set(tool_call_json['arguments'].keys()) & + set(ref_tool['arguments'].keys())) + reward += arg_overlap * 1.0 + except: + reward -= 1.0 + + # 2. Penalize if no tool call when reference has one + if '' in reference and '' not in completion: + reward -= 5.0 + + # 3. Penalize if tool call when reference doesn't have one + if '' not in reference and '' in completion: + reward -= 2.0 + + # 4. Length penalty (avoid overly verbose responses) + if len(completion) > len(reference) * 2: + reward -= 1.0 + + rewards.append(reward) + + return rewards + +# GRPO Configuration +grpo_config = GRPOConfig( + output_dir="./qwen-tool-calling-grpo", + num_train_epochs=3, + per_device_train_batch_size=2, + per_device_eval_batch_size=4, + gradient_accumulation_steps=8, + learning_rate=5e-6, + warmup_steps=100, + logging_steps=10, + save_steps=100, + eval_strategy="steps", + eval_steps=10, + bf16=True, + max_grad_norm=0.3, + + # GRPO-specific parameters + num_generations=4, # Number of generations per prompt for group comparison + temperature=0.9, # Sampling temperature + # max_new_tokens=512, + # kl=0.05, # KL divergence coefficient to prevent drift from reference +) + +# Initialize GRPO Trainer +trainer = GRPOTrainer( + model=model, + args=grpo_config, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + # tokenizer=tokenizer, + reward_funcs=[reward_function], +) + +# Train the model +print("Starting GRPO training...") +trainer.train() + +# Save the fine-tuned model +trainer.save_model("./qwen-tool-calling-grpo-final") +tokenizer.save_pretrained("./qwen-tool-calling-grpo-final") + +print("Training complete!") + +# Example inference +def test_tool_calling(prompt, tools): + model.eval() + + formatted_prompt = tokenizer.apply_chat_template( + [ + {"role": "system", "content": "You are an assistant capable of tool calls"}, + {"role": "user", "content": prompt} + ], + tools=tools, + add_generation_prompt=True, + tokenize=False + ) + + inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=512, + temperature=0.7, + do_sample=True + ) + + response = tokenizer.decode(outputs[0], skip_special_tokens=False) + print(response) + +# Test the model +test_tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"} + }, + "required": ["location"] + } + } + } +] + +test_tool_calling("What's the weather in London?", test_tools) diff --git a/recipes/tool_calls/train_sft_lora.py b/recipes/tool_calls/train_sft_lora.py new file mode 100644 index 0000000..840bbc6 --- /dev/null +++ b/recipes/tool_calls/train_sft_lora.py @@ -0,0 +1,152 @@ +## pip install -U bitsandbytes + +import ast +import re +import json +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments +from trl import SFTTrainer, SFTConfig +from datasets import Dataset, load_dataset +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training + + +# dataset = load_dataset("younissk/tool-calling-mix") +ds = load_dataset("NousResearch/hermes-function-calling-v1", "func_calling_singleturn") +pattern = r'(.*?)' + +# 4. Load model and tokenizer +model_name = "Qwen/Qwen2.5-7B-Instruct" # or Qwen2.5-3B, Qwen2.5-1.5B +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token + +def format_example(x): + try: + tools = [s['function'] for s in ast.literal_eval(re.findall(pattern, x['conversations'][0]['value'], re.DOTALL)[-1].strip())] + text = tokenizer.apply_chat_template( + [ + {"role": "system", "content": "You are an assistant capable of tool calls"}, + {"role": "user", "content": x['conversations'][-2]['value']}, + {"role": "assistant", "content": x['conversations'][-1]['value']} + ], + tools=tools, + add_generation_prompt=True, + tokenize=False + ) + return {"text": text} + except Exception as e: + # Return None or a sentinel value to filter out later + return None + +formatted_dataset = ds.map(format_example).filter(lambda x: x is not None) +formatted_dataset = formatted_dataset.remove_columns(["conversations", "category", "subcategory", "task"]) + +train_test_split = formatted_dataset['train'].train_test_split(test_size=0.1, seed=42) +train_dataset = train_test_split['train'] +eval_dataset = train_test_split['test'] + +print(f"Training samples: {len(train_dataset)}") +print(f"Validation samples: {len(eval_dataset)}") + +# Convert to dataset +# formatted_data = [format_training_example(ex) for ex in training_examples] +# dataset = Dataset.from_list(formatted_data) +# dataset = formatted_dataset +# print(" Dataset --- ", dataset) + + +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + load_in_8bit=True # Use 8-bit quantization for memory efficiency +) + +# 5. Prepare model for LoRA fine-tuning +model = prepare_model_for_kbit_training(model) + +""" +text = tokenizer.apply_chat_template( + messages, + tools=tools, + add_generation_prompt=True, + tokenize=False +) +""" + +lora_config = LoraConfig( + r=16, # LoRA rank + lora_alpha=32, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM" +) + +model = get_peft_model(model, lora_config) +model.print_trainable_parameters() + +sft_config = SFTConfig( + output_dir="./qwen-tool-calling3", + num_train_epochs=3, + per_device_train_batch_size=2, + gradient_accumulation_steps=8, + learning_rate=2e-4, + warmup_steps=100, + logging_steps=10, + save_steps=100, + eval_strategy="steps", # Enable evaluation + eval_steps=30, # Evaluate every 100 steps + optim="paged_adamw_8bit", + fp16=False, + bf16=True, + max_grad_norm=0.3, + # SFT-specific parameters + dataset_text_field='text', + max_length=2048, +) +trainer = SFTTrainer( + model=model, + args=sft_config, + train_dataset=formatted_dataset['train'], + eval_dataset=eval_dataset, # Add validation dataset +) + +trainer.train() + + +# 8. Train the model +print("Starting training...") +trainer.train() + +# 9. Save the fine-tuned model +trainer.save_model("./qwen-tool-calling-final") +tokenizer.save_pretrained("./qwen-tool-calling-final") + +print("Training complete!") + +# 10. Example inference +def test_tool_calling(prompt): + model.eval() + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=512, + temperature=0.7, + do_sample=True + ) + + response = tokenizer.decode(outputs[0], skip_special_tokens=False) + print(response) + +# Test the model +test_prompt = f"""<|im_start|>system +You are a helpful assistant with access to tools: {json.dumps(tools)}<|im_end|> +<|im_start|>user +What's the weather in London?<|im_end|> +<|im_start|>assistant +""" + +test_tool_calling(test_prompt) \ No newline at end of file From a96f1dec25c11b24d9e438fa046d4b2ac2bb6320 Mon Sep 17 00:00:00 2001 From: Nidhi Hiremath Date: Sun, 19 Oct 2025 23:11:26 -0700 Subject: [PATCH 2/5] change reward --- recipes/tool_calls/tc_grpo.py | 138 ++++++++++++++++++++++++---------- 1 file changed, 98 insertions(+), 40 deletions(-) diff --git a/recipes/tool_calls/tc_grpo.py b/recipes/tool_calls/tc_grpo.py index 95e2cc8..21ebd4a 100644 --- a/recipes/tool_calls/tc_grpo.py +++ b/recipes/tool_calls/tc_grpo.py @@ -80,7 +80,7 @@ def format_example_grpo(x, tokenizer=tokenizer): # Define reward function for tool calling # def reward_function(prompts, completions, references, tools_list): -def reward_function(prompts, **kwargs): +def reward_function(prompts, completions, **kwargs): """ Reward function that evaluates the quality of tool calls. Returns a list of reward scores (one per completion). @@ -90,60 +90,118 @@ def reward_function(prompts, **kwargs): # print(f" {key}: {type(kwargs[key])}") rewards = [] - completions = kwargs.get("completions", []) + # completions = kwargs.get("completions", []) references = kwargs.get("references", []) + # references = kwargs.get("completion", []) tools = kwargs.get("tools", []) for completion, reference in zip(completions, references): reward = 0.0 + ref_has_tool = '' in reference + comp_has_tool = '' in completion + # Case 1: Reference has tool call + if ref_has_tool: + if comp_has_tool: + # Both have tool calls - compare quality + try: + tool_call_match = re.search(r'(.*?)', completion, re.DOTALL) + ref_tool_match = re.search(r'(.*?)', reference, re.DOTALL) + if tool_call_match and ref_tool_match: + comp_tool = json.loads(tool_call_match.group(1)) + ref_tool = json.loads(ref_tool_match.group(1)) + + # Reward for valid JSON structure + reward += 3.0 + + # Reward for correct tool name + if comp_tool.get('name') == ref_tool.get('name'): + reward += 5.0 + + # Reward for matching arguments + if 'arguments' in comp_tool and 'arguments' in ref_tool: + comp_args = set(comp_tool['arguments'].keys()) + ref_args = set(ref_tool['arguments'].keys()) + + # Jaccard similarity for arguments + if len(ref_args) > 0: + overlap = len(comp_args & ref_args) + union = len(comp_args | ref_args) + reward += (overlap / union) * 4.0 + else: + # Wrong tool name + reward -= 3.0 + else: + # Has tool call tags but invalid JSON + reward += 1.0 + + except json.JSONDecodeError: + # Malformed JSON in tool call + reward -= 2.0 + else: + # Missing tool call when one is needed + reward -= 5.0 + # Case 2: Reference has no tool call + else: + if comp_has_tool: + # False positive - made a tool call when not needed + reward -= 3.0 + else: + # Correct - no tool call needed + reward += 2.0 + # Small penalty for excessive length (normalized) + len_ratio = len(completion) / max(len(reference), 1) + if len_ratio > 2.0: + reward -= 1.0 + rewards.append(reward) + return rewards # 1. Check if completion contains valid JSON tool call - try: - # Look for tool call patterns - if '' in completion and '' in completion: - reward += 2.0 + # try: + # # Look for tool call patterns + # if '' in completion and '' in completion: + # reward += 2.0 - # Extract and validate JSON - tool_call_match = re.search(r'(.*?)', completion, re.DOTALL) - if tool_call_match: - tool_call_json = json.loads(tool_call_match.group(1)) - reward += 2.0 + # # Extract and validate JSON + # tool_call_match = re.search(r'(.*?)', completion, re.DOTALL) + # if tool_call_match: + # tool_call_json = json.loads(tool_call_match.group(1)) + # reward += 2.0 - # Check if tool name matches reference - if 'name' in tool_call_json: - ref_tool_match = re.search(r'(.*?)', reference, re.DOTALL) - if ref_tool_match: - ref_tool = json.loads(ref_tool_match.group(1)) - if tool_call_json.get('name') == ref_tool.get('name'): - reward += 3.0 + # # Check if tool name matches reference + # if 'name' in tool_call_json: + # ref_tool_match = re.search(r'(.*?)', reference, re.DOTALL) + # if ref_tool_match: + # ref_tool = json.loads(ref_tool_match.group(1)) + # if tool_call_json.get('name') == ref_tool.get('name'): + # reward += 3.0 - # Check if arguments overlap - if 'arguments' in tool_call_json and 'arguments' in ref_tool: - arg_overlap = len(set(tool_call_json['arguments'].keys()) & - set(ref_tool['arguments'].keys())) - reward += arg_overlap * 1.0 - except: - reward -= 1.0 + # # Check if arguments overlap + # if 'arguments' in tool_call_json and 'arguments' in ref_tool: + # arg_overlap = len(set(tool_call_json['arguments'].keys()) & + # set(ref_tool['arguments'].keys())) + # reward += arg_overlap * 1.0 + # except: + # reward -= 1.0 - # 2. Penalize if no tool call when reference has one - if '' in reference and '' not in completion: - reward -= 5.0 + # # 2. Penalize if no tool call when reference has one + # if '' in reference and '' not in completion: + # reward -= 5.0 - # 3. Penalize if tool call when reference doesn't have one - if '' not in reference and '' in completion: - reward -= 2.0 + # # 3. Penalize if tool call when reference doesn't have one + # if '' not in reference and '' in completion: + # reward -= 2.0 - # 4. Length penalty (avoid overly verbose responses) - if len(completion) > len(reference) * 2: - reward -= 1.0 + # # 4. Length penalty (avoid overly verbose responses) + # if len(completion) > len(reference) * 2: + # reward -= 1.0 - rewards.append(reward) + # rewards.append(reward) - return rewards + # return rewards # GRPO Configuration grpo_config = GRPOConfig( - output_dir="./qwen-tool-calling-grpo", + output_dir="./qwen-tool-calling-1", num_train_epochs=3, per_device_train_batch_size=2, per_device_eval_batch_size=4, @@ -179,8 +237,8 @@ def reward_function(prompts, **kwargs): trainer.train() # Save the fine-tuned model -trainer.save_model("./qwen-tool-calling-grpo-final") -tokenizer.save_pretrained("./qwen-tool-calling-grpo-final") +trainer.save_model("./qwen-tool-calling-grpo-1") +tokenizer.save_pretrained("./qwen-tool-calling-grpo-1") print("Training complete!") @@ -229,4 +287,4 @@ def test_tool_calling(prompt, tools): } ] -test_tool_calling("What's the weather in London?", test_tools) +test_tool_calling("What's the weather in London?", test_tools) \ No newline at end of file From b52129d03fa0687962718a76e1e4b5b668b73a44 Mon Sep 17 00:00:00 2001 From: Nidhi Hiremath Date: Wed, 22 Oct 2025 20:28:27 -0700 Subject: [PATCH 3/5] more exps --- recipes/tool_calls/dataset_clean.py | 127 +++++++++++++++ recipes/tool_calls/new_ds_test.py | 109 +++++++++++++ recipes/tool_calls/tc_grpo.py | 236 +++++++++++++--------------- 3 files changed, 347 insertions(+), 125 deletions(-) create mode 100644 recipes/tool_calls/dataset_clean.py create mode 100644 recipes/tool_calls/new_ds_test.py diff --git a/recipes/tool_calls/dataset_clean.py b/recipes/tool_calls/dataset_clean.py new file mode 100644 index 0000000..7a37eae --- /dev/null +++ b/recipes/tool_calls/dataset_clean.py @@ -0,0 +1,127 @@ +import ast +import re +import json +import jsonschema +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments +from datasets import Dataset, load_dataset, DatasetDict +from vllm import LLM, SamplingParams + +ds = load_dataset("NousResearch/hermes-function-calling-v1", "func_calling_singleturn") + +# 4. Load model and tokenizer +model_name = "Qwen/Qwen2.5-7B-Instruct" # or Qwen2.5-3B, Qwen2.5-1.5B +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token + +def format_example(x, add_generation_prompt=True): + try: + pattern = r'(.*?)' + tools = [s['function'] for s in ast.literal_eval(re.findall(pattern, x['conversations'][0]['value'], re.DOTALL)[-1].strip())] + text = tokenizer.apply_chat_template( + [ + # {"role": "system", "content": "You are an assistant capable of tool calls"}, + {"role": "user", "content": x['conversations'][-2]['value']}, + # {"role": "assistant", "content": x['conversations'][-1]['value']} + ], + tools=tools, + add_generation_prompt=add_generation_prompt, + tokenize=False + ) + return { + "text": text, + # "tools": tools, + "gt_generation": x['conversations'][-1]['value'], + } + except Exception as e: + print(e) + # Return None or a sentinel value to filter out later + return None + +formatted_dataset = ds.map(format_example).filter(lambda x: x is not None) + +llm = LLM(model="Qwen/Qwen2.5-7B-Instruct") +sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=4096) +prompts = [sample['text'] for sample in formatted_dataset['train']] +generations = llm.generate(prompts, sampling_params) + +all_generations = [] +for idx, output in enumerate(generations): + prompt = output.prompt + generated_text = output.outputs[0].text + all_generations.append({ + "prompt": prompt, + "generated_text": generated_text, + "gt_generation": formatted_dataset['train'][idx]['gt_generation'], + "conversations": formatted_dataset['train'][idx]['conversations'], + "idx": formatted_dataset['train'][idx]['id'], + }) + +## Sometimes GT generation is garbage, make sure they are approximately correct with. json schema +def get_all_available_tools_dict(prompt): + pattern = r'(.*?)' + try: + data_string = re.findall(pattern, prompt, re.DOTALL)[-1].strip() + json_objects = [] + for line in data_string.strip().split('\n'): + if line: # Skip empty lines + json_objects.append(json.loads(line)) + tool_dict = {} + for obj in json_objects: + tool_dict[obj['name']] = obj + return tool_dict + except exception as e: + print(f"Exception {e}") + return {} + +def get_selected_tools_dict(generated_text): + pattern =r'(.*?)' + tools = re.findall(pattern, generated_text, re.DOTALL) + tool_dict = {} + for tool in tools: + try: + tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) # escape chars and some other nonsense + if 'name' not in tool: + print(" No name in tool???") + continue + tool_dict[tool['name']] = tool_dict.get(tool['name'], []) + [tool] + except Exception as e: + print(f"Exception in get_selected_tools_dict for {tool}") + return tool_dict + +def tools_validity(avail_tools, tools): + if not tools: + return False + for tool_name, tools in tools.items(): + for tool in tools: + if tool["name"] not in avail_tools: + print(" Hallucinating b**ch??") + return False + tool_name = tool["name"] + schema = avail_tools[tool_name].get("parameters", {}) # TODO - might depend on dataset + gen = tool["arguments"] + try: + jsonschema.validate(instance=gen, schema=schema) + except jsonschema.ValidationError as e: + print(f"Valid data validation error: {e.message}") + return False + # all tools validated + return True + + +success = 0 +success_idxs = [] +for idx, gen in enumerate(all_generations): + avail_tools = get_all_available_tools_dict(gen['prompt']) + # gen_tools = get_selected_tools_dict(gen['generated_text']) + ref_tools = get_selected_tools_dict(gen['gt_generation']) + validity = tools_validity(avail_tools, ref_tools) + if validity: + success += 1 + success_idxs.append(idx) + +filtered_generations = [all_generations[idx] for idx in success_idxs] + +new_ds = Dataset.from_list(filtered_generations) +dataset_dict = DatasetDict({"train": new_ds}) +dataset_dict.push_to_hub("baseten-admin/CleanNousResearch_simple") \ No newline at end of file diff --git a/recipes/tool_calls/new_ds_test.py b/recipes/tool_calls/new_ds_test.py new file mode 100644 index 0000000..f5bbd85 --- /dev/null +++ b/recipes/tool_calls/new_ds_test.py @@ -0,0 +1,109 @@ +import ast +import re +import json +import jsonschema +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments +from datasets import Dataset, load_dataset, DatasetDict +from vllm import LLM, SamplingParams + +ds = load_dataset("Salesforce/xlam-function-calling-60k") + +# 4. Load model and tokenizer +model_name = "Qwen/Qwen2.5-7B-Instruct" # or Qwen2.5-3B, Qwen2.5-1.5B +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token + +def format_example(x, add_generation_prompt=True): + try: + pattern = r'(.*?)' + tools = json.loads(x['tools']) + text = tokenizer.apply_chat_template( + [ + {"role": "user", "content": x['query']}, + ], + tools=tools, + add_generation_prompt=add_generation_prompt, + tokenize=False + ) + return { + "text": text, + # "tools": tools, + "gt_generation": x['answers'], + } + except Exception as e: + print(e) + # Return None or a sentinel value to filter out later + return None + +LIMIT = 500 +formatted_dataset = ds.map(format_example).filter(lambda x: x is not None) +llm = LLM(model="Qwen/Qwen2.5-7B-Instruct") +sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=4096) +prompts = [sample['text'] for sample in formatted_dataset['train']][:LIMIT] +generations = llm.generate(prompts, sampling_params) + +all_generations = [] +for idx, output in enumerate(generations): + prompt = output.prompt + generated_text = output.outputs[0].text + all_generations.append({ + "prompt": prompt, + "generated_text": generated_text, # needs to be parsed + "gt_generation": formatted_dataset['train'][idx]['gt_generation'], # this is already a list of tools to be called + "tools": formatted_dataset['train'][idx]['tools'], + "idx": formatted_dataset['train'][idx]['id'], + }) + +from typing import List, dict + +def get_selected_tools_dict(generated_text): + """Returns a dict of tool_name: list of tool_calls from the generated text in order""" + pattern =r'(.*?)' + tools = re.findall(pattern, generated_text, re.DOTALL) + parsed_tools = [] + for tool in tools: + try: + # tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) # escape chars and some other nonsense + tool = json.loads(tool.encode().decode('unicode_escape').strip()) + parsed_tools.append(tool) + except Exception as e: + print(f"Exception in parsing generated tools for {tool}") + return parsed_tools + +def get_correctness_of_tool_calls(gen_text: str, gt_tools: str): + gen_tools = get_selected_tools_dict(gen_text) + gt_tools = json.loads(gt_tools) + if gen_tools == gt_tools: + return True + return False + +def get_parsed_tool_calls(gen_text: str, gt_tools: str): + gen_tools = get_selected_tools_dict(gen_text) + gt_tools = json.loads(gt_tools) + return gen_tools, gt_tools + +correct_count = 0 +correct_len_incorrect = 0 +too_many_count = 0 +too_few_count = 0 +for idx, sample in enumerate(all_generations): + # print(idx) + # if get_correctness_of_tool_calls(sample['generated_text'], sample['gt_generation']): + # correct_count += 1 + gen_tools, gt_tools = get_parsed_tool_calls(sample['generated_text'], sample['gt_generation']) + if len(gen_tools) > len(gt_tools): + too_many_count += 1 + elif len(gen_tools) < len(gt_tools): + too_few_count += 1 + else: + if gen_tools == gt_tools: + correct_count += 1 + else: + correct_len_incorrect += 1 + +print(f"Total samples: {len(all_generations)}") +print(f"Correct tool calls: {correct_count}") +print(f"Correct length but incorrect tool calls: {correct_len_incorrect}") +print(f"Too many tool calls: {too_many_count}") +print(f"Too few tool calls: {too_few_count}") diff --git a/recipes/tool_calls/tc_grpo.py b/recipes/tool_calls/tc_grpo.py index 21ebd4a..f75cd65 100644 --- a/recipes/tool_calls/tc_grpo.py +++ b/recipes/tool_calls/tc_grpo.py @@ -11,7 +11,9 @@ # Load dataset -ds = load_dataset("NousResearch/hermes-function-calling-v1", "func_calling_singleturn") +# ds = load_dataset("NousResearch/hermes-function-calling-v1", "func_calling_singleturn") +# ds = load_dataset("baseten-admin/CleanNousResearch_simple") +ds = load_dataset("Salesforce/xlam-function-calling-60k") pattern = r'(.*?)' # Load model and tokenizer @@ -19,33 +21,31 @@ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token -def format_example_grpo(x, tokenizer=tokenizer): - """Format examples for GRPO - we need prompts and completions separately""" +def format_example(x, add_generation_prompt=False): try: - tools = [s['function'] for s in ast.literal_eval(re.findall(pattern, x['conversations'][0]['value'], re.DOTALL)[-1].strip())] - # Create the prompt (without assistant response) - prompt = tokenizer.apply_chat_template( + pattern = r'(.*?)' + tools = json.loads(x['tools']) + text = tokenizer.apply_chat_template( [ - {"role": "system", "content": "You are an assistant capable of tool calls"}, - {"role": "user", "content": x['conversations'][-2]['value']}, + {"role": "user", "content": x['query']}, ], tools=tools, - add_generation_prompt=True, + add_generation_prompt=add_generation_prompt, tokenize=False ) - # Get the ground truth completion (reference) - completion = x['conversations'][-1]['value'] return { - "prompt": prompt, - # "completion": completion, - "references": completion, - "tools": json.dumps(tools) - } + "prompt": text, + # "text": text, + # "tools": tools, + "gt_generation": x['answers'], + } except Exception as e: + print(e) + # Return None or a sentinel value to filter out later return None -formatted_dataset = ds.map(format_example_grpo).filter(lambda x: x is not None) -formatted_dataset = formatted_dataset.remove_columns(["conversations", "category", "subcategory", "task"]) +formatted_dataset = ds.map(format_example).filter(lambda x: x is not None) +# formatted_dataset = formatted_dataset.remove_columns(["conversations", "category", "subcategory", "task"]) train_test_split = formatted_dataset['train'].train_test_split(test_size=0.1, seed=42) train_dataset = train_test_split['train'] @@ -60,7 +60,7 @@ def format_example_grpo(x, tokenizer=tokenizer): torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, - load_in_8bit=True + # load_in_8bit=True ) # Prepare model for LoRA fine-tuning @@ -69,16 +69,47 @@ def format_example_grpo(x, tokenizer=tokenizer): lora_config = LoraConfig( r=16, lora_alpha=32, - target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "all_linear"], lora_dropout=0.05, bias="none", - task_type="CAUSAL_LM" + # task_type="CAUSAL_LM" ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # Define reward function for tool calling + +def get_all_available_tools_dict(prompt): + pattern = r'(.*?)' + try: + data_string = re.findall(pattern, prompt, re.DOTALL)[-1].strip() + json_objects = [] + for line in data_string.strip().split('\n'): + if line: # Skip empty lines + json_objects.append(json.loads(line)) + tool_dict = {} + for obj in json_objects: + tool_dict[obj['name']] = obj + return tool_dict + except exception as e: + print(f"Exception {e}") + return {} + +def get_selected_tools_list(generated_text): + """Returns a dict of tool_name: list of tool_calls from the generated text in order""" + pattern =r'(.*?)' + tools = re.findall(pattern, generated_text, re.DOTALL) + parsed_tools = [] + for tool in tools: + try: + # tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) # escape chars and some other nonsense + tool = json.loads(tool.encode().decode('unicode_escape').strip()) + parsed_tools.append(tool) + except Exception as e: + print(f"Exception in parsing generated tools for {tool}") + return parsed_tools + # def reward_function(prompts, completions, references, tools_list): def reward_function(prompts, completions, **kwargs): """ @@ -89,115 +120,69 @@ def reward_function(prompts, completions, **kwargs): # for key in kwargs: # print(f" {key}: {type(kwargs[key])}") rewards = [] - # completions = kwargs.get("completions", []) - references = kwargs.get("references", []) + references = kwargs.get("gt_generation", []) # references = kwargs.get("completion", []) tools = kwargs.get("tools", []) - - for completion, reference in zip(completions, references): + for prompts, completion, reference in zip(prompts, completions, references): + available_tools = get_all_available_tools_dict(prompts) # dict of tool_name: tool_def reward = 0.0 - ref_has_tool = '' in reference - comp_has_tool = '' in completion - # Case 1: Reference has tool call - if ref_has_tool: - if comp_has_tool: - # Both have tool calls - compare quality - try: - tool_call_match = re.search(r'(.*?)', completion, re.DOTALL) - ref_tool_match = re.search(r'(.*?)', reference, re.DOTALL) - if tool_call_match and ref_tool_match: - comp_tool = json.loads(tool_call_match.group(1)) - ref_tool = json.loads(ref_tool_match.group(1)) - - # Reward for valid JSON structure - reward += 3.0 - - # Reward for correct tool name - if comp_tool.get('name') == ref_tool.get('name'): - reward += 5.0 - - # Reward for matching arguments - if 'arguments' in comp_tool and 'arguments' in ref_tool: - comp_args = set(comp_tool['arguments'].keys()) - ref_args = set(ref_tool['arguments'].keys()) - - # Jaccard similarity for arguments - if len(ref_args) > 0: - overlap = len(comp_args & ref_args) - union = len(comp_args | ref_args) - reward += (overlap / union) * 4.0 - else: - # Wrong tool name - reward -= 3.0 - else: - # Has tool call tags but invalid JSON - reward += 1.0 - - except json.JSONDecodeError: - # Malformed JSON in tool call - reward -= 2.0 - else: - # Missing tool call when one is needed - reward -= 5.0 - # Case 2: Reference has no tool call - else: - if comp_has_tool: - # False positive - made a tool call when not needed - reward -= 3.0 + try: + gen_tools = get_selected_tools_list(completion) + except: + # not parsable or correct structure + print("Count not parse") + reward = -3.0 + rewards.append(reward) + continue + gt_tools = json.loads(reference) + if gt_tools == gen_tools: # perfect match + rewards.append(4.0) + continue + # lengths are unequal or something else is wrong + PENALTY_NAME = 1.0 + PENALTY_MISSING_PARAM = 0.5 + PENALTY_WRONG_PARAM = 0.5 + PENALTY_EXTRA_PARAM = 0.25 + REWARD_CORRECT_CALL = 1.0 + reward = 0.0 + # Account for extra/missing tool calls + if len(gen_tools) < len(gt_tools): + reward -= PENALTY_NAME * (len(gt_tools) - len(gen_tools)) + elif len(gen_tools) > len(gt_tools): + reward -= PENALTY_NAME * (len(gen_tools) - len(gt_tools)) + # Compare in order + for gt_tool, gen_tool in zip(gt_tools, gen_tools): + # tool name + if gt_tool['name'] != gen_tool['name']: + reward -= PENALTY_NAME + continue + # params + gt_params = gt_tool.get('parameters', {}) + gen_params = gen_tool.get('parameters', {}) + # missing or wrong params + for key, value in gt_params.items(): + if key not in gen_params: + reward -= PENALTY_MISSING_PARAM + elif gen_params[key] != value: + reward -= PENALTY_WRONG_PARAM + # extra params + for key in gen_params: + if key not in gt_params: + reward -= PENALTY_EXTRA_PARAM + # reward for correct call (name + maybe params) + # But you might choose: only if no param penalties + if gt_params and not any(key not in gen_params or gen_params[key] != gt_params[key] for key in gt_params): + # all params matched + reward += REWARD_CORRECT_CALL + elif not gt_params: + # no params expected, name matches + reward += REWARD_CORRECT_CALL else: - # Correct - no tool call needed - reward += 2.0 - # Small penalty for excessive length (normalized) - len_ratio = len(completion) / max(len(reference), 1) - if len_ratio > 2.0: - reward -= 1.0 + # partial match: maybe +0 or some smaller reward + reward += REWARD_CORRECT_CALL * 0.5 rewards.append(reward) return rewards - - # 1. Check if completion contains valid JSON tool call - # try: - # # Look for tool call patterns - # if '' in completion and '' in completion: - # reward += 2.0 - - # # Extract and validate JSON - # tool_call_match = re.search(r'(.*?)', completion, re.DOTALL) - # if tool_call_match: - # tool_call_json = json.loads(tool_call_match.group(1)) - # reward += 2.0 - - # # Check if tool name matches reference - # if 'name' in tool_call_json: - # ref_tool_match = re.search(r'(.*?)', reference, re.DOTALL) - # if ref_tool_match: - # ref_tool = json.loads(ref_tool_match.group(1)) - # if tool_call_json.get('name') == ref_tool.get('name'): - # reward += 3.0 - - # # Check if arguments overlap - # if 'arguments' in tool_call_json and 'arguments' in ref_tool: - # arg_overlap = len(set(tool_call_json['arguments'].keys()) & - # set(ref_tool['arguments'].keys())) - # reward += arg_overlap * 1.0 - # except: - # reward -= 1.0 - - # # 2. Penalize if no tool call when reference has one - # if '' in reference and '' not in completion: - # reward -= 5.0 - - # # 3. Penalize if tool call when reference doesn't have one - # if '' not in reference and '' in completion: - # reward -= 2.0 - - # # 4. Length penalty (avoid overly verbose responses) - # if len(completion) > len(reference) * 2: - # reward -= 1.0 - - # rewards.append(reward) - - # return rewards # GRPO Configuration grpo_config = GRPOConfig( @@ -205,8 +190,9 @@ def reward_function(prompts, completions, **kwargs): num_train_epochs=3, per_device_train_batch_size=2, per_device_eval_batch_size=4, - gradient_accumulation_steps=8, - learning_rate=5e-6, + gradient_accumulation_steps=2, + # learning_rate=5e-4, + learning_rate=1e-5, warmup_steps=100, logging_steps=10, save_steps=100, From f068a584078de7201d77e16c906b39ff5ab02923 Mon Sep 17 00:00:00 2001 From: Nidhi Hiremath Date: Thu, 23 Oct 2025 10:23:45 -0700 Subject: [PATCH 4/5] add wandb project + make test set smaller for faster iter --- recipes/tool_calls/tc_grpo.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/recipes/tool_calls/tc_grpo.py b/recipes/tool_calls/tc_grpo.py index f75cd65..3acde99 100644 --- a/recipes/tool_calls/tc_grpo.py +++ b/recipes/tool_calls/tc_grpo.py @@ -2,6 +2,7 @@ import ast import re +import os import json import torch from transformers import AutoTokenizer, AutoModelForCausalLM @@ -9,7 +10,7 @@ from datasets import load_dataset from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training - +os.environ["WANDB_PROJECT"] = "qwen-tool-calling-grpo-t" # Load dataset # ds = load_dataset("NousResearch/hermes-function-calling-v1", "func_calling_singleturn") # ds = load_dataset("baseten-admin/CleanNousResearch_simple") @@ -47,7 +48,8 @@ def format_example(x, add_generation_prompt=False): formatted_dataset = ds.map(format_example).filter(lambda x: x is not None) # formatted_dataset = formatted_dataset.remove_columns(["conversations", "category", "subcategory", "task"]) -train_test_split = formatted_dataset['train'].train_test_split(test_size=0.1, seed=42) +# train_test_split = formatted_dataset['train'].train_test_split(test_size=0.1, seed=42) +train_test_split = formatted_dataset['train'].train_test_split(test_size=300, seed=42) train_dataset = train_test_split['train'] eval_dataset = train_test_split['test'] @@ -103,11 +105,18 @@ def get_selected_tools_list(generated_text): parsed_tools = [] for tool in tools: try: + stripped_tool = tool.encode().decode('unicode_escape').strip() # tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) # escape chars and some other nonsense - tool = json.loads(tool.encode().decode('unicode_escape').strip()) + tool = json.loads(stripped_tool) parsed_tools.append(tool) - except Exception as e: - print(f"Exception in parsing generated tools for {tool}") + except: + try: + tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) + print("Prsed with ast.literal_eval") + parsed_tools.append(tool) + except Exception as e: + print(f"Exception in parsing generated tools for {tool}") + print("Generated text\n", generated_text) return parsed_tools # def reward_function(prompts, completions, references, tools_list): @@ -182,6 +191,7 @@ def reward_function(prompts, completions, **kwargs): # partial match: maybe +0 or some smaller reward reward += REWARD_CORRECT_CALL * 0.5 rewards.append(reward) + print("Rewards:", rewards) return rewards # GRPO Configuration @@ -199,7 +209,8 @@ def reward_function(prompts, completions, **kwargs): eval_strategy="steps", eval_steps=10, bf16=True, - max_grad_norm=0.3, + max_grad_norm=1.0, + report_to=["wandb"], # GRPO-specific parameters num_generations=4, # Number of generations per prompt for group comparison @@ -247,7 +258,7 @@ def test_tool_calling(prompt, tools): with torch.no_grad(): outputs = model.generate( **inputs, - max_new_tokens=512, + max_new_tokens=1024, temperature=0.7, do_sample=True ) From f8f898b01b1bc3e637e399b409615df660f210bb Mon Sep 17 00:00:00 2001 From: Nidhi Hiremath Date: Mon, 27 Oct 2025 15:26:09 -0700 Subject: [PATCH 5/5] more bkup --- recipes/tool_calls/filter_ds.py | 133 ++++++++++++++ recipes/tool_calls/grpo_pt2.py | 296 ++++++++++++++++++++++++++++++++ recipes/tool_calls/tc_grpo.py | 4 +- 3 files changed, 431 insertions(+), 2 deletions(-) create mode 100644 recipes/tool_calls/filter_ds.py create mode 100644 recipes/tool_calls/grpo_pt2.py diff --git a/recipes/tool_calls/filter_ds.py b/recipes/tool_calls/filter_ds.py new file mode 100644 index 0000000..bb5abeb --- /dev/null +++ b/recipes/tool_calls/filter_ds.py @@ -0,0 +1,133 @@ +import ast +import re +import json +import jsonschema +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments +from datasets import Dataset, load_dataset, DatasetDict +from vllm import LLM, SamplingParams + +ds = load_dataset("Salesforce/xlam-function-calling-60k") + +# 4. Load model and tokenizer +model_name = "Qwen/Qwen2.5-7B-Instruct" # or Qwen2.5-3B, Qwen2.5-1.5B +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token + +def format_example(x, add_generation_prompt=True): + try: + pattern = r'(.*?)' + tools = json.loads(x['tools']) + text = tokenizer.apply_chat_template( + [ + {"role": "user", "content": x['query']}, + ], + tools=tools, + add_generation_prompt=add_generation_prompt, + tokenize=False + ) + return { + "text": text, + # "tools": tools, + "gt_generation": x['answers'], + } + except Exception as e: + print(e) + # Return None or a sentinel value to filter out later + return None + +LIMIT = 60000 +formatted_dataset = ds.map(format_example).filter(lambda x: x is not None) +llm = LLM(model="Qwen/Qwen2.5-7B-Instruct") +sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=1024) +prompts = [sample['text'] for sample in formatted_dataset['train']][:LIMIT] +generations = llm.generate(prompts, sampling_params) + +all_generations = [] +for idx, output in enumerate(generations): + prompt = output.prompt + generated_text = output.outputs[0].text + all_generations.append({ + "prompt": prompt, + "generated_text": generated_text, # needs to be parsed + "gt_generation": formatted_dataset['train'][idx]['gt_generation'], # this is already a list of tools to be called + "tools": formatted_dataset['train'][idx]['tools'], + "query": formatted_dataset['train'][idx]['query'], + "answers": formatted_dataset['train'][idx]['answers'], + "idx": formatted_dataset['train'][idx]['id'], + }) + +from typing import List, dict + +def get_selected_tools_dict(generated_text): + """Returns a dict of tool_name: list of tool_calls from the generated text in order""" + pattern =r'(.*?)' + tools = re.findall(pattern, generated_text, re.DOTALL) + parsed_tools = [] + for tool in tools: + try: + # tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) # escape chars and some other nonsense + # tool = json.loads(tool.encode().decode('unicode_escape').strip()) + tool = json.loads(tool) + parsed_tools.append(tool) + except: + try: + tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) + parsed_tools.append(tool) + except Exception as e: + print(f"Exception in parsing generated tools for {tool}") + print("Generated text\n", generated_text) + return parsed_tools + +def get_correctness_of_tool_calls(gen_text: str, gt_tools: str): + gen_tools = get_selected_tools_dict(gen_text) + gt_tools = json.loads(gt_tools) + if gen_tools == gt_tools: + return True + return False + +def get_parsed_tool_calls(gen_text: str, gt_tools: str): + gen_tools = get_selected_tools_dict(gen_text) + gt_tools = json.loads(gt_tools) + return gen_tools, gt_tools + +correct_count = 0 +correct_len_incorrect = 0 +too_many_count = 0 +too_few_count = 0 +incorrect_idxs = [] +for idx, sample in enumerate(all_generations): + # print(idx) + # if get_correctness_of_tool_calls(sample['generated_text'], sample['gt_generation']): + # correct_count += 1 + gen_tools, gt_tools = get_parsed_tool_calls(sample['generated_text'], sample['gt_generation']) + if gen_tools != gt_tools: + incorrect_idxs.append(sample['idx']) + # if len(gen_tools) > len(gt_tools): + # too_many_count += 1 + # elif len(gen_tools) < len(gt_tools): + # too_few_count += 1 + # else: + # if gen_tools == gt_tools: + # correct_count += 1 + # else: + # correct_len_incorrect += 1 + +# print(f"Total samples: {len(all_generations)}") +# print(f"Correct tool calls: {correct_count}") +# print(f"Correct length but incorrect tool calls: {correct_len_incorrect}") +# print(f"Too many tool calls: {too_many_count}") +# print(f"Too few tool calls: {too_few_count}") + +incorrect_set = set(incorrect_idxs) +correct_idxs = [i for i in range(len(all_generations)) if i not in incorrect_set] + +import random +sampled_correct_idxs = random.sample(correct_idxs, 7000) +filtered_generations = [all_generations[i] for i in sampled_correct_idxs] + [all_generations[i] for i in incorrect_idxs] +random.shuffle(filtered_generations) + + +new_ds = Dataset.from_list(filtered_generations) +dataset_dict = DatasetDict({"train": new_ds}) +dataset_dict.push_to_hub("baseten-admin/xlam-function-calling-sampled") \ No newline at end of file diff --git a/recipes/tool_calls/grpo_pt2.py b/recipes/tool_calls/grpo_pt2.py new file mode 100644 index 0000000..acbb1e4 --- /dev/null +++ b/recipes/tool_calls/grpo_pt2.py @@ -0,0 +1,296 @@ +## pip install -U bitsandbytes trl transformers datasets peft accelerate + +import ast +import re +import os +import json +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from trl import GRPOConfig, GRPOTrainer +from datasets import load_dataset +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training + +os.environ["WANDB_PROJECT"] = "qwen-tool-calling-grpo-t" +# Load dataset +# ds = load_dataset("NousResearch/hermes-function-calling-v1", "func_calling_singleturn") +# ds = load_dataset("baseten-admin/CleanNousResearch_simple") +ds = load_dataset("Salesforce/xlam-function-calling-60k") +pattern = r'(.*?)' + +# Load model and tokenizer +model_name = "Qwen/Qwen2.5-7B-Instruct" +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token + +deflearn(x, add_generation_prompt=False): + try: + pattern = r'(.*?)' + tools = json.loads(x['tools']) + text = tokenizer.apply_chat_template( + [ + {"role": "user", "content": x['query']}, + ], + tools=tools, + add_generation_prompt=add_generation_prompt, + tokenize=False + ) + return { + "prompt": text, + # "text": text, + # "tools": tools, + "gt_generation": x['answers'], + } + except Exception as e: + print(e) + # Return None or a sentinel value to filter out later + return None + +formatted_dataset = ds.map(format_example).filter(lambda x: x is not None) +# formatted_dataset = formatted_dataset.remove_columns(["conversations", "category", "subcategory", "task"]) + +# train_test_split = formatted_dataset['train'].train_test_split(test_size=0.1, seed=42) +train_test_split = formatted_dataset['train'].train_test_split(test_size=300, seed=42) +train_dataset = train_test_split['train'] +eval_dataset = train_test_split['test'] + +print(f"Training samples: {len(train_dataset)}") +print(f"Validation samples: {len(eval_dataset)}") + +# Load model with quantization +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + # load_in_8bit=True +) + +# Prepare model for LoRA fine-tuning +model = prepare_model_for_kbit_training(model) + +lora_config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "all_linear"], + lora_dropout=0.05, + bias="none", + # task_type="CAUSAL_LM" +) + +model = get_peft_model(model, lora_config) +model.print_trainable_parameters() + +# Define reward function for tool calling + +def get_all_available_tools_dict(prompt): + pattern = r'(.*?)' + try: + data_string = re.findall(pattern, prompt, re.DOTALL)[-1].strip() + json_objects = [] + for line in data_string.strip().split('\n'): + if line: # Skip empty lines + json_objects.append(json.loads(line)) + tool_dict = {} + for obj in json_objects: + tool_dict[obj['name']] = obj + return tool_dict + except exception as e: + print(f"Exception {e}") + return {} + +def get_selected_tools_list(generated_text): + """Returns a dict of tool_name: list of tool_calls from the generated text in order""" + pattern =r'(.*?)' + tools = re.findall(pattern, generated_text, re.DOTALL) + parsed_tools = [] + for tool in tools: + try: + # stripped_tool = tool.encode().decode('unicode_escape').strip() + # tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) # escape chars and some other nonsense + tool = json.loads(tool) + parsed_tools.append(tool) + except: + try: + tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) + print("Prsed with ast.literal_eval") + parsed_tools.append(tool) + except Exception as e: + print(f"Exception in parsing generated tools for {tool}") + print("Generated text\n", generated_text) + return parsed_tools + +# def reward_function(prompts, completions, references, tools_list): +def reward_function(prompts, completions, **kwargs): + """ + Reward function that evaluates the quality of tool calls. + Returns a list of reward scores (one per completion). + """ + # print("All kwargs:") + # for key in kwargs: + # print(f" {key}: {type(kwargs[key])}") + rewards = [] + # completions = kwargs.get("completions", []) + references = kwargs.get("gt_generation", []) + # references = kwargs.get("completion", []) + tools = kwargs.get("tools", []) + for prompts, completion, reference in zip(prompts, completions, references): + available_tools = get_all_available_tools_dict(prompts) # dict of tool_name: tool_def + reward = 0.0 + try: + gen_tools = get_selected_tools_list(completion) + except: + # not parsable or correct structure + print("Count not parse") + reward = -3.0 + rewards.append(reward) + continue + gt_tools = json.loads(reference) + if gt_tools == gen_tools: # perfect match + rewards.append(4.0) + continue + # lengths are unequal or something else is wrong + PENALTY_NAME = 1.0 + PENALTY_MISSING_PARAM = 0.5 + PENALTY_WRONG_PARAM = 0.5 + PENALTY_EXTRA_PARAM = 0.25 + REWARD_CORRECT_CALL = 1.0 + reward = 0.0 + # Account for extra/missing tool calls + if len(gen_tools) < len(gt_tools): + reward -= PENALTY_NAME * (len(gt_tools) - len(gen_tools)) + elif len(gen_tools) > len(gt_tools): + reward -= PENALTY_NAME * (len(gen_tools) - len(gt_tools)) + # Compare in order + for gt_tool, gen_tool in zip(gt_tools, gen_tools): + # tool name + if 'name' not in gt_tool: + print("GT tool missing name:", gt_tool) + # reward = 0.0 + continue + if 'name' not in gen_tool: + print("Generated tool missing name:", gen_tool) + gen_tool['name'] = '' + if gt_tool['name'] != gen_tool['name']: + reward -= PENALTY_NAME + continue + # params + gt_params = gt_tool.get('parameters', {}) + gen_params = gen_tool.get('parameters', {}) + # missing or wrong params + for key, value in gt_params.items(): + if key not in gen_params: + reward -= PENALTY_MISSING_PARAM + elif gen_params[key] != value: + reward -= PENALTY_WRONG_PARAM + # extra params + for key in gen_params: + if key not in gt_params: + reward -= PENALTY_EXTRA_PARAM + # reward for correct call (name + maybe params) + # But you might choose: only if no param penalties + if gt_params and not any(key not in gen_params or gen_params[key] != gt_params[key] for key in gt_params): + # all params matched + reward += REWARD_CORRECT_CALL + elif not gt_params: + # no params expected, name matches + reward += REWARD_CORRECT_CALL + else: + # partial match: maybe +0 or some smaller reward + reward += REWARD_CORRECT_CALL * 0.5 + rewards.append(reward) + print("Rewards:", rewards) + return rewards + +# GRPO Configuration +grpo_config = GRPOConfig( + output_dir="./qwen-tool-calling-1", + num_train_epochs=3, + per_device_train_batch_size=2, + per_device_eval_batch_size=4, + # gradient_accumulation_steps=2, + gradient_accumulation_steps=1, + learning_rate=5e-4, + # learning_rate=1e-5, + lr_scheduler="cosine", + warmup_steps=50, + logging_steps=10, + save_steps=200, + eval_strategy="steps", + eval_steps=200, + bf16=True, + max_grad_norm=1.0, + report_to=["wandb"], + + # GRPO-specific parameters + num_generations=8, # Number of generations per prompt for group comparison + temperature=0.9, # Sampling temperature + max_new_tokens=512, + kl=0.05, # KL divergence coefficient to prevent drift from reference +) + +# Initialize GRPO Trainer +trainer = GRPOTrainer( + model=model, + args=grpo_config, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + # tokenizer=tokenizer, + reward_funcs=[reward_function], +) + +# Train the model +print("Starting GRPO training...") +trainer.train() + +# Save the fine-tuned model +trainer.save_model("./qwen-tool-calling-grpo-1") +tokenizer.save_pretrained("./qwen-tool-calling-grpo-1") + +print("Training complete!") + +# Example inference +def test_tool_calling(prompt, tools): + model.eval() + + formatted_prompt = tokenizer.apply_chat_template( + [ + {"role": "system", "content": "You are an assistant capable of tool calls"}, + {"role": "user", "content": prompt} + ], + tools=tools, + add_generation_prompt=True, + tokenize=False + ) + + inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=512, + temperature=0.7, + do_sample=True + ) + + response = tokenizer.decode(outputs[0], skip_special_tokens=False) + print(response) + +# Test the model +test_tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"} + }, + "required": ["location"] + } + } + } +] + +test_tool_calling("What's the weather in London?", test_tools) \ No newline at end of file diff --git a/recipes/tool_calls/tc_grpo.py b/recipes/tool_calls/tc_grpo.py index 3acde99..772a240 100644 --- a/recipes/tool_calls/tc_grpo.py +++ b/recipes/tool_calls/tc_grpo.py @@ -69,9 +69,9 @@ def format_example(x, add_generation_prompt=False): model = prepare_model_for_kbit_training(model) lora_config = LoraConfig( - r=16, + r=16, # higher rank = more weights in lora lora_alpha=32, - target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "all_linear"], + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "all_linear"], lora_dropout=0.05, bias="none", # task_type="CAUSAL_LM"