// Copyright (C) 2022 The Qt Company Ltd.
// Copyright (C) 2019 Giulio Girardi <giulio.girardi@protechgroup.it>
// SPDX-License-Identifier: LicenseRef-Qt-Commercial OR GPL-3.0-only

#include "qgrpcchannel.h"
#include "qgrpcchannel_p.h"
#include "qgrpcchanneloperation.h"

#include "qabstractgrpcclient.h"

#include <QtCore/QFuture>
#include <QtCore/QList>
#include <QtCore/QThread>
#include <QtCore/QTimer>
#include <QtCore/qloggingcategory.h>
#include <QtProtobuf/QProtobufSerializer>
#include <qtgrpcglobal_p.h>

#include <grpcpp/channel.h>
#include <grpcpp/create_channel.h>
#include <grpcpp/impl/codegen/byte_buffer.h>
#include <grpcpp/impl/codegen/client_unary_call.h>
#include <grpcpp/impl/codegen/rpc_method.h>
#include <grpcpp/impl/codegen/slice.h>
#include <grpcpp/impl/codegen/status.h>
#include <grpcpp/impl/codegen/sync_stream.h>
#include <grpcpp/security/credentials.h>

#include <unordered_map>

#if QT_CONFIG(ssl)
#  include <QtNetwork/QSslKey>
#endif

QT_BEGIN_NAMESPACE

using namespace Qt::StringLiterals;

/*!
    \class QGrpcChannel
    \inmodule QtGrpc

    \brief The QGrpcHttp2Channel class is an HTTP/2 implementation of
    QAbstractGrpcChannel, based on the reference gRPC C++ API.

    QGrpcChannel accepts the same grpc::ChannelCredentials type that is required
    by native-api grpc::CreateChannel.
    \sa{https://grpc.github.io/grpc/cpp/classgrpc_1_1_channel_credentials.html}{gRPC ChannelCredentials}.
*/

static grpc::Status parseByteBuffer(const grpc::ByteBuffer &buffer, QByteArray &data)
{
    std::vector<grpc::Slice> slices;
    auto status = buffer.Dump(&slices);

    if (!status.ok())
        return status;

    for (const auto &slice : slices)
        data.append((const char *)slice.begin(), slice.size());

    return grpc::Status::OK;
}

static grpc::ByteBuffer parseQByteArray(QByteArrayView bytearray)
{
    grpc::ByteBuffer buffer;
    grpc::Slice slice(bytearray.data(), bytearray.size());
    grpc::ByteBuffer tmp(&slice, 1);
    buffer.Swap(&tmp);
    return buffer;
}

static std::string toStdString(QLatin1StringView view)
{
    return std::string(view.data(), view.size());
}

static QByteArray buildRpcName(QLatin1StringView service, QLatin1StringView method)
{
    return '/' % QByteArrayView(service) % '/' % QByteArrayView(method);
}

static std::optional<std::chrono::milliseconds> deadlineForCall(
        const QGrpcChannelOptions &channelOptions, const QGrpcCallOptions &callOptions)
{
    if (callOptions.deadline())
        return *callOptions.deadline();
    if (channelOptions.deadline())
        return *channelOptions.deadline();
    return std::nullopt;
}

QGrpcChannelStream::QGrpcChannelStream(grpc::Channel *channel, QLatin1StringView method,
                                       QByteArrayView data)
{
    grpc::ByteBuffer request = parseQByteArray(data);

    reader = grpc::internal::ClientReaderFactory<grpc::ByteBuffer>::Create(
            channel,
            grpc::internal::RpcMethod(toStdString(method).c_str(),
                                      grpc::internal::RpcMethod::SERVER_STREAMING),
            &context, request);

    thread = QThread::create([this] {
        grpc::ByteBuffer response;
        grpc::Status parseStatus;

        while (reader->Read(&response)) {
            QByteArray data;
            parseStatus = parseByteBuffer(response, data);
            if (!parseStatus.ok()) {
                status = { static_cast<QGrpcStatus::StatusCode>(parseStatus.error_code()),
                           QString::fromStdString(parseStatus.error_message()) };
                return; // exit thread
            }

            emit dataReady(data);
        }

        parseStatus = reader->Finish();
        status = { static_cast<QGrpcStatus::StatusCode>(parseStatus.error_code()),
                   QString::fromStdString(parseStatus.error_message()) };
    });

    connect(thread, &QThread::finished, this, &QGrpcChannelStream::finished);
}

void QGrpcChannelStream::start()
{
    thread->start();
}

QGrpcChannelStream::~QGrpcChannelStream()
{
    cancel();
    thread->wait();
    thread->deleteLater();
    delete reader;
}

void QGrpcChannelStream::cancel()
{
    // TODO: check thread safety
    context.TryCancel();
}

