Web crawler in Rust

Web crawler in Rust

I have heard many good things about Rust for several years now. A couple of months ago, I finally decided to start learning Rust. I skimmed through the Book and did the exercises from rustlings. While they helped me get started, I learn best by doing some projects. So I decided to replace the crawler that I used for my Ghost blog, which had been written in bash with wget, with something written in Rust.

And I was pleasantly surprised. I am by no means very knowledgeable in Rust, I still have to look up most of the operations on the Option and Result types, I have to DuckDuckGo how to make HTTP requests, read and write files and so on, but I was still able to write a minimal crawler in about 2-3 hours and then in about 10 hours of total work  I had something that was both faster and had fewer bugs than the wget script.

So let's start writing a simple crawler that downloads all the HTML pages from a blog.

Initializing a Rust project

After installing Rust, let's create a project somewhere:

 > cargo new rust_crawler

This initializes a Hello World program, which we can verify that it runs using:

> cargo run
   Compiling rust_crawler v0.1.0 (D:\Programming\rust_crawler)
    Finished dev [unoptimized + debuginfo] target(s) in 9.31s
     Running `target\debug\rust_crawler.exe`
Hello, world!

Making HTTP requests

Let's make our first HTTP request. For this, we will use the reqwest library. It has both blocking and asynchronous APIs for making HTTP calls. We'll start off with the blocking API, because it's easier.

use std::io::Read;

fn main() {
    let client = reqwest::blocking::Client::new();
    let origin_url = "https://ghost.rolisz.ro/";
    let mut res = client.get(origin_url).send().unwrap();
    println!("Status for {}: {}", origin_url, res.status());

    let mut body  = String::new();
    res.read_to_string(&mut body).unwrap();
    println!("HTML: {}", &body[0..40]);
}
> cargo run
   Compiling rust_crawler v0.1.0 (D:\Programming\rust_crawler)
    Finished dev [unoptimized + debuginfo] target(s) in 2.30s
     Running `target\debug\rust_crawler.exe`
Status: 200 OK https://ghost.rolisz.ro/
HTML <!DOCTYPE html>
<html lang="en">
<head>

We create a new reqwest blocking client, create a GET request and we send it. The send call normally returns a Result, which we just unwrap for now. We print out the status code, to make sure the request returned ok and then we copy the content of the request into a mutable variable and we print it out. So far so good.

Now let's parse the HTML and extract all the links we find. For this we will use the select crate, which can parse HTML and allows us to search through the nodes.

use std::io::Read;
use select::document::Document;
use select::predicate::Name;

fn main() {
    let client = reqwest::blocking::Client::new();
    let origin_url = "https://ghost.rolisz.ro/";
    let mut res = client.get(origin_url).send().unwrap();
    println!("Status for {}: {}", origin_url, res.status());

    let mut body  = String::new();
    res.read_to_string(&mut body).unwrap();
   
    Document::from(body.as_str())
        .find(Name("a"))
        .filter_map(|n| n.attr("href"))
        .for_each(|x| println!("{}", x));
}
> cargo run --color=always --package rust_crawler --bin rust_crawler
   Compiling rust_crawler v0.1.0 (D:\Programming\rust_crawler)
    Finished dev [unoptimized + debuginfo] target(s) in 2.65s
     Running `target\debug\rust_crawler.exe`
Status for https://ghost.rolisz.ro/: 200 OK
https://ghost.rolisz.ro
https://ghost.rolisz.ro
https://ghost.rolisz.ro/projects/
https://ghost.rolisz.ro/about-me/
https://ghost.rolisz.ro/uses/
https://ghost.rolisz.ro/tag/trips/
https://ghost.rolisz.ro/tag/reviews/
#subscribe
/2020/02/13/lost-in-space/
/2020/02/13/lost-in-space/
/author/rolisz/
/author/rolisz/
...
/2020/02/07/interview-about-wfh/
/2020/02/07/interview-about-wfh/
/2019/01/30/nas-outage-1/
/2019/01/30/nas-outage-1/
/author/rolisz/
/author/rolisz/
https://ghost.rolisz.ro
https://ghost.rolisz.ro
https://www.facebook.com/rolisz
https://twitter.com/rolisz
https://ghost.org
javascript:;
#

We search for all the anchor tags, filter only those that have a valid href attribute and we print the value of those attributes.

We see all the links in the output, but there are some issues. First, some of the links are absolute, some are relative, and some are pseudo-links used for doing Javascript things. Second, the links that point towards posts are duplicated and third, there are links that don't point towards something on my blog.

The duplicate problem is easy to fix: we put everything into a HashSet and then we'll get only a unique collection of URLs.

use std::io::Read;
use select::document::Document;
use select::predicate::Name;
use std::collections::HashSet;

