import os
import glob
import xml.etree.ElementTree as ET
from datetime import datetime, timedelta
import numpy as np
from typing import List, Tuple, Dict, Any

from resistics.common.checks import isMagnetic, isElectric, consistentChans
from resistics.common.print import blockPrint
from resistics.time.data import TimeData
from resistics.time.clean import removeZeros, removeZerosSingle, removeNansSingle

SPAM data has the following characteristics:

- SPAM raw data is single precision floats with unit Volts.
- Getting unscaled samples returns data with unit mV for both the electric and magnetic fields. This is because gain is removed in unscaled samples to ensure consistency when a single recording is made up of multiple data files, each with different gain settings
- The start time in XTR files is the time of the first sample in the data
- The end time in XTR files is the time of the last sample in the data

In situations where a SPAM dataset is recorded in multiple small files, it is required that the recording is continuous.

Attributes
----------
recChannels : Dict
Channels in each data file
dtype : np.float32
The data type
numDataFiles : int
The number of data files

Methods
-------
setParameters()
Set parameters specific to a data format
getUnscaledSamples(**kwargs)
Get raw, unscaled data
getDataFilesForSamples(startSample, endSample)
Get the data files that contribute to the requested samples so they can be read
getPhysicalSamples(**kwargs)
Get data in physical units
Get sections and section headers to be read in for SPAM data
chanDefaults()
Get defaults values for channel headers
Merge the headers from all the data files
printDataFileList()
Get data file information as a list of strings
printDataFiles()
Print data file information to terminal

Notes
-----
Getting unscaled samples for SPAM data removes the gain rather than return exactly the values in the data files. In cases where there are multiple data files, it is not necessary that they have been recorded with the same gain. Therefore, to ensure consistency when looking at raw data, the gain is removed at the getUnscaledSamples stage rather than getPhysicalSamples, where it would have probably been more appropriate. This means that getUnscaledSamples returns data where all channels are in mV.

The scalings to convert the raw data to mV are stored in the ts_lsb chan header and calculated out as the header files are being read.

.. todo::
"""

[docs]    def setParameters(self) -> None:
# get a list of the header and data files in the folder
self.dataF = glob.glob(os.path.join(self.dataPath, "*.RAW"))
# data byte information might be different for each file
# so it is a dictionary
self.dataByteOffset: Dict = {}
self.recChannels = {}
self.dataByteSize = 4
# data type
self.dtype = np.float32
# get the number of data files and header files - this should be equal
self.numDataFiles: int = len(self.dataF)

[docs]    def getUnscaledSamples(self, **kwargs) -> TimeData:
"""Get raw data from data file, returned in mV

SPAM raw data is single precision float with unit Volts. Calling this applies the ts_lsb calculated when the headers are read. This is because when a recording consists of multiple data files, each channel of each data file might have a different scaling. The only way to make the data consistent is to apply the ts_lsb scaling.

Therefore, this method returns the data in mV for all channels.

Parameters
----------
chans : List[str], optional
List of channels to return if not all are required
startSample : int, optional
First sample to return
endSample : int, optional
Last sample to return

Returns
-------
TimeData
Time data object
"""
# initialise chans, startSample and endSample with the whole dataset
options = self.parseGetDataKeywords(kwargs)

# get the files to read and the samples to take from them, in the correct order
options["startSample"], options["endSample"]
)
numSamples = options["endSample"] - options["startSample"] + 1
# set up the dictionary to hold the data
data = {}
for chan in options["chans"]:
data[chan] = np.zeros(shape=(numSamples), dtype=self.dtype)

# loop through chans and get data
sampleCounter = 0
# get samples - this is inclusive
# spam files always record 5 channels
byteOff = (
self.dataByteOffset[dFile]
+ sToRead[0] * self.recChannels[dFile] * self.dataByteSize
)
dFilePath = os.path.join(self.dataPath, dFile)
dFilePath,
dtype=self.dtype,
mode="r",
offset=byteOff,
)
# now need to unpack this
for chan in options["chans"]:
# check to make sure channel exists
self.checkChan(chan)
# get the channel index - the chanIndex should give the right order in the data file
# as it is the same order as in the header file
chanIndex = self.chanMap[chan]
# use the range sampleCounter -> sampleCounter +  dSamples, because this actually means sampleCounter + dSamples - 1 as python ranges are not inclusive of the end value
# scale by the lsb scalar here - note that these can be different for each file in the run
data[chan][sampleCounter : sampleCounter + dSamples] = (
* scalar[chan]
)
# increment sample counter
sampleCounter = sampleCounter + dSamples  # get ready for the next data read

# return data
startTime, stopTime = self.sample2time(
options["startSample"], options["endSample"]
)
"Unscaled data {} to {} read in from measurement {}, samples {} to {}".format(
startTime,
stopTime,
self.dataPath,
options["startSample"],
options["endSample"],
)
)
"Data scaled to mV for all channels using scalings in header files"
)
return TimeData(
sampleFreq=self.getSampleFreq(),
startTime=startTime,
stopTime=stopTime,
data=data,
)

[docs]    def getDataFilesForSamples(
self, startSample: int, endSample: int
) -> Tuple[List[str], List[List[int]], List[float]]:
"""Get the data files that have to be read to cover the sample range

