Skip to content

Basic usage

pytorch_ndonnx is build around the idea of duck-typing. Existing code that is written to be called with pytorch.Tensor objects is instead called using pytorch_ndonnx.Tensor objects. The latter is a thin wrapper around a ndonnx.Array which in turn keeps track of any operation applied to it. Ultimately, the entire computational graph can be reconstructed from a given output. See the ndonnx documentation for further details.

Example

import ndonnx as ndx
import onnx
import pytorch_ndonnx as tx
import torch.nn as nn
import torch.nn.functional as F


class MyModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)

    def forward(self, x):
        return F.relu(F.max_pool2d(self.conv1(x), 2))

# 'arguments' are inputs to a graph
arr = ndx.argument(shape=("N", 1, 28, 28), dtype=ndx.float32)
# Wrap the resulting ndonnx array in a pytorch_ndonnx.Tensor
tensor = tx.Tensor(arr)
# Use 'tensor' to simply ducktype the existing code
result = MyModule()(tensor)

# Create the onnx graph in between the specified arguments and resulting arrays.
model_proto = ndx.build({"x": arr}, {"y": result.to_ndonnx()})
# Write the model to disk
onnx.save_model(model_proto, "example.onnx")

Calling pytorch functions

Pytorch provides an extension protocol that allows for calling its various free functions with custom tensor objects. pytorch_ndonnx provides support for the functions listed here.