從圖片相似度學習圖片的表示
很多時候帶分類標注的圖片樣本是很難獲得的,但是圖片之間的相似度卻不難獲得。最簡單的方式有幾個:
- 視頻里相鄰的幀是相似的。
- 如果有用戶日志數據,可以基于協同過濾計算圖片之間的相似度。
最早用來從相似圖片數據集上學習圖片表示的網絡結構是siamese網絡。
siamese.png
兩幅圖通過兩個共享權重的CNN得到各自的表示。而各自表示的距離決定了他們是相似還是不相似。
在siamese網絡之后,又提出了用triplet loss來學習圖片的表示,大概思路如下:
- 拿到3張圖片A, B, C。其中A,B相似,A,C不相似。
- 學到A, B, C 的表示,使得A,B之間的距離盡量小,而A,C之間的距離盡量大。
用mxnet實現triplet loss很簡單,代碼如下:
def get_net(batch_size): same = mx.sym.Variable('same') diff = mx.sym.Variable('diff') anchor = mx.sym.Variable('anchor') one = mx.sym.Variable('one') one = mx.sym.Reshape(data = one, shape = (-1, 1)) conv_weight = [] conv_bias = [] for i in range(3): conv_weight.append(mx.sym.Variable('conv' + str(i) + '_weight')) conv_bias.append(mx.sym.Variable('conv' + str(i) + '_bias')) fc_weight = mx.sym.Variable('fc_weight') fc_bias = mx.sym.Variable('fc_bias') fa = get_conv(anchor, conv_weight, conv_bias, fc_weight, fc_bias) fs = get_conv(same, conv_weight, conv_bias, fc_weight, fc_bias) fd = get_conv(diff, conv_weight, conv_bias, fc_weight, fc_bias) fs = fa - fs fd = fa - fd fs = fs * fs fd = fd * fd fs = mx.sym.sum(fs, axis = 1, keepdims = 1) fd = mx.sym.sum(fd, axis = 1, keepdims = 1) loss = fd - fs loss = one - loss loss = mx.sym.Activation(data = loss, act_type = 'relu') return mx.sym.MakeLoss(loss)
這里conv_weight[], fc_weight, conv_bias[], fc_bias是兩個CNN網絡共享的模型。理論上這里可以用任何的CNN網絡(AlexNet, GoogleNet, ResNet)。我們用了一個特別簡單的CNN,如下:
def get_conv(data, conv_weight, conv_bias, fc_weight, fc_bias): cdata = data ks = [5, 3, 3] for i in range(3): cdata = mx.sym.Convolution(data=cdata, kernel=(ks[i],ks[i]), num_filter=32, weight = conv_weight[i], bias = conv_bias[i], name = 'conv' + str(i)) cdata = mx.sym.Pooling(data=cdata, pool_type="avg", kernel=(2,2), stride=(1, 1)) cdata = mx.sym.Activation(data=cdata, act_type="relu") cdata = mx.sym.Flatten(data = cdata) cdata = mx.sym.FullyConnected(data = cdata, num_hidden = 1024, weight = fc_weight, bias = fc_bias, name='fc') cdata = mx.sym.L2Normalization(data = cdata) return cdata
Triple loss用的 Simultaneous Feature Learning and Hash Coding with Deep Neural Networks 里的定義:
Triple Loss
下面是在cifar10數據集上測試的結果。為了形象的表示,采用了圖片檢索的方式(因為不是論文,所以就不那么嚴謹了)。在訓練集上學習圖片的表示,然后對于測試集的一張隨機圖片,找到測試集上和他最相似的其他圖片:
cifar_triple.png
在其他的論文中還有一些其它評測方式,比如學習到表示后,用一個SVM去學習分類,看看分類的準確度相比End-End的CNN如何。基本的結論都是精度會稍微低一些,但是沒用明顯區別。這說明學到的表示是靠譜的。
來自:http://www.jianshu.com/p/70a66c8f73d3
本文由用戶 suiaing 自行上傳分享,僅供網友學習交流。所有權歸原作者,若您的權利被侵害,請聯系管理員。
轉載本站原創文章,請注明出處,并保留原始鏈接、圖片水印。
本站是一個以用戶分享為主的開源技術平臺,歡迎各類分享!