Google AutoMl Tables:如何使用 JAVA 客户端库在数据集中设置目标列

问题描述

根据我的要求,我需要自动化整个周期:-

  1. 将数据从数据库提取为 CSV 格式。
  2. 上传到谷歌存储桶
  3. 在 google AutoML 表中创建数据集
  4. 将我上传的 CSV 从存储桶导入数据集
  5. 基于创建的数据集训练模型
  6. 训练后部署模型

我已经成功地完成了从第 1 步到第 4 步,我也完成了第 6 步,但是 我在第 5 步中遇到了问题,我需要在我的 CSV 文件中定义目标列是什么。

为了解决这个问题,我浏览了 AutoMl 文档,但无法通过它的(AutoMl Tables)JAVA 客户端库找到任何方法。这是文档中提供的代码,我们需要 2 个重要参数,它们是 tableSpecIdcolumnSpecId,下面是文档中的代码:- https://cloud.google.com/automl-tables/docs/train

import com.google.cloud.automl.v1beta1.AutoMlClient;
import com.google.cloud.automl.v1beta1.ColumnSpec;
import com.google.cloud.automl.v1beta1.ColumnSpecName;
import com.google.cloud.automl.v1beta1.LocationName;
import com.google.cloud.automl.v1beta1.Model;
import com.google.cloud.automl.v1beta1.OperationMetadata;
import com.google.cloud.automl.v1beta1.TablesModelMetadata;
import java.io.IOException;
import java.util.concurrent.ExecutionException;

class TablesCreateModel {

  public static void main(String[] args)
      throws IOException,ExecutionException,InterruptedException {
    // Todo(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String datasetId = "YOUR_DATASET_ID";
    String tableSpecId = "YOUR_TABLE_SPEC_ID";
    String columnSpecId = "YOUR_COLUMN_SPEC_ID";
    String displayName = "YOUR_DATASET_NAME";
    createModel(projectId,datasetId,tableSpecId,columnSpecId,displayName);
  }

  // Create a model
  static void createModel(
      String projectId,String datasetId,String tableSpecId,String columnSpecId,String displayName)
      throws IOException,InterruptedException {
    // Initialize client that will be used to send requests. This client only needs to be created
    // once,and can be reused for multiple requests. After completing all of your requests,call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (AutoMlClient client = AutoMlClient.create()) {
      // A resource that represents Google Cloud Platform location.
      LocationName projectLocation = LocationName.of(projectId,"us-central1");

      // Get the complete path of the column.
      ColumnSpecName columnSpecName =
          ColumnSpecName.of(projectId,"us-central1",columnSpecId);

      // Build the get column spec.
      ColumnSpec targetColumnSpec =
          ColumnSpec.newBuilder().setName(columnSpecName.toString()).build();

      // Set model Metadata.
      TablesModelMetadata Metadata =
          TablesModelMetadata.newBuilder()
              .setTargetColumnSpec(targetColumnSpec)
              .setTrainBudgetMilliNodeHours(24000)
              .build();

      Model model =
          Model.newBuilder()
              .setdisplayName(displayName)
              .setDatasetId(datasetId)
              .setTablesModelMetadata(Metadata)
              .build();

      // Create a model with the model Metadata in the region.
      OperationFuture<Model,OperationMetadata> future =
          client.createModelAsync(projectLocation,model);
      // OperationFuture.get() will block until the model is created,which may take several hours.
      // You can use OperationFuture.getinitialFuture to get a future representing the initial
      // response to the request,which contains information while the operation is in progress.
      System.out.format("Training operation name: %s%n",future.getinitialFuture().get().getName());
      *System*.out.println("Training started...");
    }
  }
}

在将我的 CSV 成功导入数据集(第 4 步)后,我可以从这 2 个重要参数中获得 tableSpecId,但我无法获得 columnSpecId 作为它定义了列之间的相关性,根据我的理解,它还定义了哪一列是目标列。

在网上做了更多研究后,我发现下面提到的 REST API 发送请求以在数据集中设置目标列 https://automl.clients6.google.com/v1beta1/projects/[projectId]/locations/[Location]/datasets/[DatasetId]?updateMask=tablesDatasetMetadata.targetColumnSpecId&key=[auth stuff] 但我使用的是 SpringBoot,这意味着我正在使用 AutoMl 的 JAVA 客户端库。现在,在 AutoML 的客户端库中,我找不到任何可以明确发送请求的方法,用于告诉 Google AutoMl 表该特定列是我的目标列。我猜在 python 客户端库中有可用的方法,但在 java 中没有。我先谢谢你。

