Skip to main content

Documentation Index

Fetch the complete documentation index at: https://lancedb-bcbb4faf-mintlify-f5da8d82.mintlify.app/llms.txt

Use this file to discover all available pages before exploring further.

LanceDB provides a seamless integration with PyTorch for training and inference. This allows you to use LanceDB as a backend for your PyTorch models, and to use PyTorch for training and inference. You can use LanceDB to store your data, and PyTorch to train your models.

Quickstart

The Table class in LanceDB implements a contract for a PyTorch Dataset. This means you can simply use a LanceDB table in a PyTorch dataloader directly.
Python
import lancedb
import torch
import pyarrow as pa

mem_db = lancedb.connect("memory://")
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))

# Any LanceDB table can be used as a PyTorch Dataset
dataloader = torch.utils.data.DataLoader(
    table, batch_size=1024, shuffle=True
)

for batch in dataloader:
    print(batch)
Although the Table class in LanceDB implements the torch.utils.data.Dataset interface, you may find that using a table Permutation is more flexible.
Python
from lancedb.permutation import Permutation

permutation = Permutation.identity(table)
dataloader = torch.utils.data.DataLoader(permutation)

Output Formats

By default, a Table data loader will emit a pyarrow.RecordBatch. To convert to a different format (such as a pytorch.Tensor), you will need to provide a custom collate function. The Permutation class is more flexible. By default, the output will be a list of dicts. This is the default output format of standard data loaders and usually more convenient when you are getting started. However, there is a significant performance penalty converting from Arrow, Lance’s internal representation, to this default format. To address this, the Permutation class provides a set of builtin transform functions that can be applied to map the Arrow data in different ways. The arrow and polars formats will always avoid data copies. However, numpy, pandas, and torch_col formats will also avoid data copies in most cases. The python, python_col, and torch formats will all require at least one full copy of the data and are the slowest options.

Using the torch_col format with a torch data loader

The torch_col format is the most efficient way to convert from Arrow to a torch.Tensor. It will convert the entire Arrow batch to a column-major torch.Tensor. In other words, given C columns and R rows, the resulting Tensor will have shape (C, R). However, this format generates an error if you are using a torch.utils.data.DataLoader with the default collation function:
Python
TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor
This error occurs because the default collation function does not currently expect a single two-dimensional tensor. It expects a list of tensors which it will then stack. This is what is output by the torch format but that format requires a data copy. To avoid this error, and avoid data copies, you will need to provide a custom collation function in addition to specifying the torch_col format.
Python
from lancedb.permutation import Permutation

permutation = Permutation.identity(table).with_format("torch_col")
dataloader = torch.utils.data.DataLoader(permutation, collate_fn=lambda x: x)
This will now output a single two-dimensional tensor for each batch.

Selecting columns

By default, the Table class will return all columns in the table when used as input to PyTorch. If you only need a subset of columns, you can significantly reduce your I/O requirements by selecting only the columns you need. The Permutation class provides a select_columns method that provides this functionality.
Python
from lancedb.permutation import Permutation

permutation = Permutation.identity(table).select_columns(["id", "prompt"])
dataloader = torch.utils.data.DataLoader(
    permutation, batch_size=1024, shuffle=True
)

for batch in dataloader:
    print(batch.schema)

Using multiple DataLoader workers

PyTorch’s DataLoader can fan out reads across worker processes by setting num_workers > 0. LanceDB tables and Permutation objects are picklable, so each worker reopens its own connection after the worker process starts. Because LanceDB is multi-threaded internally, use the spawn start method (not fork) when running with multiple workers. See the performance guide for more on safe multiprocessing patterns.
Python
from lancedb.permutation import Permutation

permutation = Permutation.identity(table)
dataloader = torch.utils.data.DataLoader(
    permutation,
    batch_size=1024,
    shuffle=True,
    num_workers=4,
    multiprocessing_context="spawn",
    persistent_workers=True,
)

Remote tables in DataLoader workers

Tables opened from a remote LanceDB Enterprise connection (db://...) also work with multi-worker DataLoaders. The connection details needed to reopen the table — db_url, api_key, region, host_override, and the serializable parts of client_config — travel with the pickled table and are used to rebuild the connection in each worker.
Python
import lancedb
from lancedb.permutation import Permutation

db = lancedb.connect(
    "db://my-database",
    api_key="sk-...",
    region="us-east-1",
)
table = db.open_table("my_table")

permutation = Permutation.identity(table).select_columns(["id", "image"])
dataloader = torch.utils.data.DataLoader(
    permutation,
    batch_size=512,
    num_workers=4,
    multiprocessing_context="spawn",
)
This embeds the API key in the pickle sent to each worker. If you’d rather load credentials inside the worker — for example, from an environment variable or a secret manager — use the connection factory escape hatch described below. A factory is also required when your client_config uses a non-serializable header_provider.

Providing a custom connection factory

Permutation.with_connection_factory lets you control how each worker reopens the base table. The factory takes the base table name and returns a LanceDB table. It must be picklable, which in practice means a top-level function, a functools.partial of one, or an instance of a picklable class with __call__ — lambdas and closures over local variables will not work.
Python
import os
import lancedb
from lancedb.permutation import Permutation

def open_table(name: str):
    db = lancedb.connect(
        "db://my-database",
        api_key=os.environ["LANCEDB_API_KEY"],
        region="us-east-1",
    )
    return db.open_table(name)

permutation = (
    Permutation.identity(table)
    .with_connection_factory(open_table)
)
dataloader = torch.utils.data.DataLoader(
    permutation,
    batch_size=512,
    num_workers=4,
    multiprocessing_context="spawn",
)