#!/usr/bin/env python

import argparse
import sys

import numpy as np


def build_warp(fname_tx, fname_ty, fname_ref, fname_out):
    try:
        from spinalcordtoolbox.image import Image
    except ImportError:
        sys.exit(
            "ERROR: Could not import spinalcordtoolbox.\n"
            "Make sure you have activated the SCT virtual environment:\n"
            "  source $SCT_DIR/python/bin/activate venv_sct"
        )

    print(f"Loading x-translation parameters : {fname_tx}")
    tx_img = Image(fname_tx)
    if tx_img.data.ndim != 4 or tx_img.data.shape[:2] != (1, 1):
        sys.exit(
            f"ERROR: Expected --warp-x to have shape [1, 1, Z, T], "
            f"got {tx_img.data.shape}.\n"
            "Make sure you are passing the moco_params_x.nii.gz file, "
            "not a full 4D image."
        )
    tx = tx_img.data[0, 0, :, :]

    print(f"Loading y-translation parameters : {fname_ty}")
    ty_img = Image(fname_ty)
    if ty_img.data.ndim != 4 or ty_img.data.shape[:2] != (1, 1):
        sys.exit(
            f"ERROR: Expected --warp-y to have shape [1, 1, Z, T], "
            f"got {ty_img.data.shape}.\n"
            "Make sure you are passing the moco_params_y.nii.gz file, "
            "not a full 4D image."
        )
    ty = ty_img.data[0, 0, :, :]

    print(f"Loading reference image          : {fname_ref}")
    ref = Image(fname_ref)
    if ref.data.ndim != 4:
        sys.exit(
            f"ERROR: Expected --im-ref to be a 4D image, got shape {ref.data.shape}.\n"
            "Pass the original 4D fMRI run as the reference."
        )
    nx, ny, nz, nt = ref.data.shape

    if tx.shape != (nz, nt):
        sys.exit(
            f"ERROR: x-translation parameters have shape {tx.shape} (Z={tx.shape[0]}, T={tx.shape[1]}), "
            f"but the reference image has Z={nz}, T={nt}.\n"
            "Make sure --warp-x, --warp-y, and --im-ref all come from the same fMRI run."
        )
    if ty.shape != (nz, nt):
        sys.exit(
            f"ERROR: y-translation parameters have shape {ty.shape} (Z={ty.shape[0]}, T={ty.shape[1]}), "
            f"but the reference image has Z={nz}, T={nt}.\n"
            "Make sure --warp-x, --warp-y, and --im-ref all come from the same fMRI run."
        )

    print(f"Building 5D warping field [{nx} x {ny} x {nz} x {nt} x 3] ...")
    disp = np.zeros((nx, ny, nz, nt, 3), dtype=np.float32)
    for t in range(nt):
        for z in range(nz):
            disp[:, :, z, t, 0] = tx[z, t]
            disp[:, :, z, t, 1] = ty[z, t]

    im_out = Image(disp, hdr=ref.hdr)
    im_out.affine = ref.affine
    im_out.hdr.set_data_shape(disp.shape)
    im_out.hdr.set_intent('vector', (), '')
    im_out.save(fname_out)

    print(f"Warping field saved to           : {fname_out}")


def main():
    parser = argparse.ArgumentParser(
        prog="build_warp.py",
        description=(
            "Combine sct_fmri_moco x/y translation parameter images into a\n"
            "single 5D warping field (ITK format) for use with sct_apply_transfo.\n"
            "\n"
            "Activate the SCT virtual environment before running:\n"
            "  source $SCT_DIR/python/bin/activate venv_sct"
        ),
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument(
        "--warp-x",
        required=True,
        metavar="FILE",
        help="x-translation parameter image from sct_fmri_moco (moco_params_x.nii.gz).",
    )
    parser.add_argument(
        "--warp-y",
        required=True,
        metavar="FILE",
        help="y-translation parameter image from sct_fmri_moco (moco_params_y.nii.gz).",
    )
    parser.add_argument(
        "--im-ref",
        required=True,
        metavar="FILE",
        help=(
            "Reference 4D NIfTI image (e.g. the original fMRI run). "
            "Its header and affine are used to define the output voxel grid."
        ),
    )
    parser.add_argument(
        "--im-out",
        required=True,
        metavar="FILE",
        help="Output filename for the 5D warping field (e.g. warp_moco.nii.gz).",
    )

    args = parser.parse_args()
    build_warp(
        fname_tx=args.warp_x,
        fname_ty=args.warp_y,
        fname_ref=args.im_ref,
        fname_out=args.im_out,
    )


if __name__ == "__main__":
    main()
