Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: deduplicate mmq code #7397

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

This PR deduplicates the CUDA code related to mul_mat_q, mostly the code around launching the kernels. There is still duplication around the __global__ functions because I don't know how to handle __launch_bounds__ in a way that isn't an overly complicated and difficult to understand macro.

ggml-cuda/mmq.cu Outdated
Comment on lines 16 to 20
typedef struct mmq_arch_config_t {
int x;
int y;
int nwarps;
} mmq_arch_config_t;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The typedef struct stuff is not necessary in C++, just do struct mmq_arch_config_t { .. };.

ggml-cuda/mmq.cu Outdated
Comment on lines 1195 to 1210
#if __CUDA_ARCH__ >= MIN_CC_DP4A
constexpr mmq_config_t mmq_config = MMQ_CONFIG_Q4_0;

#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#if defined(RDNA3) || defined(RDNA2)
const int mmq_x = MMQ_X_Q4_0_RDNA2;
const int mmq_y = MMQ_Y_Q4_0_RDNA2;
const int nwarps = NWARPS_Q4_0_RDNA2;
constexpr mmq_arch_config_t arch_config = mmq_config.rdna2;
#else
const int mmq_x = MMQ_X_Q4_0_RDNA1;
const int mmq_y = MMQ_Y_Q4_0_RDNA1;
const int nwarps = NWARPS_Q4_0_RDNA1;
constexpr mmq_arch_config_t arch_config = mmq_config.rdna1;
#endif // defined(RDNA3) || defined(RDNA2)
#else
#if __CUDA_ARCH__ >= CC_VOLTA
constexpr mmq_arch_config_t arch_config = mmq_config.ampere;
#else
constexpr mmq_arch_config_t arch_config = mmq_config.pascal;
#endif // __CUDA_ARCH__ >= CC_VOLTA
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use a constexpr function to deduplicate this code, eg.

constexpr static __device__ mmq_arch_config_t get_mmq_arch_config(const mmq_config_t mmq_config) {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#if defined(RDNA3) || defined(RDNA2)
    return mmq_config.rdna2;
#else
    return mmq_config.rdna1;
#endif // defined(RDNA3) || defined(RDNA2)
#else
#if __CUDA_ARCH__ >= CC_VOLTA
    return mmq_config.ampere;
#else
    return mmq_config.pascal;
#endif // __CUDA_ARCH__ >= CC_VOLTA
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}

template <bool need_check> static __global__ void
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#if defined(RDNA3) || defined(RDNA2)
    __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_0.rdna2.nwarps, 2)
