Powered by AppSignal & Oban Pro
Would you like to see your link here? Contact us

RSS experiments

rss_experiments.livemd

RSS experiments

Mix.install([
  {:explorer, "~> 0.10.1"},
  {:fast_rss, "~> 0.5.0"},
  {:req, "~> 0.5.10"},
  {:bumblebee, "~> 0.6.0"},
  {:exla, ">= 0.0.0"},
  {:nx, "~> 0.9.0"},
  {:kino, "~> 0.16.0"}
])

Section

alias Req
Nx.global_default_backend(EXLA.Backend)
defmodule Montage.ImageEmbedding do
  @model_repo "openai/clip-vit-base-patch32"

  def serving() do
    {:ok, clip} =
      Bumblebee.load_model({:hf, @model_repo},
        module: Bumblebee.Vision.ClipVision,
        architecture: :for_embedding
      )

    {:ok, featurizer} = Bumblebee.load_featurizer({:hf, @model_repo})

    Bumblebee.Vision.image_embedding(clip, featurizer,
      defn_options: [compiler: EXLA],
      embedding_processor: :l2_norm,
      output_attribute: :embedding
    )
  end

  def predict(image) do
    Nx.Serving.batched_run(__MODULE__, image)
  end
end

defmodule BartModel do
  @model_repo "facebook/bart-large-mnli"
  @label_positive "mentions death of a famous person"
  @label_negative "does not mention death of person by proper name"
  def serving() do
    {:ok, model} = Bumblebee.load_model({:hf, @model_repo})
    {:ok, tokenizer_info} = Bumblebee.load_tokenizer({:hf, @model_repo})
    labels = [@label_positive, @label_negative]
    serving = Bumblebee.Text.zero_shot_classification(model, tokenizer_info, labels)
    serving
  end

  @spec predict(String.t()) :: Map.t()
  def predict(string) do
    # Nx.Serving.batched_run(__MODULE__, string)
    Nx.Serving.run(serving(), string)
  end

  def positive_label() do
    @label_positive
  end

  def get_serving() do
    serving()
  end
end
BartModel.predict("A cat perished searching for some food at home")
Kino.start_child({Nx.Serving, serving: BartModel.get_serving(), name: BartModelServing})
defmodule DeathDetector do
  # A module to detect death of a famous person from a list of RSS feeds

  # defining a type, for now is just a string maybe later it can be expanded
  @type rss_feed() :: String.t()

  @spec recent_deaths(list(rss_feed)) :: list(String.t())
  def recent_deaths(_rss_feeds) do
    # takes a list of 
    []
  end

  @spec get_rss_body(String.t()) :: list(String.t())
  def get_rss_body(url) do
    # returns 
    resp = Req.get!(url)

    case resp do
      %{status: 200, body: body} ->
        {:ok, _map_of_rss} = FastRSS.parse_rss(body)

      %{status: _} ->
        IO.puts("RSS feed unreachable")
    end
  end

  @spec process_parsed_rss(map()) :: list(String.t())
  def process_parsed_rss(rss_map) do
    Map.get(rss_map, :items)
    |> Enum.map(fn x ->
      {String.contains?(x["description"], "Ukraine"), x["description"]}
    end)
    |> Enum.filter(fn {dead, _descript} -> dead end)
  end

  @spec detect_death(String, float()) :: bool()
  def detect_death(description, threshold \\ 0.8) do
    # using a small model
    output = BartModel.predict(description)

    %{
      predictions: preds
    } = output

    score =
      preds
      |> Enum.find(fn x -> Map.get(x, :label) == BartModel.positive_label() end)
      |> Map.get(:score)

    if score >= threshold do
      true
    else
      false
    end
  end

  def process_prediction(model_output, threshold \\ 0.89) do
    %{
      predictions: preds
    } = model_output

    score =
      preds
      |> Enum.find(fn x -> Map.get(x, :label) == BartModel.positive_label() end)
      |> Map.get(:score)

    if score >= threshold do
      true
    else
      false
    end
  end

  @spec detect_deaths_batch(list) :: list()
  def detect_deaths_batch(descriptions) do
    all_preds = Nx.Serving.batched_run(BartModelServing, descriptions)

    Enum.map(all_preds, fn x ->
      process_prediction(x)
    end)
  end
end

# url = "https://rss.nytimes.com/services/xml/rss/nyt/World.xml"
# resp = Req.get!(url)
# body = Map.get(resp, :body)
DeathDetector.detect_death("Donald Trump launches a new war against the cats")
samples = [
  "Donald Trump launches a new war against the cats",
  "Nelson Mandela dies again"
]

DeathDetector.detect_deaths_batch(samples)
url = "https://rss.nytimes.com/services/xml/rss/nyt/World.xml"
resp = Req.get!(url)
body = Map.get(resp, :body)
alias FastRSS
{:ok, map_of_rss} = FastRSS.parse_rss(body)
Map.keys(map_of_rss)
all_descriptions =
  Enum.map(
    map_of_rss["items"],
    fn x ->
      x["description"]
    end
  )

processed = DeathDetector.detect_deaths_batch(all_descriptions)
Enum.zip(processed, map_of_rss["items"])
|> Enum.filter(fn {dead, _m} -> dead end)
|> Enum.map(fn {_dead, rss_item} -> rss_item["title"] end)
Map.keys(Enum.at(map_of_rss["items"], 0))

Enum.at(map_of_rss["items"], 0)["description"]

r =
  Enum.map(
    map_of_rss["items"],
    fn x ->
      {DeathDetector.detect_death(x["description"]), x["description"]}
    end
  )
  |> Enum.filter(fn {dead, _descript} -> dead end)
Enum.at(r, 0)
Enum.at(map_of_rss["items"], 0)["link"]

r =
  Enum.map(
    map_of_rss["items"],
    fn x ->
      {String.contains?(x["description"], "Trump"), x["description"]}
    end
  )
  |> Enum.filter(fn {dead, _descript} -> dead end)