错误(在 columnSpecId 中传递空字符串时):-

com.google.api.gax.rpc.invalidargumentexception: io.grpc.StatusRuntimeException: INVALID_ARGUMENT: Request contains an invalid argument.
java.util.concurrent.ExecutionException: com.google.api.gax.rpc.invalidargumentexception: io.grpc.StatusRuntimeException: INVALID_ARGUMENT: Request contains an invalid argument.
    at com.google.common.util.concurrent.AbstractFuture.getDoneValue(AbstractFuture.java:566)
    at com.google.common.util.concurrent.AbstractFuture.get(AbstractFuture.java:547)
    at com.google.common.util.concurrent.FluentFuture$TrustedFuture.get(FluentFuture.java:86)
    at com.google.common.util.concurrent.ForwardingFuture.get(ForwardingFuture.java:62)
    at com.realcoderz.ai.AutoMlTables.TablesCreateModel.createModel(TablesCreateModel.java:76)
    at com.realcoderz.ai.AutoMlTables.TablesCreateModel.getTablesCreateModel(TablesCreateModel.java:29)
    at com.realcoderz.ai.controller.aimasterController.CreateModel(aimasterController.java:105)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:64)
    at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.base/java.lang.reflect.Method.invoke(Method.java:564)
    at org.springframework.web.method.support.invocableHandlerMethod.doInvoke(invocableHandlerMethod.java:197)
    at org.springframework.web.method.support.invocableHandlerMethod.invokeForRequest(invocableHandlerMethod.java:141)
    at org.springframework.web.servlet.mvc.method.annotation.ServletinvocableHandlerMethod.invokeAndHandle(ServletinvocableHandlerMethod.java:106)
    at org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.invokeHandlerMethod(RequestMappingHandlerAdapter.java:894)
    at org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.handleInternal(RequestMappingHandlerAdapter.java:808)
    at org.springframework.web.servlet.mvc.method.AbstractHandlerMethodAdapter.handle(AbstractHandlerMethodAdapter.java:87)
    at org.springframework.web.servlet.dispatcherServlet.dodispatch(dispatcherServlet.java:1060)
    at org.springframework.web.servlet.dispatcherServlet.doService(dispatcherServlet.java:962)
    at org.springframework.web.servlet.FrameworkServlet.processRequest(FrameworkServlet.java:1006)
    at org.springframework.web.servlet.FrameworkServlet.doGet(FrameworkServlet.java:898)
    at javax.servlet.http.HttpServlet.service(HttpServlet.java:626)
    at org.springframework.web.servlet.FrameworkServlet.service(FrameworkServlet.java:883)
    at javax.servlet.http.HttpServlet.service(HttpServlet.java:733)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:231)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
    at org.apache.tomcat.websocket.server.WsFilter.doFilter(WsFilter.java:53)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
    at org.springframework.web.filter.RequestContextFilter.doFilterInternal(RequestContextFilter.java:100)
    at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:119)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
    at org.springframework.web.filter.FormContentFilter.doFilterInternal(FormContentFilter.java:93)
    at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:119)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
    at org.springframework.web.filter.CharacterEncodingFilter.doFilterInternal(CharacterEncodingFilter.java:201)
    at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:119)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
    at org.apache.catalina.core.StandardWrapperValve.invoke(StandardWrapperValve.java:202)
    at org.apache.catalina.core.StandardContextValve.invoke(StandardContextValve.java:97)
    at org.apache.catalina.authenticator.AuthenticatorBase.invoke(AuthenticatorBase.java:542)
    at org.apache.catalina.core.StandardHostValve.invoke(StandardHostValve.java:143)
    at org.apache.catalina.valves.ErrorReportValve.invoke(ErrorReportValve.java:92)
    at org.apache.catalina.core.StandardEngineValve.invoke(StandardEngineValve.java:78)
    at org.apache.catalina.connector.CoyoteAdapter.service(CoyoteAdapter.java:343)
    at org.apache.coyote.http11.Http11Processor.service(Http11Processor.java:374)
    at org.apache.coyote.AbstractProcessorLight.process(AbstractProcessorLight.java:65)
    at org.apache.coyote.AbstractProtocol$ConnectionHandler.process(AbstractProtocol.java:888)
    at org.apache.tomcat.util.net.NioEndpoint$SocketProcessor.doRun(NioEndpoint.java:1597)
    at org.apache.tomcat.util.net.socketProcessorBase.run(SocketProcessorBase.java:49)
    at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1130)
    at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:630)
    at org.apache.tomcat.util.threads.TaskThread$WrappingRunnable.run(TaskThread.java:61)
    at java.base/java.lang.Thread.run(Thread.java:832)
