diff --git a/recipes/tool_calls/dataset_clean.py b/recipes/tool_calls/dataset_clean.py
new file mode 100644
index 0000000..7a37eae
--- /dev/null
+++ b/recipes/tool_calls/dataset_clean.py
@@ -0,0 +1,127 @@
+import ast
+import re
+import json
+import jsonschema
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
+from datasets import Dataset, load_dataset, DatasetDict
+from vllm import LLM, SamplingParams
+
+ds = load_dataset("NousResearch/hermes-function-calling-v1", "func_calling_singleturn")
+
+# 4. Load model and tokenizer
+model_name = "Qwen/Qwen2.5-7B-Instruct" # or Qwen2.5-3B, Qwen2.5-1.5B
+tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+tokenizer.pad_token = tokenizer.eos_token
+
+def format_example(x, add_generation_prompt=True):
+ try:
+ pattern = r'(.*?)'
+ tools = [s['function'] for s in ast.literal_eval(re.findall(pattern, x['conversations'][0]['value'], re.DOTALL)[-1].strip())]
+ text = tokenizer.apply_chat_template(
+ [
+ # {"role": "system", "content": "You are an assistant capable of tool calls"},
+ {"role": "user", "content": x['conversations'][-2]['value']},
+ # {"role": "assistant", "content": x['conversations'][-1]['value']}
+ ],
+ tools=tools,
+ add_generation_prompt=add_generation_prompt,
+ tokenize=False
+ )
+ return {
+ "text": text,
+ # "tools": tools,
+ "gt_generation": x['conversations'][-1]['value'],
+ }
+ except Exception as e:
+ print(e)
+ # Return None or a sentinel value to filter out later
+ return None
+
+formatted_dataset = ds.map(format_example).filter(lambda x: x is not None)
+
+llm = LLM(model="Qwen/Qwen2.5-7B-Instruct")
+sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=4096)
+prompts = [sample['text'] for sample in formatted_dataset['train']]
+generations = llm.generate(prompts, sampling_params)
+
+all_generations = []
+for idx, output in enumerate(generations):
+ prompt = output.prompt
+ generated_text = output.outputs[0].text
+ all_generations.append({
+ "prompt": prompt,
+ "generated_text": generated_text,
+ "gt_generation": formatted_dataset['train'][idx]['gt_generation'],
+ "conversations": formatted_dataset['train'][idx]['conversations'],
+ "idx": formatted_dataset['train'][idx]['id'],
+ })
+
+## Sometimes GT generation is garbage, make sure they are approximately correct with. json schema
+def get_all_available_tools_dict(prompt):
+ pattern = r'(.*?)'
+ try:
+ data_string = re.findall(pattern, prompt, re.DOTALL)[-1].strip()
+ json_objects = []
+ for line in data_string.strip().split('\n'):
+ if line: # Skip empty lines
+ json_objects.append(json.loads(line))
+ tool_dict = {}
+ for obj in json_objects:
+ tool_dict[obj['name']] = obj
+ return tool_dict
+ except exception as e:
+ print(f"Exception {e}")
+ return {}
+
+def get_selected_tools_dict(generated_text):
+ pattern =r'(.*?)'
+ tools = re.findall(pattern, generated_text, re.DOTALL)
+ tool_dict = {}
+ for tool in tools:
+ try:
+ tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) # escape chars and some other nonsense
+ if 'name' not in tool:
+ print(" No name in tool???")
+ continue
+ tool_dict[tool['name']] = tool_dict.get(tool['name'], []) + [tool]
+ except Exception as e:
+ print(f"Exception in get_selected_tools_dict for {tool}")
+ return tool_dict
+
+def tools_validity(avail_tools, tools):
+ if not tools:
+ return False
+ for tool_name, tools in tools.items():
+ for tool in tools:
+ if tool["name"] not in avail_tools:
+ print(" Hallucinating b**ch??")
+ return False
+ tool_name = tool["name"]
+ schema = avail_tools[tool_name].get("parameters", {}) # TODO - might depend on dataset
+ gen = tool["arguments"]
+ try:
+ jsonschema.validate(instance=gen, schema=schema)
+ except jsonschema.ValidationError as e:
+ print(f"Valid data validation error: {e.message}")
+ return False
+ # all tools validated
+ return True
+
+
+success = 0
+success_idxs = []
+for idx, gen in enumerate(all_generations):
+ avail_tools = get_all_available_tools_dict(gen['prompt'])
+ # gen_tools = get_selected_tools_dict(gen['generated_text'])
+ ref_tools = get_selected_tools_dict(gen['gt_generation'])
+ validity = tools_validity(avail_tools, ref_tools)
+ if validity:
+ success += 1
+ success_idxs.append(idx)
+
+filtered_generations = [all_generations[idx] for idx in success_idxs]
+
+new_ds = Dataset.from_list(filtered_generations)
+dataset_dict = DatasetDict({"train": new_ds})
+dataset_dict.push_to_hub("baseten-admin/CleanNousResearch_simple")
\ No newline at end of file
diff --git a/recipes/tool_calls/filter_ds.py b/recipes/tool_calls/filter_ds.py
new file mode 100644
index 0000000..bb5abeb
--- /dev/null
+++ b/recipes/tool_calls/filter_ds.py
@@ -0,0 +1,133 @@
+import ast
+import re
+import json
+import jsonschema
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
+from datasets import Dataset, load_dataset, DatasetDict
+from vllm import LLM, SamplingParams
+
+ds = load_dataset("Salesforce/xlam-function-calling-60k")
+
+# 4. Load model and tokenizer
+model_name = "Qwen/Qwen2.5-7B-Instruct" # or Qwen2.5-3B, Qwen2.5-1.5B
+tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+tokenizer.pad_token = tokenizer.eos_token
+
+def format_example(x, add_generation_prompt=True):
+ try:
+ pattern = r'(.*?)'
+ tools = json.loads(x['tools'])
+ text = tokenizer.apply_chat_template(
+ [
+ {"role": "user", "content": x['query']},
+ ],
+ tools=tools,
+ add_generation_prompt=add_generation_prompt,
+ tokenize=False
+ )
+ return {
+ "text": text,
+ # "tools": tools,
+ "gt_generation": x['answers'],
+ }
+ except Exception as e:
+ print(e)
+ # Return None or a sentinel value to filter out later
+ return None
+
+LIMIT = 60000
+formatted_dataset = ds.map(format_example).filter(lambda x: x is not None)
+llm = LLM(model="Qwen/Qwen2.5-7B-Instruct")
+sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=1024)
+prompts = [sample['text'] for sample in formatted_dataset['train']][:LIMIT]
+generations = llm.generate(prompts, sampling_params)
+
+all_generations = []
+for idx, output in enumerate(generations):
+ prompt = output.prompt
+ generated_text = output.outputs[0].text
+ all_generations.append({
+ "prompt": prompt,
+ "generated_text": generated_text, # needs to be parsed
+ "gt_generation": formatted_dataset['train'][idx]['gt_generation'], # this is already a list of tools to be called
+ "tools": formatted_dataset['train'][idx]['tools'],
+ "query": formatted_dataset['train'][idx]['query'],
+ "answers": formatted_dataset['train'][idx]['answers'],
+ "idx": formatted_dataset['train'][idx]['id'],
+ })
+
+from typing import List, dict
+
+def get_selected_tools_dict(generated_text):
+ """Returns a dict of tool_name: list of tool_calls from the generated text in order"""
+ pattern =r'(.*?)'
+ tools = re.findall(pattern, generated_text, re.DOTALL)
+ parsed_tools = []
+ for tool in tools:
+ try:
+ # tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) # escape chars and some other nonsense
+ # tool = json.loads(tool.encode().decode('unicode_escape').strip())
+ tool = json.loads(tool)
+ parsed_tools.append(tool)
+ except:
+ try:
+ tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'"))
+ parsed_tools.append(tool)
+ except Exception as e:
+ print(f"Exception in parsing generated tools for {tool}")
+ print("Generated text\n", generated_text)
+ return parsed_tools
+
+def get_correctness_of_tool_calls(gen_text: str, gt_tools: str):
+ gen_tools = get_selected_tools_dict(gen_text)
+ gt_tools = json.loads(gt_tools)
+ if gen_tools == gt_tools:
+ return True
+ return False
+
+def get_parsed_tool_calls(gen_text: str, gt_tools: str):
+ gen_tools = get_selected_tools_dict(gen_text)
+ gt_tools = json.loads(gt_tools)
+ return gen_tools, gt_tools
+
+correct_count = 0
+correct_len_incorrect = 0
+too_many_count = 0
+too_few_count = 0
+incorrect_idxs = []
+for idx, sample in enumerate(all_generations):
+ # print(idx)
+ # if get_correctness_of_tool_calls(sample['generated_text'], sample['gt_generation']):
+ # correct_count += 1
+ gen_tools, gt_tools = get_parsed_tool_calls(sample['generated_text'], sample['gt_generation'])
+ if gen_tools != gt_tools:
+ incorrect_idxs.append(sample['idx'])
+ # if len(gen_tools) > len(gt_tools):
+ # too_many_count += 1
+ # elif len(gen_tools) < len(gt_tools):
+ # too_few_count += 1
+ # else:
+ # if gen_tools == gt_tools:
+ # correct_count += 1
+ # else:
+ # correct_len_incorrect += 1
+
+# print(f"Total samples: {len(all_generations)}")
+# print(f"Correct tool calls: {correct_count}")
+# print(f"Correct length but incorrect tool calls: {correct_len_incorrect}")
+# print(f"Too many tool calls: {too_many_count}")
+# print(f"Too few tool calls: {too_few_count}")
+
+incorrect_set = set(incorrect_idxs)
+correct_idxs = [i for i in range(len(all_generations)) if i not in incorrect_set]
+
+import random
+sampled_correct_idxs = random.sample(correct_idxs, 7000)
+filtered_generations = [all_generations[i] for i in sampled_correct_idxs] + [all_generations[i] for i in incorrect_idxs]
+random.shuffle(filtered_generations)
+
+
+new_ds = Dataset.from_list(filtered_generations)
+dataset_dict = DatasetDict({"train": new_ds})
+dataset_dict.push_to_hub("baseten-admin/xlam-function-calling-sampled")
\ No newline at end of file
diff --git a/recipes/tool_calls/grpo_pt2.py b/recipes/tool_calls/grpo_pt2.py
new file mode 100644
index 0000000..acbb1e4
--- /dev/null
+++ b/recipes/tool_calls/grpo_pt2.py
@@ -0,0 +1,296 @@
+## pip install -U bitsandbytes trl transformers datasets peft accelerate
+
+import ast
+import re
+import os
+import json
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from trl import GRPOConfig, GRPOTrainer
+from datasets import load_dataset
+from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
+
+os.environ["WANDB_PROJECT"] = "qwen-tool-calling-grpo-t"
+# Load dataset
+# ds = load_dataset("NousResearch/hermes-function-calling-v1", "func_calling_singleturn")
+# ds = load_dataset("baseten-admin/CleanNousResearch_simple")
+ds = load_dataset("Salesforce/xlam-function-calling-60k")
+pattern = r'(.*?)'
+
+# Load model and tokenizer
+model_name = "Qwen/Qwen2.5-7B-Instruct"
+tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+tokenizer.pad_token = tokenizer.eos_token
+
+deflearn(x, add_generation_prompt=False):
+ try:
+ pattern = r'(.*?)'
+ tools = json.loads(x['tools'])
+ text = tokenizer.apply_chat_template(
+ [
+ {"role": "user", "content": x['query']},
+ ],
+ tools=tools,
+ add_generation_prompt=add_generation_prompt,
+ tokenize=False
+ )
+ return {
+ "prompt": text,
+ # "text": text,
+ # "tools": tools,
+ "gt_generation": x['answers'],
+ }
+ except Exception as e:
+ print(e)
+ # Return None or a sentinel value to filter out later
+ return None
+
+formatted_dataset = ds.map(format_example).filter(lambda x: x is not None)
+# formatted_dataset = formatted_dataset.remove_columns(["conversations", "category", "subcategory", "task"])
+
+# train_test_split = formatted_dataset['train'].train_test_split(test_size=0.1, seed=42)
+train_test_split = formatted_dataset['train'].train_test_split(test_size=300, seed=42)
+train_dataset = train_test_split['train']
+eval_dataset = train_test_split['test']
+
+print(f"Training samples: {len(train_dataset)}")
+print(f"Validation samples: {len(eval_dataset)}")
+
+# Load model with quantization
+model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ trust_remote_code=True,
+ # load_in_8bit=True
+)
+
+# Prepare model for LoRA fine-tuning
+model = prepare_model_for_kbit_training(model)
+
+lora_config = LoraConfig(
+ r=16,
+ lora_alpha=32,
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "all_linear"],
+ lora_dropout=0.05,
+ bias="none",
+ # task_type="CAUSAL_LM"
+)
+
+model = get_peft_model(model, lora_config)
+model.print_trainable_parameters()
+
+# Define reward function for tool calling
+
+def get_all_available_tools_dict(prompt):
+ pattern = r'(.*?)'
+ try:
+ data_string = re.findall(pattern, prompt, re.DOTALL)[-1].strip()
+ json_objects = []
+ for line in data_string.strip().split('\n'):
+ if line: # Skip empty lines
+ json_objects.append(json.loads(line))
+ tool_dict = {}
+ for obj in json_objects:
+ tool_dict[obj['name']] = obj
+ return tool_dict
+ except exception as e:
+ print(f"Exception {e}")
+ return {}
+
+def get_selected_tools_list(generated_text):
+ """Returns a dict of tool_name: list of tool_calls from the generated text in order"""
+ pattern =r'(.*?)'
+ tools = re.findall(pattern, generated_text, re.DOTALL)
+ parsed_tools = []
+ for tool in tools:
+ try:
+ # stripped_tool = tool.encode().decode('unicode_escape').strip()
+ # tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) # escape chars and some other nonsense
+ tool = json.loads(tool)
+ parsed_tools.append(tool)
+ except:
+ try:
+ tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'"))
+ print("Prsed with ast.literal_eval")
+ parsed_tools.append(tool)
+ except Exception as e:
+ print(f"Exception in parsing generated tools for {tool}")
+ print("Generated text\n", generated_text)
+ return parsed_tools
+
+# def reward_function(prompts, completions, references, tools_list):
+def reward_function(prompts, completions, **kwargs):
+ """
+ Reward function that evaluates the quality of tool calls.
+ Returns a list of reward scores (one per completion).
+ """
+ # print("All kwargs:")
+ # for key in kwargs:
+ # print(f" {key}: {type(kwargs[key])}")
+ rewards = []
+ # completions = kwargs.get("completions", [])
+ references = kwargs.get("gt_generation", [])
+ # references = kwargs.get("completion", [])
+ tools = kwargs.get("tools", [])
+ for prompts, completion, reference in zip(prompts, completions, references):
+ available_tools = get_all_available_tools_dict(prompts) # dict of tool_name: tool_def
+ reward = 0.0
+ try:
+ gen_tools = get_selected_tools_list(completion)
+ except:
+ # not parsable or correct structure
+ print("Count not parse")
+ reward = -3.0
+ rewards.append(reward)
+ continue
+ gt_tools = json.loads(reference)
+ if gt_tools == gen_tools: # perfect match
+ rewards.append(4.0)
+ continue
+ # lengths are unequal or something else is wrong
+ PENALTY_NAME = 1.0
+ PENALTY_MISSING_PARAM = 0.5
+ PENALTY_WRONG_PARAM = 0.5
+ PENALTY_EXTRA_PARAM = 0.25
+ REWARD_CORRECT_CALL = 1.0
+ reward = 0.0
+ # Account for extra/missing tool calls
+ if len(gen_tools) < len(gt_tools):
+ reward -= PENALTY_NAME * (len(gt_tools) - len(gen_tools))
+ elif len(gen_tools) > len(gt_tools):
+ reward -= PENALTY_NAME * (len(gen_tools) - len(gt_tools))
+ # Compare in order
+ for gt_tool, gen_tool in zip(gt_tools, gen_tools):
+ # tool name
+ if 'name' not in gt_tool:
+ print("GT tool missing name:", gt_tool)
+ # reward = 0.0
+ continue
+ if 'name' not in gen_tool:
+ print("Generated tool missing name:", gen_tool)
+ gen_tool['name'] = ''
+ if gt_tool['name'] != gen_tool['name']:
+ reward -= PENALTY_NAME
+ continue
+ # params
+ gt_params = gt_tool.get('parameters', {})
+ gen_params = gen_tool.get('parameters', {})
+ # missing or wrong params
+ for key, value in gt_params.items():
+ if key not in gen_params:
+ reward -= PENALTY_MISSING_PARAM
+ elif gen_params[key] != value:
+ reward -= PENALTY_WRONG_PARAM
+ # extra params
+ for key in gen_params:
+ if key not in gt_params:
+ reward -= PENALTY_EXTRA_PARAM
+ # reward for correct call (name + maybe params)
+ # But you might choose: only if no param penalties
+ if gt_params and not any(key not in gen_params or gen_params[key] != gt_params[key] for key in gt_params):
+ # all params matched
+ reward += REWARD_CORRECT_CALL
+ elif not gt_params:
+ # no params expected, name matches
+ reward += REWARD_CORRECT_CALL
+ else:
+ # partial match: maybe +0 or some smaller reward
+ reward += REWARD_CORRECT_CALL * 0.5
+ rewards.append(reward)
+ print("Rewards:", rewards)
+ return rewards
+
+# GRPO Configuration
+grpo_config = GRPOConfig(
+ output_dir="./qwen-tool-calling-1",
+ num_train_epochs=3,
+ per_device_train_batch_size=2,
+ per_device_eval_batch_size=4,
+ # gradient_accumulation_steps=2,
+ gradient_accumulation_steps=1,
+ learning_rate=5e-4,
+ # learning_rate=1e-5,
+ lr_scheduler="cosine",
+ warmup_steps=50,
+ logging_steps=10,
+ save_steps=200,
+ eval_strategy="steps",
+ eval_steps=200,
+ bf16=True,
+ max_grad_norm=1.0,
+ report_to=["wandb"],
+
+ # GRPO-specific parameters
+ num_generations=8, # Number of generations per prompt for group comparison
+ temperature=0.9, # Sampling temperature
+ max_new_tokens=512,
+ kl=0.05, # KL divergence coefficient to prevent drift from reference
+)
+
+# Initialize GRPO Trainer
+trainer = GRPOTrainer(
+ model=model,
+ args=grpo_config,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ # tokenizer=tokenizer,
+ reward_funcs=[reward_function],
+)
+
+# Train the model
+print("Starting GRPO training...")
+trainer.train()
+
+# Save the fine-tuned model
+trainer.save_model("./qwen-tool-calling-grpo-1")
+tokenizer.save_pretrained("./qwen-tool-calling-grpo-1")
+
+print("Training complete!")
+
+# Example inference
+def test_tool_calling(prompt, tools):
+ model.eval()
+
+ formatted_prompt = tokenizer.apply_chat_template(
+ [
+ {"role": "system", "content": "You are an assistant capable of tool calls"},
+ {"role": "user", "content": prompt}
+ ],
+ tools=tools,
+ add_generation_prompt=True,
+ tokenize=False
+ )
+
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
+
+ with torch.no_grad():
+ outputs = model.generate(
+ **inputs,
+ max_new_tokens=512,
+ temperature=0.7,
+ do_sample=True
+ )
+
+ response = tokenizer.decode(outputs[0], skip_special_tokens=False)
+ print(response)
+
+# Test the model
+test_tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get the current weather for a location",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {"type": "string", "description": "City name"}
+ },
+ "required": ["location"]
+ }
+ }
+ }
+]
+
+test_tool_calling("What's the weather in London?", test_tools)
\ No newline at end of file
diff --git a/recipes/tool_calls/new_ds_test.py b/recipes/tool_calls/new_ds_test.py
new file mode 100644
index 0000000..f5bbd85
--- /dev/null
+++ b/recipes/tool_calls/new_ds_test.py
@@ -0,0 +1,109 @@
+import ast
+import re
+import json
+import jsonschema
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
+from datasets import Dataset, load_dataset, DatasetDict
+from vllm import LLM, SamplingParams
+
+ds = load_dataset("Salesforce/xlam-function-calling-60k")
+
+# 4. Load model and tokenizer
+model_name = "Qwen/Qwen2.5-7B-Instruct" # or Qwen2.5-3B, Qwen2.5-1.5B
+tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+tokenizer.pad_token = tokenizer.eos_token
+
+def format_example(x, add_generation_prompt=True):
+ try:
+ pattern = r'(.*?)'
+ tools = json.loads(x['tools'])
+ text = tokenizer.apply_chat_template(
+ [
+ {"role": "user", "content": x['query']},
+ ],
+ tools=tools,
+ add_generation_prompt=add_generation_prompt,
+ tokenize=False
+ )
+ return {
+ "text": text,
+ # "tools": tools,
+ "gt_generation": x['answers'],
+ }
+ except Exception as e:
+ print(e)
+ # Return None or a sentinel value to filter out later
+ return None
+
+LIMIT = 500
+formatted_dataset = ds.map(format_example).filter(lambda x: x is not None)
+llm = LLM(model="Qwen/Qwen2.5-7B-Instruct")
+sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=4096)
+prompts = [sample['text'] for sample in formatted_dataset['train']][:LIMIT]
+generations = llm.generate(prompts, sampling_params)
+
+all_generations = []
+for idx, output in enumerate(generations):
+ prompt = output.prompt
+ generated_text = output.outputs[0].text
+ all_generations.append({
+ "prompt": prompt,
+ "generated_text": generated_text, # needs to be parsed
+ "gt_generation": formatted_dataset['train'][idx]['gt_generation'], # this is already a list of tools to be called
+ "tools": formatted_dataset['train'][idx]['tools'],
+ "idx": formatted_dataset['train'][idx]['id'],
+ })
+
+from typing import List, dict
+
+def get_selected_tools_dict(generated_text):
+ """Returns a dict of tool_name: list of tool_calls from the generated text in order"""
+ pattern =r'(.*?)'
+ tools = re.findall(pattern, generated_text, re.DOTALL)
+ parsed_tools = []
+ for tool in tools:
+ try:
+ # tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) # escape chars and some other nonsense
+ tool = json.loads(tool.encode().decode('unicode_escape').strip())
+ parsed_tools.append(tool)
+ except Exception as e:
+ print(f"Exception in parsing generated tools for {tool}")
+ return parsed_tools
+
+def get_correctness_of_tool_calls(gen_text: str, gt_tools: str):
+ gen_tools = get_selected_tools_dict(gen_text)
+ gt_tools = json.loads(gt_tools)
+ if gen_tools == gt_tools:
+ return True
+ return False
+
+def get_parsed_tool_calls(gen_text: str, gt_tools: str):
+ gen_tools = get_selected_tools_dict(gen_text)
+ gt_tools = json.loads(gt_tools)
+ return gen_tools, gt_tools
+
+correct_count = 0
+correct_len_incorrect = 0
+too_many_count = 0
+too_few_count = 0
+for idx, sample in enumerate(all_generations):
+ # print(idx)
+ # if get_correctness_of_tool_calls(sample['generated_text'], sample['gt_generation']):
+ # correct_count += 1
+ gen_tools, gt_tools = get_parsed_tool_calls(sample['generated_text'], sample['gt_generation'])
+ if len(gen_tools) > len(gt_tools):
+ too_many_count += 1
+ elif len(gen_tools) < len(gt_tools):
+ too_few_count += 1
+ else:
+ if gen_tools == gt_tools:
+ correct_count += 1
+ else:
+ correct_len_incorrect += 1
+
+print(f"Total samples: {len(all_generations)}")
+print(f"Correct tool calls: {correct_count}")
+print(f"Correct length but incorrect tool calls: {correct_len_incorrect}")
+print(f"Too many tool calls: {too_many_count}")
+print(f"Too few tool calls: {too_few_count}")
diff --git a/recipes/tool_calls/tc_grpo.py b/recipes/tool_calls/tc_grpo.py
new file mode 100644
index 0000000..772a240
--- /dev/null
+++ b/recipes/tool_calls/tc_grpo.py
@@ -0,0 +1,287 @@
+## pip install -U bitsandbytes trl transformers datasets peft accelerate
+
+import ast
+import re
+import os
+import json
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from trl import GRPOConfig, GRPOTrainer
+from datasets import load_dataset
+from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
+
+os.environ["WANDB_PROJECT"] = "qwen-tool-calling-grpo-t"
+# Load dataset
+# ds = load_dataset("NousResearch/hermes-function-calling-v1", "func_calling_singleturn")
+# ds = load_dataset("baseten-admin/CleanNousResearch_simple")
+ds = load_dataset("Salesforce/xlam-function-calling-60k")
+pattern = r'(.*?)'
+
+# Load model and tokenizer
+model_name = "Qwen/Qwen2.5-7B-Instruct"
+tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+tokenizer.pad_token = tokenizer.eos_token
+
+def format_example(x, add_generation_prompt=False):
+ try:
+ pattern = r'(.*?)'
+ tools = json.loads(x['tools'])
+ text = tokenizer.apply_chat_template(
+ [
+ {"role": "user", "content": x['query']},
+ ],
+ tools=tools,
+ add_generation_prompt=add_generation_prompt,
+ tokenize=False
+ )
+ return {
+ "prompt": text,
+ # "text": text,
+ # "tools": tools,
+ "gt_generation": x['answers'],
+ }
+ except Exception as e:
+ print(e)
+ # Return None or a sentinel value to filter out later
+ return None
+
+formatted_dataset = ds.map(format_example).filter(lambda x: x is not None)
+# formatted_dataset = formatted_dataset.remove_columns(["conversations", "category", "subcategory", "task"])
+
+# train_test_split = formatted_dataset['train'].train_test_split(test_size=0.1, seed=42)
+train_test_split = formatted_dataset['train'].train_test_split(test_size=300, seed=42)
+train_dataset = train_test_split['train']
+eval_dataset = train_test_split['test']
+
+print(f"Training samples: {len(train_dataset)}")
+print(f"Validation samples: {len(eval_dataset)}")
+
+# Load model with quantization
+model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ trust_remote_code=True,
+ # load_in_8bit=True
+)
+
+# Prepare model for LoRA fine-tuning
+model = prepare_model_for_kbit_training(model)
+
+lora_config = LoraConfig(
+ r=16, # higher rank = more weights in lora
+ lora_alpha=32,
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "all_linear"],
+ lora_dropout=0.05,
+ bias="none",
+ # task_type="CAUSAL_LM"
+)
+
+model = get_peft_model(model, lora_config)
+model.print_trainable_parameters()
+
+# Define reward function for tool calling
+
+def get_all_available_tools_dict(prompt):
+ pattern = r'(.*?)'
+ try:
+ data_string = re.findall(pattern, prompt, re.DOTALL)[-1].strip()
+ json_objects = []
+ for line in data_string.strip().split('\n'):
+ if line: # Skip empty lines
+ json_objects.append(json.loads(line))
+ tool_dict = {}
+ for obj in json_objects:
+ tool_dict[obj['name']] = obj
+ return tool_dict
+ except exception as e:
+ print(f"Exception {e}")
+ return {}
+
+def get_selected_tools_list(generated_text):
+ """Returns a dict of tool_name: list of tool_calls from the generated text in order"""
+ pattern =r'(.*?)'
+ tools = re.findall(pattern, generated_text, re.DOTALL)
+ parsed_tools = []
+ for tool in tools:
+ try:
+ stripped_tool = tool.encode().decode('unicode_escape').strip()
+ # tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) # escape chars and some other nonsense
+ tool = json.loads(stripped_tool)
+ parsed_tools.append(tool)
+ except:
+ try:
+ tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'"))
+ print("Prsed with ast.literal_eval")
+ parsed_tools.append(tool)
+ except Exception as e:
+ print(f"Exception in parsing generated tools for {tool}")
+ print("Generated text\n", generated_text)
+ return parsed_tools
+
+# def reward_function(prompts, completions, references, tools_list):
+def reward_function(prompts, completions, **kwargs):
+ """
+ Reward function that evaluates the quality of tool calls.
+ Returns a list of reward scores (one per completion).
+ """
+ # print("All kwargs:")
+ # for key in kwargs:
+ # print(f" {key}: {type(kwargs[key])}")
+ rewards = []
+ # completions = kwargs.get("completions", [])
+ references = kwargs.get("gt_generation", [])
+ # references = kwargs.get("completion", [])
+ tools = kwargs.get("tools", [])
+ for prompts, completion, reference in zip(prompts, completions, references):
+ available_tools = get_all_available_tools_dict(prompts) # dict of tool_name: tool_def
+ reward = 0.0
+ try:
+ gen_tools = get_selected_tools_list(completion)
+ except:
+ # not parsable or correct structure
+ print("Count not parse")
+ reward = -3.0
+ rewards.append(reward)
+ continue
+ gt_tools = json.loads(reference)
+ if gt_tools == gen_tools: # perfect match
+ rewards.append(4.0)
+ continue
+ # lengths are unequal or something else is wrong
+ PENALTY_NAME = 1.0
+ PENALTY_MISSING_PARAM = 0.5
+ PENALTY_WRONG_PARAM = 0.5
+ PENALTY_EXTRA_PARAM = 0.25
+ REWARD_CORRECT_CALL = 1.0
+ reward = 0.0
+ # Account for extra/missing tool calls
+ if len(gen_tools) < len(gt_tools):
+ reward -= PENALTY_NAME * (len(gt_tools) - len(gen_tools))
+ elif len(gen_tools) > len(gt_tools):
+ reward -= PENALTY_NAME * (len(gen_tools) - len(gt_tools))
+ # Compare in order
+ for gt_tool, gen_tool in zip(gt_tools, gen_tools):
+ # tool name
+ if gt_tool['name'] != gen_tool['name']:
+ reward -= PENALTY_NAME
+ continue
+ # params
+ gt_params = gt_tool.get('parameters', {})
+ gen_params = gen_tool.get('parameters', {})
+ # missing or wrong params
+ for key, value in gt_params.items():
+ if key not in gen_params:
+ reward -= PENALTY_MISSING_PARAM
+ elif gen_params[key] != value:
+ reward -= PENALTY_WRONG_PARAM
+ # extra params
+ for key in gen_params:
+ if key not in gt_params:
+ reward -= PENALTY_EXTRA_PARAM
+ # reward for correct call (name + maybe params)
+ # But you might choose: only if no param penalties
+ if gt_params and not any(key not in gen_params or gen_params[key] != gt_params[key] for key in gt_params):
+ # all params matched
+ reward += REWARD_CORRECT_CALL
+ elif not gt_params:
+ # no params expected, name matches
+ reward += REWARD_CORRECT_CALL
+ else:
+ # partial match: maybe +0 or some smaller reward
+ reward += REWARD_CORRECT_CALL * 0.5
+ rewards.append(reward)
+ print("Rewards:", rewards)
+ return rewards
+
+# GRPO Configuration
+grpo_config = GRPOConfig(
+ output_dir="./qwen-tool-calling-1",
+ num_train_epochs=3,
+ per_device_train_batch_size=2,
+ per_device_eval_batch_size=4,
+ gradient_accumulation_steps=2,
+ # learning_rate=5e-4,
+ learning_rate=1e-5,
+ warmup_steps=100,
+ logging_steps=10,
+ save_steps=100,
+ eval_strategy="steps",
+ eval_steps=10,
+ bf16=True,
+ max_grad_norm=1.0,
+ report_to=["wandb"],
+
+ # GRPO-specific parameters
+ num_generations=4, # Number of generations per prompt for group comparison
+ temperature=0.9, # Sampling temperature
+ # max_new_tokens=512,
+ # kl=0.05, # KL divergence coefficient to prevent drift from reference
+)
+
+# Initialize GRPO Trainer
+trainer = GRPOTrainer(
+ model=model,
+ args=grpo_config,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ # tokenizer=tokenizer,
+ reward_funcs=[reward_function],
+)
+
+# Train the model
+print("Starting GRPO training...")
+trainer.train()
+
+# Save the fine-tuned model
+trainer.save_model("./qwen-tool-calling-grpo-1")
+tokenizer.save_pretrained("./qwen-tool-calling-grpo-1")
+
+print("Training complete!")
+
+# Example inference
+def test_tool_calling(prompt, tools):
+ model.eval()
+
+ formatted_prompt = tokenizer.apply_chat_template(
+ [
+ {"role": "system", "content": "You are an assistant capable of tool calls"},
+ {"role": "user", "content": prompt}
+ ],
+ tools=tools,
+ add_generation_prompt=True,
+ tokenize=False
+ )
+
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
+
+ with torch.no_grad():
+ outputs = model.generate(
+ **inputs,
+ max_new_tokens=1024,
+ temperature=0.7,
+ do_sample=True
+ )
+
+ response = tokenizer.decode(outputs[0], skip_special_tokens=False)
+ print(response)
+
+# Test the model
+test_tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get the current weather for a location",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {"type": "string", "description": "City name"}
+ },
+ "required": ["location"]
+ }
+ }
+ }
+]
+
+test_tool_calling("What's the weather in London?", test_tools)
\ No newline at end of file
diff --git a/recipes/tool_calls/train_sft_lora.py b/recipes/tool_calls/train_sft_lora.py
new file mode 100644
index 0000000..840bbc6
--- /dev/null
+++ b/recipes/tool_calls/train_sft_lora.py
@@ -0,0 +1,152 @@
+## pip install -U bitsandbytes
+
+import ast
+import re
+import json
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
+from trl import SFTTrainer, SFTConfig
+from datasets import Dataset, load_dataset
+from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
+
+
+# dataset = load_dataset("younissk/tool-calling-mix")
+ds = load_dataset("NousResearch/hermes-function-calling-v1", "func_calling_singleturn")
+pattern = r'(.*?)'
+
+# 4. Load model and tokenizer
+model_name = "Qwen/Qwen2.5-7B-Instruct" # or Qwen2.5-3B, Qwen2.5-1.5B
+tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+tokenizer.pad_token = tokenizer.eos_token
+
+def format_example(x):
+ try:
+ tools = [s['function'] for s in ast.literal_eval(re.findall(pattern, x['conversations'][0]['value'], re.DOTALL)[-1].strip())]
+ text = tokenizer.apply_chat_template(
+ [
+ {"role": "system", "content": "You are an assistant capable of tool calls"},
+ {"role": "user", "content": x['conversations'][-2]['value']},
+ {"role": "assistant", "content": x['conversations'][-1]['value']}
+ ],
+ tools=tools,
+ add_generation_prompt=True,
+ tokenize=False
+ )
+ return {"text": text}
+ except Exception as e:
+ # Return None or a sentinel value to filter out later
+ return None
+
+formatted_dataset = ds.map(format_example).filter(lambda x: x is not None)
+formatted_dataset = formatted_dataset.remove_columns(["conversations", "category", "subcategory", "task"])
+
+train_test_split = formatted_dataset['train'].train_test_split(test_size=0.1, seed=42)
+train_dataset = train_test_split['train']
+eval_dataset = train_test_split['test']
+
+print(f"Training samples: {len(train_dataset)}")
+print(f"Validation samples: {len(eval_dataset)}")
+
+# Convert to dataset
+# formatted_data = [format_training_example(ex) for ex in training_examples]
+# dataset = Dataset.from_list(formatted_data)
+# dataset = formatted_dataset
+# print(" Dataset --- ", dataset)
+
+
+model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ trust_remote_code=True,
+ load_in_8bit=True # Use 8-bit quantization for memory efficiency
+)
+
+# 5. Prepare model for LoRA fine-tuning
+model = prepare_model_for_kbit_training(model)
+
+"""
+text = tokenizer.apply_chat_template(
+ messages,
+ tools=tools,
+ add_generation_prompt=True,
+ tokenize=False
+)
+"""
+
+lora_config = LoraConfig(
+ r=16, # LoRA rank
+ lora_alpha=32,
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
+ lora_dropout=0.05,
+ bias="none",
+ task_type="CAUSAL_LM"
+)
+
+model = get_peft_model(model, lora_config)
+model.print_trainable_parameters()
+
+sft_config = SFTConfig(
+ output_dir="./qwen-tool-calling3",
+ num_train_epochs=3,
+ per_device_train_batch_size=2,
+ gradient_accumulation_steps=8,
+ learning_rate=2e-4,
+ warmup_steps=100,
+ logging_steps=10,
+ save_steps=100,
+ eval_strategy="steps", # Enable evaluation
+ eval_steps=30, # Evaluate every 100 steps
+ optim="paged_adamw_8bit",
+ fp16=False,
+ bf16=True,
+ max_grad_norm=0.3,
+ # SFT-specific parameters
+ dataset_text_field='text',
+ max_length=2048,
+)
+trainer = SFTTrainer(
+ model=model,
+ args=sft_config,
+ train_dataset=formatted_dataset['train'],
+ eval_dataset=eval_dataset, # Add validation dataset
+)
+
+trainer.train()
+
+
+# 8. Train the model
+print("Starting training...")
+trainer.train()
+
+# 9. Save the fine-tuned model
+trainer.save_model("./qwen-tool-calling-final")
+tokenizer.save_pretrained("./qwen-tool-calling-final")
+
+print("Training complete!")
+
+# 10. Example inference
+def test_tool_calling(prompt):
+ model.eval()
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
+
+ with torch.no_grad():
+ outputs = model.generate(
+ **inputs,
+ max_new_tokens=512,
+ temperature=0.7,
+ do_sample=True
+ )
+
+ response = tokenizer.decode(outputs[0], skip_special_tokens=False)
+ print(response)
+
+# Test the model
+test_prompt = f"""<|im_start|>system
+You are a helpful assistant with access to tools: {json.dumps(tools)}<|im_end|>
+<|im_start|>user
+What's the weather in London?<|im_end|>
+<|im_start|>assistant
+"""
+
+test_tool_calling(test_prompt)
\ No newline at end of file