#include <stdio.h>
#include <unistd.h>
#include <stdlib.h>
#include <string.h>
#include <ctype.h>
#include <errno.h>
#include <fcntl.h>
#include <sys/ioctl.h>
#include <sys/stat.h>
#include <linux/compiler.h>
#include <linux/types.h>

#include "../mptbase.h"

/* Typedefs
 */
#ifndef uchar
typedef unsigned char uchar;
#endif

#ifndef u8
typedef unsigned char u8;
#endif

#ifndef u16
typedef unsigned short u16;
#endif

#ifndef u32
typedef unsigned int u32;
#endif

#ifndef u64
typedef unsigned long long u64;
#endif


#ifndef min_t
#define min_t(type,x,y) \
        ({ type __x = (x); type __y = (y); __x < __y ? __x: __y; })
#endif

#ifndef max_t
#define max_t(type,x,y) \
        ({ type __x = (x); type __y = (y); __x > __y ? __x: __y; })
#endif


/* Driver Includes
 */
#include "../mptctl.h"
#include "../lsi/mpi_type.h"
#include "../lsi/mpi.h"
#include "../lsi/mpi_ioc.h"
#include "../lsi/mpi_cnfg.h"
#include "../lsi/mpi_init.h"
#include "../lsi/mpi_raid.h"

/*=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=*/

static char *mptflash_name_ver	= WHAT_MAGIC_STRING "MptFlash-" MPT_LINUX_VERSION_COMMON;

/*
 *	Special/reserved names...
 */
static char *ctlname	= MPT_MISCDEV_PATHNAME;
static char *versearch	= MPT_FW_REV_MAGIC_ID_STRING;
static int step = 0;

/*
 *  Forward protos...
 */
static char *get_oldver(char *results_buf, int iocnum);
static char *what_search(const char *buf, int size_in, int *size_out);
int verify_pid(char *fwbuf, struct mpt_fw_xfer *fwdata);

