[튜토리얼2] 불균형 데이터 분류하기

이 튜토리얼에서는 한 클래스의 데이터 수가 다른 클래스의 데이터 수보다 훨씬 많은 불균형 데이터셋을 분류하는 방법을 보여 줍니다. Kaggle에서 호스팅되는 신용 카드 사기 탐지 데이터셋을 사용하겠습니다. 해당 데이터셋은 총 284,807건의 거래에서 492건의 부정거래를 적발하는 것을 목표로 하고 있습니다. 케라스(Keras) API를 사용하여 모델을 정의하고 class weight를 사용하여 모델이 불균형 데이터로부터 잘 학습하도록 해보겠습니다.

이번 튜토리얼은 다음의 과정을 포함합니다.

import warnings
warnings.simplefilter('ignore')

import tensorflow as tf
from tensorflow import keras

import os
import tempfile

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

목차

  1. 데이터 전처리와 탐색
    • 1.1 캐글 신용 카드 사기 데이터셋 다운로드하기
    • 1.2 클래스 레이블 불균형 검사하기
    • 1.3 데이터를 정제하고 나누고 정규화하기
    • 1.4 데이터 분포 확인하기
  2. 모델과 메트릭스 정의하기
    • 2.1 유용한 메트릭스 이해하기
  3. 베이스라인 모델
    • 3.1 모델 구축하기
    • 3.2 선택사항: 편향(bias)을 알맞게 초기화합니다.
    • 3.3 초기 가중치 체크포인트 만들기
    • 3.4 편향 수정이 도움이 되는지 확인하기
    • 3.5 모델 학습시키기
    • 3.6 학습 히스토리 확인하기
    • 3.7 평가 메트릭스
    • 3.8 ROC 그리기
  4. 클래스 가중치
    • 4.1 클래스 가중치 계산하기
    • 4.2 클래스 가중치로 모델 학습시키기
    • 4.3 학습 히스토리 확인하기
    • 4.4 평가 메트릭스
    • 4.5 ROC 그리기
  5. 오버샘플링(Oversampling)
    • 5.1 소수 클래스 오버샘플링하기
      • 넘파이(NumPy) 사용하기
      • tf.data 사용하기
    • 5.2 오버샘플링한 데이터 학습시키기
    • 5.3 학습 히스토리 확인하기
    • 5.4 재학습시키기
    • 5.5 학습 히스토리 다시 확인하기
    • 5.6 평가 메트릭스
    • 5.7 ROC 그리기
  6. 주어진 문제에 이 튜토리얼 적용하기

1. 데이터 전처리와 탐색하기

1.1 캐글 신용 카드 사기 데이터셋 다운로드하기

판다스(Pandas)는 구조화된 데이터를 불러오고 작업하는 데 유용한 기능들을 포함하고 있는 파이썬 라이브러리이며, CSV를 데이터 프레임으로 다운로드하는 데 사용할 수 있습니다.

file = tf.keras.utils
raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv')
raw_df.head()

.describe() 메소드를 사용하여 각 피쳐에 대한 통계를 살펴보겠습니다.

raw_df[['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V26', 'V27', 'V28', 'Amount', 'Class']].describe()

1.2 클래스 레이블 불균형 검사하기

데이터셋의 불균형을 살펴봅시다. Class 피쳐 값이 1이면 부정 거래, 0이면 정상 거래를 의미합니다.:

neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('데이터 :\n    전체: {}개\n    부정 거래: {}개 (전체의 {:.2f}%)\n'.format(
    total, pos, 100 * pos / total))

전체 중 약 0.17% 정도가 부정 거래로, 매우 불균형한 데이터셋이라는 것을 확인할 수 있습니다.

1.3 데이터를 정제하고 나누고 정규화하기

아무런 처리를 하지 않은 raw 데이터에는 다음과 같은 몇 가지 문제가 있습니다. 먼저 TimeAmount 피쳐는 너무 가변적이어서 직접 사용할 수 없다는 점입니다. 따라서 필요 없는 Time 피쳐를 드랍하고 범위가 너무 넓은 Amount 피쳐에 로그를 취해 범위를 줄입니다.

cleaned_df = raw_df.copy()

# `Time` 열은 필요없습니다.
cleaned_df.pop('Time')

# `Amount` 열은 범위가 매우 넓습니다. log-space로 바꿔줍니다.
eps=0.001 # 0 => 0.1¢
cleaned_df['Log Ammount'] = np.log(cleaned_df.pop('Amount')+eps)