fn main() {
    let client = reqwest::blocking::Client::new();
    let origin_url = "https://ghost.rolisz.ro/";
    let mut res = client.get(origin_url).send().unwrap();
    println!("Status for {}: {}", origin_url, res.status());

    let mut body  = String::new();
    res.read_to_string(&mut body).unwrap();

    let found_urls = Document::from(body.as_str())
        .find(Name("a"))
        .filter_map(|n| n.attr("href"))
        .map(str::to_string)
        .collect::<HashSet<String>>();
    println!("URLs: {:#?}", found_urls)
}

First we have to convert the URLs from str type to String, so we get objects that have a separate lifetime from the original string which contains the whole HTML. Then we insert all the strings into a hash set, using the collect function from Rust, which handles insertion into all kinds of containers, in all kinds of situations.

To solve the other two problems we have to parse the URLs, using methods provided by reqwest.

use std::io::Read;
use select::document::Document;
use select::predicate::Name;
use std::collections::HashSet;
use reqwest::Url;

fn get_links_from_html(html: &str) -> HashSet<String> {
    Document::from(html.as_str())
        .find(Name("a").or(Name("link")))
        .filter_map(|n| n.attr("href"))
        .filter_map(normalize_url)
        .collect::<HashSet<String>>()
}

fn normalize_url(url: &str) -> Option<String> {
    let new_url = Url::parse(url);
    match new_url {
        Ok(new_url) => {
            if new_url.has_host() && new_url.host_str().unwrap() == "ghost.rolisz.ro" {
                Some(url.to_string())
            } else {
                None
            }
        },
        Err(_e) => {
            // Relative urls are not parsed by Reqwest
            if url.starts_with('/') {
                Some(format!("https://ghost.rolisz.ro{}", url))
            } else {
                None
            }
        }
    }
}

fn main() {
    let client = reqwest::blocking::Client::new();
    let origin_url = "https://ghost.rolisz.ro/";
    let mut res = client.get(origin_url).send().unwrap();
    println!("Status for {}: {}", origin_url, res.status());

    let mut body = String::new();
    res.read_to_string(&mut body).unwrap();

    let found_urls = get_links_from_html(&body);
    println!("URLs: {:#?}", found_urls)
}

We moved all the logic to a function get_links_from_html. We apply another filter_map to the links we find, in which we check if we can parse the URL. If we can, we check if there is a host and if it's equal to my blog. Otherwise, if we can't parse, we check if it starts with a /, in which case it's a relative URL. All other cases lead to rejection of the URL.

Now it's time to start going over these links that we get so that we crawl the whole blog. We'll do a breadth first traversal and we'll have to keep track of the visited URLs.

fn fetch_url(client: &reqwest::blocking::Client, url: &str) -> String {
    let mut res = client.get(url).send().unwrap();
    println!("Status for {}: {}", url, res.status());

    let mut body  = String::new();
    res.read_to_string(&mut body).unwrap();
    body
}

fn main() {
    let now = Instant::now();

    let client = reqwest::blocking::Client::new();
    let origin_url = "https://ghost.rolisz.ro/";

    let body = fetch_url(&client, origin_url);

    let mut visited = HashSet::new();
    visited.insert(origin_url.to_string());
    let found_urls = get_links_from_html(&body);
    let mut new_urls = found_urls
    	.difference(&visited)
        .map(|x| x.to_string())
        .collect::<HashSet<String>>();

    while !new_urls.is_empty() {
        let mut found_urls: HashSet<String> = new_urls.iter().map(|url| {
            let body = fetch_url(&client, url);
            let links = get_links_from_html(&body);
            println!("Visited: {} found {} links", url, links.len());
            links
        }).fold(HashSet::new(), |mut acc, x| {
                acc.extend(x);
                acc
        })
        visited.extend(new_urls);
        
        new_urls = found_urls
        	.difference(&visited)
            .map(|x| x.to_string())
            .collect::<HashSet<String>>();
        println!("New urls: {}", new_urls.len())
    }
    println!("URLs: {:#?}", found_urls);
    println!("{}", now.elapsed().as_secs());

}

First, we moved the code to fetch a URL to its own function, because we will be using it in two places.

Then the idea is that we have a HashSet containing all the pages we have visited so far. When we visit a new page, we find all the links in that page and we subtract from them all the links that we have previously visited. These will be new links that we will have to visit. We repeat this as long as we have new links to visit.

So we run this and we get the following output:

