Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

decodeAFs Very slow #43

Open
carlsummer opened this issue Oct 13, 2022 · 2 comments
Open

decodeAFs Very slow #43

carlsummer opened this issue Oct 13, 2022 · 2 comments

Comments

@carlsummer
Copy link

DecodeAFs has a lot of for which is very slow. Can you write a version running in torch

@carlsummer
Copy link
Author

140ms before the change

130ms after modification
`def decodeAFs(BW, VAF, HAF, fg_thresh=128, err_thresh=5, viz=False):
output = np.zeros_like(BW, dtype=np.uint8) # initialize output array
lane_end_pts = [] # keep track of latest lane points
next_lane_id = 1 # next available lane ID

if viz:
    im_color = cv2.applyColorMap(BW, cv2.COLORMAP_JET)
    cv2.imshow('BW', im_color)
    ret = cv2.waitKey(0)

# start decoding from last row to first
BW_fg_thresh_rows,BW_fg_thresh_cols = np.where(BW>fg_thresh)
bw_row_unique = np.unique(BW_fg_thresh_rows)
for i in range(len(bw_row_unique)-1, -1, -1):
    row = bw_row_unique[i]
    cols = BW_fg_thresh_cols[BW_fg_thresh_rows==row] # get fg cols
    clusters = [[]]
    prev_col = cols[0]

    # parse horizontally
    for col in cols:
        if col - prev_col > err_thresh or HAF[row, prev_col] < 0 and HAF[row, col] >= 0: # if too far away from last point
            clusters.append([])
        clusters[-1].append(col)
        prev_col = col

    # parse vertically
    # assign existing lanes
    assigned = [False for _ in clusters]
    C = np.Inf*np.ones((len(lane_end_pts), len(clusters)), dtype=np.float64)
    for r, pts in enumerate(lane_end_pts): # for each end point in an active lane
        for c, cluster in enumerate(clusters):
            # mean of current cluster
            cluster_mean = np.array([[np.mean(cluster), row]], dtype=np.float32)
            # get vafs from lane end points
            vafs = np.array([VAF[int(round(x[1])), int(round(x[0])), :] for x in pts], dtype=np.float32)
            vafs = vafs / np.linalg.norm(vafs, axis=1, keepdims=True)
            # get predicted cluster center by adding vafs
            pred_points = pts + vafs*np.linalg.norm(pts - cluster_mean, axis=1, keepdims=True)
            # get error between prediceted cluster center and actual cluster center
            error = np.mean(np.linalg.norm(pred_points - cluster_mean, axis=1))
            C[r, c] = error
    # assign clusters to lane (in acsending order of error)
    row_ind, col_ind = np.unravel_index(np.argsort(C, axis=None), C.shape)
    for r, c in zip(row_ind, col_ind):
        if C[r, c] >= err_thresh:
            break
        if assigned[c]:
            continue
        assigned[c] = True
        # update best lane match with current pixel
        output[row, clusters[c]] = r+1
        lane_end_pts[r] = np.stack((np.array(clusters[c], dtype=np.float32), row*np.ones_like(clusters[c])), axis=1)
    # initialize unassigned clusters to new lanes
    for c, cluster in enumerate(clusters):
        if not assigned[c]:
            output[row, cluster] = next_lane_id
            lane_end_pts.append(np.stack((np.array(cluster, dtype=np.float32), row*np.ones_like(cluster)), axis=1))
            next_lane_id += 1

if viz:
    im_color = cv2.applyColorMap(40*output, cv2.COLORMAP_JET)
    cv2.imshow('Output', im_color)
    ret = cv2.waitKey(0)

return output`

@carlsummer
Copy link
Author

Decode is very slow. If you can change it to torch, it will be perfect @arangesh

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant