问题描述
我正在尝试使用 kaggle 提供的 Tpu v3-8 微调基于 bert 的 QA 模型(PyTorch)。在验证过程中,我使用 ParallelLoader 同时对 8 个内核进行预测。但在那之后我不知道我应该怎么做才能从每个核心(并以与数据集对应的正确顺序)收集所有结果,以计算整体 EM & F1 分数。有人可以帮忙吗? 代码:
def _run():
MAX_LEN = 192 # maximum text length in the batch (cannot have too high due to memory constraints)
BATCH_SIZE = 16 # batch size (cannot have too high due to memory constraints)
EPOCHS = 2 # number of epochs
train_sampler = torch.utils.data.distributed.distributedSampler(
tokenized_datasets['train'],num_replicas=xm.xrt_world_size(),# tell PyTorch how many devices (TPU cores) we are using for training
rank=xm.get_ordinal(),# tell PyTorch which device (core) we are on currently
shuffle=True
)
train_data_loader = torch.utils.data.DataLoader(
tokenized_datasets['train'],batch_size=BATCH_SIZE,sampler=train_sampler,drop_last=True,num_workers=0,)
valid_sampler = torch.utils.data.distributed.distributedSampler(
tokenized_datasets['validation'],rank=xm.get_ordinal(),shuffle=False
)
valid_data_loader = torch.utils.data.DataLoader(
tokenized_datasets['validation'],sampler=valid_sampler,drop_last=False,num_workers=0
)
device = xm.xla_device() # device (single TPU core)
model = model.to(device) # put model onto the TPU core
xm.master_print('done loading model')
xm.master_print(xm.xrt_world_size(),'as size')
lr = 0.5e-5 * xm.xrt_world_size()
optimizer = AdamW(model.parameters(),lr=lr) # define our optimizer
for epoch in range(EPOCHS):
gc.collect()
# use ParallelLoader (provided by PyTorch XLA) for TPU-core-specific dataloading:
para_loader = pl.ParallelLoader(train_data_loader,[device])
xm.master_print('parallel loader created... training Now')
gc.collect()
call training loop:
train_loop_fn(para_loader.per_device_loader(device),model,optimizer,device,scheduler=None)
del para_loader
model.eval()
para_loader = pl.ParallelLoader(valid_data_loader,[device])
gc.collect()
model.eval()
# call evaluation loop
print("call evaluation loop")
start_logits,end_logits = eval_loop_fn(para_loader.per_device_loader(device),device)
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)