Skip to content

Commit

Permalink
remove __half (it was causing weird memory deallocation errors) remov…
Browse files Browse the repository at this point in the history
…e Pythonkit and bring swift-testing with few more fixes
  • Loading branch information
machineko committed Jul 11, 2024
1 parent 85e710b commit 071392f
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 133 deletions.
6 changes: 2 additions & 4 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ let package = Package(
dependencies:
[
.package(url: "https://github.com/machineko/SwiftCU", branch: "main"),
.package(url: "https://github.com/pvieito/PythonKit.git", branch: "master"),
.package(url: "https://github.com/apple/swift-docc-plugin", from: "1.3.0"),

.package(url: "https://github.com/apple/swift-testing.git", from: "0.10.0"),
],
targets: [
.target(
Expand Down Expand Up @@ -48,7 +46,7 @@ let package = Package(

dependencies: [
"SwiftCU", "cxxCUBLAS", "SwiftCUBLAS",
.product(name: "PythonKit", package: "PythonKit")
.product(name: "Testing", package: "swift-testing"),
],
swiftSettings: [
.interoperabilityMode(.Cxx),
Expand Down
42 changes: 16 additions & 26 deletions Sources/SwiftCUBLAS/CUBLASUtils/CUBlasUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ extension CUBLASParamsMixed {
return CUDA_R_8I
case is Int32.Type:
return CUDA_R_32I
case is __half.Type:
case is Float16.Type:
return CUDA_R_16F
default:
fatalError("Unsupported CUBLAS data type")
fatalError("\(inputType.self) not supported")
}
}

Expand All @@ -49,14 +49,14 @@ extension CUBLASParamsMixed {
return CUDA_R_32F
case is Double.Type:
return CUDA_R_64F
case is UInt8.Type:
return CUDA_R_8U
case is Int8.Type:
return CUDA_R_8I
case is Int32.Type:
return CUDA_R_32I
case is __half.Type:
case is Float16.Type:
return CUDA_R_16F
default:
fatalError("Unsupported CUBLAS data type")
fatalError("\(inputType.self) not supported")
}
}
}
Expand Down Expand Up @@ -223,24 +223,11 @@ extension CUBLASHandle {
let status = cublasSgemm_v2(
self.handle, transposeA.ascublas, transposeB.ascublas, params.m, params.n,
params.k, &params.alpha, params.A, params.lda, params.B, params.ldb, &params.beta, params.C, params.ldc
)
return status.asSwift
}

/// Performs half-precision general matrix multiplication (HGEMM) using CUBLAS.
/// - Parameters:
/// - transposeA: Specifies whether to transpose matrix A.
/// - transposeB: Specifies whether to transpose matrix B.
/// - params: The parameters for the HGEMM operation.
/// - Returns: The status of the HGEMM operation.
public func hgemm(
transposeA: cublasOperation = .cublas_op_n, transposeB: cublasOperation = .cublas_op_n, params: inout CUBLASParams<__half>
) -> cublasStatus {
let status = cublasHgemm(
self.handle, transposeA.ascublas, transposeB.ascublas, params.m, params.n,
params.k, &params.alpha, params.A, params.lda, params.B, params.ldb, &params.beta, params.C, params.ldc
)
return status.asSwift
).asSwift
#if safetyCheck
status.safetyCheckCondition(message: "Can't run sgemm cublasSgemm_v2 function \(status)")
#endif
return status
}

/// Performs mixed-precision general matrix multiplication (GEMM) using CUBLAS.
Expand Down Expand Up @@ -269,7 +256,10 @@ extension CUBLASHandle {
params.C, params.outputCUDAType, params.ldc,
computeType.ascublas,
cublasGemmAlgo.ascublas
)
return status.asSwift
).asSwift
#if safetyCheck
status.safetyCheckCondition(message: "Can't run cublasGemmEx function \(status)")
#endif
return status
}
}
1 change: 0 additions & 1 deletion Sources/SwiftCUBLAS/SwiftCUBLAS.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ extension Float: CUBLASDataType {}
extension Double: CUBLASDataType {}
extension Int8: CUBLASDataType {}
extension Int32: CUBLASDataType {}
extension __half: CUBLASDataType {}

/// A structure that manages a CUBLAS handle.
public struct CUBLASHandle: ~Copyable {
Expand Down
1 change: 0 additions & 1 deletion Sources/cxxCUBLAS/include/cublas_head.hpp
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
#include <cublas_v2.h>
#include <cuda_fp16.h>
Loading

0 comments on commit 071392f

Please sign in to comment.