改进并验证机器学习代码在 Ray

问题描述

我有一段代码,我的同行认为它不会跨多个射线节点并行执行。下面是粘贴的代码,它从数据库获取数据,数据类型为 MLDataset。 MLDataset 是 ray.util.data.Dataset.py 中的一个类。获取它后,我使用 ray.util.iter 中的异步方法进行迭代,然后将其拆分为训练和测试数据。使用 TensorFlow API 转换为张量切片,然后将其输入 Ray 的 TFTrainer 类。Tf Trainer 类仅接受张量数据集。所以这里的需求是改进第一行之后的代码,验证跨多个节点的并行性。我可以分享整个代码,任何帮助表示赞赏。

import tensorflow as tf
import ray
from ray.util.sgd.tf.tf_trainer import TFTrainer,TFTrainable
from sklearn.model_selection import train_test_split

 def fetch_values_from_database():
     custom_dataset = <Custom Method of return type MLDataset.from_parallel_it >
      resultList = []
      for df in custom_dataset .gather_async():
            for value in df.values:
                resultList += [[va for va in value]]
      resultColumn = [value[-1] for value in resultList]
      trainColumns = [value[3:-1] for value in resultList]
      X_train,X_test,y_train,y_test = train_test_split(trainColumns,resultColumn,test_size=0.20,shuffle=True)
      train_dataset = tf.data.Dataset.from_tensor_slices((X_train,y_train)).batch(32)
      test_dataset = tf.data.Dataset.from_tensor_slices((X_test,y_test)).batch(32)
      return train_dataset,test_dataset


trainer = TFTrainer(
        model_creator=<invoke a method which returns TF model>,data_creator=fetch_values_from_database,verbose=True
        )

我基于此链接开发了一个完整的工作示例 https://docs.ray.io/en/master/raysgd/raysgd_tensorflow.html

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)