package lu.tudor.santec.gecamed.core.gui.utils;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.UnknownHostException;
import java.util.Collection;
import java.util.LinkedList;
import java.util.TreeSet;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;



/**
 * @author jens.ferring(at)tudor.lu
 * 
 * Scans the defined ports of the defined host in separate threads. 
 * After each port was scanned it is put into the open or close list.
 * 
 * Therefore you can test fast and easy if there are all ports open 
 * your application needs.
 */
public class PortScanner 
{
	/* ======================================== */
	// 		FINAL MEMBERS
	/* ======================================== */
	
	/**
	 * The default socket timeout in milliseconds
	 */
	private static final int DEFAULT_SOCKET_TIMEOUT 	= 1200;
	
	/**
	 * The default no of threads created to scan the ports simultaneous
	 */
	private static final int DEFAULT_SIMULTANEOUS_SCANS = 5000;
	
	
	
	/* ======================================== */
	// 		MEMBERS
	/* ======================================== */
	
	/** the logger Object for this class */
	private static Logger 	logger 			= Logger.getLogger(PortScanner.class.getName());
	
	private static int 		socketTimeout 	= DEFAULT_SOCKET_TIMEOUT;
	
	
	private String 				host;
	
	private int[] 				portsToScan;
	
	private int 				scanIndex;
	
	private Collection<Integer> openPorts;
	
	private Collection<Integer> closedPorts;
	
	private Scan[] 				scans;
	
	private Boolean 			unknownHostException = false;
	
	private long 				scantime = -1;
	
	private int 				scansRunning;
	
	private int 				maxScans = -1;
	
	
	
	/* ======================================== */
	// 		CONSTRUCTORS
	/* ======================================== */
	
	/**
	 * @param host The host to scan ports on
	 * @param ports The ports to check if open
	 */
	public PortScanner (String host, int ... ports) 
	{
		this (DEFAULT_SIMULTANEOUS_SCANS, host, ports);
	}
	
	
	/**
	 * @param simultaneousScans The number of threads created, to scan the ports simultaneous.
	 * @param host The host to scan ports on
	 * @param ports The ports to check if open
	 */
	public PortScanner (int simultaneousScans, String host, int ... ports)
	{
		if (ports == null || ports.length <= 0)
			throw new WrongArgumentException();
		this.host 			= host;
		this.portsToScan 	= ports;
		this.scanIndex 		= 0;
		this.openPorts 		= new TreeSet<Integer>();
		this.closedPorts 	= new LinkedList<Integer>();
		
		if (simultaneousScans > ports.length)
			simultaneousScans = ports.length;
		else if (simultaneousScans <= 0)
			simultaneousScans = 1;
		
		this.scans 			= new Scan[simultaneousScans];
	}
	
	
	
	/* ======================================== */
	// 		CLASS BODY
	/* ======================================== */
	
