KNN过时了!ANNs比它快了整整380倍

灰太狼 2023-01-02 01:30 171阅读 0赞

全文共5889字,预计学习时长15分钟

KNN过时了!ANNs比它快了整整380倍

图源:Google

我们正在经历一场灭绝性的大事件,颇受欢迎的KNN算法正面临淘汰,而几乎每门数据科学课上都会学习这种算法!

KNN背景

寻找与给定项目的K个相似项的做法在机器学习界被广泛称为“相似”搜索或“最近邻”(NN)搜索。最广为人知的最近邻搜索算法便是K最近邻(KNN)算法。

它的用途很广,在已有的物品集如手机电商目录中,运用KNN,便可以从这一整个目录中找到一个少数(K)最近邻来发起新的搜索请求。

比如说,在下面这个例子中,如果将K设为3,那每一个“iphone”的3最近邻便是其他“iphone”。类似地,对每一个“Samsung”的3最近邻便是所有的Samsungs。

KNN过时了!ANNs比它快了整整380倍

6款手机的目录

KNN的问题

尽管KNN很适合寻找类似项,但是它使用的是穷尽匹配距离计算方法。如果数据中有1000项的话,要找到一个新产品中K=3的最近邻,该算法就会把新产品与数据库中其他产品一起执行1000次距离计算。

这还不算太糟糕。但试想在现实中,顾客对顾客(C2C)的市场数据库里有着上百万的产品,并且每天都可能会上传新的产品。把新产品与所有数百万的产品进行比对的做法确实太浪费时间了,也就是说根本无法拓展。

解决方法

让最近邻在大量的数据中也适用的解决方法就是彻底规避这种暴力的距离计算法,代之以更复杂的一类算法,名为近似最近邻(ANN)。

KNN过时了!ANNs比它快了整整380倍

近似最近邻(ANN)

严格来说,近似最近邻这种算法在最近邻搜索中允许少量的错误。但在现实中的C2C市场里,“真正的”最近邻的数字要比被搜索的“K”最近邻高。相比暴力的KNN,ANN能够在短时间内达到惊人的准确率。下列有几种ANN算法:

· Spotify的【ANNOY】

· Google的【ScaNN】

· Facebook的【Faiss】

· 还有个人最爱:分层可导航小世界图【HNSW】

下面我们把焦点从Python的 sklearn中的KNN算法转向在Python的hnswlib 包中的HNSW图这一出色的ANN算法。接下来将使用大型的【Amazon product dataset】,其中包含‘手机&配件’分类中的527000个产品,以此来证明HNSW的速度非常快(准确说是快380倍),同时还能得到与sklearn的KNN百分之99.3相同的结果。

在HNSW中【paper @ arxiv】,作者使用多层图来描述ANN算法。在插入元素的时候,HNSW图是通过随机选择每个元素的最大层,以指数递减的概率分布逐步建立的。这保证了layer=0时有很多元素来进行精确搜索,而layer=2时有e^-2较少的元素,便于粗略搜索。

最近邻的搜索从最顶层开始粗略搜索,继而向下层递进,直至在最下层使用贪心算法的线路来遍历全图,找到所需数字的近邻。

KNN过时了!ANNs比它快了整整380倍

HNSW图的结构。在最顶层开始最近邻的搜索(粗略搜索),在最底层结束搜索(精细搜索)。

HNSWPython包

整个HNSW算法都是通过C++写成的,并与Python绑定,能用Python包管理工具(pip),通过打字安装在你的机器里:pip install hnswlib。安装这个包并导入后,创建HNSW图需要几个步骤,将其包装成下列的便捷函数。

  1. import hnswlib
  2. import numpy as npdef fit_hnsw_index(features, ef=100, M=16,save_index_file=False):
  3. # Convenience function to create HNSWgraph
  4. # features : list of lists containingthe embeddings
  5. # ef, M: parameters to tune the HNSWalgorithm
  6. num_elements = len(features)
  7. labels_index =np.arange(num_elements) EMBEDDING_SIZE= len(features[0]) # Declaring index
  8. # possible space options are l2,cosine or ip
  9. p = hnswlib.Index(space='l2',dim=EMBEDDING_SIZE) # Initing index -the maximum number of elements should be known
  10. p.init_index(max_elements=num_elements, ef_construction=ef, M=M) # Element insertion
  11. int_labels = p.add_items(features,labels_index) # Controlling the recallby setting ef
  12. # ef should always be > k
  13. p.set_ef(ef)
  14. # If you want to save the graph to afile
  15. if save_index_file:
  16. p.save_index(save_index_file)
  17. return p

