# fieldbackup.py © 2024 by Marc Rochkind is licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 
# License: https://creativecommons.org/licenses/by-nc-sa/4.0
# (Non-commercial personal use only with attribution; derivative works must contain exactly the same license.)

import boto3
import os
import sys
import threading
import time
import hashlib
from tkinter import *
from tkinter.ttk import *
from tkinter import ttk
from tkinter import filedialog
from tkinter import messagebox
from tkinter import scrolledtext
from tkinter.simpledialog import askstring
import win32api
import shutil
import datetime
from credentials import *

# credentials.py should contain something like this:
	# def secret_key():
	#     return <secret_key>

	# def access_key():
	#     return <access_key>

ingest_path = 'C:/field_ingest'

try:
	os.makedirs(ingest_path)
except FileExistsError:
	pass
except:
	print('Error creating directories')

# https://stackoverflow.com/questions/12186993/what-is-the-algorithm-to-compute-the-amazon-s3-etag-for-a-file-larger-than-5gb
def calculate_s3_etag(file_path, chunk_size=8 * 1024 * 1024):
    md5s = []

    with open(file_path, 'rb') as fp:
        while True:
            data = fp.read(chunk_size)
            if not data:
                break
            md5s.append(hashlib.md5(data))

    if len(md5s) < 1:
        return '"{}"'.format(hashlib.md5().hexdigest())

    if len(md5s) == 1:
        return '"{}"'.format(md5s[0].hexdigest())

    digests = b''.join(m.digest() for m in md5s)
    digests_md5 = hashlib.md5(digests)
    return '"{}-{}"'.format(digests_md5.hexdigest(), len(md5s))

class ProgressPercentage(object):
	def __init__(self, filename, seq):
		self._seq = seq
		self._filename = filename
		self._size = float(os.path.getsize(filename))
		self._seen_so_far = 0
		self._lock = threading.Lock()
	def __call__(self, bytes_amount):
		# To simplify, assume this is hooked up to a single filename
		with self._lock:
			self._seen_so_far += bytes_amount
			percentage = (self._seen_so_far / self._size) * 100
			s = self._seq + "%s  %s / %s  (%.2f%%)" % (
			        self._filename, self._seen_so_far, self._size,
			        percentage)
			sys.stdout.write(s + '\n')
			sys.stdout.flush()

try:
	s3 = boto3.client('s3', aws_access_key_id=access_key(), aws_secret_access_key=secret_key())
except Exception as e:
	logit(e)

def upload_file(path, seq):
	global total_restart

	while True:
		base = os.path.basename(path)
		try:
			s3.upload_file(path, 'mjr-field', base, Callback=ProgressPercentage(path, seq))
		except Exception as e:
			print('ERROR from upload_file: ' + str(e))
		else:
			try:
				response = s3.get_object_attributes(Bucket='mjr-field', Key=base, ObjectAttributes=['ETag'])
			except Exception as e:
				print(seq + base + "ERROR: can't get ETag for uploaded file. " + str(e))
			else:
				etag_source = calculate_s3_etag(path)
				if etag_source == '"' + response['ETag'] + '"':
					return True
				else:
					print(seq + base + "ERROR: ETags do not match")
		total_restart += 1
		status("Sleeping for 30 sec. after error")
		time.sleep(30)
		status("Retrying " + base)

def check_etag(source, base, seq):
	global total_mismatch, total_restart

	while True:
		try:
			response = s3.get_object_attributes(Bucket='mjr-field', Key=base, ObjectAttributes=['ETag'])
		except s3.exceptions.NoSuchKey:
			return False
		except Exception as e:
			print(seq + base + "ERROR: can't get ETag for uploaded file. " + str(e))
		else:
			etag_source = calculate_s3_etag(source)
			matched = etag_source == '"' + response['ETag'] + '"'
			if not matched:
				total_mismatch += 1
				logit("ETag mismatch with S3: " + base);
			return True
		total_restart += 1
		status("Sleeping for 30 sec. after error")
		time.sleep(30)
		status("Retrying " + base)

def backupS3():
	global total_restart, total_mismatch

	n = 0
	total = 0
	total_uploaded = 0
	total_mismatch = 0
	total_already = 0
	total_restart = 0
	for root, dirs, files in os.walk(ingest_path):
		total += len(files)
	for root, dirs, files in os.walk(ingest_path):
		for base in files:
			n += 1
			seq = str(n) + ' of ' + str(total) + ': '
			source = os.path.join(root, base)
			if check_etag(source, base, seq): # name exists
				total_already += 1
				status(seq + base + ' already backed up')
			else:
				status(seq + base + ' backing up ...')
				upload_file(source, seq)
				total_uploaded += 1
				status(seq + base + ' backed up to S3')
	logit(f"Backup: Total Uploaded: {total_uploaded} of {total}");
	logit(f"Backup: Total Already Uploaded: {total_already} of {total}");
	logit(f"Backup: Total Already Uploaded With ETag Mismatch: {total_mismatch} of {total}");
	logit(f"Backup: Total Restarts: {total_restart}");
	if total_uploaded + total_already == total and total_mismatch == 0:
		logit("Backup: *** All good ***")
	else:
		logit("Backup: ERRORS")

