#include <linux/init.h>
#include <linux/module.h>
#include <linux/netfilter.h>
#include <linux/netfilter_bridge.h>
#include <linux/list.h>
#include <linux/if_ether.h>
#include <linux/etherdevice.h>
#include <net/netlink.h>
#include <linux/string.h>
#include <linux/timer.h>
#include <linux/smp.h>
#include <linux/spinlock.h>
#include <linux/kobject.h>
#include <linux/version.h>
#include <linux/jhash.h>
#include <linux/ip.h>
#include <linux/tcp.h>
#include <linux/udp.h>
#include <linux/if_vlan.h>

#include "traffic_separation.h"
#include "mtkmapfilter.h"

#define MAX_VLAN_NUM 48
#define INVALID_VLAN_ID		4095
#define MAX_NUM_TRANSPARENT_VLAN 128

#ifndef BIT
#define BIT(x) (1U << (x))
#endif

extern unsigned char traffic_seperation;


struct client_context {
	struct hlist_head client_head;
	spinlock_t hash_lock;
};

struct client_db {
	struct hlist_node hlist;
	struct rcu_head rcu;
	unsigned char mac[ETH_ALEN];
	unsigned short vid;
};

struct transparent_vlan {
	unsigned char transparent_enabled;
	unsigned int bitmap_transparent_vids[128];
};

struct traffic_separation_config {
	unsigned short primary_vid;
	unsigned char default_pcp;
	unsigned char pvid_num;
	unsigned int bitmap_pvid[128];
	struct transparent_vlan trans_vlan;
	struct client_context clients[HASH_TABLE_SIZE];
};
unsigned short size_of_int = sizeof(int) * 8;

struct traffic_separation_config ts_config;

struct client_db *create_client_db(unsigned char mac[], unsigned short vid)
{
	struct client_db *client = NULL;

	client = (struct client_db*)kmalloc(sizeof(struct client_db), GFP_ATOMIC);

	if (client) {
		memcpy(client->mac, mac, ETH_ALEN);
		client->vid = vid;
	}

	return client;
 }

struct client_db *get_client_db(unsigned char mac[])
{
	struct client_db *client = NULL;
	unsigned char hash_idx = MAC_ADDR_HASH_INDEX(mac);
	struct hlist_head *head = &ts_config.clients[hash_idx].client_head;

	hlist_for_each_entry_rcu(client, head, hlist) {
		if (!memcmp(client->mac, mac, ETH_ALEN))
			break;
	}

	return client;
}

static void free_client_db(struct rcu_head *head)
{
	struct client_db *dev = container_of(head, struct client_db, rcu);
	printk("----->free client db rcu %p,%p\n", head, dev);
	kfree(dev);
}

void remove_client_db(unsigned char mac[])
{
	struct client_db *client = NULL;	
	unsigned char hash_idx = MAC_ADDR_HASH_INDEX(mac);
	struct client_context *head = &ts_config.clients[hash_idx];

	spin_lock_bh(&head->hash_lock);
	hlist_for_each_entry_rcu(client, &head->client_head, hlist) {
		if (!memcmp(client->mac, mac, ETH_ALEN)) {
			hlist_del_rcu(&client->hlist);
			call_rcu(&client->rcu, free_client_db);
			break;
		}
	}
	spin_unlock_bh(&head->hash_lock);
}

void clear_client_db(void)
{
	struct client_db *client = NULL;
	struct hlist_node *n;
	struct client_context *head = NULL;
	int i = 0;
	struct hlist_head clear_list;

	INIT_HLIST_HEAD(&clear_list);

	for (i = 0; i < HASH_TABLE_SIZE; i++) {
		head = &ts_config.clients[i];

		spin_lock_bh(&head->hash_lock);
		hlist_for_each_entry_safe(client, n, &head->client_head, hlist) {
			hlist_del_rcu(&client->hlist);
			hlist_add_head_rcu(&client->hlist, &clear_list);
		}
		spin_unlock_bh(&head->hash_lock);
	}

	hlist_for_each_entry_safe(client, n, &clear_list, hlist) {
		hlist_del_rcu(&client->hlist);
		synchronize_rcu();
		kfree(client);
	}
}