	/**
	 * Creates a new PortScanner and starts scanning. Will wait until all ports are scanned.
	 * 
	 * @param host The host to scan ports on
	 * @param ports The ports to check if open
	 * @return The PortScanner used for scanning.
	 */
	public static PortScanner scan (String host, int ... ports)
	{
		return scan(DEFAULT_SIMULTANEOUS_SCANS, host, ports);
	}
	
	
	/**
	 * Creates a new PortScanner and starts scanning. Will wait until all ports are scanned.
	 * 
	 * @param simultaneousScans The number of threads created, to scan the ports simultaneous.
	 * @param host The host to scan ports on
	 * @param ports The ports to check if open
	 * @return The PortScanner used for scanning.
	 */
	public static PortScanner scan (int simultaneousScans, String host, int ... ports)
	{
		return scan(new PortScanner(simultaneousScans, host, ports));
	}
	
	
	/**
	 * Creates a number of threads that starts scanning the defined ports. 
	 */
	public synchronized boolean startScanning ()
	{
		Scan 	scan;
		int 	i 			= 0;
		
		try 
		{
			if (scansRunning > 0)
			{
				logger.log(Level.WARN, "There is still a scan running on this object. " +
						"The scaning is therefore canceled.");
				return false;
			}
			
			maxScans 		= -1;
			scansRunning 	= 0;
			
			while (i < scans.length)
			{
				scan 	= new Scan();
				scans[i++]= scan;
				scan.start();
				scanStarted();
			}
		}
		catch (OutOfMemoryError e)
		{
			if (--i <= 0)
				throw e;
			else
			{
				maxScans 	= i / 2 + 1;
				logger.log(Level.WARN, 
						e.getClass() + ": " +
						e.getMessage() + 
						"\n\tRunning with " + i + 
						" scan threads only, reducing threads to " + maxScans + ".");
			}
		}
		catch (Throwable e)
		{
			logger.log(Level.ERROR, e.getMessage(), e);
			return false;
		}
		return true;
	}
	
	
	/**
	 * @return <code>true</code> if at least one thread is still scanning a port else <code>false</code>.
	 */
	public boolean isScanning ()
	{
		if (scansRunning == 0)
			 return false;
		else return true;
	}
	
	
	/**
	 * @return All ports that are open or <code>null</code> if the ports are still being scanned.
	 */
	public Collection<Integer> getOpenPorts ()
	{
		if (isScanning())
			 return null;
		else return openPorts;
	}
	

	/**
	 * @return All ports that are closed or <code>null</code> if the ports are still being scanned.
	 */
	public Collection<Integer> getClosedPorts ()
	{
		if (isScanning())
			 return null;
		else return closedPorts;
	}
	
	
	/**
	 * @return The time in milliseconds that was needed for scanning.
	 */
	public long getScanTime ()
	{
		return scantime;
	}
	
	
	/**
	 * @return The time set to wait for response when connecting the socket.
	 */
	public static int getSocketTimeout ()
	{
		return socketTimeout;
	}
	
	
	/**
	 * @param timeout The time to wait for response when connecting the socket.
	 */
	public static void setSocketTimeout (int timeout)
	{
		socketTimeout = timeout;
	}
	
	
	
	/* ======================================== */
	// 		HELP METHODS
	/* ======================================== */
	
	/**
	 * Creates a new PortScanner and starts scanning. Will wait until all ports are scanned.
	 * 
	 * @param ps The PortScanner used for scanning.
	 * @return The PortScanner used for scanning or <code>null</code> if the scan couldn't be started.
	 */
	private static PortScanner scan(PortScanner ps) 
	{
		logger.log(Level.INFO, "start port scanning ...");
		long time = System.currentTimeMillis();
		
		if (!ps.startScanning())
			return null;
		
		// wait for the threads to be finished with scanning
		while (ps.isScanning()) 
		{
			try
			{
				Thread.sleep(100);
			} 
			catch (InterruptedException e)
			{
				logger.log(Level.ERROR, e.getMessage(), e);
			}
		}
		
		// take the time necessary for scanning
		time = System.currentTimeMillis() - time;
		ps.scantime = time;
		
		logger.log(Level.INFO, "port scanning took " + time + " ms");
		
		return ps;
	}
	
	
	/**
	 * A new scan has started
	 */
	private void scanStarted ()
	{
		changeScansRunning(+1);
	}
	
	
	/**
	 * A scan has finished
	 */
	private void scanFinished ()
	{
		changeScansRunning(-1);
	}
	
	
	/**
	 * @param added Change the number of scans by this.
	 */
	private synchronized void changeScansRunning (int added)
	{
		scansRunning = scansRunning + added;
	}
	
	
	/**
	 * @return The next port to scan or -1 if there is no port left
	 */
	private synchronized int getNextPort ()
	{
		if (maxScans >= 0 && scansRunning > maxScans)
			// reduce the number of running threads to maxScans
			return -1;
		else if (scanIndex < portsToScan.length)
			// there unchecked ports left, return the next one
			return portsToScan[scanIndex++];
		else 
			// no unchecked ports left
			return -1;
	}
	
	
	/**
	 * Adds the given port to the list of ports that are NOT open.
	 * 
	 * @param port The port to add
	 */
	private synchronized void addClosedPort (int port)
	{
		closedPorts.add(Integer.valueOf(port));
	}
	

