libtorch:为什么我得到“ c10 :: Error在内存位置0x000000C5539CC010”何时调用Torch :: nll_loss?

问题描述

更新:我在原始问题下添加一个“最小可复制示例”。

我已经调试此错误已有一段时间了。我打电话给torch::nll_loss()时会发生这种情况。我以为这可能是因为我的张量不匹配,但是我花了一些时间来确保它们的大小相同,而且我不确定还要检查什么。与python pytorch中提供的错误不同,我遇到的libtorch错误非常有用。 (到目前为止,我所有的libtorch错误都是Error at memory location 0x...........。)

我有两个张量predictiontarget

大部分实现来自this example

输入Tensor使用以下方法生成

void NeuralNetwork::Load(Waveform waveform)
{
    // Convert waveform.samples into torch::Tensor
    c10::DeviceType deviceType;
    if (torch::cuda::is_available()) {
        deviceType = torch::kCUDA;
    }
    else {
        deviceType = torch::kcpu;
    }
    float newArr[1][WAVEFORM_SIZE] = { {0} };
    for(int i=0;i<WAVEFORM_SIZE;i++)
        newArr[0][i] = waveform.samples[i];
    auto options = torch::TensorOptions().dtype(torch::kFloat32).device(deviceType);
    torch::Tensor inputTensor = torch::from_blob(newArr,{ 1,WAVEFORM_SIZE },options);
    inputTensor.set_requires_grad(true);
    TrainingSample newData = TrainingSample(inputTensor,waveform);
    trainingData.push_back(newData);
}

prediction是通过以下网络生成的:

struct NeuralNetwork::Net : torch::nn::Module {
    int _inputSize;
    Net(int inputSize) {
        _inputSize = inputSize;
        // Construct and register two Linear submodules.
        fc1 = register_module("fc1",torch::nn::Linear(_inputSize,64));
        fc2 = register_module("fc2",torch::nn::Linear(64,32));
        fc3 = register_module("fc3",torch::nn::Linear(32,3));
    }

    // Implement the Net's algorithm.
    torch::Tensor forward(torch::Tensor x) {
        // If tensor is one dimensional,change to batch_size,seq_len (2D)
        int tensorDims = 0;
        for (int i : x.sizes())
            tensorDims++;
        if(tensorDims==1)
            x = torch::unsqueeze(x,0);
        DBG("x.sizes: " + NeuralNetwork::Item2String<c10::IntArrayRef>(x.sizes()));
        DBG("\nx.sizes size: " + NeuralNetwork::Item2String<c10::IntArrayRef>(sizeof(x.sizes()) / sizeof(x.sizes()[0])));
        // Use one of many tensor manipulation functions.
        DBG("\ninput tensor: \n" + NeuralNetwork::Item2String<torch::Tensor>(x));
        x = torch::relu(fc1->forward(x));
        DBG("\nafter fc1: \n" + NeuralNetwork::Item2String<torch::Tensor>(x));
        x = torch::dropout(x,/*p=*/0.5,/*train=*/is_training());
        DBG("\nafter dropout: \n" + NeuralNetwork::Item2String<torch::Tensor>(x));
        x = torch::relu(fc2->forward(x));
        DBG("\nafter fc2: \n" + NeuralNetwork::Item2String<torch::Tensor>(x));
        x = torch::log_softmax(fc3->forward(x),/*dim=*/1);
        DBG("\nafter fc3: \n" + NeuralNetwork::Item2String<torch::Tensor>(x));
        return x;
    }

    // Use one of many "standard library" modules.
    torch::nn::Linear fc1{ nullptr },fc2{ nullptr },fc3{ nullptr };
};

target使用以下方法生成