#endif // defined(RDNA3) || defined(RDNA2)
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
    mul_mat_q4_0(
    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
    const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {

#if __CUDA_ARCH__ >= MIN_CC_DP4A
    constexpr mmq_arch_config_t arch_config = get_mmq_arch_config(MMQ_CONFIG_Q4_0);

    mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q4_0<arch_config.y>,
        load_tiles_q4_0<arch_config.y, arch_config.nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
    GGML_UNUSED(vec_dot_q4_0_q8_1_mul_mat);
    NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}

@slaren
Copy link
Collaborator

slaren commented May 19, 2024

It also seems possible to use a constexpr function with __launch_bounds__, but it may require more changes. This compiles:

constexpr static __device__ int get_config_launch_bounds(const mmq_config config) {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#if defined(RDNA3) || defined(RDNA2)
    return WARP_SIZE*config.nwarps;
#endif // defined(RDNA3) || defined(RDNA2)
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
    return 0;
}

template <bool need_check> static __global__ void
    __launch_bounds__(get_config_launch_bounds(MMQ_CONFIG_Q4_0), 2)
    mul_mat_q4_0(...

C++11 constexpr functions only allow a return, but the requirements are much more relaxed in C++14. C++17 also has if constexpr. I think it would be fine to increase the C++ standard of the CUDA backend to C++17 if necessary, it's supported anywhere CUDA is, and it is required for CUTLASS anyway.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels May 19, 2024
@mofosyne mofosyne added refactoring Refactoring Review Complexity : High Generally require indepth knowledge of LLMs or GPUs labels May 20, 2024
Copy link
Contributor

github-actions bot commented May 20, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 523 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8913.74ms p(95)=21161.84ms fails=, finish reason: stop=469 truncated=54
  • Prompt processing (pp): avg=103.27tk/s p(95)=405.07tk/s
  • Token generation (tg): avg=45.83tk/s p(95)=51.1tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=cuda-deduplicate-mmq commit=d9a738cf00b10815858887583225ead9d5b57989

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 523 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1716300085 --> 1716300707
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 352.12, 352.12, 352.12, 352.12, 352.12, 574.44, 574.44, 574.44, 574.44, 574.44, 464.1, 464.1, 464.1, 464.1, 464.1, 465.56, 465.56, 465.56, 465.56, 465.56, 490.87, 490.87, 490.87, 490.87, 490.87, 567.97, 567.97, 567.97, 567.97, 567.97, 570.95, 570.95, 570.95, 570.95, 570.95, 586.23, 586.23, 586.23, 586.23, 586.23, 609.66, 609.66, 609.66, 609.66, 609.66, 623.48, 623.48, 623.48, 623.48, 623.48, 638.71, 638.71, 638.71, 638.71, 638.71, 669.29, 669.29, 669.29, 669.29, 669.29, 669.52, 669.52, 669.52, 669.52, 669.52, 634.42, 634.42, 634.42, 634.42, 634.42, 656.13, 656.13, 656.13, 656.13, 656.13, 680.6, 680.6, 680.6, 680.6, 680.6, 691.46, 691.46, 691.46, 691.46, 691.46, 699.98, 699.98, 699.98, 699.98, 699.98, 704.22, 704.22, 704.22, 704.22, 704.22, 704.37, 704.37, 704.37, 704.37, 704.37, 729.7, 729.7, 729.7, 729.7, 729.7, 733.09, 733.09, 733.09, 733.09, 733.09, 733.82, 733.82, 733.82, 733.82, 733.82, 740.45, 740.45, 740.45, 740.45, 740.45, 748.69, 748.69, 748.69, 748.69, 748.69, 768.38, 768.38, 768.38, 768.38, 768.38, 771.7, 771.7, 771.7, 771.7, 771.7, 772.68, 772.68, 772.68, 772.68, 772.68, 773.91, 773.91, 773.91, 773.91, 773.91, 774.66, 774.66, 774.66, 774.66, 774.66, 774.27, 774.27, 774.27, 774.27, 774.27, 777.79, 777.79, 777.79, 777.79, 777.79, 781.68, 781.68, 781.68, 781.68, 781.68, 782.12, 782.12, 782.12, 782.12, 782.12, 785.61, 785.61, 785.61, 785.61, 785.61, 786.33, 786.33, 786.33, 786.33, 786.33, 784.62, 784.62, 784.62, 784.62, 784.62, 784.82, 784.82, 784.82, 784.82, 784.82, 792.92, 792.92, 792.92, 792.92, 792.92, 792.12, 792.12, 792.12, 792.12, 792.12, 790.25, 790.25, 790.25, 790.25, 790.25, 792.98, 792.98, 792.98, 792.98, 792.98, 793.74, 793.74, 793.74, 793.74, 793.74, 793.82, 793.82, 793.82, 793.82, 793.82, 777.24, 777.24, 777.24, 777.24, 777.24, 777.95, 777.95, 777.95, 777.95, 777.95, 777.87, 777.87, 777.87, 777.87, 777.87, 776.54, 776.54, 776.54, 776.54, 776.54, 773.83, 773.83, 773.83, 773.83, 773.83, 776.85, 776.85, 776.85, 776.85, 776.85, 776.61, 776.61, 776.61, 776.61, 776.61, 777.06, 777.06, 777.06, 777.06, 777.06, 781.54, 781.54, 781.54, 781.54, 781.54, 785.94, 785.94, 785.94, 785.94, 785.94, 784.11, 784.11, 784.11, 784.11, 784.11, 782.99, 782.99, 782.99, 782.99, 782.99, 779.5, 779.5, 779.5, 779.5, 779.5, 778.31, 778.31, 778.31, 778.31, 778.31, 777.45, 777.45, 777.45, 777.45, 777.45, 778.02, 778.02, 778.02, 778.02, 778.02, 779.46, 779.46]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 523 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1716300085 --> 1716300707
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 42.52, 42.52, 42.52, 42.52, 42.52, 44.72, 44.72, 44.72, 44.72, 44.72, 33.19, 33.19, 33.19, 33.19, 33.19, 32.81, 32.81, 32.81, 32.81, 32.81, 27.56, 27.56, 27.56, 27.56, 27.56, 28.57, 28.57, 28.57, 28.57, 28.57, 28.68, 28.68, 28.68, 28.68, 28.68, 30.24, 30.24, 30.24, 30.24, 30.24, 31.03, 31.03, 31.03, 31.03, 31.03, 31.37, 31.37, 31.37, 31.37, 31.37, 31.69, 31.69, 31.69, 31.69, 31.69, 31.77, 31.77, 31.77, 31.77, 31.77, 31.54, 31.54, 31.54, 31.54, 31.54, 31.5, 31.5, 31.5, 31.5, 31.5, 30.95, 30.95, 30.95, 30.95, 30.95, 30.33, 30.33, 30.33, 30.33, 30.33, 29.41, 29.41, 29.41, 29.41, 29.41, 28.38, 28.38, 28.38, 28.38, 28.38, 28.26, 28.26, 28.26, 28.26, 28.26, 28.58, 28.58, 28.58, 28.58, 28.58, 28.66, 28.66, 28.66, 28.66, 28.66, 28.78, 28.78, 28.78, 28.78, 28.78, 28.78, 28.78, 28.78, 28.78, 28.78, 28.91, 28.91, 28.91, 28.91, 28.91, 29.17, 29.17, 29.17, 29.17, 29.17, 29.27, 29.27, 29.27, 29.27, 29.27, 29.48, 29.48, 29.48, 29.48, 29.48, 29.74, 29.74, 29.74, 29.74, 29.74, 29.81, 29.81, 29.81, 29.81, 29.81, 29.62, 29.62, 29.62, 29.62, 29.62, 29.81, 29.81, 29.81, 29.81, 29.81, 30.13, 30.13, 30.13, 30.13, 30.13, 30.29, 30.29, 30.29, 30.29, 30.29, 30.38, 30.38, 30.38, 30.38, 30.38, 30.52, 30.52, 30.52, 30.52, 30.52, 30.57, 30.57, 30.57, 30.57, 30.57, 30.67, 30.67, 30.67, 30.67, 30.67, 30.45, 30.45, 30.45, 30.45, 30.45, 30.43, 30.43, 30.43, 30.43, 30.43, 30.07, 30.07, 30.07, 30.07, 30.07, 29.86, 29.86, 29.86, 29.86, 29.86, 30.0, 30.0, 30.0, 30.0, 30.0, 30.16, 30.16, 30.16, 30.16, 30.16, 30.18, 30.18, 30.18, 30.18, 30.18, 30.34, 30.34, 30.34, 30.34, 30.34, 30.17, 30.17, 30.17, 30.17, 30.17, 29.79, 29.79, 29.79, 29.79, 29.79, 29.51, 29.51, 29.51, 29.51, 29.51, 28.85, 28.85, 28.85, 28.85, 28.85, 28.72, 28.72, 28.72, 28.72, 28.72, 28.69, 28.69, 28.69, 28.69, 28.69, 28.77, 28.77, 28.77, 28.77, 28.77, 28.79, 28.79, 28.79, 28.79, 28.79, 28.88, 28.88, 28.88, 28.88, 28.88, 28.94, 28.94, 28.94, 28.94, 28.94, 28.85, 28.85, 28.85, 28.85, 28.85, 28.81, 28.81, 28.81, 28.81, 28.81, 28.67, 28.67, 28.67, 28.67, 28.67, 28.73, 28.73, 28.73, 28.73, 28.73, 28.87, 28.87, 28.87, 28.87, 28.87, 28.94, 28.94]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 523 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1716300085 --> 1716300707
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14, 0.14, 0.14, 0.14, 0.14, 0.38, 0.38, 0.38, 0.38, 0.38, 0.32, 0.32, 0.32, 0.32, 0.32, 0.4, 0.4, 0.4, 0.4, 0.4, 0.09, 0.09, 0.09, 0.09, 0.09, 0.21, 0.21, 0.21, 0.21, 0.21, 0.22, 0.22, 0.22, 0.22, 0.22, 0.13, 0.13, 0.13, 0.13, 0.13, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.2, 0.2, 0.2, 0.2, 0.2, 0.26, 0.26, 0.26, 0.26, 0.26, 0.12, 0.12, 0.12, 0.12, 0.12, 0.34, 0.34, 0.34, 0.34, 0.34, 0.33, 0.33, 0.33, 0.33, 0.33, 0.49, 0.49, 0.49, 0.49, 0.49, 0.35, 0.35, 0.35, 0.35, 0.35, 0.24, 0.24, 0.24, 0.24, 0.24, 0.17, 0.17, 0.17, 0.17, 0.17, 0.11, 0.11, 0.11, 0.11, 0.11, 0.21, 0.21, 0.21, 0.21, 0.21, 0.26, 0.26, 0.26, 0.26, 0.26, 0.17, 0.17, 0.17, 0.17, 0.17, 0.14, 0.14, 0.14, 0.14, 0.14, 0.09, 0.09, 0.09, 0.09, 0.09, 0.08, 0.08, 0.08, 0.08, 0.08, 0.13, 0.13, 0.13, 0.13, 0.13, 0.18, 0.18, 0.18, 0.18, 0.18, 0.31, 0.31, 0.31, 0.31, 0.31, 0.16, 0.16, 0.16, 0.16, 0.16, 0.1, 0.1, 0.1, 0.1, 0.1, 0.13, 0.13, 0.13, 0.13, 0.13, 0.16, 0.16, 0.16, 0.16, 0.16, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.16, 0.16, 0.16, 0.16, 0.16, 0.2, 0.2, 0.2, 0.2, 0.2, 0.29, 0.29, 0.29, 0.29, 0.29, 0.42, 0.42, 0.42, 0.42, 0.42, 0.36, 0.36, 0.36, 0.36, 0.36, 0.11, 0.11, 0.11, 0.11, 0.11, 0.15, 0.15, 0.15, 0.15, 0.15, 0.11, 0.11, 0.11, 0.11, 0.11, 0.13, 0.13, 0.13, 0.13, 0.13, 0.4, 0.4, 0.4, 0.4, 0.4, 0.52, 0.52, 0.52, 0.52, 0.52, 0.54, 0.54, 0.54, 0.54, 0.54, 0.5, 0.5, 0.5, 0.5, 0.5, 0.21, 0.21, 0.21, 0.21, 0.21, 0.22, 0.22, 0.22, 0.22, 0.22, 0.21, 0.21, 0.21, 0.21, 0.21, 0.12, 0.12, 0.12, 0.12, 0.12, 0.14, 0.14, 0.14, 0.14, 0.14, 0.17, 0.17, 0.17, 0.17, 0.17, 0.34, 0.34, 0.34, 0.34, 0.34, 0.31, 0.31, 0.31, 0.31, 0.31, 0.35, 0.35, 0.35, 0.35, 0.35, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.13, 0.13, 0.13, 0.13, 0.13, 0.11, 0.11]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 523 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1716300085 --> 1716300707
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 3.0, 3.0]
                    
