变形金刚脚本运行,但在PyCharm调试器中中断

问题描述

我在调试模式下使用以下脚本来更好地了解Transformers的model.generate()函数的内部工作原理。它是我为客户端构建的API的一部分,因此请忽略Flask代码-此处的关键问题是使调试器正常工作,以便我可以在模型的生成过程中遵循标记化文本。变形金刚库会在调试时中断吗?为什么会这样呢?:

import os
import shutil
import subprocess

import numpy as np
import torch
from flask import Flask,request
from flask_restful import Api,Resource,reqparse

import transformers
from transformers import CONfig_NAME,WEIGHTS_NAME,GPT2LMHeadModel,GPT2Tokenizer


'''
Model Imports
'''


app = Flask(__name__)
api = Api(app)


def get_model(model_dir):
    if not os.path.exists(model_dir):
        print(f'Building model directory at {model_dir}')
        os.mkdir(model_dir)
    try:
        command = f'aws s3 sync AWS_BUCKET {model_dir}'
        subprocess.call(command.split())
    except:
        print('AWS commandline call Failed. Have you configured the AWS cli yet?')

MODEL_DIR = "./model"
if not os.path.exists(MODEL_DIR):
    get_model(MODEL_DIR)

NUM_PATTERN = r'\s\d+[A-Za-z]*'

output_model_file = os.path.join(MODEL_DIR,WEIGHTS_NAME)
output_config_file = os.path.join(MODEL_DIR,CONfig_NAME)

# Re-load the saved model and vocabulary
print('Loading model!')
model = GPT2LMHeadModel.from_pretrained(MODEL_DIR)
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_DIR)

'''
Arg Parser
'''
parser = reqparse.RequestParser()
parser.add_argument('prompt',type=str,help='Main input string to be transformed. required.',required=True)
parser.add_argument('max_length',type=int,help='Max length for generation.',default=20)
parser.add_argument('repetition_penalty',help='Penalty for word repetition. Higher = fewer repetitions.',default=5)
parser.add_argument('length_penalty',help='Exponential penalty for length. Higher = shorter sentences.',default=1)
parser.add_argument('num_beams',help='# Beams to use for beam search.',default=5)
parser.add_argument('temperature',type=float,help='Temperature of the softmax operation used in generation.',default=3)
parser.add_argument('top_k',help='Top words to select from during text generation.',default=50)
parser.add_argument('top_p',help='Top-P for Nucleus Sampling. Lower = more restrictive search.',default=0.8)
parser.add_argument('num_return_sequences',help='Number of sequences to generate.',default=1)

def decode(output):
    return str(tokenizer.decode(output,skip_special_tokens=True))

class TransformerAPI(Resource):
    def get(self):
        args = parser.parse_args()

        app.logger.info(f'Using model loaded from {MODEL_DIR}.')
        ids = tokenizer.encode(args['prompt'])
        inp = torch.tensor(np.array(ids)[np.newaxis,:])
        
        #Account for generation limits < input value
        if inp.shape[1] >= args['max_length']:
            print(inp.shape[1])
            print(args['max_length'])
            result = inp[:,:args['max_length']]
            print(result)
            decoded = [decode(result.tolist()[0])] * args['num_return_sequences']
            return {'completion': decoded,'model_used': MODEL_DIR}
        else:
            result = model.generate(input_ids=inp,max_length=args['max_length'],repetition_penalty=args['repetition_penalty'],length_penalty=args['length_penalty'],do_sample=True,num_beams=args['num_beams'],temperature=args['temperature'],top_k=args['top_k'],top_p=args['top_p'],num_return_sequences=args['num_return_sequences'])

            decoded = [decode(l.tolist()) for l in result]

            return {'completion': decoded,'model_used': MODEL_DIR}

api.add_resource(TransformerAPI,'/api/v1')

if __name__ == '__main__':
    #app.run(debug=True)
    ids = tokenizer.encode('The present invention')
    inp = torch.tensor(np.array(ids)[np.newaxis,:])
    result = model.generate(input_ids=inp,max_length=15,repetition_penalty=5,length_penalty=1,num_beams=5,temperature=3,num_return_sequences=1)
    print(result)

python app.py执行得很好,但是在调试模式下(在PyCharm上)运行相同会遇到错误

Traceback (most recent call last):
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevd.py",line 1438,in _exec
    pydev_imports.execfile(file,globals,locals)  # execute the script
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py",line 18,in execfile
    exec(compile(contents+"\n",file,'exec'),glob,loc)
  File "/Users/mgb/Desktop/Work/Apteryx_Clients_2/bao/bao-ai/apteryx_apis/patformer/app.py",line 45,in <module>
    tokenizer = GPT2Tokenizer.from_pretrained(MODEL_DIR)
  File "/Users/mgb/opt/anaconda3/envs/transformers/lib/python3.8/site-packages/transformers/tokenization_utils.py",line 282,in from_pretrained
    return cls._from_pretrained(*inputs,**kwargs)
  File "/Users/mgb/opt/anaconda3/envs/transformers/lib/python3.8/site-packages/transformers/tokenization_utils.py",line 411,in _from_pretrained
    tokenizer = cls(*init_inputs,**init_kwargs)
  File "/Users/mgb/opt/anaconda3/envs/transformers/lib/python3.8/site-packages/transformers/tokenization_gpt2.py",line 118,in __init__
    super(GPT2Tokenizer,self).__init__(bos_token=bos_token,eos_token=eos_token,unk_token=unk_token,line 232,in __init__
    assert isinstance(value,str) or (six.PY2 and isinstance(value,unicode))
AssertionError

这似乎与GPT2Tokenizer对象有关。

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...