权重和偏差扫描无法使用pytorch lightning导入模块

问题描述

我正在使用pytorch-lightning训练变体自动编码器。我的pytorch-lightning代码可与“重量和偏差”记录器一起使用。我正在尝试使用W&B参数扫描进行参数扫描。

超参数搜索过程基于我从this repo.中学到的内容

运行正确初始化,但是当我的训练脚本使用第一组超参数运行时,出现以下错误

2020-08-14 14:09:07,109 - wandb.wandb_agent - INFO - About to run command: /usr/bin/env python train_sweep.py --LR=0.02537477586974176
Traceback (most recent call last):
  File "train_sweep.py",line 1,in <module>
    import yaml
ImportError: No module named yaml

yaml已安装并正常工作。我可以通过手动设置参数来训练网络,但不能通过参数扫描来训练网络。

这是我训练VAE的扫描脚本:

import yaml
import numpy as np
import ipdb
import torch
from vae_experiment import VAEXperiment
import torch.backends.cudnn as cudnn
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping
from vae_network import VanillaVAE
import os
import wandb
from utils import get_config,log_to_wandb

# Sweep parameters
hyperparameter_defaults = dict(
    root='data_semantics',gpus=1,batch_size = 2,lr = 1e-3,num_layers = 5,features_start = 64,bilinear = False,grad_batches = 1,epochs = 20
)

wandb.init(config=hyperparameter_defaults)
config = wandb.config

def main(hparams):

    model = VanillaVAE(hparams['exp_params']['img_size'],**hparams['model_params'])
    model.build_layers()
    experiment = VAEXperiment(model,hparams['exp_params'],hparams['parameters'])

    logger = WandbLogger(
        project='vae',name=config['logging_params']['name'],version=config['logging_params']['version'],save_dir=config['logging_params']['save_dir']
        )

    wandb_logger.watch(model.net)

    early_stopping = EarlyStopping(
       monitor='val_loss',min_delta=0.00,patience=3,verbose=False,mode='min'
    )

    runner = Trainer(weights_save_path="../../Logs/",min_epochs=1,logger=logger,log_save_interval=10,train_percent_check=1.,val_percent_check=1.,num_sanity_val_steps=5,early_stop_callback = early_stopping,**config['trainer_params']
     )

    runner.fit(experiment)

if __name__ == '__main__':
    main(config)

为什么会出现此错误

解决方法

您是否通过输入pythonpython3在shell中启动python? 您的脚本可能正在调用python 2而不是python 3。

在这种情况下,您可以明确告诉wandb使用python3。请参见this section of documentation,尤其是“使用Python 3运行扫描”。

,

问题在于我的代码结构和运行wandb命令的方式顺序不正确。用wandb查看this pytorch-ligthning是正确的结构。

这是我的重构代码:

#!/usr/bin/env python
import wandb
from utils import get_config

#---------------------------------------------------------------------------------------------

def main():

    """
    The training function used in each sweep of the model.
    For every sweep,this function will be executed as if it is a script on its own.
    """

    import wandb
    import yaml
    import numpy as np
    import torch
    from vae_experiment import VAEXperiment
    import torch.backends.cudnn as cudnn
    from pytorch_lightning import Trainer
    from pytorch_lightning.loggers import WandbLogger
    from pytorch_lightning.callbacks import EarlyStopping
    from vae_network import VanillaVAE
    import os
    from utils import log_to_wandb,format_config

    path_to_config = 'sweep.yaml'
    config = get_config(path_to_yaml)

    path_to_defaults = 'defaults.yaml'
    param_defaults = get_config(path_to_defaults)

    wandb.init(config=param_defaults)

    config = format_config(config,wandb.config)
    model = VanillaVAE(config['meta']['img_size'],hidden_dims = config['hidden_dims'],latent_dim  = config['latent_dim'])
    model.build_layers()

    experiment = VAEXperiment(model,config)

    early_stopping = EarlyStopping(
       monitor='val_loss',min_delta=0.00,patience=3,verbose=False,mode='max'
    )

    runner = Trainer(weights_save_path=config['meta']['save_dir'],min_epochs=1,train_percent_check=1.,val_percent_check=1.,num_sanity_val_steps=5,early_stop_callback = early_stopping,**config['trainer_params'])

    runner.fit(experiment)
    log_to_wandb(config,runner,experiment,path_to_config)

#---------------------------------------------------------------------------------------------

path_to_yaml = 'sweep.yaml'
sweep_config = get_config(path_to_yaml)
sweep_id = wandb.sweep(sweep_config)
wandb.agent(sweep_id,function=main)

#---------------------------------------------------------------------------------------------

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...