void ts_init(void)
{
	int i = 0;

	memset(&ts_config, 0, sizeof(struct traffic_separation_config));

	for (i = 0; i < HASH_TABLE_SIZE; i++) {
		INIT_HLIST_HEAD(&ts_config.clients[i].client_head);
		spin_lock_init(&ts_config.clients[i].hash_lock);
	}
}

void ts_deinit(void)
{
	clear_client_db();
	return;
}

unsigned char is_ts_enable(void)
{
	return ts_config.pvid_num;
}

unsigned char handle_ts_default_8021q(unsigned short primary_vid, unsigned char default_pcp)
{
	ts_config.primary_vid = primary_vid;
	ts_config.default_pcp = default_pcp;
	printk("handle_ts_default_8021q vid=%d, pcp=%d\n", primary_vid, default_pcp);
	return 0;
}

unsigned char handle_ts_policy(struct ts_policy *policy)
{
	unsigned char i = 0, index = 0, offset = 0;

	if (policy->num > MAX_VLAN_NUM) {
		printk("invalid traffic separtion policy number=%d\n", policy->num);
		ts_config.pvid_num = 0;
		return -1;
	}

	memset(ts_config.bitmap_pvid, 0, sizeof(ts_config.bitmap_pvid));
	ts_config.pvid_num = policy->num;

	printk("ts policy number=%d\n", ts_config.pvid_num);
	for (i = 0; i < policy->num; i++) {
		if (policy->ssid_2_vid[i].vlan_id < 1 || policy->ssid_2_vid[i].vlan_id >= INVALID_VLAN_ID)
			continue;
		index = policy->ssid_2_vid[i].vlan_id / size_of_int;
		offset = policy->ssid_2_vid[i].vlan_id % size_of_int;
		ts_config.bitmap_pvid[index] |= BIT(offset);
		printk("ts policy(%d) vid=%d ([%d]|1<<%d)=%d\n", i, 
			policy->ssid_2_vid[i].vlan_id, index, offset, ts_config.bitmap_pvid[index]);
	}

	return 0;
}

unsigned char handle_transparent_vlan(struct transparent_vids *tvids)
{
	unsigned char i = 0, index = 0, offset = 0;

	memset(ts_config.trans_vlan.bitmap_transparent_vids, 0,
		sizeof(ts_config.trans_vlan.bitmap_transparent_vids));
	
	printk("transparent_vid_num=%d\n", tvids->num);

	ts_config.trans_vlan.transparent_enabled = tvids->num;
	memset(ts_config.trans_vlan.bitmap_transparent_vids, 0,
		sizeof(ts_config.trans_vlan.bitmap_transparent_vids));
	
	for (i = 0; i < tvids->num; i++) {
		if (tvids->vids[i] < 1 || tvids->vids[i] >= INVALID_VLAN_ID)
			continue;
		index = tvids->vids[i] / size_of_int;
		offset = tvids->vids[i] % size_of_int;
		ts_config.trans_vlan.bitmap_transparent_vids[index] |= BIT(offset);
		printk("transparent vids[%d] vid=%d ([%d]|1<<%d)=%d\n", i, tvids->vids[i],
			index, offset, ts_config.trans_vlan.bitmap_transparent_vids[index]);
	}

	return 0;
}