/*=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=*/
int
main (int argc, char **argv)
{
	char userinp[128];
	char newver[64];
	char oldver[64];
	char what2[64];
	char *oldverp;
	struct mpt_fw_xfer fwdata;
	unsigned long ctlnum = MPTFWDOWNLOAD;
	char *fwbuf;
	int fwfd;
	int ctlfd;
	int size;
	struct stat fwstat;
	char *myname;
	char *fwfname = NULL;
	char *arg2 = NULL;
	char *whatstr;
	int whatsz;
	int i, j;
	u32 checksum;

	if ( (myname = strrchr(argv[0], '/')) == NULL ) {
		myname = argv[0];
	} else {
		myname++;
		argv[0] = myname;
	}
	if (argc > 1)
		fwfname = argv[1];
	if (argc > 2)
		arg2 = argv[2];
	else
		arg2 = "0";

	step = 1;				/* step #1 */

	if (!fwfname || strlen(fwfname)<1 || !arg2 || !isdigit(arg2[0]))
	{
		fprintf(stderr, "%s\n", mptflash_name_ver+4);
		fprintf(stderr, "Usage: %s fwfname.rom [0|1|2|...|7 for ioc# ]\n", myname);
		return step;
	}
	fprintf(stderr, "  step%d: Cmd line check: Ok\n", step);

	fwdata.iocnum = atoi(arg2);
	if (fwdata.iocnum > MPT_MAX_ADAPTERS || fwdata.iocnum < 0) {
		fprintf(stderr, "  step%db: Cmd line Oops, resetting ioc from %d to 0\n", step, fwdata.iocnum);
		fwdata.iocnum = 0;
	}

	++step;
	if ((oldverp = get_oldver(oldver, fwdata.iocnum)) == NULL) {
		oldverp = "unknown";
	}
	fprintf(stderr, "  step%d: get_oldver() results: \"%s\"\n",
			step, oldverp);

	/*
	 *  Open f/w image file
	 */
	++step;
	if ((fwfd = open(fwfname, O_RDONLY)) < 0) {
		fprintf(stderr, "ERROR: open(\"%s\") FAILED: %s\n", fwfname, strerror(errno));
		return step;
	}
	fprintf(stderr, "  step%d: open(\"%s\"): Ok, fwfd=%d\n", step, fwfname, fwfd);

	/*
	 *  Stat the f/w image file to get size
	 */
	++step;
	if (fstat(fwfd, &fwstat) < 0) {
		fprintf(stderr, "ERROR: fstat(%d) FAILED: %s\n",
				fwfd, strerror(errno));
		return step;
	}
	size = (int)fwstat.st_size;
	fprintf(stderr, "  step%d: stat(\"%s\"): Ok, size=%d bytes\n",
			step, fwfname, size);

	/*
	 *  Malloc big enough buffer for entire f/w image
	 */
	++step;
	if ((fwbuf = malloc(size)) == NULL) {
		fprintf(stderr, "ERROR: Firmware buffer malloc %d bytes FAILED: %s\n",
				size, strerror(errno));
		return step;
	}
	fprintf(stderr, "  step%d: malloc(%d): Ok\n", step, size);

	bzero(fwbuf, size);

	fwdata.fwlen = size;
	fwdata.bufp = fwbuf;

	/*
	 *  Read all bytes from f/w image file to malloc'd buffer
	 */
	++step;
	if ((j = read(fwfd, fwbuf, size)) < 0) {
		fprintf(stderr, "ERROR: Firmware read from \"%s\" FAILED: %s\n",
				fwfname, strerror(errno));
		return step;
	}
	fprintf(stderr, "  step%d: read(%d,,%d): Ok\n", step, fwfd, size);

	if (verify_pid(fwbuf, &fwdata) != 0)
		exit(1);

	/*
	 * Checksum algorithm pulled from Stephen's lsiupdate utility
	 * for Solaris. This should work for both little and big endian
	 * systems.
	 */

	checksum = 0;
	for (i = 0; i < j; i++)
		checksum += ((U32)((U8)fwbuf[i])) << (8 * (i & 3));
	if (checksum != 0) {
		fprintf (stderr, "ERROR: Bad checksum for specified firmwware"
			 " image (%s).\n", fwfname);
		exit (1);
	}

	/*
	 *  Scan file for new version (magic what string identifier)
	 */
	newver[0] = what2[0] = '\0';
	if ((whatstr = what_search(fwbuf, size, &whatsz)) != NULL && whatsz) {
		strncpy(newver, whatstr, min_t(int, sizeof(newver)-1, whatsz));
		newver[min_t(int, sizeof(newver)-1, whatsz)] = '\0';
		if ((whatstr = what_search(whatstr, size-(whatstr-fwbuf+whatsz), &whatsz)) != NULL && whatsz) {
			strncpy(what2, whatstr, min_t(int, sizeof(what2)-1, whatsz));
			what2[min_t(int, sizeof(what2)-1, whatsz)] = '\0';
		}
	}
	fprintf(stderr, "         Current F/W Version = \"%s\"\n", oldverp);
	fprintf(stderr, "             New F/W Version = \"%s\"\n", newver);

	/*
	 *  Open /dev/mptctl for ioctl() call
	 */
	++step;
	if ((ctlfd = open(ctlname, O_RDWR)) < 0) {
		fprintf(stderr, "Driver mptctl not loaded. Executing insmod mptctl\n");
		system ("insmod mptctl");
		if ((ctlfd = open(ctlname, O_RDWR)) < 0) {
			fprintf(stderr, "ERROR: open(\"%s\") FAILED: %s\n", ctlname, strerror(errno));
			return step;
		}
	}
	fprintf(stderr, "  step%d: open(\"%s\"): Ok, ctlfd=%d\n", step, ctlname, ctlfd);

	++step;
	fprintf(stderr, "  pre-step%d: ioctl(%d,MPTFWDOWNLOAD,)\n", step, ctlfd);

	fprintf(stdout, "\n  You are at irreversible step %d of %d in this process.\n",
			step, step);
	fprintf(stdout, "  Are you absolutely sure you want to proceed with download? (y/[n]) ");
	fgets(userinp, sizeof(userinp), stdin);

	if (userinp[0] == 'y') {
//fprintf(stderr, "DbG: fwdata.iocnum = %02xh\n", fwdata.iocnum);
//fprintf(stderr, "DbG: fwdata.bufp   = %p\n", fwdata.bufp);
//fprintf(stderr, "DbG: fwdata.fwlen  = %d\n", fwdata.fwlen);
		if (ioctl(ctlfd, ctlnum, (char *) &fwdata) < 0) {
			fprintf(stderr, "ERROR: Firmware transfer to ioc%d FAILED: %s\n",
					fwdata.iocnum, strerror(errno));
			return step;
		} else {
			fprintf(stdout, "\n  !!! MPT firmware transfer to ioc%d SUCCEEDED !!!\n", fwdata.iocnum);
			fprintf(stdout, "  (%d of %d f/w update steps were successful)\n\n",
					step, step);
			fprintf(stdout, "  CHANGE EFFECTIVE ONLY AFTER NEXT RESET / POWER CYCLE!\n\n");
		}
	} else {
		fprintf(stderr, "User Aborted Download.\n");
	}

	return 0;
}

