From fae80d80bbfe7168535d45775f0e60abb897bf1b Mon Sep 17 00:00:00 2001 From: jwansek Date: Thu, 2 Dec 2021 17:31:36 +0000 Subject: added query expansion, linked term searches --- search.py | 85 ++++++++++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 63 insertions(+), 22 deletions(-) (limited to 'search.py') diff --git a/search.py b/search.py index e6c3330..60cbd2a 100644 --- a/search.py +++ b/search.py @@ -1,6 +1,11 @@ +from nltk.corpus import wordnet +from nltk import pos_tag +import collections +import itertools import database import logging import terms +import time import sys import re @@ -12,32 +17,68 @@ logging.basicConfig( logging.StreamHandler() ]) +WORDNET_POS_MAP = { + 'NN': wordnet.NOUN, + 'NNS': wordnet.NOUN, + 'NNP': wordnet.NOUN, + 'NNPS': wordnet.NOUN, + 'JJ': [wordnet.ADJ, wordnet.ADJ_SAT], + 'JJS': [wordnet.ADJ, wordnet.ADJ_SAT], + 'RB': wordnet.ADV, + 'RBR': wordnet.ADV, + 'RBS': wordnet.ADV, + 'RP': [wordnet.ADJ, wordnet.ADJ_SAT], + 'VB': wordnet.VERB, +} + def main(search_words): - - txt = [re.sub(r"[^a-zA-Z\s]", "", i).rstrip().lower() for i in search_words] - - search_words = [] - for i in txt: - search_words += re.split(r"\s+|\n", i) - - search_words = [terms.LEM.lemmatize(i) for i in search_words if i != "" and i not in terms.STOPWORDS] - logging.info("Started searching. Using terms: %s" % " ".join(search_words)) + starttime = time.time() + pos_tags = [(token, tag) for token, tag in pos_tag(search_words) if token.lower().replace(",", "") not in terms.STOPWORDS] + + single_terms = [w.lower() for w in search_words] + logging.info("Started with the terms: %s" % str(single_terms)) + with database.Database() as db: + l = db.attempt_get_linked_words(single_terms) + linked_terms = collections.Counter([",".join(i) for i in l]) + # do again so we get a weight of 2 + linked_terms += collections.Counter([",".join(i) for i in l]) + logging.info("Found the linked terms: %s" % str(l)) + + synsets = [wordnet.synsets(token, WORDNET_POS_MAP[tag]) for token, tag in pos_tags if WORDNET_POS_MAP.__contains__(tag)] + synonyms = list(itertools.chain.from_iterable([[lemma.name().lower().replace("_", ",") for syn in synset for lemma in syn.lemmas()] for synset in synsets])) + + # for syn in synsets: + # for sy in syn: + # print([w for s in sy.closure(lambda s:s.hyponyms()) for w in s.lemma_names()]) + + for synonym in synonyms: + if len(synonym.split(",")) > 1: + linked_terms[synonym] = 1 + else: + single_terms.append(synonym) + + single_terms = collections.Counter(single_terms) + + logging.info("Expanded single terms to: %s" % str(single_terms)) + logging.info("Expanded linked terms to: %s" % str(linked_terms)) + logging.info("\n\n") with database.Database() as db: - tf_idf_scores = [] - for term in search_words: - tf_idf_scores.append(db.get_tf_idf_score(term, tf_idf_thresh = 1, limit = 1000)) - logging.info("Fetched %d scores for term '%s'..." % (len(tf_idf_scores[-1]), term)) - - merged_scores = {i: 0 for i in range(1, db.get_num_documents() + 1)} - for scorelist in tf_idf_scores: - for doc_id, score in scorelist.items(): - merged_scores[doc_id] += score - logging.info("Merged scores...") - - sorted_scores = list(reversed(sorted(merged_scores.items(), key = lambda i: i[1]))) + tf_idf_scores = collections.Counter() + for single_term, search_weight in single_terms.items(): + scores = collections.Counter(db.get_tf_idf_score_single(single_term, tf_idf_thresh = 1, limit = 1000, multiplier = search_weight)) + logging.info("Got %d scores for term '%s' (multiplier %d)" % (len(scores), single_term, search_weight)) + tf_idf_scores += scores + + for linked_terms, search_weight in linked_terms.items(): + scores = db.get_tf_idf_score_linked(linked_terms.split(","), tf_idf_thresh=0, multiplier=search_weight) + logging.info("Got %d scores for linked term '%s' (multiplier %d)" % (len(scores), str(linked_terms), search_weight)) + tf_idf_scores += scores + + sorted_scores = list(reversed(sorted(tf_idf_scores.items(), key = lambda i: i[1]))) toshow = 30 logging.info("Sorted scores...") + logging.info("Results:\n\n") for i, j in enumerate(sorted_scores, 0): if i >= toshow: @@ -46,7 +87,7 @@ def main(search_words): docid, score = j logging.info("%.2f - %d - %s" % (score, docid, db.get_document_name_by_id(docid))) - logging.info("%d results found in total" % len([i[1] for i in sorted_scores if i[1] > 0.1])) + logging.info("Got %d results in total. Took %.2f minutes (%.2fs per term)" % (len(tf_idf_scores), (time.time() - starttime) / 60, (time.time() - starttime) / (len(single_terms) + len(linked_terms)))) if __name__ == "__main__": -- cgit v1.2.3