创建HNSW索引后,查询“K”最近邻就像调用下列一行代码一样简单。

  1. ann_neighbor_indices, ann_distances = p.knn_query(features, k)

KNNvs. ANN 基准实验

首先下载一个有50万行以上的大型数据集。接着用预先训练好的 fasttext句子向量,将文本列转换为 300d的嵌入向量。

然后训练KNN和HNSW ANN模型,输入长短不一的数据[1000, 10000,100000, len(data)],以此来测量数据大小对速度的影响。

最后,从两个模型中寻求K=10和100的最近邻来测量K对速度的影响。首先引入必要的包和模型。这会花一点时间,fasttext模型需要从网上下载。

  1. # Imports
  2. # For input data pre-processing
  3. import json
  4. import gzip
  5. import pandas as pd
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. import fasttext.util
  9. fasttext.util.download_model('en', if_exists='ignore') # English pre-trainedmodel
  10. ft = fasttext.load_model('cc.en.300.bin')# For KNN vs ANN benchmarking
  11. from datetime import datetime
  12. from tqdm import tqdm
  13. from sklearn.neighbors import NearestNeighbors
  14. import hnswlib

数据

使用【Amazon product dataset】,它包含了‘手机&配件’分类中527000个产品。从这个链接下载数据集,并运行以下代码来将其转换成一个数据框架。只需用到产品的title列,因为要用它来寻找相似产品。

  1. # Data: http://deepyeti.ucsd.edu/jianmo/amazon/
  2. data = []
  3. with gzip.open('meta_Cell_Phones_and_Accessories.json.gz') as f:
  4. for l in f:
  5. data.append(json.loads(l.strip()))# Pre-Processing: https://colab.research.google.com/drive/1Zv6MARGQcrBbLHyjPVVMZVnRWsRnVMpV#scrollTo=LgWrDtZ94w89
  6. # Convert list into pandas dataframe
  7. df = pd.DataFrame.from_dict(data)
  8. df.fillna('', inplace=True)# Filter unformatted rows
  9. df = df[~df.title.str.contains('getTime')]# Restrict to just 'Cell Phones andAccessories'
  10. df = df[df['main_cat']=='Cell Phones & Accessories']# Reset index
  11. df.reset_index(inplace=True, drop=True)# Only keep the title columns
  12. df = df[['title']]# Check the df
  13. print(df.shape)
  14. df.head()

如果一切运行顺利,就能得到如下输出结果。

KNN过时了!ANNs比它快了整整380倍

亚马逊产品数据集

嵌入

要在文本数据上运行相似性搜索,就必须要首先将其转变为数字向量。一个快捷的方法就是使用预先训练好的网络嵌入层,比如Facebook提供的【FastText】。因为所有行有要有相同长度的向量,因此不用考虑title里的字数,应该在df中将 get_sentence_vector方法运用到title列上。

在嵌入完成后,将emb列作为一个列表提取出来,输入到NN算法中。当然在这一步之前需要先进行文本清理预处理。另外,利用微调后的嵌入模型也是个不错的选择。

  1. # Title Embedding using FastText Sentence Embedding
  2. df['emb'] = df['title'].apply(ft.get_sentence_vector)# Extract out theembeddings column as a list of lists for input to our NN algos
  3. X = [item.tolist() for item in df['emb'].values]

基准

有了算法输入后,就可以开始基准测试了。以搜索空间中的产品数量和被搜索的K最近邻为循环,在循环中运行测试。

在每次迭代时,除了对每个算法所花费的时间进行计时外,还要检查pct_overlap ,作为KNN最近邻和同时被ANN抓取的最近邻数量之比。

