Estimator 이해 및 fit(), predict() 메서드

사이킷런은 API 일관성과 개발 편의성을 제공하기 위한 노력이 엿보이는 머신러닝 학습에 최적인 패키지입니다.

사이킷런은 머신러닝 모델 학습을 위해서 fit() 메서드와 학습된 모델의 예측을 위해 predict() 메서드를 제공합니다.

 

사이킷런에서는 분류 알고리즘을 구현한 클래스를 Classifier로, 회귀 알고리즘을 구현한 클래스를 Regressor로 지칭하고, 이 둘을 합쳐 Estimator 클래스라고 부릅니다.(지도학습의 모든 알고리즘을 구현한 클래스를 통칭함)

 

이 Estimator 클래스는 fit()과 predict()만을 이용해 간단하게 학습과 예측 결과를 반환합니다.

Scikit-learn class 구현 클래스
Estimator
(분류+회귀)
Classifier (분류) DecisionTreeClassifier
RandomForestClassifier
GradientBoostingClassifer
GaussianNB
SVC
Regressor (회귀) LinearRegression
Ridge, Lasso
RandomForestRegressor
GradientBoostingRegressor

Estimator의 fit(), predict()

cross_val_score()와 같은 evaluation 함수, GridSearchCV와 같은 하이퍼 파라미터 튜닝을 지원하는 클래스의 경우 이 Estimator를 인자로 받습니다. 인자로 받은 Estimator에 대해서 cross_val_score(), GridSearchCV.fit() 함수 내에서 Estimator의 fit()과 predict()를 호출해서 평가를 하거나 하이퍼 파라미터 튜닝을 수행하는 것입니다.


비지도 학습의 fit(), transform()

사이킷런 비지도 학습의 차원축소, 클러스터링, 피처 추출 등을 구현한 클래스 역시 fit() 과 transform()을 적용하지만, 이것은 지도학습에서의 fit(), transform()과 다른 의미입니다.

 

비지도 학습에서의 fit()은 입력 데이터의 형태에 맞춰 데이터를 변환하기 위한 사전 구조를 맞추는 작업이고,

transform()은 fit 이후 입력 데이터의 차원 변환, 클러스터랑, 피처 추출 등의 실제 작업을 수행합니다.

 

 

'💡 AI > ML' 카테고리의 다른 글

ML - 평가(evaluation)  (0) 2021.11.12
ML - fit(), transform() 과 fit_transform()의 차이  (0) 2021.09.15
ML - 타이타닉 생존자 예측  (0) 2021.09.15
ML - 정규화  (0) 2021.09.15
ML - 레이블 인코딩, 원핫 인코딩  (0) 2021.09.15
복사했습니다!