Searching for machine learning models using semantic search
Finding models on the Hugging Face hub using semantic search
- Finding candidate models
- Using the huggingface_hub API to download some model metadata
- Semantic search of model cards
- Can we search using model labels?
- Conclusion
The Hugging Face model hub has (at the time of the last checking) 60,509 models publicly available. Some of these models are useful as base models for further fine-tuning; these include your classics like bert-base-uncased
.
The hub also has more obscure indie hits that might already do a good job on your desired downstream task or be a closer start. For example, if one wanted to classify the genre of 18th Century books, it might make sense to start with a model for classifying 19th Century books.
Finding candidate models
Ideally, we'd like a quick way to identify if a model might already do close to what we want. From there, we would likely want to review a bunch of other info about the model before deciding if it might be helpful for us or not.
Unfortunately, finding suitable models on the hub isn't always that easy. Even knowing that models for genre classification exist on the hub, we don't find any results.
It's not documented exactly how the search on the hub works, but it seems to be based mainly on the model's name rather than the README or other information. In this blog post, I will continue some previous experiments with embeddings to see if there might be different ways in which we could identify potential models.
This will be a very rough experiment and is more about establishing whether this is an avenue worth exploring rather than a fully fleshed-out approach.
First install some libraries we'll use:
import torch
deps = ["datasets" ,"sentence-transformers", "rich['jupyter']", "requests"]
if torch.cuda.is_available():
deps.append("faiss-gpu")
else:
deps.append("faise-cpu")
%%capture
!pip install {" ".join(deps)} --upgrade
!git config --global credential.helper store
These days I almost always have the rich extension loaded!
%load_ext rich
Using the huggingface_hub API to download some model metadata
Our goal is to see if we might be able to find suitable models more efficiently using some form of semantic search (i.e. using embeddings). To do this, we should grab some model data from the hub. The easiest way to do this is using the hub API.
from huggingface_hub import hf_api
import re
from rich import print
api = hf_api.HfApi()
api
We can take a look at some example models
all_models = api.list_models()
all_models[:3]
For a particular model we can also see what files there are.
files = api.list_repo_files(all_models[0].modelId)
files
Filtering
To limit the scope of this blog post, we'll focus only on Pytorch models and 'text classification' models. The metadata about the model type is likely usually pretty reliable. The model task metadata, on the other hand, is not always reliable in my experience. This means we probably have some models that aren't text-classification models and don't include some actual text classification models in our dataset. For now, we won't worry too much about this.
from huggingface_hub import ModelSearchArguments
model_args = ModelSearchArguments()
from huggingface_hub import ModelFilter
model_filter = ModelFilter(
task=model_args.pipeline_tag.TextClassification,
library=model_args.library.PyTorch
)
api.list_models(filter=model_filter)[0]
Now we have a filter we'll use that to grab all the models that match this filter.
all_models = api.list_models(filter=model_filter)
all_models[0]
Let's see how many models that gives us.
len(all_models)
Later on, in this blog, we'll want to work with the config.json
files (we'll get back to why later!), so we'll quickly check that all our models have this.
def has_config(model):
has_config = False
files = model.siblings
for file in files:
if "config.json" in file.rfilename:
has_config = True
return has_config
else:
continue
has_config(all_models[0])
has_config = [model for model in all_models if has_config(model)]
Let's check how many we have now
len(has_config)
We can also download a particular file from the hub
from huggingface_hub import hf_hub_download
file = hf_hub_download(repo_id=all_models[0].modelId, filename="config.json")
file
import json
with open(file) as f:
data = json.load(f)
data
We can also check if the model has a README.md
def has_file_in_repo(model,file_name):
has_file = False
files = model.siblings
for file in files:
if file_name in file.rfilename:
has_file = True
return has_file
else:
continue
has_file_in_repo(has_config[0],'README.md')
has_readme = [model for model in has_config if has_file_in_repo(model,"README.md")]
We can see that there are more configs than READMEs
len(has_readme)
len(has_config)
We now write some functions to grab both the README.md
and config.json
files from the hub.
from requests.exceptions import JSONDecodeError
import concurrent.futures
@lru_cache(maxsize=None)
def get_model_labels(model):
try:
url = hf_hub_url(repo_id=model.modelId, filename="config.json")
return model.modelId, list(requests.get(url).json()['label2id'].keys())
except (KeyError, JSONDecodeError, AttributeError):
return model.modelId, None
get_model_labels(has_config[0])
def get_model_readme(model):
url = hf_hub_url(repo_id=model.modelId, filename="README.md")
return requests.get(url).text
def get_data(model):
readme = get_model_readme(model)
_, labels = get_model_labels(model)
return model.modelId, labels, readme
Since this takes a little while we make a progress bar and do this using multiple threads
from tqdm.auto import tqdm
with tqdm(total=len(has_config)) as progress:
with concurrent.futures.ThreadPoolExecutor() as e:
tasks = []
for model in has_config:
future = e.submit(get_data, model)
future.add_done_callback(lambda p: progress.update())
tasks.append(future)
results = [task.result() for task in tasks]
Load our data using Pandas.
import pandas as pd
df = pd.DataFrame(results,columns=['modelId','label','readme'])
df
You can see we now have a DataFrame containing the modelID, the model labels and the README.md
for each model (where it exists).
Since the README.md
(the model card) is the obvious source of information about a model we'll start here. One question we may have is how long our the README.md
is. Some models have very detailed model cards whilst others have very little information in the model card. We can get a bit of a sense of this by looking at the range of README.md
lenghts:
df['readme'].apply(len).describe()
We might want to filter on the length of the README so we'll store that info in a new column.
df['readme_len'] = df['readme'].apply(len)
Since we might want to work with this data again, let's load it into a datasets
Dataset and use push_to_hub
to store a copy.
from datasets import Dataset
ds = Dataset.from_pandas(df)
ds
from huggingface_hub import notebook_login
notebook_login()
ds.push_to_hub('davanstrien/hf_model_metadata')
We can now load it again using load_dataset
.
from datasets import load_dataset
ds = load_dataset('davanstrien/hf_model_metadata', split='train')
Clean up some memory...
del df
Semantic search of model cards
We now get to the main point of all of this. Can we use semantic search to try and find models of interest? For this, we'll use the sentence-transformers library. This blog won't cover all the background of this library. The docs give a helpful overview and some tutorials.
To start, we'll see if we can search using the information in the README.md
. This should, in theory, contain data that might be similar to the kinds of things we want to search for when finding candidate models. We might prefer to use semantic search over an exact match because the terms we use might be different, or there is a related concept/model that might be close enough to make it worthwhile for fine-tuning.
First, we import the SentenceTransformer
class and some util functions.
from sentence_transformers import SentenceTransformer, util
We'll now download an embedding model. There are many we could choose from but since we're just trying things out at the moment we won't stress about the particular model we use here.
model = SentenceTransformer('all-MiniLM-L6-v2')
Let's start on longer README's, here i mean a long readme that is just not super short...
ds_longer_readmes = ds.filter(lambda x: x['readme_len']>100)
We now create embeddings for the readme
column and store this in a new embedding
column
def encode_readme(readme):
return model.encode(readme,device='cuda')
ds_with_embeddings = ds_longer_readmes.map(lambda example:
{"embedding":encode_readme(example['readme'])},batched=True, batch_size=16)
ds_with_embeddings
We can now use the add_fais_index
to create an index which allows us to efficiently query these embeddings
ds_with_embeddings.add_faiss_index(column='embedding')
query_readme = ds_with_embeddings[35]['readme']
print(query_readme)
We pass this README into the model we used to create our embedding. This creates a query embedding for this README.
q = model.encode(query_readme)
We can use get_nearest_examples
to look for the most similar results to this query.
scores, retrieved_examples = ds_with_embeddings.get_nearest_examples('embedding', q, k=10)
Let's take a look at the first result
print(retrieved_examples['modelId'][0])
print(retrieved_examples["readme"][0])
and a lower similarity result
print(retrieved_examples["readme"][9])
The results seem pretty reasonable; the first result appears to be a duplicate. The lower result is for a slightly different task using social media data.
q = model.encode("fake news")
scores, retrieved_examples = ds_with_embeddings.get_nearest_examples('embedding', q, k=10)
print(retrieved_examples["readme"][0])
print(retrieved_examples["readme"][1])
print(retrieved_examples["readme"][2])
Not a bad start. Let's try another one
q = model.encode("financial sentiment")
scores, retrieved_examples = ds_with_embeddings.get_nearest_examples('embedding', q, k=10)
print(retrieved_examples["readme"][0])
print(retrieved_examples["readme"][1])
print(retrieved_examples["readme"][9])
These seem like a good starting point. However, we have a few issues relying on model cards alone. Firstly a lot of models don't include them and the quality of them can be mixed. It's maybe a question if we want to use a model that has no model card at all but it is possible that despite a good model card we don't capture everything we'd need for searching in the README.
Can we search using model labels?
We're only working with classification models in this case. For most Pytorch models on the hub, we have a config file. This config usually contains the model's labels. For example, 'positive', 'negative'.
Maybe instead of relying only on the metadata, we can search 'inside' the model. The labels will often be a helpful reflection of what we're looking for. For example, we want to find a sentiment classification model that roughly puts text into positive or negative sentiment. Again, relying on exact label matches may not work well, but maybe embeddings get around this problem. Let's try it out!
Let's look at an example label.
ds[0]['label']
Since we're expecting labels to match this format lets filter out any that don't fit this structure.
ds = ds.filter(lambda example: isinstance(example['label'],list))
How to create embeddings for our labels?
How should we encode our labels? At the moment, we have a list of labels. One option would be to create an embedding for every single label, which will require us to query multiple embeddings to check for a match. We may also prefer intuatively to have an embedding for the combination of labels. This is because we probably know more about the model type from all its labels rather than looking at one label at a time. We'll deal with the labels very crudely by joining them on ,
and creating a single string out of all the labels. I'm sure this isn't the best possible approach, but it might be a good place to start testing this idea.
ds = ds.map(lambda example: {"string_label": ",".join(example['label'])})
ds
ds_with_embeddings = ds.map(lambda example:
{"label_embedding":encode_readme(example['string_label'])},batched=True, batch_size=16)
ds_with_embeddings
ds_with_embeddings[0]['string_label']
q = model.encode("negative")
ds_with_embeddings.add_faiss_index(column='label_embedding')
scores, retrieved_examples = ds_with_embeddings.get_nearest_examples('label_embedding', q, k=10)
retrieved_examples['label'][:10]
So far, these results look pretty good, although we haven't done anything we couldn't do with simple string matching. Let's see what happens if we use a slightly more abstract search.
q = model.encode("music")
scores, retrieved_examples = ds_with_embeddings.get_nearest_examples('label_embedding', q, k=10)
retrieved_examples['label'][:10]
We can see that we get back labels related to music genre: ['Dance', 'Heavy Metal', 'Hip Hop', 'Indie', 'Pop', 'Rock']
, for our first four results. After that, we get back ['business', 'entertainment', 'sports'],
which might not be too far off what we want if we searched for music.
How about another search term
q = model.encode("hateful")
scores, retrieved_examples = ds_with_embeddings.get_nearest_examples('label_embedding', q, k=10)
retrieved_examples['label'][:10]
Again here we have something quite close to what we'd get with string matching, but we have a bit more flexibility in how we spell/define our labels which might help surface more possible results.
We'll try a bunch more things...
def query_labels(query:str):
q = model.encode(query)
scores, retrieved_examples = ds_with_embeddings.get_nearest_examples('label_embedding', q, k=10)
print(f"results for: {query}")
print(list(zip(retrieved_examples['label'][:10],retrieved_examples['modelId'][:10])))
query_labels("politics")
query_labels("fiction, non_fiction")
Let's try the set of emotions one should feel everyday.
query_labels("worry, disgust, anxiety, fear")
This example of searching for a set of labels might be a better approach in general since the query will better match the format of the intitial search.
Conclusion
It seems like there is some merit in exploring some of these ideas further. There are a lot of improvements that could be made:
- how the embeddings are created
- removing some 'noise' from the README, for example, by first parsing the Markdown
- improving how the embeddings are created for the labels
- combining the embeddings in some way either upfront or when queryig
- a bunch of other things...
If I find some spare time, I plan to dig into these topics a bit further. This is also a nice excuse to play with one of the new open source embedding databases that have popped up in the last couple of years.