3
This commit is contained in:
163
exp1/matmult.c
Normal file
163
exp1/matmult.c
Normal file
@@ -0,0 +1,163 @@
|
||||
/*
|
||||
任务5(选做):编写一个N×N矩阵乘法函数的并行线程化版本,程序保存为matmult.c,
|
||||
设计一种方案,验证并行程序正确性。
|
||||
编译、调试、运行程序,设2^k ≤ N < 2^(k+1)
|
||||
给出线程数为1、2、4、…、2^(k-1)、2^k等情况下的运行时间,计算加速比与效率,并对结果给出解释。
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <pthread.h>
|
||||
#include <sys/time.h>
|
||||
|
||||
double get_time()
|
||||
{
|
||||
struct timeval tv;
|
||||
gettimeofday(&tv, NULL);
|
||||
return tv.tv_sec + tv.tv_usec / 1000000.0;
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
int id;
|
||||
int start_row;
|
||||
int end_row;
|
||||
int N;
|
||||
double *A;
|
||||
double *B;
|
||||
double *C;
|
||||
} thread_arg_t;
|
||||
|
||||
void *mul_thread(void *arg)
|
||||
{
|
||||
thread_arg_t *ta = (thread_arg_t *)arg;
|
||||
for (int i = ta->start_row; i < ta->end_row; i++)
|
||||
{
|
||||
for (int j = 0; j < ta->N; j++)
|
||||
{
|
||||
double sum = 0.0;
|
||||
for (int k = 0; k < ta->N; k++)
|
||||
{
|
||||
sum += ta->A[i * ta->N + k] * ta->B[k * ta->N + j];
|
||||
}
|
||||
ta->C[i * ta->N + j] = sum;
|
||||
}
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/* 串行版本用于验证 */
|
||||
void matmul_serial(double *A, double *B, double *C, int N)
|
||||
{
|
||||
for (int i = 0; i < N; i++)
|
||||
{
|
||||
for (int j = 0; j < N; j++)
|
||||
{
|
||||
double sum = 0.0;
|
||||
for (int k = 0; k < N; k++)
|
||||
{
|
||||
sum += A[i * N + k] * B[k * N + j];
|
||||
}
|
||||
C[i * N + j] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int verify(double *C1, double *C2, int N)
|
||||
{
|
||||
for (int i = 0; i < N * N; i++)
|
||||
{
|
||||
if (C1[i] - C2[i] > 1e-6 || C2[i] - C1[i] > 1e-6)
|
||||
{
|
||||
printf("Mismatch at index %d: %f vs %f\n", i, C1[i], C2[i]);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
int N, nthreads;
|
||||
|
||||
if (argc != 3)
|
||||
{
|
||||
printf("Usage: %s <N> <nthreads>\n", argv[0]);
|
||||
exit(0);
|
||||
}
|
||||
N = atoi(argv[1]);
|
||||
nthreads = atoi(argv[2]);
|
||||
|
||||
double *A = (double *)malloc(N * N * sizeof(double));
|
||||
double *B = (double *)malloc(N * N * sizeof(double));
|
||||
double *C = (double *)malloc(N * N * sizeof(double));
|
||||
double *C_serial = (double *)malloc(N * N * sizeof(double));
|
||||
|
||||
srand(42);
|
||||
for (int i = 0; i < N * N; i++)
|
||||
{
|
||||
A[i] = (double)(rand() % 100) / 10.0;
|
||||
B[i] = (double)(rand() % 100) / 10.0;
|
||||
}
|
||||
|
||||
/* 并行版本计时 */
|
||||
pthread_t *tid = (pthread_t *)malloc(nthreads * sizeof(pthread_t));
|
||||
thread_arg_t *targs = (thread_arg_t *)malloc(nthreads * sizeof(thread_arg_t));
|
||||
|
||||
double t_start = get_time();
|
||||
int rows_per_thread = N / nthreads;
|
||||
int extra = N % nthreads;
|
||||
int current_row = 0;
|
||||
|
||||
for (int i = 0; i < nthreads; i++)
|
||||
{
|
||||
targs[i].id = i;
|
||||
targs[i].start_row = current_row;
|
||||
targs[i].end_row = current_row + rows_per_thread + (i < extra ? 1 : 0);
|
||||
targs[i].N = N;
|
||||
targs[i].A = A;
|
||||
targs[i].B = B;
|
||||
targs[i].C = C;
|
||||
current_row = targs[i].end_row;
|
||||
pthread_create(&tid[i], NULL, mul_thread, &targs[i]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < nthreads; i++)
|
||||
{
|
||||
pthread_join(tid[i], NULL);
|
||||
}
|
||||
double t_end = get_time();
|
||||
double t_parallel = t_end - t_start;
|
||||
|
||||
/* 串行版本验证 */
|
||||
double t_ser_start = get_time();
|
||||
matmul_serial(A, B, C_serial, N);
|
||||
double t_ser_end = get_time();
|
||||
double t_serial = t_ser_end - t_ser_start;
|
||||
|
||||
printf("Matrix size: %d x %d, Threads: %d\n", N, N, nthreads);
|
||||
printf("Serial time: %f seconds\n", t_serial);
|
||||
printf("Parallel time: %f seconds\n", t_parallel);
|
||||
|
||||
if (verify(C, C_serial, N))
|
||||
{
|
||||
printf("Verification: PASSED - parallel result matches serial.\n");
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Verification: FAILED.\n");
|
||||
}
|
||||
|
||||
if (t_parallel > 0)
|
||||
{
|
||||
printf("Speedup: %f\n", t_serial / t_parallel);
|
||||
printf("Efficiency: %f\n", t_serial / (nthreads * t_parallel));
|
||||
}
|
||||
|
||||
free(A);
|
||||
free(B);
|
||||
free(C);
|
||||
free(C_serial);
|
||||
free(tid);
|
||||
free(targs);
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user