	/**
	 * Adds the given port to the list of ports that are open.
	 * 
	 * @param port The port to add
	 */
	private synchronized void addOpenPort (int port)
	{
		openPorts.add(Integer.valueOf(port));
	}
	
	
	
	/* ======================================== */
	// 		CLASS: SCAN
	/* ======================================== */
	
	/**
	 * @author jens.ferring(at)tudor.lu
	 * 
	 * The Thread that does the port scanning.
	 */
	private class Scan extends Thread
	{
		/* ======================================== */
		// 		RUNNABLE
		/* ======================================== */
		
		@Override
		public void run()
		{
			int 				port;
			boolean 			open;
			InetSocketAddress 	address;
			Socket 				socket = new Socket();
			
			try
			{
				while ((port = getNextPort()) >= 0)
				{
					/* Perform the scan on the given port, 
					 * while there are ports left to scan 
					 * (if port == -1, there are no ports left)
					 */
					try
					{
						address = new InetSocketAddress(host, port);
						socket.connect(address, socketTimeout);
						
						// socket is open 
						open = true;
					} 
					catch (UnknownHostException e)
					{
						throw e;
					}
					catch (IOException e)
					{
						// socket is in use
						open = false;
					}
					
					if (open)
						 addOpenPort(port);
					else addClosedPort(port);
				}
			}
			catch (UnknownHostException e)
			{
				synchronized (unknownHostException)
				{
					if (!unknownHostException)
					{
						unknownHostException = true;
						logger.log(Level.ERROR, "The host is not available");
					}
				}
			}
			finally
			{
				// notify that this scan is finished, so the no of running scans is reduced
				scanFinished();
				try
				{
					socket.close();
				}
				catch (IOException e)
				{
					logger.log(Level.ERROR, e.getMessage(), e);
				}
			}
		}
	}
	
	
	
	/* ======================================== */
	// 		CLASS: WRONG ARGUMENT EXCEPTION
	/* ======================================== */
	
	public class WrongArgumentException extends RuntimeException
	{
		private static final long serialVersionUID = 1L;

		public WrongArgumentException()
		{
			super("There must at least be 1 port defined to scan.");
		}
	}
	
	
	
	/* ======================================== */
	// 		MAIN FOR TESTING
	/* ======================================== */
	
	public static void main(String[] args)
	{
//		int from 	= 10000;
//		int to 		= 11000;
//		
//		int[] ports = new int[to-from+1];
//		
//		for (int i = 0; i < ports.length; i++)
//			ports[i] = from + i;
//		
//		PortScanner ps = PortScanner.scan(3000, "localhost", ports);
		PortScanner ps = PortScanner.scan("localhost", 1098, 1099, 3837, 4444, 8080, 8093);
//		PortScanner ps = PortScanner.scan("localhost", 80,135,443,445,664,1064,1187,1188,1190,1257,1258,1322,1324,1326);
		
		Collection<Integer> openPorts 	= ps.getOpenPorts();
		Collection<Integer> closedPorts = ps.getClosedPorts();
		
		System.out.println(openPorts.size() + " ports open, " + closedPorts.size() + " ports closed");
		
		StringBuilder output;
		output = new StringBuilder("open port: ");
		for (Integer port : openPorts)
			output.append("\n   ").append(port);
		System.out.println(output.toString());
//		
//		System.out.println();
//		
//		output = new StringBuilder("closed ports: ");
//		for (Integer port : closedPorts)
//			output.append("\n   ").append(port);
//		System.out.println(output.toString());
	}
}
