進(jìn)階使用
前置知識(shí)
BM25簡(jiǎn)介
BM25算法(Best Matching 25)是一種廣泛用于信息檢索領(lǐng)域的排名函數(shù),用于在給定查詢(Query)時(shí)對(duì)一組文檔(Document)進(jìn)行評(píng)分和排序。BM25在計(jì)算Query和Document之間的相似度時(shí),本質(zhì)上是依次計(jì)算Query中每個(gè)單詞和Document的相關(guān)性,然后對(duì)每個(gè)單詞的相關(guān)性進(jìn)行加權(quán)求和。BM25算法一般可以表示為如下形式:
上式中,q 和 d 分別表示用來(lái)計(jì)算相似度的Query和Document, qi 表示 q 的第 i 個(gè)單詞,R(qi, d) 表示單詞 qi 和文檔 d 的相關(guān)性,Wi 表示單詞 qi 的權(quán)重,計(jì)算得到的 score(q, d) 表示 q 和 d 的相關(guān)性得分,得分越高表示 q 和 d 越相似。Wi 和 R(qi, d) 一般可以表示為如下形式:
其中,N 表示總文檔數(shù),N(qi) 表示包含單詞 qi 的文檔數(shù),tf(qi, d) 表示qi 在文檔 d 中的詞頻,Ld 表示文檔 d 的長(zhǎng)度,Lavg 表示平均文檔長(zhǎng)度,k1 和 b 是分別用來(lái)控制 tf(qi, d) 和 Ld 對(duì)得分影響的超參數(shù)。
稀疏向量生成
在檢索場(chǎng)景中,為了讓BM25算法的Score方便進(jìn)行計(jì)算,通常分別對(duì)Document和Query進(jìn)行編碼,然后通過(guò)點(diǎn)積的方式計(jì)算出兩者的相似度。得益于BM25原理的特性,其原生支持將Score拆分為兩部分Sparse Vector,DashText提供了encode_document
以及encode_query
兩個(gè)接口來(lái)分別實(shí)現(xiàn)這兩部分向量的生成,其生成鏈路如下圖所示:
最終生成的稀疏向量可表示為:
Score/距離計(jì)算
生成d和q的稀疏向量后,就可以通過(guò)簡(jiǎn)單的點(diǎn)積進(jìn)行距離計(jì)算,即將相同單詞上的值對(duì)應(yīng)相乘再求和,通過(guò)稀疏向量計(jì)算距離的方式如下所示:
上述計(jì)算方式本質(zhì)上是通過(guò)點(diǎn)積來(lái)計(jì)算的,score 越大表示越相似,如果需要結(jié)合Dense Vector一起進(jìn)行距離度量時(shí),需要對(duì)齊距離度量方式。也就是說(shuō),在結(jié)合Dense Vector+Sparse Vector的場(chǎng)景中,距離計(jì)算只支持點(diǎn)積度量方式。
如何自訓(xùn)練模型
考慮到內(nèi)置的BM25 Model是基于通用語(yǔ)料(中文Wiki語(yǔ)料)訓(xùn)練得到,在特定領(lǐng)域下通常不能表現(xiàn)出最佳的效果。因此,在一些特定場(chǎng)景下,通常建議訓(xùn)練自定義BM25模型。使用DashText來(lái)訓(xùn)練自定義模型時(shí)一般需要遵循以下步驟:
Step1:確認(rèn)使用場(chǎng)景
當(dāng)準(zhǔn)備使用SparseVector來(lái)進(jìn)行信息檢索時(shí),應(yīng)提前考慮當(dāng)前場(chǎng)景下的Query以及Document來(lái)源,通常需要提前準(zhǔn)備好一定數(shù)量Document來(lái)入庫(kù),這些Document通常需要和特定的業(yè)務(wù)場(chǎng)景直接相關(guān)。
Step2:準(zhǔn)備語(yǔ)料
根據(jù)BM25原理,語(yǔ)料直接決定了BM25模型的參數(shù)。通常應(yīng)按照以下幾個(gè)原則來(lái)準(zhǔn)備語(yǔ)料:
語(yǔ)料來(lái)源應(yīng)盡可能反映對(duì)應(yīng)場(chǎng)景的特性,盡可能讓 N(qi) 能夠反映對(duì)應(yīng)真實(shí)場(chǎng)景的詞頻信息。
調(diào)節(jié)合理的語(yǔ)料切片長(zhǎng)度和切片數(shù)量,避免出現(xiàn)語(yǔ)料當(dāng)中只有少量長(zhǎng)文本的情況。
一般情況下,如無(wú)特殊要求或限制,可以直接將Step1準(zhǔn)備的一系列Document組織為語(yǔ)料即可。
Step3:準(zhǔn)備Tokenizer
Tokenizer決定了分詞的結(jié)果,分詞的結(jié)果則直接影響Sparse Vector的生成,在特定領(lǐng)域下使用自定義Tokenizer會(huì)達(dá)到更好的效果。DashText提供了兩種擴(kuò)展Tokenizer的方式:
使用自定義詞表:DashText內(nèi)置的Jieba Tokenizer支持傳入自定義詞表。(Java SDK暫不支持該功能)
from dashtext import TextTokenizer, SparseVectorEncoder
my_tokenizer = TextTokenizer.from_pretrained(model_name='Jieba', dict='dict.txt')
my_encoder = SparseVectorEncoder(tokenize_function=my_tokenizer.tokenize)
使用自定義Tokenizer:DashText支持任務(wù)自定義的Tokenizer,只需提供一個(gè)符合
Callable[[str], List[str]]
簽名的Tokenize函數(shù)即可。
from dashtext import SparseVectorEncoder
from transformers import BertTokenizer
my_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
my_encoder = SparseVectorEncoder(tokenize_function=my_tokenizer.tokenize)
import com.aliyun.dashtext.common.DashTextException;
import com.aliyun.dashtext.common.ErrorCode;
import com.aliyun.dashtext.encoder.SparseVectorEncoder;
import com.aliyun.dashtext.tokenizer.BaseTokenizer;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
public class Main {
public static class MyTokenizer implements BaseTokenizer {
@Override
public List<String> tokenize(String s) throws DashTextException {
if (s == null) {
throw new DashTextException(ErrorCode.INVALID_ARGUMENT);
}
// 使用正則表達(dá)式將文本按空白符和標(biāo)點(diǎn)符號(hào)分割,并轉(zhuǎn)換為小寫
return Arrays.stream(s.split("\\s+|(?<!\\d)[.,](?!\\d)"))
.map(String::toLowerCase)
.filter(token -> !token.isEmpty()) // 過(guò)濾掉空字符串
.collect(Collectors.toList());
}
}
public static void main(String[] args) {
SparseVectorEncoder encoder = new SparseVectorEncoder(new MyTokenizer());
}
}
Step4:訓(xùn)練模型
實(shí)際上,這里的“訓(xùn)練”本質(zhì)上是一個(gè)“統(tǒng)計(jì)”參數(shù)的過(guò)程。由于訓(xùn)練自定義模型的過(guò)程中包含著大量Tokenizing/Hashing過(guò)程,所以可能會(huì)耗費(fèi)一定的時(shí)間。DashText提供了SparseVectorEncoder.train
接口可以用來(lái)訓(xùn)練模型。
Step5:調(diào)參優(yōu)化(可選)
模型訓(xùn)練完成后,可以準(zhǔn)備部分驗(yàn)證數(shù)據(jù)集以及通過(guò)微調(diào) k1 和 b 來(lái)達(dá)到最佳的召回效果。調(diào)節(jié)k1和b一般需要遵循以下原則:
調(diào)節(jié)k1 (1.2 < k1 < 2)可控制Document詞頻對(duì)Score的影響,k1 越大Document的詞頻對(duì)Score的貢獻(xiàn)越小。
調(diào)節(jié)b (0 < b < 1)可控制文檔長(zhǎng)度對(duì)Score的影響,b 越大表示文檔長(zhǎng)度對(duì)權(quán)重的影響越大
一般情況下,如無(wú)特殊要求或限制,不需要調(diào)整 k1 和 b。
Step6:Finetune模型(可選)
實(shí)際場(chǎng)景下,可能會(huì)存在需要補(bǔ)充訓(xùn)練語(yǔ)料來(lái)增量式地更新BM25模型參數(shù)的情況。DashText的SparseVectorEncoder.train
接口原生支持模型的增量更新。需要注意的是,模型更改之后,使用舊模型進(jìn)行編碼并已入庫(kù)的向量就失去了時(shí)效性,一般需要重新入庫(kù)。
示例代碼
以下是一個(gè)簡(jiǎn)單完整的自訓(xùn)練模型示例。
from dashtext import SparseVectorEncoder
from pydantic import BaseModel
from typing import Dict, List
class Result(BaseModel):
doc: str
score: float
def calculate_score(query_vector: Dict[int, float], document_vector: Dict[int, float]) -> float:
score = 0.0
for key, value in query_vector.items():
if key in document_vector:
score += value * document_vector[key]
return score
# 創(chuàng)建空SparseVectorEncoder(可以設(shè)置自定義Tokenizer)
encoder = SparseVectorEncoder()
# step1: 準(zhǔn)備語(yǔ)料以及Documents
corpus_document: List[str] = [
"The quick brown fox rapidly and agilely leaps over the lazy dog that lies idly by the roadside.",
"Never jump over the lazy dog quickly",
"A fox is quick and jumps over dogs",
"The quick brown fox",
"Dogs are domestic animals",
"Some dog breeds are quick and jump high",
"Foxes are wild animals and often have a brown coat",
]
# step2: 訓(xùn)練BM25 Model
encoder.train(corpus_document)
# step3: 調(diào)參優(yōu)化BM25 Model
query: str = "quick brown fox"
print(f"query: {query}")
k1s = [1.0, 1.5]
bs = [0.5, 0.75]
for k1, b in zip(k1s, bs):
print(f"current k1: {k1}, b: ")
encoder.b = b
encoder.k1 = k1
query_vector = encoder.encode_queries(query)
results: List[Result] = []
for idx, doc in enumerate(corpus_document):
doc_vector = encoder.encode_documents(doc)
score = calculate_score(query_vector, doc_vector)
results.append(Result(doc=doc, score=score))
results.sort(key=lambda r: r.score, reverse=True)
for result in results:
print(result)
# step4: 選擇最優(yōu)參數(shù)并保存模型
encoder.b = 0.75
encoder.k1 = 1.5
encoder.dump("./model.json")
# step5: 后續(xù)使用時(shí)可以加載模型
new_encoder = SparseVectorEncoder()
bm25_model_path = "./model.json"
new_encoder.load(bm25_model_path)
# step6: 對(duì)模型進(jìn)行finetune并保存
extra_corpus: List[str] = [
"The fast fox jumps over the lazy, chubby dog",
"A swift fox hops over a napping old dog",
"The quick fox leaps over the sleepy, plump dog",
"The agile fox jumps over the dozing, heavy-set dog",
"A speedy fox vaults over a lazy, old dog lying in the sun"
]
new_encoder.train(extra_corpus)
new_bm25_model_path = "new_model.json"
new_encoder.dump(new_bm25_model_path)
import com.aliyun.dashtext.encoder.SparseVectorEncoder;
import java.io.*;
import java.util.*;
public class Main {
public static class Result {
public String doc;
public float score;
public Result(String doc, float score) {
this.doc = doc;
this.score = score;
}
@Override
public String toString() {
return String.format("Result(doc=%s, score=%f)", doc, score);
}
}
public static float calculateScore(Map<Long, Float> queryVector, Map<Long, Float> documentVector) {
float score = 0.0f;
for (Map.Entry<Long, Float> entry : queryVector.entrySet()) {
if (documentVector.containsKey(entry.getKey())) {
score += entry.getValue() * documentVector.get(entry.getKey());
}
}
return score;
}
public static void main(String[] args) throws IOException {
// 創(chuàng)建空SparseVectorEncoder(可以設(shè)置自定義Tokenizer)
SparseVectorEncoder encoder = new SparseVectorEncoder();
// step1: 準(zhǔn)備語(yǔ)料以及Documents
List<String> corpusDocument = Arrays.asList(
"The quick brown fox rapidly and agilely leaps over the lazy dog that lies idly by the roadside.",
"Never jump over the lazy dog quickly",
"A fox is quick and jumps over dogs",
"The quick brown fox",
"Dogs are domestic animals",
"Some dog breeds are quick and jump high",
"Foxes are wild animals and often have a brown coat"
);
// step2: 訓(xùn)練BM25 Model
encoder.train(corpusDocument);
// step3: 調(diào)參優(yōu)化BM25 Model
String query = "quick brown fox";
System.out.println("query: " + query);
float[] k1s = {1.0f, 1.5f};
float[] bs = {0.5f, 0.75f};
for (int i = 0; i < k1s.length; i++) {
float k1 = k1s[i];
float b = bs[i];
System.out.println("current k1: " + k1 + ", b: " + b);
encoder.setB(b);
encoder.setK1(k1);
Map<Long, Float> queryVector = encoder.encodeQueries(query);
List<Result> results = new ArrayList<>();
for (String doc : corpusDocument) {
Map<Long, Float> docVector = encoder.encodeDocuments(doc);
float score = calculateScore(queryVector, docVector);
results.add(new Result(doc, score));
}
results.sort((r1, r2) -> Float.compare(r2.score, r1.score));
for (Result result : results) {
System.out.println(result);
}
}
// step4: 選擇最優(yōu)參數(shù)并保存模型
encoder.setB(0.75f);
encoder.setK1(1.5f);
encoder.dump("./model.json");
// step5: 后續(xù)使用時(shí)可以加載模型
SparseVectorEncoder newEncoder = new SparseVectorEncoder();
newEncoder.load("./model.json");
// step6: 對(duì)模型進(jìn)行finetune并保存
List<String> extraCorpus = Arrays.asList(
"The fast fox jumps over the lazy, chubby dog",
"A swift fox hops over a napping old dog",
"The quick fox leaps over the sleepy, plump dog",
"The agile fox jumps over the dozing, heavy-set dog",
"A speedy fox vaults over a lazy, old dog lying in the sun"
);
newEncoder.train(extraCorpus);
newEncoder.dump("./new_model.json");
}
}
API參考
DashText API詳情可參考:https://pypi.org/project/dashtext/