/*
	Copyright (C) 2004-2005 Christopher E. Miller

	This software is provided 'as-is', without any express or implied
	warranty.  In no event will the authors be held liable for any damages
	arising from the use of this software.

	Permission is granted to anyone to use this software for any purpose,
	including commercial applications, and to alter it and redistribute it
	freely, subject to the following restrictions:

	1. The origin of this software must not be misrepresented; you must not
	   claim that you wrote the original software. If you use this software
	   in a product, an acknowledgment in the product documentation would be
	   appreciated but is not required.
	2. Altered source versions must be plainly marked as such, and must not
	   be misrepresented as being the original software.
	3. This notice may not be removed or altered from any source
	   distribution.

	socket.d 1.3
	Jan 2005

 	Thanks to Benjamin Herr for his assistance.
*/

/*
	-- based on socked.d 1.3 by Christopher E. Miller and Benjamin Herr --
	2006-02-04 Thomas Kuehne - thomas(at)kuehne.cn
		- added InternetAddress4 and InternetAddress6
		- added IPv6 handling to InternetAddress and InternetHost
		- extended unittests
		- added TcpSocket.this(char[] host, ushort port) short cut
*/

module std.socket;

private import std.string, std.stdint, std.c.stdlib;


version(linux)
{
	version = BsdSockets;
}

version(Win32)
{
	private import std.c.windows.windows, std.c.windows.winsock;
	private alias std.c.windows.winsock.timeval _ctimeval;

	typedef SOCKET socket_t = INVALID_SOCKET;
	private const int _SOCKET_ERROR = SOCKET_ERROR;


	private int _lasterr()
	{
		return WSAGetLastError();
	}
}
else version(BsdSockets)
{
	version(linux)
	{
		private import std.c.linux.linux, std.c.linux.socket;
		private alias std.c.linux.linux.timeval _ctimeval;
	}

	typedef int32_t socket_t = -1;
	private const int _SOCKET_ERROR = -1;


	private int _lasterr()
	{
		return getErrno();
	}

}
else
{
	static assert(0); // No socket support yet.
}


class SocketException: Exception
{
	int errorCode; // Platform-specific error code.


	this(char[] msg, int err = 0)
	{
		errorCode = err;

		version(linux)
		{
			if(errorCode > 0)
			{
				char* cs;
				size_t len;

				cs = strerror(errorCode);
				len = strlen(cs);

				if(cs[len - 1] == '\n')
					len--;
				if(cs[len - 1] == '\r')
					len--;
				msg = msg ~ ": " ~ cs[0 .. len];
			}
		}

		super(msg);
	}
}


static this()
{
	version(Win32)
	{
		WSADATA wd;

		// Winsock will still load if an older version is present.
		// The version is just a request.
		int val;
		val = WSAStartup(0x2020, &wd);
		if(val) // Request Winsock 2.2 for IPv6.
			throw new SocketException("Unable to initialize socket library", val);
	}
}


static ~this()
{
	version(Win32)
	{
		WSACleanup();
	}
}


enum AddressFamily: int
{
	UNSPEC =     AF_UNSPEC,
	UNIX =       AF_UNIX,
	INET =       AF_INET,
	IPX =        AF_IPX,
	APPLETALK =  AF_APPLETALK,
	INET6 =      AF_INET6,
}


enum SocketType: int
{
	STREAM =     SOCK_STREAM,
	DGRAM =      SOCK_DGRAM,
	RAW =        SOCK_RAW,
	RDM =        SOCK_RDM,
	SEQPACKET =  SOCK_SEQPACKET,
}


enum ProtocolType: int
{
	IP =    IPPROTO_IP,
	ICMP =  IPPROTO_ICMP,
	IGMP =  IPPROTO_IGMP,
	GGP =   IPPROTO_GGP,
	TCP =   IPPROTO_TCP,
	PUP =   IPPROTO_PUP,
	UDP =   IPPROTO_UDP,
	IDP =   IPPROTO_IDP,
	IPV6 =  IPPROTO_IPV6,
}


class Protocol
{
	ProtocolType type;
	char[] name;
	char[][] aliases;


	void populate(protoent* proto)
	{
		type = cast(ProtocolType)proto.p_proto;
		name = std.string.toString(proto.p_name).dup;

		int i;
		for(i = 0;; i++)
		{
			if(!proto.p_aliases[i])
				break;
		}

		if(i)
		{
			aliases = new char[][i];
			for(i = 0; i != aliases.length; i++)
			{
				aliases[i] = std.string.toString(proto.p_aliases[i]).dup;
			}
		}
		else
		{
			aliases = null;
		}
	}


	bit getProtocolByName(char[] name)
	{
		protoent* proto;
		proto = getprotobyname(toStringz(name));
		if(!proto)
			return false;
		populate(proto);
		return true;
	}


	// Same as getprotobynumber().
	bit getProtocolByType(ProtocolType type)
	{
		protoent* proto;
		proto = getprotobynumber(type);
		if(!proto)
			return false;
		populate(proto);
		return true;
	}
}


unittest
{
	Protocol proto = new Protocol;
	assert(proto.getProtocolByType(ProtocolType.TCP));
	printf("About protocol TCP:\n\tName: %.*s\n", proto.name);
	foreach(char[] s; proto.aliases)
	{
		printf("\tAlias: %.*s\n", s);
	}
}


class Service
{
	char[] name;
	char[][] aliases;
	ushort port;
	char[] protocolName;


	void populate(servent* serv)
	{
		name = std.string.toString(serv.s_name).dup;
		port = ntohs(serv.s_port);
		protocolName = std.string.toString(serv.s_proto).dup;

		int i;
		for(i = 0;; i++)
		{
			if(!serv.s_aliases[i])
				break;
		}

		if(i)
		{
			aliases = new char[][i];
			for(i = 0; i != aliases.length; i++)
			{
				aliases[i] = std.string.toString(serv.s_aliases[i]).dup;
			}
		}
		else
		{
			aliases = null;
		}
	}


	bit getServiceByName(char[] name, char[] protocolName)
	{
		servent* serv;
		serv = getservbyname(toStringz(name), toStringz(protocolName));
		if(!serv)
			return false;
		populate(serv);
		return true;
	}


