Skip to main content Link Menu Expand (external link) Document Search Copy Copied

Using your own models

Adding custom models

If you want to use your own model and/or want to make it public, you just need to implement a class inheriting from the custom_models/ class and implement the create_model method. There you can build/download the model and its weights. The constructors expects a device (str) and a kwargs (dict) where you can put model parameters. The backend attribute needs to be set to either pt (PyTorch) or tf (Tensorflow). The create_model method needs to return the model and an optional preprocessing method. If no preprocessing is set, the ImageNet default preprocessing is used. Afterwards you can put the file in the custom_models directory and create a pull request to include the model in the official GitHub repository.

from thingsvision.custom_models.custom import Custom
import torchvision.models as torchvision_models
import torch

class VGG16_ecoset(Custom):
    def __init__(self, device, **kwargs) -> None:
        self.backend = 'pt'
        self.preprocess = None

    def create_model(self):
          model = torchvision_models.vgg16(pretrained=False, num_classes=565)
          path_to_weights = ''
          state_dict = torch.hub.load_state_dict_from_url(path_to_weights, map_location=self.device)
          return model, self.preprocess

Use custom models with the get_extractor_from_model function

Alternatively, you can use the get_extractor_from_model() helper function to directly create an extractor for any PyTorch or TensorFlow model, without the need for creating a custom class. Simply pass the model, the device, the backend (either pt or tf) and optionally a preprocessing function to the get_extractor_from_model function. The function will return an extractor that can be used with the extract method, just like the built-in extractor classes.

from thingsvision import get_extractor_from_model
from torchvision.models import alexnet, AlexNet_Weights
import torch

# initialize a model of your choice, here we use AlexNet from torchvision 
# and load the ImageNet weights
model_weights = AlexNet_Weights.DEFAULT
model = alexnet(model_weights)

# you can also pass a custom preprocessing function that is applied to every 
# image before extraction
transforms = model_weights.transforms()

# provide the backend of the model (either 'pt' or 'tf')
backend = 'pt'

# set the device to 'cuda' if you have a GPU available
device = 'cuda' if torch.cuda.is_available() else 'cpu'

extractor = get_extractor_from_model(

Custom forward and flatten functions for PyTorch models

The created extractor will pass image batches to the model’s forward function (self.model(batch)) and optionally flatten the activations (act.view(act.size(0), -1)). Your models (e.g., a CLIP variant) forward function may have a different signature or activations in your model require a different way of flattening. In that case, you can provide custom forward and flatten functions to the get_extractor_from_model function to enable the extractor to work with your model. These will be used instead of the default ones. You can pass them as forward_fn and flatten_fn arguments to the get_extractor_from_model function.

from thingsvision import get_extractor_from_model
import torch.nn as nn

class ComplexForwardModel(nn.Module):
    def __init__(self):
        # model definition 
        # ...

    def forward(self, x, y):
        # forward function with custom signature
        # ...
        return x

def custom_forward_fn(self, batch):
    # custom forward function that passes the batch to the model's forward function
    return self.model(batch, y=None)

model = ComplexForwardModel()
backend = 'pt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

extractor = get_extractor_from_model(

Note that the custom forward function needs to have the following signature: forward_fn(self, batch), as it replaces the default forward_fn function of the extractor. The self argument is the extractor instance and the batch argument is the batch of images that is passed to the model’s forward function.