< [핸즈온 머신러닝 2판] 2.7 모델 세부 튜닝

프로그래밍 공부/핸즈온 머신러닝 2판

[핸즈온 머신러닝 2판] 2.7 모델 세부 튜닝

Rocketbabydolls 2022. 10. 21. 19:00

혼공머신에서 공부한 그리드서치(교차검증), 개념은 같지만 사용방법이 다르고 중요한건 매개변수의 정확한 뜻이 기억이 잘 안나서ㅠ 다시 기록해놓았다.

from sklearn.model_selection import GridSearchCV

param_grid = [
    # 12(=3 X 4)개의 하이퍼 파라미터 조합을 시도합니다.
    {'n_estimators':[3, 10, 30], 'max_features': [2, 4, 6, 8]},
    # bootstrap은 false로 하고 6(=2 X 3)개의 조합을 시도합니다.
    {'bootstrap': [False], 'n_estimators': [3, 10], 'max_features': [2, 3, 4]},
]

forest_reg = RandomForestRegressor(random_state=42)
# 다섯 개의 폴드로 훈련하면 총 (12+6)*5=90번의 훈련이 일어납니다.

grid_search = GridSearchCV(forest_reg, param_grid, cv=5,
                          scoring='neg_mean_squared_error',
                          return_train_score=True)
grid_search.fit(housing_prepared, housing_labels)

-그리드서치의 파라미터 

n_estimators : 전체 결정 트리의 개수 

max_features : 최대 특성의 개수

bootstrap : 매개변수 샘플을 랜덤하게 뽑을지 결정함

cv : 교차 검증에 사용할 겹의 개수(몇 개로 나누어서 검증하는지)

 

설정한 결정 트리와 특성의 개수를 곱한 만큼 조합을 시도한다. cv가 5이므로 첫 번쨰 파라미터 조합 3* 4 = 12 개에 두 번째 파라미터 조합 2*3 = 6 을 더한 것이 5번 일어나므로

총 90번의 훈련이 일어나게 된다.

 

final_model = grid_search.best_estimator_

X_test = strat_test_set.drop("median_house_value", axis=1)
y_test = strat_test_set["median_house_value"].copy()

X_test_prepared = full_pipeline.transform(X_test)
final_predictions = final_model.predict(X_test_prepared)

final_mse = mean_squared_error(y_test, final_predictions)
final_rmse = np.sqrt(final_mse)
final_rmse

 

앞에서 랜덤 포레스트를 사용해 모델을 훈련시키기로 결정했다. ( 많은 모델들이 있지만 이 책은 초심자에게도 권장되는 책이므로 선형회귀, 결정트리, 랜덤포레스트만 사용한 것으로 추정) 

 

따라서 GridSearch(교차검증)을 통해 얻은 최고의 하이퍼파라미터를 final_model 변수에 저장한 뒤, 우리가 구하고 싶은 것은 median_house_value 의 회귀 예측이므로 제외 후 모델을 학습시켜서 실제 값과의 평균 제곱근 오차를 구한다.

값은 시행마다 조금씩 다를 수밖에 없어 책에 나온 47730 보다는 높은 48000 이 나왔다. 
평균적으로 48000달러 정도 오차가 발생한다는 것.
이 추정값이 정확한지 알기 위해 신뢰 구간을 계산한다. (scipy.stats.t.interval() 사용) 
 
from scipy import stats

confidence = 0.95
squared_errors = (final_predictions - y_test) ** 2
np.sqrt(stats.t.interval(confidence, len(squared_errors) - 1, <---- 자유도!!
                         loc=squared_errors.mean(),  
                         scale=stats.sem(squared_errors)))

자유도에 대한 자세한 설명은 https://ondemandstore.tistory.com/2 여기서 볼 수 있다.