#!/usr/bin/env python3

# (C) Copyright 2025- ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#


import sys
from pathlib import Path

import numpy as np

# Add the directory containing stream.py to the import path
sys.path.append(str(Path(__file__).parent))
from stream import Stream


def dtype_uint(little_endian, size):
    order = "<" if little_endian else ">"
    return np.dtype({4: np.uint32}[size]).newbyteorder(order)


def dtype_float(little_endian, size):
    order = "<" if little_endian else ">"
    return np.dtype({4: np.float32, 8: np.float64}[size]).newbyteorder(order)


def main():
    import argparse

    parser = argparse.ArgumentParser(description="Show MIR CSR matrix information.")
    parser.add_argument("files", nargs="+", type=Path, help="One or more matrix files.")
    parser.add_argument(
        "--no-summarization", action="store_true", help="Print without summarization"
    )

    fmt = parser.add_mutually_exclusive_group()
    fmt.add_argument(
        "--json", "-j", action="store_true", help="Print as JSON (0-based indexing)"
    )
    fmt.add_argument(
        "--coo",
        action="store_true",
        help="Print as coordinates (0-based indexing, requires SciPy)",
    )

    args = parser.parse_args()

    if args.no_summarization:
        np.set_printoptions(threshold=np.inf)

    for path in args.files:
        if not path.exists():
            print(f"Error: {path} not found.", file=sys.stderr)
            continue

        with open(path, "rb") as f:
            s = Stream(f)

            nr = s.read_unsigned_long()
            nc = s.read_unsigned_long()
            nnz = s.read_unsigned_long()

            little_endian = s.read_int() != 0
            sizeof_index = s.read_unsigned_long()
            sizeof_scalar = s.read_unsigned_long()
            sizeof_size = s.read_unsigned_long()

            class StreamArray:
                def __init__(self, stream, dtype):
                    array = stream.read_large_blob()
                    self.len = len(array)
                    self.array = np.frombuffer(array, dtype=dtype)

            outer = StreamArray(s, dtype_uint(little_endian, sizeof_index))
            inner = StreamArray(s, dtype_uint(little_endian, sizeof_index))
            data = StreamArray(s, dtype_float(little_endian, sizeof_scalar))

            if args.json:
                from json import dumps

                obj = dict(
                    nr=nr,
                    nc=nc,
                    nnz=nnz,
                    outer=outer.array.astype(int).tolist(),
                    inner=inner.array.astype(int).tolist(),
                    data=data.array.astype(float).tolist(),
                )

                print(dumps(obj, indent=4))
                continue

            if args.coo:
                from scipy.sparse import csr_matrix

                M = csr_matrix(
                    (data.array, inner.array, outer.array), shape=(nr, nc)
                ).tocoo()
                for i, j, v in zip(M.row, M.col, M.data):
                    print(i, j, v)
                continue

            print(f"'{path}'")
            print(f"nr x nc <= nnz: {nr} x {nc} = {nr * nc} <= {nnz}")

            print(f"little_endian: {little_endian}")
            print(f"sizeof(Index): {sizeof_index} bytes")
            print(f"sizeof(Scalar): {sizeof_scalar} bytes")
            print(f"sizeof(Size): {sizeof_size} bytes")

            print(f"Outer array ({outer.len} bytes): ", outer.array)
            print(f"Inner array ({inner.len} bytes): ", inner.array)
            print(f"Data array ({data.len} bytes): ", data.array)


if __name__ == "__main__":
    main()
