在 Julia 中将数据集拆分为训练和测试

问题描述

我正在尝试将数据集拆分为 Julia 中的训练和测试子集。到目前为止,我已经尝试使用 MLDataUtils.jl 包进行此操作,但结果并不符合预期。 以下是我的发现和问题:

代码

# the inputs are

a = DataFrame(A = [1,2,3,4,5,6,7,8,9,10],B = [1,C = [1,10]
             )
b = [1,10]

using MLDataUtils
(x1,y1),(x2,y2) = stratifiedobs((a,b),p=0.7)

#Output of this operation is: (which is not the expectation)
println("x1 is: $x1")
x1 is:
10×3 DataFrame
│ Row │ A     │ B     │ C     │
│     │ Int64 │ Int64 │ Int64 │
├─────┼───────┼───────┼───────┤
│ 1   │ 1     │ 1     │ 1     │
│ 2   │ 2     │ 2     │ 2     │
│ 3   │ 3     │ 3     │ 3     │
│ 4   │ 4     │ 4     │ 4     │
│ 5   │ 5     │ 5     │ 5     │
│ 6   │ 6     │ 6     │ 6     │
│ 7   │ 7     │ 7     │ 7     │
│ 8   │ 8     │ 8     │ 8     │
│ 9   │ 9     │ 9     │ 9     │
│ 10  │ 10    │ 10    │ 10    │

println("y1 is: $y1")
y1 is:
10-element Array{Int64,1}:
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10

# but x2 is printed as 
(0×3 SubDataFrame,Float64[]) 

# while y2 as 
0-element view(::Array{Float64,1},Int64[]) with eltype Float64)

但是,我希望将此数据集分成 2 部分,其中 70% 的数据在训练中,30% 的数据在测试中。 请提出一种更好的方法来在 julia 中执行此操作。 提前致谢。

解决方法

可能 MLJ.jl 开发人员可以向您展示如何使用通用生态系统进行操作。这是仅使用 DataFrames.jl 的解决方案:

julia> using DataFrames,Random

julia> a = DataFrame(A = [1,2,3,4,5,6,7,8,9,10],B = [1,C = [1,10]
                    )
10×3 DataFrame
 Row │ A      B      C     
     │ Int64  Int64  Int64 
─────┼─────────────────────
   1 │     1      1      1
   2 │     2      2      2
   3 │     3      3      3
   4 │     4      4      4
   5 │     5      5      5
   6 │     6      6      6
   7 │     7      7      7
   8 │     8      8      8
   9 │     9      9      9
  10 │    10     10     10

julia> function splitdf(df,pct)
           @assert 0 <= pct <= 1
           ids = collect(axes(df,1))
           shuffle!(ids)
           sel = ids .<= nrow(df) .* pct
           return view(df,sel,:),view(df,.!sel,:)
       end
splitdf (generic function with 1 method)

julia> splitdf(a,0.7)
(7×3 SubDataFrame
 Row │ A      B      C     
     │ Int64  Int64  Int64 
─────┼─────────────────────
   1 │     3      3      3
   2 │     4      4      4
   3 │     6      6      6
   4 │     7      7      7
   5 │     8      8      8
   6 │     9      9      9
   7 │    10     10     10,3×3 SubDataFrame
 Row │ A      B      C     
     │ Int64  Int64  Int64 
─────┼─────────────────────
   1 │     1      1      1
   2 │     2      2      2
   3 │     5      5      5)

我使用视图来节省内存,但如果您愿意,也可以只具体化训练和测试数据帧。

,

这就是我为 Beta Machine Learning Toolkit 中的通用数组实现它的方式:

"""
    partition(data,parts;shuffle=true)
Partition (by rows) one or more matrices according to the shares in `parts`.
# Parameters
* `data`: A matrix/vector or a vector of matrices/vectors
* `parts`: A vector of the required shares (must sum to 1)
* `shufle`: Wheter to randomly shuffle the matrices (preserving the relative order between matrices)
 """
function partition(data::AbstractArray{T,1},parts::AbstractArray{Float64,1};shuffle=true) where T <: AbstractArray
        n = size(data[1],1)
        if !all(size.(data,1) .== n)
            @error "All matrices passed to `partition` must have the same number of rows"
        end
        ridx = shuffle ? Random.shuffle(1:n) : collect(1:n)
        return partition.(data,Ref(parts);shuffle=shuffle,fixedRIdx = ridx)
end

function partition(data::AbstractArray{T,N} where N,1};shuffle=true,fixedRIdx=Int64[]) where T
    n = size(data,1)
    nParts = size(parts)
    toReturn = []
    if !(sum(parts) ≈ 1)
        @error "The sum of `parts` in `partition` should total to 1."
    end
    ridx = fixedRIdx
    if (isempty(ridx))
       ridx = shuffle ? Random.shuffle(1:n) : collect(1:n)
    end
    current = 1
    cumPart = 0.0
    for (i,p) in enumerate(parts)
        cumPart += parts[i]
        final = i == nParts ? n : Int64(round(cumPart*n))
        push!(toReturn,data[ridx[current:final],:])
        current = (final +=1)
    end
    return toReturn
