본문 바로가기
ML&DL/Recommender System

[Recommender System / Paper review] #24 DeepFM: A Factorization-Machine based Neural Network for CTR Prediction

by 거북이주인장 2023. 4. 25.

Summary

  • 구글에서 발표한 wide & deep neural network 논문에서 linear 특성을 반영하는 wide 부분을 factorization machine으로 변경한 모형을 제시한다.
  • 구글 논문은 단순히 변수에 weight을 줘서 합하는 형태였다면, deepfm은 fm에서 제안한 방법을 사용한다. 즉, higher order interaction을 latent vector의 내적을 통해 모형에 포함한다.
  • 중요한 부분은 fm에서 사용하는 변수의 dense latent feature을 deep neural network에서 범주형 변수를 임베딩 벡터로 만드는 과정에서 사용한다는 점이다.
  • 이를 통해 low & high interaction의 효과를 모두 모형에 포함한다.

Motivation

  • 추천 시스템 분야에서는 아이템에 대한 ctr 예측 값을 기준으로 랭킹을 매기고 이 랭킹 결과가 유저에게 추천된다. 따라서 양질의 피쳐를 통해 더 정확한 ctr을 예측하는 알고리즘을 만드는 것이 중요한데, 이때 보통 피쳐는 매우 복잡한 형태를 띄는 경우가 많다.
  • 피쳐를 interaction term으로 포함하여 ctr을 예측하는 알고리즘으로 대표적인 것이 factorization machine이다. fm은 2 way feature interaction을 모델링 할 수 있지만, 그보다 더 높은 interaction은 고려하지 못한다.
  • 더 복잡한 형태의 interaction term을 모델링하는 방법이 제안되었으나, 이들은 모두 feature engineering을 직접 손으로 해야한다는 단점이 있다. 피쳐의 개수가 많아질수록 이는 현실적으로 불가능한 방법이다.
  • 이와 같이, 더 고차원 피쳐를 모델링 시에 사용하고 동시에 수동을 feature engineering하는 수고로움을 덜기 위해 deepfm이 개발되었다. 모델의 fm 부분에서는 lower-order interaction을 모델링하고 neural network 부분에서는 higher-order interaction을 모델링한다.

Approach

Notation

  • 데이터는 n개의 $(\mathcal{X}, y)$로 구성되고 m개의 피쳐를 가지고 있다.
  • ctr 문제는 $y \in \{0, 1\}$인 확률을 예측하는 모형을 만드는 것이다. 여기서 $y=1$이라면 클릭됐다는 것이고 $y=0$이라면 클릭되지 않았다는 뜻이다.
  • m개의 범주를 가지는 범주형 변수는 m차원의 one-hot 벡터로 변환된다.

DeepFM

목표는 low & high order feature interaction을 학습하는 것이다. 기존의 fm은 2 way interaction만 모형에 포함하여 high order interaction을 포함하지 않는다는 단점이 있었다. deepfm의 아키텍쳐는 아래와 같다.

m개의 sparse features가 들어오면 임베딩 벡터로 변환하고 fm layer와 hidden layer을 태운다. fm layer에서는 2 way interaction의 형태로 low interaction을 학습한다. hidden layer은 mlp을 통해서 high order interaction을 학습한다. 구체적으로 fm 파트와 deep neural network의 파트를 살펴보자.

FM Component

  • factorization machine 논문에서 제안된 fm과 동일한 아키텍쳐이다.
  • linear interaction을 addition을 통해서 모델링하고 2 way interaction을 latent vector간의 내적을 통해 모델링한다.
  • fm이 가지는 의의는, interaction term을 단순히 $w_j$가 아닌, $<v_i, v_j>$와 같이 두 latent vector의 내적을 통해 모델링한다는 것이다. sparse data에서는 두 변수가 동시에 발생하는 데이터 자체가 없을 수 있고 이런 상황에서 학습을 진행하면 interaction term에 대한 계수 추정치가 0으로 나올 수 있는데, fm은 dense latent vector의 내적으로 계수를 추정하여 이런 상황을 방지할 수 있다.

Deep Component

  • 왼쪽 그림은 deepfm의 dnn 아키텍쳐로 sprase feature을 임베딩 벡터로 만들고 mlp을 태워서 high-order feature interaction을 모델링한다.
  • 오른쪽 그림은 상세한 mlp 아키텍쳐이다. 여기서 중요한 부분은 sparse data을 임베팅 벡터로 만드는데 가중치 행렬로 $V$가 사용된다는 것이다.
  • $V$는 fm 부분에서 latent vector로 사용된 행렬로써 입력 피쳐를 임베딩 벡터로 압축하는데 사용된다.
  • 이와 같이 deepfm에서 fm과 dnn이 동일한 피쳐 임베딩을 공유한다는 점이 deepfm의 큰 특징 중 하나이다. 이를 통해 raw feature로부터 low order interaction에서 high order interaction까지 학습할 수 있다는 것과 feature engineering을 손으로 직접 하지 않아도 된다는 장점이 있다.

Relationship with other neural networks

  • FNN
    • FM으로 초기화되는 feedforward neural network이다. 
    • FM으로 pre-training하는 것의 단점은 임베딩 파라미터가 fm에 의해서 영향을 받을 수 있고 pre-training 단계에서 효율성이 감소될 수 있다.
    • FNN은 higher-order interaction만 모델링한다.
    • deepfm은 pre-training 단계가 필요 없고 low & high order interaction 모두 모델링한다.
  • PNN
    • high order interaction을 모델링하기 위해 PNN은 product layer을 추가한다.
    • 하지만 FNN처럼 low order interaction을 모델링하지 못한다
  • Wide & Deep
    • 구글에서 제안한 모형으로 linearity을 통해 low interaction과 dnn을 통해 high interaction을 모델링하고자 한다.
    • 하지만 wide 파트에서 변수들을 interaction으로 묶는 과정에서 feature engineering이 필요하지만 deepfm은 필요하지 않다.
    • wide 파트를 fm으로 바꾸는것이 가장 간단한 변형인데, 이는 피쳐 임베딩을 deep component와 공유하지 않는 단점이 있고 low & high interaction을 모형에 제대로 반영하지 못할 수도 있다.

Results

Datasets

  • criteo dataset: 13개의 연속형 변수과 26개의 범주형 변수를 가지는 유저의 클릭 데이터
  • company* dataset: 한 회사로부터 논문 저자들이 직접 수집한 데이터

Evaluation metrics

  • ctr 예측 문제이므로 AUC
  • Logloss

Performance evaluation

  • LR 모형 대비 시간이 얼마나 걸렸는지 측정한 결과 deepfm이 적게 걸린 것으로 나타났다.

  • 두 데이터에서 deepfm의 auc가 가장 높고 logloss도 가장 작다.

Conclusion

  • mf을 지나서 rs에 neural network을 어떻게 적용하는지 흐름을 보고 있는데, 여태까지 읽은 논문들은 꽤나 간단한 방법론으로 보인다.
  • 이 논문은 fm의 연장선으로 ctr 예측 문제에서 범주형 변수가 많을 시에 이를 효과적으로 다룰 수 있는 방법을 제시한다는 점에서 시도해볼법한 알고리즘인 것 같다.
  • 인상적인 부분은 fm의 latent vector을 dnn에서 학습할 시에 가중치로 가져온다는 것이었고 이를 통해 low & high order interaction을 동시에 달성한다고 하였다. 과연 이것의 효과가 있을까.. 실험에서 FM & DNN이 방금 말한 구조인 것 같은데.. toy data을 가져와서 방법을 확인해보는게 가장 좋을 것 같다는 생각이 들었다.

댓글