/*
 * Copyright (C) by Argonne National Laboratory
 *     See COPYRIGHT in top-level directory
 */

#include "mpiimpl.h"
#include "algo_common.h"
#include "treealgo.h"

/* Algorithm: Pipelined bcast
 * For large messages, we use the tree-based pipelined algorithm.
 * Time = Time to send the first chunk to the rightmost child + (num_chunks - 1) * software overhead to inject a chunk
 */
int MPIR_Bcast_intra_pipelined_tree(void *buffer,
                                    MPI_Aint count,
                                    MPI_Datatype datatype,
                                    int root, MPIR_Comm * comm_ptr, int tree_type,
                                    int branching_factor, int is_nb, int chunk_size,
                                    int recv_pre_posted, int coll_attr)
{
    int rank, comm_size, i, j, k, *p, src = -1, dst, offset = 0;
    int is_contig;
    int mpi_errno = MPI_SUCCESS;
    MPI_Status status;
    MPI_Aint type_size, num_chunks, chunk_size_floor, chunk_size_ceil;
    MPI_Aint true_lb, true_extent, recvd_size, actual_packed_unpacked_bytes, nbytes = 0;
    void *sendbuf = NULL;
    int parent = -1, num_children = 0, lrank = 0, num_req = 0;
    MPIR_Request **reqs = NULL;
    MPI_Status *statuses = NULL;
    MPIR_Treealgo_tree_t my_tree;
    MPIR_CHKLMEM_DECL();

    MPIR_COMM_RANK_SIZE(comm_ptr, rank, comm_size);

    /* If there is only one process, return */
    if (comm_size == 1)
        goto fn_exit;

    MPIR_Datatype_is_contig(datatype, &is_contig);

    if (is_contig) {
        MPIR_Datatype_get_size_macro(datatype, type_size);
    } else {
        MPIR_Pack_size(1, datatype, &type_size);
    }

    nbytes = type_size * count;

    if (nbytes == 0)
        goto fn_exit;

    if (is_contig) {    /* no need to pack */
        MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent);
        sendbuf = (char *) buffer + true_lb;
    } else {
        MPIR_CHKLMEM_MALLOC(sendbuf, nbytes);
        if (rank == root) {
            mpi_errno = MPIR_Typerep_pack(buffer, count, datatype, 0, sendbuf, nbytes,
                                          &actual_packed_unpacked_bytes, MPIR_TYPEREP_FLAG_NONE);
            MPIR_ERR_CHECK(mpi_errno);
        }
    }

    /* treat all cases as MPIR_BYTE_INTERNAL */
    MPIR_Algo_calculate_pipeline_chunk_info(chunk_size, 1, nbytes, &num_chunks,
                                            &chunk_size_floor, &chunk_size_ceil);

    if (tree_type == MPIR_TREE_TYPE_KARY) {
        lrank = (rank + (comm_size - root)) % comm_size;
        parent = (lrank == 0) ? -1 : (((lrank - 1) / branching_factor) + root) % comm_size;
        num_children = branching_factor;
    } else {
        if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE ||
            tree_type == MPIR_TREE_TYPE_TOPOLOGY_AWARE_K) {
            mpi_errno =
                MPIR_Treealgo_tree_create_topo_aware(comm_ptr, tree_type, branching_factor, root,
                                                     MPIR_CVAR_BCAST_TOPO_REORDER_ENABLE, &my_tree);
        } else if (tree_type == MPIR_TREE_TYPE_TOPOLOGY_WAVE) {
            mpi_errno =
                MPIR_Treealgo_tree_create_topo_wave(comm_ptr, branching_factor, root,
                                                    MPIR_CVAR_BCAST_TOPO_REORDER_ENABLE,
                                                    MPIR_CVAR_BCAST_TOPO_OVERHEAD,
                                                    MPIR_CVAR_BCAST_TOPO_DIFF_GROUPS,
                                                    MPIR_CVAR_BCAST_TOPO_DIFF_SWITCHES,
                                                    MPIR_CVAR_BCAST_TOPO_SAME_SWITCHES, &my_tree);
        } else {
            mpi_errno =
                MPIR_Treealgo_tree_create(rank, comm_size, tree_type, branching_factor, root,
                                          &my_tree);
        }
        MPIR_ERR_CHECK(mpi_errno);
        num_children = my_tree.num_children;
    }

    if (is_nb) {
        MPIR_CHKLMEM_MALLOC(reqs,
                            sizeof(MPIR_Request *) * (num_children * num_chunks + num_chunks));
        MPIR_CHKLMEM_MALLOC(statuses,
                            sizeof(MPI_Status) * (num_children * num_chunks + num_chunks));
    }

    if (tree_type != MPIR_TREE_TYPE_KARY && my_tree.parent != -1)
        src = my_tree.parent;
    else if (tree_type == MPIR_TREE_TYPE_KARY && parent != -1)
        src = parent;

    if (is_nb) {
        if (num_chunks > 3 && !recv_pre_posted) {
            /* For large number of chunks, pre-posting all the receives can add overhead
             * so posting three IRecvs to keep the pipeline going*/
            for (i = 0; i < 3; i++) {
                MPI_Aint msgsize = (i == 0) ? chunk_size_floor : chunk_size_ceil;

                if (src != -1) {        /* post receive from parent */
                    mpi_errno =
                        MPIC_Irecv((char *) sendbuf + offset, msgsize, MPIR_BYTE_INTERNAL,
                                   src, MPIR_BCAST_TAG, comm_ptr, &reqs[num_req++]);
                    MPIR_ERR_CHECK(mpi_errno);
                }
                offset += msgsize;
            }

        } else {
            /* For small number of chunks, all the receives can be pre-posted */
            for (i = 0; i < num_chunks; i++) {
                MPI_Aint msgsize = (i == 0) ? chunk_size_floor : chunk_size_ceil;
                if (src != -1) {
                    mpi_errno =
                        MPIC_Irecv((char *) sendbuf + offset, msgsize, MPIR_BYTE_INTERNAL,
                                   src, MPIR_BCAST_TAG, comm_ptr, &reqs[num_req++]);
                    MPIR_ERR_CHECK(mpi_errno);
                }
                offset += msgsize;
            }
        }
    }
    offset = 0;

    for (i = 0; i < num_chunks; i++) {
        MPI_Aint msgsize = (i == 0) ? chunk_size_floor : chunk_size_ceil;

        if ((num_chunks <= 3 && is_nb) || (recv_pre_posted && is_nb)) {
            /* Wait to receive the chunk before it can be sent to the children */
            if (src != -1) {
                mpi_errno = MPIC_Wait(reqs[i]);
                MPIR_ERR_CHECK(mpi_errno);
                MPIR_Get_count_impl(&reqs[i]->status, MPIR_BYTE_INTERNAL, &recvd_size);
                MPIR_ERR_CHKANDJUMP2(recvd_size != msgsize, mpi_errno, MPI_ERR_OTHER,
                                     "**collective_size_mismatch",
                                     "**collective_size_mismatch %d %d",
                                     (int) recvd_size, (int) msgsize);
            }
        } else if (num_chunks > 3 && is_nb && i < 3 && !recv_pre_posted) {
            /* Wait to receive the chunk before it can be sent to the children */
            if (src != -1) {
                mpi_errno = MPIC_Wait(reqs[i]);
                MPIR_ERR_CHECK(mpi_errno);
                MPIR_Get_count_impl(&reqs[i]->status, MPIR_BYTE_INTERNAL, &recvd_size);
                MPIR_ERR_CHKANDJUMP2(recvd_size != msgsize, mpi_errno, MPI_ERR_OTHER,
                                     "**collective_size_mismatch",
                                     "**collective_size_mismatch %d %d",
                                     (int) recvd_size, (int) msgsize);
            }
        } else {
            /* Receive message from parent */
            if (src != -1) {
                mpi_errno =
                    MPIC_Recv((char *) sendbuf + offset, msgsize, MPIR_BYTE_INTERNAL,
                              src, MPIR_BCAST_TAG, comm_ptr, &status);
                MPIR_ERR_CHECK(mpi_errno);
                MPIR_Get_count_impl(&status, MPIR_BYTE_INTERNAL, &recvd_size);
                MPIR_ERR_CHKANDJUMP2(recvd_size != msgsize, mpi_errno, MPI_ERR_OTHER,
                                     "**collective_size_mismatch",
                                     "**collective_size_mismatch %d %d",
                                     (int) recvd_size, (int) msgsize);
            }
        }
        if (tree_type == MPIR_TREE_TYPE_KARY) {
            /* Send data to the children */
            for (k = 1; k <= branching_factor; k++) {
                dst = lrank * branching_factor + k;
                if (dst >= comm_size)
                    break;

                dst = (dst + root) % comm_size;

                if (!is_nb) {
                    mpi_errno =
                        MPIC_Send((char *) sendbuf + offset, msgsize, MPIR_BYTE_INTERNAL, dst,
                                  MPIR_BCAST_TAG, comm_ptr, coll_attr);
                } else {
                    mpi_errno =
                        MPIC_Isend((char *) sendbuf + offset, msgsize, MPIR_BYTE_INTERNAL, dst,
                                   MPIR_BCAST_TAG, comm_ptr, &reqs[num_req++], coll_attr);
                }
                MPIR_ERR_CHECK(mpi_errno);

            }
        } else if (num_children) {
            /* Send data to the children */
            for (j = 0; j < num_children; j++) {
                p = (int *) utarray_eltptr(my_tree.children, j);
                dst = *p;
                if (!is_nb) {
                    mpi_errno =
                        MPIC_Send((char *) sendbuf + offset, msgsize, MPIR_BYTE_INTERNAL, dst,
                                  MPIR_BCAST_TAG, comm_ptr, coll_attr);
                } else {
                    mpi_errno =
                        MPIC_Isend((char *) sendbuf + offset, msgsize, MPIR_BYTE_INTERNAL, dst,
                                   MPIR_BCAST_TAG, comm_ptr, &reqs[num_req++], coll_attr);
                }
                MPIR_ERR_CHECK(mpi_errno);
            }
        }
        offset += msgsize;
    }

    if (is_nb) {
        mpi_errno = MPIC_Waitall(num_req, reqs, statuses);
        MPIR_ERR_CHECK(mpi_errno);
    }

    if (!is_contig) {
        if (rank != root) {
            mpi_errno = MPIR_Typerep_unpack(sendbuf, nbytes, buffer, count, datatype, 0,
                                            &actual_packed_unpacked_bytes, MPIR_TYPEREP_FLAG_NONE);
            MPIR_ERR_CHECK(mpi_errno);
        }
    }

    if (tree_type != MPIR_TREE_TYPE_KARY)
        MPIR_Treealgo_tree_free(&my_tree);
  fn_exit:
    MPIR_CHKLMEM_FREEALL();
    return mpi_errno;
  fn_fail:
    goto fn_exit;
}