Caused by: com.google.api.gax.rpc.invalidargumentexception: io.grpc.StatusRuntimeException: INVALID_ARGUMENT: Request contains an invalid argument.
    at com.google.api.gax.rpc.ApiExceptionFactory.createException(ApiExceptionFactory.java:49)
    at com.google.api.gax.grpc.GrpcApiExceptionFactory.create(GrpcApiExceptionFactory.java:72)
    at com.google.api.gax.grpc.GrpcApiExceptionFactory.create(GrpcApiExceptionFactory.java:60)
    at com.google.api.gax.grpc.GrpcExceptionCallable$ExceptionTransformingFuture.onFailure(GrpcExceptionCallable.java:97)
    at com.google.api.core.ApiFutures$1.onFailure(ApiFutures.java:68)
    at com.google.common.util.concurrent.Futures$CallbackListener.run(Futures.java:1041)
    at com.google.common.util.concurrent.DirectExecutor.execute(DirectExecutor.java:30)
    at com.google.common.util.concurrent.AbstractFuture.executeListener(AbstractFuture.java:1215)
    at com.google.common.util.concurrent.AbstractFuture.complete(AbstractFuture.java:983)
    at com.google.common.util.concurrent.AbstractFuture.setException(AbstractFuture.java:771)
    at io.grpc.stub.ClientCalls$GrpcFuture.setException(ClientCalls.java:563)
    at io.grpc.stub.ClientCalls$UnaryStreamToFuture.onClose(ClientCalls.java:533)
    at io.grpc.internal.DelayedClientCall$DelayedListener$3.run(DelayedClientCall.java:464)
    at io.grpc.internal.DelayedClientCall$DelayedListener.delayOrExecute(DelayedClientCall.java:428)
    at io.grpc.internal.DelayedClientCall$DelayedListener.onClose(DelayedClientCall.java:461)
    at io.grpc.internal.ClientCallImpl.cloSEObserver(ClientCallImpl.java:617)
    at io.grpc.internal.ClientCallImpl.access$300(ClientCallImpl.java:70)
    at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$1StreamClosed.runInternal(ClientCallImpl.java:803)
    at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$1StreamClosed.runInContext(ClientCallImpl.java:782)
    at io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
    at io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:123)
    at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:515)
    at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264)
    at java.base/java.util.concurrent.ScheduledThreadPoolExecutor$ScheduledFutureTask.run(ScheduledThreadPoolExecutor.java:304)
    at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1130)
    at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:630)
    ... 1 more
Caused by: io.grpc.StatusRuntimeException: INVALID_ARGUMENT: Request contains an invalid argument.
    at io.grpc.Status.asRuntimeException(Status.java:533)
    ... 16 more

这里有 2 张图像用于比较 图像 1 我已经通过我的 springboot 应用程序发送请求,其中我将 columnSpecId 留空 和在 图像 2 中我手动设置在谷歌云控制台 UI 中定位和训练模型 Image 1 Image 2

解决方法

您可以通过调用 TableSpec 上的 ListColumnSpecs 来获取 columnSpecId:https://googleapis.dev/java/google-cloud-clients/latest/com/google/cloud/automl/v1beta1/AutoMlClient.html#listColumnSpecs-com.google.cloud.automl.v1beta1.ListColumnSpecsRequest-

像这样:

TableSpecName parent = TableSpecName.of("[PROJECT]","[LOCATION]","[DATASET]","[TABLE_SPEC]")
ColumnSpec targetColumnSpec;
for (ColumnSpec element : autoMlClient.listColumnSpecs(parent).iterateAll()) {
  if (element.getDisplayName().equals("[MY_TARGET_COLUMN]")) {
    targetColumnSpec = element;
    break;
  }
}

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...