	// Any protocol name will be matched.
	bit getServiceByName(char[] name)
	{
		servent* serv;
		serv = getservbyname(toStringz(name), null);
		if(!serv)
			return false;
		populate(serv);
		return true;
	}


	bit getServiceByPort(ushort port, char[] protocolName)
	{
		servent* serv;
		serv = getservbyport(port, toStringz(protocolName));
		if(!serv)
			return false;
		populate(serv);
		return true;
	}


	// Any protocol name will be matched.
	bit getServiceByPort(ushort port)
	{
		servent* serv;
		serv = getservbyport(port, null);
		if(!serv)
			return false;
		populate(serv);
		return true;
	}
}


unittest
{
	Service serv = new Service;
	if(serv.getServiceByName("epmap", "tcp"))
	{
		printf("About service epmap:\n\tService: %.*s\n\tPort: %d\n\tProtocol: %.*s\n",
			serv.name, serv.port, serv.protocolName);
		foreach(char[] s; serv.aliases)
		{
			printf("\tAlias: %.*s\n", s);
		}
	}
	else
	{
		printf("No service for epmap.\n");
	}
}


class HostException: Exception
{
	int errorCode;


	this(char[] msg, int err = 0)
	{
		errorCode = err;
		super(msg);
	}
}

class InternetHost{
	char[] name;
	char[][] aliases;
	InternetAddress[] addrList;

	void validHostent(hostent* he){
		if(he.h_addrtype == AddressFamily.INET){
			if(he.h_length != 4){
				throw new AddressException("Bad IPv4 address");
			}
		}else if(he.h_addrtype == AddressFamily.INET6){
			if(he.h_length != 16){
				throw new AddressException("Bad IPv6 address");
			}
		}else{
			throw new HostException("unsupported address family", _lasterr());
		}
	}

	void addAddress(InternetAddress newAddress){
		foreach(InternetAddress oldAddress; addrList){
			if(oldAddress == newAddress){
				return;
			}
		}
		addrList ~= newAddress;
	}
	
	void populate(hostent* he){
		size_t i;

		name = std.string.toString(he.h_name).dup;

		for(i = 0; ; i++){
			if(! he.h_aliases[i]){
				break;
			}
		}

		if(i){
			aliases = new char[][i];
			for(i = 0; i < aliases.length; i++){
				aliases[i] = std.string.toString(he.h_aliases[i]).dup;
			}
		}else{
			aliases = null;
		}


		addrList.length = 0;
		for(i = 0; he.h_addr_list[i]; i++){
			if((he.h_addrtype == AddressFamily.INET) || (he.h_addrtype == AddressFamily.INET6)){
				addAddress(new InternetAddress((cast(ubyte*)he.h_addr_list[i])[0 .. he.h_length]));
			}
		}

		if(addrList.length < 1){
			throw new AddressException("found no supported addresses");
		}
	}

	void populate(addrinfo* ai, char[] name = null){
		this.name = name;
		aliases.length = 0;
		addrList.length = 0;

		while(ai){
			if(ai.ai_canonname){
				aliases ~= std.string.toString(ai.ai_canonname);
			}
			if((ai.ai_family == AF_INET) || (ai.ai_family == AF_INET6)){
				addAddress(new InternetAddress(ai.ai_addr));
			}
			ai = ai.ai_next;
		}

		if(addrList.length < 1){
			throw new AddressException("found no supported addresses");
		}
	}

	bit getHostByName(char[] name, char[] service){
		return getHostByName(name, null, service);
	}
	
	bit getHostByName(char[] name, addrinfo* hint = null, char[] service = null){
		addrinfo* res;

		int error = getaddrinfo(toStringz(name), service ? toStringz(service) : null, hint, &res);

		if(error != 0){
			return false;
		}

		populate(res, name);

		return true;
	}

	/// $(D_COMMENT IPv4 only)
	bit getHostByAddr(uint addr){
		name.length = 0;
		aliases.length = 0;
		addrList.length = 1;

		addrList[0] = new InternetAddress(addr);
		name = addrList[0].toAddrString();
		return true;
	}

	bit getHostByAddr(char[] addr){
		addrinfo* hint = new addrinfo;
		hint.ai_family = AF_UNSPEC;
		hint.ai_flags = AI_NUMERICHOST;
		return getHostByName(addr, hint);
	}

	char[] toString(){
		char[] back=name;
		foreach(char[] a; aliases){
			if(back.length > 0){
				back ~= ", ";
			}
			back ~= a;
		}
		foreach(InternetAddress addr; addrList){
			if(back.length > 0){
				back ~= ", ";
			}
			back ~= addr.toAddrString();
		}
		return back;
	}
}

unittest
{
	InternetHost ih = new InternetHost;
	assert(ih.getHostByName("localhost"));
	printf("addrList.length = %d\n", ih.addrList.length);
	assert(ih.addrList.length);
	InternetAddress ia = ih.addrList[0];
	printf("IP address = %.*s\nname = %.*s\n", ia.toAddrString(), ih.name);
	foreach(int i, char[] s; ih.aliases)
	{
		printf("aliases[%d] = %.*s\n", i, s);
	}

	printf("---\n");

	assert(ih.getHostByAddr(ih.addrList[0].toAddrString()));
	printf("name = %.*s\n", ih.name);
	foreach(int i, char[] s; ih.aliases)
	{
		printf("aliases[%d] = %.*s\n", i, s);
	}
}


class AddressException: Exception
{
	this(char[] msg)
	{
		super(msg);
	}
}


abstract class Address
{
	protected sockaddr* name();
	protected int nameLen();
	AddressFamily addressFamily();
	char[] toString();
}


class UnknownAddress: Address
{
	protected:
	sockaddr sa;


	sockaddr* name()
	{
		return &sa;
	}


	int nameLen()
	{
		return sa.sizeof;
	}


	public:
	AddressFamily addressFamily()
	{
		return cast(AddressFamily)sa.sa_family;
	}


	char[] toString()
	{
		return "Unknown";
	}
}

class InternetAddress : Address{
	const uint ADDR_ANY = INADDR_ANY;
	const uint ADDR_NONE = INADDR_NONE;
	const ubyte[16] ADDR_ANY6 = 0;

	const ushort PORT_ANY = 0;

	protected bool DUMMY;
	protected this(bool dummy){
	}

	this(){
		this = new InternetAddress4();
	}

