{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Neuroimaging non-cartesian reconstruction using Stacked3DNFFT\n\nAuthor: Chaithya G R\n\nIn this tutorial we will reconstruct an MRI image from non-cartesian kspace\nmeasurements, using Stacked3D NonCartesianFFT.\n\n## Import neuroimaging data\n\nWe use the toy datasets available in pysap, more specifically a 3D Orange.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Package import\nfrom mri.operators import Stacked3DNFFT, WaveletN\nfrom mri.operators.utils import convert_locations_to_mask, \\\n    gridded_inverse_fourier_transform_stack, get_stacks_fourier, \\\n    convert_mask_to_locations\nfrom mri.reconstructors import SingleChannelReconstructor\nimport pysap\nfrom pysap.data import get_sample_data\n\n# Third party import\nfrom modopt.math.metrics import ssim\nfrom modopt.opt.linear import Identity\nfrom modopt.opt.proximity import SparseThreshold\nimport numpy as np\n\n# Loading input data\nimage = get_sample_data('3d-pmri')\nimage = pysap.Image(data=np.sqrt(np.sum(np.abs(image.data)**2, axis=0)))\n\n# Reducing the size of the volume for faster computation\nimage.data = image.data[:, :, 48: -48]\n\n# Obtain MRI non-cartesian sampling plane\nmask_radial = get_sample_data(\"mri-radial-samples\")\n\n# Tiling the plane on the z-direction\n# sampling_z = np.ones(image.shape[2])  # no sampling\nsampling_z = np.random.randint(2, size=image.shape[2])  # random sampling\nsampling_z[22: 42] = 1\nNz = sampling_z.sum()  # Number of acquired plane\n\nz_locations = np.repeat(convert_mask_to_locations(sampling_z),\n                        mask_radial.shape[0])\nz_locations = z_locations[:, np.newaxis]\nkspace_loc = np.hstack([np.tile(mask_radial.data, (Nz, 1)),\n                        z_locations])\nmask = pysap.Image(data=np.moveaxis(\n    convert_locations_to_mask(kspace_loc, image.shape), -1, 0))\n\n# View Input\n# image.show()\n# mask.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Generate the kspace\n\nFrom the 2D brain slice and the acquisition mask, we retrospectively\nundersample the k-space using a radial acquisition mask\nWe then reconstruct the zero order solution as a baseline\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Get the locations of the kspace samples and the associated observations\nfourier_op = Stacked3DNFFT(kspace_loc=kspace_loc,\n                           shape=image.shape,\n                           implementation='cpu',\n                           n_coils=1)\nkspace_obs = fourier_op.op(image.data)\n\n# Gridded solution\ngrid_space = [np.linspace(-0.5, 0.5, num=img_shape)\n              for img_shape in image.shape[:-1]]\ngrid = np.meshgrid(*tuple(grid_space))\nkspace_plane_loc, z_sample_loc, sort_pos, idx_mask_z = get_stacks_fourier(\n    kspace_loc,\n    image.shape)\ngrid_soln = gridded_inverse_fourier_transform_stack(\n    kspace_data_sorted=kspace_obs[sort_pos],\n    kspace_plane_loc=kspace_plane_loc,\n    idx_mask_z=idx_mask_z,\n    grid=tuple(grid),\n    volume_shape=image.shape,\n    method='linear')\n\nimage_rec0 = pysap.Image(data=grid_soln)\n# image_rec0.show()\nbase_ssim = ssim(image_rec0, image)\nprint('The Base SSIM is : ' + str(base_ssim))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## FISTA optimization\n\nWe now want to refine the zero order solution using a FISTA optimization.\nThe cost function is set to Proximity Cost + Gradient Cost\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# TODO get the right mu operator\n# Setup the operators\nlinear_op = WaveletN(wavelet_name=\"sym8\", nb_scales=4, dim=3)\nregularizer_op = SparseThreshold(Identity(), 6 * 1e-9, thresh_type=\"soft\")\n# Setup Reconstructor\nreconstructor = SingleChannelReconstructor(\n    fourier_op=fourier_op,\n    linear_op=linear_op,\n    regularizer_op=regularizer_op,\n    gradient_formulation='synthesis',\n    verbose=1,\n)\n# Start Reconstruction\nx_final, costs, metrics = reconstructor.reconstruct(\n    kspace_data=kspace_obs,\n    optimization_alg='fista',\n    num_iterations=10,\n)\nimage_rec = pysap.Image(data=np.abs(x_final))\n# image_rec.show()\nrecon_ssim = ssim(image_rec, image)\nprint('The Reconstruction SSIM is : ' + str(recon_ssim))"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}