#!/usr/bin/python

import sys, string, copy, time
import sequence
from math import log, pow, floor, exp
from hypergeometric import getLogFETPvalue

#
# turn on psyco to speed up by 3X
#
if __name__=='__main__':
    try:
        import psyco
        psyco.full()
        psyco_found = True
    except ImportError:
        psyco_found = False
    pass
    if psyco_found:
        print >> sys.stderr, "Using psyco..."

# Format for printing very small numbers; used by sprint_logx
_pv_format = "%4.1fe%+04.0f"
_log10 = log(10)

_ambig_to_dna = {
        'A' : 'A',
        'C' : 'C',
        'G' : 'G',
        'T' : 'T',
        'R' : 'AG',
        'Y' : 'CT',
        'K' : 'GT',
        'M' : 'AC',
        'S' : 'CG',
        'W' : 'AT',
        'B' : 'CGT',
        'D' : 'AGT',
        'H' : 'ACT',
        'V' : 'ACG',
        'N' : 'ACGT'
        }

# get array of int zeros (numpy is not standard)
def int_zeros(size):
    return [0] * size

# Too slow!!!  Not used...
def hamdist(str1, str2):
    """Count the # of differences between equal length strings str1 and str2"""

    diffs = 0
    for ch1, ch2 in zip(str1, str2):
        if ch1 != ch2 and ch1 != 'N':
            diffs += 1
    return diffs

# print very large or small numbers
def sprint_logx(logx, prec, format):
    """ Print x with given format given logx.  Handles very large
    and small numbers with prec digits after the decimal.
    Returns the string to print."""
    log10x = logx/_log10
    e = floor(log10x)
    m = pow(10, (log10x - e))
    if ( m + (.5*pow(10,-prec)) >= 10):
        m = 1
        e += 1
    str = format % (m, e)
    return str

def get_rc(word):
    return word.translate(string.maketrans("ACGTURYKMBVDHSWN", "TGCAAYRMKVBHDSWN"))[::-1]

def get_strings_from_seqs(seqs):
    """ Extract strings from FASTA sequence records.
    """
    strings = []
    for s in seqs:
        str = s.getString()
        # make upper case
        str = str.upper()
        strings.append(str)
    return strings

def get_min_hamming_dists(seqs, word, use_rc=True):
    """
    Return a list of the number of sequences whose minimum
    Hamming distance to the given word or its reverse complement
    is X, for X in [0..w].

    Also returns a list with the best distance for each sequence.

    Also returns a list with the offset (1-relative) of the leftmost
    match to the word in each sequence. (Matches to reverse
    complement of the word have negative offsets.)

    Returns: counts, dists, offsets
    """

    w = len(word)
    word = word.upper()

    # get reverse complement
    if use_rc: rc_word = get_rc(word)

    # get minimum distances
    counts = int_zeros(w+1)
    dists = int_zeros(len(seqs))
    offsets = int_zeros(len(seqs))
    for seqno in range(len(seqs)):
        s = seqs[seqno]
        s = s.upper()
        slen = len(s)

        # loop over all words in current sequence to get minimum distance
        min_dist = w
        best_offset = 0                 # no site in sequence
        for i in range(0, slen-w+1):

            # positive strand
            dist = 0
            for j in range(w):
                ch1 = word[j]
                ch2 = s[i+j]
                #FIXME: this should be done better
                # don't allow ambiguous characters in match
                if ch2 == 'N': dist = w+1; break
                #if ch1 != 'N' and ch1 != ch2: dist += 1
                if ch2 not in _ambig_to_dna[ch1]: dist += 1
                if dist >= min_dist: break
            if dist < min_dist:
                min_dist = dist
                best_offset = i+1

            # negative strand
            if use_rc:
                dist = 0
                for j in range(w):
                    ch1 = rc_word[j]
                    ch2 = s[i+j]
                    #FIXME: this should be done better
                    # don't allow ambiguous characters in match
                    if ch2 == 'N': dist = w+1; break
                    #if ch1 != 'N' and ch1 != ch2: dist += 1
                    if ch2 not in _ambig_to_dna[ch1]: dist += 1
                    if dist >= min_dist: break
                if (dist < min_dist):
                    min_dist = dist
                    best_offset = -(i+1)

            if min_dist == 0: break

        counts[min_dist] += 1
        dists[seqno] = min_dist         # best distance for sequence
        offsets[seqno] = best_offset    # best offset in sequence

    return counts, dists, offsets


