package com.cube.cql;

import com.cube.sql.SQLParser;
import com.cube.storage.StorageEngine;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.*;

/**
 * Simple QueryExecutor implementation backed by the `StorageEngine`.
 * Supports a small subset of operations: CREATE TABLE (store PK), INSERT, SELECT, UPDATE, DELETE.
 */
public class QueryExecutor {

    private final StorageEngine storage;
    // map of "keyspace.table" -> primaryKeyColumn
    private final Map<String, String> primaryKeys = new HashMap<>();

    public QueryExecutor(StorageEngine storage) {
        this.storage = storage;
    }

    public Result execute(CQLParser.ParsedQuery parsedQuery) {
        try {
            String cql = parsedQuery.getCql();
            SQLParser.ParsedSQL sql = SQLParser.parse(cql);

            switch (sql.getType()) {
                case CREATE_TABLE:
                    return handleCreateTable(sql);
                case INSERT:
                    return handleInsert(sql);
                case SELECT:
                    return handleSelect(sql);
                case UPDATE:
                    return handleUpdate(sql);
                case DELETE:
                    return handleDelete(sql);
                case DROP_TABLE:
                    return handleDropTable(sql);
                default:
                    return new Result(false, "Unsupported CQL type: " + sql.getType(), Collections.emptyList(), 0);
            }

        } catch (Exception e) {
            return new Result(false, "Execution error: " + e.getMessage(), Collections.emptyList(), 0);
        }
    }

    private Result handleCreateTable(SQLParser.ParsedSQL sql) {
        String key = sql.getKeyspace() + "." + sql.getTable();
        String pk = sql.getPrimaryKey();
        if (pk == null) {
            // fallback: first column
            if (!sql.getColumnDefinitions().isEmpty()) {
                pk = sql.getColumnDefinitions().keySet().iterator().next();
            }
        }
        if (pk != null) {
            primaryKeys.put(key, pk);
        }
        return new Result(true, "Table created: " + key, Collections.emptyList(), 0);
    }

    private Result handleDropTable(SQLParser.ParsedSQL sql) {
        String key = sql.getKeyspace() + "." + sql.getTable();
        primaryKeys.remove(key);
        return new Result(true, "Table dropped: " + key, Collections.emptyList(), 0);
    }

    private Result handleInsert(SQLParser.ParsedSQL sql) throws IOException {
        String tableKey = sql.getKeyspace() + "." + sql.getTable();
        Map<String, String> cols = sql.getColumns();
        String pkCol = primaryKeys.getOrDefault(tableKey, cols.keySet().iterator().next());
        String pkVal = cols.get(pkCol);
        if (pkVal == null) {
            return new Result(false, "Primary key value missing for column: " + pkCol, Collections.emptyList(), 0);
        }

        String storageKey = storageKey(sql.getKeyspace(), sql.getTable(), pkVal);
        // serialize as simple newline-separated key=value UTF-8
        StringBuilder sb = new StringBuilder();
        for (Map.Entry<String, String> e : cols.entrySet()) {
            sb.append(e.getKey()).append("=").append(e.getValue() == null ? "" : e.getValue()).append('\n');
        }
        storage.put(storageKey, sb.toString().getBytes(StandardCharsets.UTF_8));
        return new Result(true, "Inserted", Collections.emptyList(), 1);
    }

    private Result handleSelect(SQLParser.ParsedSQL sql) throws IOException {
        List<Map<String, byte[]>> rows = new ArrayList<>();

        Map<String, String> where = sql.getWhereClause();
        if (where != null && !where.isEmpty()) {
            // If primary key present, do direct get
            String tableKey = sql.getKeyspace() + "." + sql.getTable();
            String pkCol = primaryKeys.get(tableKey);
            if (pkCol != null && where.containsKey(pkCol)) {
                String pkVal = where.get(pkCol);
                String storageKey = storageKey(sql.getKeyspace(), sql.getTable(), pkVal);
                byte[] value = storage.get(storageKey);
                if (value != null) {
                    Map<String, byte[]> row = decodeRow(value);
                    rows.add(filterColumns(row, sql.getSelectColumns()));
                }
                return new Result(true, "OK", rows, rows.size());
            }
        }

        // Full scan of table
        String prefix = storagePrefix(sql.getKeyspace(), sql.getTable());
        Iterator<Map.Entry<String, byte[]>> it = storage.scanEntries(prefix);
        while (it.hasNext()) {
            Map.Entry<String, byte[]> entry = it.next();
            Map<String, byte[]> row = decodeRow(entry.getValue());
            if (matchesWhere(row, sql.getWhereClause())) {
                rows.add(filterColumns(row, sql.getSelectColumns()));
            }
        }

        return new Result(true, "OK", rows, rows.size());
    }

