hyeye archive
[Python] sklearn의 train_test_split() 사용법 본문
1. 개요
머신러닝 모델을 학습하고 그 결과를 검증하기 위해서는 원래의 데이터를 Training, Validation, Testing의 용도로 나누어 다뤄야 한다. 그렇지 않고 Training에 사용한 데이터를 검증용으로 사용하면 시험문제를 알고 있는 상태에서 공부를 하고 그 지식을 바탕으로 시험을 치루는 꼴이므로 제대로 된 검증이 이루어지지 않기 때문이다.
딥러닝을 제외하고도 다양한 기계학습과 데이터 분석 툴을 제공하는 scikit-learn 패키지 중 model_selection에는 데이터 분할을 위한 train_test_split 함수가 들어있다.
2. Parameter & Return
from sklearn.model_selection import train_test_split
train_test_split(arrays, test_size, train_size, random_state, shuffle, stratify)
(1) Parameter
arrays : 분할시킬 데이터를 입력 (Python list, Numpy array, Pandas dataframe 등..)
test_size : 테스트 데이터셋의 비율(float)이나 갯수(int) (default = 0.25)
train_size : 학습 데이터셋의 비율(float)이나 갯수(int) (default = test_size의 나머지)
random_state : 호출할 때마다 동일한 학습/테스트용 데이터 세트를 생성하기 위해 주어지는 난수 값. train_test_split()는 랜덤으로 데이터를 분리하므로 random_state를 설정하지 않으면 수행할 때마다 다른 학습/테스트 데이터 세트가 생성된다. 따라서 random_state를 설정하여 수행 시 결과값을 동일하게 맞춰주는 것이다. 이 때 radom_state에는 어떤 숫자를 적든 그 기능은 같기 때문에 어떤 숫자를 적든 상관없다.
shuffle : 셔플여부설정 (default = True)
stratify : 지정한 Data의 비율을 유지한다. 예를 들어, Label Set인 Y가 25%의 0과 75%의 1로 이루어진 Binary Set일 때, stratify=Y로 설정하면 나누어진 데이터셋들도 0과 1을 각각 25%, 75%로 유지한 채 분할된다.
(2) Return
X_train, X_test, Y_train, Y_test : arrays에 데이터와 레이블을 둘 다 넣었을 경우의 반환이며, 데이터와 레이블의 순서쌍은 유지된다.
X_train, X_test : arrays에 레이블 없이 데이터만 넣었을 경우의 반환
3. Example
import numpy as np
from sklearn.model_selection import train_test_split
X = [[0,1],[2,3],[4,5],[6,7],[8,9]]
Y = [0,1,2,3,4]
# 데이터(X)만 넣었을 경우
X_train, X_test = train_test_split(X, test_size=0.2, random_state=123)
# X_train : [[0,1],[6,7],[8,9],[2,3]]
# X_test : [[4,5]]
# 데이터(X)와 레이블(Y)을 넣었을 경우
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.33, random_state=321)
# X_train : [[4,5],[0,1],[6,7]]
# Y_train : [2,0,3]
# X_test : [[2,3],[8,9]]
# Y_test : [1,4]
'Programming > Python' 카테고리의 다른 글
[Python] 클래스 상속 : super().__init__() 의미 (0) | 2023.09.15 |
---|---|
[Python] 클래스(class)와 객체(object), 인스턴스(instance), 생성자(constructor), 메소드(method) 의미 (0) | 2023.06.08 |
[Python] Numpy의 np.clip() 사용법 (0) | 2023.03.13 |
[Python] Pandas의 pd.get_dummies() 사용법 (0) | 2022.07.12 |
[Python] Pandas의 pd.read_csv() 사용법/csv 파일 불러오기 (0) | 2022.07.08 |