#!/usr/bin/env python
# coding: utf-8

# In[24]:


import pymc3 as pm 
import matplotlib.pyplot as plt
import numpy as np 
import theano
import theano.tensor as tt
from scipy import stats
import pymc3 as pm 

n=100
k=10
time_list = np.random.random(n)
time_list.sort()
T = np.amax(time_list)
cat_list = np.random.choice(k,n)
data =[cat_list,time_list]


# In[25]:


def likelihood(A, W, nu, lambda0,X): 
    print('just got in')
    def logp_(data1):
        
        cat_list,time_list = data1.value    
        N = time_list.size #number of observations
        K = A.shape[0] #number of categories
        
        cop_cat = np.array([int(cat) for cat in cat_list])
        cat_list = theano.shared(cop_cat)
        time_list = theano.shared(time_list)
        
        
        ll = 0
        WA = W * A

        eps = 0.000000001
        beta = 0.000000001
        for index in range(1, N):
            
            c_time = time_list[index] 
            c_cat = cat_list[index]

            
            parent_idx = X[index]
            if tt.lt(0,parent_idx ):
                p_time = time_list[parent_idx]
                p_cat = cat_list[parent_idx]
                current_occurance_prob = tt.exp(-tt.abs_(c_time - p_time)*nu) * WA[c_cat][p_cat]
                
            else:
                current_occurance_prob = lambda0[c_cat]
           
            ll += tt.log(tt.minimum(tt.maximum(current_occurance_prob, eps), 1))

        beta /= nu
        return beta + ll

    return logp_        


# In[26]:



with pm.Model() as model:

            
    X = []
    for i in range(n):
        X.append(pm.Categorical('X_{}'.format(i),p=np.ones(n)/n))
        
    A = pm.Bernoulli('A',p=0.5,shape=(k,k))
    W = pm.Normal('W',mu=0,sd=1,shape=(k,k))
    nu = pm.Gamma('nu',alpha=2,beta=1)
    lambda0 = pm.Gamma('lambda0',alpha=2,beta=1,shape = k)
    like = pm.DensityDist('like', likelihood(A, W, nu, lambda0,X), observed=data)

    trace = pm.sample()


# In[ ]:





# In[ ]:




