计算每个测试集类别的熵,以测量火炬上的不确定性

问题描述

我正在尝试使用MC Dropout方法和此链接中提出的解决方案来计算图像分类任务的每个数据集的熵,以测量火炬上的不确定性
Measuring uncertainty using MC Dropout on pytorch

首先,我已经计算出每个批次在不同前向传递中的平均值(class_mean_batch),然后对所有测试加载器(classes_mean)进行了计算,然后进行了一些转换以获取(total_mean)以将其用于计算熵,如图所示。下面的代码

def mcdropout_test(batch_size,n_classes,model,T):

    #set non-dropout layers to eval mode
    model.eval()

    #set dropout layers to train mode
    enable_dropout(model)
    
    softmax = nn.Softmax(dim=1)
    classes_mean = []
       
    for images,labels in testloader:
        images = images.to(device)
        labels = labels.to(device)
        classes_mean_batch = []
            
        with torch.no_grad():
          output_list = []
          
          #getting outputs for T forward passes
          for i in range(T):
            output = model(images)
            output = softmax(output)
            output_list.append(torch.unsqueeze(output,0))
            
        
        concat_output = torch.cat(output_list,0)
        
        # getting mean of each class per batch across multiple MCD forward passes
        for i in range (n_classes):
          mean = torch.mean(concat_output[:,:,i])
          classes_mean_batch.append(mean)
        
        # getting mean of each class for the testloader
        classes_mean.append(torch.stack(classes_mean_batch))
        

    total_mean = []
    concat_classes_mean = torch.stack(classes_mean)

    for i in range (n_classes):
      concat_classes = concat_classes_mean[:,i]
      total_mean.append(concat_classes)


    total_mean = torch.stack(total_mean)
    total_mean = np.asarray(total_mean.cpu())
 
    epsilon = sys.float_info.min
    # Calculating entropy across multiple MCD forward passes 
    entropy = (- np.sum(total_mean*np.log(total_mean + epsilon),axis=-1)).tolist()
    for i in range(n_classes):
      print(f'The uncertainty of class {i+1} is {entropy[i]:.4f}')
    
    

任何人都可以请更正或确认我用来计算每个类的熵的实现。

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

错误1:Request method ‘DELETE‘ not supported 错误还原:...
错误1:启动docker镜像时报错:Error response from daemon:...
错误1:private field ‘xxx‘ is never assigned 按Alt...
报错如下,通过源不能下载,最后警告pip需升级版本 Requirem...