Replicate versions all of the models you train and stores them on Amazon S3 or Google Cloud Storage, so you can pull down those models into production inference systems.
Using the Replicate Python API, you can load a model directly from within your inference script. For example, if you did this in your training script:
import torchimport replicatedef train():experiment = replicate.init(path=".", params={...})for epoch in range(num_epochs):# ...torch.save(model, "model.pth")experiment.checkpoint(path="model.pth",metrics={"loss": loss},primary_metric=("loss", "minimize"))
Then you can use this in your inference script to get the model back:
import replicateexperiment = replicate.experiments.get("e510303")checkpoint = experiment.best()model = torch.load(checkpoint.open("model.pth"))
You can also get files using the command-line interface. This might be useful if you want the model weights on disk, or if you're building a Docker image with the weights inside.
For example, if you run this for the example training script above:
replicate checkout e510303 -o weights/
Then the model weights will be written to weights/model.pth
.
Note: Either an experiment ID or checkpoint ID can be passed to replicate checkout
. The checkpoint ID makes a better versioning identifier because it specifies a specific version of your model weights.
You can only use an experiment ID in the Python API, currently. Support for checkpoint IDs is being worked on. See this GitHub issue for more details.