def get_best_hamming_enrichment(word, pos_seqs, neg_seqs, print_dists=False, given_only=False):
    """
    Find the most enriched Hamming distance for given word.
    Returns (best_dist, best_log_pvalue, pos_dists, pos_offsets, best_p, best_n)
    Pos_dists[s] is best distance for sequence s.
    Pos_offsets[s] is offset to the leftmost best match in sequence s.
    Offsets are 1-relative; match on reverse complement has negative offset.
    """

    # compute Hamming distance from input word
    # print >> sys.stderr, "Computing Hamming distances..."
    (pos_counts, pos_dists, pos_offsets) = get_min_hamming_dists(pos_seqs, word, not given_only)
    (neg_counts, neg_dists, neg_offsets) = get_min_hamming_dists(neg_seqs, word, not given_only)

    (best_dist, best_log_pvalue, best_p, best_n) = \
            get_best_hamming_enrichment_from_counts(len(word), len(pos_seqs), len(neg_seqs), \
                    pos_counts, neg_counts)

    return best_dist, best_log_pvalue, pos_dists, pos_offsets, best_p, best_n


def get_best_hamming_enrichment_from_counts(w, P, N, pos_counts, neg_counts, print_dists=False):
    """
    Find the most enriched Hamming distance for given counts.
    Returns (best_dist, best_log_pvalue, best_p, best_n)
    """

    # get cumulative counts
    cum_pos_counts = copy.copy(pos_counts)
    cum_neg_counts = copy.copy(neg_counts)
    for i in range(1, len(pos_counts)):
        cum_pos_counts[i] += cum_pos_counts[i-1]
        cum_neg_counts[i] += cum_neg_counts[i-1]
        #print >> sys.stderr, "cum counts", i, pos_counts[i-1], cum_pos_counts[i]

    # compute hypergeometric enrichment at each distance and save best
    best_dist = w
    best_log_pvalue = 1     # infinity
    best_p = 0
    best_n = 0
    for i in range(len(pos_counts)):
        p = cum_pos_counts[i]
        n = cum_neg_counts[i]
        log_pvalue = getLogFETPvalue(p, P, n, N, best_log_pvalue)
        if log_pvalue < best_log_pvalue:
            best_log_pvalue = log_pvalue
            best_dist = i
            best_p = p
            best_n = n
    #FIXME:
        #print >> sys.stderr, "Computing FET", i, p, P, n, N, sprint_logx(log_pvalue, 1, _pv_format)

        if print_dists:
            pv_string = sprint_logx(log_pvalue, 1, _pv_format)
            print "d %d : %d %d %d %d %s" % (i, p, P, n, N, pv_string)

    return best_dist, best_log_pvalue, best_p, best_n


def get_enrichment_and_neighbors(word, alphabet, min_log_pvalue, pos_seqs, neg_seqs, given_only):
    """
    Find the Hamming enrichment of the given word.
    If pvalue under given pvalue, estimate the Hamming
    enrichment for each distance-1 neighbor of the given word.
    Returns word_pvalue_record, neighbor_pvalue_records (dict).
    pvalue_record = [p, P, n, N, log_pvalue, dist]
    """

    # compute Hamming distance from input word
    #print >> sys.stderr, "Getting Hamming counts, distances and offsets in positive sequences..."
    (pos_counts, pos_dists, pos_offsets) = get_min_hamming_dists(pos_seqs, word, not given_only)
    #print >> sys.stderr, "Getting Hamming counts, distances and offsets in negative sequences..."
    (neg_counts, neg_dists, neg_offsets) = get_min_hamming_dists(neg_seqs, word, not given_only)

    P = len(pos_seqs)
    N = len(neg_seqs)
    w = len(word)
    alen = len(alphabet)
    #print >> sys.stderr, "Getting Hamming enrichment for consensus", word, "from counts"
    (dist, log_pvalue, p, n) = get_best_hamming_enrichment_from_counts(w, P, N, pos_counts, neg_counts)

    # create p-value record
    pvalue_record = [p, P, n, N, log_pvalue, dist]
    neighbors = {}
    print >> sys.stderr, "Exact p-value of", word, "is", sprint_logx(log_pvalue, 1, _pv_format), pvalue_record

    # short-circuit
    if log_pvalue >= min_log_pvalue:
        return pvalue_record, neighbors

    # get inverse alphabet
    inv_alphabet = {}
    for i in range(alen):
        inv_alphabet[alphabet[i]] = i

    # get PWM and nsites arrays for each exact Hamming distance
    (pos_freqs, pos_nsites) = get_freqs_and_nsites(w, pos_dists, pos_offsets, pos_seqs)
    (neg_freqs, neg_nsites) = get_freqs_and_nsites(w, neg_dists, neg_offsets, neg_seqs)

    # get estimated counts of sequences at each Hamming distance for each possible Hamming-1 move.
    print >> sys.stderr, "Estimating Hamming distance counts in positive sequences..."
    pos_neighbor_counts = estimate_new_hamming_counts(word, alphabet, inv_alphabet, pos_counts, pos_freqs, pos_nsites)
    print >> sys.stderr, "Estimating Hamming distance counts in negative sequences..."
    neg_neighbor_counts = estimate_new_hamming_counts(word, alphabet, inv_alphabet, neg_counts, neg_freqs, neg_nsites)

    # compute the estimated enrichment p-value for each Hamming-1 move
    print >> sys.stderr, "Finding best neighbors..."
    for col in range(w):
        for let in range(alen):
            if let == inv_alphabet[word[col]]: continue
            new_word = word[:col] + alphabet[let] + word[col+1:]
            p_counts = pos_neighbor_counts[col][let]
            n_counts = neg_neighbor_counts[col][let]
            (dist, log_pvalue, p, n) = get_best_hamming_enrichment_from_counts(w, P, N, p_counts, n_counts)
            new_pvalue_record = [p, P, n, N, log_pvalue, dist]
            neighbors[new_word] = new_pvalue_record

    return(pvalue_record, neighbors)