/*=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=*/
/*
 * Verify that the ProductID in the FW matches the board we're
 * sending the image to. ProductID format:
 * 4 bits - type { FC, SCSI, Unknown }
 * 4 bits - Product { Initiator, Target, ...}
 * 8 bits - Family { 1030 A0, 1030 B0, ...}
 *
 * Return: 
 *	-1 	unable to open proc file system
 *	-2	mismatch in type field
 *	-3	mismatch in family field
 *	-4	mismatch in product field or user aborts
 */
int verify_pid(char *fwbuf, struct mpt_fw_xfer *fwdata)
{
	u16 fwpid, iocpid;
	FILE *iocinfofp;
	char iipath[256], fwpidbuf[256], *fpbptr;
	char userinp[128];
	int i;

	snprintf(iipath, 256, "/proc/mpt/ioc%d/info", fwdata->iocnum);

	if ((iocinfofp = fopen(iipath, "r")) == NULL) {
		fprintf(stderr, "ERROR: unable to open %s to get ProductID of"
			" ioc%d.\n", iipath, fwdata->iocnum);
		return (-1);
	}

	while (1) {
		if (fgets(fwpidbuf, 256, iocinfofp) == NULL) {
			fprintf (stderr, "ERROR: unable to find ProductID "
				 "string in %s.\n", iipath);
			return (-1);
		}
		if (strstr(fwpidbuf, "ProductID") != NULL) {
			fpbptr = strchr(fwpidbuf, '=');
			fpbptr = &fpbptr[2];
			iocpid = (u16) strtol(fpbptr, NULL, 16);
			break;
		}
	}

	if (iocpid == MPI_MANUFACTPAGE_DEVICEID_FC909) {
		u8 *fwptr = (u8 *) fwbuf;

		for (i = 0x017000; (i+7) < fwdata->fwlen; i++) {
			if ((((fwptr[i+1] << 8)| fwptr[i]) == MPI_MANUFACTPAGE_DEVICEID_FC909)
			    && (fwptr[i+2] == 0x00) && (fwptr[i+3] == 0x00)
			    && (fwptr[i+4] == 0x00) && (fwptr[i+5] == 0x00)
			    && (fwptr[i+6] == 0xff) && (fwptr[i+7] == 0x03)) {
				return (0);
			}
		}
		fwpid = 0xdead;
	} else {
		fwpid = (((u16)((u8)fwbuf[35])) << 8) | ((u16)((u8)fwbuf[34]));
	}

	/*
	 * Check the Type (FC vs SCSI), upper nibble of pid. MUST Match
	 */
	if ((fwpid & MPI_FW_HEADER_PID_TYPE_MASK) != (iocpid & MPI_FW_HEADER_PID_TYPE_MASK)){
		uint fwtype = fwpid & MPI_FW_HEADER_PID_TYPE_MASK;
		uint ioctype = iocpid & MPI_FW_HEADER_PID_TYPE_MASK;

		fprintf(stderr, "ERROR: This fw image is for %s product \n"
			"and the specified board/ioc type is %s.\n",
			fwtype==0?"a SCSI":(fwtype == 1?"a FC":"an unknown"),
			ioctype==0?"SCSI":(ioctype == 1?"FC":"unknown"));

		return (-2);
	}

	/*
	 * Check the Family, lower byte. MUST Match
	 */
	if ((fwpid & MPI_FW_HEADER_PID_FAMILY_MASK) != (iocpid & MPI_FW_HEADER_PID_FAMILY_MASK)){
		uint fwfam = fwpid & MPI_FW_HEADER_PID_FAMILY_MASK;
		uint iocfam = iocpid & MPI_FW_HEADER_PID_FAMILY_MASK;

		fprintf(stderr, "ERROR: This fw image is for revision %d\n"
			" product and the specified board/ioc chip "
			" revision is %d.\n",
			fwfam & MPI_FW_HEADER_PID_FAMILY_MASK,
			iocfam & MPI_FW_HEADER_PID_FAMILY_MASK);

		return (-3);
	}


	/*
	 * Check the Product
	 */
	if ((fwpid & MPI_FW_HEADER_PID_PROD_MASK) != (iocpid & MPI_FW_HEADER_PID_PROD_MASK)){
		uint fwprod = fwpid & MPI_FW_HEADER_PID_PROD_MASK;
		uint iocprod = iocpid & MPI_FW_HEADER_PID_PROD_MASK;
		char *fwProdString;
		char *iocProdString;

		if ((fwpid & MPI_FW_HEADER_PID_TYPE_MASK) == MPI_FW_HEADER_PID_TYPE_SCSI) {
			switch(fwprod) {
			case MPI_FW_HEADER_PID_PROD_INITIATOR_SCSI:
				fwProdString="Initiator";
				break;
			case MPI_FW_HEADER_PID_PROD_TARGET_INITIATOR_SCSI:
				fwProdString="Initiator-Target";
				break;
			case MPI_FW_HEADER_PID_PROD_TARGET_SCSI:
				fwProdString="Target";
				break;
			case MPI_FW_HEADER_PID_PROD_IM_SCSI:
				fwProdString="Integrated Mirroring";
				break;
			case MPI_FW_HEADER_PID_PROD_IS_SCSI:
				fwProdString="Integrated Striping";
				break;
			case MPI_FW_HEADER_PID_PROD_CTX_SCSI:
				fwProdString="Context";
				break;
			default:
				fwProdString="Unknown";
				break;
			}

			switch(iocprod) {
			case MPI_FW_HEADER_PID_PROD_INITIATOR_SCSI:
				iocProdString="Initiator";
				break;
			case MPI_FW_HEADER_PID_PROD_TARGET_INITIATOR_SCSI:
				iocProdString="Initiator-Target";
				break;
			case MPI_FW_HEADER_PID_PROD_TARGET_SCSI:
				iocProdString="Target";
				break;
			case MPI_FW_HEADER_PID_PROD_IM_SCSI:
				iocProdString="Integrated Mirroring";
				break;
			case MPI_FW_HEADER_PID_PROD_IS_SCSI:
				iocProdString="Integrated Striping";
				break;
			case MPI_FW_HEADER_PID_PROD_CTX_SCSI:
				iocProdString="Context";
				break;
			default:
				iocProdString="Unknown";
				break;
			}
		} else  {
			fwProdString="Unknown";
			iocProdString="Unknown";
		}

		fprintf(stderr, "WARNING! The fw image contains %s firmware. \n"
			" The specified board/ioc contains %s firmware. \n"
			" Removing Integrated Mirroring or Striping firmware \n"
			" may leave your system unable to boot.  \n",
			fwProdString, iocProdString);
		fprintf(stdout, "  Are you absolutely sure you want to update \n"
			" the firwmare? (y/[n]) ");
		fgets(userinp, sizeof(userinp), stdin);
		if (userinp[0] == 'y')
			;
		else
			return (-4);
	}

	return (0);
}