Parameters
----------
startSample : int
Starting sample of the sample range
endSamples : int
Ending sample of the sample range

Returns
-------
Time data object
"""
# have the datafiles saved in sample order beginning with the earliest first
# go through each datafile and find the range to be read
scalings = []
for idx, dFile in enumerate(self.dataFileList):
fileStartSamp = self.dataRanges[idx][0]
fileEndSamp = self.dataRanges[idx][1]
if fileStartSamp > endSample or fileEndSamp < startSample:
continue  # nothing to read from this file
# in this case, there is some overlap with the samples to read
readFrom = 0  # i.e. the first sample in the datafile
readTo = fileEndSamp - fileStartSamp  # this the last sample in the file
if fileStartSamp < startSample:
if fileEndSamp > endSample:
scalings.append(self.scalings[idx])

[docs]    def getPhysicalSamples(self, **kwargs):
"""Get data scaled to physical values

resistics uses field units, meaning physical samples will return the following:

- Electrical channels in mV/km
- Magnetic channels in mV
- To get magnetic fields in nT, calibration needs to be performed

Notes
-----
The method getUnscaledSamples multiplies the raw data by the ts_lsb converting it to mV. Because gain is removed when getting the unscaledSamples and all channel data is in mV, the only calculation that has to be done is to divide by the dipole lengths (east-west spacing and north-south spacing).

To get magnetic fields in nT, they have to be calibrated.

Parameters
----------
chans : List[str]
List of channels to return if not all are required
startSample : int
First sample to return
endSample : int
Last sample to return
remaverage : bool
Remove average from the data
remzeros : bool
Remove zeroes from the data
remnans: bool
Remove NanNs from the data

Returns
-------
TimeData
Time data object
"""
# initialise chans, startSample and endSample with the whole dataset
options = self.parseGetDataKeywords(kwargs)
# get data
timeData = self.getUnscaledSamples(
chans=options["chans"],
startSample=options["startSample"],
endSample=options["endSample"],
)
# Scalars are applied in getUnscaledSamples to convert to mV - this is for ease of calculation and because each data file in the run might have a separate scaling
# all that is left is to divide by the dipole length in km and remove the average
for chan in options["chans"]:
if chan == "Ex":
# multiply by 1000/self.getChanDx same as dividing by dist in km
timeData.data[chan] = 1000 * timeData.data[chan] / self.getChanDx(chan)
"Dividing channel {} by electrode distance {} km to give mV/km".format(
chan, self.getChanDx(chan) / 1000.0
)
)
if chan == "Ey":
# multiply by 1000/self.getChanDy same as dividing by dist in km
timeData.data[chan] = 1000 * timeData.data[chan] / self.getChanDy(chan)
"Dividing channel {} by electrode distance {} km to give mV/km".format(
chan, self.getChanDy(chan) / 1000.0
)
)

# if remove zeros - False by default
if options["remzeros"]:
timeData.data[chan] = removeZerosSingle(timeData.data[chan])
# if remove nans - False by default
if options["remnans"]:
timeData.data[chan] = removeNansSingle(timeData.data[chan])
# remove the average from the data - True by default
if options["remaverage"]:
timeData.data[chan] = timeData.data[chan] - np.average(
timeData.data[chan]
)

"Remove zeros: {}, remove nans: {}, remove average: {}".format(
options["remzeros"], options["remnans"], options["remaverage"]
)
)
return timeData

[docs]    def spamHeaders(self) -> Tuple[List[str], Dict[str, str]]:
"""Get the sections in SPAM header files (XTR and XTRX)