def get_freqs_and_nsites(w, dists, offsets, seqs):
    """
    Get PWMs and numbers of sites for each exact Hamming distance.
    Returns (freqs[d], nsites[d]).
    """

    freqs = []
    nsites = []
    for d in range(w+1):            # Hamming distance
        # get alignments
        aln = get_aln_from_dists_and_offsets(w, d, d, dists, offsets, seqs)
        # make frequency matrix
        #FIXME alphabet
        pwm = sequence.PWM(sequence.getAlphabet('DNA'))
        pwm.setFromAlignment(aln)
        freqs.append(pwm.getFreq())
        nsites.append(len(aln))
        #print "getting freqs for d", d

    return freqs, nsites


def estimate_new_hamming_counts(word, alphabet, inv_alphabet, counts, freqs, nsites):
    """
    Return neighbor_counts[col][let][dist] list estimated if old_let
    were replaced by let.  List contains estimated counts of sequences
    containing the neighboring word.
    """

    # estimate counts for each Hamming-1 neighbor of given word
    w = len(word)
    alen = len(alphabet)
    neighbor_counts = []                            # [col][let] = counts
    for col in range(w):                            # column to change
        neighbor_counts.append([])
        old_let = inv_alphabet[word[col]]
        for let in range(alen):                 # new letter index
            neighbor_counts[col].append( \
                    estimate_new_hamming_counts_col_let(counts, freqs, nsites, w, col, old_let, let, inv_alphabet['N']))

    return neighbor_counts


def estimate_new_hamming_counts_col_let(counts, freqs, nsites, max_dist, col, old_let, new_let, match_any):
    """
    Return counts[dist] list estimated if old_let were replaced by new_let
    in column col.
FIXME: could add columns on either side of freqs array so we can do "lengthen" and "shift" moves.
    freqs[d][col] is the frequency column of the alignment of nsites[d] sites with
    Hamming distance d (PWM of sites at Hamming distance d).
    match_any is the index of the 'N' character.
    """

    # initialize the counts
    new_counts = copy.copy(counts)

    # done if same letter
    if old_let == new_let: return new_counts

    for d in range(max_dist+1):
        up = down = 0
        n = nsites[d]
        #FIXME
        #print "d", d, "col", col, "old let", old_let, "new let", new_let
        # no counts at this distance?
        if len(freqs[d]) == 0: continue
        f = freqs[d][col]
        # new N?
        if old_let != match_any and new_let == match_any:
            up = n * (1-f[old_let])                 # sites that moved up
        else:
        # lost an N?
            if old_let == match_any and new_let != match_any:
                down = n * (1-f[new_let])       # sites that moved down
            else:
            # no N involved
                down = n * f[old_let]           # sites that moved down
                up = n * f[new_let]             # sites that moved up

        # update the counts at this distance
        new_counts[d] -= up + down
        # update the counts at distances above and below this one
        if d > 0:
            new_counts[d-1] += int(round(up))
        if d < max_dist:
            new_counts[d+1] += int(round(down))

    #FIXME
    #print "new counts", new_counts, "old counts", counts
    #if col == 9 and new_let == 1:
    #       print "in estimating:", old_let, new_let, "old counts", counts, "new counts", new_counts
    return new_counts