/*=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=*/
/*
 *  See if we can open "/proc/mpt/iocN/summary" to obtain
 *  current f/w version.
 */
static char *
get_oldver(char *results_buf, int iocnum)
{
	char pname[32];
	char oldver_scratch[256];
	int procfd;
	unsigned int oldver_value = 0;
	char *str = NULL;
	char isscsi = 0;

	(void) sprintf(pname, "/proc/mpt/ioc%d/summary", iocnum);
	if ((procfd = open(pname, O_RDONLY)) >= 0) {
		char *found;
		/*
		 *  Dependency: Current format of /proc/mpt/iocN/summary output:
		 *    iocN: ... FwRev=0e0c1800h (Exp 1218), ...\n
		 *  or
		 *    iocN: ... FwRev=01010100h, ...\n
		 *
		 *                     ^^^^^^---- What we want!
		 */
		if (read(procfd, oldver_scratch, sizeof(oldver_scratch)) < 0) {
			fprintf(stderr, "         OldVer-read-fail-WARNING\n");
		} else {
			str = oldver_scratch;
			if ((found = strstr(str, versearch)) != NULL) {
				if (strstr(str, "LSI53C") != NULL)
					isscsi = 1;
				str = found + strlen(versearch);
				if (sscanf(str, "%x", &oldver_value) == 1) {
					;	// done!
				} else {
					fprintf(stderr, "         OldVer-sscanf-fail-WARNING\n");
					str = NULL;
				}
			} else {
				fprintf(stderr, "         OldVer-search-fail-WARNING\n");
			}
		}
		close(procfd);
	}

	/*
	 *	Changes to handle both old and new-style FWVersion encoding
	 */
	if ((isscsi == 0) && (oldver_value < 0x00010000)) {
		int mdbg = 0;

		if (oldver_value & 0x0080) {
			oldver_value &= ~0x0080;
			mdbg++;
		}

		if ((oldver_value & 0xE000) == 0xE000) {
			oldver_value &= 0x0FFFF;
			(void) sprintf(results_buf, "LSIFC9x9-%d.%02d.%02d%s",
					2001,
					(oldver_value & 0x0F00) >> 8,
					oldver_value & 0x00FF,
					mdbg ? " (MDBG)" : "" );
		} else {
			(void) sprintf(results_buf, "LSIFC9x9-%d.%02d.%02d%s",
					(oldver_value & 0xF000) >> 12,
					(oldver_value & 0x0F00) >> 8,
					oldver_value & 0x00FF,
					mdbg ? " (MDBG)" : "" );
		}
		str = results_buf;
	} else {
		int mdbg = 0;

		if (isscsi == 0) {
			if (oldver_value & 0x00008000) {
				oldver_value &= ~0x00008000;
				mdbg++;
			}

			if ((oldver_value & 0x0E000000) == 0x0E000000) {
				oldver_value &= 0x0FFFFFF;
				sprintf(results_buf, "LSIFC9x9-%d.%02d.%02d%s",
					2001,
					(oldver_value >> 16) & 0x00FF,
					(oldver_value >> 8) & 0x00FF,
					mdbg ? " (MDBG)" : "" );
			} else {
				int	 len;
				len = sprintf(results_buf, "LSIFC9x9-%d.%02d.%02d",
					oldver_value >> 24,
					(oldver_value >> 16) & 0x00FF,
					(oldver_value >> 8) & 0x00FF);
				if ((oldver_value & 0x0FF) != 0)
					(void) sprintf(results_buf+len, ".%02d",
						oldver_value & 0x0FF);
				if (mdbg)
					strcat(results_buf, " (MDBG)");
			}
		} else {
			int	 len;
			len = sprintf(results_buf, "LSI53Cxx-%02x.%02x.%02x",
					oldver_value >> 24,
					(oldver_value >> 16) & 0x00FF,
					(oldver_value >> 8) & 0x00FF);
			if ((oldver_value & 0x0FF) != 0)
				(void) sprintf(results_buf+len, ".%02d",
						oldver_value & 0x0FF);
		}
		str = results_buf;
	}

	return str;
}

/*=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=*/
static char *
what_search(const char *buf, int size_in, int *size_out)
{
	register char c;
	register char *found = NULL;

	*size_out = 0;
	while (--size_in) {
		c = *buf++;
loop:
		if (size_in < 5)
			return NULL;
		if (c != '@')
			continue;
		if (--size_in && (c = *buf++) != '(')
			goto loop;
		if (--size_in && (c = *buf++) != '#')
			goto loop;
		if (--size_in && (c = *buf++) != ')')
			goto loop;
		found = (char*)buf;
		while (--size_in && (c = *buf++) && c != '"' && c != '>' && c != '\n')
			(*size_out)++;
//		fprintf(stderr, "\nDbG: offset=%d, size_out=%d\n", found-fwbuf, *size_out);
		return found;
	}
	return NULL;
}

/*=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=*/

