function [route,H,numExpanded] = AStar(input_map, start_coords, dest_coords, path_step_time, neighbourhood_type)
% Run A* algorithm on a grid.
% Inputs : 
%   input_map : a logical array where the freespace cells are false
%   the obstacles are true
%   start_coords and dest_coords : Coordinates of the start and end cell
%   respectively, the first entry is the row and the second the column.
% Output :
%    route : An array containing the linear indices of the cells along the
%    shortest route from start to dest or an empty array if there is no
%    route. This is a single dimensional vector
%    numExpanded: the total number of nodes
%    expanded during search. 
%    H, the matrix of heuristic values. Output for debugging.

% set up color map for display
% 1 - white = clear cell
% 2 - black = obstacle
% 3 - red = visited
% 4 - blue = on list
% 5 - grey = start
% 6 - yellow = destination
% 7 - green = best path

cmap = [1 1 1; ...
        0 0 0; ...
        1 0 0; ...
        0 0 1; ...
        0.5 0.5 0.5; ...
        1 1 0; ...
        0 1 0];

colormap(cmap);

% Update the figure evey time step?
drawMapEveryTime = true;

% Get the size of the map
[nrows, ncols] = size(input_map);

% map - a table that keeps track of the state of each grid cell
map = zeros(nrows,ncols);

map(~input_map) = 1;   % Mark free cells
map(input_map)  = 2;   % Mark obstacle cells

% Generate linear indices of start and dest nodes
start_node = sub2ind(size(map), start_coords(1), start_coords(2));
dest_node  = sub2ind(size(map), dest_coords(1),  dest_coords(2));

map(start_node) = 5;
map(dest_node)  = 6;

% meshgrid will `replicate grid vectors' nrows and ncols to produce
% a full grid
% type `help meshgrid' in the Matlab command prompt for more information
parent = zeros(nrows,ncols);
[X, Y] = meshgrid (1:ncols, 1:nrows);

xd = dest_coords(1);
yd = dest_coords(2);

% Evaluate Heuristic function, H, for each grid cell
% Manhattan distance
H = abs(X - xd) + abs(Y - yd);
H = H';
% Initialize cost arrays
f = Inf(nrows,ncols);
g = Inf(nrows,ncols);

g(start_node) = 0;
f(start_node) = H(start_node);

% keep track of the number of nodes that are expanded
numExpanded = 0;

% Main Loop

while true
    
    % Draw current map
    map(start_node) = 5;
    map(dest_node) = 6;
    
    % make drawMapEveryTime = true if you want to see how the 
    % nodes are expanded on the grid. 
    if (drawMapEveryTime)
        image(1.5, 1.5, map);
        grid on;
        axis image;
        drawnow;
    end
    
    % Find the node with the minimum f value
    [min_f, current] = min(f(:));
    
    if ((current == dest_node) || isinf(min_f))
        break;
    end;
    
    % Update input_map
    map(current) = 3;
    f(current) = Inf; % remove this node from further consideration
    
    % Compute row, column coordinates of current node
    [i, j] = ind2sub(size(f), current);
    
    % Get the list of neighbours
    neighbours = getNeighbours(g, current, neighbourhood_type);
    
     % If the neighbour cell is clear then add to the list of locations
     % to visit. Mark the locations to visit with the current cell as
     % their parent.
    for k = 1:numel(neighbours)  
    
        neighbour = neighbours(k);
        if map(neighbour) == 1 || map(neighbour) == 6
            map(neighbour) = 4;
            g(neighbour) = min(g(neighbour), g(current)+1);
            f(neighbour) = min(f(neighbour), g(current)+H(current));
            
            if ~parent(neighbour) | f(parent(neighbour)) > f(current)
                parent(neighbour) = current;
            end
        end
 
    end
    
    % Keep track of the number of map locations considered
    numExpanded = numExpanded + 1;
    
end

%% Construct route from start to dest by following the parent links
if (isinf(f(dest_node)))
    route = [];
else
    route = [dest_node];
    
    while (parent(route(1)) ~= 0)
        route = [parent(route(1)), route];
    end

    % Visualise the path
    for k = 2:length(route) - 1        
        map(route(k)) = 7;
        pause(path_step_time);
        image(1.5, 1.5, map);
        grid on;
        axis image;
    end
end

end

function neighbours = getNeighbours(distanceFromStart, node, neighbourhood_type)
    
    % Get the 2D coords of the node we are interested in
    [x,y] = ind2sub(size(distanceFromStart), node);
    node_coords = [x,y];
    
    % Define the neighbourhood
    moore_neighbourhood = [[-1,-1]; [-1,0]; [-1,1]; [0,-1]; [1,0]; [1,1]; [0,1]; [1,-1]];
    von_neumann_neighbourhood = [[-1,0]; [1,0]; [0,-1]; [0,1]];
    
    % Select the type of neighbourhood the caller specified
    if strcmp(neighbourhood_type, 'moore')
        neighbourhood = moore_neighbourhood;
    else
        neighbourhood = von_neumann_neighbourhood;
    end
        
    % initialize the list of neighbours
    neighbours = [];
    
    % Loop through the potential neighbours and add them to the list. If a
    % location is an obstacle skip it. 
    for i = 1:size(neighbourhood,1)
        proposed_coords = node_coords+neighbourhood(i,:);
        
        % Check that the coords are in the map
        if proposed_coords(1) > 0 & proposed_coords(2) > 0 & proposed_coords(1) <= size(distanceFromStart,1) & proposed_coords(2) <= size(distanceFromStart,2)
            neighbour_coords = proposed_coords;
            neighbours = [neighbours, sub2ind(size(distanceFromStart),neighbour_coords(1), neighbour_coords(2))];
        end
   end
end