Tensorflow解释器针对“数据计数” iOS引发错误

问题描述

我正在使用TensorFlowLiteSwift,并且正在使用的模型负责在将图像裁剪为梯形形状时使图像变平。 现在,Tensorflow并没有提供太多文档。因此,我一直在尝试通过他们的示例项目来实现。

但是这里有个问题,它引发错误,提示“提供的数据数必须与所需的数匹配”,所需的数为4。我在Interpreter.swift中回溯了byteCount,但找不到实际的设置器。

那么,.tflite模型负责“所需计数”吗?如果没有,那么如何设置?

这是我认为有助于理解我的问题的代码段:

/// Performs image preprocessing,invokes the `Interpreter`,and processes the inference results.
    func runModel(on item: ImageProcessInfo) -> UIImage? {
        let rgbData = item.resizedImage.scaledData(with: CGSize(width: 1000,height: 900),byteCount: inputWidth * inputHeight
                                                   * batchSize,isQuantized: false)
        
        var corner = item.corners.map { $0.map { p -> (Float,Float) in
            return (Float(p.x),Float(p.y))
            } }
        var item = item
        
        guard let height = NSMutableData(capacity: 0) else { return nil }
        height.append(&item.originalHeight,length: 4)
        
        guard let width = NSMutableData(capacity: 0) else { return nil }
        width.append(&item.originalWidth,length: 4)
        
        guard let corners = NSMutableData(capacity: 0) else { return nil }
        corners.append(&corner,length: 4)
        
        do {
            try interpreter.copy(rgbData!,toInputAt: 0)
            try interpreter.copy(height as Data,toInputAt: 1)
            try interpreter.copy(width as Data,toInputAt: 2)
            try interpreter.copy(corners as Data,toInputAt: 3)
            try interpreter.invoke()
            
            let outputTensor1 = try self.interpreter.output(at: 0)
            
            guard let cgImage = postprocessImageData(data: outputTensor1.data,size: CGSize(width: 1000,height: 900)) else {
                return nil
            }
            
            let outputImage = UIImage(cgImage: cgImage)
            return outputImage
            
        } catch {
            dump(error)
            return nil
        }
    }

extension UIImage {
    func scaledData(with size: CGSize,byteCount: Int,isQuantized: Bool) -> Data? {
      guard let cgImage = self.cgImage,cgImage.width > 0,cgImage.height > 0 else { return nil }
      guard let imageData = imageData(from: cgImage,with: size) else { return nil }
      var scaledBytes = [UInt8](repeating: 0,count: byteCount)
      var index = 0
      for component in imageData.enumerated() {
        let offset = component.offset
        let isAlphaComponent = (offset % 4)
          == 3
        guard !isAlphaComponent else { continue }
        scaledBytes[index] = component.element
        index += 1
      }
      if isQuantized { return Data(scaledBytes) }
      let scaledFloats = scaledBytes.map { (Float32($0) - 127.5) / 127.5 }
      return Data(copyingBufferOf: scaledFloats)
    }

private func imageData(from cgImage: CGImage,with size: CGSize) -> Data? {
      let bitmapInfo = CGBitmapInfo(
        rawValue: CGBitmapInfo.byteOrder32Big.rawValue | CGImageAlphaInfo.premultipliedLast.rawValue
      )
      let width = Int(size.width)
      let scaledBytesPerRow = (cgImage.bytesPerRow / cgImage.width) * width
      guard let context = CGContext(
          data: nil,width: width,height: Int(size.height),bitsPerComponent: cgImage.bitsPerComponent,bytesPerRow: scaledBytesPerRow,space: CGColorSpaceCreateDeviceRGB(),bitmapInfo: bitmapInfo.rawValue)
      else {
        return nil
      }
      context.draw(cgImage,in: CGRect(origin: .zero,size: size))
      return context.makeImage()?.dataProvider?.data as Data?
    }
}

@discardableResult
  public func copy(_ data: Data,toInputAt index: Int) throws -> Tensor {
    let maxIndex = inputTensorCount - 1
    guard case 0...maxIndex = index else {
      throw InterpreterError.invalidTensorIndex(index: index,maxIndex: maxIndex)
    }
    guard let cTensor = TfLiteInterpreterGetInputTensor(cInterpreter,Int32(index)) else {
      throw InterpreterError.allocateTensorsRequired
    }

    /* Error here */
    let byteCount = TfLiteTensorByteSize(cTensor)
    guard data.count == byteCount else {
      throw InterpreterError.invalidTensorDataCount(provided: data.count,required: byteCount)
    }

    #if swift(>=5.0)
      let status = data.withUnsafeBytes {
        TfLiteTensorCopyFromBuffer(cTensor,$0.baseAddress,data.count)
      }
    #else
      let status = data.withUnsafeBytes { TfLiteTensorCopyFromBuffer(cTensor,$0,data.count) }
    #endif  // swift(>=5.0)
    guard status == kTfLiteOk else { throw InterpreterError.failedToCopyDataToInputTensor }
    return try input(at: index)
  }

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)