    private Result handleUpdate(SQLParser.ParsedSQL sql) throws IOException {
        int updated = 0;
        Map<String, String> where = sql.getWhereClause();
        if (where != null && !where.isEmpty()) {
            String tableKey = sql.getKeyspace() + "." + sql.getTable();
            String pkCol = primaryKeys.get(tableKey);
            if (pkCol != null && where.containsKey(pkCol)) {
                String pkVal = where.get(pkCol);
                String storageKey = storageKey(sql.getKeyspace(), sql.getTable(), pkVal);
                byte[] value = storage.get(storageKey);
                if (value != null) {
                    Map<String, byte[]> row = decodeRow(value);
                    // apply updates
                    for (Map.Entry<String, String> e : sql.getColumns().entrySet()) {
                        row.put(e.getKey(), e.getValue() == null ? null : e.getValue().getBytes(StandardCharsets.UTF_8));
                    }
                    storage.put(storageKey, encodeRow(row));
                    updated++;
                }
                return new Result(true, "Updated", Collections.emptyList(), updated);
            }
        }

        // Otherwise scan and update matching rows
        String prefix = storagePrefix(sql.getKeyspace(), sql.getTable());
        Iterator<Map.Entry<String, byte[]>> it = storage.scanEntries(prefix);
        while (it.hasNext()) {
            Map.Entry<String, byte[]> entry = it.next();
            Map<String, byte[]> row = decodeRow(entry.getValue());
            if (matchesWhere(row, sql.getWhereClause())) {
                for (Map.Entry<String, String> e : sql.getColumns().entrySet()) {
                    row.put(e.getKey(), e.getValue() == null ? null : e.getValue().getBytes(StandardCharsets.UTF_8));
                }
                storage.put(entry.getKey(), encodeRow(row));
                updated++;
            }
        }

        return new Result(true, "Updated", Collections.emptyList(), updated);
    }

    private Result handleDelete(SQLParser.ParsedSQL sql) throws IOException {
        int deleted = 0;
        Map<String, String> where = sql.getWhereClause();
        String tableKey = sql.getKeyspace() + "." + sql.getTable();
        String pkCol = primaryKeys.get(tableKey);

        if (pkCol != null && where != null && where.containsKey(pkCol)) {
            String pkVal = where.get(pkCol);
            String storageKey = storageKey(sql.getKeyspace(), sql.getTable(), pkVal);
            if (storage.delete(storageKey)) deleted++;
            return new Result(true, "Deleted", Collections.emptyList(), deleted);
        }

        String prefix = storagePrefix(sql.getKeyspace(), sql.getTable());
        Iterator<Map.Entry<String, byte[]>> it = storage.scanEntries(prefix);
        List<String> keysToDelete = new ArrayList<>();
        while (it.hasNext()) {
            Map.Entry<String, byte[]> entry = it.next();
            Map<String, byte[]> row = decodeRow(entry.getValue());
            if (matchesWhere(row, sql.getWhereClause())) {
                keysToDelete.add(entry.getKey());
            }
        }
        for (String k : keysToDelete) {
            if (storage.delete(k)) deleted++;
        }

        return new Result(true, "Deleted", Collections.emptyList(), deleted);
    }

    private boolean matchesWhere(Map<String, byte[]> row, Map<String, String> where) {
        if (where == null || where.isEmpty()) return true;
        for (Map.Entry<String, String> cond : where.entrySet()) {
            byte[] val = row.get(cond.getKey());
            String sval = val == null ? null : new String(val, StandardCharsets.UTF_8);
            if (!Objects.equals(sval, cond.getValue())) return false;
        }
        return true;
    }

    private Map<String, byte[]> filterColumns(Map<String, byte[]> row, List<String> selectCols) {
        if (selectCols == null || selectCols.isEmpty() || (selectCols.size() == 1 && "*".equals(selectCols.get(0)))) {
            return row;
        }
        Map<String, byte[]> out = new LinkedHashMap<>();
        for (String c : selectCols) {
            out.put(c, row.get(c));
        }
        return out;
    }

    private String storageKey(String keyspace, String table, String pkVal) {
        return keyspace + ":" + table + ":" + pkVal;
    }

    private String storagePrefix(String keyspace, String table) {
        return keyspace + ":" + table + ":";
    }

    private Map<String, byte[]> decodeRow(byte[] bytes) {
        Map<String, byte[]> map = new LinkedHashMap<>();
        if (bytes == null || bytes.length == 0) return map;
        String s = new String(bytes, StandardCharsets.UTF_8);
        String[] lines = s.split("\n");
        for (String line : lines) {
            if (line.isEmpty()) continue;
            int idx = line.indexOf('=');
            if (idx <= 0) continue;
            String k = line.substring(0, idx);
            String v = line.substring(idx + 1);
            map.put(k, v.getBytes(StandardCharsets.UTF_8));
        }
        return map;
    }

    private byte[] encodeRow(Map<String, byte[]> row) {
        StringBuilder sb = new StringBuilder();
        for (Map.Entry<String, byte[]> e : row.entrySet()) {
            sb.append(e.getKey()).append("=");
            if (e.getValue() != null) sb.append(new String(e.getValue(), StandardCharsets.UTF_8));
            sb.append('\n');
        }
        return sb.toString().getBytes(StandardCharsets.UTF_8);
    }

    public static class Result {
        private final boolean success;
        private final String message;
        private final List<Map<String, byte[]>> rows;
        private final int rowsAffected;

        public Result(boolean success, String message, List<Map<String, byte[]>> rows, int rowsAffected) {
            this.success = success;
            this.message = message;
            this.rows = rows != null ? rows : new ArrayList<>();
            this.rowsAffected = rowsAffected;
        }

        public boolean isSuccess() { return success; }
        public String getMessage() { return message; }
        public List<Map<String, byte[]>> getRows() { return rows; }
        public int getRowsAffected() { return rowsAffected; }
    }
}