Loading

@JohannesGaessler
Copy link
Collaborator Author

I tried wrangling the code into producing the correct tile sizes for each arch without having to compile all tile sizes for all archs but after some time I gave up. I don't know how to correctly pass the templates for e.g. allocate tiles. For now I would like to just merge this PR as-is; I will soon look into how to utilize tensor cores with PTX at which point I will likely overhaul the MMQ kernels to be more like #4801 . At that point I plan to replace the current system for determining tile sizes anyways.

@JohannesGaessler JohannesGaessler merged commit d8ee902 into ggerganov:master May 21, 2024
61 of 73 checks passed
@ggerganov
Copy link
Owner

ggml-ci is failing to build after this PR:

https://github.com/ggml-org/ci/blob/67a2d49a31cb9441c0b34053563f44a086738c2f/llama.cpp/d8/ee90222791afff2ab666ded4cb6195fd94cced/ggml-4-x86-cuda-v100/stdall#L150-L153

/home/ggml/work/llama.cpp/ggml-cuda/mmq.cu(1185): error #177-D: function "get_arch_config_device" was declared but never referenced
                             mmq_arch_config_t get_arch_config_device(mmq_config_t mmq_config) {
                                               ^

Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"

1 error detected in the compilation of "/home/ggml/work/llama.cpp/ggml-cuda/mmq.cu".

Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request May 21, 2024
@mofosyne
Copy link
Collaborator

mofosyne commented May 22, 2024

Note, CI is failing due to ggml-org / ggml-4-x86-cuda-v100 - failure 2 in 0:15.28

but confirmed fixed a few commits later under CUDA: fix unused warning in mmq.cu (https://github.com/ggerganov/llama.cpp/pull/7442[)](https://github.com/ggerganov/llama.cpp/commit/fcf6538ba6702c55eaec70da9a75c81d04900a72)

teleprint-me pushed a commit to teleprint-me/llama.cpp that referenced this pull request May 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs refactoring Refactoring Review Complexity : High Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants