从头开始培训GPT2和改革者

问题描述

例如,我正在寻找脚本/笔记本,用德语从头开始训练GPT2和Reformer模型。 类似于:

https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb

我正在尝试修改同一笔记本,但是GPT2似乎不接受LinebyLineDataset或padding。

我的错误是:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<timed eval> in <module>

~/anaconda3/envs/thesis_p1/lib/python3.6/site-packages/transformers/trainer.py in train(self,model_path)
    490                 self._past = None
    491 
--> 492             for step,inputs in enumerate(epoch_iterator):
    493 
    494                 # Skip past any already trained steps if resuming training

~/anaconda3/envs/thesis_p1/lib/python3.6/site-packages/tqdm/notebook.py in __iter__(self,*args,**kwargs)
    226     def __iter__(self,**kwargs):
    227         try:
--> 228             for obj in super(tqdm_notebook,self).__iter__(*args,**kwargs):
    229                 # return super(tqdm...) will not catch exception
    230                 yield obj

~/anaconda3/envs/thesis_p1/lib/python3.6/site-packages/tqdm/std.py in __iter__(self)
   1128 
   1129         try:
-> 1130             for obj in iterable:
   1131                 yield obj
   1132                 # Update and possibly print the progressbar.

~/.local/lib/python3.6/site-packages/torch/utils/data/DataLoader.py in __next__(self)
    344     def __next__(self):
    345         index = self._next_index()  # may raise stopiteration
--> 346         data = self.dataset_fetcher.fetch(index)  # may raise stopiteration
    347         if self.pin_memory:
    348             data = _utils.pin_memory.pin_memory(data)

~/.local/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py in fetch(self,possibly_batched_index)
     45         else:
     46             data = self.dataset[possibly_batched_index]
---> 47         return self.collate_fn(data)

~/anaconda3/envs/thesis_p1/lib/python3.6/site-packages/transformers/data/data_collator.py in __call__(self,examples)
     79 
     80     def __call__(self,examples: List[torch.Tensor]) -> Dict[str,torch.Tensor]:
---> 81         batch = self._tensorize_batch(examples)
     82         if self.mlm:
     83             inputs,labels = self.mask_tokens(batch)

~/anaconda3/envs/thesis_p1/lib/python3.6/site-packages/transformers/data/data_collator.py in _tensorize_batch(self,examples)
     96             if self.tokenizer._pad_token is None:
     97                 raise ValueError(
---> 98                     "You are attempting to pad samples but the tokenizer you are using"
     99                     f" ({self.tokenizer.__class__.__name__}) does not have one."
    100                 )

ValueError: You are attempting to pad samples but the tokenizer you are using (GPT2Tokenizer) does not have one.

这是我当前的实现方式:

数据集看起来像这样(百万行):

1   "09.05.2019,Flyer: Zeit für Perspektiven - Unterstützung im Haushalt durch professionelle Dienstleistungen"
2   %0A%0ADie Burg Werle (ca. 10 km von hier entfernt) war ein schwer einnehmbarer Schlupfwinkel.
3   %0A%0AHier,abseits der verkehrsreichen Straßen,liegt das idyllische Quellental,ein Naturdenkmal der besonderen Art.
4   ½ bis 1 Tasse (75–150 ml) HEITMANN Reine Citronensäure in ½ Liter Wasser geben und in den Wassertank der Maschine füllen.
5   %0% der anfallenden Kosten ergeben sich aus der Straßenbeleuchtung.
6   ¾ Parken während der Ladezeit in Fußgängerzonen,in denen das Be- oder Entladen für bestimmte Zeiten freigegeben ist.

首先,我训练Sentence Piece Tokenizer:

from pathlib import Path
import sentencepiece as spm
paths = [str(x) for x in Path(".").glob("**/*.txt")]
arg='--input=deu-de_web-public_2019_1M-sentences.txt --model_prefix=m_test --vocab_size=52000'
spm.SentencePieceTrainer.train(arg)

然后我按如下方式加载我的GPT2令牌生成器:

from transformers import GPT2TokenizerFast

tokenizer = GPT2Tokenizer.from_pretrained("./German",additional_special_tokens=["<s>","<pad>","</s>","<unk>","<mask>"],max_len=512)

这是我的GPT2配置和语言模型:

from transformers import GPT2LMHeadModel,GPT2Config

# Initializing a GPT2 configuration
configuration = GPT2Config(vocab_size=52_000)
model = GPT2LMHeadModel(config=configuration)

数据集准备的逻辑:

from transformers import LineByLineTextDataset

dataset = LineByLineTextDataset(
    tokenizer=tokenizer,file_path="./deu-de_web-public_2019_1M-sentences.txt",block_size=128,)
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,mlm=False,)

训练逻辑:

from transformers import Trainer,TrainingArguments

training_args = TrainingArguments(
    output_dir="./output",overwrite_output_dir=True,num_train_epochs=1,per_gpu_train_batch_size=64,save_steps=10_000,save_total_limit=2,)

trainer = Trainer(
    model=model,args=training_args,data_collator=data_collator,train_dataset=dataset,prediction_loss_only=True,)
trainer.train()

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...