데이터셋을 학습, 검증 그리고 테스트 셋으로 분할합니다. 검증 데이터셋은 모델에 적용하는 과정에서 손실과 메트릭스를 평가하는 데 사용되지만 모델에 직접 검증 데이터를 학습시키지는 않습니다. 테스트 데이터셋은 학습 과정에서 전혀 사용되지 않으며, 모델이 새로운 데이터에 얼마나 일반화되는지 평가하기 위해 마지막 단계에서만 사용됩니다. 이는 학습 데이터의 부족으로 인해 과대적합(오버피팅) 이 크게 우려되는 불균형 데이터셋에서 특히 중요합니다.

# sklearn에서 유틸리티를 가져와 데이터셋을 분할하고 섞는 데 사용합니다.
train_df, test_df = train_test_split(cleaned_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.2)

# 레이블과 피쳐를 넘파이 배열로 바꿔줍니다.
train_labels = np.array(train_df.pop('Class'))
bool_train_labels = train_labels != 0
val_labels = np.array(val_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))

train_features = np.array(train_df)
val_features = np.array(val_df)
test_features = np.array(test_df)

sklearnStandardScaler를 사용하여 입력 피쳐를 정규화합니다. 정규화를 통해 평균은 0으로, 표준 편차는 1로 설정됩니다.

scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)

val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)

train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)


print('학습 레이블셋 shape:', train_labels.shape)
print('검증 레이블셋 shape:', val_labels.shape)
print('테스트 레이블셋 shape:', test_labels.shape)
print()
print('학습 피쳐셋 shape:', train_features.shape)
print('검증 피쳐셋 shape:', val_features.shape)
print('테스트 피쳐셋 shape:', test_features.shape)

1.4 데이터 분포 확인하기

그런 다음 몇 가지 피쳐에 대해 부정 거래와 정상 거래 데이터의 분포를 비교하여 확인해보겠습니다. 이 시점에서 확인해 볼 수 있는 좋은 질문은 다음과 같습니다.

pos_df = pd.DataFrame(train_features[ bool_train_labels], columns = train_df.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns = train_df.columns)

sns.jointplot(pos_df['V5'], pos_df['V6'],
              kind='hex', xlim = (-5,5), ylim = (-5,5))
plt.suptitle("class=1 distribution")

sns.jointplot(neg_df['V5'], neg_df['V6'],
              kind='hex', xlim = (-5,5), ylim = (-5,5))
_ = plt.suptitle("class=0 distribution")

2. 모델과 메트릭스 정의하기

밀도 있게 연결된 은닉 레이어(Dense Layer), 과대적합(오버피팅)을 줄이기 위한 드롭아웃 레이어, 트랜잭션의 부정확성을 반환하는 시그모이드 출력 레이어로 구성된 단순한 신경망을 생성하는 함수를 정의합니다.

METRICS = [
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'), 
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
]

def make_model(metrics = METRICS, output_bias=None):
    if output_bias is not None:
        output_bias = tf.keras.initializers.Constant(output_bias)
    model = keras.Sequential([
        keras.layers.Dense(
            16, activation='relu',
            input_shape=(train_features.shape[-1],)),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(1, activation='sigmoid',
                         bias_initializer=output_bias),
    ])

    model.compile(
        optimizer=keras.optimizers.Adam(lr=1e-3),
        loss=keras.losses.BinaryCrossentropy(),
        metrics=metrics)

    return model

2.1 유용한 메트릭스 이해하기

아래와 같이 정의된 몇 가지 메트릭스는 모델을 통해 계산할 수 있으며 모델의 성능을 평가할 때 유용합니다.

$Accuracy = \frac{\text{true samples}}{\text{total samples}}$

$Precision = \frac{\text{true positives}}{\text{true positives + false positives}}$

$Recall = \frac{\text{true positives}}{\text{true positives + false negatives}}$

참고: 만약 모델이 모든 예측 대상 데이터에 대해 항상 정상 거래(class=0)라고 예측한다면, 모델의 정확도를 99.8% 이상으로 높일 수 있습니다.따라서 정확도는 불균형 데이터셋 분류에는 유용한 메트릭스가 아닙니다.

3. 베이스라인 모델

3.1 모델 구축하기