Returns
-------
sections : List[str]
The sections in the header files
"""
sections = ["STATUS", "TITLE", "PROJECT", "FILE", "SITE", "CHANNAME", "DATA"]
sectionHeaders["TITLE"] = ["AUTHOR", "VERSION", "DATE", "COMMENT"]

[docs]    def chanDefaults(self) -> Dict[str, Any]:

Returns
-------
Dict[str, Any]
Dictionary of headers for channels and default values
"""
chanH = {}
chanH["gain_stage1"] = 1
chanH["gain_stage2"] = 1
chanH["hchopper"] = 0  # this depends on sample frequency
chanH["echopper"] = 0
# channel output information (sensor_type, channel_type, ts_lsb, pos_x1, pos_x2, pos_y1, pos_y2, pos_z1, pos_z2, sensor_sernum)
chanH["ats_data_file"] = ""
chanH["num_samples"] = 0
chanH["sensor_type"] = ""
chanH["channel_type"] = ""
chanH["ts_lsb"] = 1
# the lsb/scaling is not applied. data is raw voltage which needs to be scaled
# an lsb is constructed from the scaling in the XTR/XTRX file to take the data to mV
chanH["scaling_applied"] = False  # check this
chanH["pos_x1"] = 0
chanH["pos_x2"] = 0
chanH["pos_y1"] = 0
chanH["pos_y2"] = 0
chanH["pos_z1"] = 0
chanH["pos_z2"] = 0
chanH["sensor_sernum"] = 0
return chanH

For SPAM data, the may be more than one header file as data can be split up into smaller files as it is recorded. In that case, the header information should be somehow merged.

All sampling frequencies should be the same
"""
else:

# check to make sure no gaps, calculate out the sample ranges and list the data files for each sample

The raw data for SPAM is in single precision Volts. However, if there are multiple data files for a single recording, each one may have a different gain. Therefore, a scaling has to be calculated for each data file and channel. This scaling will convert all channels to mV.

For the most part, this method only reads recording information. However, it does additionally calculate out the lsb scaling and store it in the ts_lsb channel header. More information is provided in the notes.

Notes
-----
The raw data for SPAM is in single precision floats and record the raw Voltage measurements of the sensors. However, if there are multiple data files for a single continuous recording, each one may have a different gain. Therefore, a scaling has to be calculated for each data file.

For electric channels, the scaling begins with the scaling provided in the header file in the DATA section. This incorporates any gain occuring in the device. This scaling is further amended by a conversion to mV and polarity reversal,

.. math::

scaling = 1000 * scaling , \\
scaling = -1000 * scaling , \\
ts_lsb = scaling ,

where the reason for the 1000 factor in line 2 is not clear, nor is the polarity reversal. However, this information was provided by people more familiar with the data format.

For magnetic channels, the scaling in the header file DATA section is ignored. This is because it includes a static gain correction, which would be duplicated at the calibration stage. Therefore, this is not included at this point.

.. math::

scaling = -1000 , \\
ts_lsb = scaling ,

This scaling converts the magnetic data from V to mV.

Parameters
----------
"""
sectionLines = {}
# let's get data
for line in lines:
line = line.strip()
line = line.replace("'", " ")
# continue if line is empty
if line == "":
continue
if "[" in line:
sec = line[1:-1]
sectionLines[sec] = []
else:
sectionLines[sec].append(line)
# the base class is built around a set of headers based on ATS headers
# though this is a bit more work here, it saves lots of code repetition
# recording information (start_time, start_date, stop_time, stop_date, ats_data_file)
fileLine = sectionLines["FILE"][0]
fileSplit = fileLine.split()
timeLine = sectionLines["FILE"][2]
timeSplit = timeLine.split()
# these are the unix time stamps
startDate = float(timeSplit[1] + "." + timeSplit[2])
datetimeStart = datetime.utcfromtimestamp(startDate)
stopDate = float(timeSplit[3] + "." + timeSplit[4])
datetimeStop = datetime.utcfromtimestamp(stopDate)
# here calculate number of samples
deltaSeconds = (datetimeStop - datetimeStart).total_seconds()
# calculate number of samples - have to add one because the time given in SPAM recording is the actual time of the last sample
numSamples = int(deltaSeconds * headers["sample_freq"]) + 1
# put these in headers for ease of future calculations in merge headers
# spam datasets only have the one data file for all channels
# data information (meas_channels, sample_freq)
chanLine = sectionLines["CHANNAME"][0]
# this gets reformatted to an int later
# deal with the channel headers
for iChan in range(0, numChansInt):
chanH = self.chanDefaults()
# set the sample frequency from the main headers
# line data - read through the data in the correct channel order
chanLine = sectionLines["CHANNAME"][iChan + 1]
chanSplit = chanLine.split()
dataLine = sectionLines["DATA"][iChan + 1]
dataSplit = dataLine.split()
# channel input information (gain_stage1, gain_stage2, hchopper, echopper)
chanH["gain_stage1"] = 1
chanH["gain_stage2"] = 1
# channel output information (sensor_type, channel_type, ts_lsb, pos_x1, pos_x2, pos_y1, pos_y2, pos_z1, pos_z2, sensor_sernum)
chanH["ats_data_file"] = fileSplit[1]
chanH["num_samples"] = numSamples

