#!/usr/bin/env python

import sys

if len(sys.argv) < 3:
    print "\nwritevtkhalos: Create a VTK PolyData file of halo particle positions."
    print "Usage: writevtkhalos <position file> <bound halo file> [Nmin] [cull]"
    print "[Nmin] optional minimum particle count per group cutoff (inclusive)"
    print "[cull] remove subhalos"
    print "Output: <positionfile>.pvd"
    print "        <positionfile>.0...(N-1).vtp"
    print "        N = number of halos"
    print "Exiting\n"
    sys.exit(1)

import array
import struct
import vtk

def cull_subgroups(groups):

    indexes = range(0,len(groups))
    indexes.reverse()

    for i in indexes:
        little_group = groups[i]
        for j in range(0,i):
            big_group = groups[j]
            bs = set(big_group)
            ls = set(little_group)
            if bs & ls == ls:
                groups.pop(i)
                break

fname = sys.argv[1]
fnamehalos = sys.argv[2]
np_min = 0
cull = False

if sys.argv[-1] == 'cull':
    cull = True
    sys.argv.remove('cull')

if len(sys.argv) > 3:
   np_min = int(sys.argv[3])

fin = open(fname,'rb')
x = array.array('d')
y = array.array('d')
z = array.array('d')
np = struct.unpack('=i',fin.read(4))[0]

x.read(fin,np)
y.read(fin,np)
z.read(fin,np)
fin.close()

fin = open(fnamehalos,'rb')
nh = struct.unpack('=i',fin.read(4))[0]
halos = []

for i in range(0,nh):
    nph = struct.unpack('=i',fin.read(4))[0]
    if nph:
        p = array.array('i')
        p.read(fin,nph)
        if nph >= np_min:
            halos.append(p)

fin.close()

strout = "Found %d groups" % len(halos)
if np_min:
    strout += " with particle count greater than %d" % np_min

print strout

if cull:
    cull_subgroups(halos)
    print "%d groups left after culling" % len(halos)


pvtp = open(fname + '.pvd','w')
pvtp.write('<?xml version="1.0"?>\n')
pvtp.write(' <VTKFile type="Collection" version="0.1" byte_order="BigEndian" compressor="vtkZLibDataCompressor">\n')
pvtp.write('   <Collection>\n')


for i in range(0,len(halos)):
    points = vtk.vtkPoints()
    scalars = vtk.vtkFloatArray()
    data = vtk.vtkPolyData()

    for j in range(0,len(halos[i])):
        scalars.InsertNextValue(float(i))
        points.InsertNextPoint(x[halos[i][j]], y[halos[i][j]], z[halos[i][j]])

    data.SetPoints(points)
    data.GetPointData().SetScalars(scalars)

    writer = vtk.vtkXMLPolyDataWriter()
    writer.SetInput(data)

    fname_i = fname +'.'+str(i) + '.hop.vtp'
    pvtp.write('     <DataSet part="'+str(i)+'" file="'+fname_i+'"/>\n')

    writer.SetFileName(fname_i)
    writer.Write()


pvtp.write('   </Collection>\n')
pvtp.write('</VTKFile>\n')

pvtp.close()

sys.exit(0) 