unsigned char handle_client_vid(struct client_vid *client)
{
	unsigned char hash_idx = MAC_ADDR_HASH_INDEX(client->client_mac);
	struct client_context *client_head = NULL;
	struct client_db *cli_db = NULL;

	if (client->vid < 1 || client->vid >= 4095) {
		printk("client vid invlaid(%d)\n", client->vid);
		return 0;
	}
	client_head = &ts_config.clients[hash_idx];

	if (client->status == STATION_JOIN) {
		rcu_read_lock();
		cli_db = get_client_db(client->client_mac);
		if (!cli_db) {
			rcu_read_unlock();
			cli_db = create_client_db(client->client_mac, client->vid);
			spin_lock_bh(&client_head->hash_lock);
			hlist_add_head_rcu(&cli_db->hlist, &client_head->client_head);
			spin_unlock_bh(&client_head->hash_lock);  
		} else {
			cli_db->vid = client->vid;
			rcu_read_unlock();
		}
	} else if(client->status == STATION_LEAVE) {
		remove_client_db(client->client_mac);
	}

	return 0;
}



unsigned int is_vid_in_policy(unsigned short vid)
{
	return ts_config.bitmap_pvid[vid / size_of_int] & BIT(vid % size_of_int);
}

unsigned char is_transparent_vlan_on(void)
{
	return ts_config.trans_vlan.transparent_enabled;
}

unsigned int is_transparent_vlan(unsigned short vid)
{
	return ts_config.trans_vlan.bitmap_transparent_vids[vid / size_of_int] & BIT(vid % size_of_int);
}

#define VLAN_VID_MASK		0x0fff /* VLAN Identifier */
#define vlan_tx_tag_get_id(__skb)	((__skb)->vlan_tci & VLAN_VID_MASK)

unsigned int RtmpOsCsumAdd(unsigned int csum, unsigned int addend)
{
	unsigned int res = csum;
	res += addend;
	return res + (res < addend);
}

void RtmpOsSkbPullRcsum(struct sk_buff *skb, unsigned int len)
{
	if (len > skb->len)
		return;

	skb_pull(skb, len);
	if (skb->ip_summed == CHECKSUM_COMPLETE)
		skb->csum = RtmpOsCsumAdd(skb->csum, ~csum_partial(skb->data, len, 0));
	else if (skb->ip_summed == CHECKSUM_PARTIAL &&
		 (skb->csum_start - (skb->data - skb->head)) < 0)
		skb->ip_summed = CHECKSUM_NONE;
}

static inline void remove_vlan_tag(struct sk_buff *skb)
{
	unsigned short VLAN_LEN = 4;
	unsigned char extra_field_offset = 2 * ETH_ALEN;

	memmove(skb->data + VLAN_LEN,
		skb->data, extra_field_offset);
	RtmpOsSkbPullRcsum(skb, 4);
	skb_reset_mac_header(skb);
	skb_reset_network_header(skb);
	skb_reset_transport_header(skb);
	skb_reset_mac_len(skb);
}
unsigned char add_vlan_tag(struct sk_buff *skb, unsigned short vlan_id, unsigned char vlan_pcp);

unsigned int ts_tx_process(struct sk_buff *skb, unsigned char wan_tag)
{
	unsigned short vid = 0;
	unsigned char vlan_in_header = 0;

	/*firstly, check the transparent vlan*/
	if (is_transparent_vlan_on()) {
		if (skb->vlan_proto == htons(ETH_P_8021Q) && skb->vlan_tci) {
			vid = vlan_tx_tag_get_id(skb);
			if (is_transparent_vlan(vid))
				return NF_ACCEPT;
		}
	}

	/*then, check normal traffic separation configuration*/
	if (!is_ts_enable())
		return NF_ACCEPT;

	/*if vlan info is in skb->vlan_tci*/
	if (skb->vlan_proto == htons(ETH_P_8021Q) && skb->vlan_tci) {
		vid = vlan_tx_tag_get_id(skb);
	} else {
	/*if vlan info is in ethernet header*/
		if (skb->protocol ==  htons(ETH_P_8021Q)) {
			struct vlan_hdr *vhdr;
			unsigned short vlan_tci;

			vhdr = (struct vlan_hdr *) skb->data;
			vlan_tci = ntohs(vhdr->h_vlan_TCI);
			vid = vlan_tci & VLAN_VID_MASK;
			vlan_in_header = 1;
		}
	}

	if (vid) {
		if (vid == ts_config.primary_vid || wan_tag)
			skb->vlan_tci = 0;
		/*whether need to check if the vid is inlcued in the traffic policy???*/
	}

	return NF_ACCEPT;
}