# channel information
# spams often use Bx, By - use H within the software as a whole
chanH["channel_type"] = consistentChans(chanSplit[2])
# the sensor number is a bit of a hack - want MFSXXe or something - add MFS in front of the sensor number - this is liable to break
# at the same time, set the chopper
calLine = sectionLines["200{}003".format(iChan + 1)][0]
calSplit = calLine.split()
if isMagnetic(chanH["channel_type"]):
chanH["sensor_sernum"] = calSplit[
2
]  # the last three digits is the serial number
sensorType = calSplit[1].split("_")[1][-2:]
chanH["sensor_type"] = "MFS{:02d}".format(int(sensorType))
if "LF" in calSplit[1]:
chanH["hchopper"] = 1
else:
chanH["sensor_type"] = "ELC00"
if "LF" in calLine:
chanH["echopper"] = 1

# data is raw voltage of sensors
# both E and H fields need polarity reversal (from email with Reinhard)
scaling = float(dataSplit[-2])
if isElectric(chanH["channel_type"]):
# the factor of 1000 is not entirely clear
lsb = 1000.0 * scaling
# volts to millivolts and a minus to switch polarity giving data in mV
lsb = -1000.0 * lsb
else:
# volts to millivolts and a minus to switch polarity giving data in mV
# scaling in header file is ignored because it duplicates static gain correction in calibration
lsb = -1000.0
chanH["ts_lsb"] = lsb

# the distances
if chanSplit[2] == "Ex":
chanH["pos_x1"] = float(dataSplit[4]) / 2
chanH["pos_x2"] = chanH["pos_x1"]
if chanSplit[2] == "Ey":
chanH["pos_y1"] = float(dataSplit[4]) / 2
chanH["pos_y2"] = chanH["pos_y1"]
if chanSplit[2] == "Ez":
chanH["pos_z1"] = float(dataSplit[4]) / 2
chanH["pos_z2"] = chanH["pos_z1"]

# append chanHeaders to the list

# check information from raw file headers

Parameters
----------
"""
raise NotImplementedError("Support for XTRX files has not yet been implemented")

Read the headers from the raw file and figure out the data byte offset.

Parameters
----------
rawFile : str
The .RAW data file

Notes
-----
Open with encoding ISO-8859-1 because it has a value for all bytes unlike other encoding. In particular, want to find number of samples and the size of the header. The extended header is ignored.
"""
dFile = open(os.path.join(self.dataPath, rawFile), "r", encoding="ISO-8859-1")

