ODE-Specialized Physics-Informed Neural Network (PINN) Solver

NNODE(chain, opt=OptimizationPolyalgorithms.PolyOpt(), init_params = nothing;
                          autodiff=false, batch=0, kwargs...)

Algorithm for solving ordinary differential equations using a neural network. This is a specialization of the physics-informed neural network which is used as a solver for a standard ODEProblem.


Note that NNODE only supports ODEs which are written in the out-of-place form, i.e. du = f(u,p,t), and not f(du,u,p,t). If not declared out-of-place then the NNODE will exit with an error.

Positional Arguments

  • chain: A neural network architecture, defined as either a Flux.Chain or a Lux.AbstractExplicitLayer.
  • opt: The optimizer to train the neural network. Defaults to OptimizationPolyalgorithms.PolyOpt()
  • init_params: The initial parameter of the neural network. By default this is nothing which thus uses the random initialization provided by the neural network library.

Keyword Arguments

  • autodiff: The switch between automatic and numerical differentiation for the PDE operators. The reverse mode of the loss function is always automatic differentation (via Zygote), this is only for the derivative in the loss function (the derivative with respect to time).
  • batch: The batch size to use for the internal quadrature. Defaults to 0, which means the application of the neural network is done at individual time points one at a time. batch>0 means the neural network is applied at a row vector of values t simultaniously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data.
  • strategy: The training strategy used to choose the points for the evaluations. Default of nothing means that QuadratureTraining with QuadGK is used if no dt is given, and GridTraining is used with dt if given.
  • kwargs: Extra keyword arguments are splatted to the Optimization.jl solve call.


f(u,p,t) = cos(2pi*t)
tspan = (0.0f0, 1.0f0)
u0 = 0.0f0
prob = ODEProblem(linear, u0 ,tspan)
chain = Flux.Chain(Dense(1,5,σ),Dense(5,1))
opt = Flux.ADAM(0.1)
sol = solve(prob, NeuralPDE.NNODE(chain,opt), dt=1/20f0, verbose = true,
            abstol=1e-10, maxiters = 200)

Solution Notes

Note that the solution is evaluated at fixed time points according to standard output handlers such as saveat and dt. However, the neural network is a fully continuous solution so sol(t) is an accuate interpolation (up to the neural network training result). In addition, the OptimizationSolution is returned as sol.k for further analysis.


Lagaris, Isaac E., Aristidis Likas, and Dimitrios I. Fotiadis. "Artificial neural networks for solving ordinary and partial differential equations." IEEE Transactions on Neural Networks 9, no. 5 (1998): 987-1000.