#include "xccom-def.h"
#include <stdio.h>
#include <ocidl.h>
#include <olectl.h>

#define CHECK(expr) if ((hres = expr) != S_OK) goto ERRTAG

/* IUnknown implementation for objects */

STDMETHODIMP xccom_object_queryinterface(xccom_object *This, REFIID riid, void **ppvObject)
{
	LPOLESTR str;
	StringFromIID(riid, &str);
	printf("%p - Looking for interface %ws\n", This, str);
	if (InlineIsEqualGUID(riid, &IID_IUnknown))
		*ppvObject = &This->unk_iface;
	else if (This->disp_iface != NULL && InlineIsEqualGUID(riid, &IID_IDispatch))
		return IUnknown_QueryInterface(This->disp_iface, riid, ppvObject);
	else {
		xc_iface *iface = This->iface_list;
		while (iface != NULL) {
			if (InlineIsEqualGUID(riid, iface->iid)) {
				*ppvObject = &iface->iface;
				break;
			}
			iface = iface->next;
		}
		if (iface == NULL)
			return E_NOINTERFACE;
	}

	IUnknown_AddRef(&This->unk_iface);

	return S_OK;
}

STDMETHODIMP_(ULONG) xccom_object_addref(xccom_object *This)
{
	return ++(This->ref_count);
}

STDMETHODIMP_(ULONG) xccom_object_release(xccom_object *This)
{
	if (--(This->ref_count) == 0) {
		if (This->disp_iface != NULL)
			IUnknown_Release(This->disp_iface);
		This->delete_fun(This);
		return 0;
	}
	return This->ref_count;
}

static IUnknownVtbl xccom_object_iunknown = {
	(HRESULT(STDMETHODCALLTYPE *)(IUnknown*,REFIID,void**))xccom_object_queryinterface,
	(ULONG(STDMETHODCALLTYPE *)(IUnknown*))xccom_object_addref,
	(ULONG(STDMETHODCALLTYPE *)(IUnknown*))xccom_object_release
};

/* IUnknown implementation for interfaces */

STDMETHODIMP xc_iface_queryinterface(xc_iface *This, REFIID iid, void **ppvObject)
{
	if (InlineIsEqualGUID(iid, This->iid)) {
		*ppvObject = &This->iface;
		IUnknown_AddRef(&This->iface);
		return S_OK;
	}
	return IUnknown_QueryInterface(&This->root->unk_iface, iid, ppvObject);
}

STDMETHODIMP_(ULONG) xc_iface_addref(xc_iface *This)
{
	return IUnknown_AddRef(&This->root->unk_iface);
}

STDMETHODIMP_(ULONG) xc_iface_release(xc_iface *This)
{
	return IUnknown_Release(&This->root->unk_iface);
}

/* Utilities functions */

void xccom_object_init(xccom_object *This)
{
	ZeroMemory(This, sizeof(xccom_object));
	This->unk_iface.lpVtbl = &xccom_object_iunknown;
	This->delete_fun = (void(*)(void*))xccom_object_delete;
}

static void xc_iface_delete(xc_iface *iface)
{
	if (iface->next)
		iface->next->delete_fun(iface->next);
	LocalFree(iface);
}

void xccom_object_delete(xccom_object *This)
{
	if (This->iface_list)
		This->iface_list->delete_fun(This->iface_list);
	LocalFree(This);
}

static void xccom_object_add_interface_internal(xccom_object *This, REFIID iid, void *pVtbl, xc_iface *iface)
{
	xc_iface **ptr = &This->iface_list;
	iface->iface.lpVtbl = pVtbl;
	iface->iface.lpVtbl->QueryInterface = (HRESULT(STDMETHODCALLTYPE *)(IUnknown*,REFIID,void**))xc_iface_queryinterface;
	iface->iface.lpVtbl->AddRef = (ULONG(STDMETHODCALLTYPE *)(IUnknown*))xc_iface_addref;
	iface->iface.lpVtbl->Release = (ULONG(STDMETHODCALLTYPE *)(IUnknown*))xc_iface_release;
	iface->iid = iid;
	iface->root = This;
	iface->next = NULL;
	iface->delete_fun = (void(*)(void*))xc_iface_delete;
	while (*ptr != NULL)
		ptr = &(*ptr)->next;
	*ptr = iface;
}

void xccom_object_add_interface(xccom_object *This, REFIID iid, void *pVtbl)
{
	xc_iface *iface = (xc_iface*)LocalAlloc(LMEM_FIXED | LMEM_ZEROINIT, sizeof(xc_iface));
	xccom_object_add_interface_internal(This, iid, pVtbl, iface);
}

HRESULT xccom_object_create_dispatch_interface(xccom_object *This, REFIID iid, ITypeInfo *ptinfo)
{
	xc_iface *iface;
	HRESULT hres = E_INVALIDARG;

	if (This->disp_iface != NULL) {
		IUnknown_Release(This->disp_iface);
		This->disp_iface = NULL;
	}

	iface = This->iface_list;
	while (iface != NULL)
		if (InlineIsEqualGUID(iid, iface->iid))
			break;
		else
			iface = iface->next;
	if (iface != NULL) {
		hres = CreateStdDispatch(&This->unk_iface, &iface->iface, ptinfo, &This->disp_iface);
	}
	return hres;
}