Status for https://ghost.rolisz.ro/favicon.ico: 200 OK
thread 'main' panicked at 'called `Result::unwrap()` on an `Err` value: Custom { kind: InvalidData, error: "stream did not contain valid UTF-8" }', src\libcore\result.rs:1165:5
stack backtrace:
   0: core::fmt::write
             at /rustc/73528e339aae0f17a15ffa49a8ac608f50c6cf14\/src\libcore\fmt\mod.rs:1028
   1: std::io::Write::write_fmt<std::sys::windows::stdio::Stderr>
             at /rustc/73528e339aae0f17a15ffa49a8ac608f50c6cf14\/src\libstd\io\mod.rs:1412
   2: std::sys_common::backtrace::_print
             at /rustc/73528e339aae0f17a15ffa49a8ac608f50c6cf14\/src\libstd\sys_common\backtrace.rs:65
   3: std::sys_common::backtrace::print
             at /rustc/73528e339aae0f17a15ffa49a8ac608f50c6cf14\/src\libstd\sys_common\backtrace.rs:50
...

The problem is that our crawler tries to download, as text, pictures and other binaries. The Rust String has to be valid UTF-8, so when it tries to put there all kinds of bytes, we will have some that lead to invalid UTF-8 so we get a panic. We could solve this in two different ways: either download URLs as bytes and then convert to strings only those that we know are HTML, or we can skip the ones that are not HTML. Because I am interested in only the textual content of my blog, I will implement the latter solution.

fn has_extension(url: &&str) -> bool {
    Path::new(url).extension().is_none()
}

fn get_links_from_html(html: &str) -> HashSet<String> {
    Document::from(html.as_str())
        .find(Name("a").or(Name("link")))
        .filter_map(|n| n.attr("href"))
        .filter(has_extension)
        .filter_map(normalize_url)
        .collect::<HashSet<String>>()
}

To determine if it's an HTML, we look if there is an extension or not and we that as a filter to our function which retrieves link from the HTML.

Writing the HTML to disk

We are now getting all the HTML we want, time to start writing it to disk.

fn write_file(path: &str, content: &str) {
    fs::create_dir_all(format!("static{}", path)).unwrap();
    fs::write(format!("static{}/index.html", path), content);
}

fn main() {
    let now = Instant::now();

    let client = reqwest::blocking::Client::new();
    let origin_url = "https://ghost.rolisz.ro/";

    let body= fetch_url(&client, origin_url);

    write_file("", &body);
    let mut visited = HashSet::new();
    visited.insert(origin_url.to_string());
    let found_urls = get_links_from_html(&body);
    let mut new_urls = found_urls
    	.difference(&visited)
        .map(|x| x.to_string())
        .collect::<HashSet<String>>();

    while new_urls.len() > 0 {
        let mut found_urls: HashSet<String> = new_urls
        	.iter()
            .map(|url| {
                let body = fetch_url(&client, url);
                write_file(&url[origin_url.len() - 1..], &body);
                let links = get_links_from_html(&body);
                println!("Visited: {} found {} links", url, links.len());
                links
        })
        .fold(HashSet::new(), |mut acc, x| {
                acc.extend(x);
                acc
        })
        visited.extend(new_urls);
        new_urls = found_urls
            .difference(&visited)
            .map(|x| x.to_string())
            .collect::<HashSet<String>>();
        println!("New urls: {}", new_urls.len())
    }
    println!("URLs: {:#?}", found_urls);
    println!("{}", now.elapsed().as_secs());

}

We use the create_dir_all function, which works like mkdir -p in Linux to create the nested folder structure. We write the HTML page to the index.html file in the same folder structure as the URL structure. Most web servers will then serve the index.html file when going to the URL, so the output in the browser will be the same as the one from Ghost serving dynamic pages.

Speeding it up

Letting this run on my blog takes about 110 seconds. Let's see if we can speed it up by downloading the pages in parallel.

fn main() {
    let now = Instant::now();

    let client = reqwest::blocking::Client::new();
    let origin_url = "https://ghost.rolisz.ro/";

    let body = fetch_url(&client, origin_url);

    write_file("", &body);
    let mut visited = HashSet::new();
    visited.insert(origin_url.to_string());
    let found_urls = get_links_from_html(&body);
    let mut new_urls = found_urls
        .difference(&visited)
        .map(|x| x.to_string())
        .collect::<HashSet<String>>();

    while !new_urls.is_empty() {
        let found_urls: HashSet<String> = new_urls
            .par_iter()
            .map(|url| {
                let body = fetch_url(&client, url);
                write_file(&url[origin_url.len() - 1..], &body);

                let links = get_links_from_html(&body);
                println!("Visited: {} found {} links", url, links.len());
                links
            })
            .reduce(HashSet::new, |mut acc, x| {
                acc.extend(x);
                acc
            });
        visited.extend(new_urls);
        new_urls = found_urls
            .difference(&visited)
            .map(|x| x.to_string())
            .collect::<HashSet<String>>();
        println!("New urls: {}", new_urls.len())
    }
    println!("URLs: {:#?}", found_urls);
    println!("{}", now.elapsed().as_secs());
}

