我试图让在2017 WWDC上演示的Apple样本Core ML模型正常运行.我正在使用GoogLeNet来尝试对图像进行分类(参见
Apple Machine Learning Page).该模型将CVPixelBuffer作为输入.我有一个名为imageSample.jpg的图像,我正用于此演示.我的代码如下:
var sample = UIImage(named: "imageSample")?.cgImage let bufferThree = getCVPixelBuffer(sample!) let model = GoogLeNetPlaces() guard let output = try? model.prediction(input: GoogLeNetPlacesInput.init(sceneImage: bufferThree!)) else { fatalError("Unexpected runtime error.") } print(output.sceneLabel)
我总是在输出中获得意外的运行时错误,而不是图像分类.我转换图片的代码如下:
func getCVPixelBuffer(_ image: CGImage) -> CVPixelBuffer? { let imageWidth = Int(image.width) let imageHeight = Int(image.height) let attributes : [NSObject:AnyObject] = [ kCVPixelBufferCGImageCompatibilityKey : true as AnyObject,kCVPixelBufferCGBitmapContextCompatibilityKey : true as AnyObject ] var pxbuffer: CVPixelBuffer? = nil CVPixelBufferCreate(kCFAllocatorDefault,imageWidth,imageHeight,kCVPixelFormatType_32ARGB,attributes as CFDictionary?,&pxbuffer) if let _pxbuffer = pxbuffer { let flags = CVPixelBufferLockFlags(rawValue: 0) CVPixelBufferLockBaseAddress(_pxbuffer,flags) let pxdata = CVPixelBufferGetBaseAddress(_pxbuffer) let rgbColorSpace = CGColorSpaceCreateDeviceRGB(); let context = CGContext(data: pxdata,width: imageWidth,height: imageHeight,bitsPerComponent: 8,bytesPerRow: CVPixelBufferGetBytesPerRow(_pxbuffer),space: rgbColorSpace,bitmapInfo: CGImageAlphaInfo.premultipliedFirst.rawValue) if let _context = context { _context.draw(image,in: CGRect.init(x: 0,y: 0,height: imageHeight)) } else { CVPixelBufferUnlockBaseAddress(_pxbuffer,flags); return nil } CVPixelBufferUnlockBaseAddress(_pxbuffer,flags); return _pxbuffer; } return nil }
我从之前的StackOverflow帖子中得到了这段代码(最后一个答案here).我认识到代码可能不正确,但我不知道自己该怎么做.我相信这是包含错误的部分.该模型需要以下类型的输入:Image< RGB,224,224>
解决方法
你不需要自己做一堆图像修改就可以将Core ML模型与图像一起使用 – 新的
Vision framework可以为你做到这一点.
import Vision import CoreML let model = try VNCoreMLModel(for: MyCoreMLGeneratedModelClass().model) let request = VNCoreMLRequest(model: model,completionHandler: myResultsMethod) let handler = VNImageRequestHandler(url: myImageURL) handler.perform([request]) func myResultsMethod(request: VNRequest,error: Error?) { guard let results = request.results as? [VNClassificationObservation] else { fatalError("huh") } for classification in results { print(classification.identifier,// the scene label classification.confidence) } }
WWDC17 session on Vision应该有更多信息 – 明天下午.