def logit(s):
	print(s)
	pagetext.insert(END, s + '\n')
	pagetext.see(END)
	pagetext.update()

def status(s):
	status_label.config(text=s)
	status_label.update()

def find_card():
	global card_path

	ingest_button.state(["disabled"])
	pagetext.delete("1.0", END)
	drives = win32api.GetLogicalDriveStrings()
	drives = drives.split('\000')[:-1]
	n = 0
	for d in drives:
		path = d + 'DCIM'
		if os.path.exists(path):
			n += 1
			x = os.listdir(path)
			if len(x) == 1:
				card_path = path + '\\' + x[0]
				logit(card_path)
	if n == 1:
		ingest_button.state(["!disabled"])
		logit("Card found")
	else:
		logit('Zero or multiple card paths found')

def computer_path(card_path_full):
    mtime = os.path.getmtime(card_path_full)
    date = str(datetime.datetime.fromtimestamp(mtime))[0:10]
    dir = os.path.join(ingest_path, date)
    try:
    	os.mkdir(dir)
    except:
    	pass
    return os.path.join(dir, os.path.basename(card_path_full))

def ingest_file(card_path_full, n, total):
	global total_already, total_ingested

	seq = str(n) + ' of ' + str(total) + ': '
	# etag_card = calculate_s3_etag(card_path_full)
	base = os.path.basename(card_path_full)
	target = computer_path(card_path_full)
	if os.path.exists(target):
		total_already += 1
		status(seq + base + ' already ingested')
		return True
	try:
		status(seq + base + ' ingesting')
		shutil.copyfile(card_path_full, target)
		total_ingested += 1
	except:
		logit('ERROR copying file from card: ' + card_path_full)
		return False
	return True

def verify_file(card_path_full, n, total):
	global total_verified

	seq = str(n) + ' of ' + str(total) + ': '
	etag_card = calculate_s3_etag(card_path_full)
	base = os.path.basename(card_path_full)
	target = computer_path(card_path_full)
	if not os.path.exists(target):
		logit(seq + base + ' not ingested; rerun ingestion')
		return False
	etag_target = calculate_s3_etag(target)
	if etag_card != etag_target:
		logit(seq + base + ' ETag mismatch; rerun ingestion')
		return False
	else:
		total_verified += 1
		status(seq + base + ' verified')
		return True

def ingest():
	global total_already, total_ingested, total_verified

	paths = os.listdir(card_path)
	total = len(paths)
	total_ingested = 0
	total_already = 0
	total_verified = 0
	logit(str(total) + ' files found')
	n = 0
	for path in paths:
		n += 1
		if not ingest_file(os.path.join(card_path, path), n, total):
			logit('Ingestion stopped')
			return False
	n = 0
	for path in paths:
		n += 1
		if not verify_file(os.path.join(card_path, path), n, total):
			logit('Ingestion stopped')
			return False
	status('')
	logit('All processed OK; now backing up to S3')
	logit(f"Ingestion: Total Ingested: {total_ingested} of {total}");
	logit(f"Ingestion: Total Already Ingested: {total_already} of {total}");
	logit(f"Ingestion: Total Verified: {total_verified} of {total}");
	if total_ingested + total_already == total and total_verified == total:
		logit("Ingestion: *** All good ***")
	else:
		logit("Ingestion: ERRORS")
	backupS3()

root = Tk()
root.title("Field Backup")

leftframe = ttk.Frame(root, width=36)
leftframe.grid(column=0, row=0, sticky="nsw")

rightframe = ttk.Frame(root)
rightframe.grid(column=1, row=0)#, sticky="nsew")

ttk.Button(leftframe, text="Find Card", command=find_card).grid(column=0, row=0)
ingest_button = ttk.Button(leftframe, text="Ingest & Backup", command=ingest, state=DISABLED,)
ingest_button.grid(column=0, row=1)
ttk.Button(leftframe, text="Backup", command=backupS3).grid(column=0, row=2)

pagetext = scrolledtext.ScrolledText(rightframe, undo=True, wrap=WORD)
pagetext.grid(column=0, row=0)#, sticky='nsew')
status_label = ttk.Label(rightframe, text='')
status_label.grid(column=0, row=1, sticky='w')

for child in leftframe.winfo_children(): 
    child.grid_configure(padx=5, pady=5)
for child in rightframe.winfo_children(): 
    child.grid_configure(padx=5, pady=5)

w = root.winfo_screenwidth()
h = root.winfo_screenheight()

status('Ready')
find_card()

root.mainloop()