	this(sockaddr* addr){
		if(addr.sa_family == AddressFamily.INET){
			this = new InternetAddress4(cast(sockaddr_in*) addr);
		}else if(addr.sa_family == AddressFamily.INET6){
			this = new InternetAddress6(cast(sockaddr_in6*) addr);
		}else{
			throw new AddressException("unsupported address family: " ~ std.string.toString(addr.sa_family));
		}
	}

	unittest{
		sockaddr a;
		a.sa_family = AddressFamily.INET;
		new InternetAddress(&a);

		sockaddr b;
		b.sa_family = AddressFamily.INET6;
		new InternetAddress(&b);

		sockaddr c;
		c.sa_family = AddressFamily.IPX;
		try{
			new InternetAddress(&c);
		}catch(AddressException e){
			return;
		}

		assert(0);
	}

	this(sockaddr_in* addr){
		this = new InternetAddress4(addr);
	}

	unittest{
		sockaddr_in a;
		assert((new InternetAddress(&a)).toString() == "0.0.0.0:0");

		sockaddr_in b;
		b.sin_family = AddressFamily.INET6;

		try{
			new InternetAddress(&b);
		}catch(AddressException ae){
			return;
		}
		assert(0);
	}

	this(sockaddr_in6* addr){
		this = new InternetAddress6(addr);
	}

	unittest{
		sockaddr_in6 a;
		assert((new InternetAddress(&a)).toString() == "[::]:0");

		sockaddr_in6 b;
		b.sin6_family = AddressFamily.INET;

		try{
			new InternetAddress(&b);
		}catch(AddressException ae){
			return;
		}
		assert(0);
	}
	
	this(ubyte[] addr, ushort port = PORT_ANY){
		if(addr.length == 4){
			this = new InternetAddress4(addr, port);
		}else if(addr.length == 16){
			this = new InternetAddress6(addr, port);
		}else{
			throw new AddressException("unknown address family");
		}
	}

	unittest{
		const ubyte[] b = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15];
		InternetAddress ia = new InternetAddress(b);
		assert(ia.addressFamily() == AddressFamily.INET6);
		assert(ia.toString() == "[1:203:405:607:809:A0B:C0D:E0F]:0");
		
		ubyte[] c = b[0 .. 4].dup;
		ia = new InternetAddress(c);
		assert(ia.addressFamily() == AddressFamily.INET);
		assert(ia.toString() == "0.1.2.3:0");