unsigned int ts_rx_process(struct sk_buff *skb, unsigned char wan_tag)
{
	unsigned short vid = 0;
	struct ethhdr *hdr;
	struct client_db *cli_db = NULL;

	/*firstly, check the transparent vlan*/
	if (is_transparent_vlan_on()) {
		if (skb->vlan_proto == htons(ETH_P_8021Q) && skb->vlan_tci) {
			vid = vlan_tx_tag_get_id(skb);
			
			if (is_transparent_vlan(vid))
				return NF_ACCEPT;
		}
	}

	/*then, check normal traffic separation configuration*/
	if (!is_ts_enable())
		return NF_ACCEPT;

	/*skb with vlan tag ?? how about ETH_P_8021AD*/
	if (skb->vlan_proto == htons(ETH_P_8021Q) && skb->vlan_tci) {
		vid = vlan_tx_tag_get_id(skb);
		if (vid == ts_config.primary_vid || !is_vid_in_policy(vid))
			return NF_DROP;
		return NF_ACCEPT;
	}

	if (wan_tag) {
		hdr = eth_hdr(skb);
		rcu_read_lock();
		cli_db = get_client_db(hdr->h_dest);
		if (!cli_db) {
			rcu_read_unlock();
			goto primary_tag;
		}

		vid = cli_db->vid;
		rcu_read_unlock();

		if (vid >= 3 && vid <= 4094) {
			/*need add secondary vid for wireless sta on wan interface*/
			if (vid != ts_config.primary_vid) {
				__vlan_hwaccel_put_tag(skb, htons(ETH_P_8021Q), vid);
				goto end;
			}
		}
	}

primary_tag:
	if (ts_config.primary_vid != INVALID_VLAN_ID) {
		vid = ((ts_config.default_pcp & 0x7) << 13)  | (ts_config.primary_vid & 0x0FFF);
		__vlan_hwaccel_put_tag(skb, htons(ETH_P_8021Q), vid);
	}

end:
	return NF_ACCEPT;
}

unsigned int ts_local_in_process(struct sk_buff *skb)
{
	unsigned short vid = 0;

	/*firstly, check the transparent vlan*/
	if (is_transparent_vlan_on()) {
		if (skb->vlan_proto == htons(ETH_P_8021Q) && skb->vlan_tci) {
			vid = vlan_tx_tag_get_id(skb);
			
			if (is_transparent_vlan(vid))
				return NF_ACCEPT;
		}
	}

	if (!is_ts_enable())
		return NF_ACCEPT;
	
	if (skb->vlan_proto == htons(ETH_P_8021Q) && skb->vlan_tci) {
//		printk("ts_local_in_process: remove vlan tag vid=%d\n", vid);
		skb->vlan_tci = 0;
	}

	return NF_ACCEPT;
}

unsigned char add_vlan_tag(struct sk_buff *skb, unsigned short vlan_id, unsigned char vlan_pcp)
{
	unsigned char vlan_tci = 0;

	vlan_tci |= 0x0fff & vlan_id;
	vlan_tci |= vlan_pcp << 13;
	
	skb = vlan_insert_tag(skb, htons(ETH_P_8021Q), vlan_tci);
	if (skb) {
		skb->protocol = htons(ETH_P_8021Q);
		skb->vlan_tci = 0;
		return 0;
	} else {
		return 1;
	}
}

