package com.cube.storage;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.*;
import java.nio.file.*;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

/**
 * LSM-Tree storage engine for Cube database - 100% Pure Java!
 */
public class LSMStorageEngine implements StorageEngine {
    
    private static final Logger logger = LoggerFactory.getLogger(LSMStorageEngine.class);
    private static final int MEMTABLE_FLUSH_THRESHOLD = 1024 * 1024; // 1MB
    
    private final Path dataDirectory;
    private final Path walPath;
    
    private volatile MemTable activeMemtable;
    private final Queue<MemTable> immutableMemtables;
    private final ReadWriteLock memtableLock;
    
    private final List<SSTable> sstables;
    private final ReadWriteLock sstableLock;
    
    private final ExecutorService flushExecutor;
    private final ExecutorService compactionExecutor;
    
    private WriteAheadLog wal;
    
    public LSMStorageEngine(String dataDir) throws IOException {
        this.dataDirectory = Paths.get(dataDir);
        this.walPath = dataDirectory.resolve("wal");
        
        Files.createDirectories(dataDirectory);
        Files.createDirectories(walPath);
        
        this.activeMemtable = new MemTable();
        this.immutableMemtables = new ConcurrentLinkedQueue<>();
        this.memtableLock = new ReentrantReadWriteLock();
        
        this.sstables = new CopyOnWriteArrayList<>();
        this.sstableLock = new ReentrantReadWriteLock();
        
        this.wal = new WriteAheadLog(walPath.resolve("wal.log"));
        
        this.flushExecutor = Executors.newSingleThreadExecutor(r -> {
            Thread t = new Thread(r, "CubeDB-Flush");
            t.setDaemon(true);
            return t;
        });
        
        this.compactionExecutor = Executors.newSingleThreadExecutor(r -> {
            Thread t = new Thread(r, "CubeDB-Compaction");
            t.setDaemon(true);
            return t;
        });
        
        recoverFromWAL();
        loadSSTables();
        
        logger.info("Cube database initialized at {}", dataDirectory);
    }
    
    @Override
    public void put(String key, byte[] value) throws IOException {
        if (key == null || value == null) {
            throw new IllegalArgumentException("Key and value cannot be null");
        }
        
        memtableLock.readLock().lock();
        try {
            wal.append(new WriteAheadLog.LogEntry(
                WriteAheadLog.OperationType.PUT, key, value));
            
            activeMemtable.put(key, value);
            
            if (activeMemtable.size() >= MEMTABLE_FLUSH_THRESHOLD) {
                rotateMemtable();
            }
        } finally {
            memtableLock.readLock().unlock();
        }
    }
    
    @Override
    public byte[] get(String key) throws IOException {
        if (key == null) {
            throw new IllegalArgumentException("Key cannot be null");
        }
        
        memtableLock.readLock().lock();
        try {
            byte[] value = activeMemtable.get(key);
            if (value != null) {
                return value;
            }
        } finally {
            memtableLock.readLock().unlock();
        }
        
        for (MemTable memtable : immutableMemtables) {
            byte[] value = memtable.get(key);
            if (value != null) {
                return value;
            }
        }
        
        sstableLock.readLock().lock();
        try {
            for (int i = sstables.size() - 1; i >= 0; i--) {
                byte[] value = sstables.get(i).get(key);
                if (value != null) {
                    return value;
                }
            }
        } finally {
            sstableLock.readLock().unlock();
        }
        
        return null;
    }
    
    @Override
    public boolean delete(String key) throws IOException {
        put(key, WriteAheadLog.TOMBSTONE);
        return true;
    }
    
    @Override
    public Iterator<String> scan(String prefix) throws IOException {
        Set<String> keys = new TreeSet<>();
        
        memtableLock.readLock().lock();
        try {
            keys.addAll(activeMemtable.scan(prefix));
        } finally {
            memtableLock.readLock().unlock();
        }
        
        for (MemTable memtable : immutableMemtables) {
            keys.addAll(memtable.scan(prefix));
        }
        
        sstableLock.readLock().lock();
        try {
            for (SSTable sstable : sstables) {
                keys.addAll(sstable.scan(prefix));
            }
        } finally {
            sstableLock.readLock().unlock();
        }
        
        return keys.iterator();
    }
    
    @Override
    public Iterator<Map.Entry<String, byte[]>> scanEntries(String prefix) throws IOException {
        Map<String, byte[]> entries = new TreeMap<>();
        
        sstableLock.readLock().lock();
        try {
            for (SSTable sstable : sstables) {
                entries.putAll(sstable.scanEntries(prefix));
            }
        } finally {
            sstableLock.readLock().unlock();
        }
        
        for (MemTable memtable : immutableMemtables) {
            entries.putAll(memtable.scanEntries(prefix));
        }
        
        memtableLock.readLock().lock();
        try {
            entries.putAll(activeMemtable.scanEntries(prefix));
        } finally {
            memtableLock.readLock().unlock();
        }
        
        entries.entrySet().removeIf(e -> 
            Arrays.equals(e.getValue(), WriteAheadLog.TOMBSTONE));
        
        return entries.entrySet().iterator();
    }
    