end

用于:

julia> x = [1:10 11:20]
julia> y = collect(31:40)
julia> ((xtrain,xtest),(ytrain,ytest)) = partition([x,y],[0.7,0.3])

或者,您也可以分成三个或更多部分,并且要分区的数组数量也是可变的。

默认情况下,它们也会被打乱,但您可以使用参数 shuffle...

,

public class emergencyContact extends AppCompatActivity { private static final int CONTACT_PICKER_REQUEST = 202; private Button addContactsBtn; private ArrayList<ContactResult> list; public static final String SHARE_PREFS = "sheredPrefs"; RecyclerView recyclerView; ContactAdapter adapter; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_emergency_contact); loadData(); addContactsBtn = findViewById(R.id.add_contacts_btn); addContactsBtn.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { new MultiContactPicker.Builder(emergencyContact.this) //Activity/fragment context .theme(R.style.MyCustomPickerTheme) //Optional - default: MultiContactPicker.Azure .hideScrollbar(false) //Optional - default: false .showTrack(true) //Optional - default: true .searchIconColor(Color.WHITE) //Option - default: White .setChoiceMode(MultiContactPicker.CHOICE_MODE_MULTIPLE) //Optional - default: CHOICE_MODE_MULTIPLE .handleColor(ContextCompat.getColor(emergencyContact.this,R.color.azureColorPrimary)) //Optional - default: Azure Blue .bubbleColor(ContextCompat.getColor(emergencyContact.this,R.color.azureColorPrimary)) //Optional - default: Azure Blue .bubbleTextColor(Color.WHITE) //Optional - default: White .setTitleText("Select Contacts") //Optional - default: Select Contacts .setLoadingType(MultiContactPicker.LOAD_ASYNC) //Optional - default LOAD_ASYNC (wait till all loaded vs stream results) .limitToColumn(LimitColumn.NONE) //Optional - default NONE (Include phone + email,limiting to one can improve loading time) .setActivityAnimations(android.R.anim.fade_in,android.R.anim.fade_out,android.R.anim.fade_in,android.R.anim.fade_out) //Optional - default: No animation overrides .showPickerForResult(CONTACT_PICKER_REQUEST); } }); } @Override protected void onActivityResult(int requestCode,int resultCode,Intent data) { super.onActivityResult(requestCode,resultCode,data); if(requestCode == CONTACT_PICKER_REQUEST){ if(resultCode == RESULT_OK) { list = MultiContactPicker.obtainResult(data); buildRecycleView(); saveData(); } else if(resultCode == RESULT_CANCELED){ System.out.println("User closed the picker without selecting items."); } } } private void saveData() { SharedPreferences pref = getSharedPreferences(SHARE_PREFS,MODE_PRIVATE); SharedPreferences.Editor editor = pref.edit(); Gson gson = new Gson(); String jsonString = gson.toJson(list); editor.putString("List Key",jsonString); Log.e("MyTag","size="+ list.size()); editor.apply(); } private void loadData(){ if(list == null){ list = new ArrayList<>(); Log.e("loadData","Size="+list.size()); } SharedPreferences pref = getSharedPreferences(SHARE_PREFS,MODE_PRIVATE); Gson gson = new Gson(); String jsonString = pref.getString("List Key",null); Type type = new TypeToken<ArrayList<ContactResult>>(){}.getType(); list = gson.fromJson(jsonString,type); Log.e("MTag","Size="+list.size()); } private void buildRecycleView(){ recyclerView = findViewById(R.id.contact_rv); LinearLayoutManager LayoutManager= new LinearLayoutManager(this); ContactAdapter adapter = new ContactAdapter(list,this); recyclerView.setLayoutManager(LayoutManager); recyclerView.setAdapter(adapter); adapter.notifyDataSetChanged(); } 还有一个位置参数,在第二个位置需要一个百分比来分割。