注意!整个测试要在8核,30GB RAM的机器上全天候运行六天左右,所以会花一些时间。当然也可以通过多重处理来加速,因为每一次运行实际上是互相独立的。

  1. # Number of products for benchmark loop
  2. n_products = [1000, 10000, 100000, len(X)]# Number of neighbors for benchmarkloop
  3. n_neighbors = [10, 100]# Dictionary to save metric results for each iteration
  4. metrics = {'products':[], 'k':[], 'knn_time':[], 'ann_time':[],'pct_overlap':[]}for products in tqdm(n_products):
  5. # "products" number ofproducts included in the search space
  6. features = X[:products]
  7. for k in tqdm(n_neighbors):
  8. # "K" Nearest Neighborsearch
  9. # KNN
  10. knn_start = datetime.now()
  11. nbrs = NearestNeighbors(n_neighbors=k,metric='euclidean').fit(features)
  12. knn_distances,knn_neighbor_indices = nbrs.kneighbors(X)
  13. knn_end = datetime.now()
  14. metrics['knn_time'].append((knn_end - knn_start).total_seconds())
  15. # HNSW ANN
  16. ann_start = datetime.now()
  17. p = fit_hnsw_index(features,ef=k*10)
  18. ann_neighbor_indices,ann_distances = p.knn_query(features, k)
  19. ann_end = datetime.now()
  20. metrics['ann_time'].append((ann_end - ann_start).total_seconds())
  21. # Average Percent Overlap inNearest Neighbors across all "products"
  22. metrics['pct_overlap'].append(np.mean([len(np.intersect1d(knn_neighbor_indices[i],ann_neighbor_indices[i]))/k for i in range(len(features))]))
  23. metrics['products'].append(products)
  24. metrics['k'].append(k)
  25. metrics_df = pd.DataFrame(metrics)
  26. metrics_df.to_csv('metrics_df.csv', index=False)
  27. metrics_df

最后运行的输出结果如下图。可以看到,HNSW ANN完败KNN!

KNN过时了!ANNs比它快了整整380倍

结果

让我们以图表的形式来看看基准结果,真正体会到差异的大小。我将会用到标准matplotlib代码来绘制这些图。X轴以 log为单位。

差距是很大的!在寻求K=10和100的最近邻时,HNSW ANN完败KNN。当搜索空间包含约50万的产品时,用ANN搜寻100的最近邻要快380倍。同时,KNN和ANN都找到了99.3%相同的最近邻。

KNN过时了!ANNs比它快了整整380倍

( HNSW ANN能够以快380倍的速度,搜索包含约50万元素的空间,得出与Sklearn的KNN99.3%相同的最近邻,并且能够在搜索空间里找到每一个元素的K=100最近邻。)

KNN过时了!ANNs比它快了整整380倍

综上所述,KNN过时了!这点是毋庸置疑的,我们有了新的选择,没有理由再去用sklearn的KNN了。

KNN过时了!ANNs比它快了整整380倍

一起分享AI学习与发展的干货

欢迎关注全平台AI垂类自媒体 “读芯术”

format_png

(添加小编微信:dxsxbb,加入读者圈,一起讨论最新鲜的人工智能科技哦~)

发表评论

表情:
评论列表 (有 0 条评论,171人围观)

还没有评论,来说两句吧...

相关阅读

    相关 MyBatis 100

    比 MyBatis 效率快 100 倍的条件检索引擎,天生支持联表,使一行代码实现复杂列表检索成为可能! 2开源协议 使用Apache-2.0开源协议 3界面展示

    相关 MySQL801,太颠覆

    数字化时代,数据即价值。商战即信息战,如何从海量数据中提取精准的用户群体信息成为众多企业经营的重中之重,这就对开发工程师在速度和精准度方面的要求越来越高。 > 海量订单如何精

    相关 MySQL801,太颠覆

    行业内卷的话题热度居高不退,程序员群体的职业焦虑也尤为明显,在更新迭代日新月异的技术领域,对新技术软件保持敏感是最起码的职业尊重,尤其是在大数据领域,能否运用新技术解决实际问题

    相关 整整1年...

    一年前的今天我独自一人登上了齐齐哈尔开往上海的火车。那时的我还是在校的大学生。。。 而一年后的今天,我已经在另外一个城市,以一个民工的身份在苦苦地熬日子。。。 那时我是充满