#!/usr/bin/env python3

r'''Convert a kalibr 1-parameter distortion model to something more standard

SYNOPSIS

  $ analyses/mrcal-convert-lensmodel-from-kalibr-fov \
      --num-trials 10                                \
      --radius 0                                     \
      --viz LENSMODEL_OPENCV8                        \
      640 480 251.1 249.4 325.4 238.1 0.91

  RMS error of this solution: 0.038970851989073026 pixels.
  RMS error of this solution: 0.312567345983669 pixels.
  RMS error of this solution: 0.03897085198907192 pixels.
  RMS error of this solution: 0.31102107617891206 pixels.
  RMS error of this solution: 0.03897085198907148 pixels.
  RMS error of this solution: 0.31024552395361005 pixels.
  RMS error of this solution: 0.038970851989072686 pixels.
  RMS error of this solution: 0.03897085198907228 pixels.
  RMS error of this solution: 0.31256734598368013 pixels.
  RMS error of this solution: 232.49658418723286 pixels.
  RMS error of the BEST solution: 0.03897085198907148 pixels.
  # generated on 2025-11-10 23:53:53 with   ./mrcal-convert-lensmodel-from-kalibr-fov --num-trials 10 --radius 0 --viz LENSMODEL_OPENCV8 640 480 251.1 249.4 325.4 238.1 0.91
  {
      'lensmodel':  'LENSMODEL_OPENCV8',

      # intrinsics are fx,fy,cx,cy,distortion0,distortion1,....
      'intrinsics': [ 269.7069092, 267.8824348, 325.3977834, 238.1017166, 0.4354014227, 0.02216823609, -1.263764376e-06, 1.317529587e-06, 6.778953407e-05, 0.7476900385, 0.09048943655, 0.001241101467,],

      'rt_cam_ref': [ 0, 0, 0, 0, 0, 0,],
      'extrinsics': [ 0, 0, 0, 0, 0, 0,], # for compatibility with mrcal < 2.5

      'imagersize': [ 640, 480,],
  }

This is a cut down 'mrcal-convert-lensmodel --sampled' for the kalibr "fov"
model specifically. This model isn't supported by mrcal, so that tool couldn't
be used directly. This is a quick hack to interoperate with tools that use this
model.

The kalibr models are documented here:

  https://github.com/ethz-asl/kalibr/wiki/supported-models

See the docs for mrcal-convert-lensmodel for usage details

'''



import sys
import argparse
import re
import os

