from datasets import load_dataset
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
HuggingFace
Training GPT2 Model on custom dataset using 🤗 Transformers API
= load_dataset("text", data_files={"train": "train_c_code.txt", "test": "test_c_code.txt"}) dataset
Found cached dataset text (/home/moneebullah25/.cache/huggingface/datasets/text/default-6f7851ccb47a0532/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2)
dataset
DatasetDict({
train: Dataset({
features: ['text'],
num_rows: 3561
})
test: Dataset({
features: ['text'],
num_rows: 2677
})
})
# Load pre-trained GPT-2 model and tokenizer
= "gpt2" # or any other GPT-2 variant you prefer
model_name = GPT2Tokenizer.from_pretrained(model_name)
tokenizer = GPT2LMHeadModel.from_pretrained(model_name) model
# Tokenize your dataset
= TextDataset(
train_dataset =tokenizer,
tokenizer="train_c_code.txt",
file_path=128 # adjust block_size as needed
block_size )
/home/moneebullah25/anaconda3/lib/python3.11/site-packages/transformers/data/datasets/language_modeling.py:53: FutureWarning: This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets library. You can have a look at this example script for pointers: https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py
warnings.warn(
# Create a data collator
= DataCollatorForLanguageModeling(
data_collator =tokenizer,
tokenizer=False # Set to True if you want to use masked language modeling
mlm )
# Define training arguments
= TrainingArguments(
training_args ="./code_completion_model",
output_dir=True,
overwrite_output_dir=1, # adjust as needed
num_train_epochs=2, # adjust as needed
per_device_train_batch_size=10_000,
save_steps=2,
save_total_limit )
# Initialize Trainer
= Trainer(
trainer =model,
model=training_args,
args=data_collator,
data_collator=train_dataset,
train_dataset )
# Fine-tune the model
trainer.train()
[180/180 05:16, Epoch 1/1]
Step | Training Loss |
---|
TrainOutput(global_step=180, training_loss=1.6714536878797743, metrics={'train_runtime': 318.2669, 'train_samples_per_second': 1.128, 'train_steps_per_second': 0.566, 'total_flos': 23450959872000.0, 'train_loss': 1.6714536878797743, 'epoch': 1.0})
# Save the fine-tuned model
"./fine_tuned_code_completion_model")
model.save_pretrained("./fine_tuned_code_completion_model") tokenizer.save_pretrained(
('./fine_tuned_code_completion_model/tokenizer_config.json',
'./fine_tuned_code_completion_model/special_tokens_map.json',
'./fine_tuned_code_completion_model/vocab.json',
'./fine_tuned_code_completion_model/merges.txt',
'./fine_tuned_code_completion_model/added_tokens.json')
# Load the fine-tuned model and tokenizer
= "./fine_tuned_code_completion_model" # Replace with the path to your fine-tuned model
model_name = GPT2Tokenizer.from_pretrained(model_name)
tokenizer = GPT2LMHeadModel.from_pretrained(model_name) model
# Set the model to evaluation mode
eval()
model.
# Generate sample text
= "void main() {"
prompt = tokenizer.encode(prompt, return_tensors="pt")
input_ids
# Adjust the max length of the generated sequence as needed
= 100
max_length
# Generate the sample
= model.generate(input_ids, max_length=max_length, num_return_sequences=1, no_repeat_ngram_size=2)
output
# Decode the generated output
= tokenizer.decode(output[0], skip_special_tokens=True)
generated_text
print("Generated Code:")
print(generated_text)
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Generated Code:
void main() {
int i;
#ifdef DEBUG_HASH
int hash = 0;
hash += 1; /* hash is the hash of the current file */
}
/*
* Name: hash_hash
*/
int hash; /*
1) Set the default hash value
2) Configure the system clock
3) Reset the clock frequency
4) Load the kernel
5) Start up the program
6)