训练 TimeSFormer 进行视频分类

问题描述

我的输入数据是特征图,而不是原始图像。并具有以下形式:(4,50,1,256) mini_batch=4 / frames=50 / channels=1 / H=1 / W= 256 TimeSformer 的参数是:

dim = 128,image_size = 256,patch_size = 16,num_frames = 50,num_classes = 2,depth = 12,heads = 8,dim_head = 32,attn_dropout = 0.,ff_dropout = 0.
)

为了检查我的网络是否正常工作,我试图通过仅使用 6 个训练数据和 2 个与 (4,256) 之前形状相同的验证数据来使其过拟合。 但是我得到的训练准确度是振荡的,永远不会达到 > 80% 的值,而且我的训练损失并没有减少,它总是在 0.6900 - 06950

左右

我的训练函数和参数是:


    epochs = 300
    lr = 1e-3
    device = "cuda" if torch.cuda.is_available() else "cpu" 
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(model.parameters(),lr=lr)
    def accuracy(y_pred,y_test):
       y_pred_softmax = torch.log_softmax(y_pred,dim = 1)
       _,y_pred_tags = torch.max(y_pred_softmax,dim = 1)    
       correct_pred = (y_pred_tags == y_test).float()
       acc = correct_pred.sum() / len(correct_pred)
       acc = torch.round(acc * 100)
       return acc
    history = defaultdict(list)
    for epoch in range(epochs):
        epoch_loss = 0
        epoch_accuracy = 0
        model=model.train()
        for data,label in tqdm(train_loader):
            data = data
            label = label
            data=data.reshape(4,256)
            output = model(data)
            label=label.reshape(4,).to(torch.long)
            output = output / output.sum(0).expand_as(output)
            loss = criterion(output,label)
            acc=accuracy(output,label)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()        
            epoch_accuracy += acc / len(train_loader)
            epoch_loss += loss / len(train_loader)
        with torch.no_grad():
            epoch_val_accuracy = 0
            epoch_val_loss = 0
            model=model.eval()
            for data,label in val_loader:
                data = data
                label=label.reshape(4,).to(torch.long)
                data=data.reshape(4,256)
                val_output = model(data)
                val_output = val_output / val_output.sum(0).expand_as(val_output)
                val_loss = criterion(val_output,label)
                val_acc=accuracy(val_output,label)
                optimizer.zero_grad()            
                epoch_val_accuracy += acc / len(val_loader)
                epoch_val_loss += val_loss / len(val_loader)

我将不胜感激任何建议。 谢谢

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...