이제 앞에서 정의한 함수를 사용하여 모델을 만들고 학습해보겠습니다. 모델은 기본 배치 크기인 2048보다 큰 배치 크기를 사용하는 것이 적합합니다. 각 배치에서 최소 몇 개 이상의 부정 거래(class=1)를 포함하도록 하는 것이 중요합니다. 배치 크기가 너무 작으면 학습할 수 있는 부정 거래 예시가 없을 수 있습니다.

EPOCHS = 100
BATCH_SIZE = 2048

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_auc', 
    verbose=1,
    patience=10,
    mode='max',
    restore_best_weights=True)
model = make_model()
model.summary()

모델을 실행하여 테스트해봅시다:

model.predict(train_features[:10])

3.2 선택사항: 편향(bias)을 알맞게 초기화합니다.

초기 수렴을 돕기 위하여 불균형 데이터를 반영하여 편향을 초기화할 수 있습니다. 디폴트로 편향을 초기화한 것에서는 손실이 약 math.log(2) = 0.69314 정도입니다.

results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))

아래의 공식을 통해 편향을 올바르게 설정합니다.

여기서 $pos$란 class값이 1인 경우 즉, 부정 거래인 경우를 의미하고 $neg$는 class값이 0인 경우 즉, 정상 거래를 의미합니다.:

\(p_0 = pos/(pos + neg) = 1/(1+e^{-b_0})\) \(b_0 = -log_e(1/p_0 - 1)\) \(b_0 = log_e(pos/neg)\)

initial_bias = np.log([pos/neg])
initial_bias

이를 초기 편향으로 설정하면 모델은 훨씬 더 합리적인 초기 추측을 할 수 있을 것입니다.

pos/total = 0.0018에 가까울 것입니다.

model = make_model(output_bias = initial_bias)
model.predict(train_features[:10])

이 초기화에서의 초기 손실은 대략 다음과 같아야 합니다:

\[-p_0log(p_0)-(1-p_0)log(1-p_0) = 0.01317\]
results = model.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print("Loss: {:0.4f}".format(results[0]))

위와 같은 초기 손실은 디폴트로 초기화했을 때보다 약 50배 더 작습니다.

이런 식으로 편향을 초기화할 경우, 모델은 긍정인 데이터(부정 거래)가 거의 없다는 것을 학습하는 데 처음 몇 에포크를 보낼 필요가 없습니다. 이것은 또한 학습 중에 손실된 부분을 더 쉽게 파악할 수 있게 해줍니다.

3.3 초기 가중치 체크포인트 만들기

다양한 학습 과정을 비교하기 위해 초기 모델의 가중치를 체크포인트 파일에 보관하고 학습 전에 각 모델에 로드합니다.

initial_weights = os.path.join(tempfile.mkdtemp(),'initial_weights')
model.save_weights(initial_weights)

3.4 편향 수정이 도움이 되는지 확인하기

다음 단계로 넘어가기 전에 세심한 편향의 초기화가 실제로 도움이 되었는지 빠르게 확인해보겠습니다.

세심한 초기화가 있는 모델과 없는 모델을 20 에포크 동안 학습시키고, 두 모델의 손실을 비교합니다.

model = make_model()
model.load_weights(initial_weights)
model.layers[-1].bias.assign([0.0])
zero_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
model = make_model()
model.load_weights(initial_weights)
careful_bias_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=20,
    validation_data=(val_features, val_labels), 
    verbose=0)
def plot_loss(history, label, n):
    plt.semilogy(history.epoch,  history.history['loss'],
               color=colors[n], label='Train '+label)
    plt.semilogy(history.epoch,  history.history['val_loss'],
          color=colors[n], label='Val '+label,
          linestyle="--")
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
  
    plt.legend()
plot_loss(zero_bias_history, "Zero Bias", 0)
plot_loss(careful_bias_history, "Careful Bias", 1)

위의 그림에서 명확히 알 수 있듯이 검증 손실 측면에서 세심한 초기화는 분명한 이점을 가지고 있습니다.

3.5 모델 학습시키기

model = make_model()
model.load_weights(initial_weights)
baseline_history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_features, val_labels))

3.6 학습 히스토리 확인하기

이 섹션에서는 학습 및 검증 데이터셋에 대한 모델의 정확도와 손실을 보여 줍니다.

또한 위에서 생성한 다양한 메트리스에 대해 아래와 같은 그림을 생성할 수 있습니다.

