#include "oriand.h"

const int MAXN = 1e6 + 10;
const long long int ANDD = (1LL << 63) - 1LL;
const int MAXK = 64;

int n, k;
long long int a[MAXN];
int orp[MAXN];
int andp[MAXN];
int cnt[MAXK];

void find_or()
{
    for(int bit = 0; bit < MAXK; bit++)
    {
        cnt[bit] = 0;
    }

    int right = 1;

    a[0] = 0;

    for(int bit = 0; bit < k; bit++)
    {
        if(a[1] & (1LL << bit))
            cnt[bit]++;
    }

    int posfrom = n + 1;
    
    for(int left = 1; left <= n; left++)
    {
        int minb = n + 1;
        for(int bit = 0; bit < k; bit++)
        {
            if(a[left - 1] & (1LL << bit))
            {
                cnt[bit]--;
            }
            minb = std::min(minb, cnt[bit]);
        }

        while(right <= n)
        {
            if(minb > 0)
                break;

            right++;

            minb = n + 1;

            for(int bit = 0; bit < k; bit++)
            {
                if(a[right] & (1LL << bit))
                {
                    cnt[bit]++;
                }

                minb = std::min(minb, cnt[bit]);
            }
        }

        if(right == n + 1)
        {
            posfrom = left;
            break;
        }

        orp[left] = right;
    }

    for(int i = posfrom; i <= n; i++)
    {
        orp[i] = n + 1;
    }
}

void find_and()
{
    
    for(int bit = 0; bit < MAXK; bit++)
    {
        cnt[bit] = 0;
    }

    int left = n;

    a[n + 1] = 0;

    for(int bit = 0; bit < k; bit++)
    {
        if(a[n] & (1LL << bit))
            cnt[bit]++;
    }

    int posfrom = -1;
    
    for(int right = n; right >= 1; right--)
    {
        int maxb = 0;

        for(int bit = 0; bit < k; bit++)
        {
            if(a[right + 1] & (1LL << bit))
            {
                cnt[bit]--;
            }
                
            maxb = std::max(maxb, cnt[bit]);
        }

        while(1 <= left)
        {
            if(maxb < right - left + 1)
                break;

            left--;

            maxb = 0;

            for(int bit = 0; bit < k; bit++)
            {
                if(a[left] & (1LL << bit))
                {
                    cnt[bit]++;
                }

                maxb = std::max(maxb, cnt[bit]);
            }
        }

        if(left == 0)
        {
            posfrom = right;
            break;
        }

        andp[right] = left;
    }

    for(int i = 1; i <= posfrom; i++)
    {
        andp[i] = 0;
    }
}

long long int solve()
{
    find_or();
    find_and();

    long long int ans = 0;

    int right = 2;
    
    for(int left = 1; left <= n; left++)
    {
        while(right <= n)
        {
            if(orp[left] < andp[right])
                break;

            right++;
        }

        if(left < right && orp[left] < andp[right])
            ans = ans + 1LL * (n - right + 1);
    }

    return ans;
}

long long oriand(int N, int K, std::vector<long long> A)
{
    n = N;
    k = K;

    for(int i = 1; i <= n; i++)
    {
        a[i] = A[i - 1];
    }

    return solve();
}