		try{
			ia = new InternetAddress(c ~ c);
		}catch(AddressException e1){
			try{
				ia = new InternetAddress(c[0 .. 0]);
			}catch(AddressException e2){
				return;
			}
			assert(0);
		}
		assert(0);
	}

	unittest{
		const ubyte[] b = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15];
		InternetAddress ia = new InternetAddress(b, 12345);
		assert(ia.toString() == "[1:203:405:607:809:A0B:C0D:E0F]:12345");
		
		ubyte[] c = b[0 .. 4];
		ia = new InternetAddress(c, 54321);
		assert(ia.toString() == "0.1.2.3:54321");

		try{
			ia = new InternetAddress(c ~ c, 12);
		}catch(AddressException e1){
			try{
				ia = new InternetAddress(c[0 .. 0], 13);
			}catch(AddressException e2){
				return;
			}
			assert(0);
		}
		assert(0);
	}

	/// $(D_COMMENT IPv4 only)
	this(uint addr, ushort port = PORT_ANY){
		this = new InternetAddress4(addr, port);
	}

	unittest{
		InternetAddress ia = new InternetAddress(0xFF_01_02_30, 1456);
		assert(ia.toString() == "255.1.2.48:1456");
	}

	/// $(D_COMMENT IPv4 only)
	this(ushort port){
		this = new InternetAddress4(ADDR_ANY, port);
	}

	unittest{
		InternetAddress ia = new InternetAddress(cast(ushort)4321);
		assert(ia.toString() == "0.0.0.0:4321");
	}

	this(char[] host, ushort port = PORT_ANY){
		InternetHost ih = new InternetHost();

		if(!ih.getHostByName(host, std.string.toString(port)) || ih.addrList.length < 1){
			throw new AddressException("unable to resolve host '" ~ host ~ "'");
		}

		this = ih.addrList[0];
	}

	override sockaddr* name(){
		throw new Error("InternetAddress.name() has to be overriden by sub-classes");
	}

	override int nameLen(){
		throw new Error("InternetAddress.nameLen() has to be overriden by sub-classes");
	}

	override AddressFamily addressFamily(){
		throw new Error("InternetAddress.addressFamily() has to be overriden by sub-classes");
	}

	unittest{
		assert((new InternetAddress("::1")).addressFamily() == AddressFamily.INET6);
		assert((new InternetAddress("127.0.0.1")).addressFamily() == AddressFamily.INET);
	}

	ushort port(){
		throw new Error("InternetAddress.port() has to be overriden by sub-classes");
	}

	unittest{
		assert((new InternetAddress()).port() == 0);	
		assert((new InternetAddress("127.0.0.1", 0xA01B)).port() == 0xA01B);
		assert((new InternetAddress("1::", 0xFFF1)).port() == 0xFFF1);
	}

	ubyte[] getAddr(){
		throw new Error("InternetAddress.getAddr() has to be overriden by sub-classes");
	}

	unittest{
		ubyte[] a = (new InternetAddress("127.3.0.1")).getAddr();
		assert(a.length == 4);
		assert((a[0] == 127) && (a[1] == 3) && (a[2] == 0) && (a[3] == 1));

		a = (new InternetAddress("A01::3")).getAddr();
		assert(a.length == 16);
		assert((a[0] == 10) && (a[1] == 1) && (a[15] == 3));

		assert(a[2] == 0);
		assert(a[3] == 0);
		assert(a[4] == 0);
		assert(a[5] == 0);
		assert(a[6] == 0);
		assert(a[7] == 0);
		assert(a[8] == 0);
		assert(a[9] == 0);
		assert(a[10] == 0);
		assert(a[11] == 0);
		assert(a[12] == 0);
		assert(a[13] == 0);
		assert(a[14] == 0);
	}

	char[] toPortString(){
		return std.string.toString(port());
	}
	
	unittest{
		assert((new InternetAddress()).toPortString() == "0");	
		assert((new InternetAddress("127.0.0.1", 81)).toPortString() == "81");
		assert((new InternetAddress("1::", 80)).toPortString() == "80");
	}

	char[] toAddrString(){
		throw new Error("InternetAddress.toAddrString() has to be overriden by sub-classes");
	}

	unittest{
		assert((new InternetAddress("127.0.0.1")).toAddrString() == "127.0.0.1");
		assert((new InternetAddress("127:02:00A::03:1")).toAddrString() == "127:2:A::3:1");
		assert((new InternetAddress("0FDA:0C:00B::0:0")).toAddrString() == "FDA:C:B::");
		assert((new InternetAddress("0:0:00::B:0")).toAddrString() == "::B:0");
	}

	char[] toString(){
		throw new Error("InternetAddress.toString() has to be overriden by sub-classes");
	}

	unittest{
		assert((new InternetAddress("127.2.3.4", 8)).toString()=="127.2.3.4:8");
		assert((new InternetAddress("ABCD:02:0::0:3D", 7)).toString()=="[ABCD:2::3D]:7");
	}

	/// $(D_COMMENT IPv4 only)
	/// Deprecated: use "ubyte[] InternetAddress.getAddr()" or "uint InternetAddress4.addr()" 
	uint addr(){
		throw new AddressException("use \"ubyte[] InternetAddress.getAddr()\" or \"uint InternetAddress4.addr()\"");
	}

	/// $(D_COMMENT IPv4 only)
	/// Deprecated: use "static uint InternetAddress4.parse(char[])"
	static uint parse(char[] addr){
		return InternetAddress4.parse(addr);
	}

	override int opEquals(Object o){
		return super.opEquals(o);
	}

	int opEquals(InternetAddress o){
		return 
			(addressFamily() == o.addressFamily())
			&& (port() == o.port())
			&& (getAddr() == o.getAddr());
	}

	int opCmp(Object o){
		return super.opCmp(o);
	}
	
	int opCmp(InternetAddress o){	
		if(addressFamily() > o.addressFamily){
			return 1;
		}else if(addressFamily() == o.addressFamily){
			if(getAddr() > o.getAddr()){
				return 1;
			}else if(getAddr() == o.getAddr()){
				if(port() > o.port()){
					return 1;
				}else if(port() == o.port()){
					return 0;
				}
			}
		}

		return -1;
	}

	unittest{
		InternetAddress a = new InternetAddress("::1", 2);
		InternetAddress b = new InternetAddress("127.0.0.1", 8);
		InternetAddress c = new InternetAddress("::2", 1);
		InternetAddress d = new InternetAddress("::1", 3);
		InternetAddress e = new InternetAddress("127.0.0.1", 6);
		InternetAddress f = new InternetAddress("126.0.0.1", 6);
		InternetAddress g = new InternetAddress("0::1", 2);
		InternetAddress h = new InternetAddress("127.0.0.1", 6);

		assert(!(h != h));

		assert(h == h);
		
		assert(!(h > h));
		
		assert(!(h < h));
		
		assert(!(g != g));
		assert(g != h);
		
		assert(g == g);
		assert(!(g == h));

		assert(!(g > g));
		assert(g > h);

		assert(!(g < g));
		assert(!(g < h));

		assert(!(f != f));
		assert(f != g);
		assert(f != h);

		assert(f == f);
		assert(!(f == g));
		assert(!(f == h));
		
		assert(!(f > f));
		assert(!(f > g));
		assert(!(f > h));
		
		assert(!(f < f));
		assert(f < g);
		assert(f < h);

		assert(!(e != e));
		assert(e != f);
		assert(e != g);
		assert(!(e != h));

		assert(e == e);
		assert(!(e == f));
		assert(!(e == g));
		assert(e == h);

		assert(!(e > e));
		assert(e > f);
		assert(!(e > g));
		assert(!(e > h));

		assert(!(e < e));
		assert(!(e < f));
		assert(e < g);
		assert(!(e < h));

		assert(!(d != d));
		assert(d != e);
		assert(d != f);
		assert(d != g);
		assert(d != h);

		assert(d == d);
		assert(!(d == e));
		assert(!(d == f));
		assert(!(d == g));
		assert(!(d == h));
		
		assert(!(d > d));
		assert(d > e);
		assert(d > f);
		assert(d > g);
		assert(d > h);

		assert(!(d < d));
		assert(!(d < e));
		assert(!(d < f));
		assert(!(d < g));
		assert(!(d < h));

		assert(!(c != c));
		assert(c != d);
		assert(c != e);
		assert(c != f);
		assert(c != g);
		assert(c != h);

		assert(c == c);
		assert(!(c == d));
		assert(!(c == e));
		assert(!(c == f));
		assert(!(c == g));
		assert(!(c == h));

		assert(!(c > c));
		assert(c > d);
		assert(c > e);
		assert(c > f);
		assert(c > g);
		assert(c > h);

		assert(!(c < c));
		assert(!(c < d));
		assert(!(c < e));
		assert(!(c < f));
		assert(!(c < g));
		assert(!(c < h));

		assert(b == b);
		assert(!(b == c));
		assert(!(b == d));
		assert(!(b == e));
		assert(!(b == f));
		assert(!(b == g));
		assert(!(b == h));

		assert(!(b != b));
		assert(b != c);
		assert(b != d);
		assert(b != e);
		assert(b != f);
		assert(b != g);
		assert(b != h);
		
		assert(!(b > b));
		assert(!(b > c));
		assert(!(b > d));
		assert(b > e);
		assert(b > f);
		assert(!(b > g));
		assert(b > h);

		assert(!(b < b));
		assert(b < c);
		assert(b < d);
		assert(!(b < e));
		assert(!(b < f));
		assert(b < g);
		assert(!(b < h));
		
		assert(!(a != a));
		assert(a != b);
		assert(a != c);
		assert(a != d);
		assert(a != e);
		assert(a != f);
		assert(!(a != g));
		assert(a != h);

		assert(a == a);
		assert(!(a == b));
		assert(!(a == c));
		assert(!(a == d));
		assert(!(a == e));
		assert(!(a == f));
		assert(a == g);
		assert(!(a == h));

		assert(!(a > a));
		assert(a > b);
		assert(!(a > c));
		assert(!(a > d));
		assert(a > e);
		assert(a > f);
		assert(!(a > g));
		assert(a > h);
		
		assert(!(a < a));
		assert(!(a < b));
		assert(a < c);
		assert(a < d);
		assert(!(a < e));
		assert(!(a < f));
		assert(!(a < g));
		assert(!(a < h));
	}	
}