def plot_metrics(history):
    metrics =  ['loss', 'auc', 'precision', 'recall']
    for n, metric in enumerate(metrics):
        name = metric.replace("_"," ").capitalize()
        plt.subplot(2,2,n+1)
        plt.plot(history.epoch,  history.history[metric], color=colors[0], label='Train')
        plt.plot(history.epoch, history.history['val_'+metric],
             color=colors[0], linestyle="--", label='Val')
        plt.xlabel('Epoch')
        plt.ylabel(name)
        if metric == 'loss':
            plt.ylim([0, plt.ylim()[1]])
        elif metric == 'auc':
            plt.ylim([0.8,1])
        else:
            plt.ylim([0,1])

    plt.legend()
plot_metrics(baseline_history)

참고: 검증 곡선은 일반적으로 학습 곡선보다 성능이 우수합니다. 이는 보통 모델을 평가할 때는 드롭아웃(Dropout) 레이어가 활성화되지 않기 때문입니다.

3.7 평가 메트릭스

혼동 행렬(confusion matrix)을 사용하여 실제 레이블 대 예측 레이블을 요약할 수 있습니다. 여기서 X 축은 예측 레이블이고 Y 축은 실제 레이블입니다.

train_predictions_baseline = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model.predict(test_features, batch_size=BATCH_SIZE)
def plot_cm(labels, predictions, p=0.5):
    cm = confusion_matrix(labels, predictions > p)
    plt.figure(figsize=(5,5))
    sns.heatmap(cm, annot=True, fmt="d")
    plt.title('Confusion matrix @{:.2f}'.format(p))
    plt.ylabel('Actual label')
    plt.xlabel('Predicted label')

    print('정상 거래를 잘 예측한 경우 (True Negatives): ', cm[0][0])
    print('정상 거래를 부정 거래라고 잘못 예측한 경우 (False Positives): ', cm[0][1])
    print('부정 거래를 정상 거래라고 잘못 예측한 경우 (False Negatives): ', cm[1][0])
    print('부정 거래를 잘 예측한 경우 (True Positives): ', cm[1][1])
    print('전체 부정 거래: ', np.sum(cm[1]))

테스트 데이터셋에서 모델을 평가하고 위에서 생성한 메트릭스의 결과를 확인합니다.

baseline_results = model.evaluate(test_features, test_labels,
                                  batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model.metrics_names, baseline_results):
    print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_baseline)

모델이 모든 것을 완벽하게 예측했다면 잘못된 예측을 나타내는 대각행렬(Diagonal matrix)의 대각선 값은 0이 됩니다. 이 경우 매트릭스에 거짓 긍정(false positive)이 상대적으로 적다는 것이 나타나며, 이는 부정 거래로 잘못 예측한 정상 거래가 상대적으로 적다는 것을 의미합니다.

하지만, 여러분은 거짓 긍정(false positive)수를 증가시킬 수 있는 가능성에도 불구하고 거짓 부정(false negative)의 수를 훨씬 더 적게 갖고 싶어할 것입니다. 거짓 부정으로 인해 부정 거래가 발생할 수 있는 반면, 거짓 긍정으로 인해 고객에게 카드 활동을 확인하라는 이메일을 보낼 수 있기 때문에 거짓 부정을 줄이는 것이 더 바람직할 수 있기 때문입니다.

3.8 ROC 그리기

이제 ROC를 그립니다. 이 그림은 출력 임계값을 조정하는 것만으로 모델이 도달할 수 있는 성능 범위를 한눈에 보여 주기 때문에 유용합니다.

def plot_roc(name, labels, predictions, **kwargs):
    fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)

    plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
    plt.xlabel('False positives [%]')
    plt.ylabel('True positives [%]')
    plt.xlim([-0.5,20])
    plt.ylim([80,100.5])
    plt.grid(True)
    ax = plt.gca()
    ax.set_aspect('equal')
plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')
plt.legend(loc='lower right')

정밀도(precision)는 비교적 높은 것 같지만, 재현율(Recall)과 ROC 곡선 아래의 영역(AUC)은 원하는 만큼 높지 않습니다. 클래스는 정밀도와 재현율을 모두 최대화하려고 할 때 어려움을 겪는 경우가 많은데, 이는 불균형 데이터셋으로 작업할 때 특히 그렇습니다.