unsigned int ts_local_out_process(struct sk_buff *skb, unsigned char out_type)
{
	unsigned short vid = 0;
	struct ethhdr *hdr;
	struct client_db *cli_db = NULL;
	
	/*firstly, check the transparent vlan*/
	if (is_transparent_vlan_on()) {
		if (skb->vlan_proto == htons(ETH_P_8021Q) && skb->vlan_tci) {
			vid = vlan_tx_tag_get_id(skb);
			
			if (is_transparent_vlan(vid))
				goto end;
		}
	}
	
	if (!is_ts_enable())
		goto end;
	hdr = eth_hdr(skb);

	rcu_read_lock();
	cli_db = get_client_db(hdr->h_dest);
	if (!cli_db) {
		rcu_read_unlock();
		goto end;
	}

	vid = cli_db->vid;
	rcu_read_unlock();

	if (vid >= 1 && vid <= 4094) {
		/*if out device is ethernet and its vlan id equals to primary vlan id, do not add vlan tag for it*/
		if (out_type == ETH && vid == ts_config.primary_vid)
			goto end;
		__vlan_hwaccel_put_tag(skb, htons(ETH_P_8021Q), vid);
#if 0
		/*if out device is ethernet and its vlan id equals to primary vlan id, do not add vlan tag for it*/
		if (out_type == ETH && vid == ts_config.primary_vid)
			goto end;

		/*if it is already 802.1q header*/
		if (skb->protocol == htons(ETH_P_8021Q)) {
			struct vlan_hdr *vhdr;
			unsigned short vlan_tci = 0;

			vhdr = (struct vlan_hdr *) skb->data;
			vlan_tci |= 0x0fff & vid;
			vhdr->h_vlan_TCI = htons(vlan_tci);
		} else {
		/*if it is not 802.1q header, insert the vlan header to packet*/
			skb_push(skb, ETH_HLEN);
			add_vlan_tag(skb, vid, 0);
			skb_pull(skb, ETH_HLEN);
		}
#endif
	}
end:
	return NF_ACCEPT;
}

unsigned int ts_ip_pre_routing_process(struct sk_buff *skb)
{
	unsigned short vid = 0;
	unsigned char *dst = NULL;
	struct client_db *cli_db = NULL;

	if (!is_ts_enable())
		goto end;

	skb_set_network_header(skb, 0);
	skb_push(skb, ETH_HLEN);
	dst = skb->data;
	skb_pull(skb, ETH_HLEN);

	rcu_read_lock();
	cli_db = get_client_db(dst);
	if (!cli_db) {
		rcu_read_unlock();
		goto end;
	}

	vid = cli_db->vid;
	rcu_read_unlock();

	if (vid >= 1 && vid <= 4094) {
		__vlan_hwaccel_put_tag(skb, htons(ETH_P_8021Q), vid);
	}
end:
	return NF_ACCEPT;
}


void dump_ts_info(void)
{
	unsigned short i = 0;
	struct client_context *head = NULL;
	struct client_db * client = NULL;

	printk("traffic separation config information\n");
	printk("ts switch %s\n", traffic_seperation ? "on" : "off");
	printk("ts %s\n", is_ts_enable() ? "enabled" : "disabled");
	printk("primary vlan id=%d, primary pcp=%d\n", ts_config.primary_vid, ts_config.default_pcp);
	for (i = 1; i < INVALID_VLAN_ID; i++) {
		if (is_vid_in_policy(i))
			printk("ts policy vids=%d\n", i);
	}


	for (i = 1; i < INVALID_VLAN_ID; i++) {
		if (is_transparent_vlan(i))
			printk("transparent vids=%d\n", i);
	}

	rcu_read_lock();
	for (i = 0; i < HASH_TABLE_SIZE; i++) {
		head = &ts_config.clients[i];
		hlist_for_each_entry_rcu(client, &head->client_head, hlist) {
			printk("mac(%02x:%02x:%02x:%02x:%02x:%02x) vid=%d\n", PRINT_MAC(client->mac), client->vid);
		}
	}
	rcu_read_unlock();
}
