Matching Network

Matching Network 的思想是 学习一个 Bi-LSTM 来对有类标的图片进行 embedding,学习一个 Attention-LSTM 来对待预测的图片进行 embedding,然后计算他们的余弦距离进行分类

如上图所示,$g_\theta$ 是 support set 的 embedding 方法,$f_\theta$ 是 query set 的 embedding 方法。假设 $S = \{x_1, x_2, \cdots, x_k\}$、$\{y_1, y_2, \cdots, y_k\}$ 分别是 support set 的 feature 和 label,$\hat{x}$、$\hat{y}$ 分别是 query set 某个样本的 feature 和 label,那么该样本的预测结果为

$$ \hat{y} = \sum_{i = 1}^k a(\hat{x}, x_i) y_i $$

,其中 $a$ 为 Attention Kernel,用来度量 $\hat{x}$ 与 $x_i$ 的匹配度,是根据 feature 的余弦距离(cosine distance)再加上一个 softmax 函数计算得到的:

$$ a(\hat{x}, x_i) = \frac{e^{\operatorname{Cosine}(f(\hat{x}), g(x_i))}} {\sum_{j = 1}^k e^{\operatorname{Cosine}(f(\hat{x}), g(x_j))}} $$

对于两个 embedding 网络 $g$ 和 $f$,其设计又别有洞天:

  • 对于 support set embedding 网络 $g$,作者使用了一个 Bi-LSTM 对所有样本进行特征提取,然后用第 $i$ 个样本在双向 LSTM 中的对应位置的两份特征,作为该样本的 embedding。这样设计是为了让每个 task 中的样本 embedding 都是基于当前 support set 得到的,即

$$ g(x_i) = \operatorname{Bi-LSTM}(x_i, S) $$

  • 而对于 query set embedding 网络 $f$,作者则使用了 Attention-LSTM 模型,使得 query set 样本的 embedding 同样根据当前 support set 得到,即

$$ f(\hat{x}) = \operatorname{Attention-LSTM}(\hat{x}, S) $$