    @Override
    public void flush() throws IOException {
        memtableLock.writeLock().lock();
        try {
            if (!activeMemtable.isEmpty()) {
                rotateMemtable();
            }
        } finally {
            memtableLock.writeLock().unlock();
        }
        
        flushAllImmutableMemtables();
    }
    
    @Override
    public void compact() throws IOException {
        performCompaction();
    }
    
    @Override
    public StorageStats getStats() {
        long memtableSize = activeMemtable.size();
        for (MemTable mt : immutableMemtables) {
            memtableSize += mt.size();
        }
        
        sstableLock.readLock().lock();
        try {
            long totalSize = 0;
            long totalKeys = 0;
            
            for (SSTable sst : sstables) {
                totalSize += sst.getSize();
                totalKeys += sst.getKeyCount();
            }
            
            return new StorageStats(totalKeys, totalSize, memtableSize, sstables.size());
        } finally {
            sstableLock.readLock().unlock();
        }
    }
    
    @Override
    public void close() throws IOException {
        logger.info("Closing Cube database...");
        
        flush();
        
        if (wal != null) {
            wal.close();
        }
        
        sstableLock.writeLock().lock();
        try {
            for (SSTable sstable : sstables) {
                sstable.close();
            }
        } finally {
            sstableLock.writeLock().unlock();
        }
        
        flushExecutor.shutdown();
        compactionExecutor.shutdown();
        
        logger.info("Cube database closed");
    }
    
    private void rotateMemtable() {
        memtableLock.writeLock().lock();
        try {
            immutableMemtables.add(activeMemtable);
            activeMemtable = new MemTable();
            
            try {
                wal.rotate();
            } catch (IOException e) {
                logger.error("Failed to rotate WAL", e);
            }
            
            flushExecutor.submit(this::flushOneImmutableMemtable);
            
        } finally {
            memtableLock.writeLock().unlock();
        }
    }
    
    private void flushOneImmutableMemtable() {
        MemTable memtable = immutableMemtables.poll();
        if (memtable == null) {
            return;
        }
        
        try {
            String sstableFile = "sstable-" + System.currentTimeMillis() + ".db";
            Path sstablePath = dataDirectory.resolve(sstableFile);
            
            SSTable sstable = SSTable.create(sstablePath, memtable.getEntries());
            
            sstableLock.writeLock().lock();
            try {
                sstables.add(sstable);
            } finally {
                sstableLock.writeLock().unlock();
            }
            
            logger.info("Flushed memtable to {} ({} keys)", sstableFile, memtable.getKeyCount());
            
        } catch (IOException e) {
            logger.error("Failed to flush memtable", e);
            immutableMemtables.add(memtable);
        }
    }
    
    private void flushAllImmutableMemtables() {
        while (!immutableMemtables.isEmpty()) {
            flushOneImmutableMemtable();
        }
        
        // Give executor a moment to finish any pending work
        try {
            Thread.sleep(50);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }
    
    private void recoverFromWAL() throws IOException {
        List<WriteAheadLog.LogEntry> entries = wal.replay();
        
        for (WriteAheadLog.LogEntry entry : entries) {
            if (entry.getType() == WriteAheadLog.OperationType.PUT) {
                activeMemtable.put(entry.getKey(), entry.getValue());
            }
        }
        
        if (!entries.isEmpty()) {
            logger.info("Recovered {} entries from WAL", entries.size());
        }
    }
    
    private void loadSSTables() throws IOException {
        try (DirectoryStream<Path> stream = Files.newDirectoryStream(
                dataDirectory, "sstable-*.db")) {
            
            for (Path path : stream) {
                try {
                    SSTable sstable = SSTable.open(path);
                    sstables.add(sstable);
                } catch (IOException e) {
                    logger.error("Failed to load SSTable " + path, e);
                }
            }
        }
        
        sstables.sort(Comparator.comparing(SSTable::getCreationTime));
        
        if (!sstables.isEmpty()) {
            logger.info("Loaded {} SSTables", sstables.size());
        }
    }
    
    private void performCompaction() {
        sstableLock.writeLock().lock();
        try {
            if (sstables.size() < 2) {
                return;
            }
            
            logger.info("Starting compaction of {} SSTables...", sstables.size());
            
            Map<String, byte[]> merged = new TreeMap<>();
            
            for (SSTable sstable : sstables) {
                merged.putAll(sstable.getAll());
            }
            
            merged.entrySet().removeIf(e -> 
                Arrays.equals(e.getValue(), WriteAheadLog.TOMBSTONE));
            
            String compactedFile = "sstable-compacted-" + System.currentTimeMillis() + ".db";
            Path compactedPath = dataDirectory.resolve(compactedFile);
            
            SSTable compacted = SSTable.create(compactedPath, merged);
            
            for (SSTable old : sstables) {
                old.delete();
            }
            
            sstables.clear();
            sstables.add(compacted);
            
            logger.info("Compaction complete: {} keys in {}", merged.size(), compactedFile);
            
        } catch (IOException e) {
            logger.error("Compaction failed", e);
        } finally {
            sstableLock.writeLock().unlock();
        }
    }
    
    public Path getDataDirectory() {
        return dataDirectory;
    }
}