관심 있는 문제의 맥락에서 다양한 유형의 오류에 대한 비용을 고려하는 것이 중요합니다. 이 예에서 거짓 음성(부정 거래 누락)은 금전적 비용이 들 수 있는 반면, 거짓 긍정(긍정 거래를 부정 거래로 잘못 예측함)은 사용자의 충성도를 감소시킬 수 있습니다.

4. 클래스 가중치

4.1 클래스 가중치 계산하기

목표는 부정 거래를 식별하는 것이지만, 실제로 사용할 수 있는 양성 샘플이 많지 않기 때문에, 존재하는 몇 가지 양성 샘플 예제를 가능한 크게 가중시킬수도 있습니다.

각 클래스의 케라스 가중치를 매개변수를 통해 전달하여 이 작업을 수행할 수 있는데, 이는 모델이 더 적은 클래스의 예제에 “더 많은 관심을 기울일” 수 있게 하는 것입니다.

# total/2 단위로 확장하면 손실 규모도 비슷한 수준으로 유지할 수 있습니다.
# 모든 예제의 가중치 합계는 동일합니다.
weight_for_0 = (1 / neg)*(total)/2.0 
weight_for_1 = (1 / pos)*(total)/2.0

class_weight = {0: weight_for_0, 1: weight_for_1}

print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))

4.2 클래스 가중치로 모델 학습시키기

이제 클래스 가중치로 모델을 재학습시키고 평가하여 예측에 어떤 영향을 미치는지 확인해보겠습니다.

weighted_model = make_model()
weighted_model.load_weights(initial_weights)

weighted_history = weighted_model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_features, val_labels),
    # 클래스 가중치는 여기서 결정됩니다.
    class_weight=class_weight) 

4.3 학습 히스토리 확인하기

plot_metrics(weighted_history)

4.4 평가 메트릭스

train_predictions_weighted = weighted_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weighted = weighted_model.predict(test_features, batch_size=BATCH_SIZE)
weighted_results = weighted_model.evaluate(test_features, test_labels,
                                           batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(weighted_model.metrics_names, weighted_results):
    print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_weighted)

클래스 가중치를 적용할 경우, 정확도(Accuracy)와 정밀도(Precision)가 이전의 결과보다 더 낮다는 것을 알 수 있습니다. 이는 거짓 긍정(false positive)이 더 많기 때문입니다. 하지만 반대로 모델에서 더 많은 실제 긍정(true positive, 90개)을 발견했기 때문에 재현성(Recall)과 AUC는 더 높습니다. 이 모델은 더 많은 부정 거래를 식별하였기 때문에 정확도는 낮지만 재현성이 더 높다고 표현할 수 있습니다. 물론 두 가지 유형의 오류에 모두 비용이 발생하기 때문에 다양한 유형의 오류 간의 절충을 신중하게 고려해야합니다.

4.5 ROC 그리기

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')


plt.legend(loc='lower right')

5. 오버샘플링(Oversampling)

5.1 소수 클래스 오버샘플링하기

오버샘플링이란, 불균형 데이터셋에서 자주 사용되는 샘플링 방법으로, 소수 클래스를 오버샘플링하여 데이터셋을 다시 샘플링하는 것입니다.

pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]

pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]

- 넘파이(NumPy) 사용하기

부정 거래 데이터에서 적당한 수의 인덱스를 랜덤으로 선택하여 데이터셋의 균형을 수동으로 조정할 수 있습니다.

ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))

res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]

res_pos_features.shape
resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)

order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]

resampled_features.shape

- tf.data 사용하기

tf.data를 사용하여 균형 잡힌 데이터를 만드는 가장 쉬운 방법은 positivenegative 데이터셋으로 시작하여 이들을 병합하는 것입니다.

BUFFER_SIZE = 100000

def make_ds(features, labels):
    ds = tf.data.Dataset.from_tensor_slices((features, labels))#.cache()
    ds = ds.shuffle(BUFFER_SIZE).repeat()
    return ds

pos_ds = make_ds(pos_features, pos_labels)
neg_ds = make_ds(neg_features, neg_labels)

각 데이터셋은 (feature, label)으로 되어있습니다:

for features, label in pos_ds.take(1):
    print("Features:\n", features.numpy())
    print()
    print("Label: ", label.numpy())

experimental.sample_from_datasets를 사용해서 이 둘을 합칩니다:

resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(2)
for features, label in resampled_ds.take(1):
    print(label.numpy().mean())

이 데이터셋을 사용하려면 epoch당 스텝 수 설정이 필요합니다.

