Skip to content
Snippets Groups Projects
Distances.ipynb 6.26 KiB
Newer Older
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us write some functions to create points in the 2D plane. We start with a well-known configuration ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def points_house():\n",
    "    return np.asarray([(0,0),(2,0),(0,2),(2,2),(1,3)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "... and visualise it in a plot:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "p = points_house()\n",
    "\n",
    "plt.subplots()[1].set_aspect(1)\n",
    "plt.grid()\n",
    "plt.scatter(*zip(*p))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Well done. Let's add some more functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "from sklearn.datasets import make_blobs\n",
    "\n",
    "# Randomly create count points within a width x height rectangle \n",
    "def points_rectangle(count, width, height):\n",
    "    return np.asarray([(random.random()*width, random.random()*height) for i in range(count)])\n",
    "\n",
    "# Randomly create counts points within a circle \n",
    "def points_circle(count):\n",
    "    X, y = make_blobs(count, centers=[[0,0]])\n",
    "    return X\n",
    "\n",
    "# Create evenly spaced points on a quadratic lattice of size k\n",
    "def points_lattice(k):\n",
    "    return np.asarray([(i,j) for j in range(k) for i in range(k)])\n",
    "\n",
    "r = points_rectangle(10, 2, 3)\n",
    "c = points_circle(100)\n",
    "l = points_lattice(5)\n",
    "\n",
    "plt.subplots()[1].set_aspect(1)\n",
    "plt.grid()\n",
    "plt.scatter(*zip(*r), label=\"rectangle\")\n",
    "plt.scatter(*zip(*c), label=\"circle\")\n",
    "plt.scatter(*zip(*l), label=\"lattice\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from scipy.spatial import distance_matrix\n",
    "from scipy.linalg import svd\n",
    "import icp\n",
    "\n",
    "def rmse(a, b):\n",
    "    return np.sqrt(((a - b) ** 2).mean())\n",
    "\n",
    "# apply rotation R and translation t to A\n",
    "def apply(A, R, t):\n",
    "    result = np.empty_like(A)\n",
    "    for i in range(np.shape(A)[0]):\n",
    "        result[i] = R @ A[i] + t\n",
    "    return result\n",
    "\n",
    "# see https://math.stackexchange.com/questions/156161/\n",
    "def get_m(d):\n",
    "    m = np.empty_like(d)\n",
    "    shp = np.shape(d)\n",
    "    for i in range(shp[0]):\n",
    "        for j in range(shp[1]):\n",
    "            m[i,j] = (d[1,j]**2 + d[i,1]**2 - d[i,j]**2) / 2\n",
    "    return m\n",
    "\n",
    "# cf. https://math.stackexchange.com/questions/156161/\n",
    "def points_from_distances(d):\n",
    "    # create intermediate matrices\n",
    "    m = get_m(d)\n",
    "    # print(\"m =\", m, sep=\"\\n\")\n",
    "    # eigenvalue decomposition M = USU'\n",
    "    u, s, v = svd(m, full_matrices=True)\n",
    "    # print(\"s = \", s)\n",
    "    # re-estimate points\n",
    "    x = u @ np.sqrt(np.diag(s))\n",
    "\n",
    "    # extract points\n",
    "    q = x[:,0:2]\n",
    "    print(\"q = \", q, sep=\"\\n\")\n",
    "\n",
    "    return q\n",
    "\n",
    "def plot_points(p, q, qr):\n",
    "    fig, ax = plt.subplots()\n",
    "    plt.grid()\n",
    "    plt.scatter(*zip(*p), label=\"points\")\n",
    "    # plt.scatter(*zip(*q), label=\"restored points\")\n",
    "    plt.scatter(*zip(*qr), label=\"restored and rotated points\")\n",
    "    ax.legend()\n",
    "    ax.set_aspect(1)\n",
    "    plt.show()\n",
    "    \n",
    "def noise_and_restore(p, scale):\n",
    "    print(\"p =\", p, sep=\"\\n\")\n",
    "    \n",
    "    # measure their distances\n",
    "    d = distance_matrix(p, p)\n",
    "    print(\"d =\", d, sep=\"\\n\")\n",
    "    \n",
    "    # add noise\n",
    "    noise = np.random.normal(0, scale, (np.shape(d)))\n",
    "    print(\"noise =\", noise, sep=\"\\n\")\n",
    "    d += noise\n",
    "    print(\"d =\", d, sep=\"\\n\")\n",
    "    d_rmse = rmse(noise, np.zeros_like(noise))\n",
    "    print(\"RMSE distances:\", d_rmse)\n",
    "\n",
    "    # restore points\n",
    "    q = points_from_distances(d)\n",
    "\n",
    "    # rotate points (https://github.com/ClayFlannigan/icp)\n",
    "    T, R, t = icp.best_fit_transform(q, p)\n",
    "    print(\"T =\", T, sep=\"\\n\")\n",
    "    qr = apply(q, R, t)\n",
    "    print(\"qr =\", qr, sep=\"\\n\")\n",
    "\n",
    "    return q, qr\n",
    "    \n",
    "if __name__ == '__main__':\n",
    "    # create random points\n",
    "    # p = points_rectangle(5, 2, 4)\n",
    "    # p = points_circle(100)\n",
    "    # p = points_house()\n",
    "    # p = points_fixed()\n",
    "    p = points_lattice(10)\n",
    "\n",
    "    x = []\n",
    "    y = []\n",
    "    for e in np.arange(-1, -10, -1.0):\n",
    "        q, qr = noise_and_restore(p, 10**e)\n",
    "    \n",
    "        # compute distance between new and old points\n",
    "        delta = p - qr\n",
    "        print(\"p - qr =\", delta, sep=\"\\n\")\n",
    "\n",
    "        # compute difference\n",
    "        qr_rmse = rmse(p, qr)\n",
    "        print(\"RMSE qr:\", qr_rmse)\n",
    "\n",
    "        x.append(e)\n",
    "        y.append(qr_rmse)\n",
    "\n",
    "    # plot RMSE\n",
    "    plt.plot(x, y)\n",
    "    plt.xlabel(\"Gaussian scale (?)\")\n",
    "    plt.ylabel(\"RMSE\")\n",
    "    plt.grid()\n",
    "    plt.show()\n",
    "    \n",
    "    # plot points\n",
    "    # plot_points(p, q, qr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}