diff --git a/recipes/tool_calls/dataset_clean.py b/recipes/tool_calls/dataset_clean.py new file mode 100644 index 0000000..7a37eae --- /dev/null +++ b/recipes/tool_calls/dataset_clean.py @@ -0,0 +1,127 @@ +import ast +import re +import json +import jsonschema +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments +from datasets import Dataset, load_dataset, DatasetDict +from vllm import LLM, SamplingParams + +ds = load_dataset("NousResearch/hermes-function-calling-v1", "func_calling_singleturn") + +# 4. Load model and tokenizer +model_name = "Qwen/Qwen2.5-7B-Instruct" # or Qwen2.5-3B, Qwen2.5-1.5B +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token + +def format_example(x, add_generation_prompt=True): + try: + pattern = r'(.*?)' + tools = [s['function'] for s in ast.literal_eval(re.findall(pattern, x['conversations'][0]['value'], re.DOTALL)[-1].strip())] + text = tokenizer.apply_chat_template( + [ + # {"role": "system", "content": "You are an assistant capable of tool calls"}, + {"role": "user", "content": x['conversations'][-2]['value']}, + # {"role": "assistant", "content": x['conversations'][-1]['value']} + ], + tools=tools, + add_generation_prompt=add_generation_prompt, + tokenize=False + ) + return { + "text": text, + # "tools": tools, + "gt_generation": x['conversations'][-1]['value'], + } + except Exception as e: + print(e) + # Return None or a sentinel value to filter out later + return None + +formatted_dataset = ds.map(format_example).filter(lambda x: x is not None) + +llm = LLM(model="Qwen/Qwen2.5-7B-Instruct") +sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=4096) +prompts = [sample['text'] for sample in formatted_dataset['train']] +generations = llm.generate(prompts, sampling_params) + +all_generations = [] +for idx, output in enumerate(generations): + prompt = output.prompt + generated_text = output.outputs[0].text + all_generations.append({ + "prompt": prompt, + "generated_text": generated_text, + "gt_generation": formatted_dataset['train'][idx]['gt_generation'], + "conversations": formatted_dataset['train'][idx]['conversations'], + "idx": formatted_dataset['train'][idx]['id'], + }) + +## Sometimes GT generation is garbage, make sure they are approximately correct with. json schema +def get_all_available_tools_dict(prompt): + pattern = r'(.*?)' + try: + data_string = re.findall(pattern, prompt, re.DOTALL)[-1].strip() + json_objects = [] + for line in data_string.strip().split('\n'): + if line: # Skip empty lines + json_objects.append(json.loads(line)) + tool_dict = {} + for obj in json_objects: + tool_dict[obj['name']] = obj + return tool_dict + except exception as e: + print(f"Exception {e}") + return {} + +def get_selected_tools_dict(generated_text): + pattern =r'(.*?)' + tools = re.findall(pattern, generated_text, re.DOTALL) + tool_dict = {} + for tool in tools: + try: + tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) # escape chars and some other nonsense + if 'name' not in tool: + print(" No name in tool???") + continue + tool_dict[tool['name']] = tool_dict.get(tool['name'], []) + [tool] + except Exception as e: + print(f"Exception in get_selected_tools_dict for {tool}") + return tool_dict + +def tools_validity(avail_tools, tools): + if not tools: + return False + for tool_name, tools in tools.items(): + for tool in tools: + if tool["name"] not in avail_tools: + print(" Hallucinating b**ch??") + return False + tool_name = tool["name"] + schema = avail_tools[tool_name].get("parameters", {}) # TODO - might depend on dataset + gen = tool["arguments"] + try: + jsonschema.validate(instance=gen, schema=schema) + except jsonschema.ValidationError as e: + print(f"Valid data validation error: {e.message}") + return False + # all tools validated + return True + + +success = 0 +success_idxs = [] +for idx, gen in enumerate(all_generations): + avail_tools = get_all_available_tools_dict(gen['prompt']) + # gen_tools = get_selected_tools_dict(gen['generated_text']) + ref_tools = get_selected_tools_dict(gen['gt_generation']) + validity = tools_validity(avail_tools, ref_tools) + if validity: + success += 1 + success_idxs.append(idx) + +filtered_generations = [all_generations[idx] for idx in success_idxs] + +new_ds = Dataset.from_list(filtered_generations) +dataset_dict = DatasetDict({"train": new_ds}) +dataset_dict.push_to_hub("baseten-admin/CleanNousResearch_simple") \ No newline at end of file diff --git a/recipes/tool_calls/filter_ds.py b/recipes/tool_calls/filter_ds.py new file mode 100644 index 0000000..bb5abeb --- /dev/null +++ b/recipes/tool_calls/filter_ds.py @@ -0,0 +1,133 @@ +import ast +import re +import json +import jsonschema +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments +from datasets import Dataset, load_dataset, DatasetDict +from vllm import LLM, SamplingParams + +ds = load_dataset("Salesforce/xlam-function-calling-60k") + +# 4. Load model and tokenizer +model_name = "Qwen/Qwen2.5-7B-Instruct" # or Qwen2.5-3B, Qwen2.5-1.5B +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token + +def format_example(x, add_generation_prompt=True): + try: + pattern = r'(.*?)' + tools = json.loads(x['tools']) + text = tokenizer.apply_chat_template( + [ + {"role": "user", "content": x['query']}, + ], + tools=tools, + add_generation_prompt=add_generation_prompt, + tokenize=False + ) + return { + "text": text, + # "tools": tools, + "gt_generation": x['answers'], + } + except Exception as e: + print(e) + # Return None or a sentinel value to filter out later + return None + +LIMIT = 60000 +formatted_dataset = ds.map(format_example).filter(lambda x: x is not None) +llm = LLM(model="Qwen/Qwen2.5-7B-Instruct") +sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=1024) +prompts = [sample['text'] for sample in formatted_dataset['train']][:LIMIT] +generations = llm.generate(prompts, sampling_params) + +all_generations = [] +for idx, output in enumerate(generations): + prompt = output.prompt + generated_text = output.outputs[0].text + all_generations.append({ + "prompt": prompt, + "generated_text": generated_text, # needs to be parsed + "gt_generation": formatted_dataset['train'][idx]['gt_generation'], # this is already a list of tools to be called + "tools": formatted_dataset['train'][idx]['tools'], + "query": formatted_dataset['train'][idx]['query'], + "answers": formatted_dataset['train'][idx]['answers'], + "idx": formatted_dataset['train'][idx]['id'], + }) + +from typing import List, dict + +def get_selected_tools_dict(generated_text): + """Returns a dict of tool_name: list of tool_calls from the generated text in order""" + pattern =r'(.*?)' + tools = re.findall(pattern, generated_text, re.DOTALL) + parsed_tools = [] + for tool in tools: + try: + # tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) # escape chars and some other nonsense + # tool = json.loads(tool.encode().decode('unicode_escape').strip()) + tool = json.loads(tool) + parsed_tools.append(tool) + except: + try: + tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) + parsed_tools.append(tool) + except Exception as e: + print(f"Exception in parsing generated tools for {tool}") + print("Generated text\n", generated_text) + return parsed_tools + +def get_correctness_of_tool_calls(gen_text: str, gt_tools: str): + gen_tools = get_selected_tools_dict(gen_text) + gt_tools = json.loads(gt_tools) + if gen_tools == gt_tools: + return True + return False + +def get_parsed_tool_calls(gen_text: str, gt_tools: str): + gen_tools = get_selected_tools_dict(gen_text) + gt_tools = json.loads(gt_tools) + return gen_tools, gt_tools + +correct_count = 0 +correct_len_incorrect = 0 +too_many_count = 0 +too_few_count = 0 +incorrect_idxs = [] +for idx, sample in enumerate(all_generations): + # print(idx) + # if get_correctness_of_tool_calls(sample['generated_text'], sample['gt_generation']): + # correct_count += 1 + gen_tools, gt_tools = get_parsed_tool_calls(sample['generated_text'], sample['gt_generation']) + if gen_tools != gt_tools: + incorrect_idxs.append(sample['idx']) + # if len(gen_tools) > len(gt_tools): + # too_many_count += 1 + # elif len(gen_tools) < len(gt_tools): + # too_few_count += 1 + # else: + # if gen_tools == gt_tools: + # correct_count += 1 + # else: + # correct_len_incorrect += 1 + +# print(f"Total samples: {len(all_generations)}") +# print(f"Correct tool calls: {correct_count}") +# print(f"Correct length but incorrect tool calls: {correct_len_incorrect}") +# print(f"Too many tool calls: {too_many_count}") +# print(f"Too few tool calls: {too_few_count}") + +incorrect_set = set(incorrect_idxs) +correct_idxs = [i for i in range(len(all_generations)) if i not in incorrect_set] + +import random +sampled_correct_idxs = random.sample(correct_idxs, 7000) +filtered_generations = [all_generations[i] for i in sampled_correct_idxs] + [all_generations[i] for i in incorrect_idxs] +random.shuffle(filtered_generations) + + +new_ds = Dataset.from_list(filtered_generations) +dataset_dict = DatasetDict({"train": new_ds}) +dataset_dict.push_to_hub("baseten-admin/xlam-function-calling-sampled") \ No newline at end of file diff --git a/recipes/tool_calls/grpo_pt2.py b/recipes/tool_calls/grpo_pt2.py new file mode 100644 index 0000000..acbb1e4 --- /dev/null +++ b/recipes/tool_calls/grpo_pt2.py @@ -0,0 +1,296 @@ +## pip install -U bitsandbytes trl transformers datasets peft accelerate + +import ast +import re +import os +import json +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from trl import GRPOConfig, GRPOTrainer +from datasets import load_dataset +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training + +os.environ["WANDB_PROJECT"] = "qwen-tool-calling-grpo-t" +# Load dataset +# ds = load_dataset("NousResearch/hermes-function-calling-v1", "func_calling_singleturn") +# ds = load_dataset("baseten-admin/CleanNousResearch_simple") +ds = load_dataset("Salesforce/xlam-function-calling-60k") +pattern = r'(.*?)' + +# Load model and tokenizer +model_name = "Qwen/Qwen2.5-7B-Instruct" +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token + +deflearn(x, add_generation_prompt=False): + try: + pattern = r'(.*?)' + tools = json.loads(x['tools']) + text = tokenizer.apply_chat_template( + [ + {"role": "user", "content": x['query']}, + ], + tools=tools, + add_generation_prompt=add_generation_prompt, + tokenize=False + ) + return { + "prompt": text, + # "text": text, + # "tools": tools, + "gt_generation": x['answers'], + } + except Exception as e: + print(e) + # Return None or a sentinel value to filter out later + return None + +formatted_dataset = ds.map(format_example).filter(lambda x: x is not None) +# formatted_dataset = formatted_dataset.remove_columns(["conversations", "category", "subcategory", "task"]) + +# train_test_split = formatted_dataset['train'].train_test_split(test_size=0.1, seed=42) +train_test_split = formatted_dataset['train'].train_test_split(test_size=300, seed=42) +train_dataset = train_test_split['train'] +eval_dataset = train_test_split['test'] + +print(f"Training samples: {len(train_dataset)}") +print(f"Validation samples: {len(eval_dataset)}") + +# Load model with quantization +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + # load_in_8bit=True +) + +# Prepare model for LoRA fine-tuning +model = prepare_model_for_kbit_training(model) + +lora_config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "all_linear"], + lora_dropout=0.05, + bias="none", + # task_type="CAUSAL_LM" +) + +model = get_peft_model(model, lora_config) +model.print_trainable_parameters() + +# Define reward function for tool calling + +def get_all_available_tools_dict(prompt): + pattern = r'(.*?)' + try: + data_string = re.findall(pattern, prompt, re.DOTALL)[-1].strip() + json_objects = [] + for line in data_string.strip().split('\n'): + if line: # Skip empty lines + json_objects.append(json.loads(line)) + tool_dict = {} + for obj in json_objects: + tool_dict[obj['name']] = obj + return tool_dict + except exception as e: + print(f"Exception {e}") + return {} + +def get_selected_tools_list(generated_text): + """Returns a dict of tool_name: list of tool_calls from the generated text in order""" + pattern =r'(.*?)' + tools = re.findall(pattern, generated_text, re.DOTALL) + parsed_tools = [] + for tool in tools: + try: + # stripped_tool = tool.encode().decode('unicode_escape').strip() + # tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) # escape chars and some other nonsense + tool = json.loads(tool) + parsed_tools.append(tool) + except: + try: + tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) + print("Prsed with ast.literal_eval") + parsed_tools.append(tool) + except Exception as e: + print(f"Exception in parsing generated tools for {tool}") + print("Generated text\n", generated_text) + return parsed_tools + +# def reward_function(prompts, completions, references, tools_list): +def reward_function(prompts, completions, **kwargs): + """ + Reward function that evaluates the quality of tool calls. + Returns a list of reward scores (one per completion). + """ + # print("All kwargs:") + # for key in kwargs: + # print(f" {key}: {type(kwargs[key])}") + rewards = [] + # completions = kwargs.get("completions", []) + references = kwargs.get("gt_generation", []) + # references = kwargs.get("completion", []) + tools = kwargs.get("tools", []) + for prompts, completion, reference in zip(prompts, completions, references): + available_tools = get_all_available_tools_dict(prompts) # dict of tool_name: tool_def + reward = 0.0 + try: + gen_tools = get_selected_tools_list(completion) + except: + # not parsable or correct structure + print("Count not parse") + reward = -3.0 + rewards.append(reward) + continue + gt_tools = json.loads(reference) + if gt_tools == gen_tools: # perfect match + rewards.append(4.0) + continue + # lengths are unequal or something else is wrong + PENALTY_NAME = 1.0 + PENALTY_MISSING_PARAM = 0.5 + PENALTY_WRONG_PARAM = 0.5 + PENALTY_EXTRA_PARAM = 0.25 + REWARD_CORRECT_CALL = 1.0 + reward = 0.0 + # Account for extra/missing tool calls + if len(gen_tools) < len(gt_tools): + reward -= PENALTY_NAME * (len(gt_tools) - len(gen_tools)) + elif len(gen_tools) > len(gt_tools): + reward -= PENALTY_NAME * (len(gen_tools) - len(gt_tools)) + # Compare in order + for gt_tool, gen_tool in zip(gt_tools, gen_tools): + # tool name + if 'name' not in gt_tool: + print("GT tool missing name:", gt_tool) + # reward = 0.0 + continue + if 'name' not in gen_tool: + print("Generated tool missing name:", gen_tool) + gen_tool['name'] = '' + if gt_tool['name'] != gen_tool['name']: + reward -= PENALTY_NAME + continue + # params + gt_params = gt_tool.get('parameters', {}) + gen_params = gen_tool.get('parameters', {}) + # missing or wrong params + for key, value in gt_params.items(): + if key not in gen_params: + reward -= PENALTY_MISSING_PARAM + elif gen_params[key] != value: + reward -= PENALTY_WRONG_PARAM + # extra params + for key in gen_params: + if key not in gt_params: + reward -= PENALTY_EXTRA_PARAM + # reward for correct call (name + maybe params) + # But you might choose: only if no param penalties + if gt_params and not any(key not in gen_params or gen_params[key] != gt_params[key] for key in gt_params): + # all params matched + reward += REWARD_CORRECT_CALL + elif not gt_params: + # no params expected, name matches + reward += REWARD_CORRECT_CALL + else: + # partial match: maybe +0 or some smaller reward + reward += REWARD_CORRECT_CALL * 0.5 + rewards.append(reward) + print("Rewards:", rewards) + return rewards + +# GRPO Configuration +grpo_config = GRPOConfig( + output_dir="./qwen-tool-calling-1", + num_train_epochs=3, + per_device_train_batch_size=2, + per_device_eval_batch_size=4, + # gradient_accumulation_steps=2, + gradient_accumulation_steps=1, + learning_rate=5e-4, + # learning_rate=1e-5, + lr_scheduler="cosine", + warmup_steps=50, + logging_steps=10, + save_steps=200, + eval_strategy="steps", + eval_steps=200, + bf16=True, + max_grad_norm=1.0, + report_to=["wandb"], + + # GRPO-specific parameters + num_generations=8, # Number of generations per prompt for group comparison + temperature=0.9, # Sampling temperature + max_new_tokens=512, + kl=0.05, # KL divergence coefficient to prevent drift from reference +) + +# Initialize GRPO Trainer +trainer = GRPOTrainer( + model=model, + args=grpo_config, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + # tokenizer=tokenizer, + reward_funcs=[reward_function], +) + +# Train the model +print("Starting GRPO training...") +trainer.train() + +# Save the fine-tuned model +trainer.save_model("./qwen-tool-calling-grpo-1") +tokenizer.save_pretrained("./qwen-tool-calling-grpo-1") + +print("Training complete!") + +# Example inference +def test_tool_calling(prompt, tools): + model.eval() + + formatted_prompt = tokenizer.apply_chat_template( + [ + {"role": "system", "content": "You are an assistant capable of tool calls"}, + {"role": "user", "content": prompt} + ], + tools=tools, + add_generation_prompt=True, + tokenize=False + ) + + inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=512, + temperature=0.7, + do_sample=True + ) + + response = tokenizer.decode(outputs[0], skip_special_tokens=False) + print(response) + +# Test the model +test_tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"} + }, + "required": ["location"] + } + } + } +] + +test_tool_calling("What's the weather in London?", test_tools) \ No newline at end of file diff --git a/recipes/tool_calls/new_ds_test.py b/recipes/tool_calls/new_ds_test.py new file mode 100644 index 0000000..f5bbd85 --- /dev/null +++ b/recipes/tool_calls/new_ds_test.py @@ -0,0 +1,109 @@ +import ast +import re +import json +import jsonschema +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments +from datasets import Dataset, load_dataset, DatasetDict +from vllm import LLM, SamplingParams + +ds = load_dataset("Salesforce/xlam-function-calling-60k") + +# 4. Load model and tokenizer +model_name = "Qwen/Qwen2.5-7B-Instruct" # or Qwen2.5-3B, Qwen2.5-1.5B +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token + +def format_example(x, add_generation_prompt=True): + try: + pattern = r'(.*?)' + tools = json.loads(x['tools']) + text = tokenizer.apply_chat_template( + [ + {"role": "user", "content": x['query']}, + ], + tools=tools, + add_generation_prompt=add_generation_prompt, + tokenize=False + ) + return { + "text": text, + # "tools": tools, + "gt_generation": x['answers'], + } + except Exception as e: + print(e) + # Return None or a sentinel value to filter out later + return None + +LIMIT = 500 +formatted_dataset = ds.map(format_example).filter(lambda x: x is not None) +llm = LLM(model="Qwen/Qwen2.5-7B-Instruct") +sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=4096) +prompts = [sample['text'] for sample in formatted_dataset['train']][:LIMIT] +generations = llm.generate(prompts, sampling_params) + +all_generations = [] +for idx, output in enumerate(generations): + prompt = output.prompt + generated_text = output.outputs[0].text + all_generations.append({ + "prompt": prompt, + "generated_text": generated_text, # needs to be parsed + "gt_generation": formatted_dataset['train'][idx]['gt_generation'], # this is already a list of tools to be called + "tools": formatted_dataset['train'][idx]['tools'], + "idx": formatted_dataset['train'][idx]['id'], + }) + +from typing import List, dict + +def get_selected_tools_dict(generated_text): + """Returns a dict of tool_name: list of tool_calls from the generated text in order""" + pattern =r'(.*?)' + tools = re.findall(pattern, generated_text, re.DOTALL) + parsed_tools = [] + for tool in tools: + try: + # tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) # escape chars and some other nonsense + tool = json.loads(tool.encode().decode('unicode_escape').strip()) + parsed_tools.append(tool) + except Exception as e: + print(f"Exception in parsing generated tools for {tool}") + return parsed_tools + +def get_correctness_of_tool_calls(gen_text: str, gt_tools: str): + gen_tools = get_selected_tools_dict(gen_text) + gt_tools = json.loads(gt_tools) + if gen_tools == gt_tools: + return True + return False + +def get_parsed_tool_calls(gen_text: str, gt_tools: str): + gen_tools = get_selected_tools_dict(gen_text) + gt_tools = json.loads(gt_tools) + return gen_tools, gt_tools + +correct_count = 0 +correct_len_incorrect = 0 +too_many_count = 0 +too_few_count = 0 +for idx, sample in enumerate(all_generations): + # print(idx) + # if get_correctness_of_tool_calls(sample['generated_text'], sample['gt_generation']): + # correct_count += 1 + gen_tools, gt_tools = get_parsed_tool_calls(sample['generated_text'], sample['gt_generation']) + if len(gen_tools) > len(gt_tools): + too_many_count += 1 + elif len(gen_tools) < len(gt_tools): + too_few_count += 1 + else: + if gen_tools == gt_tools: + correct_count += 1 + else: + correct_len_incorrect += 1 + +print(f"Total samples: {len(all_generations)}") +print(f"Correct tool calls: {correct_count}") +print(f"Correct length but incorrect tool calls: {correct_len_incorrect}") +print(f"Too many tool calls: {too_many_count}") +print(f"Too few tool calls: {too_few_count}") diff --git a/recipes/tool_calls/tc_grpo.py b/recipes/tool_calls/tc_grpo.py new file mode 100644 index 0000000..772a240 --- /dev/null +++ b/recipes/tool_calls/tc_grpo.py @@ -0,0 +1,287 @@ +## pip install -U bitsandbytes trl transformers datasets peft accelerate + +import ast +import re +import os +import json +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from trl import GRPOConfig, GRPOTrainer +from datasets import load_dataset +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training + +os.environ["WANDB_PROJECT"] = "qwen-tool-calling-grpo-t" +# Load dataset +# ds = load_dataset("NousResearch/hermes-function-calling-v1", "func_calling_singleturn") +# ds = load_dataset("baseten-admin/CleanNousResearch_simple") +ds = load_dataset("Salesforce/xlam-function-calling-60k") +pattern = r'(.*?)' + +# Load model and tokenizer +model_name = "Qwen/Qwen2.5-7B-Instruct" +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token + +def format_example(x, add_generation_prompt=False): + try: + pattern = r'(.*?)' + tools = json.loads(x['tools']) + text = tokenizer.apply_chat_template( + [ + {"role": "user", "content": x['query']}, + ], + tools=tools, + add_generation_prompt=add_generation_prompt, + tokenize=False + ) + return { + "prompt": text, + # "text": text, + # "tools": tools, + "gt_generation": x['answers'], + } + except Exception as e: + print(e) + # Return None or a sentinel value to filter out later + return None + +formatted_dataset = ds.map(format_example).filter(lambda x: x is not None) +# formatted_dataset = formatted_dataset.remove_columns(["conversations", "category", "subcategory", "task"]) + +# train_test_split = formatted_dataset['train'].train_test_split(test_size=0.1, seed=42) +train_test_split = formatted_dataset['train'].train_test_split(test_size=300, seed=42) +train_dataset = train_test_split['train'] +eval_dataset = train_test_split['test'] + +print(f"Training samples: {len(train_dataset)}") +print(f"Validation samples: {len(eval_dataset)}") + +# Load model with quantization +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + # load_in_8bit=True +) + +# Prepare model for LoRA fine-tuning +model = prepare_model_for_kbit_training(model) + +lora_config = LoraConfig( + r=16, # higher rank = more weights in lora + lora_alpha=32, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "all_linear"], + lora_dropout=0.05, + bias="none", + # task_type="CAUSAL_LM" +) + +model = get_peft_model(model, lora_config) +model.print_trainable_parameters() + +# Define reward function for tool calling + +def get_all_available_tools_dict(prompt): + pattern = r'(.*?)' + try: + data_string = re.findall(pattern, prompt, re.DOTALL)[-1].strip() + json_objects = [] + for line in data_string.strip().split('\n'): + if line: # Skip empty lines + json_objects.append(json.loads(line)) + tool_dict = {} + for obj in json_objects: + tool_dict[obj['name']] = obj + return tool_dict + except exception as e: + print(f"Exception {e}") + return {} + +def get_selected_tools_list(generated_text): + """Returns a dict of tool_name: list of tool_calls from the generated text in order""" + pattern =r'(.*?)' + tools = re.findall(pattern, generated_text, re.DOTALL) + parsed_tools = [] + for tool in tools: + try: + stripped_tool = tool.encode().decode('unicode_escape').strip() + # tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) # escape chars and some other nonsense + tool = json.loads(stripped_tool) + parsed_tools.append(tool) + except: + try: + tool = ast.literal_eval(tool.encode().decode('unicode_escape').strip().replace('"', "'")) + print("Prsed with ast.literal_eval") + parsed_tools.append(tool) + except Exception as e: + print(f"Exception in parsing generated tools for {tool}") + print("Generated text\n", generated_text) + return parsed_tools + +# def reward_function(prompts, completions, references, tools_list): +def reward_function(prompts, completions, **kwargs): + """ + Reward function that evaluates the quality of tool calls. + Returns a list of reward scores (one per completion). + """ + # print("All kwargs:") + # for key in kwargs: + # print(f" {key}: {type(kwargs[key])}") + rewards = [] + # completions = kwargs.get("completions", []) + references = kwargs.get("gt_generation", []) + # references = kwargs.get("completion", []) + tools = kwargs.get("tools", []) + for prompts, completion, reference in zip(prompts, completions, references): + available_tools = get_all_available_tools_dict(prompts) # dict of tool_name: tool_def + reward = 0.0 + try: + gen_tools = get_selected_tools_list(completion) + except: + # not parsable or correct structure + print("Count not parse") + reward = -3.0 + rewards.append(reward) + continue + gt_tools = json.loads(reference) + if gt_tools == gen_tools: # perfect match + rewards.append(4.0) + continue + # lengths are unequal or something else is wrong + PENALTY_NAME = 1.0 + PENALTY_MISSING_PARAM = 0.5 + PENALTY_WRONG_PARAM = 0.5 + PENALTY_EXTRA_PARAM = 0.25 + REWARD_CORRECT_CALL = 1.0 + reward = 0.0 + # Account for extra/missing tool calls + if len(gen_tools) < len(gt_tools): + reward -= PENALTY_NAME * (len(gt_tools) - len(gen_tools)) + elif len(gen_tools) > len(gt_tools): + reward -= PENALTY_NAME * (len(gen_tools) - len(gt_tools)) + # Compare in order + for gt_tool, gen_tool in zip(gt_tools, gen_tools): + # tool name + if gt_tool['name'] != gen_tool['name']: + reward -= PENALTY_NAME + continue + # params + gt_params = gt_tool.get('parameters', {}) + gen_params = gen_tool.get('parameters', {}) + # missing or wrong params + for key, value in gt_params.items(): + if key not in gen_params: + reward -= PENALTY_MISSING_PARAM + elif gen_params[key] != value: + reward -= PENALTY_WRONG_PARAM + # extra params + for key in gen_params: + if key not in gt_params: + reward -= PENALTY_EXTRA_PARAM + # reward for correct call (name + maybe params) + # But you might choose: only if no param penalties + if gt_params and not any(key not in gen_params or gen_params[key] != gt_params[key] for key in gt_params): + # all params matched + reward += REWARD_CORRECT_CALL + elif not gt_params: + # no params expected, name matches + reward += REWARD_CORRECT_CALL + else: + # partial match: maybe +0 or some smaller reward + reward += REWARD_CORRECT_CALL * 0.5 + rewards.append(reward) + print("Rewards:", rewards) + return rewards + +# GRPO Configuration +grpo_config = GRPOConfig( + output_dir="./qwen-tool-calling-1", + num_train_epochs=3, + per_device_train_batch_size=2, + per_device_eval_batch_size=4, + gradient_accumulation_steps=2, + # learning_rate=5e-4, + learning_rate=1e-5, + warmup_steps=100, + logging_steps=10, + save_steps=100, + eval_strategy="steps", + eval_steps=10, + bf16=True, + max_grad_norm=1.0, + report_to=["wandb"], + + # GRPO-specific parameters + num_generations=4, # Number of generations per prompt for group comparison + temperature=0.9, # Sampling temperature + # max_new_tokens=512, + # kl=0.05, # KL divergence coefficient to prevent drift from reference +) + +# Initialize GRPO Trainer +trainer = GRPOTrainer( + model=model, + args=grpo_config, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + # tokenizer=tokenizer, + reward_funcs=[reward_function], +) + +# Train the model +print("Starting GRPO training...") +trainer.train() + +# Save the fine-tuned model +trainer.save_model("./qwen-tool-calling-grpo-1") +tokenizer.save_pretrained("./qwen-tool-calling-grpo-1") + +print("Training complete!") + +# Example inference +def test_tool_calling(prompt, tools): + model.eval() + + formatted_prompt = tokenizer.apply_chat_template( + [ + {"role": "system", "content": "You are an assistant capable of tool calls"}, + {"role": "user", "content": prompt} + ], + tools=tools, + add_generation_prompt=True, + tokenize=False + ) + + inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=1024, + temperature=0.7, + do_sample=True + ) + + response = tokenizer.decode(outputs[0], skip_special_tokens=False) + print(response) + +# Test the model +test_tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"} + }, + "required": ["location"] + } + } + } +] + +test_tool_calling("What's the weather in London?", test_tools) \ No newline at end of file diff --git a/recipes/tool_calls/train_sft_lora.py b/recipes/tool_calls/train_sft_lora.py new file mode 100644 index 0000000..840bbc6 --- /dev/null +++ b/recipes/tool_calls/train_sft_lora.py @@ -0,0 +1,152 @@ +## pip install -U bitsandbytes + +import ast +import re +import json +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments +from trl import SFTTrainer, SFTConfig +from datasets import Dataset, load_dataset +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training + + +# dataset = load_dataset("younissk/tool-calling-mix") +ds = load_dataset("NousResearch/hermes-function-calling-v1", "func_calling_singleturn") +pattern = r'(.*?)' + +# 4. Load model and tokenizer +model_name = "Qwen/Qwen2.5-7B-Instruct" # or Qwen2.5-3B, Qwen2.5-1.5B +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token + +def format_example(x): + try: + tools = [s['function'] for s in ast.literal_eval(re.findall(pattern, x['conversations'][0]['value'], re.DOTALL)[-1].strip())] + text = tokenizer.apply_chat_template( + [ + {"role": "system", "content": "You are an assistant capable of tool calls"}, + {"role": "user", "content": x['conversations'][-2]['value']}, + {"role": "assistant", "content": x['conversations'][-1]['value']} + ], + tools=tools, + add_generation_prompt=True, + tokenize=False + ) + return {"text": text} + except Exception as e: + # Return None or a sentinel value to filter out later + return None + +formatted_dataset = ds.map(format_example).filter(lambda x: x is not None) +formatted_dataset = formatted_dataset.remove_columns(["conversations", "category", "subcategory", "task"]) + +train_test_split = formatted_dataset['train'].train_test_split(test_size=0.1, seed=42) +train_dataset = train_test_split['train'] +eval_dataset = train_test_split['test'] + +print(f"Training samples: {len(train_dataset)}") +print(f"Validation samples: {len(eval_dataset)}") + +# Convert to dataset +# formatted_data = [format_training_example(ex) for ex in training_examples] +# dataset = Dataset.from_list(formatted_data) +# dataset = formatted_dataset +# print(" Dataset --- ", dataset) + + +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + load_in_8bit=True # Use 8-bit quantization for memory efficiency +) + +# 5. Prepare model for LoRA fine-tuning +model = prepare_model_for_kbit_training(model) + +""" +text = tokenizer.apply_chat_template( + messages, + tools=tools, + add_generation_prompt=True, + tokenize=False +) +""" + +lora_config = LoraConfig( + r=16, # LoRA rank + lora_alpha=32, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM" +) + +model = get_peft_model(model, lora_config) +model.print_trainable_parameters() + +sft_config = SFTConfig( + output_dir="./qwen-tool-calling3", + num_train_epochs=3, + per_device_train_batch_size=2, + gradient_accumulation_steps=8, + learning_rate=2e-4, + warmup_steps=100, + logging_steps=10, + save_steps=100, + eval_strategy="steps", # Enable evaluation + eval_steps=30, # Evaluate every 100 steps + optim="paged_adamw_8bit", + fp16=False, + bf16=True, + max_grad_norm=0.3, + # SFT-specific parameters + dataset_text_field='text', + max_length=2048, +) +trainer = SFTTrainer( + model=model, + args=sft_config, + train_dataset=formatted_dataset['train'], + eval_dataset=eval_dataset, # Add validation dataset +) + +trainer.train() + + +# 8. Train the model +print("Starting training...") +trainer.train() + +# 9. Save the fine-tuned model +trainer.save_model("./qwen-tool-calling-final") +tokenizer.save_pretrained("./qwen-tool-calling-final") + +print("Training complete!") + +# 10. Example inference +def test_tool_calling(prompt): + model.eval() + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=512, + temperature=0.7, + do_sample=True + ) + + response = tokenizer.decode(outputs[0], skip_special_tokens=False) + print(response) + +# Test the model +test_prompt = f"""<|im_start|>system +You are a helpful assistant with access to tools: {json.dumps(tools)}<|im_end|> +<|im_start|>user +What's the weather in London?<|im_end|> +<|im_start|>assistant +""" + +test_tool_calling(test_prompt) \ No newline at end of file