class InternetAddress4 : InternetAddress{
	protected sockaddr_in* sa4;

	override this(){
		this(ADDR_ANY, PORT_ANY);
	}

	this(uint addr, ushort port = PORT_ANY){
		sa4 = new sockaddr_in;
		sa4.sin_addr.s_addr = htonl(addr);
		sa4.sin_port = htons(port);

		super(DUMMY);
	}

	this(sockaddr* addr){
		this(cast(sockaddr_in*) addr);
	}

	this(sockaddr_in* addr){
		if(addr.sin_family != AddressFamily.INET){
			throw new AddressException("this is no IPv4 address");
		}

		sa4 = addr;

		super(DUMMY);
	}

	this(ubyte[] addr, ushort port = PORT_ANY){
		if(addr.length != 4){
			throw new AddressException("this is no IPv4 address");
		}
		
		sa4 = new sockaddr_in;
		sa4.sin_addr.s_addr = *(cast(uint32_t*)addr.ptr);
		sa4.sin_port = htons(port);

		super(DUMMY);
	}

	this(char[] host, ushort port = PORT_ANY){
		super(DUMMY);
		InternetHost ih = new InternetHost();
		addrinfo hint;
		hint.ai_family = AddressFamily.INET;

		if(!ih.getHostByName(host, &hint, std.string.toString(port)) || ih.addrList.length < 1){
			throw new AddressException("unable to resolve IPv4 address for host: " ~ host);
		}

		this = cast(InternetAddress4) ih.addrList[0];
	}

	override sockaddr* name(){
		return cast(sockaddr*) sa4;
	}

	override int nameLen(){
		return sockaddr_in.sizeof;
	}

	override AddressFamily addressFamily(){
		return cast(AddressFamily) sa4.sin_family;
	}
	
	override ushort port(){
		return ntohs(sa4.sin_port);
	}

	override uint addr(){
		return ntohl(sa4.sin_addr.s_addr);
	}

	unittest{
		assert((new InternetAddress("255.1.16.0")).addr() == 0xFF011000);

		try{
			(new InternetAddress("2::1")).addr();
		}catch(AddressException e){
			return;
		}

		assert(0);
	}

	override ubyte[] getAddr(){
		return (cast(ubyte*) &(sa4.sin_addr.s_addr))[0 .. 4];
	}
	
	override char[] toAddrString(){
		ubyte[] addr = getAddr();

		return format("%s.%s.%s.%s", addr[0], addr[1], addr[2], addr[3]);
	}

	override char[] toString(){
		return toAddrString() ~ ":" ~ toPortString();
	}
	
	static uint parse(char[] addr){
		return ntohl(inet_addr(std.string.toStringz(addr)));
	}
	
	unittest{
		assert(0xFF010203 == parse("255.1.2.3"));
	}

	invariant{
		assert(sa4.sin_family == AddressFamily.INET);
	}
}

class InternetAddress6 : InternetAddress{
	protected sockaddr_in6* sa6;

	this(){
		this(ADDR_ANY6, PORT_ANY);
	}

	this(ubyte[] addr, ushort port = PORT_ANY){
		if(addr.length != 16){
			throw new AddressException("this is no IPv6 address");
		}
		
		sa6 = new sockaddr_in6;
		sa6.sin6_addr.s6_addr8[] = addr;
		sa6.sin6_port = htons(port);

		super(DUMMY);
	}

	this(sockaddr* addr){
		this(cast(sockaddr_in6*) addr);
	}

	this(sockaddr_in6* addr){
		if(addr.sin6_family != AddressFamily.INET6){
			throw new AddressException("this is no IPv6 address");
		}

		sa6 = addr;

		super(DUMMY);
	}

	this(char[] host, ushort port = PORT_ANY){
		super(DUMMY);

		InternetHost ih = new InternetHost();
		addrinfo hint;
		hint.ai_family = AddressFamily.INET6;

		if(!ih.getHostByName(host, &hint, std.string.toString(port)) || ih.addrList.length < 1){
			throw new AddressException("unable to resolve IPv6 address for host: " ~ host);
		}

		this = cast(InternetAddress6) ih.addrList[0];
	}

	override sockaddr* name(){
		return cast(sockaddr*) sa6;
	}

	override int nameLen(){
		return sockaddr_in6.sizeof;
	}
	
	override AddressFamily addressFamily(){
		return cast(AddressFamily) sa6.sin6_family;
	}
	
	override ushort port(){
		return ntohs(sa6.sin6_port);
	}

	deprecated override uint addr(){
		throw new AddressException("this is no IPv4 address, use \"ubyte[] getAddr()\"");
	}

	override ubyte[] getAddr(){
		return sa6.sin6_addr.s6_addr8;
	}

	override char[] toAddrString(){
		ubyte[] addr = getAddr();

		char[] back;
		bool hadByPass = false;
		bool inByPass = false;
		for(size_t i=0; i< addr.length; i+=2){
			ushort s = ((cast(ushort)addr[i]) << 8) | addr[i+1];
			if(s == 0 && !hadByPass){
				inByPass = true;
			}else{
				if(inByPass){
					hadByPass = true;
					inByPass = false;
					back ~= format("::%X", s);
				}else if(i != 0){
					back ~= format(":%X", s);
				}else{
					back = format("%X", s);
				}
			}
		}

		if(back.length == 0){
			return "::".dup;
		}else if(inByPass){
			back ~= "::";
		}

		return back;
	}

	override char[] toString(){
		return "[" ~ toAddrString() ~ "]:" ~ toPortString();
	}

	deprecated static uint parse(char[] addr){
		throw new AddressException("use \"uint InternetAddress4.parse(char[])\"");
	}

	invariant{
		assert(sa6.sin6_family == AddressFamily.INET6);
	}
}

class SocketAcceptException: SocketException
{
	this(char[] msg, int err = 0)
	{
		super(msg, err);
	}
}


enum SocketShutdown: int
{
	RECEIVE =  SD_RECEIVE,
	SEND =     SD_SEND,
	BOTH =     SD_RECEIVE,
}


enum SocketFlags: int
{
	NONE =       0,

