问题描述
我已经调试此错误已有一段时间了。我打电话给torch::nll_loss()
时会发生这种情况。我以为这可能是因为我的张量不匹配,但是我花了一些时间来确保它们的大小相同,而且我不确定还要检查什么。与python pytorch中提供的错误不同,我遇到的libtorch错误非常有用。 (到目前为止,我所有的libtorch错误都是Error at memory location 0x...........
。)
我有两个张量prediction
和target
。
大部分实现来自this example。
void NeuralNetwork::Load(Waveform waveform)
{
// Convert waveform.samples into torch::Tensor
c10::DeviceType deviceType;
if (torch::cuda::is_available()) {
deviceType = torch::kCUDA;
}
else {
deviceType = torch::kcpu;
}
float newArr[1][WAVEFORM_SIZE] = { {0} };
for(int i=0;i<WAVEFORM_SIZE;i++)
newArr[0][i] = waveform.samples[i];
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(deviceType);
torch::Tensor inputTensor = torch::from_blob(newArr,{ 1,WAVEFORM_SIZE },options);
inputTensor.set_requires_grad(true);
TrainingSample newData = TrainingSample(inputTensor,waveform);
trainingData.push_back(newData);
}
prediction
是通过以下网络生成的:
struct NeuralNetwork::Net : torch::nn::Module {
int _inputSize;
Net(int inputSize) {
_inputSize = inputSize;
// Construct and register two Linear submodules.
fc1 = register_module("fc1",torch::nn::Linear(_inputSize,64));
fc2 = register_module("fc2",torch::nn::Linear(64,32));
fc3 = register_module("fc3",torch::nn::Linear(32,3));
}
// Implement the Net's algorithm.
torch::Tensor forward(torch::Tensor x) {
// If tensor is one dimensional,change to batch_size,seq_len (2D)
int tensorDims = 0;
for (int i : x.sizes())
tensorDims++;
if(tensorDims==1)
x = torch::unsqueeze(x,0);
DBG("x.sizes: " + NeuralNetwork::Item2String<c10::IntArrayRef>(x.sizes()));
DBG("\nx.sizes size: " + NeuralNetwork::Item2String<c10::IntArrayRef>(sizeof(x.sizes()) / sizeof(x.sizes()[0])));
// Use one of many tensor manipulation functions.
DBG("\ninput tensor: \n" + NeuralNetwork::Item2String<torch::Tensor>(x));
x = torch::relu(fc1->forward(x));
DBG("\nafter fc1: \n" + NeuralNetwork::Item2String<torch::Tensor>(x));
x = torch::dropout(x,/*p=*/0.5,/*train=*/is_training());
DBG("\nafter dropout: \n" + NeuralNetwork::Item2String<torch::Tensor>(x));
x = torch::relu(fc2->forward(x));
DBG("\nafter fc2: \n" + NeuralNetwork::Item2String<torch::Tensor>(x));
x = torch::log_softmax(fc3->forward(x),/*dim=*/1);
DBG("\nafter fc3: \n" + NeuralNetwork::Item2String<torch::Tensor>(x));
return x;
}
// Use one of many "standard library" modules.
torch::nn::Linear fc1{ nullptr },fc2{ nullptr },fc3{ nullptr };
};
torch::Tensor TrainingSample::getratingTensor()
{
c10::DeviceType deviceType;
if (torch::cuda::is_available()) {
deviceType = torch::kCUDA;
}
else {
deviceType = torch::kcpu;
}
float ratingArray[1][3] = { {0} };
ratingArray[0][(int)waveform.rating] = 1;
ostringstream os0;
for (int i = 0;i<(sizeof(ratingArray[0])/sizeof(ratingArray[0][0]));i++) {
os0 << ratingArray[0][i];
os0 << ",";
}
DBG("ratingArray: \n" + os0.str());
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(deviceType);
torch::Tensor ratingTensor = torch::from_blob(ratingArray,3 },options);
ostringstream os1;
os1 << ratingTensor[0];
DBG("ratingTensor: \n" + os1.str());
return ratingTensor.clone();
}
其实现方式如下:
...
for each(TrainingSample trainingSample in trainingData) {
// Check trainingSample.sampleTensor's length to make sure it works
if (trainingSample.sampleTensor.size(1) != WAVEFORM_SIZE) {
throw std::logic_error("Input must match WAVEFORM_SIZE");
}
// Reset gradients.
optimizer.zero_grad();
// Execute the model on the input data.
torch::Tensor prediction = net->forward(trainingSample.sampleTensor);
// Compute a loss value to judge the prediction of our model.
torch::Tensor target = trainingSample.getratingTensor();
std::ostringstream os_tensor0;
os_tensor0 << target;
DBG("target_val: \n" + os_tensor0.str());
std::ostringstream os_tensor1;
os_tensor1 << prediction;
DBG("prediction_val: \n" + os_tensor1.str());
torch::Tensor loss = torch::nll_loss(prediction,target);
...
运行时,我得到以下控制台输出:
x.sizes: [1,450]
x.sizes size: [2]
input tensor:
Columns 1 to 6-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08
Columns 7 to 12-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08
Columns 13 to 18-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08
Columns 19 to 24-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08
Columns 25 to 30-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08
Columns 31 to 36-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08
Columns 37 to 42-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08
Columns 43 to 48-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08
Columns 49 to 54-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08
Columns 55 to 60-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08
Columns 61 to 66-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08
Columns 67 to 72-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 -6.1297e+27 1.4370e-41
Columns 73 to 78-1.0737e+08 -1.0737e+08 1.3479e+12 2.7606e-43 1.3483e+12 2.7606e-43
Columns 79 to 84 1.5097e+35 4.5908e-41 1.3479e+12 2.7606e-43 -1.0737e+08 -1.0737e+08
Columns 85 to 90-1.0737e+08 -1.0737e+08 -1.0737e+08 -1.0737e+08 1.3479e+12 2.7606e-43
Columns 91 to 96-5.2868e-34 4.5908e-41 1.4063e+24 2.3647e-41 1.3479e+12 2.7606e-43
Columns 97 to 102-1.0737e+08 -1.0737e+08 1.3483e+12 2.7606e-43 1.3483e+12 2.7606e-43
Columns 103 to 108-5.2840e-34 4.5908e-41 -2.5697e+36 6.8103e-43 1.3479e+12 2.7606e-43
Columns 109 to 114 1.3479e+12 2.7606e-43 7.8240e-02 7.2303e-02 5.2409e-02 2.3487e-02
Columns 115 to 120-9.2460e-03 -3.9394e-02 0.0000e+00 0.0000e+00 -9.9383e-02 6.3058e-43
Columns 121 to 126-1.4176e-01 -1.6761e-01 -1.9406e-01 -2.1916e-01 -2.4059e-01 -2.5746e-01
Columns 127 to 132-2.7041e-01 -2.8198e-01 -2.9581e-01 -3.1235e-01 -3.2742e-01 -3.3695e-01
Columns 133 to 138-3.4230e-01 -3.4594e-01 -2.5143e+36 6.8103e-43 -3.5119e-01 -3.5661e-01
Columns 139 to 144-3.6110e-01 -3.6576e-01 -2.5697e+36 6.8103e-43 -2.5697e+36 6.8103e-43
Columns 145 to 150 1.3480e+12 2.7606e-43 1.3480e+12 2.7606e-43 1.3480e+12 2.7606e-43
Columns 151 to 156 1.3480e+12 2.7606e-43 1.3479e+12 2.7606e-43 1.3479e+12 2.7606e-43
Columns 157 to 162 1.3479e+12 2.7606e-43 1.3479e+12 2.7606e-43 -5.4939e-01 -5.4665e-01
Columns 163 to 168-5.3675e-01 -5.1979e-01 -4.9963e-01 -4.8594e-01 -4.8443e-01 -4.9146e-01
Columns 169 to 174-4.9957e-01 -5.0272e-01 -4.9933e-01 -4.9043e-01 -4.7711e-01 -4.6370e-01
Columns 175 to 180-4.5665e-01 -4.5746e-01 -4.6186e-01 -4.6114e-01 -4.4190e-01 -3.9672e-01
Columns 181 to 186-3.4028e-01 -2.9929e-01 -2.8463e-01 -2.8733e-01 1.3480e+12 2.7606e-43
Columns 187 to 192-2.5944e+36 6.8103e-43 1.3479e+12 2.7606e-43 -2.3795e+36 6.8103e-43
Columns 193 to 198-2.9645e+36 6.8103e-43 -2.5944e+36 6.8103e-43 0.0000e+00 0.0000e+00
Columns 199 to 204-2.3795e+36 6.8103e-43 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00
Columns 205 to 210-2.3795e+36 6.8103e-43 0.0000e+00 0.0000e+00 1.3479e+12 2.7606e-43
Columns 211 to 216-2.3795e+36 6.8103e-43 1.4329e-02 1.9215e-02 2.6336e-02 3.6743e-02
Columns 217 to 222 4.7983e-02 5.8675e-02 6.9367e-02 7.8963e-02 -2.9645e+36 6.8103e-43
Columns 223 to 228 0.0000e+00 1.1853e-01 -1.0677e-23 6.8103e-43 -2.8907e-21 6.8103e-43
Columns 229 to 234-4.5850e-21 6.8103e-43 -4.5858e-21 6.8103e-43 -4.5858e-21 6.8103e-43
Columns 235 to 240 1.8583e+19 0.0000e+00 -1.8184e-36 4.5908e-41 -2.8923e-21 6.8103e-43
Columns 241 to 246-1.0742e-23 6.8103e-43 -1.0742e-23 6.8103e-43 -1.0742e-23 6.8103e-43
Columns 247 to 252 4.2970e-01 4.2393e-01 -1.1069e-23 6.8103e-43 0.0000e+00 0.0000e+00
Columns 253 to 258 5.0396e-01 5.0000e-01 0.0000e+00 0.0000e+00 -1.0742e-23 6.8103e-43
Columns 259 to 264 5.1114e-01 5.1451e-01 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00
Columns 265 to 270 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 -1.8184e-36 4.5908e-41
Columns 271 to 276 8.9129e+04 1.0350e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00
Columns 277 to 282 0.0000e+00 0.0000e+00 4.6471e-01 4.5735e-01 4.5031e-01 4.4295e-01
Columns 283 to 288 4.3358e-01 4.2181e-01 4.0897e-01 3.9672e-01 3.8583e-01 3.7529e-01
Columns 289 to 294 3.6445e-01 3.5527e-01 3.4973e-01 3.4699e-01 3.4594e-01 3.4637e-01
Columns 295 to 300 3.4666e-01 3.4488e-01 3.4171e-01 3.3831e-01 3.3367e-01 3.2617e-01
Columns 301 to 306 3.1636e-01 3.0843e-01 3.0667e-01 3.1162e-01 3.2197e-01 3.3437e-01
Columns 307 to 312 3.4179e-01 3.3873e-01 3.2635e-01 3.1051e-01 2.9666e-01 2.8553e-01
Columns 313 to 318 2.7514e-01 2.6404e-01 2.5120e-01 2.3579e-01 2.1667e-01 1.9390e-01
Columns 319 to 324 1.7077e-01 1.5234e-01 1.4434e-01 1.4809e-01 1.5381e-01 1.5017e-01
Columns 325 to 330 1.3569e-01 1.1446e-01 9.1649e-02 7.1120e-02 5.3066e-02 3.6940e-02
Columns 331 to 336 2.2677e-02 1.1021e-02 2.9797e-03 -3.2208e-03 -9.0269e-03 -1.4329e-02
Columns 337 to 342-2.2239e-02 -3.3632e-02 -4.4148e-02 -5.2343e-02 -5.9880e-02 -6.9323e-02
Columns 343 to 348-8.2929e-02 -9.8573e-02 -1.1382e-01 -1.2942e-01 -1.4561e-01 -1.6303e-01
Columns 349 to 354-1.8139e-01 -1.9697e-01 -2.0819e-01 -2.1678e-01 -2.2688e-01 -2.4581e-01
Columns 355 to 360-2.7731e-01 -3.1671e-01 -3.5641e-01 -3.9151e-01 -4.2142e-01 -4.4712e-01
Columns 361 to 366-4.7063e-01 -4.9630e-01 -5.2582e-01 -5.5605e-01 -5.8340e-01 -6.0428e-01
Columns 367 to 372-6.1668e-01 -6.2303e-01 -6.2938e-01 -6.4352e-01 -6.6841e-01 -6.9415e-01
Columns 373 to 378-7.0000e-01 -6.7009e-01 -6.0787e-01 -5.2685e-01 -4.4050e-01 -3.7221e-01
Columns 379 to 384-3.4480e-01 -3.5365e-01 -3.7503e-01 -3.9094e-01 -3.9547e-01 -3.9074e-01
Columns 385 to 390-3.8215e-01 -3.7611e-01 -3.7830e-01 -3.9017e-01 -4.0926e-01 -4.2847e-01
Columns 391 to 396-4.3616e-01 -4.2378e-01 -3.9458e-01 -3.6263e-01 -3.4188e-01 -3.3779e-01
Columns 397 to 402-3.4927e-01 -3.6984e-01 -3.8943e-01 -4.0187e-01 -4.0720e-01 -4.0921e-01
Columns 403 to 408-4.1265e-01 -4.1879e-01 -4.2694e-01 -4.3649e-01 -4.4532e-01 -4.5174e-01
Columns 409 to 414-4.5421e-01 -4.5082e-01 -4.4352e-01 -4.3686e-01 -4.3415e-01 -4.3489e-01
Columns 415 to 420-4.3342e-01 -4.2635e-01 -4.1559e-01 -4.0253e-01 -3.8982e-01 -3.8171e-01
Columns 421 to 426-3.7727e-01 -3.7273e-01 -3.6684e-01 -3.6184e-01 -3.6070e-01 -3.6278e-01
Columns 427 to 432-3.6539e-01 -3.6697e-01 -3.6699e-01 -3.6660e-01 -3.6660e-01 -3.6614e-01
Columns 433 to 438-3.6438e-01 -3.6037e-01 -3.5455e-01 -3.4859e-01 -3.4298e-01 -3.3779e-01
Columns 439 to 444-3.3229e-01 -3.2536e-01 -3.1872e-01 -3.1430e-01 -3.1184e-01 -3.0891e-01
Columns 445 to 450-3.0017e-01 -2.8205e-01 -2.5545e-01 -2.2501e-01 -1.9967e-01 0.0000e+00
[ cpuFloatType{1,450} ]
after fc1:
Columns 1 to 6 0.0000e+00 7.8770e+34 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00
Columns 7 to 12 0.0000e+00 0.0000e+00 1.8537e+35 0.0000e+00 0.0000e+00 0.0000e+00
Columns 13 to 18 1.3528e+35 0.0000e+00 5.2141e+35 0.0000e+00 3.3142e+34 1.8497e+35
Columns 19 to 24 0.0000e+00 0.0000e+00 2.7433e+35 0.0000e+00 2.7988e+35 4.7219e+33
Columns 25 to 30 0.0000e+00 6.9118e+34 1.2590e+35 0.0000e+00 0.0000e+00 0.0000e+00
Columns 31 to 36 1.8690e+35 0.0000e+00 5.8213e+34 0.0000e+00 0.0000e+00 0.0000e+00
Columns 37 to 42 1.7089e+35 0.0000e+00 4.6019e+35 4.6340e+34 7.8131e+34 0.0000e+00
Columns 43 to 48 0.0000e+00 0.0000e+00 5.5950e+35 2.5632e+35 1.1927e+35 0.0000e+00
Columns 49 to 54 0.0000e+00 0.0000e+00 3.2757e+34 0.0000e+00 2.1214e+35 0.0000e+00
Columns 55 to 60 4.9851e+35 0.0000e+00 6.9772e+34 4.5953e+35 1.6406e+34 0.0000e+00
Columns 61 to 64 3.4096e+35 1.2904e+35 0.0000e+00 1.3653e+35
[ cpuFloatType{1,64} ]
after dropout:
Columns 1 to 6 0.0000e+00 1.5754e+35 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00
Columns 7 to 12 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00
Columns 13 to 18 2.7056e+35 0.0000e+00 1.0428e+36 0.0000e+00 0.0000e+00 3.6993e+35
Columns 19 to 24 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 5.5975e+35 0.0000e+00
Columns 25 to 30 0.0000e+00 1.3824e+35 2.5180e+35 0.0000e+00 0.0000e+00 0.0000e+00
Columns 31 to 36 0.0000e+00 0.0000e+00 1.1643e+35 0.0000e+00 0.0000e+00 0.0000e+00
Columns 37 to 42 3.4178e+35 0.0000e+00 9.2039e+35 9.2679e+34 1.5626e+35 0.0000e+00
Columns 43 to 48 0.0000e+00 0.0000e+00 0.0000e+00 5.1264e+35 2.3855e+35 0.0000e+00
Columns 49 to 54 0.0000e+00 0.0000e+00 6.5513e+34 0.0000e+00 0.0000e+00 0.0000e+00
Columns 55 to 60 9.9702e+35 0.0000e+00 1.3954e+35 0.0000e+00 0.0000e+00 0.0000e+00
Columns 61 to 64 6.8193e+35 2.5809e+35 0.0000e+00 0.0000e+00
[ cpuFloatType{1,64} ]
after fc2:
Columns 1 to 6 1.2524e+35 0.0000e+00 0.0000e+00 2.4940e+35 0.0000e+00 2.4364e+35
Columns 7 to 12 2.6166e+35 0.0000e+00 1.7298e+35 0.0000e+00 0.0000e+00 0.0000e+00
Columns 13 to 18 1.2573e+35 7.5406e+34 0.0000e+00 8.5736e+33 0.0000e+00 1.8340e+35
Columns 19 to 24 1.4026e+35 6.5115e+34 0.0000e+00 0.0000e+00 1.6587e+35 0.0000e+00
Columns 25 to 30 6.3401e+34 1.1294e+35 0.0000e+00 8.4040e+34 0.0000e+00 3.9660e+34
Columns 31 to 32 0.0000e+00 0.0000e+00
[ cpuFloatType{1,32} ]
after fc3:
-6.2916e+34 0.0000e+00 -1.1594e+35
[ cpuFloatType{1,3} ]
ratingArray:
1,ratingTensor:
1
0
0
[ cpuFloatType{3} ]
target_val:
1 0 0
[ cpuFloatType{1,3} ]
prediction_val:
-6.2916e+34 0.0000e+00 -1.1594e+35
[ cpuFloatType{1,3} ]
Exception thrown at 0x00007FFA0B7D3E49 in AudioPluginHost.exe: Microsoft C++ exception: c10::Error at memory location 0x000000C5539CC010.
Unhandled exception at 0x00007FFA0B7D3E49 in AudioPluginHost.exe: Microsoft C++ exception: c10::Error at memory location 0x000000C5539CC010.
因此,馈入网络的张量的大小正确,其他所有大小也似乎正确。输出也似乎是正确的大小(1,3)。
此外,target_val
(由getratingTensor
产生的张量)看起来大小也正确,其值也正确。
我正在使用JUCE音频插件,因此我正在使用projucer来包含/链接libtorch库,并且正在使用Visual Studio 2019对其进行编译。
我的projucer设置是:
要链接的外部库:
E:\Programming\Downloads\libtorch\lib\c10.lib
E:\Programming\Downloads\libtorch\lib\c10_cuda.lib
E:\Programming\Downloads\libtorch\lib\caffe2_nvrtc.lib
E:\Programming\Downloads\libtorch\lib\torch.lib
E:\Programming\Downloads\libtorch\lib\torch_cpu.lib
E:\Programming\Downloads\libtorch\lib\torch_cuda.lib
E:\Programming\Downloads\libtorch\include\
E:\Programming\Downloads\libtorch\include\torch\csrc\api\include
额外的图书馆搜索路径:
E:\Programming\Downloads\libtorch\lib
最小可复制示例
我在Windows 10和Visual Studio 2019上使用JUCE 6.0.1
path\to\libtorch\lib\c10.lib
path\to\libtorch\lib\c10_cuda.lib
path\to\libtorch\lib\caffe2_nvrtc.lib
path\to\libtorch\lib\torch.lib
path\to\libtorch\lib\torch_cpu.lib
path\to\libtorch\lib\torch_cuda.lib
- 单击“调试”。
在
Header Search Paths
字段中,输入:
path\to\libtorch\include\
path\to\libtorch\include\torch\csrc\api\include
在字段Extra Library Search Paths
中,输入:
path\to\libtorch\lib
- 在projucer中,单击
File Exporter
,右键单击Source
文件夹,然后选择名为“Debug
”的“ Add New CPP&Header File”。 - 现在在您的音频插件文件夹中,有一个子文件夹:
Debug\Builds\VisualStudio2019\x64\Debug\Standalone Plugin
。创建从.dll
中的path\to\libtorch\lib
文件到Standalone Plugin
文件夹的硬链接。 - 使用projucer中的按钮,在Visual Studio 2019中打开您的项目。
- 在
Debug.h
中,输入:
#pragma once
#include <torch/torch.h>
#include <JuceHeader.h>
#define WAVEFORM_SIZE 450
using namespace std;
using namespace juce;
class NeuralNetwork {
public:
NeuralNetwork();
struct Net;
float trainingData[WAVEFORM_SIZE];
int epoch = 0;
void Train();
torch::Tensor getratingTensor();
template <class T>
static string Item2String(T x);
};
- 在
Debug.cpp
中,输入:
#include "Debug.h"
// Define a new Module.
struct NeuralNetwork::Net : torch::nn::Module {
int _inputSize;
Net(int inputSize) {
_inputSize = inputSize;
// Construct and register two Linear submodules.
fc1 = register_module("fc1",3));
}
// Implement the Net's algorithm.
torch::Tensor forward(torch::Tensor x) {
int tensorDims = 0;
for (int i : x.sizes())
tensorDims++;
if (tensorDims == 1)
x = torch::unsqueeze(x,fc3{ nullptr };
};
NeuralNetwork::NeuralNetwork()
{
}
void NeuralNetwork::Train()
{
// Create saw-tooth wave,WAVEFORM_SIZE samples
for (int i = 0; i < WAVEFORM_SIZE; i++) {
trainingData[i] = (i - floor((float)WAVEFORM_SIZE / 2)) / ((float)WAVEFORM_SIZE/2);
}
// Convert waveform.samples into torch::Tensor
c10::DeviceType deviceType;
if (torch::cuda::is_available()) {
deviceType = torch::kCUDA;
}
else {
deviceType = torch::kcpu;
}
// Adjust dimensions (unsqueeze array)
float newArr[1][WAVEFORM_SIZE] = { {0} };
for (int i = 0; i < WAVEFORM_SIZE; i++)
newArr[0][i] = trainingData[i];
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(deviceType);
torch::Tensor inputTensor = torch::from_blob(newArr,options);
inputTensor.set_requires_grad(true);
// Now we have inputTensor
// Get target Tensor
torch::Tensor targetTensor = getratingTensor();
// Train model
// Create a new Net.
auto net = std::make_shared<Net>(WAVEFORM_SIZE);
// Instantiate an SGD optimization algorithm to update our Net's parameters.
torch::optim::SGD optimizer(net->parameters(),/*lr=*/0.01);
size_t batch_index = 0;
// Check trainingSample.sampleTensor's length to make sure it works
if (inputTensor.size(1) != WAVEFORM_SIZE) {
throw std::logic_error("Input must match WAVEFORM_SIZE");
}
// Reset gradients.
optimizer.zero_grad();
// Execute the model on the input data.
torch::Tensor prediction = net->forward(inputTensor);
// Compute a loss value to judge the prediction of our model.
std::ostringstream os_tensor0;
os_tensor0 << targetTensor;
DBG("target_val: \n" + os_tensor0.str());
std::ostringstream os_tensor1;
os_tensor1 << prediction;
DBG("prediction_val: \n" + os_tensor1.str());
torch::Tensor loss = torch::nll_loss(prediction,targetTensor);
// Compute gradients of the loss w.r.t. the parameters of our model.
loss.backward();
// Update the parameters based on the calculated gradients.
optimizer.step();
// Output the loss and checkpoint every 100 batches.
if (++batch_index % 100 == 0) {
std::cout << "Epoch: " << epoch << " | Batch: " << batch_index
<< " | Loss: " << loss.item<float>() << std::endl;
// Serialize your model periodically as a checkpoint.
torch::save(net,"net.pt");
}
}
torch::Tensor NeuralNetwork::getratingTensor()
{
int rating = 0;
c10::DeviceType deviceType;
if (torch::cuda::is_available()) {
deviceType = torch::kCUDA;
}
else {
deviceType = torch::kcpu;
}
float ratingArray[1][3] = { {0} };
ratingArray[0][rating] = 1;
ostringstream os0;
for (int i = 0; i < (sizeof(ratingArray[0]) / sizeof(ratingArray[0][0])); i++) {
os0 << ratingArray[0][i];
os0 << ",options);
ostringstream os1;
os1 << ratingTensor[0];
DBG("ratingTensor: \n" + os1.str());
return ratingTensor.clone();
}
template <class T>
string NeuralNetwork::Item2String(T x) {
std::ostringstream os;
os << x;
return os.str();
}
- 在
PluginEditor.cpp
的其他两个#include
旁边,添加#include "Debug.h"
- 在同一文件的
DebugAudioProcessorEditor::DebugAudioProcessorEditor
构造函数中,添加以下代码:
NeuralNetwork neuralNet = NeuralNetwork();
neuralNet.Train();
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)