def get_aln_from_dists_and_offsets(w, d1, d2, dists, offsets, seqs):
    """
    Get an alignment of the leftmost best site in each sequence.
    Site must have d1 <= distance <= d2.
    """

    aln = []
    for i in range(len(seqs)):
        dist = dists[i]
        # only include match if within distance range
        if dist <= d2 and d1 <= dist:
            offset = offsets[i]
            if offset == 0: continue
            if offset > 0:
                start = offset-1
            else:
                start = -offset - 1
            match = seqs[i][start : start+w]
            if offset < 0: match = get_rc(match)
            aln.append(match)
    return aln

def erase_word_distance(word, dist, seqs):
    """
    Erase all non-overlapping matches to word in seqs by changing to 'N'.
    Greedy algorithm erases leftmost site first.
    Site must be within given Hamming distance to be erased.
    """

    (counts, dists, offsets) = get_min_hamming_dists(seqs, word)

    w = len(word)
    ens = w * 'N'
    for i in range(len(seqs)):
        d = dists[i]
        # only erase match if within distance range
        min_offset = 1
        if d == dist:
            offset = offsets[i]
            if offset == 0: continue
            if offset > 0:
                if offset < min_offset: continue
                start = offset-1
            else:
                if -offset < min_offset: continue
                start = -offset - 1
            min_offset = start + w  # enforce non-overlap

            seqs[i] = seqs[i][:start] + ens + seqs[i][start+w:]
        #FIXME
        #print seqs[i]


def get_aln_from_word(word, d1, d2, seqs, given_only):
    """
    Get an alignment of the leftmost best site in each sequence.
    Site must have d1 <= distance <= d2.
    """
    (counts, dists, offsets) = get_min_hamming_dists(seqs, word, not given_only)

    # get the alignment of the best sites
    return get_aln_from_dists_and_offsets(len(word), d1, d2, dists, offsets, seqs)


def get_best_hamming_alignment(word, pos_seqs, neg_seqs, print_dists=False, given_only=False):
    """
    Find the most enriched Hamming distance to the word.
    Get the alignment of the matching sites (ZOOPS).

    Returns dist, log_pvalue, p, P, n, N, aln
    """

    # get best Hamming enrichment
    (best_dist, best_log_pvalue, dists, offsets, p, n) = get_best_hamming_enrichment(word, pos_seqs, neg_seqs, print_dists, given_only)

    # get the alignment of the best sites
    aln = get_aln_from_dists_and_offsets(len(word), 0, best_dist, dists, offsets, pos_seqs)

    return best_dist, best_log_pvalue, p, len(pos_seqs), n, len(neg_seqs), aln


def print_meme_header():

    alphabet = "".join(sequence.getAlphabet('DNA').getSymbols())
    print >> sys.stdout, "\nMEME version 4.5\n\nALPHABET= %s\n\nstrands: + -\n\nBackground letter frequencies (from" % alphabet
    # FIXME: put in real background frequencies
    print >> sys.stdout, "A 0.25 C 0.25 G 0.25 T 0.25\n"


def print_meme_motif(word, nsites, ev_string, aln):

    # make the PWM
    pwm = sequence.PWM(sequence.getAlphabet('DNA'))
    pwm.setFromAlignment(aln)

    # print PWM in MEME format
    alphabet = sequence.getAlphabet('DNA').getSymbols()
    alen = len(alphabet)
    w = len(word)
    print "\nMOTIF %s\nletter-probability matrix: alength= %d w= %d nsites= %d E= %s" % \
            (word, alen, w, nsites, ev_string)
    for row in pwm.pretty(): print row
    print ""

def sorted_re_pvalue_keys(re_pvalues):
    """ Return the keys of a p-value dictionary, sorted by increasing p-value """
    if not re_pvalues: return []
    keys = re_pvalues.keys()
    keys.sort( lambda x, y: cmp(re_pvalues[x][4], re_pvalues[y][4]) or cmp(x,y) )
    return keys