이 경우 “epoch”에 대한 정의는 덜 명확하기 때문에 각 부정적인 예를 한 번 보는 데 필요한 배치의 수라고 합시다:

resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)
resampled_steps_per_epoch

5.2 오버샘플링한 데이터 학습시키기

이제 클래스 가중치를 사용하는 대신 오버샘플링 한 데이터셋으로 모델을 학습시켜 비교해봅시다.

참고: 긍정 예제를 복제하여 데이터의 균형을 유지했기 때문에 총 데이터셋 크기가 더 크고 각 에포크에서 더 많은 스텝으로 학습합니다.

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# 이 데이터셋은 균형을 맞췄기 때문에 편향은 0으로 설정합니다.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(2) 

resampled_history = resampled_model.fit(
    resampled_ds,
    epochs=EPOCHS,
    steps_per_epoch=resampled_steps_per_epoch,
    callbacks = [early_stopping],
    validation_data=val_ds)

각 그래디언트를 업데이트할 때 학습 과정이 전체 데이터셋을 고려하는 경우, 이 오버샘플링은 기본적으로 클래스 가중치와 동일합니다.

그러나 여기서와 같이 배치별로 모델을 학습시킬 때, 오버 샘플링된 데이터는 보다 부드러운 그래디언트 신호를 제공합니다. 각각의 긍정인 예가 하나의 배치에서 큰 가중치를 가지지 않고, 매번 여러 다른 배치에서 작은 가중치를 가지기 때문입니다.

이렇게 부드러운 그래디언트 신호를 통해 모델을 보다 쉽게 학습시킬 수 있습니다.

5.3 학습 히스토리 확인하기

여기서는 학습 데이터가 검증 및 테스트 데이터와 전혀 다른 분포를 가지기 때문에 메트릭스의 분포가 달라집니다.

plot_metrics(resampled_history )

5.4 재학습시키기

균형 잡힌 데이터에 대한 학습이 더 쉽기 때문에 위의 학습 절차가 빠르게 오버피팅될 수 있습니다.

callbacks.EarlyStopping로 언제 훈련을 중단해야 하는지를 통제하여 에포크를 끝냅니다.

resampled_model = make_model()
resampled_model.load_weights(initial_weights)

# 이 데이터셋은 균형을 맞췄으므로 편형은 0으로 설정합니다.
output_layer = resampled_model.layers[-1] 
output_layer.bias.assign([0])

resampled_history = resampled_model.fit(
    resampled_ds,
    # 이는 실제 에포크가 아닙니다.
    steps_per_epoch = 20,
    epochs=10*EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_ds))

5.5 학습 히스토리 다시 확인하기

plot_metrics(resampled_history)

5.6 평가 메트릭스

train_predictions_resampled = resampled_model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_resampled = resampled_model.predict(test_features, batch_size=BATCH_SIZE)
resampled_results = resampled_model.evaluate(test_features, test_labels,
                                             batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(resampled_model.metrics_names, resampled_results):
    print(name, ': ', value)
print()

plot_cm(test_labels, test_predictions_weighted)

5.7 ROC 그리기

plot_roc("Train Baseline", train_labels, train_predictions_baseline, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions_baseline, color=colors[0], linestyle='--')

plot_roc("Train Weighted", train_labels, train_predictions_weighted, color=colors[1])
plot_roc("Test Weighted", test_labels, test_predictions_weighted, color=colors[1], linestyle='--')

plot_roc("Train Resampled", train_labels, train_predictions_resampled,  color=colors[2])
plot_roc("Test Resampled", test_labels, test_predictions_resampled,  color=colors[2], linestyle='--')
plt.legend(loc='lower right')

6. 주어진 문제에 이 튜토리얼 적용하기

불균형 데이터 분류는 소수 클래스에 해당하는 학습할 샘플이 거의 없기 때문에 모델 학습에 데이터를 적용할 때 항상 주의해야합니다. 항상 데이터부터 시작하여 가능한 한 많은 샘플을 수집하고, 모델이 소수 클래스를 최대한 활용할 수 있도록 어떤 피쳐가 관련되어 있는지 충분히 숙고해야 합니다. 어느 시점에서는 모델이 원하는 결과를 개선하고 산출하는 데 어려움을 겪을 수 있으므로 문제의 문맥과 다른 유형의 오류 간에 균형을 유지하는 것이 중요합니다.

Copyright 2019 The TensorFlow Authors.

#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.