问题描述
我正在尝试将数据集拆分为 Julia 中的训练和测试子集。到目前为止,我已经尝试使用 MLDataUtils.jl 包进行此操作,但结果并不符合预期。 以下是我的发现和问题:
# the inputs are
a = DataFrame(A = [1,2,3,4,5,6,7,8,9,10],B = [1,C = [1,10]
)
b = [1,10]
using MLDataUtils
(x1,y1),(x2,y2) = stratifiedobs((a,b),p=0.7)
#Output of this operation is: (which is not the expectation)
println("x1 is: $x1")
x1 is:
10×3 DataFrame
│ Row │ A │ B │ C │
│ │ Int64 │ Int64 │ Int64 │
├─────┼───────┼───────┼───────┤
│ 1 │ 1 │ 1 │ 1 │
│ 2 │ 2 │ 2 │ 2 │
│ 3 │ 3 │ 3 │ 3 │
│ 4 │ 4 │ 4 │ 4 │
│ 5 │ 5 │ 5 │ 5 │
│ 6 │ 6 │ 6 │ 6 │
│ 7 │ 7 │ 7 │ 7 │
│ 8 │ 8 │ 8 │ 8 │
│ 9 │ 9 │ 9 │ 9 │
│ 10 │ 10 │ 10 │ 10 │
println("y1 is: $y1")
y1 is:
10-element Array{Int64,1}:
1
2
3
4
5
6
7
8
9
10
# but x2 is printed as
(0×3 SubDataFrame,Float64[])
# while y2 as
0-element view(::Array{Float64,1},Int64[]) with eltype Float64)
但是,我希望将此数据集分成 2 部分,其中 70% 的数据在训练中,30% 的数据在测试中。 请提出一种更好的方法来在 julia 中执行此操作。 提前致谢。
解决方法
可能 MLJ.jl 开发人员可以向您展示如何使用通用生态系统进行操作。这是仅使用 DataFrames.jl 的解决方案:
julia> using DataFrames,Random
julia> a = DataFrame(A = [1,2,3,4,5,6,7,8,9,10],B = [1,C = [1,10]
)
10×3 DataFrame
Row │ A B C
│ Int64 Int64 Int64
─────┼─────────────────────
1 │ 1 1 1
2 │ 2 2 2
3 │ 3 3 3
4 │ 4 4 4
5 │ 5 5 5
6 │ 6 6 6
7 │ 7 7 7
8 │ 8 8 8
9 │ 9 9 9
10 │ 10 10 10
julia> function splitdf(df,pct)
@assert 0 <= pct <= 1
ids = collect(axes(df,1))
shuffle!(ids)
sel = ids .<= nrow(df) .* pct
return view(df,sel,:),view(df,.!sel,:)
end
splitdf (generic function with 1 method)
julia> splitdf(a,0.7)
(7×3 SubDataFrame
Row │ A B C
│ Int64 Int64 Int64
─────┼─────────────────────
1 │ 3 3 3
2 │ 4 4 4
3 │ 6 6 6
4 │ 7 7 7
5 │ 8 8 8
6 │ 9 9 9
7 │ 10 10 10,3×3 SubDataFrame
Row │ A B C
│ Int64 Int64 Int64
─────┼─────────────────────
1 │ 1 1 1
2 │ 2 2 2
3 │ 5 5 5)
我使用视图来节省内存,但如果您愿意,也可以只具体化训练和测试数据帧。
,这就是我为 Beta Machine Learning Toolkit 中的通用数组实现它的方式:
"""
partition(data,parts;shuffle=true)
Partition (by rows) one or more matrices according to the shares in `parts`.
# Parameters
* `data`: A matrix/vector or a vector of matrices/vectors
* `parts`: A vector of the required shares (must sum to 1)
* `shufle`: Wheter to randomly shuffle the matrices (preserving the relative order between matrices)
"""
function partition(data::AbstractArray{T,1},parts::AbstractArray{Float64,1};shuffle=true) where T <: AbstractArray
n = size(data[1],1)
if !all(size.(data,1) .== n)
@error "All matrices passed to `partition` must have the same number of rows"
end
ridx = shuffle ? Random.shuffle(1:n) : collect(1:n)
return partition.(data,Ref(parts);shuffle=shuffle,fixedRIdx = ridx)
end
function partition(data::AbstractArray{T,N} where N,1};shuffle=true,fixedRIdx=Int64[]) where T
n = size(data,1)
nParts = size(parts)
toReturn = []
if !(sum(parts) ≈ 1)
@error "The sum of `parts` in `partition` should total to 1."
end
ridx = fixedRIdx
if (isempty(ridx))
ridx = shuffle ? Random.shuffle(1:n) : collect(1:n)
end
current = 1
cumPart = 0.0
for (i,p) in enumerate(parts)
cumPart += parts[i]
final = i == nParts ? n : Int64(round(cumPart*n))
push!(toReturn,data[ridx[current:final],:])
current = (final +=1)
end
return toReturn
end
用于:
julia> x = [1:10 11:20]
julia> y = collect(31:40)
julia> ((xtrain,xtest),(ytrain,ytest)) = partition([x,y],[0.7,0.3])
或者,您也可以分成三个或更多部分,并且要分区的数组数量也是可变的。
默认情况下,它们也会被打乱,但您可以使用参数 shuffle
...
public class emergencyContact extends AppCompatActivity {
private static final int CONTACT_PICKER_REQUEST = 202;
private Button addContactsBtn;
private ArrayList<ContactResult> list;
public static final String SHARE_PREFS = "sheredPrefs";
RecyclerView recyclerView;
ContactAdapter adapter;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_emergency_contact);
loadData();
addContactsBtn = findViewById(R.id.add_contacts_btn);
addContactsBtn.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
new MultiContactPicker.Builder(emergencyContact.this) //Activity/fragment context
.theme(R.style.MyCustomPickerTheme) //Optional - default: MultiContactPicker.Azure
.hideScrollbar(false) //Optional - default: false
.showTrack(true) //Optional - default: true
.searchIconColor(Color.WHITE) //Option - default: White
.setChoiceMode(MultiContactPicker.CHOICE_MODE_MULTIPLE) //Optional - default: CHOICE_MODE_MULTIPLE
.handleColor(ContextCompat.getColor(emergencyContact.this,R.color.azureColorPrimary)) //Optional - default: Azure Blue
.bubbleColor(ContextCompat.getColor(emergencyContact.this,R.color.azureColorPrimary)) //Optional - default: Azure Blue
.bubbleTextColor(Color.WHITE) //Optional - default: White
.setTitleText("Select Contacts") //Optional - default: Select Contacts
.setLoadingType(MultiContactPicker.LOAD_ASYNC) //Optional - default LOAD_ASYNC (wait till all loaded vs stream results)
.limitToColumn(LimitColumn.NONE) //Optional - default NONE (Include phone + email,limiting to one can improve loading time)
.setActivityAnimations(android.R.anim.fade_in,android.R.anim.fade_out,android.R.anim.fade_in,android.R.anim.fade_out) //Optional - default: No animation overrides
.showPickerForResult(CONTACT_PICKER_REQUEST);
}
});
}
@Override
protected void onActivityResult(int requestCode,int resultCode,Intent data) {
super.onActivityResult(requestCode,resultCode,data);
if(requestCode == CONTACT_PICKER_REQUEST){
if(resultCode == RESULT_OK) {
list = MultiContactPicker.obtainResult(data);
buildRecycleView();
saveData();
} else if(resultCode == RESULT_CANCELED){
System.out.println("User closed the picker without selecting items.");
}
}
}
private void saveData() {
SharedPreferences pref = getSharedPreferences(SHARE_PREFS,MODE_PRIVATE);
SharedPreferences.Editor editor = pref.edit();
Gson gson = new Gson();
String jsonString = gson.toJson(list);
editor.putString("List Key",jsonString);
Log.e("MyTag","size="+ list.size());
editor.apply();
}
private void loadData(){
if(list == null){
list = new ArrayList<>();
Log.e("loadData","Size="+list.size());
}
SharedPreferences pref = getSharedPreferences(SHARE_PREFS,MODE_PRIVATE);
Gson gson = new Gson();
String jsonString = pref.getString("List Key",null);
Type type = new TypeToken<ArrayList<ContactResult>>(){}.getType();
list = gson.fromJson(jsonString,type);
Log.e("MTag","Size="+list.size());
}
private void buildRecycleView(){
recyclerView = findViewById(R.id.contact_rv);
LinearLayoutManager LayoutManager= new LinearLayoutManager(this);
ContactAdapter adapter = new ContactAdapter(list,this);
recyclerView.setLayoutManager(LayoutManager);
recyclerView.setAdapter(adapter);
adapter.notifyDataSetChanged();
}
还有一个位置参数,在第二个位置需要一个百分比来分割。