# read EVENT HEADER - there can be multiple of these, but normally only the one
# Multiple events are largely deprecated. Only a single event is used
fileSize = os.path.getsize(os.path.join(self.dataPath, rawFile))
seekPt = (record - 1) * generalHeader["recLength"]
if not seekPt > fileSize:
# seek from beginning of file
dFile.seek(seekPt, 0)
# read extra to make sure
eventSplit = eventString.split()
eH = {}
eH["start"] = int(eventSplit[0])
eH["startms"] = int(eventSplit[1])
eH["stop"] = int(eventSplit[2])
eH["stopms"] = int(eventSplit[3])
eH["cvalue1"] = float(eventSplit[4])
eH["cvalue2"] = float(eventSplit[5])
eH["cvalue3"] = float(eventSplit[6])
eH["EHInfile"] = int(eventSplit[7])
eH["nextEH"] = int(eventSplit[8])
eH["previousEH"] = int(eventSplit[9])
eH["numData"] = int(eventSplit[10])
eH["startData"] = int(eventSplit[11])
eH["extended"] = int(eventSplit[12])
record = eH["nextEH"]  # set to go to next eH
else:
break  # otherwise break out of for loops
# close the data file
dFile.close()
# now compare number of samples with that calculated previously
self.printWarning("Data file: {}".format(dFile))
self.printWarning(
"Number of samples in raw file header {} does not equal that calculated from data {}".format(
)
)
self.printWarning("Number of samples calculated from data will be used")
# set the byte offset for the file
self.dataByteOffset[rawFile] = (

Checks all the header files to see if there are any gaps and calculates the sample ranges for each file together with the total number of samples. Sets the start and end time of the recording and class variables datetimeStart and datetimeStop.

Parameters
----------
List of headers from each data file
List of chan headers from each data file
"""
# take the first header as an example
# just fill in the data file list and data ranges
self.dataRanges = [[0, self.headers["num_samples"] - 1]]
self.scalings = []
tmp = {}
self.scalings.append(tmp)
return  # then there was only one file - no need to do all the below

# make sure that all headers have the same sample rate
# and save the start and stop times and dates
startTimes = []
stopTimes = []
numSamples = []
self.printError(
"Not all datasets in {} have the same sample frequency.\nExiting...".format(
self.dataPath
),
quitRun=True,
)
self.printError(
"Not all datasets in {} have the same number of channels.\nExiting...".format(
self.dataPath
),
quitRun=True,
)
# now store startTimes, stopTimes and numSamples
# do this as datetimes, will be easier
datetimeStart = datetime.strptime(startString, "%Y-%m-%d %H:%M:%S.%f")
datetimeStop = datetime.strptime(stopString, "%Y-%m-%d %H:%M:%S.%f")
startTimes.append(datetimeStart)
stopTimes.append(datetimeStop)
# check the start and end times
# sort by start times
sortIndices = sorted(list(range(len(startTimes))), key=lambda k: startTimes[k])
# now sort stop times by the same indices
check = True
# get the stop time of the previous dataset
stopTimePrev = stopTimes[sortIndices[i - 1]]
startTimeNow = startTimes[sortIndices[i]]
if startTimeNow != stopTimePrev + sampleTime:
self.printWarning(
"There is a gap between the datafiles in {}".format(self.dataPath)
)
self.printWarning(
"Please separate out datasets with gaps into separate folders"
)
# print out where the gap was found
self.printWarning("Gap found between datafiles:")
self.printWarning(
)
self.printWarning(
)
# set check as false
check = False
# if did not pass check, then exit
if not check:
self.printError(
"Gaps in data. All data for a single recording must be continuous. Exiting...",
quitRun=True,
)

# make sure there are no gaps
totalSamples = sum(numSamples)

# get a list of all the datafiles, scalings and the sample ranges
self.dataFileList = []
self.dataRanges = []
self.scalings = []
sample = -1
# now need some sort of lookup table to say where the sample ranges are
iSort = sortIndices[i]  # get the sorted index
startSample = sample + 1
endSample = (
startSample + numSamples[iSort] - 1
)  # -1 because this is inclusive of the start sample
self.dataRanges.append([startSample, endSample])
# increment sample
sample = endSample
# save the scalings for each chan
tmp = {}
self.scalings.append(tmp)

# now set the LSB information for the chanHeaders
# i.e. if they change, this should reflect that
lsbSet = set()
for scalar in self.scalings:
if len(lsbSet) == 1:
else:
self.printWarning(
"Multiple different LSB values found for chan {}: {}".format(
chan, list(lsbSet)
)
)
self.printWarning(
"This is handled, but the header information given will show only a single LSB value"
)

# do the same with number of samples
datetimeStart = min(startTimes)
datetimeStop = max(stopTimes)
# set datafiles = the whole list of datafiles
"%H:%M:%S.%f"
)

[docs]    def printDataFileList(self) -> List[str]:
"""Information about the data files as a list of strings

Returns
-------
List[str]
List of information about the data files
"""
textLst: List[str] = []
textLst.append("Data File\t\tSample Ranges")
for dFile, sRanges in zip(self.dataFileList, self.dataRanges):
textLst.append("{}\t\t{} - {}".format(dFile, sRanges[0], sRanges[1]))
textLst.append("Total samples = {}".format(self.getNumSamples()))
return textLst

[docs]    def printDataFileInfo(self) -> None:
"""Print a list of the data files"""
blockPrint(
"{} Data File List".format(self.__class__.__name__),
self.printDataFileList(),
)