人間だったら考えて

なんでよ?考えて考えてっ 人間だったら考えて

PyTorchを用いたListNetの実装

この記事はランク学習(Learning to Rank) Advent Calendar 2018 - Adventarの13本目の記事です

この記事は何?

ニューラルネットワークを用いたランク学習の手法として、ListNet*1が提案されています。

以前下の記事で、同じくニューラルネットワークを用いたランク学習の手法であるRankNetを紹介しましたが、ListNetはRankNetと異なり、Listwise手法に分類されます。
szdr.hatenablog.com

この記事では、PyTorchを用いたListNetの実装を紹介します。



ListNet

非常にざっくりとListNetの解説をします。詳細解説は以下文献が詳しいので、そちらをご確認ください。

著者スライド
Learning to Rank: from Pairwise Approach to Listwise Approach

日本語記事(Chainer実装付き)
qiita.com


訓練データにQ=\{q^{(1)}, q^{(2)}, \cdots, q^{(m)}\}個のクエリに関するデータ点が含まれており、クエリq^{(i)}に文書群d^{(i)}=(d_1^{(i)}, d_2^{(i)}, \cdots, d_{n^{(i)}}^{(i)})が紐づいているとします。

それぞれの文書には関連度評価値y^{(i)}=(y_1^{(i)}, y_2^{(i)}, \cdots, y_{n^{(i)}}^{(i)})(例えば、Excellent(4)・Perfect(3)・Good(2)・Fair(1)・Bad(0))が与えられています。

ランキングモデルをfとし、文書d_j^{(i)}の特徴量x_j^{(i)}を用いて、ランクスコアz_j^{(i)}=f(x_j^{(i)})が得られます。

…ここまで準備できたところで、論文中で"top one probability"と呼んでいる確率を定義します(ちょっとステップ飛ばしてますが)。

P_{y^{(i)}}(x_j^{(i)})=\exp(y_j^{(i)}) / \sum_{k=1}^{n^{(i)}} \exp (y_k^{(i)})

P_{z^{(i)}(f)}(x_j^{(i)})=\exp(f(x_j^{(i)})) / \sum_{k=1}^{n^{(i)}} \exp (f (x_k^{(i)}))

…なんか複雑っぽいですが、それぞれ文書関連度・ランクスコアに対するsoftmaxを計算しているだけです。


上の"top one probability"を使って、ListNetにおける損失関数(論文中では交差エントロピーを使用)は以下のように定義できます。

L \left( y ^ { ( i ) } , z ^ { ( i ) } \left( f \right) \right) = - \sum _ { j = 1 } ^ { n ^ { (i) } } P _ { y ^ { (i) } } \left( x _ { j } ^ { ( i ) } \right) \log \left( P _ { z ^ { (i) } \left( f \right) } \left( x _ { j } ^ { ( i ) } \right) \right)

…というわけで、損失関数が定義できたので、あとはこれを最適化するだけです。



PyTorchを用いたListNetの実装

それでは、本題のPyTorchを用いたListNetの実装を紹介します。

下の記事で紹介したRankNetの実装と重複しているコードも多いですが。。。
szdr.hatenablog.com


まずはネットワークの定義です。
今回は単純なfeed-forwardニューラルネットワークを使います。

class Net(nn.Module):
    def __init__(self, D):
        super(Net, self).__init__()
        self.l1 = nn.Linear(D, 10)
        self.l2 = nn.Linear(10, 1)

    def forward(self, x):
        x = torch.sigmoid(self.l1(x))
        x = self.l2(x)
        return x

次に、本実装のメインであるListNetの損失関数を実装します。

def listnet_loss(y_i, z_i):
    """
    y_i: (n_i, 1)
    z_i: (n_i, 1)
    """

    P_y_i = F.softmax(y_i, dim=0)
    P_z_i = F.softmax(z_i, dim=0)
    return - torch.sum(P_y_i * torch.log(P_z_i))

…めちゃくちゃあっさりしてますね… y_iは文書関連度のベクトル、z_iは予測スコアのベクトルを表しています。

PyTorchのCrossEntropyLoss使うともっとあっさり書けるんですかね?クエリによって文書数が異なるケース(CrossEntropyLossにおけるクラス数)でもうまく動くか分からなかったので、明示的に書いてみました。

上の実装ですが、一度に複数クエリに関するデータを受け取れない実装になっているので注意してください。その辺りはお好みで拡張を…


では実際に動かしてみます、精度評価はswapped-pairsとNDCGを使います。

def ndcg(ys_true, ys_pred):
    def dcg(ys_true, ys_pred):
        _, argsort = torch.sort(ys_pred, descending=True, dim=0)
        ys_true_sorted = ys_true[argsort]
        ret = 0
        for i, l in enumerate(ys_true_sorted, 1):
            ret += (2 ** l - 1) / np.log2(1 + i)
        return ret
    ideal_dcg = dcg(ys_true, ys_true)
    pred_dcg = dcg(ys_true, ys_pred)
    return pred_dcg / ideal_dcg


学習・精度評価を合わせたソースコードの全体像を公開しておきます。

github.com

上のコードを実行すると、epoch毎にvalidationにおけるswapped-pairsとndcgを出力します。

epoch: 1 valid swapped pairs: 1095/4950 ndcg: 0.8722
epoch: 2 valid swapped pairs: 787/4950 ndcg: 0.9366
epoch: 3 valid swapped pairs: 548/4950 ndcg: 0.9701
epoch: 4 valid swapped pairs: 385/4950 ndcg: 0.9841
epoch: 5 valid swapped pairs: 275/4950 ndcg: 0.9908
epoch: 6 valid swapped pairs: 224/4950 ndcg: 0.9937
epoch: 7 valid swapped pairs: 182/4950 ndcg: 0.9952
epoch: 8 valid swapped pairs: 146/4950 ndcg: 0.9966
epoch: 9 valid swapped pairs: 139/4950 ndcg: 0.9965
epoch: 10 valid swapped pairs: 113/4950 ndcg: 0.9972

学習を進めていくと、ちゃんとswapped-pairsの数が小さくなり、ndcgが向上していくことが分かります!



まとめ

この記事ではPyTorchを用いたListNetの実装を紹介しました。
ListNetはRankNetよりも効率的に学習でき、NDCGやMAPといった評価指標についても精度で勝つなど、かなり強力な手法だと思います。

*1:"Learning to Rank: From Pairwise Approach to Listwise Approach", Z. Cao, 2007.