MxNet Gluon,在 GPU 上训练 Sequential() 模型的正确方法

问题描述

我一直在尝试使用服务器上的 GPU 训练模型,但出现错误

Check Failed: e == CUBLAS_STATUS_SUCCESS (13 vs. 0) : cuBLAS: CUBLAS_STATUS_EXECUTION_Failed

我在论坛上发现了几个类似的话题,他们似乎都指出了安装的CUDA版本存在一些问题。 但是,在同一台机器上,我可以在同一环境下在 GPU 上训练一些对象检测模型而不会出现任何错误,所以我猜问题出在我的代码中。

这是我目前所做的:

import mxnet as mx
from mxnet import autograd,gluon,init
from mxnet import ndarray as nd
import numpy as np
from mxnet.gluon import nn
from mxnet import init
from mxnet.gluon import loss as gloss
import pandas as pd


def main()

    ## read,preprocess and split data
    df_data = pd.read_csv('some_file.csv')
    df_data = pre_process(df_data)
    X_train,y_train,X_test,y_test = split_data(df_data)


    train(X_train,y_test,lr,batch_size,nr_epochs)


def train(X_train,nr_epochs):
    ctx = mx.gpu(3)
    y_train = mx.nd.array(y_train.to_numpy().reshape(-1,1),dtype=np.float32,ctx=ctx)
    y_test = mx.nd.array(y_test.to_numpy().reshape(-1,ctx=ctx)
    X_train = mx.nd.array(X_train.to_numpy(),ctx=ctx)
    X_test = mx.nd.array(X_test.to_numpy(),ctx=ctx)

    ##--------------------
    ##   building model
    ##--------------------
    batch = batch_size
    epochs = nr_epochs
    dataset = gluon.data.dataset.ArrayDataset(X_train,y_train)
    data_loader = gluon.data.DataLoader(dataset,batch_size=batch,shuffle=True)

    model = nn.Sequential()
    model.add(nn.Dense(64,activation='relu'))
    model.add(nn.Dense(1))
    model.initialize(init.normal(sigma=0.01),ctx)
    model.collect_params().reset_ctx(ctx)
    loss = gloss.L2Loss()
    trainer = gluon.Trainer(model.collect_params(),'sgd',{'learning_rate': lr})

    ##--------------------
    ##   training
    ##--------------------
    for epoch in range(1,epochs + 1):
        for X_batch,Y_batch in data_loader:
            with autograd.record():
                l = loss(model(X_batch),Y_batch)
            l.backward()
            trainer.step(batch)

即使我尝试以下简单操作,也会发生错误

print(l)

在 for 循环中,所以我认为这里有问题?

我正在使用:
mxnet-cu90mkl=1.5.1.post0

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...