问题描述
使用 SageMaker 中的 GPT2-medium 库对预训练的 Huggingface 模型进行微调时出现运行时错误 - ml.p3.8xlarge 实例。
finetuning_gpt2_script.py
包含以下内容,
图书馆:
from transformers import Trainer,TrainingArguments
from transformers import EarlyStoppingCallback
from transformers import GPT2LMHeadModel,GPT2Tokenizer
from transformers import TextDataset,DataCollatorForLanguageModeling
预训练模型:
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2-medium")
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
训练和测试数据构建:
train_dataset = TextDataset(
tokenizer=gpt2_tokenizer,file_path=train_path,block_size=128)
test_dataset = TextDataset(
tokenizer=gpt2_tokenizer,file_path=test_path,block_size=128)
data_collator = DataCollatorForLanguageModeling(
tokenizer=gpt2_tokenizer,mlm=False,)
train_path
& test_path
是大小为 145 万和 20 万行数据的非结构化文本数据文件
训练参数:
training_args = TrainingArguments(
output_dir="./gpt2-finetuned-models",#The output directory
overwrite_output_dir=True,#overwrite the content of the output directory
num_train_epochs=1,# number of training epochs
per_device_train_batch_size=8,# batch size for training #32
per_device_eval_batch_size=8,# batch size for evaluation #64
save_steps=100,# after # steps model is saved
warmup_steps=500,# number of warmup steps for learning rate scheduler
prediction_loss_only=True,metric_for_best_model = "eval_loss",load_best_model_at_end = True,evaluation_strategy="epoch",)
training_args
是为训练模型而构建的训练参数。
培训师:
trainer = Trainer(
model=gpt2_model,args=training_args,data_collator=data_collator,train_dataset=train_dataset,eval_dataset=test_dataset,callbacks = [early_stop_callback],)
early_stop_callback = EarlyStoppingCallback(early_stopping_patience = 3)
培训:
trainer.train()
trainer.save_model(model_path)
在这里,使用 ml.p3.8xlarge 实例在 4 个 GPU 中仅完成了 1 个时期的训练。
训练是通过像下面这样的火炬分配来完成的,
python -m torch.distributed.launch finetuning_gpt2_script.py
在 epoch 结束时进行训练,观察到以下错误,
RuntimeError: Input tensor at index 3 has invalid shape [2,2,16,128,64] but expected [2,4,64]
-
RuntimeError
是否是因为train_dataset
和test_dataset
使用TextData
构造的方式? - 我在
torch-distribution
中做错了吗?
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)