/* COM event functions */

typedef struct xc_connection {
	CONNECTDATA cData;
	struct xc_connection *next;
} xc_connection;

typedef struct xc_connectionpoint {
	xccom_object com;
	IConnectionPointContainer *owner;
	const IID *iid;
	int cookieCounter;
	xc_connection *c_list;
} xc_connectionpoint;

#define TOCP(This) (xc_connectionpoint*)TOCOMOBJ(This)

void xc_connectionpoint_delete(xc_connectionpoint *This)
{
	xc_connection *conn = This->c_list;
	while (conn != NULL) {
		xc_connection *this_conn = conn;
		conn = conn->next;
		IUnknown_Release(this_conn->cData.pUnk);
		LocalFree(this_conn);
	}
	xccom_object_delete(&This->com);
}

STDMETHODIMP xc_connectionpoint_advise(IConnectionPoint *This, IUnknown *pUnk, DWORD *pdwCookie)
{
	IUnknown *iface;
	xc_connectionpoint *cp = TOCP(This);
	xc_connection *conn;

	if (IUnknown_QueryInterface(pUnk, cp->iid, (void**)&iface) != S_OK)
		if (IUnknown_QueryInterface(pUnk, &IID_IDispatch, (void**)&iface) != S_OK)
			return CONNECT_E_CANNOTCONNECT;
	*pdwCookie = cp->cookieCounter++;
	conn = (xc_connection*)LocalAlloc(LMEM_FIXED|LMEM_ZEROINIT, sizeof(xc_connection));
	conn->cData.pUnk = iface;
	conn->cData.dwCookie = *pdwCookie;
	conn->next = cp->c_list;
	cp->c_list = conn;

	return S_OK;
}

STDMETHODIMP xc_connectionpoint_enumconnections(IConnectionPoint *This, IEnumConnections **ppEnum)
{
	return E_NOTIMPL;
}

STDMETHODIMP xc_connectionpoint_getconnectioninterface(IConnectionPoint *This, IID *iid)
{
	memcpy(iid, (TOCP(This))->iid, sizeof(IID));
	return S_OK;
}

STDMETHODIMP xc_connectionpoint_getconnectionpointcontainer(IConnectionPoint *This, IConnectionPointContainer **ppCPC)
{
	xc_connectionpoint *cp = TOCP(This);
	return IConnectionPointContainer_QueryInterface(cp->owner, &IID_IConnectionPointContainer, ppCPC);
}

STDMETHODIMP xc_connectionpoint_unadvise(IConnectionPoint *This, DWORD dwCookie)
{
	xc_connectionpoint *cp = TOCP(This);
	xc_connection *conn = cp->c_list, *prev = NULL;

	while (conn != NULL) {
		if (conn->cData.dwCookie = dwCookie)
			break;
		prev = conn;
		conn = conn->next;
	}
	if (conn != NULL) {
		if (prev == NULL)
			cp->c_list = conn->next;
		else
			prev->next = conn->next;
		IUnknown_Release(conn->cData.pUnk);
		LocalFree(conn);
		return S_OK;
	} else
		return CONNECT_E_NOCONNECTION;
}

static IConnectionPointVtbl cp_vtbl = {
	/* IUnknown */
	NULL,
	NULL,
	NULL,
	/* IConnectionPoint */
	xc_connectionpoint_getconnectioninterface,
	xc_connectionpoint_getconnectionpointcontainer,
	xc_connectionpoint_advise,
	xc_connectionpoint_unadvise,
	xc_connectionpoint_enumconnections
};

xc_connectionpoint* make_connectionpoint(IConnectionPointContainer *owner, REFIID iid)
{
	xc_connectionpoint *cp = (xc_connectionpoint*)LocalAlloc(LMEM_FIXED|LMEM_ZEROINIT, sizeof(xc_connectionpoint));

	cp->owner = owner;
	cp->iid = iid;
	xccom_object_init(&cp->com);
	xccom_object_add_interface(&cp->com, &IID_IConnectionPoint, &cp_vtbl);
	cp->com.delete_fun = (void(*)(void*))xc_connectionpoint_delete;

	return cp;
}

typedef struct xc_cpc_item {
	IConnectionPoint *cp;
	struct xc_cpc_item *next;
} xc_cpc_item;

typedef struct xc_CPcontainer {
	xc_iface iface;
	xc_cpc_item *cp_list;
} xc_CPcontainer;

void xc_cpc_delete(xc_CPcontainer *This)
{
	xc_cpc_item *item = This->cp_list;
	while (item != NULL) {
		xc_cpc_item *this_item = item;
		item = item->next;
		IConnectionPoint_Release(this_item->cp);
		LocalFree(this_item);
	}
	xc_iface_delete(&This->iface);
}