QGrpcChannelCall::QGrpcChannelCall(grpc::Channel *channel, QLatin1StringView method,
                                   QByteArrayView data)
{
    grpc::ByteBuffer request = parseQByteArray(data);
    thread = QThread::create([this, request, channel, method = toStdString(method)] {
        grpc::ByteBuffer callResponse;
        grpc::Status callStatus;

        callStatus = grpc::internal::BlockingUnaryCall(
                channel,
                grpc::internal::RpcMethod(method.c_str(), grpc::internal::RpcMethod::NORMAL_RPC),
                &context, request, &callResponse);
        if (!callStatus.ok()) {
            status = { static_cast<QGrpcStatus::StatusCode>(callStatus.error_code()),
                       QString::fromStdString(callStatus.error_message()) };
            return; // exit thread
        }

        callStatus = parseByteBuffer(callResponse, response);
        status = { static_cast<QGrpcStatus::StatusCode>(callStatus.error_code()),
                   QString::fromStdString(callStatus.error_message()) };
    });

    connect(thread, &QThread::finished, this, &QGrpcChannelCall::finished);
}

void QGrpcChannelCall::start()
{
    thread->start();
}

QGrpcChannelCall::~QGrpcChannelCall()
{
    cancel();
    thread->wait();
    thread->deleteLater();
}

void QGrpcChannelCall::cancel()
{
    // TODO: check thread safety
    context.TryCancel();
}

void QGrpcChannelCall::waitForFinished(const QDeadlineTimer &deadline)
{
    thread->wait(deadline);
}

QGrpcChannelPrivate::QGrpcChannelPrivate(const QGrpcChannelOptions &channelOptions,
                                         QGrpcChannel::NativeGrpcChannelCredentials credentialsType)
    : m_channelOptions(channelOptions)
{
    switch (credentialsType) {
    case QGrpcChannel::InsecureChannelCredentials:
        m_credentials = grpc::InsecureChannelCredentials();
        m_channel = grpc::CreateChannel(m_channelOptions.host().toString().toStdString(),
                                        m_credentials);
        break;
    case QGrpcChannel::GoogleDefaultCredentials:
        m_credentials = grpc::GoogleDefaultCredentials();
        m_channel = grpc::CreateChannel(m_channelOptions.host().toString().toStdString(),
                                        m_credentials);
        break;
    case QGrpcChannel::SslDefaultCredentials:
#if QT_CONFIG(ssl)
        if (auto maybeSslConfig = m_channelOptions.sslConfiguration()) {
            grpc::SslCredentialsOptions options;
            auto accumulateSslCert = [](const std::string &lhs, const QSslCertificate &rhs) {
                return lhs + rhs.toPem().toStdString();
            };
            options.pem_root_certs = std::accumulate(maybeSslConfig->peerCertificateChain().begin(),
                            maybeSslConfig->peerCertificateChain().end(), options.pem_root_certs,
                            accumulateSslCert);
            options.pem_cert_chain = std::accumulate(maybeSslConfig->localCertificateChain().begin(),
                            maybeSslConfig->localCertificateChain().end(), options.pem_cert_chain,
                            accumulateSslCert);
            options.pem_private_key = maybeSslConfig->privateKey().toPem();
            m_credentials = grpc::SslCredentials(options);
        } else {
            m_credentials = grpc::SslCredentials(grpc::SslCredentialsOptions());
        }
#else
        m_credentials = grpc::SslCredentials(grpc::SslCredentialsOptions());
#endif
        m_channel = grpc::CreateChannel(m_channelOptions.host().toString().toStdString(),
                                        m_credentials);
        break;
    }
}

QGrpcChannelPrivate::~QGrpcChannelPrivate() = default;

void QGrpcChannelPrivate::call(std::shared_ptr<QGrpcChannelOperation> channelOperation)
{
    const QByteArray rpcName =
            buildRpcName(channelOperation->service(), channelOperation->method());
    QSharedPointer<QGrpcChannelCall> call(new QGrpcChannelCall(
            m_channel.get(), QLatin1StringView(rpcName), channelOperation->arg()));
    auto connection = std::make_shared<QMetaObject::Connection>();
    auto abortConnection = std::make_shared<QMetaObject::Connection>();

    *connection = QObject::connect(call.get(), &QGrpcChannelCall::finished, channelOperation.get(),
                                   [call, channelOperation, connection, abortConnection] {
                                       QObject::disconnect(*connection);
                                       QObject::disconnect(*abortConnection);
                                       if (call->status == QGrpcStatus::Ok) {
                                           channelOperation->dataReady(call->response);
                                       } else {
                                           emit channelOperation->errorOccurred(call->status);
                                       }
                                       emit channelOperation->finished();
                                   });

    *abortConnection = QObject::connect(channelOperation.get(), &QGrpcChannelOperation::cancelled,
                                        call.get(), [connection, abortConnection]() {
                                            QObject::disconnect(*connection);
                                            QObject::disconnect(*abortConnection);
                                        });

    call->start();
    if (auto deadline = deadlineForCall(m_channelOptions, channelOperation->options()))
        QTimer::singleShot(*deadline, call.get(), [call] { call->cancel(); });
}