def get_best_neighbor(word, min_log_pvalue, pos_seqs, neg_seqs):
    # returns the log pvalue of the word and the best Hamming distance-1 neighbor
    #FIXME

    # estimate the p-values of Hamming distance-1 neighbors
    (pvalue_record, neighbors) = get_enrichment_and_neighbors(word, "ACGTN", min_log_pvalue, pos_seqs, neg_seqs)

    # if the p-value of the word is too large, quit
    if pvalue_record[4] >= min_log_pvalue or not neighbors:
        return pvalue_record[4], ""

    # return the best neighbor (estimated)
    (best_pvalue, best_word) = min([ (neighbors[tmp][4],tmp) for tmp in neighbors] )
    print  "original", word, sprint_logx(pvalue_record[4], 1, _pv_format), "best neighbor (estimated)", best_word, sprint_logx(best_pvalue, 1, _pv_format)
    return pvalue_record[4], best_word

def refine_consensus(word, pos_seqs, neg_seqs):

    best_log_pvalue = 1
    best_word = word
    while 1:
        (log_pvalue, neighbor) = get_best_neighbor(word, best_log_pvalue, pos_seqs, neg_seqs)
        if log_pvalue >= best_log_pvalue:               # didn't improve
            break
        else:                                           # improved
            best_word = word
            best_log_pvalue = log_pvalue
            word = neighbor

    return(best_word, best_log_pvalue)

def main():

    pos_seq_file_name = None        # no positive sequence file specified
    neg_seq_file_name = None        # no negative sequence file specified
    refine = False

    #
    # get command line arguments
    #
    usage = """USAGE:
    %s [options]

    -w <word>               word (required)
    -p <file_name>          positive sequences FASTA file name (required)
    -n <file_name>          negative sequences FASTA file name (required)
    -r                      refine consensus by branching search
                            (distance 1 steps; beam size = 1).
    -h                      print this usage message

    Compute the Hamming distance from <word> to each FASTA sequence
    in the positive and negative files.  Apply Fisher Exact test to
    each distance.
    <word> may contain ambiguous characters, but only 'N' currently
    works with -r though.

    """ % (sys.argv[0])

    # no arguments: print usage
    if len(sys.argv) == 1:
        print >> sys.stderr, usage; sys.exit(1)

    # parse command line
    i = 1
    while i < len(sys.argv):
        arg = sys.argv[i]
        if (arg == "-w"):
            i += 1
            try: word = sys.argv[i]
            except: print >> sys.stderr, usage; sys.exit(1)
        elif (arg == "-p"):
            i += 1
            try: pos_seq_file_name = sys.argv[i]
            except: print >> sys.stderr, usage; sys.exit(1)
        elif (arg == "-n"):
            i += 1
            try: neg_seq_file_name = sys.argv[i]
            except: print >> sys.stderr, usage; sys.exit(1)
        elif (arg == "-r"):
            try: refine = True
            except: print >> sys.stderr, usage; sys.exit(1)
        elif (arg == "-h"):
            print >> sys.stderr, usage; sys.exit(1)
        else:
            print >> sys.stderr, usage; sys.exit(1)
        i += 1

    # check that required arguments given
    if (pos_seq_file_name == None or neg_seq_file_name == None):
        print >> sys.stderr, usage; sys.exit(1)

    # keep track of time
    start_time = time.time()

    # read sequences
    print >> sys.stderr, "Reading sequences..."
    pos_seqs = get_strings_from_seqs(sequence.readFASTA(pos_seq_file_name,'Extended DNA'))
    neg_seqs = get_strings_from_seqs(sequence.readFASTA(neg_seq_file_name,'Extended DNA'))

    word = word.upper()
    #print >> sys.stderr, "Computing Hamming enrichment..."
    #(dist, log_pvalue, p, P, n, N, aln) = get_best_hamming_alignment(word, pos_seqs, neg_seqs)

    if refine:
        (best_word, best_log_pvalue) = refine_consensus(word, pos_seqs, neg_seqs)
    else:
        best_word = word

    print >> sys.stderr, "Computing Hamming alignment..."
    (dist, log_pvalue, p, P, n, N, aln) = get_best_hamming_alignment(best_word, pos_seqs, neg_seqs)
    pv_string = sprint_logx(log_pvalue, 1, _pv_format)
    nsites = len(aln)
    print >> sys.stderr, "[", p, P, n, N, dist, "]"
    print >> sys.stderr, "Best ZOOPs alignment has %d sites / %d at distance %d with p-value %s" % (nsites, P, dist, pv_string)
    print_meme_header()
    print_meme_motif(best_word, nsites, pv_string, aln)

    # print elapsed time
    end_time = time.time()
    elapsed = end_time - start_time
    print >> sys.stderr, "elapsed time: %.2f seconds" % elapsed
    print >> sys.stdout, "#elapsed time: %.2f seconds" % elapsed


if __name__=='__main__': main()