	OOB =        MSG_OOB, //out of band
	PEEK =       MSG_PEEK, //only for receiving
	DONTROUTE =  MSG_DONTROUTE, //only for sending
}


extern(C) struct timeval
{
	// D interface
	int seconds;
	int microseconds;

	// C interface
	deprecated
	{
		alias seconds tv_sec;
		alias microseconds tv_usec;
	}
}


//a set of sockets for Socket.select()
class SocketSet
{
	private:
	uint nbytes; // Win32: excludes uint.sizeof "count".
	byte* buf;


	version(Win32)
	{
		uint count()
		{
			return *(cast(uint*)buf);
		}


		void count(int setter)
		{
			*(cast(uint*)buf) = setter;
		}


		socket_t* first()
		{
			return cast(socket_t*)(buf + uint.sizeof);
		}
	}
	else version(BsdSockets)
	{
		int maxfd = -1;


		socket_t* first()
		{
			return cast(socket_t*)buf;
		}
	}


	fd_set* _fd_set()
	{
		return cast(fd_set*)buf;
	}


	public:
	this(uint max)
	{
		version(Win32)
		{
			nbytes = max * socket_t.sizeof;
			buf = new byte[nbytes + uint.sizeof];
			count = 0;
		}
		else version(BsdSockets)
		{
			nbytes = max / NFDBITS * socket_t.sizeof;
			if(max % NFDBITS)
				nbytes += socket_t.sizeof;
			buf = new byte[nbytes]; // new initializes to 0.
		}
	}


	this()
	{
		this(FD_SETSIZE);
	}


	void reset()
	{
		version(Win32)
		{
			count = 0;
		}
		else version(BsdSockets)
		{
			maxfd = -1;
			buf[0 .. nbytes] = 0;
		}
	}


	void add(socket_t s)
	in
	{
		// Make sure too many sockets don't get added.
		version(Win32)
		{
			assert(count < max);
		}
		else version(BsdSockets)
		{
			assert(FDELT(s) < nbytes / socket_t.sizeof);
		}
	}
	body
	{
		FD_SET(s, _fd_set);

		version(BsdSockets)
		{
			if(s > maxfd)
				maxfd = s;
		}
	}


	void add(Socket s)
	{
		add(s.sock);
	}


	void remove(socket_t s)
	{
		FD_CLR(s, _fd_set);
	}


	void remove(Socket s)
	{
		remove(s.sock);
	}


	int isSet(socket_t s)
	{
		return FD_ISSET(s, _fd_set);
	}


	int isSet(Socket s)
	{
		return isSet(s.sock);
	}


	// Max sockets that can be added, like FD_SETSIZE.
	uint max()
	{
		version(Win32)
		{
			return nbytes / socket_t.sizeof;
		}
		else version(BsdSockets)
		{
			return nbytes / socket_t.sizeof * NFDBITS;
		}
		else
		{
			static assert(0);
		}
	}


	fd_set* toFd_set()
	{
		return _fd_set;
	}


	int selectn()
	{
		version(Win32)
		{
			return 0;
		}
		else version(BsdSockets)
		{
			return maxfd + 1;
		}
	}
}


enum SocketOptionLevel: int
{
	SOCKET =  SOL_SOCKET,
	IP =      ProtocolType.IP,
	ICMP =    ProtocolType.ICMP,
	IGMP =    ProtocolType.IGMP,
	GGP =     ProtocolType.GGP,
	TCP =     ProtocolType.TCP,
	PUP =     ProtocolType.PUP,
	UDP =     ProtocolType.UDP,
	IDP =     ProtocolType.IDP,
	IPV6 =    ProtocolType.IPV6,
}


extern(C) struct linger
{
	// D interface
	version(Win32)
	{
		uint16_t on;
		uint16_t time;
	}
	else version(BsdSockets)
	{
		int32_t on;
		int32_t time;
	}

	// C interface
	deprecated
	{
		alias on l_onoff;
		alias time l_linger;
	}
}


enum SocketOption: int
{
	DEBUG =                SO_DEBUG,
	BROADCAST =            SO_BROADCAST,
	REUSEADDR =            SO_REUSEADDR,
	LINGER =               SO_LINGER,
	OOBINLINE =            SO_OOBINLINE,
	SNDBUF =               SO_SNDBUF,
	RCVBUF =               SO_RCVBUF,
	DONTROUTE =            SO_DONTROUTE,

	// SocketOptionLevel.TCP:
	TCP_NODELAY =          .TCP_NODELAY,

	// SocketOptionLevel.IPV6:
	IPV6_UNICAST_HOPS =    .IPV6_UNICAST_HOPS,
	IPV6_MULTICAST_IF =    .IPV6_MULTICAST_IF,
	IPV6_MULTICAST_LOOP =  .IPV6_MULTICAST_LOOP,
	IPV6_JOIN_GROUP =      .IPV6_JOIN_GROUP,
	IPV6_LEAVE_GROUP =     .IPV6_LEAVE_GROUP,
}


class Socket
{
	private:
	socket_t sock;
	AddressFamily _family;

	version(Win32)
		bit _blocking = false;


	// For use with accepting().
	protected this()
	{
	}


	public:
	this(AddressFamily af, SocketType type, ProtocolType protocol)
	{
		sock = cast(socket_t)socket(af, type, protocol);
		if(sock == socket_t.init)
			throw new SocketException("Unable to create socket", _lasterr());
		_family = af;
	}


	// A single protocol exists to support this socket type within the
	// protocol family, so the ProtocolType is assumed.
	this(AddressFamily af, SocketType type)
	{
		this(af, type, cast(ProtocolType)0); // Pseudo protocol number.
	}


	this(AddressFamily af, SocketType type, char[] protocolName)
	{
		protoent* proto;
		proto = getprotobyname(toStringz(protocolName));
		if(!proto)
			throw new SocketException("Unable to find the protocol", _lasterr());
		this(af, type, cast(ProtocolType)proto.p_proto);
	}


	~this()
	{
		close();
	}


	// Get underlying socket handle.
	socket_t handle() // getter
	{
		return sock;
	}


	bit blocking() // getter
	{
		version(Win32)
		{
			return _blocking;
		}
		else version(BsdSockets)
		{
			return !(fcntl(handle, F_GETFL, 0) & O_NONBLOCK);
		}
	}