STDMETHODIMP xc_cpc_enumconnectionpoints(IConnectionPointContainer *This, IEnumConnectionPoints **ppEnum)
{
	return E_NOTIMPL;
}

STDMETHODIMP xc_cpc_findconnectionpoint(IConnectionPointContainer *This, REFIID iid, IConnectionPoint **ppCP)
{
	xc_CPcontainer *cpc = (xc_CPcontainer*)This;
	xc_cpc_item *item = cpc->cp_list;

	while (item != NULL) {
		if (InlineIsEqualGUID(iid, (TOCP(item->cp))->iid))
			break;
		item = item->next;
	}
	if (item != NULL)
		return IConnectionPoint_QueryInterface(item->cp, &IID_IConnectionPoint, (void**)ppCP);
	else
		return CONNECT_E_NOCONNECTION;
}

static IConnectionPointContainerVtbl cpc_vtbl = {
	/* IUnknown */
	NULL,
	NULL,
	NULL,
	/* IConnectionPointContainer */
	xc_cpc_enumconnectionpoints,
	xc_cpc_findconnectionpoint
};

void xccom_object_add_connectionpoint(xccom_object *This, REFIID iid)
{
	xc_CPcontainer *cpc = NULL;
	xc_iface *iface = This->iface_list;
	xc_connectionpoint *cp;
	xc_cpc_item *item;

	while (iface != NULL) {
		if (InlineIsEqualGUID(iface->iid, &IID_IConnectionPointContainer)) {
			cpc = (xc_CPcontainer*)iface;
			break;
		}
		iface = iface->next;
	}
	if (cpc == NULL) {
		cpc = (xc_CPcontainer*)LocalAlloc(LMEM_FIXED|LMEM_ZEROINIT, sizeof(xc_CPcontainer));
		xccom_object_add_interface_internal(This, &IID_IConnectionPointContainer, &cpc_vtbl, &cpc->iface);
		cpc->iface.delete_fun = (void(*)(void*))xc_cpc_delete;
	}

	cp = make_connectionpoint((IConnectionPointContainer*)&cpc->iface, iid);
	item = (xc_cpc_item*)LocalAlloc(LMEM_FIXED|LMEM_ZEROINIT, sizeof(xc_cpc_item));
	IUnknown_QueryInterface(&cp->com.unk_iface, &IID_IConnectionPoint, (void**)&item->cp);
	item->next = cpc->cp_list;
	cpc->cp_list = item;
}

HRESULT xccom_object_fire_event(xccom_object *This, REFIID iid, DISPID dispID)
{
	IConnectionPointContainer *cpc = NULL;
	IConnectionPoint *cp = NULL;
	HRESULT hres;

#define ERRTAG FAILED_fire_event
	CHECK(IUnknown_QueryInterface(&This->unk_iface, &IID_IConnectionPointContainer, (void**)&cpc));
	CHECK(IConnectionPointContainer_FindConnectionPoint(cpc, iid, &cp));
	{
		xc_connectionpoint *xcp = TOCP(cp);
		xc_connection *conn = xcp->c_list;
		DISPPARAMS dispParams = { NULL, NULL, 0, 0 };

		while (conn != NULL) {
			IDispatch_Invoke((IDispatch*)conn->cData.pUnk, dispID, &IID_NULL,
					LANG_NEUTRAL, DISPATCH_METHOD, &dispParams,
					NULL, NULL, NULL);
			conn = conn->next;
		}
	}
#undef ERRTAG

FAILED_fire_event:
	if (cpc != NULL)
		IConnectionPointContainer_Release(cpc);
	if (cp != NULL)
		IConnectionPoint_Release(cp);
	return hres;
}

/* Interface initialization / termination */

HRESULT COM_init_interfaces(ITypeLib *tl)
{
	COM_IFACE *iface = __com_iface;
	HRESULT hres = S_OK;

	while (hres == S_OK && !InlineIsEqualGUID(iface->iid, &IID_NULL)) {
		hres = ITypeLib_GetTypeInfoOfGuid(tl, iface->iid, &iface->ti);
		if (hres == S_OK)
			ITypeInfo_AddRef(iface->ti);
		iface++;
	}

	return hres;
}

HRESULT COM_term_interfaces()
{
	COM_IFACE *iface = __com_iface;
	while (!InlineIsEqualGUID(iface->iid, &IID_NULL)){
		ITypeInfo_Release(iface->ti);
		iface++;
	}
	return S_OK;
}

ITypeInfo* COM_get_typeinfo(REFIID iid)
{
	COM_IFACE *iface = __com_iface;
	ITypeInfo *ti = NULL;

	while (ti == NULL && !InlineIsEqualGUID(iface->iid, &IID_NULL)) {
		if (InlineIsEqualGUID(iface->iid, iid)) {
			ti = iface->ti;
			break;
		}
		iface++;
	}
	return ti;
}