In Rust there is this awesome library called Rayon which provides a very simple primitive for running functions in parallel: par_iter, which is short for parallel iterator. It's an almost drop-in replacement for iter, which is part of the standard library for collections, and it runs the provided closure in parallel, taking care of boring stuff like thread scheduling. Besides changing iter to par_iter, we have to change the fold to reduce and provide a closure that returns the "zero" element, so it can generate multiple of them.

This reduces the running time to 70 seconds, down from 110 seconds.

Proper error handling

One more thing to fix in our program: error handling. Rust helps us a lot with error handling with it's builtin Option and Result types, but so far we've been ignoring them, liberally sprinkling unwrap everywhere. unwrap returns the inner type or panics if there is an error (for Result) or None value (for Option). To handle these correctly, we should create our own error type.

One appearance of unwrap that we can get rid of easily is in the normalize_url function. In the if we have new_url.has_host() && new_url.host_str().unwrap() == "ghost.rolisz.ro" This can't possibly panic, because we do a check first that the host string exists, but there is a nicer Rust way to express this:

if let Some("ghost.rolisz.ro") = new_url.host_str() {
	Some(url.to_string())
}

To my Rust newbie eyes, it looks really weird at a first glance, but it does make sense eventually.

For the other cases we need to define our own Error type, which will be a wrapper around the other types, providing a uniform interface to all of them:

#[derive(Debug)]
enum Error {
    Write { url: String, e: IoErr },
    Fetch { url: String, e: reqwest::Error },
}

type Result<T> = std::result::Result<T, Error>;

impl<S: AsRef<str>> From<(S, IoErr)> for Error {
    fn from((url, e): (S, IoErr)) -> Self {
        Error::Write {
            url: url.as_ref().to_string(),
            e,
        }
    }
}

impl<S: AsRef<str>> From<(S, reqwest::Error)> for Error {
    fn from((url, e): (S, reqwest::Error)) -> Self {
        Error::Fetch {
            url: url.as_ref().to_string(),
            e,
        }
    }
}

We have two kinds of errors in our crawler: IoErr and reqwest::Error. The first is returned when trying to write a file, the second when we try to fetch a URL. Besides the original error, we add some context, such as the URL or path that was accessed when we got the error. We provide implementation to convert from each library error to our own error type and we also define a Result helper type so that we don't always have to type out our Error type.

fn fetch_url(client: &reqwest::blocking::Client, url: &str) -> Result<String> {
    let mut res = client.get(url).send().map_err(|e| (url, e))?;
    println!("Status for {}: {}", url, res.status());

    let mut body = String::new();
    res.read_to_string(&mut body).map_err(|e| (url, e))?;
    Ok(body)
}

fn write_file(path: &str, content: &str) -> Result<()> {
    let dir = format!("static{}", path);
    fs::create_dir_all(format!("static{}", path)).map_err(|e| (&dir, e))?;
    let index = format!("static{}/index.html", path);
    fs::write(&index, content).map_err(|e| (&index, e))?;

    Ok(())
}

Our two functions that can produce errors now return a Result type. All the operations that can return an error have a map_err applied to the result, and we generate our own Error from the existing error.

let (found_urls, errors): (Vec<Result<HashSet<String>>>, Vec<_>) = new_urls
      .par_iter()
      .map(|url| -> Result<HashSet<String>> {
            let body = fetch_url(&client, url)?;
            write_file(&url[origin_url.len() - 1..], &body)?;

            let links = get_links_from_html(&body);
            println!("Visited: {} found {} links", url, links.len());
            Ok(links)
       })
       .partition(Result::is_ok);

Our main loop to download new URLs changes a bit. Our closure now returns either a set of URLs or an error. To separate the two kinds of results, we partition the iterator based on Result::is_ok, resulting in the vectors, one with HashSets, one with Errors, but both still wrapped in Results.

visited.extend(new_urls);
new_urls = found_urls
    .into_par_iter()
    .map(Result::unwrap)
    .reduce(HashSet::new, |mut acc, x| {
        acc.extend(x);
        acc
    })
    .difference(&visited)
    .map(|x| x.to_string())
    .collect::<HashSet<String>>();
println!("New urls: {}", new_urls.len());

We handle each vector separately. For the success one we have to unwrap and the merge all the HashSets into one.

println!(
   "Errors: {:#?}",
    errors
        .into_iter()
        .map(Result::unwrap_err)
        .collect::<Vec<Error>>()
)

For the Vec containing the Errors, we have to unwrap the errors and then we just  print them out.

And with that we have a small and simple web crawler, which runs fairly fast and which handles most (all?) errors correctly. The final version of the code can be found here.

Special thanks to Cedric Hutchings and lights0123 who reviewed my code on Code Review.