	void blocking(bit byes) // setter
	{
		version(Win32)
		{
			uint num = !byes;
			if(_SOCKET_ERROR == ioctlsocket(sock, FIONBIO, &num))
				goto err;
			_blocking = byes;
		}
		else version(BsdSockets)
		{
			int x = fcntl(sock, F_GETFL, 0);
			if(-1 == x)
				goto err;
			if(byes)
				x &= ~O_NONBLOCK;
			else
				x |= O_NONBLOCK;
			if(-1 == fcntl(sock, F_SETFL, x))
				goto err;
		}
		return; // Success.

		err:
		throw new SocketException("Unable to set socket blocking", _lasterr());
	}


	AddressFamily addressFamily() // getter
	{
		return _family;
	}


	bit isAlive() // getter
	{
		int type, typesize = type.sizeof;
		return !getsockopt(sock, SOL_SOCKET, SO_TYPE, cast(char*)&type, &typesize);
	}


	void bind(Address addr)
	{
		if(_SOCKET_ERROR == .bind(sock, addr.name(), addr.nameLen()))
			throw new SocketException("Unable to bind socket", _lasterr());
	}

	void connect(Address to)
	{
		if(_SOCKET_ERROR == .connect(sock, to.name(), to.nameLen()))
		{
			int err;
			err = _lasterr();

			if(!blocking)
			{
				version(Win32)
				{
					if(WSAEWOULDBLOCK == err)
						return;
				}
				else version(linux)
				{
					if(EINPROGRESS == err)
						return;
				}
				else
				{
					static assert(0);
				}
			}
			throw new SocketException("Unable to connect socket", err);
		}
	}


	//need to bind() first
	void listen(int backlog)
	{
		if(_SOCKET_ERROR == .listen(sock, backlog))
			throw new SocketException("Unable to listen on socket", _lasterr());
	}


	// Override to use a derived class.
	// The returned socket's handle must not be set.
	protected Socket accepting()
	{
		return new Socket;
	}


	Socket accept()
	{
		socket_t newsock;
		//newsock = cast(socket_t).accept(sock, null, null); // DMD 0.101 error: found '(' when expecting ';' following 'statement
		alias .accept topaccept;
		newsock = cast(socket_t)topaccept(sock, null, null);
		if(socket_t.init == newsock)
			throw new SocketAcceptException("Unable to accept socket connection", _lasterr());

		Socket newSocket;
		try
		{
			newSocket = accepting();
			assert(newSocket.sock == socket_t.init);

			newSocket.sock = newsock;
			version(Win32)
				newSocket._blocking = _blocking; //inherits blocking mode
			newSocket._family = _family; //same family
		}
		catch(Object o)
		{
			_close(newsock);
			throw o;
		}

		return newSocket;
	}


	void shutdown(SocketShutdown how)
	{
		.shutdown(sock, cast(int)how);
	}


	private static void _close(socket_t sock)
	{
		version(Win32)
		{
			.closesocket(sock);
		}
		else version(BsdSockets)
		{
			.close(sock);
		}
	}


	//calling shutdown() before this is recommended
	//for connection-oriented sockets
	void close()
	{
		_close(sock);
		sock = socket_t.init;
	}


	private Address newFamilyObject()
	{
		Address result;
		switch(_family)
		{
			case cast(AddressFamily)AddressFamily.INET:
				result = new InternetAddress4();
				break;
			case cast(AddressFamily)AddressFamily.INET6:
				result = new InternetAddress6();
				break;

			default:
				result = new UnknownAddress;
		}
		return result;
	}


	// Returns the local machine's host name. Idea from mango.
	static char[] hostName() // getter
	{
		char[256] result; // Host names are limited to 255 chars.
		if(_SOCKET_ERROR == .gethostname(result, result.length))
			throw new SocketException("Unable to obtain host name", _lasterr());
		return std.string.toString(cast(char*)result).dup;
	}


	Address remoteAddress()
	{
		Address addr = newFamilyObject();
		int nameLen = addr.nameLen();
		if(_SOCKET_ERROR == .getpeername(sock, addr.name(), &nameLen))
			throw new SocketException("Unable to obtain remote socket address", _lasterr());
		assert(addr.addressFamily() == _family);
		return addr;
	}


	Address localAddress()
	{
		Address addr = newFamilyObject();
		int nameLen = addr.nameLen();
		if(_SOCKET_ERROR == .getsockname(sock, addr.name(), &nameLen))
			throw new SocketException("Unable to obtain local socket address", _lasterr());
		assert(addr.addressFamily() == _family);
		return addr;
	}


	const int ERROR = _SOCKET_ERROR;


	//returns number of bytes actually sent, or -1 on error
	int send(void[] buf, SocketFlags flags)
	{
		int sent = .send(sock, buf, buf.length, cast(int)flags);
		return sent;
	}


	int send(void[] buf)
	{
		return send(buf, SocketFlags.NONE);
	}


	int sendTo(void[] buf, SocketFlags flags, Address to)
	{
		int sent = .sendto(sock, buf, buf.length, cast(int)flags, to.name(), to.nameLen());
		return sent;
	}


	int sendTo(void[] buf, Address to)
	{
		return sendTo(buf, SocketFlags.NONE, to);
	}


	//assumes you connect()ed
	int sendTo(void[] buf, SocketFlags flags)
	{
		int sent = .sendto(sock, buf, buf.length, cast(int)flags, null, 0);
		return sent;
	}


	//assumes you connect()ed
	int sendTo(void[] buf)
	{
		return sendTo(buf, SocketFlags.NONE);
	}


	//returns number of bytes actually received, 0 on connection closure, or -1 on error
	int receive(void[] buf, SocketFlags flags)
	{
		if(!buf.length) //return 0 and don't think the connection closed
			return 0;
		int read = .recv(sock, buf, buf.length, cast(int)flags);
		// if(!read) //connection closed
		return read;
	}


	int receive(void[] buf)
	{
		return receive(buf, SocketFlags.NONE);
	}


	int receiveFrom(void[] buf, SocketFlags flags, out Address from)
	{
		if(!buf.length) //return 0 and don't think the connection closed
			return 0;
		from = newFamilyObject();
		int nameLen = from.nameLen();
		int read = .recvfrom(sock, buf, buf.length, cast(int)flags, from.name(), &nameLen);
		assert(from.addressFamily() == _family);
		// if(!read) //connection closed
		return read;
	}


	int receiveFrom(void[] buf, out Address from)
	{
		return receiveFrom(buf, SocketFlags.NONE, from);
	}


	//assumes you connect()ed
	int receiveFrom(void[] buf, SocketFlags flags)
	{
		if(!buf.length) //return 0 and don't think the connection closed
			return 0;
		int read = .recvfrom(sock, buf, buf.length, cast(int)flags, null, null);
		// if(!read) //connection closed
		return read;
	}


	//assumes you connect()ed
	int receiveFrom(void[] buf)
	{
		return receiveFrom(buf, SocketFlags.NONE);
	}


	//returns the length, in bytes, of the actual result - very different from getsockopt()
	int getOption(SocketOptionLevel level, SocketOption option, void[] result)
	{
		int len = result.length;
		if(_SOCKET_ERROR == .getsockopt(sock, cast(int)level, cast(int)option, result, &len))
			throw new SocketException("Unable to get socket option", _lasterr());
		return len;
	}


	// Common case for integer and boolean options.
	int getOption(SocketOptionLevel level, SocketOption option, out int32_t result)
	{
		return getOption(level, option, (&result)[0 .. 1]);
	}


	int getOption(SocketOptionLevel level, SocketOption option, out linger result)
	{
		//return getOption(cast(SocketOptionLevel)SocketOptionLevel.SOCKET, SocketOption.LINGER, (&result)[0 .. 1]);
		return getOption(level, option, (&result)[0 .. 1]);
	}


	void setOption(SocketOptionLevel level, SocketOption option, void[] value)
	{
		if(_SOCKET_ERROR == .setsockopt(sock, cast(int)level, cast(int)option, value, value.length))
			throw new SocketException("Unable to set socket option", _lasterr());
	}


	// Common case for integer and boolean options.
	void setOption(SocketOptionLevel level, SocketOption option, int32_t value)
	{
		setOption(level, option, (&value)[0 .. 1]);
	}


	void setOption(SocketOptionLevel level, SocketOption option, linger value)
	{
		//setOption(cast(SocketOptionLevel)SocketOptionLevel.SOCKET, SocketOption.LINGER, (&value)[0 .. 1]);
		setOption(level, option, (&value)[0 .. 1]);
	}


	//SocketSet's updated to include only those sockets which an event occured
	//returns the number of events, 0 on timeout, or -1 on interruption
	//for a connect()ing socket, writeability means connected
	//for a listen()ing socket, readability means listening
	//Winsock: possibly internally limited to 64 sockets per set
	static int select(SocketSet checkRead, SocketSet checkWrite, SocketSet checkError, timeval* tv)
	in
	{
		//make sure none of the SocketSet's are the same object
		if(checkRead)
		{
			assert(checkRead !is checkWrite);
			assert(checkRead !is checkError);
		}
		if(checkWrite)
		{
			assert(checkWrite !is checkError);
		}
	}
	body
	{
		fd_set* fr, fw, fe;
		int n = 0;

		version(Win32)
		{
			// Windows has a problem with empty fd_set`s that aren't null.
			fr = (checkRead && checkRead.count()) ? checkRead.toFd_set() : null;
			fw = (checkWrite && checkWrite.count()) ? checkWrite.toFd_set() : null;
			fe = (checkError && checkError.count()) ? checkError.toFd_set() : null;
		}
		else
		{
			if(checkRead)
			{
				fr = checkRead.toFd_set();
				n = checkRead.selectn();
			}
			else
			{
				fr = null;
			}

			if(checkWrite)
			{
				fw = checkWrite.toFd_set();
				int _n;
				_n = checkWrite.selectn();
				if(_n > n)
					n = _n;
			}
			else
			{
				fw = null;
			}

			if(checkError)
			{
				fe = checkError.toFd_set();
				int _n;
				_n = checkError.selectn();
				if(_n > n)
					n = _n;
			}
			else
			{
				fe = null;
			}
		}

		int result = .select(n, fr, fw, fe, cast(_ctimeval*)tv);

		version(Win32)
		{
			if(_SOCKET_ERROR == result && WSAGetLastError() == WSAEINTR)
				return -1;
		}
		else version(linux)
		{
			if(_SOCKET_ERROR == result && getErrno() == EINTR)
				return -1;
		}
		else
		{
			static assert(0);
		}

		if(_SOCKET_ERROR == result)
			throw new SocketException("Socket select error", _lasterr());

		return result;
	}


	static int select(SocketSet checkRead, SocketSet checkWrite, SocketSet checkError, int microseconds)
	{
		timeval tv;
		tv.seconds = 0;
		tv.microseconds = microseconds;
		return select(checkRead, checkWrite, checkError, &tv);
	}


	//maximum timeout
	static int select(SocketSet checkRead, SocketSet checkWrite, SocketSet checkError)
	{
		return select(checkRead, checkWrite, checkError, null);
	}


	/+
	bit poll(events)
	{
		int WSAEventSelect(socket_t s, WSAEVENT hEventObject, int lNetworkEvents); // Winsock 2 ?
		int poll(pollfd* fds, int nfds, int timeout); // Unix ?
	}
	+/
}


class TcpSocket: Socket
{
	this(AddressFamily family)
	{
		super(family, SocketType.STREAM, ProtocolType.TCP);
	}


	this()
	{
		this(AddressFamily.INET);
	}


	//shortcut
	this(Address connectTo)
	{
		this(connectTo.addressFamily());
		connect(connectTo);
	}

	//shortcut
	this(char[] host, ushort port)
	{
		InternetHost ih = new InternetHost();
		ih.getHostByName(host, std.string.toString(port));

		TcpSocket socket;
		SocketException se;
	
		foreach(InternetAddress ia; ih.addrList)
		{
			try
			{
				socket = new TcpSocket(ia);
			}
			catch(SocketException e)
			{
				se = e;
			}
		}

		if(se is null)
		{
			throw new SocketException(format("failed to connect to %s on port %s", host, port));
		}
		else
		{
			throw se;
		}
	}
}


class UdpSocket: Socket
{
	this(AddressFamily family)
	{
		super(family, SocketType.DGRAM, ProtocolType.UDP);
	}


	this()
	{
		this(AddressFamily.INET);
	}
}

