import os
import glob
from warnings import warn

thrust_abspath = os.path.abspath("../../thrust/")

# try to import an environment first
try:
  Import('env')
except:
  exec open("../../build/build-env.py")
  env = Environment()

# this function builds a trivial source file from a Thrust header
def trivial_source_from_header(source, target, env):
  target_filename = str(target[0])
  fid = open(target_filename, 'w')

  # make sure we don't trip over <windows.h> when compiling with cl.exe
  if env.subst('$CC') == 'cl':
    fid.write('#include <windows.h>\n')

  for src in source:
    src_abspath = str(src)
    src_relpath = os.path.relpath(src_abspath, thrust_abspath)
    include = os.path.join('thrust', src_relpath)

    fid.write('#include <' + include + '>\n')
  fid.close()


# CUFile builds a trivial .cu file from a Thrust header
cu_from_header_builder = Builder(action = trivial_source_from_header,
                                 suffix = '.cu',
                                 src_suffix = '.h')
env.Append(BUILDERS = {'CUFile' : cu_from_header_builder})

# CPPFile builds a trivial .cpp file from a Thrust header
cpp_from_header_builder = Builder(action = trivial_source_from_header,
                                  suffix = '.cpp',
                                  src_suffix = '.h')
env.Append(BUILDERS = {'CPPFile' : cpp_from_header_builder})


# find all user-includable .h files in the thrust tree and generate trivial source files #including them
extensions = ['.h']
folders = ['',              # main folder
           'iterator/',
           'system/',
           'system/cpp',
           'system/cuda',
           'system/cuda/experimental',
           'system/omp',
           'system/tbb']

sources = []

header_fullpaths = []

for folder in folders:
  for ext in extensions:
    pattern = os.path.join(os.path.join(thrust_abspath, folder), "*" + ext)
    for fullpath in glob.glob(pattern):
      header_fullpaths.append(fullpath)
      headerfilename = os.path.basename(fullpath)

      # replace slashes with '_slash_'
      sourcefilename = fullpath.replace('/', '_slash_').replace('\\', '_slash_').replace('.h', '.ext')

      cu  = env.CUFile(sourcefilename.replace('.ext', '.cu'), fullpath)
      cpp = env.CPPFile(sourcefilename.replace('.ext', '_cpp.cpp'), fullpath)

      sources.append(cu)
      sources.append(cpp)

      # insure that all files #include <thrust/detail/config.h>
      fid = open(fullpath)
      if '#include <thrust/detail/config.h>' not in fid.read():
        warn('Header <thrust/' + folder + headerfilename + '> does not include <thrust/detail/config.h>')

# generate source files which #include all headers
all_headers_cu  = env.CUFile('all_headers.cu', header_fullpaths)
all_headers_cpp = env.CUFile('all_headers_cpp.cpp', header_fullpaths)

sources.append(all_headers_cu)
sources.append(all_headers_cpp)

# and the file with main()
sources.append('main.cu')
tester = env.Program('tester', sources)