def parse_args():

    parser = \
        argparse.ArgumentParser(description = __doc__,
                                formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument('--gridn',
                        type=int,
                        default = (30,20),
                        nargs = 2,
                        help='''How densely we should sample the imager. By default we use
                        a 30x20 grid''')
    parser.add_argument('--where',
                        type=float,
                        nargs=2,
                        help='''Used with or without --sampled. I use a subset
                        of the imager to compute the fit. The active region is a
                        circle centered on this point. If omitted, we will focus
                        on the center of the imager''')
    parser.add_argument('--radius',
                        type=float,
                        help='''Used with or without --sampled. I use a subset
                        of the imager to compute the fit. The active region is a
                        circle with a radius given by this parameter. If radius
                        == 0, I'll use the whole imager for the fit. If radius <
                        0, this parameter specifies the width of the region at
                        the corners that I should ignore: I will use
                        sqrt(width^2 + height^2)/2. - abs(radius). This is valid
                        ONLY if we're focusing at the center of the imager. By
                        default I ignore a large-ish chunk area at the corners.''')

    parser.add_argument('--viz',
                        action='store_true',
                        help='''Visualize the differences between the input and
                        output models''')
    parser.add_argument('--cbmax',
                        type=float,
                        default=4,
                        help='''Maximum range of the colorbar''')
    parser.add_argument('--title',
                        type=str,
                        default = None,
                        help='''Used if --viz. Title string for the diff plot.
                        Overrides the default title. Exclusive with
                        --extratitle''')
    parser.add_argument('--extratitle',
                        type=str,
                        default = None,
                        help='''Used if --viz. Additional string for the plot to
                        append to the default title. Exclusive with --title''')
    parser.add_argument('--hardcopy',
                        type=str,
                        help='''Used if --viz. Write the diff output to disk,
                        instead of making an interactive plot''')
    parser.add_argument('--terminal',
                        type=str,
                        help=r'''Used if --viz. gnuplotlib terminal. The default
                        is good almost always, so most people don't need this
                        option''')
    parser.add_argument('--set',
                        type=str,
                        action='append',
                        help='''Used if --viz. Extra 'set' directives to
                        gnuplotlib. Can be given multiple times''')
    parser.add_argument('--unset',
                        type=str,
                        action='append',
                        help='''Used if --viz. Extra 'unset' directives to
                        gnuplotlib. Can be given multiple times''')

    parser.add_argument('--force', '-f',
                        action='store_true',
                        default=False,
                        help='''By default existing models on disk are not
                        overwritten. Pass --force to overwrite them without
                        complaint''')
    parser.add_argument('--outdir',
                        type=lambda d: d if os.path.isdir(d) else \
                        parser.error(f"--outdir requires an existing directory as the arg, but got '{d}'"),
                        help='''Directory to write the output into. If omitted,
                        we use the directory of the input model''')
    parser.add_argument('--num-trials',
                        type    = int,
                        default = 1,
                        help='''If given, run the solve more than once. Useful
                        in case random initialization produces noticeably
                        different results. By default we run just one trial''')
    parser.add_argument('to',
                        type=str,
                        help='The target lens model')

    parser.add_argument('w_h_fx_fy_cx_cy_s',
                        type=float,
                        nargs=7,
                        help='''Input camera model. This is "kalibr fov"
                        parameters. mrcal doesn't support this, so I use these
                        manually''')

    args = parser.parse_args()

    if args.title      is not None and \
       args.extratitle is not None:
        print("Error: --title and --extratitle are exclusive", file=sys.stderr)
        sys.exit(1)

    return args

args = parse_args()

# arg-parsing is done before the imports so that --help works without building
# stuff, so that I can generate the manpages and README



# I import the LOCAL mrcal
sys.path[:0] = f"{os.path.dirname(os.path.realpath(__file__))}/..",

import numpy as np
import numpysane as nps
import time
import copy
import mrcal



def fov_unproject_normalized( q, fxy, cxy, s ):

    # The reference implementation is here:
    # https://github.com/ethz-asl/kalibr/blob/1f60227442d25e36365ef5f72cd80b9666d73467/aslam_cv/aslam_cameras/include/aslam/cameras/implementation/FovDistortion.hpp#L89

    # kMaxValidAngle is defined here:
    # https://github.com/ethz-asl/kalibr/blob/1f60227442d25e36365ef5f72cd80b9666d73467/aslam_cv/aslam_cameras/include/aslam/cameras/FovDistortion.hpp#L146


    # q has shape (..., 2)

    # shape (..., 2)
    xy = (q - cxy)/fxy

    # shape (...)
    r_d   = nps.mag(xy)

    i_invalid = r_d < 1e-6
    r_d[i_invalid] = 1. # to not /0. Will do the "right" thing in a bit

    mul2tanwby2 = 2.0 * np.tan(s / 2.0)

    # shape (...)
    r_u = np.tan(r_d * s) / mul2tanwby2 / r_d

    # if r_d is ~0 then
    #   np.tan(r_d * s) / mul2tanwby2 / r_d ~
    #   np.tan(r_d * s) / r_d / mul2tanwby2  ~
    #   np.tan(r_d * s) / (r_d*s) * s / mul2tanwby2  ~
    #   s / mul2tanwby2  ~
    r_u[i_invalid] = s / mul2tanwby2

    # shape (...,2)
    xy = \
        nps.dummy(r_u, axis = -1) * \
        xy

    v = nps.glue(xy, np.ones(xy.shape[:-1] + (1,)),
                 axis = -1)

    # normalize
    v /= nps.dummy(nps.mag(v), axis=-1)

    return v






lensmodel_to = args.to

try:
    meta = mrcal.lensmodel_metadata_and_config(lensmodel_to)
except Exception as e:
    print(f"Invalid lens model '{lensmodel_to}': couldn't get the metadata: {e}",
          file=sys.stderr)
    sys.exit(1)
if not meta['has_gradients']:
    print(f"lens model {lensmodel_to} is not supported at this time: its gradients aren't implemented",
          file=sys.stderr)
    sys.exit(1)

try:
    Ndistortions = mrcal.lensmodel_num_params(lensmodel_to) - 4
except:
    print(f"Unknown lens model: '{lensmodel_to}'", file=sys.stderr)
    sys.exit(1)



file_output = sys.stdout

dims = np.array(args.w_h_fx_fy_cx_cy_s[:2],
                dtype = np.int32)

if args.radius is None:
    # By default use 1/4 of the smallest dimension
    args.radius = -np.min(dims) // 4
    print(f"Default radius: {args.radius}. We're ignoring the regions {-args.radius} pixels from each corner",
          file=sys.stderr)
    if args.where is not None and \
       nps.norm2(args.where - (dims - 1.) / 2) > 1e-3:
        print("A radius <0 is only implemented if we're focusing on the imager center: use an explicit --radius, or omit --where",
              file=sys.stderr)
        sys.exit(1)


# Alrighty. Let's actually do the work. I do this:
#
# 1. Sample the imager space with the known model
# 2. Unproject to get the 3d observation vectors
# 3. Solve a new model that fits those vectors to the known observations, but
#    using the new model

### I sample the pixels in an NxN grid
Nx,Ny = args.gridn

qx = np.linspace(0, dims[0]-1, Nx)
qy = np.linspace(0, dims[1]-1, Ny)

# q is (Ny*Nx, 2). Each slice of q[:] is an (x,y) pixel coord
q = np.ascontiguousarray( nps.transpose(nps.clump( nps.cat(*np.meshgrid(qx,qy)), n=-2)) )
if args.radius != 0:
    # we use a subset of the input data for the fit
    if args.where is None:
        focus_center = (dims - 1.) / 2.
    else:
        focus_center = args.where

    if args.radius > 0:
        r = args.radius
    else:
        if nps.norm2(focus_center - (dims - 1.) / 2) > 1e-3:
            print("A radius <0 is only implemented if we're focusing on the imager center",
                  file=sys.stderr)
            sys.exit(1)
        r = nps.mag(dims)/2. + args.radius

    grid_off_center = q - focus_center
    i = nps.norm2(grid_off_center) < r*r
    q = q[i, ...]


# To visualize the sample grid:
# import gnuplotlib as gp
# gp.plot(q[:,0], q[:,1], _with='points pt 7 ps 2', xrange=[0,3904],yrange=[3904,0], wait=1, square=1)
# sys.exit()

### I unproject this, with broadcasting
# shape (Ny*Nx, 3)
p = fov_unproject_normalized( q,
                              np.array(args.w_h_fx_fy_cx_cy_s[2:4]),
                              np.array(args.w_h_fx_fy_cx_cy_s[4:6]),
                              args.w_h_fx_fy_cx_cy_s[6] )

# Ignore any failed unprojections
i_finite = np.isfinite(p[:,0])
p = p[i_finite]
q = q[i_finite]
Npoints = len(q)
weights = np.ones((Npoints,), dtype=float)

### Solve!

### I solve the optimization a number of times with different random seed
### values, taking the best-fitting results. This is required for the richer
### models such as LENSMODEL_OPENCV8
err_rms_best         = 1e10
intrinsics_data_best = None
rt_cam_ref_best      = None

for i in range(args.num_trials):

    # random seed for the new intrinsics
    intrinsics_core = np.array(args.w_h_fx_fy_cx_cy_s[2:6])
    distortions     = (np.random.rand(Ndistortions) - 0.5) * 1e-3 # random initial seed
    intrinsics_to_values = nps.dummy(nps.glue(intrinsics_core, distortions, axis=-1),
                                     axis=-2)
    # each point has weight 1.0
    observations_points = nps.glue(q, nps.transpose(weights), axis=-1)
    observations_points = np.ascontiguousarray(observations_points) # must be contiguous. mrcal.optimize() should really be more lax here

    # Which points we're observing. This is dense and kinda silly for this
    # application. Each slice is (i_point,i_camera,i_camera-1). Initially O
    # do everything in camera-0 coordinates, and I do not move the
    # extrinsics
    indices_point_camintrinsics_camextrinsics = np.zeros((Npoints,3), dtype=np.int32)
    indices_point_camintrinsics_camextrinsics[:,0] = \
        np.arange(Npoints,    dtype=np.int32)
    indices_point_camintrinsics_camextrinsics[:,1] = 0
    indices_point_camintrinsics_camextrinsics[:,2] = -1

    optimization_inputs = \
        dict(intrinsics                                = intrinsics_to_values,
             rt_cam_ref                                = None,
             rt_ref_frame                              = None, # no frames. Just points
             points                                    = p,
             observations_board                        = None, # no board observations
             indices_frame_camintrinsics_camextrinsics = None, # no board observations
             observations_point                        = observations_points,
             indices_point_camintrinsics_camextrinsics = indices_point_camintrinsics_camextrinsics,
             lensmodel                                 = lensmodel_to,

             imagersizes                               = nps.atleast_dims(dims, -2),

             # I'm not optimizing the point positions (frames), so these
             # need to be set to be inactive, and to include the ranges I do
             # have
             point_min_range                           = 1e-3,
             point_max_range                           = 1e3,

             # I optimize the lens parameters. That's the whole point
             do_optimize_intrinsics_core               = True,
             do_optimize_intrinsics_distortions        = True,

             do_optimize_extrinsics                    = False,

             # NOT optimizing the observed point positions
             do_optimize_frames                        = False )

    if re.match("LENSMODEL_SPLINED_STEREOGRAPHIC_", lensmodel_to):
        # splined models have a core, but those variables are largely redundant
        # with the spline parameters. So I lock down the core when targetting
        # splined models
        optimization_inputs['do_optimize_intrinsics_core'] = False

    stats = mrcal.optimize(**optimization_inputs,
                           # No outliers. I have the points that I have
                           do_apply_outlier_rejection        = False,
                           verbose                           = False)

    err_rms = stats['rms_reproj_error__pixels']
    print(f"RMS error of this solution: {err_rms} pixels.",
          file=sys.stderr)
    if err_rms < err_rms_best:
        err_rms_best = err_rms
        intrinsics_data_best  = optimization_inputs['intrinsics'][0,:].copy()

if intrinsics_data_best is None:
    print("No valid intrinsics found!", file=sys.stderr)
    sys.exit(1)

if args.num_trials > 1:
    print(f"RMS error of the BEST solution: {err_rms_best} pixels.",
          file=sys.stderr)


m_to = mrcal.cameramodel( intrinsics = (lensmodel_to, intrinsics_data_best.ravel()),
                          imagersize = dims )

note = \
    "generated on {} with   {}\n". \
    format(time.strftime("%Y-%m-%d %H:%M:%S"),
           ' '.join(mrcal.shellquote(s) for s in sys.argv))

m_to.write(file_output, note=note)







if isinstance(file_output, str):
    print(f"Wrote '{file_output}'",
          file=sys.stderr)


if args.viz:

    plotkwargs_extra = {}
    if args.set is not None:
        plotkwargs_extra['set'] = args.set
    if args.unset is not None:
        plotkwargs_extra['unset'] = args.unset

    if args.title is not None:
        plotkwargs_extra['title'] = args.title


    # I compute the reprojections again. Similar to the code used in the solve,
    # but the spacing might be different AND I do not ignore corners
    gridn_width,gridn_height = 80,50

    qx = np.linspace(0, dims[0]-1, gridn_width)
    qy = np.linspace(0, dims[1]-1, gridn_height)

    # shape (Ny,Nx,2)
    q = np.ascontiguousarray( nps.mv( nps.cat(*np.meshgrid(qx,qy)),
                                      0,-1))

    p = fov_unproject_normalized( q,
                                  np.array(args.w_h_fx_fy_cx_cy_s[2:4]),
                                  np.array(args.w_h_fx_fy_cx_cy_s[4:6]),
                                  args.w_h_fx_fy_cx_cy_s[6] )
    q_to = mrcal.project(p, *m_to.intrinsics())
    diff    = q - q_to
    difflen = nps.mag(diff)











    import gnuplotlib as gp

    plotkwargs_extra['hardcopy'] = args.hardcopy
    plotkwargs_extra['terminal'] = args.terminal

    contour_increment     = None
    contour_labels_styles = 'boxed'
    contour_labels_font   = None



    if 'title' not in plotkwargs_extra:
        title = f"Diff in fitted fov model to {lensmodel_to}"
        if args.extratitle is not None:
            title += ": " + args.extratitle
        plotkwargs_extra['title'] = title


    plot_options = plotkwargs_extra
    gp.add_plot_option(plot_options,
                       cbrange = [0,args.cbmax])
    color = difflen

    # Any invalid values (nan or inf) are set to an effectively infinite
    # difference
    color[~np.isfinite(color)] = 1e6

    curve_options = \
        mrcal.visualization._options_heatmap_with_contours(
            # update these plot options
            plotkwargs_extra,

            contour_max           = args.cbmax,
            contour_increment     = contour_increment,
            imagersize            = args.w_h_fx_fy_cx_cy_s[:2],
            gridn_width           = gridn_width,
            gridn_height          = gridn_height,
            contour_labels_styles = contour_labels_styles,
            contour_labels_font   = contour_labels_font,
            do_contours           = True)

    plot_data_args = [ (color, curve_options) ]


    data_tuples = plot_data_args

    plot = gp.gnuplotlib(**plot_options)
    plot.plot(*data_tuples)

    if args.hardcopy is None:
        plot.wait()