torch::Tensor TrainingSample::getratingTensor()
{
    c10::DeviceType deviceType;
    if (torch::cuda::is_available()) {
        deviceType = torch::kCUDA;
    }
    else {
        deviceType = torch::kcpu;
    }
    float ratingArray[1][3] = { {0} };
    ratingArray[0][(int)waveform.rating] = 1;
    ostringstream os0;
    for (int i = 0;i<(sizeof(ratingArray[0])/sizeof(ratingArray[0][0]));i++) {
        os0 << ratingArray[0][i];
        os0 << ",";
    }
    DBG("ratingArray: \n" + os0.str());
    auto options = torch::TensorOptions().dtype(torch::kFloat32).device(deviceType);
    torch::Tensor ratingTensor = torch::from_blob(ratingArray,3 },options);
    ostringstream os1;
    os1 << ratingTensor[0];
    DBG("ratingTensor: \n" + os1.str());
    return ratingTensor.clone();
}

其实现方式如下:

...
    for each(TrainingSample trainingSample in trainingData) {
        // Check trainingSample.sampleTensor's length to make sure it works
        if (trainingSample.sampleTensor.size(1) != WAVEFORM_SIZE) {
            throw std::logic_error("Input must match WAVEFORM_SIZE");            
        }
        // Reset gradients.
        optimizer.zero_grad();
        // Execute the model on the input data.
        torch::Tensor prediction = net->forward(trainingSample.sampleTensor);
        // Compute a loss value to judge the prediction of our model.
        torch::Tensor target = trainingSample.getratingTensor();

        std::ostringstream os_tensor0;
        os_tensor0 << target;
        DBG("target_val: \n" + os_tensor0.str());

        std::ostringstream os_tensor1;
        os_tensor1 << prediction;
        DBG("prediction_val: \n" + os_tensor1.str());

        torch::Tensor loss = torch::nll_loss(prediction,target);
...

运行时,我得到以下控制台输出

x.sizes: [1,450]

x.sizes size: [2]

input tensor: 
Columns 1 to 6-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08

Columns 7 to 12-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08

Columns 13 to 18-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08

Columns 19 to 24-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08

Columns 25 to 30-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08

Columns 31 to 36-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08

Columns 37 to 42-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08

Columns 43 to 48-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08

Columns 49 to 54-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08

Columns 55 to 60-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08

Columns 61 to 66-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08

Columns 67 to 72-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -6.1297e+27  1.4370e-41

Columns 73 to 78-1.0737e+08 -1.0737e+08  1.3479e+12  2.7606e-43  1.3483e+12  2.7606e-43

Columns 79 to 84 1.5097e+35  4.5908e-41  1.3479e+12  2.7606e-43 -1.0737e+08 -1.0737e+08

Columns 85 to 90-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08  1.3479e+12  2.7606e-43

Columns 91 to 96-5.2868e-34  4.5908e-41  1.4063e+24  2.3647e-41  1.3479e+12  2.7606e-43

Columns 97 to 102-1.0737e+08 -1.0737e+08  1.3483e+12  2.7606e-43  1.3483e+12  2.7606e-43

Columns 103 to 108-5.2840e-34  4.5908e-41 -2.5697e+36  6.8103e-43  1.3479e+12  2.7606e-43

Columns 109 to 114 1.3479e+12  2.7606e-43  7.8240e-02  7.2303e-02  5.2409e-02  2.3487e-02

Columns 115 to 120-9.2460e-03 -3.9394e-02  0.0000e+00  0.0000e+00 -9.9383e-02  6.3058e-43

Columns 121 to 126-1.4176e-01 -1.6761e-01 -1.9406e-01 -2.1916e-01 -2.4059e-01 -2.5746e-01

Columns 127 to 132-2.7041e-01 -2.8198e-01 -2.9581e-01 -3.1235e-01 -3.2742e-01 -3.3695e-01

Columns 133 to 138-3.4230e-01 -3.4594e-01 -2.5143e+36  6.8103e-43 -3.5119e-01 -3.5661e-01

Columns 139 to 144-3.6110e-01 -3.6576e-01 -2.5697e+36  6.8103e-43 -2.5697e+36  6.8103e-43

Columns 145 to 150 1.3480e+12  2.7606e-43  1.3480e+12  2.7606e-43  1.3480e+12  2.7606e-43

Columns 151 to 156 1.3480e+12  2.7606e-43  1.3479e+12  2.7606e-43  1.3479e+12  2.7606e-43

Columns 157 to 162 1.3479e+12  2.7606e-43  1.3479e+12  2.7606e-43 -5.4939e-01 -5.4665e-01

Columns 163 to 168-5.3675e-01 -5.1979e-01 -4.9963e-01 -4.8594e-01 -4.8443e-01 -4.9146e-01

Columns 169 to 174-4.9957e-01 -5.0272e-01 -4.9933e-01 -4.9043e-01 -4.7711e-01 -4.6370e-01

Columns 175 to 180-4.5665e-01 -4.5746e-01 -4.6186e-01 -4.6114e-01 -4.4190e-01 -3.9672e-01

Columns 181 to 186-3.4028e-01 -2.9929e-01 -2.8463e-01 -2.8733e-01  1.3480e+12  2.7606e-43

Columns 187 to 192-2.5944e+36  6.8103e-43  1.3479e+12  2.7606e-43 -2.3795e+36  6.8103e-43

Columns 193 to 198-2.9645e+36  6.8103e-43 -2.5944e+36  6.8103e-43  0.0000e+00  0.0000e+00

Columns 199 to 204-2.3795e+36  6.8103e-43  0.0000e+00  0.0000e+00  0.0000e+00  0.0000e+00

Columns 205 to 210-2.3795e+36  6.8103e-43  0.0000e+00  0.0000e+00  1.3479e+12  2.7606e-43

Columns 211 to 216-2.3795e+36  6.8103e-43  1.4329e-02  1.9215e-02  2.6336e-02  3.6743e-02

Columns 217 to 222 4.7983e-02  5.8675e-02  6.9367e-02  7.8963e-02 -2.9645e+36  6.8103e-43

Columns 223 to 228 0.0000e+00  1.1853e-01 -1.0677e-23  6.8103e-43 -2.8907e-21  6.8103e-43

Columns 229 to 234-4.5850e-21  6.8103e-43 -4.5858e-21  6.8103e-43 -4.5858e-21  6.8103e-43

Columns 235 to 240 1.8583e+19  0.0000e+00 -1.8184e-36  4.5908e-41 -2.8923e-21  6.8103e-43

Columns 241 to 246-1.0742e-23  6.8103e-43 -1.0742e-23  6.8103e-43 -1.0742e-23  6.8103e-43

Columns 247 to 252 4.2970e-01  4.2393e-01 -1.1069e-23  6.8103e-43  0.0000e+00  0.0000e+00

Columns 253 to 258 5.0396e-01  5.0000e-01  0.0000e+00  0.0000e+00 -1.0742e-23  6.8103e-43

Columns 259 to 264 5.1114e-01  5.1451e-01  0.0000e+00  0.0000e+00  0.0000e+00  0.0000e+00

Columns 265 to 270 0.0000e+00  0.0000e+00  0.0000e+00  0.0000e+00 -1.8184e-36  4.5908e-41

Columns 271 to 276 8.9129e+04  1.0350e+00  0.0000e+00  0.0000e+00  0.0000e+00  0.0000e+00

Columns 277 to 282 0.0000e+00  0.0000e+00  4.6471e-01  4.5735e-01  4.5031e-01  4.4295e-01

Columns 283 to 288 4.3358e-01  4.2181e-01  4.0897e-01  3.9672e-01  3.8583e-01  3.7529e-01

Columns 289 to 294 3.6445e-01  3.5527e-01  3.4973e-01  3.4699e-01  3.4594e-01  3.4637e-01

Columns 295 to 300 3.4666e-01  3.4488e-01  3.4171e-01  3.3831e-01  3.3367e-01  3.2617e-01

Columns 301 to 306 3.1636e-01  3.0843e-01  3.0667e-01  3.1162e-01  3.2197e-01  3.3437e-01

Columns 307 to 312 3.4179e-01  3.3873e-01  3.2635e-01  3.1051e-01  2.9666e-01  2.8553e-01

Columns 313 to 318 2.7514e-01  2.6404e-01  2.5120e-01  2.3579e-01  2.1667e-01  1.9390e-01

Columns 319 to 324 1.7077e-01  1.5234e-01  1.4434e-01  1.4809e-01  1.5381e-01  1.5017e-01

Columns 325 to 330 1.3569e-01  1.1446e-01  9.1649e-02  7.1120e-02  5.3066e-02  3.6940e-02

Columns 331 to 336 2.2677e-02  1.1021e-02  2.9797e-03 -3.2208e-03 -9.0269e-03 -1.4329e-02

Columns 337 to 342-2.2239e-02 -3.3632e-02 -4.4148e-02 -5.2343e-02 -5.9880e-02 -6.9323e-02

Columns 343 to 348-8.2929e-02 -9.8573e-02 -1.1382e-01 -1.2942e-01 -1.4561e-01 -1.6303e-01

Columns 349 to 354-1.8139e-01 -1.9697e-01 -2.0819e-01 -2.1678e-01 -2.2688e-01 -2.4581e-01

Columns 355 to 360-2.7731e-01 -3.1671e-01 -3.5641e-01 -3.9151e-01 -4.2142e-01 -4.4712e-01

Columns 361 to 366-4.7063e-01 -4.9630e-01 -5.2582e-01 -5.5605e-01 -5.8340e-01 -6.0428e-01

Columns 367 to 372-6.1668e-01 -6.2303e-01 -6.2938e-01 -6.4352e-01 -6.6841e-01 -6.9415e-01

Columns 373 to 378-7.0000e-01 -6.7009e-01 -6.0787e-01 -5.2685e-01 -4.4050e-01 -3.7221e-01

Columns 379 to 384-3.4480e-01 -3.5365e-01 -3.7503e-01 -3.9094e-01 -3.9547e-01 -3.9074e-01

Columns 385 to 390-3.8215e-01 -3.7611e-01 -3.7830e-01 -3.9017e-01 -4.0926e-01 -4.2847e-01

Columns 391 to 396-4.3616e-01 -4.2378e-01 -3.9458e-01 -3.6263e-01 -3.4188e-01 -3.3779e-01

Columns 397 to 402-3.4927e-01 -3.6984e-01 -3.8943e-01 -4.0187e-01 -4.0720e-01 -4.0921e-01

Columns 403 to 408-4.1265e-01 -4.1879e-01 -4.2694e-01 -4.3649e-01 -4.4532e-01 -4.5174e-01

Columns 409 to 414-4.5421e-01 -4.5082e-01 -4.4352e-01 -4.3686e-01 -4.3415e-01 -4.3489e-01

Columns 415 to 420-4.3342e-01 -4.2635e-01 -4.1559e-01 -4.0253e-01 -3.8982e-01 -3.8171e-01

Columns 421 to 426-3.7727e-01 -3.7273e-01 -3.6684e-01 -3.6184e-01 -3.6070e-01 -3.6278e-01

Columns 427 to 432-3.6539e-01 -3.6697e-01 -3.6699e-01 -3.6660e-01 -3.6660e-01 -3.6614e-01

Columns 433 to 438-3.6438e-01 -3.6037e-01 -3.5455e-01 -3.4859e-01 -3.4298e-01 -3.3779e-01

Columns 439 to 444-3.3229e-01 -3.2536e-01 -3.1872e-01 -3.1430e-01 -3.1184e-01 -3.0891e-01

Columns 445 to 450-3.0017e-01 -2.8205e-01 -2.5545e-01 -2.2501e-01 -1.9967e-01  0.0000e+00
[ cpuFloatType{1,450} ]

after fc1: 
Columns 1 to 6 0.0000e+00  7.8770e+34  0.0000e+00  0.0000e+00  0.0000e+00  0.0000e+00

Columns 7 to 12 0.0000e+00  0.0000e+00  1.8537e+35  0.0000e+00  0.0000e+00  0.0000e+00

Columns 13 to 18 1.3528e+35  0.0000e+00  5.2141e+35  0.0000e+00  3.3142e+34  1.8497e+35

Columns 19 to 24 0.0000e+00  0.0000e+00  2.7433e+35  0.0000e+00  2.7988e+35  4.7219e+33

Columns 25 to 30 0.0000e+00  6.9118e+34  1.2590e+35  0.0000e+00  0.0000e+00  0.0000e+00

Columns 31 to 36 1.8690e+35  0.0000e+00  5.8213e+34  0.0000e+00  0.0000e+00  0.0000e+00

Columns 37 to 42 1.7089e+35  0.0000e+00  4.6019e+35  4.6340e+34  7.8131e+34  0.0000e+00

Columns 43 to 48 0.0000e+00  0.0000e+00  5.5950e+35  2.5632e+35  1.1927e+35  0.0000e+00

Columns 49 to 54 0.0000e+00  0.0000e+00  3.2757e+34  0.0000e+00  2.1214e+35  0.0000e+00

Columns 55 to 60 4.9851e+35  0.0000e+00  6.9772e+34  4.5953e+35  1.6406e+34  0.0000e+00

Columns 61 to 64 3.4096e+35  1.2904e+35  0.0000e+00  1.3653e+35
[ cpuFloatType{1,64} ]

after dropout: 
Columns 1 to 6 0.0000e+00  1.5754e+35  0.0000e+00  0.0000e+00  0.0000e+00  0.0000e+00

Columns 7 to 12 0.0000e+00  0.0000e+00  0.0000e+00  0.0000e+00  0.0000e+00  0.0000e+00

Columns 13 to 18 2.7056e+35  0.0000e+00  1.0428e+36  0.0000e+00  0.0000e+00  3.6993e+35

Columns 19 to 24 0.0000e+00  0.0000e+00  0.0000e+00  0.0000e+00  5.5975e+35  0.0000e+00

Columns 25 to 30 0.0000e+00  1.3824e+35  2.5180e+35  0.0000e+00  0.0000e+00  0.0000e+00

Columns 31 to 36 0.0000e+00  0.0000e+00  1.1643e+35  0.0000e+00  0.0000e+00  0.0000e+00

Columns 37 to 42 3.4178e+35  0.0000e+00  9.2039e+35  9.2679e+34  1.5626e+35  0.0000e+00

Columns 43 to 48 0.0000e+00  0.0000e+00  0.0000e+00  5.1264e+35  2.3855e+35  0.0000e+00

Columns 49 to 54 0.0000e+00  0.0000e+00  6.5513e+34  0.0000e+00  0.0000e+00  0.0000e+00

Columns 55 to 60 9.9702e+35  0.0000e+00  1.3954e+35  0.0000e+00  0.0000e+00  0.0000e+00

Columns 61 to 64 6.8193e+35  2.5809e+35  0.0000e+00  0.0000e+00
[ cpuFloatType{1,64} ]

after fc2: 
Columns 1 to 6 1.2524e+35  0.0000e+00  0.0000e+00  2.4940e+35  0.0000e+00  2.4364e+35

Columns 7 to 12 2.6166e+35  0.0000e+00  1.7298e+35  0.0000e+00  0.0000e+00  0.0000e+00

Columns 13 to 18 1.2573e+35  7.5406e+34  0.0000e+00  8.5736e+33  0.0000e+00  1.8340e+35

Columns 19 to 24 1.4026e+35  6.5115e+34  0.0000e+00  0.0000e+00  1.6587e+35  0.0000e+00

Columns 25 to 30 6.3401e+34  1.1294e+35  0.0000e+00  8.4040e+34  0.0000e+00  3.9660e+34

Columns 31 to 32 0.0000e+00  0.0000e+00
[ cpuFloatType{1,32} ]

after fc3: 
-6.2916e+34  0.0000e+00 -1.1594e+35
[ cpuFloatType{1,3} ]
ratingArray: 
1,ratingTensor: 
 1
 0
 0
[ cpuFloatType{3} ]
target_val: 
 1  0  0
[ cpuFloatType{1,3} ]
prediction_val: 
-6.2916e+34  0.0000e+00 -1.1594e+35
[ cpuFloatType{1,3} ]
Exception thrown at 0x00007FFA0B7D3E49 in AudioPluginHost.exe: Microsoft C++ exception: c10::Error at memory location 0x000000C5539CC010.
Unhandled exception at 0x00007FFA0B7D3E49 in AudioPluginHost.exe: Microsoft C++ exception: c10::Error at memory location 0x000000C5539CC010.

因此,馈入网络的张量的大小正确,其他所有大小也似乎正确。输出也似乎是正确的大小(1,3)。

此外,target_val(由getratingTensor产生的张量)看起来大小也正确,其值也正确。

我正在使用JUCE音频插件,因此我正在使用projucer来包含/链接libtorch库,并且正在使用Visual Studio 2019对其进行编译。

我的projucer设置是:

链接的外部库:

E:\Programming\Downloads\libtorch\lib\c10.lib
E:\Programming\Downloads\libtorch\lib\c10_cuda.lib
E:\Programming\Downloads\libtorch\lib\caffe2_nvrtc.lib
E:\Programming\Downloads\libtorch\lib\torch.lib
E:\Programming\Downloads\libtorch\lib\torch_cpu.lib
E:\Programming\Downloads\libtorch\lib\torch_cuda.lib

标题搜索路径:

E:\Programming\Downloads\libtorch\include\
E:\Programming\Downloads\libtorch\include\torch\csrc\api\include

额外的图书馆搜索路径:

E:\Programming\Downloads\libtorch\lib

最小可复制示例

我在Windows 10和Visual Studio 2019上使用JUCE 6.0.1

  1. 在projucer中,创建新的音频插件,将其命名为Debug
  2. 在出口商中> Visual Studio 2019: 在External Libraries to Link字段中,粘贴以下内容
path\to\libtorch\lib\c10.lib
path\to\libtorch\lib\c10_cuda.lib
path\to\libtorch\lib\caffe2_nvrtc.lib

path\to\libtorch\lib\torch.lib
path\to\libtorch\lib\torch_cpu.lib
path\to\libtorch\lib\torch_cuda.lib
  1. 单击“调试”。 在Header Search Paths字段中,输入:
path\to\libtorch\include\
path\to\libtorch\include\torch\csrc\api\include

在字段Extra Library Search Paths中,输入:

path\to\libtorch\lib
  1. 在projucer中,单击File Exporter,右键单击Source文件夹,然后选择名为“ Debug”的“ Add New CPP&Header File”。
  2. 现在在您的音频插件文件夹中,有一个文件夹:Debug\Builds\VisualStudio2019\x64\Debug\Standalone Plugin。创建从.dll中的path\to\libtorch\lib文件Standalone Plugin文件夹的硬链接
  3. 使用projucer中的按钮,在Visual Studio 2019中打开您的项目。
  4. Debug.h中,输入:
#pragma once
#include <torch/torch.h>
#include <JuceHeader.h>
#define WAVEFORM_SIZE 450
using namespace std;
using namespace juce;


class NeuralNetwork {
public:
    NeuralNetwork();
    struct Net;
    float trainingData[WAVEFORM_SIZE];
    int epoch = 0;
    void Train();
    torch::Tensor getratingTensor();
    template <class T>
    static string Item2String(T x);
};
  1. Debug.cpp中,输入:
#include "Debug.h"

// Define a new Module.
struct NeuralNetwork::Net : torch::nn::Module {
    int _inputSize;
    Net(int inputSize) {
        _inputSize = inputSize;
        // Construct and register two Linear submodules.
        fc1 = register_module("fc1",3));
    }

    // Implement the Net's algorithm.
    torch::Tensor forward(torch::Tensor x) {
        int tensorDims = 0;
        for (int i : x.sizes())
            tensorDims++;
        if (tensorDims == 1)
            x = torch::unsqueeze(x,fc3{ nullptr };
};

NeuralNetwork::NeuralNetwork()
{
}

void NeuralNetwork::Train()
{
    // Create saw-tooth wave,WAVEFORM_SIZE samples
    for (int i = 0; i < WAVEFORM_SIZE; i++) {
        trainingData[i] = (i - floor((float)WAVEFORM_SIZE / 2)) / ((float)WAVEFORM_SIZE/2);
    }

    // Convert waveform.samples into torch::Tensor
    c10::DeviceType deviceType;
    if (torch::cuda::is_available()) {
        deviceType = torch::kCUDA;
    }
    else {
        deviceType = torch::kcpu;
    }
    // Adjust dimensions (unsqueeze array)
    float newArr[1][WAVEFORM_SIZE] = { {0} };
    for (int i = 0; i < WAVEFORM_SIZE; i++)
        newArr[0][i] = trainingData[i];
    auto options = torch::TensorOptions().dtype(torch::kFloat32).device(deviceType);
    torch::Tensor inputTensor = torch::from_blob(newArr,options);
    inputTensor.set_requires_grad(true);

    // Now we have inputTensor
    // Get target Tensor
    torch::Tensor targetTensor = getratingTensor();

    // Train model
    // Create a new Net.
    auto net = std::make_shared<Net>(WAVEFORM_SIZE);

    // Instantiate an SGD optimization algorithm to update our Net's parameters.
    torch::optim::SGD optimizer(net->parameters(),/*lr=*/0.01);


    size_t batch_index = 0;

    // Check trainingSample.sampleTensor's length to make sure it works
    if (inputTensor.size(1) != WAVEFORM_SIZE) {
        throw std::logic_error("Input must match WAVEFORM_SIZE");
    }
    // Reset gradients.
    optimizer.zero_grad();
    // Execute the model on the input data.
    torch::Tensor prediction = net->forward(inputTensor);
    // Compute a loss value to judge the prediction of our model.

    std::ostringstream os_tensor0;
    os_tensor0 << targetTensor;
    DBG("target_val: \n" + os_tensor0.str());

    std::ostringstream os_tensor1;
    os_tensor1 << prediction;
    DBG("prediction_val: \n" + os_tensor1.str());

    torch::Tensor loss = torch::nll_loss(prediction,targetTensor);
    // Compute gradients of the loss w.r.t. the parameters of our model.
    loss.backward();
    // Update the parameters based on the calculated gradients.
    optimizer.step();
    // Output the loss and checkpoint every 100 batches.
    if (++batch_index % 100 == 0) {
        std::cout << "Epoch: " << epoch << " | Batch: " << batch_index
            << " | Loss: " << loss.item<float>() << std::endl;
        // Serialize your model periodically as a checkpoint.
        torch::save(net,"net.pt");
    }

}



torch::Tensor NeuralNetwork::getratingTensor()
{
    int rating = 0; 
    c10::DeviceType deviceType;
    if (torch::cuda::is_available()) {
        deviceType = torch::kCUDA;
    }
    else {
        deviceType = torch::kcpu;
    }
    float ratingArray[1][3] = { {0} };
    ratingArray[0][rating] = 1;
    ostringstream os0;
    for (int i = 0; i < (sizeof(ratingArray[0]) / sizeof(ratingArray[0][0])); i++) {
        os0 << ratingArray[0][i];
        os0 << ",options);
    ostringstream os1;
    os1 << ratingTensor[0];
    DBG("ratingTensor: \n" + os1.str());
    return ratingTensor.clone();
}

template <class T>
string NeuralNetwork::Item2String(T x) {
    std::ostringstream os;
    os << x;
    return os.str();
}
  1. PluginEditor.cpp的其他两个#include旁边,添加#include "Debug.h"
  2. 在同一文件DebugAudioProcessorEditor::DebugAudioProcessorEditor构造函数中,添加以下代码
    NeuralNetwork neuralNet = NeuralNetwork();
    neuralNet.Train();

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...