Research Notes

[R] K-Nearest Neighbor (KNN) 본문

Programming Language/R

[R] K-Nearest Neighbor (KNN)

jiachoi 2023. 7. 3. 12:01

1. 분류에서의 적용기법 : K-인접기법(k-nearest neighbor)

- 분류 : 타겟값이 존재하는 지도학습. 분류 규칙을 생성하고 새로운 데이터를 분류하는 기법

 

2. KNN

k-인접방법(KNN) : k개의 가장 가까운 이웃들을 사용해서 분류하는 방법

: 위의 예제는 녹색을 분류하기 위함. k=3이면 가장 가까운 것들 (세모2개, 네모1개)에 따라 새로운 객체는 세모가 됨. k=5일때는 파랑색이 더 많기때문에 네모로 분류됨. 이 분류문제는 최적의 k를 찾는것이 중요하다.

- 최적의 K는?

k가 너무 크면 데이터 구조를 파악하기 어렵고 너무 작으면 과적합 위험이 있음.

교차검증으로 정확도가 높은 k를 선정.

- 장점 : 단순/효율, 데이터 분산 추정 필요 없음, 빠른 훈련단계

- 단점 : 모델 생성하지 않음, 느린 분류단계, 메모리 소요 많음, 결측치를 위한 훈련 필요

 

3. KNN을 R로 수행하기

1) KNN을 수행하기 위한 추가 패키지 설치

- class : KNN수행을 위한 패키지

- gmodels : 분류분석 후 검증에 사용되는 cross table을 위한 패키지

- scales : 최적의 k등 그래프를 위한 패키지

# packages
install.packages("class")#no weighted value knn
install.packages("gmodels")#crosstable
install.packages("scales")#for graph
library(class)
library(gmodels)
library(scales) # 패키지를 사용할때는 꼭 library는 설정해야 한다. 

# set working directory
setwd("/Users/choijia/postech_ai/ML/week10_1_new")

2) Iris 데이터 ; 데이터 불러들이기, 학습데이터와 검증데이터의 분할

2-1) 데이터 불러들이기

# read csv file
iris<-read.csv("iris.csv")
# head(iris)
# str(iris)
attach(iris)

2-2) 데이터 분할(학습데이터 2/3, 검증데이터 1/3)

# training/ test data : n=150
set.seed(1000, sample.kind="Rounding")
N=nrow(iris)
tr.idx=sample(1:N, size=N*2/3, replace=FALSE)

2-3) 학습데이터의 독립변수/종속변수, 검증데이터의 독립변수/종속변수

# attributes in training and test
iris.train<-iris[tr.idx,-5]
iris.test<-iris[-tr.idx,-5]
# target value in training and test
trainLabels<-iris[tr.idx,5]
testLabels<-iris[-tr.idx,5]

train<-iris[tr.idx,]
test<-iris[-tr.idx,]

3) KNN의 수행과 결과

- knn함수 : knn(train=학습데이터, test=검증데이터, cl=타겟변수, k=k)

# knn (5-nearest neighbor)
md1<-knn(train=iris.train,test=iris.test,cl=trainLabels,k=5)
md1
help(knn) # knn의 매뉴얼을 보고싶을 경우

knn으로 분류된 것들

4) knn의 결과 - 정확도 (k=5) --> CrossTable

# accuracy of 5-nearest neighbor classification
CrossTable(x=testLabels,y=md1, prop.chisq=FALSE)
help(CrossTable)

- KNN함수를 이용하여 얻은 결과 setosa는 모두 다 분류가 됨, versica는 1개가 오분류됨. virginica도 다 분류가 됨.

- test셋에서의 오분류율은 오분류한것(1개) /전체(50개) = 0.02%

 

4. KNN에서의 최적 k 탐색

1) 최적 k의 탐색 : 1 to nrow(train_data)/2 (여기서는 1 to 50까지)

# 최적 K를 찾는 코드 

# optimal k selection (1 to n/2)
accuracy_k <- NULL
# try k=1 to nrow(train)/2, may use nrow(train)/3(or 4,5) depending the size of n in train data
nnum<-nrow(iris.train)/2
for(kk in c(1:nnum))
{
  set.seed(1234, sample.kind="Rounding")
  knn_k<-knn(train=iris.train,test=iris.test,cl=trainLabels,k=kk)
  accuracy_k<-c(accuracy_k,sum(knn_k==testLabels)/length(testLabels))
}

# plot for k=(1 to n/2) and accuracy
test_k<-data.frame(k=c(1:nnum), accuracy=accuracy_k[c(1:nnum)])
plot(formula=accuracy~k, data=test_k,type="o",ylim=c(0.5,1), pch=20, col=3, main="validation-optimal k")
with(test_k,text(accuracy~k,labels = k,pos=1,cex=0.7))

# minimum k for the highest accuracy
min(test_k[test_k$accuracy %in% max(accuracy_k),"k"])

2) 최종 KNN모형 (k=7)

#k=7 knn
md1<-knn(train=iris.train,test=iris.test,cl=trainLabels,k=7)
CrossTable(x=testLabels,y=md1, prop.chisq=FALSE)

3) KNN의 결과 - 그래픽

# graphic display
plot(formula=Petal.Length ~ Petal.Width,
     data=iris.train,col=alpha(c("purple","blue","green"),0.7)[trainLabels],
     main="knn(k=7)")
points(formula = Petal.Length~Petal.Width,
       data=iris.test,
       pch = 17,
       cex= 1.2,
       col=alpha(c("purple","blue","green"),0.7)[md1]
)
legend("bottomright",
       c(paste("train",levels(trainLabels)),paste("test",levels(testLabels))),
       pch=c(rep(1,3),rep(17,3)),
       col=c(rep(alpha(c("purple","blue","green"),0.7),2)),
       cex=0.9
)

Petal.width와 Petal.length에 산점도를 그려보면 setosa는 잘 분류됨. virginica와 versicolor는 분류가 잘 되지 않음.

 

5. 가중치 k기법 (Weighted kNN)

- 거리에 따라 가중치를 부여하는 두 가지 알고리즘 존재

1) KNN가중치 기법을 사용하기 위한 패키지 설치

## Weighted KNN packages
install.packages("kknn")#weighted value knn
library(kknn)
help("kknn")

2) k=5일때 distance=1

# weighted knn
md2<-kknn(Species~., train=train,test=iris.test,k=5,distance=1,kernel="triangular")
md2

2-1) k=5일때의 결과확인

# to see results for weighted knn
md2_fit<-fitted(md2)
md2_fit
# accuracy of weighted knn
CrossTable(x=testLabels,y=md2_fit,prop.chisq=FALSE,prop.c=FALSE)

3) k=7일때의 가중치 knn 기법

# weighted knn (k=7, distance=2)
md3<-kknn(Species~., train=train,test=iris.test,k=7,distance=2,kernel="triangular")
md3
# to see results for weighted knn
md3_fit<-fitted(md3)
md3_fit
# accuracy of weighted knn
CrossTable(x=testLabels,y=md3_fit,prop.chisq=FALSE,prop.c=FALSE)