How to take advantage of an optimizer for your non-Pytorch project
The goal
Pytorch provides you with optimizer. Their job is to update your parameter (e.g. weights) with a gradient. How can we use it for our own stimulation, without using the rest of PyTorch?
Questions to David Rotermund
In a first step we build a dummy torch layer from a nn.Module. It contains a registered parameter: (here it is the weight).
Dummy torch layer
import torch
my_module: torch.nn.Module = torch.nn.Module()
my_parameter: torch.nn.Parameter = torch.nn.Parameter(
torch.ones((10, 5)), requires_grad=True
)
my_module.register_parameter("w", my_parameter)
if my_module._parameters["w"] is None:
raise Exception("Parameter w is missing.")
else:
my_gradient = torch.zeros_like(my_module._parameters["w"])
The parameter consists of data
my_module._parameters["w"].data
and the gradient
my_module._parameters["w"].grad
We can interface the gradient and the data of the parameter via .grad and .data. It is important both are torch.Tensor elements and have the same dtype (Probably the dtype is not allowed to change but I didn’t tested that.).
The optimizer
Now we can connect the parameters to an optimizer . Here I will use Adam:
optimizer = torch.optim.Adam(my_module.parameters())
If we want to use the power of the optimizer we fill our momentary gradient and the old weights into the parameter .grad and .data and call
optimizer.step()
Now we can read the optimized weights from .data and put them into our non-PyTorch network. I addition we can add a lr scheduler too.
Notes:
- Don’t replace the my_module._parameters[“w”] with a new parameter object. Then the optimizer looses the contact to the parameter. Only use .data and .grad !
- Don’t use this for different parameters. Let’s say you have two layers. Then you need to create TWO fake layers and TWO optimizers (and TWO lr_scheduler). An optimizer has internal memory thus it should only deal with it’s parameter.
- You need to tell the optimizier if you want to use old parameter - update (default) or old parameter + update (maximize=True)
The source code is Open Source and can be found on GitHub.