728x90
반응형
In [1]:
"""<순환신경망(Recurrent Neural Network, RNN)>
-RNN은 텍스트 처리를 위해 고안된 모델(계층)
-바로 이전의 데이터 (텍스트)를 재사용하는 신경망 계층임
-텍스트 기반
<순환신경망 종류>
-심플 순환신경망(Simple RNN) ;문장이 긴경우에 앞쪽 문장은 기억 못함
-장기기억 순환신경망(LSTM)
-게이트웨이 반복 순환신경망(GRU)
"""
Out[1]:
'<순환신경망(Recurrent Neural Network, RNN)>\n -RNN은 텍스트 처리를 위해 고안된 모델(계층)\n -바로 이전의 데이터 (텍스트)를 재사용하는 신경망 계층임\n -텍스트 기반\n \n <순환신경망 종류>\n -심플 순환신경망(Simple RNN)\n -장기기억 순환신경망(LSTM)\n -게이트웨이 반복 순환신경망(GRU)\n\n'
In [2]:
"""사용라이브러리"""
from tensorflow import keras
from tensorflow.keras.layers import Dense
import tensorflow as tf
tf.keras.utils.set_random_seed(42)
### 연산 고정
tf.config.experimental.enable_op_determinism()
In [3]:
"""사용 데이터 셋"""
from tensorflow.keras.datasets import imdb
In [4]:
"""
<IMDB : 영화리뷰 감상평 데이터>
- 순환신경망에서 대표적으로 사용되는 데이터셋(외국)
- 케라스에서 영어로된 문장을 정수(숫자로)변환하여 제공하는 데이터셋
- 감상평을 긍정과 부정으로 라벨링 되어 있음
- 총 50000개의 샘플로 되어 있으며 훈련 및 테스트로 각각 25000개씩 분리하여 제공됨
"""
Out[4]:
'\n<IMDB : 영화리뷰 감상평 데이터>\n - 순환신경망에서 대표적으로 사용되는 데이터셋(외국)\n - 케라스에서 영어로된 문장을 정수(숫자로)변환하여 제공하는 데이터셋\n - 감상평을 긍정과 부정으로 라벨링 되어 있음\n - 총 50000개의 샘플로 되어 있으며 훈련 및 테스트로 각각 25000개씩 분리하여 제공됨\n'
IMDB 데이터 읽어들이기¶
In [5]:
"""
* num_words=500 : 말뭉치 사전의 갯수 500개만 추출하겠다는 의미
"""
imdb.load_data(num_words=500)
Out[5]:
((array([list([1, 14, 22, 16, 43, 2, 2, 2, 2, 65, 458, 2, 66, 2, 4, 173, 36, 256, 5, 25, 100, 43, 2, 112, 50, 2, 2, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 2, 336, 385, 39, 4, 172, 2, 2, 17, 2, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2, 19, 14, 22, 4, 2, 2, 469, 4, 22, 71, 87, 12, 16, 43, 2, 38, 76, 15, 13, 2, 4, 22, 17, 2, 17, 12, 16, 2, 18, 2, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2, 2, 16, 480, 66, 2, 33, 4, 130, 12, 16, 38, 2, 5, 25, 124, 51, 36, 135, 48, 25, 2, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 2, 8, 4, 107, 117, 2, 15, 256, 4, 2, 7, 2, 5, 2, 36, 71, 43, 2, 476, 26, 400, 317, 46, 7, 4, 2, 2, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2, 56, 26, 141, 6, 194, 2, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 2, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 2, 88, 12, 16, 283, 5, 16, 2, 113, 103, 32, 15, 16, 2, 19, 178, 32]), list([1, 194, 2, 194, 2, 78, 228, 5, 6, 2, 2, 2, 134, 26, 4, 2, 8, 118, 2, 14, 394, 20, 13, 119, 2, 189, 102, 5, 207, 110, 2, 21, 14, 69, 188, 8, 30, 23, 7, 4, 249, 126, 93, 4, 114, 9, 2, 2, 5, 2, 4, 116, 9, 35, 2, 4, 229, 9, 340, 2, 4, 118, 9, 4, 130, 2, 19, 4, 2, 5, 89, 29, 2, 46, 37, 4, 455, 9, 45, 43, 38, 2, 2, 398, 4, 2, 26, 2, 5, 163, 11, 2, 2, 4, 2, 9, 194, 2, 7, 2, 2, 349, 2, 148, 2, 2, 2, 15, 123, 125, 68, 2, 2, 15, 349, 165, 2, 98, 5, 4, 228, 9, 43, 2, 2, 15, 299, 120, 5, 120, 174, 11, 220, 175, 136, 50, 9, 2, 228, 2, 5, 2, 2, 245, 2, 5, 4, 2, 131, 152, 491, 18, 2, 32, 2, 2, 14, 9, 6, 371, 78, 22, 2, 64, 2, 9, 8, 168, 145, 23, 4, 2, 15, 16, 4, 2, 5, 28, 6, 52, 154, 462, 33, 89, 78, 285, 16, 145, 95]), list([1, 14, 47, 8, 30, 31, 7, 4, 249, 108, 7, 4, 2, 54, 61, 369, 13, 71, 149, 14, 22, 112, 4, 2, 311, 12, 16, 2, 33, 75, 43, 2, 296, 4, 86, 320, 35, 2, 19, 263, 2, 2, 4, 2, 33, 89, 78, 12, 66, 16, 4, 360, 7, 4, 58, 316, 334, 11, 4, 2, 43, 2, 2, 8, 257, 85, 2, 42, 2, 2, 83, 68, 2, 15, 36, 165, 2, 278, 36, 69, 2, 2, 8, 106, 14, 2, 2, 18, 6, 22, 12, 215, 28, 2, 40, 6, 87, 326, 23, 2, 21, 23, 22, 12, 272, 40, 57, 31, 11, 4, 22, 47, 6, 2, 51, 9, 170, 23, 2, 116, 2, 2, 13, 191, 79, 2, 89, 2, 14, 9, 8, 106, 2, 2, 35, 2, 6, 227, 7, 129, 113]), ..., list([1, 11, 6, 230, 245, 2, 9, 6, 2, 446, 2, 45, 2, 84, 2, 2, 21, 4, 2, 84, 2, 325, 2, 134, 2, 2, 84, 5, 36, 28, 57, 2, 21, 8, 140, 8, 2, 5, 2, 84, 56, 18, 2, 14, 9, 31, 7, 4, 2, 2, 2, 2, 2, 18, 6, 20, 207, 110, 2, 12, 8, 2, 2, 8, 97, 6, 20, 53, 2, 74, 4, 460, 364, 2, 29, 270, 11, 2, 108, 45, 40, 29, 2, 395, 11, 6, 2, 2, 7, 2, 89, 364, 70, 29, 140, 4, 64, 2, 11, 4, 2, 26, 178, 4, 2, 443, 2, 5, 27, 2, 117, 2, 2, 165, 47, 84, 37, 131, 2, 14, 2, 10, 10, 61, 2, 2, 10, 10, 288, 2, 2, 34, 2, 2, 4, 65, 496, 4, 231, 7, 2, 5, 6, 320, 234, 2, 234, 2, 2, 7, 496, 4, 139, 2, 2, 2, 2, 5, 2, 18, 4, 2, 2, 250, 11, 2, 2, 4, 2, 2, 2, 2, 372, 2, 2, 2, 2, 7, 4, 59, 2, 4, 2, 2]), list([1, 2, 2, 69, 72, 2, 13, 2, 2, 8, 12, 2, 23, 5, 16, 484, 2, 54, 349, 11, 2, 2, 45, 58, 2, 13, 197, 12, 16, 43, 23, 2, 5, 62, 30, 145, 402, 11, 2, 51, 2, 32, 61, 369, 71, 66, 2, 12, 2, 75, 100, 2, 8, 4, 105, 37, 69, 147, 2, 75, 2, 44, 257, 390, 5, 69, 263, 2, 105, 50, 286, 2, 23, 4, 123, 13, 161, 40, 5, 421, 4, 116, 16, 2, 13, 2, 40, 319, 2, 112, 2, 11, 2, 121, 25, 70, 2, 4, 2, 2, 13, 18, 31, 62, 40, 8, 2, 4, 2, 7, 14, 123, 5, 2, 25, 8, 2, 12, 145, 5, 202, 12, 160, 2, 202, 12, 6, 52, 58, 2, 92, 401, 2, 12, 39, 14, 251, 8, 15, 251, 5, 2, 12, 38, 84, 80, 124, 12, 9, 23]), list([1, 17, 6, 194, 337, 7, 4, 204, 22, 45, 254, 8, 106, 14, 123, 4, 2, 270, 2, 5, 2, 2, 2, 2, 101, 405, 39, 14, 2, 4, 2, 9, 115, 50, 305, 12, 47, 4, 168, 5, 235, 7, 38, 111, 2, 102, 7, 4, 2, 2, 9, 24, 6, 78, 2, 17, 2, 2, 21, 27, 2, 2, 5, 2, 2, 92, 2, 4, 2, 7, 4, 204, 42, 97, 90, 35, 221, 109, 29, 127, 27, 118, 8, 97, 12, 157, 21, 2, 2, 9, 6, 66, 78, 2, 4, 2, 2, 5, 2, 272, 191, 2, 6, 2, 8, 2, 2, 2, 2, 5, 383, 2, 2, 2, 2, 497, 2, 8, 2, 2, 2, 21, 60, 27, 239, 9, 43, 2, 209, 405, 10, 10, 12, 2, 40, 4, 248, 20, 12, 16, 5, 174, 2, 72, 7, 51, 6, 2, 22, 4, 204, 131, 9])], dtype=object), array([1, 0, 0, ..., 0, 1, 0], dtype=int64)), (array([list([1, 2, 202, 14, 31, 6, 2, 10, 10, 2, 2, 5, 4, 360, 7, 4, 177, 2, 394, 354, 4, 123, 9, 2, 2, 2, 10, 10, 13, 92, 124, 89, 488, 2, 100, 28, 2, 14, 31, 23, 27, 2, 29, 220, 468, 8, 124, 14, 286, 170, 8, 157, 46, 5, 27, 239, 16, 179, 2, 38, 32, 25, 2, 451, 202, 14, 6, 2]), list([1, 14, 22, 2, 6, 176, 7, 2, 88, 12, 2, 23, 2, 5, 109, 2, 4, 114, 9, 55, 2, 5, 111, 7, 4, 139, 193, 273, 23, 4, 172, 270, 11, 2, 2, 4, 2, 2, 109, 2, 21, 4, 22, 2, 8, 6, 2, 2, 10, 10, 4, 105, 2, 35, 2, 2, 19, 2, 2, 5, 2, 2, 45, 55, 221, 15, 2, 2, 2, 14, 2, 4, 405, 5, 2, 7, 27, 85, 108, 131, 4, 2, 2, 2, 405, 9, 2, 133, 5, 50, 13, 104, 51, 66, 166, 14, 22, 157, 9, 4, 2, 239, 34, 2, 2, 45, 407, 31, 7, 41, 2, 105, 21, 59, 299, 12, 38, 2, 5, 2, 15, 45, 2, 488, 2, 127, 6, 52, 292, 17, 4, 2, 185, 132, 2, 2, 2, 488, 2, 47, 6, 392, 173, 4, 2, 2, 270, 2, 4, 2, 7, 4, 65, 55, 73, 11, 346, 14, 20, 9, 6, 2, 2, 7, 2, 2, 2, 5, 2, 30, 2, 2, 56, 4, 2, 5, 2, 2, 8, 4, 2, 398, 229, 10, 10, 13, 2, 2, 2, 14, 9, 31, 7, 27, 111, 108, 15, 2, 19, 2, 2, 2, 2, 14, 22, 9, 2, 21, 45, 2, 5, 45, 252, 8, 2, 6, 2, 2, 2, 39, 4, 2, 48, 25, 181, 8, 67, 35, 2, 22, 49, 238, 60, 135, 2, 14, 9, 290, 4, 58, 10, 10, 472, 45, 55, 2, 8, 169, 11, 374, 2, 25, 203, 28, 8, 2, 12, 125, 4, 2]), list([1, 111, 2, 2, 2, 2, 2, 4, 87, 2, 2, 7, 31, 318, 2, 7, 4, 498, 2, 2, 63, 29, 2, 220, 2, 2, 5, 17, 12, 2, 220, 2, 17, 6, 185, 132, 2, 16, 53, 2, 11, 2, 74, 4, 438, 21, 27, 2, 2, 8, 22, 107, 2, 2, 2, 2, 8, 35, 2, 2, 11, 22, 231, 54, 29, 2, 29, 100, 2, 2, 34, 2, 2, 2, 5, 2, 98, 31, 2, 33, 6, 58, 14, 2, 2, 8, 4, 365, 7, 2, 2, 356, 346, 4, 2, 2, 63, 29, 93, 11, 2, 11, 2, 33, 6, 58, 54, 2, 431, 2, 7, 32, 2, 16, 11, 94, 2, 10, 10, 4, 2, 2, 7, 4, 2, 2, 2, 2, 8, 2, 8, 2, 121, 31, 7, 27, 86, 2, 2, 16, 6, 465, 2, 2, 2, 2, 17, 2, 42, 4, 2, 37, 473, 6, 2, 6, 2, 7, 328, 212, 70, 30, 258, 11, 220, 32, 7, 108, 21, 133, 12, 9, 55, 465, 2, 2, 53, 33, 2, 2, 37, 70, 2, 4, 2, 2, 74, 476, 37, 62, 91, 2, 169, 4, 2, 2, 146, 2, 2, 5, 258, 12, 184, 2, 2, 5, 2, 2, 7, 4, 22, 2, 18, 2, 2, 2, 7, 4, 2, 71, 348, 425, 2, 2, 19, 2, 5, 2, 11, 2, 8, 339, 2, 4, 2, 2, 7, 4, 2, 10, 10, 263, 2, 9, 270, 11, 6, 2, 4, 2, 2, 121, 4, 2, 26, 2, 19, 68, 2, 5, 28, 446, 6, 318, 2, 8, 67, 51, 36, 70, 81, 8, 2, 2, 36, 2, 8, 2, 2, 18, 6, 2, 4, 2, 26, 2, 2, 11, 14, 2, 2, 12, 426, 28, 77, 2, 8, 97, 38, 111, 2, 2, 168, 2, 2, 137, 2, 18, 27, 173, 9, 2, 17, 6, 2, 428, 2, 232, 11, 4, 2, 37, 272, 40, 2, 247, 30, 2, 6, 2, 54, 2, 2, 98, 6, 2, 40, 2, 37, 2, 98, 4, 2, 2, 15, 14, 9, 57, 2, 5, 2, 6, 275, 2, 2, 2, 2, 98, 6, 2, 10, 10, 2, 19, 14, 2, 267, 162, 2, 37, 2, 2, 98, 4, 2, 2, 90, 19, 6, 2, 7, 2, 2, 2, 4, 2, 2, 2, 8, 2, 90, 4, 2, 8, 4, 2, 17, 2, 2, 2, 4, 2, 8, 2, 189, 4, 2, 2, 2, 4, 2, 5, 95, 271, 23, 6, 2, 2, 2, 2, 33, 2, 6, 425, 2, 2, 2, 2, 7, 4, 2, 2, 469, 4, 2, 54, 4, 150, 2, 2, 280, 53, 2, 2, 18, 339, 29, 2, 27, 2, 5, 2, 68, 2, 19, 2, 2, 4, 2, 7, 263, 65, 2, 34, 6, 2, 2, 43, 159, 29, 9, 2, 9, 387, 73, 195, 2, 10, 10, 2, 4, 58, 2, 54, 14, 2, 117, 22, 16, 93, 5, 2, 4, 192, 15, 12, 16, 93, 34, 6, 2, 2, 33, 4, 2, 7, 15, 2, 2, 2, 325, 12, 62, 30, 2, 8, 67, 14, 17, 6, 2, 44, 148, 2, 2, 203, 42, 203, 24, 28, 69, 2, 2, 11, 330, 54, 29, 93, 2, 21, 2, 2, 27, 2, 7, 2, 4, 22, 2, 17, 6, 2, 2, 7, 2, 2, 2, 100, 30, 4, 2, 2, 2, 2, 42, 2, 11, 4, 2, 42, 101, 2, 7, 101, 2, 15, 2, 94, 2, 180, 5, 9, 2, 34, 2, 45, 6, 2, 22, 60, 6, 2, 31, 11, 94, 2, 96, 21, 94, 2, 9, 57, 2]), ..., list([1, 13, 2, 15, 8, 135, 14, 9, 35, 32, 46, 394, 20, 62, 30, 2, 21, 45, 184, 78, 4, 2, 2, 2, 2, 2, 395, 2, 5, 2, 11, 119, 2, 89, 2, 4, 116, 218, 78, 21, 407, 100, 30, 128, 262, 15, 7, 185, 2, 284, 2, 2, 37, 315, 4, 226, 20, 272, 2, 40, 29, 152, 60, 181, 8, 30, 50, 2, 362, 80, 119, 12, 21, 2, 2]), list([1, 11, 119, 241, 9, 4, 2, 20, 12, 468, 15, 94, 2, 2, 2, 39, 4, 86, 107, 8, 97, 14, 31, 33, 4, 2, 7, 2, 46, 2, 9, 2, 5, 4, 2, 47, 8, 79, 90, 145, 164, 162, 50, 6, 2, 119, 7, 9, 4, 78, 232, 15, 16, 224, 11, 4, 333, 20, 4, 2, 200, 5, 2, 5, 9, 2, 8, 79, 357, 4, 20, 47, 220, 57, 206, 139, 11, 12, 5, 55, 117, 212, 13, 2, 92, 124, 51, 45, 2, 71, 2, 13, 2, 14, 20, 6, 2, 7, 470]), list([1, 6, 52, 2, 430, 22, 9, 220, 2, 8, 28, 2, 2, 2, 6, 2, 15, 47, 6, 2, 2, 8, 114, 5, 33, 222, 31, 55, 184, 2, 2, 2, 19, 346, 2, 5, 6, 364, 350, 4, 184, 2, 9, 133, 2, 11, 2, 2, 21, 4, 2, 2, 2, 50, 2, 2, 9, 6, 2, 17, 6, 2, 2, 21, 17, 6, 2, 232, 2, 2, 29, 266, 56, 96, 346, 194, 308, 9, 194, 21, 29, 218, 2, 19, 4, 78, 173, 7, 27, 2, 2, 2, 2, 2, 9, 6, 2, 17, 210, 5, 2, 2, 47, 77, 395, 14, 172, 173, 18, 2, 2, 2, 82, 127, 27, 173, 11, 6, 392, 217, 21, 50, 9, 57, 65, 12, 2, 53, 40, 35, 390, 7, 11, 4, 2, 7, 4, 314, 74, 6, 2, 22, 2, 19, 2, 2, 2, 382, 4, 91, 2, 439, 19, 14, 20, 9, 2, 2, 2, 4, 2, 25, 124, 4, 31, 12, 16, 93, 2, 34, 2, 2])], dtype=object), array([0, 1, 1, ..., 0, 0, 0], dtype=int64)))
In [6]:
# 리스트 하나가 문장
(train_input, train_target), (test_input, test_target)= imdb.load_data(num_words=500)
print (f"{train_input.shape} / {train_target.shape}")
print (f"{test_input.shape} / {test_target.shape}")
(25000,) / (25000,) (25000,) / (25000,)
In [7]:
print(len(train_input[1]),train_input[1] )
# 문장의 단어 갯수/ 그 단어의 인덱스;말뭉치
189 [1, 194, 2, 194, 2, 78, 228, 5, 6, 2, 2, 2, 134, 26, 4, 2, 8, 118, 2, 14, 394, 20, 13, 119, 2, 189, 102, 5, 207, 110, 2, 21, 14, 69, 188, 8, 30, 23, 7, 4, 249, 126, 93, 4, 114, 9, 2, 2, 5, 2, 4, 116, 9, 35, 2, 4, 229, 9, 340, 2, 4, 118, 9, 4, 130, 2, 19, 4, 2, 5, 89, 29, 2, 46, 37, 4, 455, 9, 45, 43, 38, 2, 2, 398, 4, 2, 26, 2, 5, 163, 11, 2, 2, 4, 2, 9, 194, 2, 7, 2, 2, 349, 2, 148, 2, 2, 2, 15, 123, 125, 68, 2, 2, 15, 349, 165, 2, 98, 5, 4, 228, 9, 43, 2, 2, 15, 299, 120, 5, 120, 174, 11, 220, 175, 136, 50, 9, 2, 228, 2, 5, 2, 2, 245, 2, 5, 4, 2, 131, 152, 491, 18, 2, 32, 2, 2, 14, 9, 6, 371, 78, 22, 2, 64, 2, 9, 8, 168, 145, 23, 4, 2, 15, 16, 4, 2, 5, 28, 6, 52, 154, 462, 33, 89, 78, 285, 16, 145, 95]
In [8]:
"""0:부정, 1:긍정"""
train_target
Out[8]:
array([1, 0, 0, ..., 0, 1, 0], dtype=int64)
In [9]:
"""훈련 : 검증 = 8:2로 분리하기"""
from sklearn.model_selection import train_test_split
train_input,val_input, train_target, val_target = train_test_split( train_input, train_target, test_size=0.2, random_state=42)
In [10]:
train_input.shape, train_target.shape, val_input.shape, val_target.shape , test_input.shape, test_target.shape
Out[10]:
((20000,), (20000,), (5000,), (5000,), (25000,), (25000,))
정규화 (텍스트 데이터)¶
In [11]:
"""
-텍스트 기반의 데이터인 경우 정규화는 스케일링 처리가 아닌 문자열의 길이를 통일시키는 처리를 진행합니다.
-훈련 모델은 정해진 행렬의 사이즈를 기준을 훈련하기 때문에
"""
Out[11]:
'\n-텍스트 기반의 데이터인 경우 정규화는 스케일링 처리가 아닌 문자열의 길이를 통일시키는 처리를 진행합니다.\n-훈련 모델은 정해진 행렬의 사이즈를 기준을 훈련하기 때문에\n'
In [12]:
import numpy as np
In [13]:
print(len(train_input),train_input[0] )
20000 [1, 73, 89, 81, 25, 60, 2, 6, 20, 141, 17, 14, 31, 127, 12, 60, 28, 2, 2, 66, 45, 6, 20, 15, 497, 8, 79, 17, 491, 8, 112, 6, 2, 20, 17, 2, 2, 4, 436, 20, 9, 2, 6, 2, 7, 493, 2, 6, 185, 250, 24, 55, 2, 5, 23, 350, 7, 15, 82, 24, 15, 2, 66, 10, 10, 45, 2, 15, 4, 20, 2, 8, 30, 17, 2, 5, 2, 17, 2, 190, 4, 20, 9, 43, 32, 99, 2, 18, 15, 8, 157, 46, 17, 2, 4, 2, 5, 2, 9, 32, 2, 5, 2, 267, 17, 73, 17, 2, 36, 26, 400, 43, 2, 83, 4, 2, 247, 74, 83, 4, 250, 2, 82, 4, 96, 4, 250, 2, 8, 32, 4, 2, 9, 184, 2, 13, 384, 48, 14, 16, 147, 2, 59, 62, 69, 2, 12, 46, 50, 9, 53, 2, 74, 2, 11, 14, 31, 151, 10, 10, 4, 20, 9, 2, 364, 352, 5, 45, 6, 2, 2, 33, 269, 8, 2, 142, 2, 5, 2, 17, 73, 17, 204, 5, 2, 19, 55, 2, 2, 92, 66, 104, 14, 20, 93, 76, 2, 151, 33, 4, 58, 12, 188, 2, 151, 12, 215, 69, 224, 142, 73, 237, 6, 2, 7, 2, 2, 188, 2, 103, 14, 31, 10, 10, 451, 7, 2, 5, 2, 80, 91, 2, 30, 2, 34, 14, 20, 151, 50, 26, 131, 49, 2, 84, 46, 50, 37, 80, 79, 6, 2, 46, 7, 14, 20, 10, 10, 470, 158]
In [14]:
"""훈련 독립변수의 각 데이터 (값)의 길이를 배열(리스트)형태로 추출하기"""
# length = []
# for i in train_input:
# length.append(len(train_input[i]))
lengths = np.array([len(train_input[i]) for i in train_input])
In [15]:
lengths.shape, lengths
Out[15]:
((20000,), array([259, 520, 290, ..., 300, 70, 77]))
In [16]:
"""lengths의 값을 이용해서 전체 평균과 중앙값 출력하기"""
np.mean(lengths), np.median(lengths)
Out[16]:
(239.00925, 178.0)
In [17]:
"""시각화"""
import matplotlib.pyplot as plt
from matplotlib import font_manager, rc
plt.rc("font", family="Malgun Gothic")
plt.title("텍스트 길이 분포 확인")
plt.hist(lengths)
plt.xlabel("length(단어갯수)")
plt.ylabel("빈도")
plt.show()
"""단어 갯수의 분포를 이용해서
- 훈련시에 사용할 값들의 길이 기준 정의
-전체적으로 왼편에 집중되어 있으며
- x축 125정도에 많은 빈도를 나타내고 있음
- 따라서 독립변수 각 값들의 길이를 100일로 통일(정규화 작업)
"""
Out[17]:
'단어 갯수의 분포를 이용해서 \n - 훈련시에 사용할 값들의 길이 기준 정의\n -전체적으로 왼편에 집중되어 있으며\n - x축 125정도에 많은 빈도를 나타내고 있음\n - 따라서 독립변수 각 값들의 길이를 100일로 통일(정규화 작업)\n \n \n '
In [18]:
train_input[0]
Out[18]:
[1, 73, 89, 81, 25, 60, 2, 6, 20, 141, 17, 14, 31, 127, 12, 60, 28, 2, 2, 66, 45, 6, 20, 15, 497, 8, 79, 17, 491, 8, 112, 6, 2, 20, 17, 2, 2, 4, 436, 20, 9, 2, 6, 2, 7, 493, 2, 6, 185, 250, 24, 55, 2, 5, 23, 350, 7, 15, 82, 24, 15, 2, 66, 10, 10, 45, 2, 15, 4, 20, 2, 8, 30, 17, 2, 5, 2, 17, 2, 190, 4, 20, 9, 43, 32, 99, 2, 18, 15, 8, 157, 46, 17, 2, 4, 2, 5, 2, 9, 32, 2, 5, 2, 267, 17, 73, 17, 2, 36, 26, 400, 43, 2, 83, 4, 2, 247, 74, 83, 4, 250, 2, 82, 4, 96, 4, 250, 2, 8, 32, 4, 2, 9, 184, 2, 13, 384, 48, 14, 16, 147, 2, 59, 62, 69, 2, 12, 46, 50, 9, 53, 2, 74, 2, 11, 14, 31, 151, 10, 10, 4, 20, 9, 2, 364, 352, 5, 45, 6, 2, 2, 33, 269, 8, 2, 142, 2, 5, 2, 17, 73, 17, 204, 5, 2, 19, 55, 2, 2, 92, 66, 104, 14, 20, 93, 76, 2, 151, 33, 4, 58, 12, 188, 2, 151, 12, 215, 69, 224, 142, 73, 237, 6, 2, 7, 2, 2, 188, 2, 103, 14, 31, 10, 10, 451, 7, 2, 5, 2, 80, 91, 2, 30, 2, 34, 14, 20, 151, 50, 26, 131, 49, 2, 84, 46, 50, 37, 80, 79, 6, 2, 46, 7, 14, 20, 10, 10, 470, 158]
각 데이터의 길이를 100으로 통일 (정규화) 시키기¶
In [19]:
"""덱스트 길이 정규화 라이브러리"""
from tensorflow.keras.preprocessing.sequence import pad_sequences
In [20]:
"""훈련 독립변수 각 데이터 100개로 통일 시키기
-pad_sequences() : 텍스트의 길이를 maxlen 갯수로 통일시키기
-maxlen 보다 작으면 0으로 채우고, 크면 제거합니다.
-결과 값은 -> 2차원 리스트로 반환합니다.
"""
train_seq = pad_sequences(train_input, maxlen=100)
train_seq.shape
Out[20]:
(20000, 100)
In [21]:
print(train_input[5],train_seq[5])
[1, 2, 195, 19, 49, 2, 2, 190, 4, 2, 352, 2, 183, 10, 10, 13, 82, 79, 4, 2, 36, 71, 269, 8, 2, 25, 19, 49, 7, 4, 2, 2, 2, 2, 2, 10, 10, 48, 25, 40, 2, 11, 2, 2, 40, 2, 2, 5, 4, 2, 2, 95, 14, 238, 56, 129, 2, 10, 10, 21, 2, 94, 364, 352, 2, 2, 11, 190, 24, 484, 2, 7, 94, 205, 405, 10, 10, 87, 2, 34, 49, 2, 7, 2, 2, 2, 2, 2, 290, 2, 46, 48, 64, 18, 4, 2] [ 0 0 0 0 1 2 195 19 49 2 2 190 4 2 352 2 183 10 10 13 82 79 4 2 36 71 269 8 2 25 19 49 7 4 2 2 2 2 2 10 10 48 25 40 2 11 2 2 40 2 2 5 4 2 2 95 14 238 56 129 2 10 10 21 2 94 364 352 2 2 11 190 24 484 2 7 94 205 405 10 10 87 2 34 49 2 7 2 2 2 2 2 290 2 46 48 64 18 4 2]
In [101]:
val_input
Out[101]:
array([list([1, 225, 6, 2, 2, 200, 2, 35, 204, 2, 2, 7, 129, 2, 2, 2, 2, 5, 399, 40, 2, 2, 2, 5, 27, 2, 28, 224, 19, 2, 2, 5, 331, 2, 40, 12, 4, 22, 2, 2, 2, 2, 83, 6, 2, 11, 4, 2, 15, 2, 8, 2, 2, 17, 31, 103, 4, 85, 2, 8, 14, 2, 2, 116, 2, 2, 2, 45, 24, 196, 159, 4, 369, 471, 23, 31, 160, 5, 70, 2, 2, 4, 2, 7, 31, 160, 76, 329, 181, 8, 30, 11, 4, 172, 2, 17, 98, 17, 2, 17, 12, 32, 2, 225, 6, 2, 2, 496, 4, 2, 2, 7, 14, 22, 15, 100, 28, 2, 11, 4, 2, 7, 6, 329, 2, 2, 472, 51, 75, 130, 56, 19, 9, 2, 2, 105, 2, 2, 2, 186, 8, 30, 8, 168, 307, 33, 4, 454, 8, 97, 4, 2, 2, 53, 2, 6, 2, 229, 38, 2, 2, 17, 8, 2, 4, 2, 2, 46, 4, 414, 5, 450, 2, 2, 15, 2, 117, 53, 74, 11, 31, 2, 35, 23, 268, 2, 34, 167, 2, 14, 9, 2, 22, 231, 11, 450, 2, 2, 14, 58, 2]), list([1, 54, 6, 392, 2, 2, 2, 2, 2, 6, 2, 2, 39, 27, 2, 4, 2, 2, 2, 5, 2, 15, 4, 2, 9, 35, 2, 2, 7, 35, 445, 465, 2, 144, 4, 2, 130, 56, 145, 11, 27, 2, 29, 80, 2, 27, 2, 5, 2, 2, 2, 2, 5, 27, 2, 369, 270, 46, 23, 6, 6, 2, 8, 2, 4, 2, 19, 6, 2, 7, 2, 14, 9, 35, 2, 2, 7, 4, 356, 2, 17, 12, 64, 2, 4, 86, 320, 7, 4, 65, 2, 14, 9, 35, 2, 5, 2, 2, 22, 10, 10, 4, 2, 9, 2, 224, 19, 2, 63, 9, 2, 120, 412, 206, 2, 2, 2, 2, 73, 19, 4, 364, 352, 29, 16, 348, 4, 22, 82, 2, 6, 2, 228, 2, 34, 2, 2, 15, 2, 175, 136, 50, 26, 6, 171, 114, 2, 19, 4, 229, 21, 15, 47, 8, 30, 2, 2, 4, 204, 2, 16, 8, 97, 4, 274, 2, 83, 107, 108, 76, 69, 8, 2, 11, 4, 86, 31, 61, 2, 2, 9, 49, 7, 4, 109, 2, 2, 16, 6, 227, 99, 2, 137, 4, 85, 2, 2, 340, 2, 4, 85, 105, 26, 165, 73, 398, 18, 4, 268, 5, 4, 2, 156, 81, 6, 87, 292, 13, 16, 2, 15, 2, 9, 165, 6, 227, 53, 2, 8, 4, 114, 4, 2, 2, 26, 53, 2, 74, 2, 137, 4, 2, 26, 2, 5, 2, 10, 10, 160, 439, 9, 15, 4, 445, 2, 2, 9, 446, 2, 2, 8, 4, 2, 444, 13, 104, 6, 117, 53, 278, 5, 128, 2, 62, 28, 224, 14, 6, 176, 7, 2, 21, 50, 9, 142, 2, 44, 12, 2, 2, 93, 6, 2, 2, 7, 231, 268, 2, 7, 134, 356, 2, 4, 22, 2, 2, 39, 112, 2, 34, 4, 412, 206, 108, 21, 45, 131, 6, 87, 20, 18, 2, 2, 7, 32, 2]), list([1, 4, 105, 26, 2, 5, 2, 19, 316, 112, 345, 2, 428, 2, 42, 2, 445, 4, 116, 9, 99, 2, 18, 12, 8, 30, 78, 11, 35, 2, 120, 4, 350, 96, 45, 2, 2, 48, 335, 6, 2, 2, 337, 88, 41, 109, 9, 24, 6, 2, 253, 2, 2, 2, 109, 59, 43, 214, 8, 2, 2, 6, 176, 10, 10, 23, 4, 226, 2, 61, 322, 2, 13, 258, 4, 20, 8, 30, 221, 2, 42, 2, 33, 32]), ..., list([1, 14, 22, 2, 19, 107, 2, 2, 2, 37, 26, 2, 2, 2, 2, 5, 2, 2, 2, 37, 157, 17, 2, 2, 5, 2, 32, 2, 7, 2, 5, 349, 19, 49, 78, 2, 2, 2, 9, 2, 34, 27, 2, 322, 2, 2, 2, 37, 495, 33, 260, 383, 2, 120, 4, 2, 2, 9, 112, 2, 34, 31, 2, 37, 494, 8, 81, 2, 183, 8, 41, 5, 59, 2, 18, 27, 339, 367, 19, 450, 85, 2, 2, 2, 28, 53, 58, 23, 68, 2, 5, 400, 140, 8, 2, 2, 2, 2, 42, 2, 46, 11, 6, 2, 2, 121, 32, 2, 7, 2, 183, 140, 23, 50, 9, 2, 7, 2, 2, 5, 2, 7, 2, 2, 2, 15, 2, 14, 22, 340, 180, 8, 6, 194, 2, 92, 437, 129, 58, 14, 22, 2, 72, 64, 2, 2, 5, 15, 16, 99, 76]), list([1, 14, 20, 64, 188, 6, 300, 88, 25, 191, 202, 6, 2, 48, 25, 28, 6, 2, 2, 33, 32, 92, 106, 2, 2, 84, 25, 92, 181, 8, 106, 345, 12, 166, 84, 2, 13, 2, 13, 2, 12, 2, 4, 64, 2, 2, 26, 4, 420, 5, 4, 192, 15, 94, 6, 283, 65, 94, 66, 66, 2, 2, 15, 14, 232, 2, 33, 2, 137, 29, 2, 84, 131, 235, 2, 54, 25, 2, 11, 15, 117, 2, 13, 92, 29, 69, 2, 8, 175, 2, 11, 2, 13, 2, 4, 20, 12, 16, 24, 2, 12, 16, 2, 2, 81, 129, 2, 6, 2, 5, 2, 2, 2, 2, 14, 20, 13, 104, 12, 2, 44, 2, 8, 97, 15, 2, 68, 120, 2, 156]), list([1, 13, 2, 2, 4, 2, 2, 7, 85, 2, 36, 26, 99, 2, 382, 88, 36, 26, 2, 11, 4, 2, 13, 62, 40, 8, 67, 6, 2, 39, 294, 37, 69, 115, 110, 2, 2, 300, 382, 294, 55, 185, 13, 317, 4, 438, 2, 2, 15, 13, 69, 24, 77, 2, 34, 4, 2, 2, 2, 2, 303, 13, 244, 131, 269, 8, 169, 2, 11, 4, 114, 21, 13, 2, 104, 7, 233, 2, 61, 2, 8, 316, 9, 67, 12, 18, 2, 5, 97, 56, 129, 205, 330, 10, 10, 12, 2, 6, 2, 2, 8, 2, 2, 300, 21, 4, 114, 9, 329, 2, 12, 131, 317, 72, 2, 33, 4, 130, 21, 11, 6, 53, 2, 96, 2, 2, 9, 17, 2, 5, 445, 17, 159, 5, 2, 41, 2, 153, 2, 73, 14, 2, 41, 2, 217, 2, 2, 16, 2, 60, 151, 29, 9, 57, 488, 2, 7, 4, 2, 177, 13, 2, 423, 2, 2, 17, 4, 2, 2])], dtype=object)
In [23]:
"""검증 데이터 길이도 100으로 통일 시키기 정규화"""
val_seq = pad_sequences(val_input, maxlen=100)
val_seq.shape
Out[23]:
(5000, 100)
In [24]:
"""텍스트 길이 조정 속성 (매개 변수)
* truncating : 추출 위치 ( 앞 또는 뒤 부터)
- pre: 뒤쪽 부터 추출하기(기본값)
- post: 앞쪽부터 추출하기
* padding: 채울위치(앞 또는 뒤 부터)
-pre: 앞쪽을 0으로 채우기(기본값)
-post: 뒤쪽을 0으로 채우기
"""
train_seq = pad_sequences(train_input, maxlen=100, truncating="pre", padding="pre")
In [99]:
train_seq.shape
Out[99]:
(20000, 100)
Simple RNN 심플 순환 신경망¶
In [26]:
"""모델 생성하기"""
model = keras.Sequential()
model
Out[26]:
<keras.engine.sequential.Sequential at 0x1ca85eb9130>
In [27]:
"""계층 생성 및 모델에 추가하기
- input_shape=(100, 500) : 100은 특성 갯수 , 500은 말뭉치 갯수
- 입력계층이면서 RNN 계층
"""
model.add(keras.layers.SimpleRNN(8, input_shape=(100,500)))
"""출력계층"""
model.add(keras.layers.Dense(1, activation="sigmoid"))
In [28]:
"""<RNN 에서 사용할 독립변수 처리방식>
-RNN 모델에서는 독립변수의 데이터를
--> 원-핫 인코딩 데이터 또는 임베딩 처리를 통해서 훈련을 시켜야 합니다.
<원-핫 인코딩 (One-hot 인코딩 방식>
- 각 데이터 값 중에 1개의 단어당 분석을 위해 500개의 말뭉치와 비교하여야 함
- 이때 비교하기 위해 원 - 핫 인코딩으로 변환하여 비교하는 방식을 따름
- keras.util.to_categorical()함수 사용
- 프로그램을 통해 변환해야 함 (별도 계층이 존재하지는 않음)
- 각 단어 별로 원 - 핫 인코딩 처리가 되기에 데이터가 많아지며 속도가 느림
-데이터가 많아지기 때문에 많은 메모리 공간을 차지함
<단어 임베딩 (Embedding) 방식>
- 원-핫 인코딩의 느린 속도를 개선하기 위하여 개선된 방식 (메모리 활용)
- 많은 공간을 사용하지 않음
- keras.layers.Embedding() 계층을 사용함 (프로그램 처리방식이 아님)
"""
Out[28]:
'<RNN 에서 사용할 독립변수 처리방식>\n\n -RNN 모델에서는 독립변수의 데이터를 \n --> 원-핫 인코딩 데이터 또는 임베딩 처리를 통해서 훈련을 시켜야 합니다.\n \n <원-핫 인코딩 (One-hot 인코딩 방식>\n - 각 데이터 값 중에 1개의 단어당 분석을 위해 500개의 말뭉치와 비교하여야 함\n - 이때 비교하기 위해 원 - 핫 인코딩으로 변환하여 비교하는 방식을 따름\n - keras.util.to_categorical()함수 사용\n - 프로그램을 통해 변환해야 함 (별도 계층이 존재하지는 않음)\n - 각 단어 별로 원 - 핫 인코딩 처리가 되기에 데이터가 많아지며 속도가 느림\n -데이터가 많아지기 때문에 많은 메모리 공간을 차지함\n \n <단어 임베딩 (Embedding) 방식>\n - 원-핫 인코딩의 느린 속도를 개선하기 위하여 개선된 방식 (메모리 활용)\n - 많은 공간을 사용하지 않음\n - keras.layers.Embedding() 계층을 사용함 (프로그램 처리방식이 아님)\n'
원- 핫 인코딩 데이터로 변환하기¶
In [29]:
"""
- to_categorical() : 원-핫 인코딩 함수
- 말뭉치 갯수만큼의 차원이 발생함 ( 총 3차원)
"""
train_oh = keras.utils.to_categorical(train_seq)
val_oh = keras.utils.to_categorical(val_seq)
train_oh.shape, val_oh.shape
Out[29]:
((20000, 100, 500), (5000, 100, 500))
In [30]:
print(train_seq[0])
len(train_oh)
[ 10 4 20 9 2 364 352 5 45 6 2 2 33 269 8 2 142 2 5 2 17 73 17 204 5 2 19 55 2 2 92 66 104 14 20 93 76 2 151 33 4 58 12 188 2 151 12 215 69 224 142 73 237 6 2 7 2 2 188 2 103 14 31 10 10 451 7 2 5 2 80 91 2 30 2 34 14 20 151 50 26 131 49 2 84 46 50 37 80 79 6 2 46 7 14 20 10 10 470 158]
Out[30]:
20000
In [31]:
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= simple_rnn (SimpleRNN) (None, 8) 4072 dense (Dense) (None, 1) 9 ================================================================= Total params: 4,081 Trainable params: 4,081 Non-trainable params: 0 _________________________________________________________________
In [32]:
"""모델설정하기
순환신경망(RNN)에서 주로 사용되는 옵티마이저는 RMSprop
RMSprop: 먼 거리는 조금 반영 , 가까운 거리는 많이 반영하는 개념을 적용함
"""
rmsprop = keras.optimizers.RMSprop(learning_rate=0.0001)
model.compile(optimizer=rmsprop, loss="binary_crossentropy",
metrics="accuracy")
콜백함수 정의하기¶
In [33]:
"""자동 저장 및 종료하는 콜백함수 정의하기"""
model_path = "./model/best_simpleRNN_model.h5"
"""모델 자동 저장 콜백함수"""
checkpoint_cb = keras.callbacks.ModelCheckpoint(model_path)
"""자동 저장 종료 콜백함수"""
ealry_stopping_cb = keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)
모델 훈련시키기¶
In [34]:
train_target.shape
Out[34]:
(20000,)
In [35]:
"""
* batch_size: 훈련시 데이터를 배치사이즈 만큼 잘라서 종속변수와 검증 실행
- 배치 사이즈 만큼 종속변수와 비교시 틀리게 되면,
- 다음 배치사이즈에서 오류를 줄여가면서 훈련을 진행하게 됩니다.
- 배치사이즈의 값을 정의하지 않으면 전체데이터를 기준으로 종속변수와 비교 및 훈련 반복시 오류 조정 진행됨
- 배치 사이즈의 값은 정의된 값은 없으며 보통32, 64정도를 주로 사용함
- 튜닝 대상 하이퍼파라메터 변수임
"""
history = model.fit(train_oh,train_target, epochs=100, batch_size=64,validation_data=(val_oh, val_target),
callbacks=[checkpoint_cb, ealry_stopping_cb])
Epoch 1/100 313/313 [==============================] - 8s 21ms/step - loss: 0.6989 - accuracy: 0.5041 - val_loss: 0.6978 - val_accuracy: 0.5038 Epoch 2/100 313/313 [==============================] - 6s 21ms/step - loss: 0.6946 - accuracy: 0.5102 - val_loss: 0.6949 - val_accuracy: 0.5102 Epoch 3/100 313/313 [==============================] - 6s 21ms/step - loss: 0.6920 - accuracy: 0.5200 - val_loss: 0.6930 - val_accuracy: 0.5146 Epoch 4/100 313/313 [==============================] - 6s 20ms/step - loss: 0.6901 - accuracy: 0.5302 - val_loss: 0.6916 - val_accuracy: 0.5230 Epoch 5/100 313/313 [==============================] - 6s 20ms/step - loss: 0.6883 - accuracy: 0.5408 - val_loss: 0.6905 - val_accuracy: 0.5270 Epoch 6/100 313/313 [==============================] - 6s 19ms/step - loss: 0.6866 - accuracy: 0.5472 - val_loss: 0.6894 - val_accuracy: 0.5356 Epoch 7/100 313/313 [==============================] - 6s 19ms/step - loss: 0.6845 - accuracy: 0.5563 - val_loss: 0.6880 - val_accuracy: 0.5416 Epoch 8/100 313/313 [==============================] - 6s 19ms/step - loss: 0.6819 - accuracy: 0.5659 - val_loss: 0.6856 - val_accuracy: 0.5482 Epoch 9/100 313/313 [==============================] - 6s 19ms/step - loss: 0.6777 - accuracy: 0.5828 - val_loss: 0.6807 - val_accuracy: 0.5726 Epoch 10/100 313/313 [==============================] - 6s 19ms/step - loss: 0.6614 - accuracy: 0.6260 - val_loss: 0.6585 - val_accuracy: 0.6240 Epoch 11/100 313/313 [==============================] - 6s 19ms/step - loss: 0.6407 - accuracy: 0.6653 - val_loss: 0.6425 - val_accuracy: 0.6586 Epoch 12/100 313/313 [==============================] - 6s 19ms/step - loss: 0.6272 - accuracy: 0.6857 - val_loss: 0.6294 - val_accuracy: 0.6762 Epoch 13/100 313/313 [==============================] - 6s 19ms/step - loss: 0.6140 - accuracy: 0.6982 - val_loss: 0.6197 - val_accuracy: 0.6848 Epoch 14/100 313/313 [==============================] - 6s 19ms/step - loss: 0.6000 - accuracy: 0.7131 - val_loss: 0.6080 - val_accuracy: 0.6944 Epoch 15/100 313/313 [==============================] - 6s 20ms/step - loss: 0.5878 - accuracy: 0.7224 - val_loss: 0.5960 - val_accuracy: 0.7098 Epoch 16/100 313/313 [==============================] - 6s 20ms/step - loss: 0.5754 - accuracy: 0.7329 - val_loss: 0.5864 - val_accuracy: 0.7126 Epoch 17/100 313/313 [==============================] - 6s 19ms/step - loss: 0.5646 - accuracy: 0.7390 - val_loss: 0.5739 - val_accuracy: 0.7266 Epoch 18/100 313/313 [==============================] - 6s 19ms/step - loss: 0.5548 - accuracy: 0.7442 - val_loss: 0.5647 - val_accuracy: 0.7326 Epoch 19/100 313/313 [==============================] - 6s 19ms/step - loss: 0.5445 - accuracy: 0.7503 - val_loss: 0.5544 - val_accuracy: 0.7398 Epoch 20/100 313/313 [==============================] - 6s 19ms/step - loss: 0.5343 - accuracy: 0.7566 - val_loss: 0.5461 - val_accuracy: 0.7418 Epoch 21/100 313/313 [==============================] - 6s 19ms/step - loss: 0.5270 - accuracy: 0.7597 - val_loss: 0.5383 - val_accuracy: 0.7500 Epoch 22/100 313/313 [==============================] - 6s 19ms/step - loss: 0.5164 - accuracy: 0.7650 - val_loss: 0.5329 - val_accuracy: 0.7472 Epoch 23/100 313/313 [==============================] - 6s 19ms/step - loss: 0.5098 - accuracy: 0.7697 - val_loss: 0.5286 - val_accuracy: 0.7484 Epoch 24/100 313/313 [==============================] - 6s 19ms/step - loss: 0.5025 - accuracy: 0.7732 - val_loss: 0.5224 - val_accuracy: 0.7556 Epoch 25/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4956 - accuracy: 0.7763 - val_loss: 0.5184 - val_accuracy: 0.7576 Epoch 26/100 313/313 [==============================] - 6s 20ms/step - loss: 0.4886 - accuracy: 0.7793 - val_loss: 0.5140 - val_accuracy: 0.7568 Epoch 27/100 313/313 [==============================] - 6s 20ms/step - loss: 0.4847 - accuracy: 0.7824 - val_loss: 0.5085 - val_accuracy: 0.7634 Epoch 28/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4779 - accuracy: 0.7871 - val_loss: 0.5070 - val_accuracy: 0.7628 Epoch 29/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4743 - accuracy: 0.7897 - val_loss: 0.5017 - val_accuracy: 0.7634 Epoch 30/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4695 - accuracy: 0.7914 - val_loss: 0.5009 - val_accuracy: 0.7642 Epoch 31/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4660 - accuracy: 0.7937 - val_loss: 0.4960 - val_accuracy: 0.7676 Epoch 32/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4616 - accuracy: 0.7954 - val_loss: 0.4923 - val_accuracy: 0.7704 Epoch 33/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4576 - accuracy: 0.7983 - val_loss: 0.4911 - val_accuracy: 0.7730 Epoch 34/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4561 - accuracy: 0.7990 - val_loss: 0.4907 - val_accuracy: 0.7674 Epoch 35/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4518 - accuracy: 0.8011 - val_loss: 0.4896 - val_accuracy: 0.7720 Epoch 36/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4494 - accuracy: 0.8037 - val_loss: 0.4853 - val_accuracy: 0.7750 Epoch 37/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4465 - accuracy: 0.8042 - val_loss: 0.4820 - val_accuracy: 0.7752 Epoch 38/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4451 - accuracy: 0.8051 - val_loss: 0.4820 - val_accuracy: 0.7758 Epoch 39/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4430 - accuracy: 0.8057 - val_loss: 0.4798 - val_accuracy: 0.7760 Epoch 40/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4400 - accuracy: 0.8086 - val_loss: 0.4885 - val_accuracy: 0.7718 Epoch 41/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4384 - accuracy: 0.8079 - val_loss: 0.4762 - val_accuracy: 0.7792 Epoch 42/100 313/313 [==============================] - 6s 20ms/step - loss: 0.4369 - accuracy: 0.8111 - val_loss: 0.4828 - val_accuracy: 0.7770 Epoch 43/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4349 - accuracy: 0.8102 - val_loss: 0.4788 - val_accuracy: 0.7778 Epoch 44/100 313/313 [==============================] - 6s 20ms/step - loss: 0.4332 - accuracy: 0.8116 - val_loss: 0.4736 - val_accuracy: 0.7798 Epoch 45/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4320 - accuracy: 0.8110 - val_loss: 0.4804 - val_accuracy: 0.7808 Epoch 46/100 313/313 [==============================] - 6s 20ms/step - loss: 0.4293 - accuracy: 0.8140 - val_loss: 0.4710 - val_accuracy: 0.7800 Epoch 47/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4285 - accuracy: 0.8130 - val_loss: 0.4706 - val_accuracy: 0.7786 Epoch 48/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4272 - accuracy: 0.8146 - val_loss: 0.4700 - val_accuracy: 0.7784 Epoch 49/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4255 - accuracy: 0.8149 - val_loss: 0.4673 - val_accuracy: 0.7818 Epoch 50/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4239 - accuracy: 0.8164 - val_loss: 0.4662 - val_accuracy: 0.7838 Epoch 51/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4228 - accuracy: 0.8180 - val_loss: 0.4662 - val_accuracy: 0.7828 Epoch 52/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4210 - accuracy: 0.8183 - val_loss: 0.4652 - val_accuracy: 0.7832 Epoch 53/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4199 - accuracy: 0.8194 - val_loss: 0.4753 - val_accuracy: 0.7764 Epoch 54/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4187 - accuracy: 0.8198 - val_loss: 0.4635 - val_accuracy: 0.7842 Epoch 55/100 313/313 [==============================] - 6s 20ms/step - loss: 0.4173 - accuracy: 0.8210 - val_loss: 0.4615 - val_accuracy: 0.7856 Epoch 56/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4164 - accuracy: 0.8198 - val_loss: 0.4606 - val_accuracy: 0.7850 Epoch 57/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4147 - accuracy: 0.8201 - val_loss: 0.4596 - val_accuracy: 0.7874 Epoch 58/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4135 - accuracy: 0.8210 - val_loss: 0.4607 - val_accuracy: 0.7844 Epoch 59/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4121 - accuracy: 0.8224 - val_loss: 0.4631 - val_accuracy: 0.7850 Epoch 60/100 313/313 [==============================] - 6s 19ms/step - loss: 0.4111 - accuracy: 0.8227 - val_loss: 0.4604 - val_accuracy: 0.7856
In [36]:
model_f = keras.models.load_model(f"./model/best_simpleRNN_model.h5")
In [37]:
model_f.evaluate(x)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[37], line 1 ----> 1 model_f.evaluate(x) NameError: name 'x' is not defined
In [38]:
history.history
Out[38]:
{'loss': [0.6988863348960876, 0.694591224193573, 0.691993772983551, 0.6900757551193237, 0.688349723815918, 0.6865505576133728, 0.6844810843467712, 0.6818762421607971, 0.6776968836784363, 0.6613922119140625, 0.6407214403152466, 0.6272233128547668, 0.6139527559280396, 0.6000173091888428, 0.5878065824508667, 0.5753841400146484, 0.5646007061004639, 0.5547738075256348, 0.544468343257904, 0.5343042612075806, 0.5270196795463562, 0.5164255499839783, 0.5098249316215515, 0.5025126338005066, 0.49561241269111633, 0.4885960519313812, 0.48474764823913574, 0.47787153720855713, 0.4742515981197357, 0.46946385502815247, 0.46600741147994995, 0.46163251996040344, 0.4576233923435211, 0.45611807703971863, 0.45180943608283997, 0.4494328200817108, 0.44650301337242126, 0.4451380968093872, 0.4429549276828766, 0.44003602862358093, 0.43837177753448486, 0.4369499087333679, 0.4349014163017273, 0.4331761300563812, 0.4320429265499115, 0.42930543422698975, 0.4284878075122833, 0.4272317886352539, 0.42552992701530457, 0.4238724112510681, 0.42280313372612, 0.4209577143192291, 0.41992518305778503, 0.4186874032020569, 0.417283296585083, 0.41643744707107544, 0.4146941304206848, 0.41347652673721313, 0.41209378838539124, 0.41113418340682983], 'accuracy': [0.5040500164031982, 0.5101500153541565, 0.5200499892234802, 0.5302000045776367, 0.5407999753952026, 0.547249972820282, 0.5562999844551086, 0.565850019454956, 0.5827500224113464, 0.6260499954223633, 0.6652500033378601, 0.685699999332428, 0.698199987411499, 0.7130500078201294, 0.7224000096321106, 0.73294997215271, 0.7390000224113464, 0.7442499995231628, 0.7502999901771545, 0.756600022315979, 0.7597000002861023, 0.7650499939918518, 0.7697499990463257, 0.7732499837875366, 0.7763000130653381, 0.7792500257492065, 0.7823500037193298, 0.7870500087738037, 0.789650022983551, 0.7914000153541565, 0.7936999797821045, 0.7954000234603882, 0.79830002784729, 0.7990000247955322, 0.8010500073432922, 0.8036500215530396, 0.8041999936103821, 0.8050500154495239, 0.8057000041007996, 0.8086000084877014, 0.8078500032424927, 0.8111000061035156, 0.8101500272750854, 0.8116000294685364, 0.8109999895095825, 0.8140000104904175, 0.8129500150680542, 0.8145999908447266, 0.8148999810218811, 0.8163999915122986, 0.8180000185966492, 0.8182500004768372, 0.819350004196167, 0.8197500109672546, 0.8210499882698059, 0.8198000192642212, 0.8201000094413757, 0.8210499882698059, 0.8224499821662903, 0.8226500153541565], 'val_loss': [0.697847306728363, 0.69490647315979, 0.6930205821990967, 0.6915863752365112, 0.6905214786529541, 0.6893666982650757, 0.6879813075065613, 0.6856301426887512, 0.6807393431663513, 0.6584729552268982, 0.6424593329429626, 0.6293952465057373, 0.6197000741958618, 0.6079766154289246, 0.5959619879722595, 0.5864059329032898, 0.5739489197731018, 0.5647232532501221, 0.5544484257698059, 0.5461227893829346, 0.5383333563804626, 0.5329364538192749, 0.5286375880241394, 0.5224370360374451, 0.518394947052002, 0.5140067934989929, 0.5084830522537231, 0.5070354342460632, 0.501668393611908, 0.5008770823478699, 0.49603110551834106, 0.492311030626297, 0.4910542368888855, 0.4906723201274872, 0.48963719606399536, 0.4852820932865143, 0.4819522500038147, 0.48204702138900757, 0.47984471917152405, 0.48854535818099976, 0.4762129783630371, 0.4827572703361511, 0.4787876605987549, 0.47356390953063965, 0.4803864657878876, 0.4710353910923004, 0.47056445479393005, 0.4700363278388977, 0.4673305153846741, 0.4661930203437805, 0.4662083089351654, 0.46522432565689087, 0.475283145904541, 0.46350961923599243, 0.4614780843257904, 0.46060898900032043, 0.4596421420574188, 0.460712730884552, 0.46307188272476196, 0.4604252576828003], 'val_accuracy': [0.5037999749183655, 0.510200023651123, 0.5145999789237976, 0.5230000019073486, 0.5270000100135803, 0.5356000065803528, 0.5415999889373779, 0.5482000112533569, 0.5726000070571899, 0.6240000128746033, 0.6585999727249146, 0.6761999726295471, 0.6848000288009644, 0.6944000124931335, 0.7098000049591064, 0.7125999927520752, 0.7265999913215637, 0.7325999736785889, 0.739799976348877, 0.7418000102043152, 0.75, 0.7472000122070312, 0.7483999729156494, 0.7555999755859375, 0.7576000094413757, 0.7567999958992004, 0.7634000182151794, 0.7627999782562256, 0.7634000182151794, 0.76419997215271, 0.7675999999046326, 0.7703999876976013, 0.7730000019073486, 0.7674000263214111, 0.7720000147819519, 0.7749999761581421, 0.7752000093460083, 0.7757999897003174, 0.7760000228881836, 0.7717999815940857, 0.77920001745224, 0.7770000100135803, 0.7778000235557556, 0.7797999978065491, 0.7807999849319458, 0.7799999713897705, 0.7785999774932861, 0.7784000039100647, 0.7817999720573425, 0.7838000059127808, 0.782800018787384, 0.7832000255584717, 0.7764000296592712, 0.7842000126838684, 0.7856000065803528, 0.7850000262260437, 0.7874000072479248, 0.7843999862670898, 0.7850000262260437, 0.7856000065803528]}
In [39]:
"""훈련 및 검증에 대한 손실 곡선 그리기"""
plt.plot(history.epoch, history.history["loss"])
plt.plot(history.epoch, history.history["val_loss"])
plt.grid()
plt.legend(["loss","val_loss"])
plt.show()
In [40]:
"""훈련 및 검증에 대한 손실 정확도 그리기"""
plt.plot(history.epoch, history.history["accuracy"])
plt.plot(history.epoch, history.history["val_accuracy"])
plt.grid()
plt.legend(["accuracy","val_accuracy"])
plt.show()
단어 임베딩 방식을 사용¶
In [42]:
train_seq
Out[42]:
array([[ 10, 4, 20, ..., 10, 470, 158], [206, 2, 26, ..., 6, 2, 2], [ 2, 7, 2, ..., 2, 2, 12], ..., [ 2, 37, 299, ..., 7, 14, 2], [ 0, 0, 0, ..., 25, 170, 2], [ 0, 0, 0, ..., 25, 194, 2]])
In [81]:
"""모델 생성하기"""
model2 = keras.Sequential()
model2
Out[81]:
<keras.engine.sequential.Sequential at 0x1caade91a60>
In [82]:
"""계층 생성 및 추가하기
입력 계층 생성하기( 단어 임베딩 계층으로 생성)
-500: 말뭉치 갯수
-16: 출력크기 (갯수)
-input_length: 사용할 특성 갯수(input_shape와 동일)
"""
"""입력계층 생성하기 ( 단어 임베딩 계층으로 생성)"""
model2.add(keras.layers.Embedding(500,16, input_length=100))
"""simple RNN 계층 추가"""
model2.add(keras.layers.SimpleRNN(8))
"""출력계층 추가"""
model2.add(keras.layers.Dense(1, activation="sigmoid"))
In [83]:
model2.summary()
Model: "sequential_3" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= embedding_3 (Embedding) (None, 100, 16) 8000 simple_rnn_4 (SimpleRNN) (None, 8) 200 dense_4 (Dense) (None, 1) 9 ================================================================= Total params: 8,209 Trainable params: 8,209 Non-trainable params: 0 _________________________________________________________________
In [84]:
# rmsprop = keras.optimizers.RMSprop(learning_rate=0.0001)
model2.compile(optimizer=rmsprop, loss="binary_crossentropy",
metrics="accuracy")
In [85]:
train_seq
Out[85]:
array([[ 10, 4, 20, ..., 10, 470, 158], [206, 2, 26, ..., 6, 2, 2], [ 2, 7, 2, ..., 2, 2, 12], ..., [ 2, 37, 299, ..., 7, 14, 2], [ 0, 0, 0, ..., 25, 170, 2], [ 0, 0, 0, ..., 25, 194, 2]])
In [86]:
val_seq
Out[86]:
array([[ 32, 2, 225, ..., 14, 58, 2], [ 53, 2, 8, ..., 7, 32, 2], [ 0, 0, 0, ..., 2, 33, 32], ..., [383, 2, 120, ..., 16, 99, 76], [106, 345, 12, ..., 120, 2, 156], [ 4, 114, 21, ..., 4, 2, 2]])
In [87]:
"""모델 훈련 후 손실 및 정확도 시각화 까지 해주세요"""
history = model2.fit(train_seq,train_target, epochs=100, batch_size=64,validation_data=(val_seq, val_target),
callbacks=[checkpoint_cb, ealry_stopping_cb])
Epoch 1/100 313/313 [==============================] - 4s 11ms/step - loss: 0.6957 - accuracy: 0.5072 - val_loss: 0.6945 - val_accuracy: 0.5054 Epoch 2/100 313/313 [==============================] - 3s 10ms/step - loss: 0.6864 - accuracy: 0.5531 - val_loss: 0.6772 - val_accuracy: 0.6120 Epoch 3/100 313/313 [==============================] - 3s 10ms/step - loss: 0.6639 - accuracy: 0.6537 - val_loss: 0.6558 - val_accuracy: 0.6762 Epoch 4/100 313/313 [==============================] - 3s 10ms/step - loss: 0.6411 - accuracy: 0.7019 - val_loss: 0.6386 - val_accuracy: 0.6944 Epoch 5/100 313/313 [==============================] - 3s 10ms/step - loss: 0.6198 - accuracy: 0.7281 - val_loss: 0.6172 - val_accuracy: 0.7320 Epoch 6/100 313/313 [==============================] - 3s 10ms/step - loss: 0.5998 - accuracy: 0.7472 - val_loss: 0.6004 - val_accuracy: 0.7324 Epoch 7/100 313/313 [==============================] - 3s 10ms/step - loss: 0.5816 - accuracy: 0.7583 - val_loss: 0.5838 - val_accuracy: 0.7512 Epoch 8/100 313/313 [==============================] - 3s 10ms/step - loss: 0.5647 - accuracy: 0.7696 - val_loss: 0.5733 - val_accuracy: 0.7494 Epoch 9/100 313/313 [==============================] - 3s 10ms/step - loss: 0.5485 - accuracy: 0.7745 - val_loss: 0.5547 - val_accuracy: 0.7654 Epoch 10/100 313/313 [==============================] - 3s 10ms/step - loss: 0.5339 - accuracy: 0.7810 - val_loss: 0.5415 - val_accuracy: 0.7702 Epoch 11/100 313/313 [==============================] - 3s 11ms/step - loss: 0.5194 - accuracy: 0.7883 - val_loss: 0.5314 - val_accuracy: 0.7662 Epoch 12/100 313/313 [==============================] - 3s 11ms/step - loss: 0.5064 - accuracy: 0.7918 - val_loss: 0.5211 - val_accuracy: 0.7686 Epoch 13/100 313/313 [==============================] - 3s 11ms/step - loss: 0.4953 - accuracy: 0.7961 - val_loss: 0.5131 - val_accuracy: 0.7728 Epoch 14/100 313/313 [==============================] - 3s 11ms/step - loss: 0.4846 - accuracy: 0.7987 - val_loss: 0.5095 - val_accuracy: 0.7698 Epoch 15/100 313/313 [==============================] - 3s 11ms/step - loss: 0.4754 - accuracy: 0.8033 - val_loss: 0.5180 - val_accuracy: 0.7548 Epoch 16/100 313/313 [==============================] - 3s 11ms/step - loss: 0.4675 - accuracy: 0.8049 - val_loss: 0.4956 - val_accuracy: 0.7714 Epoch 17/100 313/313 [==============================] - 3s 11ms/step - loss: 0.4592 - accuracy: 0.8079 - val_loss: 0.4889 - val_accuracy: 0.7762 Epoch 18/100 313/313 [==============================] - 3s 10ms/step - loss: 0.4530 - accuracy: 0.8084 - val_loss: 0.4890 - val_accuracy: 0.7748 Epoch 19/100 313/313 [==============================] - 3s 10ms/step - loss: 0.4464 - accuracy: 0.8138 - val_loss: 0.4844 - val_accuracy: 0.7752 Epoch 20/100 313/313 [==============================] - 3s 11ms/step - loss: 0.4412 - accuracy: 0.8122 - val_loss: 0.4799 - val_accuracy: 0.7768 Epoch 21/100 313/313 [==============================] - 3s 11ms/step - loss: 0.4360 - accuracy: 0.8152 - val_loss: 0.4783 - val_accuracy: 0.7774 Epoch 22/100 313/313 [==============================] - 4s 11ms/step - loss: 0.4317 - accuracy: 0.8165 - val_loss: 0.4810 - val_accuracy: 0.7740 Epoch 23/100 313/313 [==============================] - 3s 11ms/step - loss: 0.4275 - accuracy: 0.8195 - val_loss: 0.4766 - val_accuracy: 0.7786 Epoch 24/100 313/313 [==============================] - 3s 11ms/step - loss: 0.4238 - accuracy: 0.8218 - val_loss: 0.4761 - val_accuracy: 0.7804 Epoch 25/100 313/313 [==============================] - 3s 10ms/step - loss: 0.4206 - accuracy: 0.8224 - val_loss: 0.4747 - val_accuracy: 0.7800 Epoch 26/100 313/313 [==============================] - 3s 10ms/step - loss: 0.4168 - accuracy: 0.8238 - val_loss: 0.4770 - val_accuracy: 0.7784 Epoch 27/100 313/313 [==============================] - 3s 11ms/step - loss: 0.4147 - accuracy: 0.8252 - val_loss: 0.4732 - val_accuracy: 0.7804 Epoch 28/100 313/313 [==============================] - 3s 11ms/step - loss: 0.4121 - accuracy: 0.8262 - val_loss: 0.4724 - val_accuracy: 0.7776 Epoch 29/100 313/313 [==============================] - 3s 10ms/step - loss: 0.4100 - accuracy: 0.8268 - val_loss: 0.4754 - val_accuracy: 0.7774 Epoch 30/100 313/313 [==============================] - 3s 11ms/step - loss: 0.4072 - accuracy: 0.8295 - val_loss: 0.4741 - val_accuracy: 0.7822 Epoch 31/100 313/313 [==============================] - 3s 10ms/step - loss: 0.4054 - accuracy: 0.8278 - val_loss: 0.4777 - val_accuracy: 0.7790
In [91]:
"""훈련 및 검증에 대한 손실 곡선 그리기"""
plt.plot(history.epoch, history.history["loss"])
plt.plot(history.epoch, history.history["val_loss"])
plt.grid()
plt.legend(["loss","val_loss"])
plt.show()
In [89]:
"""훈련 및 검증에 대한 손실 정확도 그리기"""
plt.plot(history.epoch, history.history["accuracy"])
plt.plot(history.epoch, history.history["val_accuracy"])
plt.grid()
plt.legend(["accuracy","val_accuracy"])
plt.show()
성능 평가¶
In [92]:
"""원-핫 데이터 모델로 검증데이터 평가하기"""
model.evaluate(val_oh, val_target)
157/157 [==============================] - 2s 8ms/step - loss: 0.4596 - accuracy: 0.7874
Out[92]:
[0.4596419334411621, 0.7874000072479248]
In [93]:
"""임베딩 모델로 검증데이터 평가하기"""
model2.evaluate(val_seq, val_target)
157/157 [==============================] - 0s 2ms/step - loss: 0.4724 - accuracy: 0.7776
Out[93]:
[0.47240811586380005, 0.7775999903678894]
테스트 데이터로 원-핫 모델 및 단어 임베딩 모델로 각각 평가하기¶
In [100]:
"""테스트 데이터 100개 길이로 통일(정규화)하기"""
test_seq = pad_sequences(test_input, maxlen=100, truncating="pre", padding="pre")
"""원핫 모델에 사용할 데이터 생성"""
test_oh = keras.utils.to_categorical(test_seq)
In [97]:
"""원 핫 인코딩 모델 검증"""
model.evaluate(test_oh, test_target)
782/782 [==============================] - 8s 7ms/step - loss: 0.4588 - accuracy: 0.7886
Out[97]:
[0.45879143476486206, 0.7885599732398987]
In [98]:
"""단어 임베딩 모델 검증"""
model2.evaluate(test_seq, test_target)
782/782 [==============================] - 2s 2ms/step - loss: 0.4657 - accuracy: 0.7850
Out[98]:
[0.4657423794269562, 0.7850000262260437]
In [102]:
y_pred = model.predict(test_oh)
binary_pred = (y_pred > 0.5).astype(int)
binary_pred
Out[102]:
array([[0], [0], [1], ..., [0], [0], [1]])
In [104]:
from sklearn.metrics import confusion_matrix , precision_score, recall_score, f1_score
precision = precision_score(binary_pred, test_target)
recall = recall_score(binary_pred,test_target)
f1 = f1_score(binary_pred,test_target)
cm = confusion_matrix(binary_pred, test_target)
print(precision, recall, f1)
print(cm)
0.79 0.7877313337587747 0.7888640357884645 [[9839 2625] [2661 9875]]
In [105]:
import seaborn as sns
plt.title("원핫인코딩 모델 혼동행렬도")
sns.heatmap(cm, annot=True, fmt="d",cmap= "Blues",cbar=False)
plt.show()
In [ ]:
728x90
반응형
'딥러닝' 카테고리의 다른 글
합성곱신경망(CNN)을_이용한_이미지_분류 (0) | 2024.01.08 |
---|---|
RNN 응용 규칙기반 챗봇 (1) | 2024.01.08 |
퍼셉트론_분류데이터사용 (0) | 2024.01.05 |
DNN_분류데이터사용 (0) | 2024.01.05 |
[딥러닝]인공신경망_훈련모델_맛보기 (2) | 2024.01.03 |