void QGrpcChannelPrivate::startServerStream(std::shared_ptr<QGrpcChannelOperation> channelOperation)
{
    const QByteArray rpcName =
            buildRpcName(channelOperation->service(), channelOperation->method());

    QSharedPointer<QGrpcChannelStream> sub(new QGrpcChannelStream(
            m_channel.get(), QLatin1StringView(rpcName), channelOperation->arg()));

    auto abortConnection = std::make_shared<QMetaObject::Connection>();
    auto readConnection = std::make_shared<QMetaObject::Connection>();
    auto connection = std::make_shared<QMetaObject::Connection>();

    auto disconnectAllConnections = [abortConnection, readConnection, connection]() {
        QObject::disconnect(*connection);
        QObject::disconnect(*readConnection);
        QObject::disconnect(*abortConnection);
    };

    *readConnection =
            QObject::connect(sub.get(), &QGrpcChannelStream::dataReady, channelOperation.get(),
                             [channelOperation](QByteArrayView data) {
                                 channelOperation->dataReady(data.toByteArray());
                             });

    *connection = QObject::connect(sub.get(), &QGrpcChannelStream::finished, channelOperation.get(),
                                   [disconnectAllConnections, sub, channelOperation] {
                                       qGrpcDebug()
                                               << "Stream ended with server closing connection";
                                       disconnectAllConnections();

                                       if (sub->status != QGrpcStatus::Ok)
                                           emit channelOperation->errorOccurred(sub->status);
                                       emit channelOperation->finished();
                                   });

    *abortConnection =
            QObject::connect(channelOperation.get(), &QGrpcChannelOperation::cancelled, sub.get(),
                             [disconnectAllConnections, sub, channelOperation] {
                                 qGrpcDebug() << "Server stream was cancelled by client";
                                 disconnectAllConnections();
                                 sub->cancel();
                             });

    sub->start();
    if (auto deadline = deadlineForCall(m_channelOptions, channelOperation->options()))
        QTimer::singleShot(*deadline, sub.get(), [sub] { sub->cancel(); });
}

std::shared_ptr<QAbstractProtobufSerializer> QGrpcChannelPrivate::serializer() const
{
    // TODO: make selection based on credentials or channel settings
    return std::make_shared<QProtobufSerializer>();
}

/*!
    Constructs a gRPC channel, with \a options and \a credentialsType.
*/
QGrpcChannel::QGrpcChannel(const QGrpcChannelOptions &options,
                           NativeGrpcChannelCredentials credentialsType)
    : QAbstractGrpcChannel(), dPtr(std::make_unique<QGrpcChannelPrivate>(options, credentialsType))
{
}

/*!
    Destroys the QGrpcChannel object.
*/
QGrpcChannel::~QGrpcChannel() = default;

/*!
    \internal
    Implementation of unary gRPC call based on the
    reference gRPC C++ API.
*/
void QGrpcChannel::call(std::shared_ptr<QGrpcChannelOperation> channelOperation)
{
    dPtr->call(std::move(channelOperation));
}

/*!
    \internal
    Implementation of server-side gRPC stream based on the
    reference gRPC C++ API.
*/
void QGrpcChannel::startServerStream(std::shared_ptr<QGrpcChannelOperation> channelOperation)
{
    dPtr->startServerStream(std::move(channelOperation));
}

/*!
    \internal
    Implementation of client-side gRPC stream based on the
    reference gRPC C++ API.
*/
void QGrpcChannel::startClientStream(std::shared_ptr<QGrpcChannelOperation> channelOperation)
{
    QTimer::singleShot(0, channelOperation.get(), [channelOperation] {
        emit channelOperation->errorOccurred(
                { QGrpcStatus::Unknown,
                  "Client-side streaming support is not implemented in QGrpcChannel"_L1 });
    });
}

/*!
    \internal
    Implementation of bidirectional gRPC stream based on the
    reference gRPC C++ API.
*/
void QGrpcChannel::startBidirStream(std::shared_ptr<QGrpcChannelOperation> channelOperation)
{
    QTimer::singleShot(0, channelOperation.get(), [channelOperation] {
        emit channelOperation->errorOccurred(
                { QGrpcStatus::Unknown,
                  "Bidirectional streaming support is not implemented in QGrpcChannel"_L1 });
    });
}

/*!
    Returns the newly created QProtobufSerializer shared pointer.
*/
std::shared_ptr<QAbstractProtobufSerializer> QGrpcChannel::serializer() const
{
    return dPtr->serializer();
}

QT_END_NAMESPACE
