Files
C-exp-collection/exp1/matmult.c
2026-06-09 06:43:13 +02:00

164 lines
4.0 KiB
C
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/*
任务5(选做编写一个N×N矩阵乘法函数的并行线程化版本程序保存为matmult.c
设计一种方案,验证并行程序正确性。
编译、调试、运行程序设2^k ≤ N < 2^(k+1)
给出线程数为1、2、4、…、2^(k-1)、2^k等情况下的运行时间计算加速比与效率并对结果给出解释。
*/
#include <stdio.h>
#include <stdlib.h>
#include <pthread.h>
#include <sys/time.h>
double get_time()
{
struct timeval tv;
gettimeofday(&tv, NULL);
return tv.tv_sec + tv.tv_usec / 1000000.0;
}
typedef struct {
int id;
int start_row;
int end_row;
int N;
double *A;
double *B;
double *C;
} thread_arg_t;
void *mul_thread(void *arg)
{
thread_arg_t *ta = (thread_arg_t *)arg;
for (int i = ta->start_row; i < ta->end_row; i++)
{
for (int j = 0; j < ta->N; j++)
{
double sum = 0.0;
for (int k = 0; k < ta->N; k++)
{
sum += ta->A[i * ta->N + k] * ta->B[k * ta->N + j];
}
ta->C[i * ta->N + j] = sum;
}
}
return NULL;
}
/* 串行版本用于验证 */
void matmul_serial(double *A, double *B, double *C, int N)
{
for (int i = 0; i < N; i++)
{
for (int j = 0; j < N; j++)
{
double sum = 0.0;
for (int k = 0; k < N; k++)
{
sum += A[i * N + k] * B[k * N + j];
}
C[i * N + j] = sum;
}
}
}
int verify(double *C1, double *C2, int N)
{
for (int i = 0; i < N * N; i++)
{
if (C1[i] - C2[i] > 1e-6 || C2[i] - C1[i] > 1e-6)
{
printf("Mismatch at index %d: %f vs %f\n", i, C1[i], C2[i]);
return 0;
}
}
return 1;
}
int main(int argc, char **argv)
{
int N, nthreads;
if (argc != 3)
{
printf("Usage: %s <N> <nthreads>\n", argv[0]);
exit(0);
}
N = atoi(argv[1]);
nthreads = atoi(argv[2]);
double *A = (double *)malloc(N * N * sizeof(double));
double *B = (double *)malloc(N * N * sizeof(double));
double *C = (double *)malloc(N * N * sizeof(double));
double *C_serial = (double *)malloc(N * N * sizeof(double));
srand(42);
for (int i = 0; i < N * N; i++)
{
A[i] = (double)(rand() % 100) / 10.0;
B[i] = (double)(rand() % 100) / 10.0;
}
/* 并行版本计时 */
pthread_t *tid = (pthread_t *)malloc(nthreads * sizeof(pthread_t));
thread_arg_t *targs = (thread_arg_t *)malloc(nthreads * sizeof(thread_arg_t));
double t_start = get_time();
int rows_per_thread = N / nthreads;
int extra = N % nthreads;
int current_row = 0;
for (int i = 0; i < nthreads; i++)
{
targs[i].id = i;
targs[i].start_row = current_row;
targs[i].end_row = current_row + rows_per_thread + (i < extra ? 1 : 0);
targs[i].N = N;
targs[i].A = A;
targs[i].B = B;
targs[i].C = C;
current_row = targs[i].end_row;
pthread_create(&tid[i], NULL, mul_thread, &targs[i]);
}
for (int i = 0; i < nthreads; i++)
{
pthread_join(tid[i], NULL);
}
double t_end = get_time();
double t_parallel = t_end - t_start;
/* 串行版本验证 */
double t_ser_start = get_time();
matmul_serial(A, B, C_serial, N);
double t_ser_end = get_time();
double t_serial = t_ser_end - t_ser_start;
printf("Matrix size: %d x %d, Threads: %d\n", N, N, nthreads);
printf("Serial time: %f seconds\n", t_serial);
printf("Parallel time: %f seconds\n", t_parallel);
if (verify(C, C_serial, N))
{
printf("Verification: PASSED - parallel result matches serial.\n");
}
else
{
printf("Verification: FAILED.\n");
}
if (t_parallel > 0)
{
printf("Speedup: %f\n", t_serial / t_parallel);
printf("Efficiency: %f\n", t_serial / (nthreads * t_parallel));
}
free(A);
free(B);
free(C);
free(C_serial);
free(tid);
free(targs);
return 0;
}