##############################
# Required packages:
# ggplot2
# reshape2
##############################
# Read in tables with:
# read.data <- read.table("RNA_data_file_name",header=T,row.names=1,sep="\t")
# sample.data <- read.table("sample_info_file_name",header=T,sep="\t")
##############################
# Use with:
# simple.RNA.QC(read.data, sample.data, "some_name")
##############################

simple.RNA.QC <- function(count.table, sample.data, name=NULL){

	# Extract gene counts and noFeature
	if("NoFeature" %in% row.names(count.table)){
		noFeature <- tail(count.table,1)
		count.table <- head(count.table,-1)
	} else {
		noFeature <- head(count.table,1)
		noFeature <- noFeature-noFeature
		row.names(noFeature) <- "NoFeature"
	}

	# Read summary
	read.summary <- rbind(noFeature,colSums(count.table))
	row.names(read.summary) <- c("Unmapped","Genes")

	# cpm normalisation
	normalised.count.table <- sweep(count.table, 2, colSums(count.table),
			FUN="/")
	normalised.count.table <- normalised.count.table * 1000000

	# set up output pdf, if name given
	if(!is.null(name)){
		pdf(paste0(name,".pdf"))
	}

	# barplots
    plot.counts.barplot(read.summary)

	# Distance matrix heatmap
	dists <- make.dist.heatmap(normalised.count.table)
	plot.dist.heatmap(as.matrix(dists))

	# PCA
	PC.dat <- calc.pca(normalised.count.table)
	plot.PCA.ids(PC.dat)
	plot.PCA.group(PC.dat,sample.data[,2],colnames(sample.data)[2], T)
	
	# Close graphics device
	if(!is.null(name)){
		dev.off()
	}

}

######################
# Helper functions for plots adapted from RNA_QC_Plot_Functions.R

plot.counts.barplot <- function(summary, prop=F) {

	library(ggplot2)
	library(reshape2)

	tmp <- melt(t(summary)) # long format, correct order of categories
	names(tmp)[2] <- "Category"

    if(prop){
		fill.position <- T
		ylbl <- "Proportion of reads"
	} else {
        fill.position <- F
		counts <- summary/10^6
		ylbl <- "Million reads"
	}

	theplot <- ggplot(tmp, aes(x=Var1, y=value, fill=Category)) +
			geom_col(pos="stack") +
			labs(title="Read counts", x="Samples", y=ylbl) +
			theme(axis.text.x=element_text(angle=45, vjust=1, hjust=1))

	print(theplot)

}

make.dist.heatmap <- function(masked.data){
        dists <- dist(t(masked.data))
        return(dists)
}

plot.dist.heatmap <- function(dists) {
	heatmap.helper(dists,"Distance matrix heatmap",reverse.palette=F)
}

heatmap.helper <- function(mat, main, reverse.palette=F){

        library(ggplot2)
        library(reshape2)

        long.mat <- melt(as.matrix(mat)) # long format, symmetric anyway
        long.mat$Var1 <- as.factor(as.character((long.mat$Var1)))
        long.mat$Var2 <- as.factor(as.character((long.mat$Var2)))

        if(reverse.palette){
            rgb.palette <- c("yellow","firebrick")
        } else {
            rgb.palette <- c("firebrick","yellow")
        }

        p <- ggplot(long.mat, aes(Var1, Var2)) +
            geom_tile(aes(fill=value), colour="white") +
            scale_fill_gradient(low=rgb.palette[1], high=rgb.palette[2]) +
            labs(x="", y="", title=main) +
            theme(legend.title=element_blank(), axis.ticks = element_blank(),
                axis.text.x = element_text(angle=45, hjust=1, vjust=1),
                panel.background = element_blank()) +
            coord_fixed(ratio=1)

        # This will need something to change the order of samples as per dendrogram eventually

        print(p)
}

calc.pca <- function(data)
{
	pca.norm <- prcomp(t(data))
	PC.norm <- pca.norm$x
	PC.norm <- data.frame(PC.norm)

	var.cap <- summary(pca.norm)
	pc_var.cap <- round(var.cap$importance[2,]*100, 1)
	cum_var.cap <- round(var.cap$importance[3,]*100, 1)

	out <- list("PC"=PC.norm, "PC1_var"=pc_var.cap[1],
            "PC2_var"=pc_var.cap[2])
	return(out)
}

plot.PCA.ids <- function(PC.dat){
    pca.ids(PC.dat$PC, PC.dat$PC1_var, PC.dat$PC2_var, "Sample Names",
        rownames(PC.dat$PC),0.6)
}

pca.ids <- function(PC.norm, pc1, pc2, title, labels, mag)
{
	## identify samples by name
	par(mfrow=c(1,1))
	plot(PC.norm$PC1, PC.norm$PC2, type="n",
            main=paste0("PCA plot: ", title), cex.main=1,
            xlab=paste("PC1 (",pc1, sep="","%)"),
            ylab=paste("PC2 (",pc2, sep="","%)"))
	text(PC.norm$PC1, PC.norm$PC2, labels, cex=mag)
}

plot.PCA.group <- function(PC.dat,groups,groupName, col.pca=T) {

    pca.colours(PC.dat$PC, PC.dat$PC1_var, PC.dat$PC2_var, groupName,
        groups)
}

pca.colours <- function(PC.norm, pc1, pc2, title, group)
{
	library(ggplot2)
	library(reshape2)

	tmp <- cbind(PC.norm, Group=factor(group))

	theplot <- ggplot(tmp, aes(x=PC1, y=PC2, col=Group, shape=Group)) +
			geom_point(show.legend=T) + xlab(paste0("PC1 (",pc1,"%)")) +
			ylab(paste0("PC2 (",pc2,"%)")) + labs(title="PCA by group")


	print(theplot)

}